
Deep Neural Networks (DNNs) are regarded as one of the most effective tools for finding patterns in large datasets through training. At the core of the training problems, we have complex loss landscapes and the training of a DNN boils down to optimizing the loss as the number of iterations increases. A few of the most commonly used optimizers are Stochastic Gradient Descent, RMSProp (Root Mean Square Propagation), Adam (Adaptive Moment Estimation) etc.
Recently (September 2024), researchers from Apple (and EPFL) proposed a new optimizer, AdEMAMix¹, which they show to work better and faster than AdamW optimizer for language modeling and image classification tasks.
In this post, I will go into detail about the mathematical concepts behind this optimizer and discuss some very interesting results presented in this paper. Topics that will be covered in this post are:
- Review of Adam Optimizer
- Exponential Moving Average (EMA) in Adam.
- The Main Idea Behind AdEMAMix: Mixture of two EMAs.
- The Exponential Decay Rate Scheduler in AdEMAMix.
Let’s begin.
Adam Optimizer Concepts: Review
We will see soon that AdEMAMix is an extension of the Adam² optimizer, so let’s brush up on the concepts behind the Adam optimizer.
Adam combines the concepts of momentum and adaptive learning rates (RMSProp). Adam keeps track of the exponentially decaying average of the past gradients to include in the current momentum as below:

_gt, _β_1 are gradient at time step t and a constant parameter (from the Adam paper ~0.9 is a reasonable value) that controls the learning rate, respectively. This one is usually referred to in literature as the first moment.
The second moment keeps track of the squared gradients from the past as below:

Since both _mt, _vt are initialized to 0, there would be a tendency to be biased towards 0. To remove this bias, bias correction terms are introduced for these moments:

How does this bias correction actually help? Let’s do a simple calculation by starting from _m_t=0 (_for t=0); Without the bias correction, the first few moments would look as below:

With the bias correction, however, we have a modified expression:

If we compare Eq. 4 and Eq. 5 i.e. without and with the bias correction, we see that, including the bias correction in the first moment leads to more unbiased estimates of the gradients at the early stages. As β1 is set to 0.9, (1-β1) will be closer to 0, so without the bias correction we only take a negligible contribution from _g1 and _g2.
We can plot the variation of the first moment with and without bias correction below as a function of training epochs for some random gradient values to show that without bias correction, the first moment will be close to zero at the early epochs.

Including these bias corrections, we have the rule of updating the parameters as below:

Adam optimizer with the regularization parameter (i.e. AdamW) can be written as:

Exponential Moving Average (EMA):
We need to get accustomed to the intuition behind EMA. If we look back at Eq. 1 and 2, we see that both _mt and _vt follow a recursive relation where the past observation (m{t_−1}, v{t_−1}) are scaled by a factor (β) and the current observation is scaled by (1−β). Where does the exponential part come in? To see this, let’s proceed with one of them, _mt.
We will initialize _mt with 0 for the first step and then try to find a general formula;

Here, we see that the contribution of each past gradient decays exponentially as we move further back in time as they are multiplied by an additional factor of β1.
From now on, I will use the notations used in the original AdEMAMix paper. The authors denote a gradient set containing all the gradients as a function of training epochs:

Using this, we have the EMA for the first moment as:

This reveals the key concept behind the EMA, the current observations are more important and the past observations, as we move back in time, are less important. The multiplication term β^i goes exponentially down with the increase of epochs, as the multiplicative factor 0<β<1 is less than 1, the effect of g{t_−1} on _mt is less than _gt and so on. We can plot this in Python as below:
beta1 = 0.9
epochs = 50 # Number of epochs (iterations)
weights = [(1 - beta1) * (beta1 ** i) for i in range(epochs)] #ema
plt.figure(figsize=(10, 6))
plt.plot(weights,
label=r"Exponential Decay of Gradient Influence $(beta _1 = 0.9)$")
plt.xlabel('Time Steps in the Past')
plt.ylabel('Influence of Past Gradients')
plt.title('Gradient $(g_t)$ Influence with Time')
plt.legend()

