Train ImageNet without Hyperparameters with Automatic Gradient Descent

Towards architecture-aware optimisation

Chris Mingard
Towards Data Science

--

TL;DR We’ve derived an optimiser called automatic gradient descent (AGD) that can train ImageNet without hyperparameters. This removes the need for expensive and time-consuming learning rate tuning, selection of learning rate decay schedulers, etc. Our paper can be found here.

I worked on this project with Jeremy Bernstein, Kevin Huang, Navid Azizan and Yisong Yue. See Jeremy’s GitHub for a clean Pytorch implementation, or my GitHub for an experimental version with more features. Figure 1 summarises the differences between AGD, Adam, and SGD.

Figure 1 Solid lines show train accuracy and dotted lines show test accuracy. Left: In contrast to our method, Adam and SGD with default hyperparameters perform poorly on a deep fully connected network (FCN) on CIFAR-10. Middle: A learning rate grid search for Adam and SGD. Our optimiser performs about as well as fully-tuned Adam and SGD. Right: AGD trains ImageNet to a respectable test accuracy.

Motivation

Anyone who has trained a deep neural network has likely had to tune the learning rate. This is (1) to ensure training is maximally efficient and (2) because finding the right learning rate can significantly improve overall generalisation. This is also a huge pain.

Figure 2 Why learning rates are important for optimisation. To maximise the speed of convergence, you want to find the Goldilocks learning rate: large, but not so large where the non-linear terms in the objective function kick you around.

However, with SGD the optimal learning rate highly depends on the architecture being trained. Finding it often requires a costly grid search procedure, sweeping over many orders of magnitude. Furthermore, other hyperparameters, like momentum and learning rate decay schedulers, also need to be selected and tuned.

We present an optimiser called automatic gradient descent (AGD) that does not need a learning rate to train a wide range of architectures and datasets, scaling all the way up to ResNet-50 on ImageNet. This removes the need for any hyperparameter tuning (as both the effective learning rate and learning rate decay drop out of the analysis), saving on compute costs and massively speeding up the process of training a model.

Why do we need hyperparameters anyway?

A deep learning system is composed of lots of interrelated components: architecture, data, loss function and gradients. There is a structure in the way these components interact, but as of yet nobody has exactly nailed down this structure, so we’re left with lots of tuning (e.g. learning rate, initialisation, schedulers) to ensure rapid convergence, and avoid overfitting.

But characterising these interactions perfectly could remove all degrees of freedom in the optimisation process — which are currently taken care of by manual hyperparameter tuning. Second-order methods currently characterise the sensitivity of the objective to weight perturbations using the Hessian, and remove degrees of freedom that way — however, such methods can be computationally intensive and thus not practical for large models.

We derive AGD by characterising these interactions analytically:

  1. We bound the change in the output of the neural network in terms of the change in weights, for given data and architecture.
  2. We relate the change in objective (the total loss over all inputs in a batch) to the change in the output of the neural network
  3. We combine these results in a so-called majorise-minimise approach. We majorise the objective — that is, we derive an upper bound on the objective that lies tangent to the objective. We can then minimise this upper bound, knowing that this will move us downhill. This is visualised in Figure 3, where the red curve shows the majorisation of the objective function, shown by the blue curve.
Figure 3 The left panel shows the basic idea behind majorise-minimise — minimising the objective function (blue) is done by minimising a series of upper bounds, or majorisations, (red). The right panel shows how a change in weights induces a change in the function, which in turn induces a change in the loss on a single datapoint, which in turn induces a change in the objective. We bound ∆L in terms of ∆W, and use this to construct our majorisation.

AGD in Pytorch

In this section, we go through all the key parts of the algorithm. See Appendix A for a sketch derivation.

On parameterisation

We use a parameterisation that differs slightly from the conventional PyTorch defaults. AGD can be derived without assuming this parameterisation, but using it simplifies the analysis. For a fully connected layer l, we use orthogonal initialisation, scaled such that the singular values have magnitude sqrt((input dimension of l )/(output dimension of l )).

We use this normalisation because it has nice properties that PyTorch default parameterisation does not, including stability with width, resistance to blow-ups in the activations, and promotion of feature learning. This is similar to Greg Yang and Edward Hu’s muP.

On the update

The step can be broken into two separate parts. The first is the calculation of eta (η), the “automatic learning rate”, which scales the update at all layers. Eta has a logarithmic dependence on the gradient norm — when the gradients are small, eta is approximately linear (like standard optimisers) but when they are very large, the logarithm automatically performs a type of gradient clipping.

Each layer is updated using eta multiplied by the layer’s weight norm, multiplied by normalised gradients, and divided by depth. The division by depth is responsible for scaling with depth. It is interesting that gradient normalisation drops out of the analysis, as other optimisers like Adam incorporate similar ideas heuristically.

Experiments

The goal for these experiments was to test AGD’s ability to (1) converge across a wide range of architectures and datasets, and (2) achieve comparable test accuracy to tuned Adam and SGD.

