Making Sense of Big Data

The Serendipitous Effectiveness of Weight Decay in Deep Learning

Coming full circle after unintended changes of effects

Antoine Labatie
Towards Data Science
8 min readDec 30, 2020

--

Photo by Josh Calabrese on Unsplash

Weight decay is without question an integral part of the deep learning toolkit. It might seem like the prototypical example of a simple idea whose effectiveness endured time. If we look closely, however, comes a big surprise: weight decay’s effectiveness in deep learning owes more to serendipity than to its original rationale.

Original rationale of weight decay

In machine learning, the weight decay term λ∥w∥², with strength given by hyperparameter λ>0, can be added to the loss function L(w) to penalize the norm of model parameters w. The minimization of L(w) is then replaced by the minimization of L(w) + λ∥w∥².

Adding such weight decay can be seen as imposing a soft constraint on the norm of w. This soft constraint reduces the model’s effective capacity (i.e. its representation power), and with it, its tendency to overfit.

To get a better intuition, let’s consider the hypothetical case where the regularization strength becomes infinite λ → ∞. In that case, the minimization of L(w) + λ∥w∥² becomes dominated by the term λ∥w∥² and the solution w=0 becomes independent of the training data. Without adaptation to the training data, no overfitting!

Historically, this idea of penalizing the norm of model parameters as a regularizer emerged in many different contexts independently. It can be traced back at least to the work of Tikhonov in 1943 in the context of inverse problems, hence its alternate naming of “Tikhonov regularization”. The idea didn’t leave the front of the stage thereafter, being notably at the core of support vector machines that dominated the machine learning landscape before the advent of deep learning.

On a theoretical level, it can be proven that any class of models with convex and Lipschitz loss function L(w) on a bounded domain is “agnostic PAC-learnable” through the minimization of L(w) + λ∥w∥² [1]. In simple terms, this means that the minimizer of L(w) + λ∥w∥² only has a very small probability of overfitting given enough data. Crucially, the number of data points required not to overfit is independent of the number of parameters. This property is very appealing in the context of deep learning with models constituted of millions or even trillions of parameters.

So what’s the catch? As we said, one of the hypotheses underlying this rationale is that L(w) is Lipschitz. And, as we shall see shortly, this hypothesis is fundamentally wrong for deep models with normalization layers. So the original rationale of weight decay becomes essentially null in modern state-of-the-art models that — nearly always — incorporate normalization layers, e.g. with batch normalization or group normalization in Convolutional Neural Networks or layer normalization in Transformers.

Weight decay in deep learning

Normalization → scale invariance

Normalization plays a crucial role in the successful “scaling” of neural networks to very large and deep models. Its primary purpose is to ensure that, no matter how wide and deep a neural network is, the signal propagated through it remains normalized.

For this purpose to be achieved, each normalization layer transforms its input x to produce the output Norm(x)=(x-μ(x))/σ(x), with μ(x), σ(x) the mean and standard deviation of x over some subsets of its components. The inclusion of this simple operation in a neural network has profound impacts, one of which being the emergence of the scale invariance property.

To see how this property emerges, let’s assume — as is typical — that x itself is the result of a convolution with weights w. If w is scaled by some positive constant ρ>0, then so is the convolution output, now equal to ρx. In turn, the statistics of ρx are scaled by ρ compared to the statistics of x, and Norm(ρx) remains equal to (x-μ(x))/σ(x).

So when normalization is present, the output of the neural network with weights w is always equal to the output of the neural network with weights ρw. This means that the loss of both networks are equal: L(w)= Lw) and that L(w) is “scale-invariant”.

A profound implication of this property is that L(w) is fully determined by its values on the unit sphere ∥w∥=1. Indeed, the value of L(w) for any w can be recovered from the value of its projection on the unit sphere w/∥w∥, given that L(w)=L(w/∥w∥) by scale invariance.

From this follow two implications on the gradient ∇L(w) of L(w):

  • L(w) is orthogonal to w since L(w) does not change along the direction of w used for the projection on the unit sphere
  • L(w) scales as 1/∥w∥ since the rate of variation of the loss is multiplied by 1/∥w∥ at point w compared to point w/∥w∥ (which explains why scale-invariant functions cannot be Lipschitz!)

Weight decay + scale invariance → increased effective learning rate

In the remainder of this post, we focus on the context widely encountered in deep learning of stochastic gradient descent (SGD). In this context, we denote w_t the model parameters at iteration t, and ∇L(w_t|B_t), ∇L(w_t|D) the gradients at point w_t computed respectively on the mini-batch B_t and the whole dataset D.

Fig. 1: Weight equilibrium of SGD with learning rate η and weight decay strength λ

As can be seen on Fig. 1, SGD without weight decay would solely follow the direction of the mini-batch gradient ∇L(w_t|B_t), that — by orthogonality to w_t — has a small centrifugal component with respect to the interpolation of model parameters between t and t+1. As a consequence, w_t would “spiral out” to ∥w_t∥→ ∞, causing gradients to vanish: ∥∇L(w_t|B_t)∥∝1/∥w_t∥→ 0, and the optimization to slow down.