Main Idea Behind AdEMAMix: Mixture of EMAs
To quote from the paper:
A single EMA cannot both give a significant weight to recent gradients, and give a non-negligible weight to older gradient.
We need to clarify this concept. Looking back at the past equations, you can probably get an idea already but let’s make it even more clear.
Let’s consider Eq. 1, when we increase β from 0.9 to 0.99, the term (1−β)_gt contributes less towards _mt, where _gt is the current gradient. Also increasing β increases the contribution of the term β m{t_−1}, i.e. the effect of past gradients. This is partly also shown in Eq. 8.
So it is indeed difficult to give equal importance to recent gradients and the contribution of older gradients. But how much do older gradients contribute? The researchers through experimentation found that for language modelling and computer vision tasks, the gradients can stay relevant for "tens of thousands of steps".
The researchers show that when they increase β i.e. give more importance to past gradients, the performance actually worsens with AdamW as an optimizer. They argue that it is not because we are trying to include outdated gradient information forcefully but it is because increasing beta reduces sensitivity towards the current gradient.
The authors mention:
A small β (e.g. 0.9) will give a high weight to the immediate past and negligible weights to older timesteps. In contrast, a high β (e.g. 0.9999) is giving a relatively uniform, yet non-negligible weight to all gradients. No β value can simultaneously give a high weight to the immediate past and a non-negligible weight to very old timesteps.
The authors then go on to suggest the key concept for constructing AdEMAMix optimizer.
A linear combination between a "fast-changing" (e.g. β1=0.9) and a "slow-changing" (e.g. β3=0.9999) EMA allows the iterate to beneficiate from (i) the great speedup provided by the larger (slow-changing) momentum, while (ii) still being reactive to small changes in the loss landscape (fast-changing).
Now we need to understand what they meant by "fast-changing" and "slow-changing"
Fast Changing EMA (β1 = 0.9):
- The "fast-changing" EMA uses a lower value for the decay rate, _β_1=0.9 (compared to 0.9999), which gives more weight to recent gradients and ‘forgets’ older gradients.
- This EMA is responsive to small, recent changes in the loss landscape because it places a stronger emphasis on the most recent gradients.
Slow Changing EMA (β3 = 0.9999):
- The "slow-changing" EMA uses a higher value (_β_3=0.9999) for the decay rate i.e. it retains past gradient information for a much longer time.
- As this EMA changes slowly, naturally it should be useful for loss-landscapes which are long and contain many flat regions.
Half-Life for Fast and Slow-Changing EMAs:
Previously, we asked how much older gradients contribute. Given the EMA formula, we can do some simple calculations as discussed in the paper. Since the weighting factor for the gradients in the EMA (Eq. 10) is a geometric series, we use high school Mathematics to derive half-life.
Half-life in this context is defined as the number of successive previous steps receiving a cumulative weight of 0.5. Of course, it is half because for the infinite geometric series in Eq. 10 i.e. for T=∞, we will have the sum 1. Back to half-life calculation:

We can use the summation rule for geometric series:

Taking natural logarithm on both sides, we get an expression for the t_{half}:

For _β_1=0.9, we have a half-life of ≈5.5788 (epochs), for β3=0.9999, however, we will have a much higher half-life of ≈6930.1252 (epochs).
This highlights that for small β=0.9, half of the weight is given to the previous six gradients, while for higher β=0.9999, this is spread over 6930 steps. For small β, we focus on the immediate gradients, for larger β we also include the contribution from the older gradients.
We can write a simple script in Python to show how the weights in gradients drop as a function of training epochs for different values of β.
max_epoch = 3000
# plot in the paper contains 10k epochs (Fig. 3a, AdEMAMix)
# weights for different betas
for beta in beta_values:
weights = (beta ** steps) * (1 - beta)
plt.plot(steps, weights,
label=f'$beta$ = {beta}')
plt.title('Exponential Decay of Weights in EMA',
fontsize=12)
plt.xlabel('Epochs (i=0=Current)',
fontsize=10)
plt.ylabel(r'Weight: $beta^i , left(1 - betaright)$',
fontsize=10)
plt.legend(fontsize=12)
#plt.yscale('log')
plt.xscale('log')

AdEMAMix:
With these concepts backing the idea of the necessity of two EMAs with small and large βs, the authors introduce the formulation for AdEMAMix:

For the new multiplicative factor α in Eq. 14, the authors found that α ∈ [4,10] works best.
Compared to Adam, we have only one extra update rule; Let’s look at the equations for Adam optimizer below:

