The world’s leading publication for data science, AI, and ML professionals.

Why You Should Use Callbacks in TensorFlow 2

Customize your training of deep neural networks – a practical guide

Photo by John Schnobrich on Unsplash
Photo by John Schnobrich on Unsplash

Callbacks are essential when you want to control the training of a model.

And you do want to control the training…

Callbacks help us prevent overfitting, visualize our training progress, save checkpoints and much more.

TensorFlow

But why TensorFlow?

TensorFlow is the preferred deep learning API in the world by many companies.

Maintained and heavily used internally by Google, this is state-of-the-art technology.

One of the reasons for this success is undoubtedly the huge amount of tools and libraries available in its ecosystem empowering companies and individuals with easy access to research and technology for their products and systems.

Of course, it helps that it is a Python library since Python is the most used programming language for Data Science.

Today we are taking a look at one, particularly useful tool. One that will make your life considerably easier if used correctly.

Building a Model

Callbacks are functions that you can wrap around your model training.

They could be used if you want the training to stop based on a high accuracy for example, or when your loss function’s curve begins to flatten out.

In this article, I will not focus much on the actual model but rather see how we can control the training process through callbacks.

I assume that the reader knows how to build a simple model in TensorFlow therefore I won’t go into detail about the code.

Let’s warm up by coding a simple CNN predicting some hand-written digits.

Let’s import some libraries and load the data. We will use the MNIST dataset here.

The above is pretty standard stuff. We load in the (more or less) raw data, reshape it and normalize it by dividing all the pixel values by the maximum value of 255.

Note that when you divide a NumPy array by a number, you divide all entries by that number.

Next, we make the targets one-hot encoded which basically means converting the number n to a vector with zeros everywhere except at the _n_th entry.

Let us define the model.

I won’t go into details about convolutional neural networks but this is a pretty standard one using pairwise conv-pool layers with some dropout layers in between to avoid overfitting.

After generating features they are fed into a dense neural network with ReLu and Softmax activations respectively.

The last softmax layer outputs 10 values corresponding to the one-hot encoded targets.

Instead of just training the model as is, we want to add a callback.

Callbacks

In TensorFlow, there are two ways of using callbacks. We can use the built-in callbacks available out of the box or we can build a callback ourselves.

I’ll show you how to do both via some examples.

LearningRateScheduler

Let’s say that we want to have more control over the learning rate during training.

We know that in the pursuit of reaching the minimum of the loss function using gradients, we might take too wide jumps when we get close enough.

We can adjust these jumps by changing the learning rate while training!

We do that using the LearningRateScheduler available in tensorflow.keras.callbacks.

Take a look at the following snippet.

In this way, we adjust the learning rate after 3 epochs.

But still, we don’t know when we reach an acceptable accuracy. What if we want to stop the training after reaching say an accuracy of 98%?

Custom callback

Let us build a custom callback class for this scenario.

Note that I have appended the custom callback to the callbacks list that now contains two callbacks. This is totally fine.

The output of this code is the following.

Very nice!

When using custom callbacks you can specify where in the training you want to tap in and change things.

We can do this by using the functions on_train_begin, on_epoch_begin, on_epoch_end, on_train_end.

EarlyStopping

Most often we don’t have a certain metric target but rather we want to avoid overfitting.

One way to do that is to use the built-in EarlyStopping callback. You can create it as follows: early_stopping = tf.keras.callbacks.EarlyStopping(patience=1) for example, and this is very handy in general.

ModelCheckpoint

It is fine to use the EarlyStopping callback, but it sucks if the last epoch before quitting training messed up the weights that we fought so hard to get by overfitting at the end.

Here’s where the ModelCheckpoint comes into play. Basically, this callback lets you save a version of the model for each epoch so that you can always choose the one with the highest validation accuracy for example.

The syntax is similar to the other ones: tf.keras.callbacks.ModelCheckpoint. You need to know where you want to put the saved models and you can specify that using the filepath parameter.

You could for instance do something like:

checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.2f}.h5')

saving it in the current folder. A wiser choice is probably to save it in a folder for itself called e.g. checkpoints.

TensorBoard

The last callback I will show you is TensorBoard.

TensorBoard is used to visualize your training progress which is very handy in experimenting situations.

If you start the training and then cd yourself into the folder where your model is located from the terminal or cmd and type in: tensorboard --logdir='logs', then you’ll be guided by the terminal to go to http://localhost:6006/.

There you’ll see the training progress like the one below.

Image by author
Image by author

Conclusion

If you have followed along but forgot to code while reading, I have collected the code for you here:

Note that we don’t really use the output from the training, but we store it in history. You are welcome to code on and save the model, plot the metrics and so on.

There are many more callbacks than I have introduced here, but we have to stop at some point. That doesn’t mean they aren’t important though.

They are! Take a look at them here.

To summarize

A callback is a powerful tool for customizing and controlling the training process in your deep learning toolbox.

From debugging and stopping training based on some custom logic to visualizations, callbacks should be every data scientist’s best friend when it comes to model building and experimenting.


Enjoy reading articles like this one on Medium? Get a membership for full access.

If you have any questions, comments or concerns, please reach out to me on LinkedIn:

Kasper Müller – Senior Consultant, Data and Analytics, FS, Technology Consulting – EY | LinkedIn


Related Articles