When weight decay is added, on the other hand, its centripetal force compensates the centrifugal force of gradients such that an equilibrium with approximately constant ∥w_t∥ can be reached. This equilibrium is stable due to the opposite dependence on ∥w_t∥ of the weight decay and gradient forces.

When this equilibrium is reached, it can be shown that the update during iteration t of the projection w_t/∥w_t∥ on the unit sphere is approximately equal to -η_eL(w_t/∥w_t∥|B_t), with the “effective learning rate” η_e equal to η/w_t∥² [2]. This means that the update during iteration t of w_t/∥w_t∥ is equivalent to an update of SGD with learning rate η_e at point w_t/∥w_t∥.

Furthermore, the effective learning rate η_e can be shown to scale as the square root of ηλ, due to the fact that ∥w_t∥² at equilibrium scales as the square root of η/λ (cf [3] and Fig. 2). The primary effect of weight decay in deep learning is therefore not to reduce the model capacity, but to increase the effective learning rate!

Let’s now see why this unintended change of effects leads in the end still to regularization, enabling weight decay to go full circle.

Fig. 2: For various choices of (η,λ), we plot ∥w_t∥² and the ratio of ∥w_t∥² by the square root of η/λ during the first 100 epochs of training of ResNet-8 on CIFAR-10, with w_t the concatenation of all 3x3 conv kernels in the network

Increased effective learning rate → regularization

Let’s start by considering an increase of learning rate in the case of networks with neither scale invariance nor weight decay.

In that case, SGD “descends” at each iteration the mini-batch gradient by making a step -ηL(w_t|B_t). Since the mini-batch gradient ∇L(w_t|B_t) is only an approximation of the dataset gradient ∇L(w_t|D), this results in the introduction of a stochastic noise with zero expectation but nonzero variance in the trajectory (cf Fig. 3).

Fig. 3: Noise in the trajectory of SGD with learning rate η

To understand the effect of an increase of learning rate, let’s now contrast the cases: (i) where we perform two iterations with learning rate η, and (ii) where we perform one iteration with learning rate 2η. While the expected length of trajectory is roughly the same in both cases, the stochastic noise is increased in case (ii) compared to case (i). Indeed, the noise variance scales as 4η² when performing one iteration with learning rate 2η, while it scales as 2η² — by independence of each iteration — when performing two iterations with learning rate η.

Thus, an increase of the learning rate leads to an increased stochasticity of SGD per unit length of trajectory. In turn, this increased stochasticity favors the convergence to “flat” regions of low loss, that are less prone than sharper regions to overfitting [4]. An increase of the learning rate is therefore a form of regularization.

How does that translate to networks with scale invariance and weight decay? Recall that the update during iteration t of w_t/∥w_t∥ is equivalent to an update of SGD with learning rate η_e at point w_t/∥w_t∥. Consequently, an increase of the effective learning rate η_e — obtained in particular when increasing the weight decay strength λ — leads to an increased stochasticity of the trajectory of w_t/∥w_t∥. In turn, this favors the convergence of w_t/∥w_t∥ to “flatter” regions of low loss, that are less prone overfitting. An increase of the weight decay strength λ is therefore still a form of regularization.

Conclusion

To sum up, weight decay’s primary effect has changed in deep learning compared to classical machine learning. It is not anymore to limit the model capacity, but rather to increase the effective learning rate. Since an increase of the effective learning rate still leads to finding solutions less prone to overfitting, weight decay has serendipitously remained a form of regularization.

Disclaimer

For the sake of clarity and without altering the analysis, a few complexities have been ignored in the main text.

  • Convex-Lipschitzness of L(w) and agnostic PAC learnability. Technically, for agnostic PAC learnability, the convexity and Lipschitzness of L(w) is required for any dataset distribution.
  • Scale and shift parameters of the norm. Normalization is usually followed by an affine transform with parameters β, γ, such that it produces the output: γ(x-μ(x))/σ(x)+β. This does not alter our argument on the emergence of the scale invariance property.
  • Parameters not subject to scale invariance. Typically, some parameters in a neural network are not subject to scale invariance (e.g. the final fully-connected layer). Our argument on the scale of equilibrium still applies for all the parameters subject to scale invariance.
  • Flatness and generalization. To guarantee generalization, flatness is not sufficient and an additional constraint on the norm of model parameters is required [4]. In networks with scale invariance, such constraint is present for w_t/∥w_t∥ since the norm of w_t/∥w_t∥ remains equal to 1.

References

[1] Shai Shalev-Shwartz, Shai Ben-David, Understanding Machine Learning: From Theory to Algorithms (2014), Cambridge University Press

[2] Elad Hoffer, Ron Banner, Itay Golan, Daniel Soudry, Norm Matters: Efficient and Accurate Normalization Schemes in Deep Networks, NIPS 2018

[3] Twan van Laarhoven, L2 Regularization versus Batch and Weight Normalization, arXiv 2017

[4] Behnam Neyshabur, Srinadh Bhojanapalli, David McAllester, Nathan Srebro, Exploring Generalization in Deep Learning, NIPS 2017

--

--