Results: While discussing the results of training LLMs, the authors show that 1.3B parameter AdEMAMix LLM trained on 101B tokens performs comparably to an AdamW model trained on 197B tokens, an almost 100% increase in the data. They highlighted that older gradients which could stay relevant for thousands of steps, can also be leveraged to obtain lower minima.
One particular result that I found very interesting is that they show that models trained with AdEMAMix, ‘forgets’ the training data slower. The authors gave an example in the paper for training an LLM (RedPajama) with a hold-out batch B. First, they separately trained with AdamW and AdEMAMix model on a RedPajama dataset which didn’t see the batch B. Then this batch was introduced at different times during training (eg. at epoch 90K, 170K, 230K etc.,). The training losses as a function of iteration using these two optimizers are shown in the plot below; Let’s look at this intriguing plot from the paper itself:

The top and bottom rows show results using AdamW and AdEMAMix optimizers respectively. They show that for both these optimizers, there was a rapid decrease in the loss on B right after the training on B. The decrease in loss is sharper with AdamW compared to AdEMAMix. So one might think that AdamW helps the LLM to learn better because the loss drops down abruptly, but with AdamW the loss on B then goes back up faster, which the authors interpret as the model forgetting B faster. This is in contrast with training curves for AdEMAMix, which are smoother, with the loss on B going back up slower, and ultimately this hold out batch B had a bigger impact on the training when using AdEMAMix – as the loss stays lower compared to training with AdamW. That’s pretty cool!!
The Math Behind the β3 Scheduler:
Since _β_3 is the new addition in the new optimizer AdEMAMix (Eq. 14), the authors gave a brilliant justification for its chosen scheduling strategy. Since _β_3 is now a function of training epochs, this is introduced in the paper as an exponential function:

where T{β_3} = T = training iterations and β{_start} is set to _β_1=0.9.
Here we need to ask why the authors chose the exponential function for this scheduler. They gave a nice intuition builder;
Let’s say we start from β=0.9 and increase it by δβ=0.0001, now we have β′=0.9001; If we look back at the half-life equation, this will cause barely any increase in the half-life:

But for the same δβ when our β is at 0.999 then we have a staggering 77 iteration increase for the half-life.

So a linear scheduler for _β_3 is not a good choice because what we want is that half-life increases linearly so that as we increase _β3, we are linearly increasing the effect of past gradients via t{half}.
As at the beginning of the training _β_3 = _β_start = _β_1 = 0.9, our t{_half} is small as we focus on the recent gradients and as training iteration increases we want to accumulate weights from more past gradients.
Let’s now derive this formula in detail step by step; Part of this is already introduced in the appendix of the paper. We start with the half-life formula which is a function of decay rate β and find the inverse of it:


As t increases, the inverse function also increases.
What we want is a linear (smooth) transformation of f(β) from _β_start to _β_end, so an interpolation parameter μ ∈ [0,1] is introduced to have a smooth transition from f(_β_start) to f(_β_end) such as:

When μ=0, we are at the f(_β_start) and, for μ=1, we are at the f(_β_end). But this f(β) is essentially t{_half} which increases linearly but we want this for β, so we use the inverse transformation to get back from f(β) space to β space.

The inverse function (f−1) maps this linear increase in t{_half} back to a corresponding value of β. Let’s get through the calculation to reach Eq. 16 from here. Using the definition of the inverse function from Eq. 18

Let’s simplify the term on the exponent by using the definition of the function f(β) from Eq. 17:

Using this simplified expression for the exponent we re-write Eq. 21 and take natural logarithm on both sides will lead us to the final formula for the scheduler:

This already looks very similar to Eq. 16, all we need is to set _β_end=_β_3, and μ=t/T{β_3}. We can also plot the β3 scheduler as a function of the interpolation factor μ.

This brings us to the end of the post. We dived into the mathematics and the intuition behind this new optimizer, AdEMAMix as recently proposed by researchers from Apple and EPFL. TLDR: Mixture of two EMAs is better than one, as one focuses on the immediate recent gradients and the other focuses on considering the contributions from the older gradients which could stay relevant for thousands of epochs for training large models.
When will you be using this optimizer for your LLM?
References:
[1] The AdEMAMix Optimizer: Better, Faster, Older; M. Pagliardini, P. Ablin, D. Grangier
[2] Adam: A Method for Stochastic Optimization; D.P. Kingma, J. Ba.
[3] Notes are available on my GitHub.
Note: All the images, if not specified with a reference, are author’s work and available for free on the linked notebook in reference.