Freezing a Keras model

How to freeze a model for serving and other applications

Joseph Aylett-Bullock
Towards Data Science

--

“landscape photography of snowy mountain and body of water” by Joel & Jasmin Førestbird on Unsplash

Introduction

Keras has become one of the dominant APIs for developing and testing neural networks across industry and academia. It combines user friendly syntax with flexibility due to its various backend possibilities, meaning that you can write in TensorFlow, Theano or CNTK and call it in Keras.

Once you have designed a network using Keras, you may want to serve it in another API, on the web, or other medium. One of the easiest way to do many of the above is to use the pre-built TensorFlow libraries (such as the TensorFlow C++ API for model inference in a C++ environment). In order to do this you will most likely have to ‘freeze’ your trained Keras model due to the way the backends of these APIs work. Unfortunately, in order to do this easily, you will have to retrain your model in the TensorFlow implementation of Keras. Fortunately, however, this process is very simple if you know what you’re aiming for thanks to the way Keras is integrated into TensorFlow, and various other materials provided by TensorFlow for this task.

When I was originally researching how to go about this task for my own requirement of serving a model in a C++ environment, I came across several possible answers from similar posts on places like stack overflow, as well as here on Medium. Similarly, some people have made a variety of great open source packages to do just this, such as keras2cpp. Yet, I found the articles lacking in some crucial details, and frequently using complicated techniques, necessary when written, but which have been significantly simplified thanks to TensorFlow updates.

My hope is that this article provides a simple, up-to-date walkthrough of how to freeze a Keras model for any general requirement, not just for serving in a C++ environment. The methodology described in this article, although written specifically for converting a model written in the Keras API (natively or in the TensorFlow implementation), it can also be used for any model written in native TensorFlow as well.

Why do I need to convert my model?

If your model is trained and saved in the Keras API then you will probably have saved an hdf5 file of your model in which the network architecture and weights are saved together in one file. This file can be called and loaded back into the Keras API for inference. Sadly, however, this file type is not recognised by TensorFlow APIs and is also unnecessarily large to store, load in, and perform inference on.

Similarly, if you write a model in the TensorFlow Python API, then the training procedure will save a TensorFlow graph, using Google’s ProtoBuf library, and a series of checkpoint files. The graph stores the information about the architecture of the network with Variable ops, whereas the checkpoint files contain the values of the weights at various stages of training (depending on how regularly your session checkpoints during training). These can normally be loaded in for inference in a TensorFlow Python API session during which weights from the checkpoint files are inserted into the Variable ops in the graph. Yet, this is inefficient when just performing inference. A saved Keras .h5 file, on the other hand, is simply the graph and final state checkpoint file combined together, while still keeping the hyperparameters stored as Variable ops. (Note: a detailed understanding of the above is not necessary, I just add it for those wanting a more detailed explanation.)

Freezing the model means producing a singular file containing information about the graph and checkpoint variables, but saving these hyperparameters as constants within the graph structure. This eliminates additional information saved in the checkpoint files such as the gradients at each point, which are included so that the model can be reloaded and training continued from where you left off. As this is not needed when serving a model purely for inference they are discarded in freezing. A frozen model is a file of the Google .pb file type.

Requirements

The requirements for freezing your model for inference are simple, however, you will probably need to install various other packages to actually perform inference depending on your application:

  • Use Keras with a TensorFlow backend
  • Although it should be working automatically having instally TensorFlow, you will need to make sure that TensorBoard is working on your computer
  • From the TensorFlow repository copy the freeze_graph.py python script to your working directory. Alternatively you can use a custom designed freeze_graph function, which we will see later.

Converting the model

In order to generate a .pb file containing the necessary information TensorFlow has helpfully written the freeze_graph.py file which, when called, will merge the TensorFlow graph and checkpoint files. Using this Python script is often advisable since it was written by the TensorFlow team, who ensure that it will work with their in-house file formats. However, in order to use it you must first have a graph and checkpoint file in the correct format. The Keras API does not generate automatically this file so you will need to retrain the model in order to generate them.