Figure 4 shows the learning curves of four architectures, from a fully-connected network (FCN) to ResNet-50, on datasets from CIFAR-10 to ImageNet. We compare AGD, shown with solid lines, to a standard optimiser, shown with dotted lines (SGD for ImageNet and tuned Adam for the other three). The top row shows the train objective (loss) and automatic learning rate η. The bottom row shows the train and test accuracies. Figure 5 compares AGD vs tuned Adam vs tuned SGD on an 8-layer FCN. We see very similar performance from all three algorithms, reaching near identical test accuracy.

Figure 6 shows that AGD trains FCNs over a wide range of depths (2 to 32) and widths (64 to 2048). Figure 7 shows the dependence of AGD on batch size (from 32 to 4096), on a 4-layer FCN. It seems to converge to a good optimum no matter the batch size!

Figure 4 AGD vs Adam on four architectures: a depth-16 FCN on CIFAR-10, ResNet-18 on CIFAR-10, VGG-16 on CIFAR-100 and ResNet-50 on ImageNet-1k. AGD keeps a reasonable pace with hyperparameter-tuned Adam (which required grid searching several orders of magnitude)! These solid lines denote AGD and the dashed lines Adam (except for ImageNet, where we use SGD instead). The top row shows the train objective (i.e. loss) and the value of the automatic learning rate η during training. The bottom row shows the train and test accuracies.
Figure 5 AGD vs Adam vs SGD on a depth-8 FCN with mean square error loss. Adam and SGD have their learning rates tuned. On the left, we plot the train and test objective functions (i.e. the loss). The middle shows the train and test accuracies. The right shows the average, minimum and maximum changes to the weights during each epoch.
Figure 6 AGD converges out of the box for a huge range of depths and widths. Smaller architectures lack the capacity to achieve low loss, but AGD still trains them!
Figure 7 And just to check AGD doesn’t only work for batch size 128, here is a selection of batch sizes for a depth-4 FCN.

Conclusion

To summarise, here is an “architecture-aware” optimiser: automatic gradient descent (AGD), capable of training small systems like an FCN on CIFAR-10 to large-scale systems like ResNet-50 on ImageNet, at a range of batch sizes, without the need for manual hyperparameter tuning.

While using AGD has not removed all hyperparameters from machine learning, those that remain — batch size and architecture — typically fall under into the “make them as large as possible to fill up time/compute budget”.

However, there is lots still to be done. We do not explicitly take into account stochasticity introduced into the gradient due to batch size. We also haven’t looked into regularisation like weight decay. While we’ve done a little bit of work in adding support for affine parameters (in batch norm layers) and bias terms, we haven’t tested it extensively, nor is it as well justified by theory as the rest of the results here.

Perhaps most importantly, we still need to do the analysis required for transformers, and test AGD on NLP tasks. Preliminary experiments with GPT-2 on OpenWebText2 indicate that AGD works here too!

Finally, check out Jeremy’s GitHub for a clean version, or my GitHub for a developmental version with support for bias terms and affine parameters, if you want to try AGD! We hope you will find it useful.

Appendix A

We will go through a sketch of the important steps of the proof here. This is designed for anyone who wants to see how the main ideas come together, without going through the full proof, found in our paper here.

Equation (1) specifies explicitly how the overall objective across the dataset S is decomposed into individual datapoints. L denotes loss, x the inputs, y the targets and w the weights. Equation (2) shows a decomposition of the linearisation error of the objective — the contributions of higher order terms given some change in weights, Δw, to the change in loss ΔL(w). The linearisation error of the objective is important because it is equal to the contributions of higher order terms in the loss expanded at weights w — bounding this will tell us how far we can move before the higher order terms become important, and make sure we’re taking steps of sensible size, downhill.

The first term on the RHS of Equation (2) is an inner product between two high-dimensional vectors, the linearisation error of the model, and the derivative of the loss with respect to f(x). Since there is no clear reason why these two vectors should be aligned, we assume that their inner product is zero.

Adding L(W+ΔW) to each side of Equation (2), and noting that the linearisation error of the loss happens to be a Bregman divergence, we can simplify the notation:

A Bregman divergence is a measure of distance between two points (in this case, the outputs of two different parameter choices of a neural network), defined in terms of a strictly convex function — in this case, the loss.

Calculating the Bregman divergence is actually quite straightforward for mean square error loss, and gives us

Where dₗ is the output dimension of the network. We now assert the following scalings. All of these are somewhat arbitrary, but having them in this form will make the analysis much simpler.

We use the following two bounds on the size of the network output. Equation (5) bounds the magnitude of the network output, and comes from just applying the (input scaling) and (weight scaling) to a fully-connected network. Equation (6) bounds the maximum change in f(x) with a change in the weights W. The second inequality in (6) is tightest at large depth but holds for any depth.

Now, we substitute Equation (6) back into Equation (4), and expand all terms out explicitly to get Equation (7).

We can substitute the sum in Equation (7) with G, defined in Equation (8) under an additional assumption about the gradient conditioning, which is discussed in detail in the paper. Finally, we get Equation (9) — this is the majorisation — the red line in Figure 3. We minimize the majorisation by differentiating with respect to η, and solve the resulting quadratic in exp(η), retaining the positive solution. This gives the following update

And this concludes our derivation of automatic gradient descent. Please let us know if you have any comments, questions or other kinds of feedback.

All images in the blog are made by the authors of our paper. Figure 2 inspired by this diagram.

--

--