It Is Necessary to Combine Batch Normalization and Skip Connections

These techniques must go hand in hand

Antoine Labatie
Towards Data Science

--

Image by geralt from pixabay

This post is based on my paper published in ICML 2019 [1]. It attempts to provide novel insights to answer the following questions:

  • Why do specific architecture of deep neural networks (DNNs) work and others don’t?
  • What’s the role played by batch norm [2] and skip connections [3]?

Answering these questions is a hot topic of research with a lot at stake. Indeed such answers would advance the theoretical understanding of DNNs and potentially unlock further improvement in their design.

How to assess the quality of an architecture?

To get started, the first key observation is that any DNN mapping from input to output requires two elements to be specified: (1) the architecture; (2) the values of model parameters — weights and biases — inside the architecture. The set of all DNN mappings obtained by fixing the architecture and varying the model parameters inside the architecture is referred as the hypothesis class.

The purpose of the hypothesis class is to impose a constraint on training. Indeed training consists in finding a DNN mapping that simultaneously: (1) belongs to the hypothesis class; (2) agrees with the training data. This constraint that DNN mappings must belong to the hypothesis class is a way of expressing the prior knowledge that the true mapping itself belongs to the hypothesis class. It is essentially this prior knowledge that makes possible the induction of the test data using only the training data. For this reason, the prior knowledge is commonly referred as the inductive bias (for more details, I refer the reader to this great book by Shalev-Shwartz and Ben-David).

Returning to our initial goal, we may assess the quality of an architecture by assessing the quality of its inductive bias. This assessment may be performed using the following procedure: fix the architecture and randomly sample model parameters inside the architecture. If most DNN mappings sampled with this procedure have bad properties, it means that there’s an inductive bias towards bad properties, i.e. that bad properties will be favored during training. In turn, this will lead either to untrainability — i.e. underfitting — when bad properties are incompatible with low training loss, or to poor generalization — i.e. overfitting — when bad properties are compatible with low training loss but unlikely to generalize.

Before being able to apply this procedure, we still need to make more precise the notion of “bad properties” of DNN mappings.

What are bad properties of a DNN mapping?

Let’s consider a fixed DNN mapping, specified by fixing the architecture and fixing the model parameters inside the architecture. This fixed DNN mapping receives a random input and propagates this input throughout its layers. We keep track of the propagation by defining:

  • The random signal: y^l = Φ^l(x), obtained by applying the fixed DNN mapping Φ^l up to layer l to the random input x
  • The random noise: dy^l = Φ^l(x+dx) -Φ^l(x), obtained as the corruption at layer l originating from the random corruption dx of the random input x

Now “bad properties” of the DNN mapping may be defined as unwanted behavior of the signal and noise: either with the signal y^l becoming non meaningful, or with the noise dy^l getting out of control. More precisely, two “pathologies” may be defined with such “bad properties” pushed to the extreme (we focus on these two “pathologies” since they are observed in our context, but other “pathologies” could be defined and observed in other contexts):

Pathological Signal
  • Pathological Signal: the signal y^l looses its directions of variance and gets concentrated along a single line in deep layers. This pathology is e.g. incompatible with the one-hot target of multiclass classification (having a number of directions of variance typically equal to the number of classes minus 1). An inductive bias towards this pathology likely leads to untrainability.
Pathological SNR
  • Pathological SNR: the noise dy^l explodes with respect to the signal y^l, with the signal-to-noise ratio SNR^l decaying exponentially with l. This pathology may be compatible with low training loss, but any input corruption dx on the test set will lead to the corrupted signal y^l+dy^l = Φ^l(x+dx) becoming pure noise — i.e. meaningless. An inductive bias towards this pathology likely leads to poor generalization.

Applying our procedure for various architectures

Now let’s apply our procedure for various architectures of convolutional DNNs including fully-connected DNNs as special case with spatial size equal to 1 — with ReLU activation function:

  • DNNs without batch norm and without skip connections suffer from pathological signal — i.e. # directions of variance of y^l close to 1 in deep layers
  • DNNs without batch norm and with skip connections similarly suffer from pathological signal — i.e. # directions of variance of y^l close to 1 in deep layers
  • DNNs with batch norm and without skip connections suffer from pathological SNR i.e. SNR^l / SNR^0 decaying exponentially with l
  • DNNs with batch norm and with skip connections do not suffer from any pathology — i.e. remain well-behaved at all depths

What’s going on?

The main force attracting towards pathology is the multiplicativity of feedforward layer composition (Conv and ReLU layers can be seen respectively as multiplication by a random matrix and multiplication by a Bernouilli random vector):

  • DNNs without skip connections are pathological in deep layers since they are subject to plain feedforward multiplicativity
  • DNNs without batch norm and with skip connections are pathological in deep layers since the roughly equal variance in residual and skip connection branches does not effectively counter feedforward multiplicativity
  • DNNs with batch norm and with skip connections remain well-behaved at all depths since the decaying ratio ∝ 1/(l+1) of signal variance between residual and skip connection branches does effectively counter feedforward multiplicativity

Conclusion

Let’s summarize our results (to dig deeper, I refer the interested reader to the paper and code):

  • The combination of batch norm and skip connections encodes a well-behaved inductive bias in deep nets
  • The benefits of these techniques, however, are difficult to disentangle. It is only when they are combined that — by diluting the residual branch into the skip connection branch — they counter feedforward multiplicativity

I hope that these results will unlock new perspectives in the understanding of deep nets.

Disclaimer

For ease of exposition (and without altering the analysis), some notations have been simplified in this post compared to the paper:

  • The pre-activation perspective has been adopted, with each layer l starting after the convolution and ending again after the convolution
  • # directions of variance of y^l corresponds in the paper to the effective rank of y^l
  • SNR^l / SNR^0 corresponds in the paper to the inverse of the squared normalized sensitivity

References

[1] A. Labatie, Characterizing Well-Behaved vs. Pathological Deep Neural Networks (2019), ICML 2019

[2] S. Ioffe and C. Szegedy, Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift (2015), ICML 2015

[3] K. He, X. Zhang, S. Ren and J. Sun, Identity Mappings in Deep Residual Networks (2016), ECCV 2016

--

--