mSource: Unsplash

A Short & Practical Guide to Callbacks in Neural Network Training

Custom Callbacks, LR Reduction, Checkpoints, Termination, & More

Andre Ye
Analytics Vidhya
Published in
5 min readAug 22, 2020

--

Callbacks are an important part of neural network training. These are actions that can be performed at various stages of training — perhaps in between epochs, after a batch is processed, or even if a condition is met. Callbacks can be used in many creative ways to improve training and performance, conserve computational resources, and give insights on what is happening inside the neural network.

This article will detail the rationale and code for important callbacks, as well as the process of creating custom callbacks.

ReduceLROnPlateau is a callback included by default in Keras. The learning rate of neural networks determines the scaling factor of the gradient; hence too high a learning rate will result in the optimizer shooting over an optima and too low a learning rate resulting in too long a training time. It’s hard to find a static, unchanging learning rate that works well.

As its name suggests, ‘Reduce Learning Rate On Plateau’ reduces the learning rate when the loss metric stops improving, or reaches a plateau. When learning stagnates, models often benefit from a reduction of 2 to 10 times the LR, which helps hone in on the optimal values of parameters.

To use ReduceLROnPlateau, a callback object must first be created. Four parameters are important: monitor, which is the metric to watch for plateauing; factor, which is the factor by which the new learning rate will be reduced (multiplied by); patience, the number of epochs of plateauing the callback will wait until it activates; and min_lr, the minimum learning rate it can be reduced to. This prevents unnecessary and unbeneficial decreases.

Then, when the model is fit, a callbacks parameter can be specified. Note that this can take in a list, so multiple callbacks can be scheduled.

LearningRateScheduler is an alternative to ReduceLROnPlateau, and allows the user to schedule the learning rate based on the epoch. If you know, perhaps from previous research or experiments, that the network should have a learning rate of x from epochs 1–10, y from epochs 10–20 (arbitrary numbers), LearningRateScheduler can help implement these changes.

Creating a learning rate scheduler requires a user-defined function that takes in the epoch and learning rate as parameters. The return object should be the new learning rate. There’s lots of freedom to experiment with values of the learning rate and the current epoch.

Then, after it is converted into a Keras callback, it can be used in the training of the model. These schedulers can be very helpful and allow for control over the network, but it’s advised to use ReduceLROnPlateau at least the first time a network is trained, since it’s more adaptive. Then, visualizing the model can give ideas on how to construct an adequate LR scheduler.

Additionally, you could use ReduceLROnPlateau and LearningRateScheduler simultaneously, for instance, using the scheduler to hard-code some learning rate requirements (such as no changes before the tenth epoch) while using the adaptive power of reducing the learning rate based on a plateau in performance.

EarlyStopping can be very helpful to prevent extra redundant runtime in training a model, which can lead to high costs. When the network has not improved in a given number of epochs, the network completes training and stops using computational resources. Similarly to ReduceLROnPlateau, early stopping requires a metric to monitor and patience.

TerminateOnNaN helps prevent exploding gradient problems in training, when the loss goes to NaN and the rest of the network goes berserk. Keras does not stop the network from training when this happens, and additionally, the overload of NaNs can contribute to a rise in demand for computing power. To prevent this from happening, adding TerminateOnNaN is a good safe-check.

ModelCheckpoint is useful for many reasons by saving the weights of the model at some frequency (perhaps every 10 or so epochs).

  • If the platform the model is being trained on suddenly cuts, the model will not need to be completely retrained.
  • If, say, in epoch 30, the model begins to show signs of overfitting or another problem like exploding gradients, we can reload the model with the most recently saved weights (say at epoch 25) and adjust the parameters such that the problem is avoided, without needing to redo most of the training.
  • Being able to extract the weights at some epoch and reload them into another model could be beneficial for transfer learning.

In the scenario below, ModelCheckpoint is used to store the weights of the model with the best performance. At each epoch, if the model performed better than other epochs recorded, its weights are stored in a file (overriding the previous one’s). At the end of training, we load the weights of the best model using model.load_weights.

Alternatively, if you want a frequency-based saving (save every 5 epochs), set save_freq=5 and disable metric-based saving triggers.

Writing custom callbacks is one of the best features Keras included, and allows for highly specific actions. However, note constructing it is more work than using a default callback.

Our custom callback will take the form of a class. Similarly to constructing a neural network in PyTorch, we can inherit from keras.callbacks.Callback, which is a base callback class. Our class can have many functions — which must have the given names listed below — indicating when those functions will be run. For instance, the function on_epoch_begin will be run at the start of each epoch. Below are all the functions Keras will read from your custom callback, but additional ‘helper’ functions can be added.

Depending on the function, you will have access to different variables. For instance, in the function on_epoch_begin the function has access both to the epoch number and to a dictionary of current metrics, logs. If additional information, like the learning rate, is desired, look into Keras’ backend support with keras.backend.get_value.

Then, the callback can be treated like any other callback.

Some ideas for custom callbacks:

  • Record the results of training in a JSON or CSV file.
  • Email the results of training every ten epochs.
  • Add more complex functionality in deciding when to save model weights.
  • Train a simple machine learning model (perhaps with sklearn) to learn when to increase or decrease the learning rate by setting it as a class variable and taking data in the form of (x: action, y: change in performance).

When callbacks are used in neural networks, the creator’s control increases and the neural network becomes more tamable.

--

--