The world’s leading publication for data science, AI, and ML professionals.

Write a custom training routine for your Keras model

When simplicity and ease of use starts holding you back

Tips and Tricks

Photo by Barn Images on Unsplash
Photo by Barn Images on Unsplash

For people getting started with Deep Learning, the Keras toolbox has no equal. It has everything you need, with the confusing low-level stuff kept to a minimum. The API is very intuitive and makes you focus on the important bits of designing a network, allowing for fast experimentation without much hassle. As an example, the network used in this guide is specified and trained in less than 25 lines of Python code.

There will come times, however, when the ease-of-use of basic Keras functions becomes limiting. Many more advanced neural network training schemes and loss functions become unnecessarily complicated to code up in native Keras. In this guide, I aim to show how the base Keras way of training a neural network can be broken up in its underlying parts, opening up the possibility to change each part as a user sees fit. I do not incorporate any of these custom parts in the example; this guide only aims to give you the tools to experiment more by yourself.


For this guide, I will be using the FashionMNIST dataset to set up and show two different ways of training a neural network to classify pictures different clothing objects. The dataset contains 70000 images from ten object classes, with a pre-defined split between 60000 training images and 10000 validation images. This example is adapted from the Keras tutorials, where you can find many more interesting tutorials.

Some example images of the FashionMNIST data set
Some example images of the FashionMNIST data set

Classical Keras

Building a neural network using Keras is super straightforward. You define it layer by layer, specifying the layers’ properties as you go along. These properties are the number of convolutional filters, the kernel size, activation function, etc. The example model consists of two blocks of 3×3 convolutions, with a 4×4 maxpooling operation in between. The final feature map is run through a global average pooling operation, and a final dense layer with softmax activation provides the class predictions. The function below builds the neural network and returns a tf.keras.Model object that contains all the layers. This object is used later for training and testing the neural network.

Before the model can be trained, Keras requires us to specify some details about the training process like the optimizer, and a loss function. For the example, we also tell Keras to track the network’s accuracy during the training process.

The real magic happens now, with the training of the network. Training a neural network involves minimizing the loss function over a large dataset of training examples. This loss function is minimized using small batches of examples taken from the large dataset. The loss function is computed for these small batches, and for every batch the networks’s parameters are updated with small steps of a gradient descent algorithm. Keras handles all of this with a single call of the ‘fit’ function, with the proper arguments.

This tells Keras to train our network on the training dataset ‘x_train’ with corresponding labels ‘y_val’. The small batches contain 64 images. Our network will train for 10 epochs, or take 10 passes over the full training dataset. At the end of such an epoch, performance is computed for the validation data, allowing us to monitor the network’s generalization potential during training.

By default, Keras shows you valuable information about your training process while it is running, such as loss and potential metrics you told it to track, and tells you how fast it is going through your dataset.

The custom way

If you think about it, the ‘fit’ functions takes care of a lot of details for you. It composes the batches, computes the loss functions, deduces which directions in parameter space we should move in, tracks validation performance. For most use cases, this will be all you need. All of these details will not change much from use case to use case and leaving them to the API frees up time to spend tuning and tinkering with what matters. When the details do change, Tensorflow provides enough tools. You will, however, need to do more of the work by yourself.

First, there is the batching of the dataset. Tensorflow comes with two ways of tackling this issue: the ‘tf.data’ API, and the ‘tf.keras.utils.Sequence’ class. This example will use the latter. ‘tf.data’ has the potential to offer improved performance, but for many of the "custom" training routines I’ve written myself, the ‘Sequence’ class was easier to use. You have to create your own subclass that implements several functions for use during training:

  • init, to initialise an object of the subclass
  • len, to specify the amount of batches
  • getitem, to code up the instructions for drawing a batch from the full training set (this is also where you often perform some form of data augmentation)
  • on_epoch_end, which can be called at the end of a training epoch to, for example, perform some shuffling of the data to change the ordering of images for the next epoch

    Next up is the set-up for the actual training. You have to specify your optimizer and get an instance of your loss function. You probably also want to initialize some bookkeeping variables. Here, I track the loss and accuracy for the training and validation data set. One can view this as writing your own alternative to the Keras ‘compile’ function.

    Finally, we arrive at the key step: training the network. Tensorflow allows us to use the same model built using Keras API functions for the custom training loop. Everything else, however, will change. Instead of one single function call, training will now require two nested for loops. The outer loop tracks the different epochs, and the inner loop provides the mechanism to iterate over batches. For every batch iteration step, we generate the corresponding batch using our custom ‘Sequence’ subclass. The forward pass of this batch through the network is watched by Tensorflow using ‘tf.GradientTape’, so we can later use the gradient of the loss to determine necessary changes to the network parameters. At the end of every epoch, the average training loss and accuracy is stored in our bookkeeping variables. This is also the time to determine the network’s performance on the validation data.

    And that is all there is to it: all the steps to pull apart the different pieces in the training process of a neural network. Writing custom training loops will allow you to keep easier track of various exotic performance measures or build loss functions that rely on information outside your direct batch of training inputs and labels. Consider it an extension to your neural network toolbox.

For completeness, you can find a Jupyter notebook below that contains all the different steps in a single package to test everything for yourself.


Related Articles