There are two relatively simple ways to go about acquiring a saved trained model in the correct format for freezing. If you are using the Keras API directly, then you will be required to change to the Keras API implemented in a TensorFlow environment.

Note: This should just require a change at the importing stage e.g. instead of from keras.layers import Convolution2D you would have from tensorflow.keras.layers import Convolution2D.

The first method calls the model and then converts it to a TensorFlow Estimator, which will handle the training. The second requires an understanding of how TensorFlow sessions work, as this method trains the network as you would one written in native TensorFlow. After implementing either of these training methods, you will be able to run freeze_graph in order to get the correct output.

Let’s work through an example in order to demonstrate this process. Here we shall convert a simple Keras model for digit classification on the MNIST dataset into an Estimator.

First, we build the model in the Keras API implemented in TensorFlow, be sure to name your input and output layers, as we will need these names later:

Defining the model, note we do not define an input layer here since the type of input layer depends on which method you choose for training

Option 1: Converting to an Estimator

This method is probably the most simple of the two in general, however, it can quickly become more complex, depending on how you want to train the model, due to the nature of handling custom Estimators. If you are only familiar with training models in the native Keras API then this is the most similar way to train your model. The ability to convert a Keras model into a TensorFlow Estimator was introduced in TensorFlow 1.4 and is descibed in this tutorial.

To convert the model as written above into an estimator, first compile the model using the normal Keras API implemented in TensorFlow and then use the model_to_estimator() function:

Converting the model into an Estimator, here the checkpoint files and graph will be saved in the model directory

Now the estimator_model behaves like a TensorFlow Estimator and we can train the model accordingly. For a guide on how to train an Estimator see the Documentation.

For training, define an input function and train the model accordingly:

The model is now trained and the graph.pbtxt and checkpoint .ckpt files will be saved in the ./Keras_MNIST model directory.

Option 2: Training like a native TensorFlow model

An alternative approch is to train the model by initiating a TensorFlow session and training within the session. This requires a more detailed understanding of how sessions work, which is beyond the scope of this article, and I refer the reader to this MNIST tutorial as an example of such training. However, unlike the example, you do not need to write the model in native TensorFlow, instead we can call our model above and just change the input layer:

The rest can be followed from the tutorial.

Freezing the model

Now that the model has been trained and the graph and checkpoint files made we can use TensorFlow’s freeze_graph.py to merge these together.

Note: Make sure the freeze_graph.py is in the same directory as the checkpoint and graph files you’d like to freeze.

Alternatively, I find the simplified version developed by Morgan to be less prone to flagging errors.

Simplified freeze_graph implementation by Morgan

This function restarts a temporary session with a graph to restore the most recent checkpoint. Another advantage of using this implementation is that you don’t have to make sure you’re specifying the correct checkpoint file, or handle the more syntactically dense freeze_graph.py inputs.

We need to know the name of the output node to give as a reference point to the function. This is the same if we use the simplified version of freeze_graph or not. The name given you gave to the final layer is not enough here, instead opening TensorBoard will give a visualisation of the graph.

TensorBoard output showing input to ‘metrics’ node

In this example we want the tensor which feeds into the ‘metrics’ node, called output_node/Softmax.

This will now generate a frozen graph in the model directory.

Key takeaways

  1. Keras models can be trained in a TensorFlow environment or, more conveniently, turned into an Estimator with little syntactic change.
  2. To freeze a model you first need to generate the checkpoint and graph files on which to can call freeze_graph.py or the simplified version above.
  3. There are many issues flagged on the TensorFlow as Keras GitHubs, as well as stack overflow, about freezing models, a large number of which can be resolved by understanding the files which need to be generated and how to specify the output node.
  4. The output node can be easily found from the TensorBoard visualisation but it is NOT always just the name of the final layer you specified.

--

--

PhD Researcher at the Institute of Particle Physics Phenomenology, Durham University, specialising in Data Intensive Science/ Machine Learning