Curse of Batch Normalization

What are some drawbacks of using batch normalization?

Sahil Uppal
Towards Data Science

--

Photo by Freddie Collins on Unsplash

Batch Normalization is Indeed one of the major breakthrough in the field of Deep Learning and is one of the hot topics for discussion among researchers in the past few years. Batch Normalization is a widely adopted technique that enables faster and more stable training and has become one of the most influential methods. However, despite its versatility, there are still some points holding this method back as we are going to discuss in this article, which shows that there’s still room for improvement for normalization methods.

Why do we use Batch Normalization?

Before discussing anything, first, we should know what batch normalization is, how it works, and discuss it’s use cases.

What Batch Normalization is

During training, the output distribution of each intermediate activation layer shifts at each iteration as we update the previous weights. This phenomenon is referred to as an internal covariant shift (ICS). So a natural thing to do, if I want to prevent this from happening, is to fix all the distributions. In simple words, if I had some problem that my distributions are shifting around, ill just clamp them and not let them shift around to help gradient optimization and prevent vanishing gradients, and this will help my neural network train faster. So reducing this internal covariant shift was the key principle driving the development of batch normalization.

How it works

Batch Normalization normalizes the output of the previous output layer by subtracting the empirical mean over the batch divided by the empirical standard deviation. This will help the data look like Gaussian distribution.

Where mu and sigma_square are the batch mean and batch variance respectively.

And, we learn a new mean and covariance in terms of two learnable parameters γ and β. So in short, you can think of batch normalization is something that helps you control the first and second moments of the distribution of the batch.

Feature distribution output from an intermediate convolution layer from VGG-16 Network. 1. (Before) without any normalization, 2. (After) applying batch normalization.

Benefits

I’ll enlist some of the benefits of using batch normalization but I won’t get into much detail, as there are tonnes of articles already covering that.

  • Faster convergence.
  • Decreases the importance of initial weights.
  • Robust to hyperparameters.
  • Requires less data for generalization.
1. Faster Convergence, 2. Robust to hyperparameters

Cursed Batch Normalization

So, getting back to the motive of the article, there are many situations under which batch normalization starts to hurt performance or doesn’t work at all.

Unstable when using small batch sizes

As discussed above, the batch normalization layer has to calculate mean and variance to normalize the previous outputs across the batch. This statistical estimation will be pretty accurate if the batch size is fairly large while keeps on decreasing as the batch size decreases.

ResNet-50’s validation error of Batch Norm trained with 32, 16, 8, 4 and 2 images/GPU

Above is ResNet-50’s validation error plot. As one can infer, if the batch size is kept 32, it’s final validation error is around 23 and the error keeps on decreasing with smaller batch sizes (Batch size can’t be 1 for batch normalization because it will be mean of itself). And there’s a huge difference in the loss (around 10%).

If the small batch size is a problem, why don’t we use a higher batch size? Well, we can’t use a higher batch size in every situation. Consider fine-tuning, we can’t use high batch size to not hurt our model with high gradients. Consider distributive training, your high batch size will eventually be distributed among instances as a set of small batch sizes.

Leads to Increased Training Time

As a result of experiments conducted by NVIDIA and Carnegie Mellon University, they claim that “even though Batch Normalization is not the computationally intensive and total number of iterations needed for convergence are decreased. The per-iteration time could be noticeably increased.”, and it can further be increased with an increase in batch size.

ResNet-50 training-time distribution on ImageNet using Titan X Pascal

As you can see, batch normalization consumed 1/4 of total training time. The reason is that because batch norm requires double iteration through input data, one for computing batch statistics and another for normalizing the output.

Different results in training and inference

For Instance, consider the real-world application “object detection”. While training an object detector, we usually go with a large batch size (YOLOv4 and Faster-RCNN both are trained on batch size = 64 by default). But after putting these models into production, these models don’t work as good as they were while training. This is because they are trained with large batch size, while in real-time they are getting a batch size equal to one because it has to process each frame subsequently. Considering this limitation, some implementations tends to use pre-computed mean and variances based on the activations on the training set. Another potential is to compute the mean and variation values based on your test-set activation distribution, but still not batch-wise.

Not good for online learning

In contrast to batch learning, online learning is a type of learning technique in which the system is trained incrementally by feeding it data instances sequentially, either individually or by small groups called mini-batches. Each learning step is fast and cheap, so the system can learn about new data on the fly, as it arrives.

Typical online learning pipeline

As it depends on an external source of data, data may arrive individually or in batches. Due to the change of batch size in every iteration, it poorly generalizes the scale and shift of input data, which eventually hurts performance.

Not good for Recurrent Neural Networks

Although batch normalization speeds-up training and generalization significantly in convolution neural networks, they are proven to be difficult to apply on recurrent architectures. Batch normalization can be applied in between stacks of RNN, where normalization is applied “vertically” i.e. the output of each RNN. But it cannot be applied “horizontally” i.e. between timesteps, as it hurts training because of exploding gradients due to repeated rescaling.

[NOTE]: Some research experiments claim that batch normalization makes neural networks prone to adversarial vulnerability. But we didn’t include this point due to lack of study and proof.

Alternatives

So these were some drawbacks of using batch normalization. There are several alternatives used in situations where batch normalization can’t hold.

  • Layer Normalization.
  • Instance Normalization.
  • Group Normalization (+ weight standardization).
  • Synchronous Batch Normalization.

Conclusion

So training deep neural networks is simple but I don’t think it’s quite easy yet. In the sense that, there are few architectures I can choose between, there’s a fixed learning rate that everyone uses and a fixed optimizer and a fixed set of tricks. These tricks have been chosen via natural selection almost like someone comes up with some trick they introduce it if it works, it stays, if it doesn’t work, people eventually forget about it and no one uses it again. Apart from this, batch normalization is a milestone technique in the development of deep learning. However, normalizing along the batch dimension introduces some problems as discussed, which suggests that there’s still room for improvement in normalization techniques.

--

--