tf.keras and TensorFlow: Batch Normalization to train deep neural networks faster

Chris Rawles
Towards Data Science
5 min readFeb 26, 2018

--

Training deep neural networks can be time consuming. In particular, training can be significantly impeded by vanishing gradients, which occurs when a network stops updating because the gradients, particularly in earlier layers, have approached zero values. Incorporating Xavier weight-initialization and ReLu activation functions helps counter the vanishing gradient problem. These techniques also help with the opposite, yet closely related issue of exploding gradients, where the gradients become extremely large preventing the model from updating.

Perhaps the most powerful tool for combatting the vanishing and exploding gradients issue is Batch Normalization. Batch Normalization works like this: for each unit in a given layer, first compute the z score, and then apply a linear transformation using two trained variables 𝛾 and 𝛽. Batch Normalization is typically done prior to the non-linear activation function (see below figure), however applying it after the activation function can also be beneficial. Check out this lecture for more detail of how the technique works.

During backpropagation gradients tend to get smaller at lower layers, slowing down weight updates and thus training. Batch Normalization helps combat the so-called vanishing gradients.

Batch Normalization can be implemented in three ways in TensorFlow. Using:

  1. tf.keras.layers.BatchNormalization
  2. tf.layers.batch_normalization
  3. tf.nn.batch_normalization

08/18/2018 update: The DNNClassifier and DNNRegressor now have a batch_norm parameter, which makes it possible and easy to do batch normalization with a canned estimator.

11/12/2019 update: This has gotten even easier with TF 2.0 using tf.keras, you can simply add in a BatchNormalization layer and do not need to worry about control_dependencies.

The tf.keras module became part of the core TensorFlow API in version 1.4. and provides a high level API for building TensorFlow models; so I will show you how to do it in Keras. The tf.layers.batch_normalization function has similar functionality, but Keras often proves to be an easier way to write model functions in TensorFlow.

Note the training variable in the Batch Normalization function. This is required because Batch Normalization operates differently during training vs. the application stage– during training the z score is computed using the batch mean and variance, while in inference, it’s computed using a mean and variance estimated from the entire training set.
In TensorFlow, Batch Normalization can be implemented as an additional layer using tf.keras.layers.

The second code block with tf.GraphKeys.UPDATE_OPS is important. Using tf.keras.layers.BatchNormalization, for each unit in the network, TensorFlow continually estimates the mean and variance of the weights over the training dataset. These are then stored in the tf.GraphKeys.UPDATE_OPS variable. After training, these stored values are used to apply Batch Normalization at prediction time. The training set mean and variance from each unit can be observed by printing extra_ops, which contains a list for each layer in the network:

print(extra_ops)[<tf.Tensor ‘batch_normalization/AssignMovingAvg:0’ shape=(500,) dtype=float32_ref>, # layer 1 mean values
<tf.Tensor ‘batch_normalization/AssignMovingAvg_1:0’ shape=(500,) dtype=float32_ref>, # layer 1 variances ...]

While Batch Normalization is also available in the tf.nn module, it requires extra bookkeeping, as the mean and variance are required arguments for the function. Thus the user has to manually compute mean and variance at both the batch level and training set level. It is, thus, a lower abstraction level than tf.keras.layers or tf.layers; avoid the tf.nn implementation.

Batch Normalization on MNIST

Below, I apply Batch Normalization to the prominent MNIST dataset using TensorFlow. Check out the code here. MNIST is an easy dataset to analyze and doesn’t require many layers to achieve low classification error. However, we can still build a deep network and observe how Batch Normalization affects convergence.

Let’s build a custom estimator using the tf.estimator API. First we build the model:

After we define our model function, let’s build the custom estimator and train and evaluate our model:

Let’s test how Batch Normalization impacts models of varying depths. After we wrap our code into a Python package, we can fire off multiple experiments in parallel using Cloud ML Engine:

The below plot show the number of training iterations (1 iteration contains a batch size of 500) required to reach 90% testing accuracy — an easy target — as a function of network depth. It’s evident that Batch Normalization significantly speeds up training for the deeper networks. Without Batch Normalization, the number of training steps increases with each subsequent layer, but with it, the number of training steps is near constant. And in practice, on more difficult datasets, more layers is a prerequisite for success.

Without Batch Normalization the number of training iterations required to hit 90% accuracy increases with the number of layers, likely due to the vanishing gradient effect.

Similarly, as shown below, for a fully connected network with 7 hidden layers, the convergence time without Batch Normalization is significantly slower.

The above experiments utilize the commonly used ReLu activation function. Though obviously not immune to the vanishing gradient effect as shown above, the ReLu activation fares much better than the sigmoid or tanh activation functions. The vulnerability of the sigmoid activation function to vanishing gradients is rather intuitive to understand. At larger magnitude (very positive or negative) values, the sigmoid function “saturates” — i.e. the derivative of the sigmoid function approaches zero. And when many nodes saturate, the number of updates decreases, and network stops training.

The same 7-layer network trains significantly slower using sigmoid activation functions without using Batch Normalization. With Batch Normalization, the network converges in a similar number of iterations when using ReLu.

On the other hand, other activation functions, such as the exponential ReLu or leaky ReLu functions, can help combat the vanishing gradient issue as they have non-zero derivatives for both positive and negative large numbers.

Finally, it is important to note that Batch Normalization incurs an extra time cost to training. Though Batch Normalization typically decreases the number of training steps to reach convergence, it brings an extra time cost because it introduces an additional operation and also introduces two new trained parameters per unit.

For the MNIST classification problem (using a 1080 GTX GPU), Batch Normalization converges in (top) fewer iterations, however the time per iteration is slower. Ultimately, the Batch Normalization version still converges faster (bottom), but the improvement is less pronounced when incorporating total training time.

Incorporating XLA and fused Batch Normalization (fused argument in tf.layers.batch_normalization) could help speed up the Batch Normalization operation by combining several individual operations into a single kernel.

Regardless, Batch Normalization can be a very valuable tool for speeding the training of deep neural networks. As always with training deep neural networks, the best way to figure out if an approach will help for your problem is to try it!

--

--