A Search for Efficient Meta-Learning: MAMLs, Reptiles, and Related Species

Cody Marie Wild
Towards Data Science
22 min readSep 27, 2020

--

It’s an oft-lamented fact that the capabilities of modern machine learning tend to be narrow and brittle: while a given technique can be applied to a number of tasks, an individual learned model specializes in only one and needs a lot of data to acquire that specialized competence.

Meta Learning asks: instead of starting from scratch on each new task, is there a way to train a model across tasks so that the acquisition of specific new tasks is faster and more data-efficient? Approaches in meta learning and the related discipline of few-shot learning have taken many shapes — from learning task-agnostic embedding spaces to recurrent networks that are passed in training data sequentially and encode a learning algorithm in their state evolution weights — but arguably the most intuitive of these methods has been MAML: Model Agnostic Meta Learning.

[As an aside: this blog post is going to assume a certain amount of background context on meta learning, most centrally around the idea of learning taking place not over a single task but a distribution of tasks. If you’re a bit shaky on that idea or on meta-learning as a concept, I’d recommend reading my earlier post on meta learning writ large, and then coming back to this one, which delves into MAML and related methods in more detail.]

The premise of MAML is: if we want an initial model that can adapt to new tasks drawn from some task distribution, given a small number of data points on each new task, we should structure our model to directly optimize for that goal. In particular, MAML executes the following set of steps:

  1. Define some set of initial parameters: θ
  2. Sample a task t from the task distribution, and perform k (generally <10) steps of gradient descent on training batches of data from t, initializing parameter values at θ. At the end of k steps, you end up with parameters ϕ.
  3. Evaluate your loss on a testing batch from task t at parameter value ϕ, and then calculate the derivative of that loss with respect to our initial parameters θ. That is, calculate how can we modify our cross-task network initialization to lead to a better loss after a small number of optimization steps on a new task. This kind of derivative, that propagates through the learning procedure itself back to the values of the starting weights, is pretty notably different from the derivatives we usually take in gradient descent, in ways that we’ll delve more into shortly.
  4. Use that gradient vector to update our θ initialization parameters, throw away the ϕ learned for this specific task, and start the process over again with a new sampled task, initializing our network with our recently-updated θ value.

A reasonable question to ask, given this approach, is:

Is this just the same thing as to training a parameter vector that works well for all tasks on average?

The shorthand I’ll use for this approach — where we just sample alternating batches from different tasks, and do normal gradient descent on each batch in term — is joint training. The empirical answer as to whether MAML actually outperforms joint training isn’t clear to me, but conceptually what MAML is trying to do is subtly different, and, in my experience, the project of pulling apart the threads of that difference has been valuable in getting a clearer view, not only of MAML itself, but of the variations on it that have been subsequently proposed.

What Makes An Initialization Good?

An obvious question posed by MAML’s goal of learning a good cross-task initialization is: what are the different ways that different parameter initializations could be good? What are the ways they could potentially, impact final loss values? Put more simply, what makes one initialization better than another for some distribution of tasks?

Because, sometimes, the very beginning isn’t actually the very best place for parameters to to start

First, let’s imagine the case where we don’t optimize at all on a new task. In this case, there’s no distinction between our ‘final’ parameter values ϕ and the initial θ ones. Changes to θ would (trivially) mean an equivalent change to ϕ, since they’re the same value.

When we start actually taking optimization steps on tasks, and learning a task-specific ϕ that differs from θ, things get more interesting.

It could be, just like above, that the benefit of your meta-learned θ is just that it has low loss to begin with on the task, and so optimization steps taken from there get a head start because of starting from that low loss value. In general, optimization will make your loss better relative to your starting point, so having a lower-loss initialization provides a lower rough ceiling on your final loss, and can lead to the final ϕ you learn being better.

However, let’s imagine an alternate case, where, for some reason, we could only choose between initialization values θ which all themselves have the same loss value on the new task you’ve drawn. In this world, we couldn’t push our post-optimization loss lower by pushing the loss of our initial parameters lower, because we’ve artificially removed that lever. In this setting, what are other ways that one initialization could be better than another?

For the sake of visualization, let’s imagine a simple example where all our equivalent-loss points are in a 2D circle around our theoretical minimum. Imagine that one point sits on a plateau, that stays up at a high loss as you move in the direction of the minimum, until it drops at the last moment. Another point sits at the top of a gradual slope downward towards the minimum. Even though these points start out at the same (relatively high) loss, the one at the top of the gradual mountain is obviously better, at least in a world where you’re using gradient descent as your optimization strategy, because gradient descent will have an easy time following the sloping path from the starting point to the minimum-loss point. The point on the plateau, on the other hand, wouldn’t give gradient descent anything to work with initially, and so would have a more difficult time finding the minimum.

So, these are two rough senses in which initializations can be “good”:

  1. They can themselves have low loss in expectation across tasks
  2. They can be positioned in ways that better facilitate subsequent steps of loss reduction on tasks

This is obviously a rough conceptual categorization, and I certainly don’t claim that this completely captures the relationship between parameters and loss, but I do think it’s a useful dichotomy

At a high level, the hope of MAML is that, because it calculates derivatives that account for how changes in initial values do or don’t propagate through gradient descent to influence final values — in short, because it backpropagates through SGD itself — it can (in theory) find initializations that are good according to both criteria (1) and criteria(2).

Joint training, which doesn’t have an inner loop of per-task optimization, wouldn’t be able to explicitly take into account the effects of such optimization, would only tend to find initializations that are good due to having a low loss on each sampled task in the absence of further optimization.

Why might we care about this distinction? One reason is that it may be implausible to find a single set of parameters that are a good “average” solution across tasks; it may just require capturing too much information in one parameter vector. However, given that any parameters learned over a full task distribution will have to trade off performance on different tasks from the distribution, if we are able to find parameters that are able to learn effectively on new tasks, we’ll hopefully be able to recover from that inherent regression-to-the-mean mediocrity given a small number of task-specific examples.

In practice, for any given algorithm or set of tasks, these two criteria aren’t actually going to be an either-or dichotomy, but a mixture — an algorithm might lean more on one approach or the another to drive its performance, and might lean on each to differing degrees across different tasks tasks. Most empirical evaluations I’ve seen don’t do a good job of differentiating whether their performance gains come from low-loss initializations versus easy-to-optimize initializations. To be fair, it’s possible that this conceptual distinction would be too messy and difficult to map onto a clean experimental metric; I certainly haven’t tried.

But, in the absence of good empirical testing, it’s hard to know: maybe there are huge gains to be gotten from the easy-to-optimize initializations, the kind that MAML has a theoretical advantage at finding, or maybe they’re simply a theoretical edge case with no real practical value, and meta-learning mostly just succeeds by finding parameters that are good on average. My read is that we haven’t answered that question with much certainty yet. But it’s useful to have in mind that much of the theoretical value proposition of MAML rests on the hope of finding initializations that are good both due to low loss and ease of per-task optimization.

Breaking (Up) the Chain

Let’s take a closer look at the derivative that MAML uses to optimize. In words, MAML’s goal is:

How can I modify my initialization parameters (θ) to decrease the loss of the parameters (ϕ) that I reach after applying a few steps of per-task gradient descent to θ?

Mathematically, that looks like this:

Source: Meta Learning With Implicit Gradients (Rajeswaran et al). The ultimate quantity we’re laying out equations for here is the derivative of the loss calculated with the parameters we reach after multiple steps, with respect to your initial θ

Let’s look more closely at the two components of this.

Alg(θ) here is the value you get after applying an optimization algorithm (in this case, k steps of SGD) to the initial θ parameters. ϕ and Alg(θ) mean the same thing; the latter just highlights the fact that where you end up (ϕ) is a function of where you start (θ), and the algorithm that gets you there is gradient descent. This part of the equation is the same as it would be if you were just performing another step of normal gradient descent: evaluating the gradient of your loss on task i at whatever parameters you end up at (in this case, ϕ, otherwise known as Alg(θ). So, this part of the derivative is uncomplicated and straightforward.

Here’s where things get tricky. Normally, the weights you calculate a gradient with respect to are the same as the weights you’re going to be updating. But in this case, we calculate a single-task loss gradient at our final weights ϕ, but then want to update our initial weights θ. These are the weights that carry over between tasks, the shared initialization that we start optimizing from at the beginning of each new task. These are the weights we ultimately want to learn in our meta-learning loop, and so we these are the weights we need a gradient with respect to. That means that we don’t just need to take a gradient of loss with respect to ϕ, but have to go one level deeper, and take it with respect to θ. This is what makes MAML an algorithm that optimizes for a set of parameters that support an effective learning process, rather than just for a better-performing fixed parameter point.

This second-level derivative is the most salient way in which MAML is different than simple joint training.

This “one level deeper” requires that we use the chain rule. To use the chain rule here, we need to multiply the loss gradient at ϕ with the gradient of the parameter values at ϕ with with respect to θ. In other words, in order to know how to manipulate θ to influence the loss we end up with at ϕ, we need to:

  1. Know the direction in which to change ϕ to get lower loss, and
  2. Know the direction in which to change θ in order to move the post-optimization value ϕ in a particular direction

Intuitively, the goal of this derivative is to capture the dynamics of the multi-step gradient descent process, and prioritize changes to θ that will effectively propagate influence through that process. Without knowing (2), if you just naively apply the same change to θ that you would to ϕ, you may end up not effectively reducing loss, because the learning process that starts from your new θ might not result in the ϕ value that’s moved in the direction that you wanted.

The way we get from initialization θ to ϕ is by performing gradient descent on k batches from a new sampled task.

Equations capturing the process of parameter updating in basic gradient descent. The final phi value depends on the derivatives of loss with respect to parameters calculated at each intermediate parameter value along the chain (represented by theta subscripted with 0-k)

To determine how shifts in θ propagate through to change in the final ϕ, you need to capture how the gradients calculated at each point in this chain wouwoulddl respond to changes in the value of the parameter point at which they’re calculated. If we change θ, we get a different starting value for our parameters (obviously), but we will also calculate a different gradient than we would have before. This is the second derivative — measuring how changes in θ influence the vector-valued loss gradient calculated at that θ value. Calculating a full MAML derivative requires calculating this second derivative for all of the k derivatives used in the k steps of updating.

This presents a few problems. First of all, second derivatives are costly to calculate, and also to store, since by default they are matrices of a size n-parameters squared. Also, it requires us keeping a record of the different weight vectors we encounter over the k updates, because we need to use them during the gradient-calculation step to calculate a second derivative at each of those locations. This means that, for a N-parameter model — where, as in modern models, N can be very large — a k-step meta-learning model would need to have access to kN memory for storing those intermediate parameter matrices inside the inner loop.

This intense memory requirement associated with naive implementations of MAML has given researchers a meaningful incentive to search for efficient approximations that capture some elements of the fully-realized gradient, but with lower computational requirements. That impulse is what motivates the variants of MAML that we’ll explore throughout the remainder of this post.

First Order MAML

Turns out it’s hard to have good pictures of meta-algorithms, so instead, here’s a rendering of what scientists think the first ancestor of all current mammal looked like, a first-order mammal, if you will...

The problem to be solved here is: how do you calculate the derivative that tells you how the changes you make to initialization values affect final loss values? And how do you do that cheaply?

The simplest way to solve a problem is…. let’s just try ignoring the problem and see how badly off that leaves us. A pragmatic instinct, even if not the most satisfying one. This is the tact taken by First Order MAML (or FOMAML).

Mathematically, instead of actually estimating the parameter-to-parameter derivative matrix mentioned above — the derivative of parameter vector ϕ with respect to initial parameter vector θ, which would require caching intermediate SGD steps — FOMAML assumes that derivative is just an identity matrix. In words, this means encoding an assumption that a unit change in the initial parameters corresponds to a unit change in the parameter values post-gradient-descent.

Conceptually, if this assumption were true, it would mean that differences in initialization values don’t lead to different optimization paths, they just lead to shifting a fixed optimization path around (since, if you imagine the path as unchanging and rigid, when you move the starting point, the endpoint moves by the same amount).

By assuming that the gradient with respect to ϕ is an identity matrix, you lose the ability to capture ways initialization may change your optimization path, rather than just shift it, but, on the positive side, that difficult-to-calculate component of the chain rule drops out, and our update with respect to θ just becomes the loss gradient we get at ϕ.

θ is updated in the direction of the green arrow — the gradient of test set loss, as it’s calculated at ϕ

FOMAML’s strategy to get to better loss values is to update θ according to the gradient of loss calculated at ϕ on the test set. It’s interesting to consider the ways this is both different from and similar to simply performing joint training — the most basic multi-task approach, where you intermix batches of data from different tasks, and take one normal first-order gradient step on each.

The biggest similarity is that the step used to update θ is qualitatively basically the exact same as you’d use in vanilla gradient descent: a single, first-order gradient calculated at a single point.

The first difference is that, in FOMAML, we’re calculating gradients using a separate per-task test set, that wasn’t used for the k-step training. However, this doesn’t strike me as likely to be a meaningful distinction: with only small-k steps of training, the likelihood that a new batch sampled from the training set would repeat examples seen in those k steps of training is pretty low, which would make a training and test set sampled from the same underlying distribution basically equal in terms of the novelty of their datapoints to the model.

The second distinction is that, rather than updating our parameter vector based on the gradients at that parameter setting, we take a few steps of gradient descent, and then update our parameters according to what the gradient was after those steps. This is a difference, but it’s unclear, at least to me, what we should expect the implications to be: the gradients calculated a few steps later in an optimization chain don’t seem inherently more informative than ones calculated earlier.

Cold-Blooded Simplicity

Full MAML and First Order MAML lay out two ends of a spectrum: on one end, a conceptually justified method with intense memory and computational requirements, and on the other end, a simplification that lowers the computational burden, but at the cost of a quite limiting set of assumptions.

In this context, the Reptile method from OpenAI came onto the scene, as a midway point between these two opposed alternatives: a heuristic that adds more information than is captured by a single gradient step, but does so in a way that sidesteps MAML’s explicit formalisms . (As an aside: no, the authors didn’t explain why the method was called Reptile, outside of the obvious context of it offering an alternative to the existing MAML method).

“Really? A machine learning model named after me?”

As with many things in machine learning, I find Reptile easiest to explain first as an algorithm, and then secondly in the language of mathematical formalism.

Just as in MAML, the problem is structured as few-shot learning over a distribution of tasks: each time a task is sampled, we perform k steps of gradient descent on batches of data from that task, where k varies, but is generally less than 10.

For each task, Reptile takes ϕ, the parameters reached after k steps on a task, and calculates the update to the initial parameters θ to be the total movement in parameter space over those k updates, or (ϕ — θ). This amounts to taking multiple gradient descent steps, drawing a line between where you started at step 0 and where you ended up after step k, and then using that vector as the update direction for θ. It’s interesting to note that, if k=1, this is the same as joint training: since the line between the start and end of a single gradient step is the same as the step itself, and θ will just be pulled in the direction of the first derivative of the task’s loss at that point.

θ is updated according to the green arrow: the vector that points towards the k-step optimization solution

Things become more interesting when we take multiple steps of gradient descent. Now, we’re getting more information about the loss surface, because we aggregate loss estimates made at multiple different points. This gives us similar information to what we’d get by calculating explicit second derivative — if multiple gradients all reinforce each other by going in the same direction, that means the gradient value isn’t changing much at that point, our aggregate vector will point more strongly in that direction.

Something to notice here is that this formulation breaks the clean MAML chain rule derivation: we don’t directly calculate the final test-set loss derivative and multiply it by a parameter-to-parameter, as we were doing before. It’s purely a heuristic that says “move towards the the global parameters that resulted from a few steps of task training”. But, on an intuitive level, because we aggregate a gradient over multiple steps, we give the single-task loss gradient more weight when an initialization is in a “easy to optimize from” place, since those places will be more likely to have multiple gradients in a row all pointing in the same direction, which will sum together. If the initialization is in a place that it’s difficult to optimize from on a given task due to noisy, conflicting gradients over multiple steps, the individual-step gradients will cancel out. If you squint, you can see this as conceptually comparable to updating θ according to a term that scales loss at ϕ by the amount of influence θ has on ϕ, since if θ is in a place of noisier gradients, any change made to θ is likely to lead the optimization path to find a meaningfully different ϕ value, where the gradient you calculated last time may no longer apply.

All of that said: this is definitely a heuristic, and you should take this framing of mine as just one possible interpretation, without any specific empirical testing behind it.

The authors also argue that this method has the effect of finding a θ initialization that is, in expectation, close in parameter space to the optimal parameters for each task. Their claim here is based on treating the parameters reached after k steps as being a good approximation of the optimal parameters. This strikes me as not enormously convincing: there’s a pretty meaningful qualitative difference between a single step of gradient descent and the full optimization procedure to take you to an optimal solution for a task. And even though Reptile takes k steps, that k is generally quite small, and so it seems like the solution reached on each task would be more qualitatively similar to the result of a single-step update, as opposed to an optimal parameter configuration.

Implicit MAML

As I’ve mentioned a few times, fully calculating the derivative of final parameters ϕ with respect to initial ones θ is a pain. In order to be able to propagate derivatives back through the multi-step computation process, you need to store the your intermediate parameters from each step of the k steps, . When we’re talking about million-parameter neural nets, having to store multiple copies of the entire network in memory for each meta-update step is a meaningful hurdle, especially as the k within your k-step inner optimization loop goes up.

Where FOMAML assumes a world in which the derivative is trivial, and Reptile constructs a heuristic that captures aspects of the gradient’s information without directly calculating it, Implicit MAML, or iMAML, constructs an approximate derivative that is more analytically grounded than Reptile’s, but which allows for a greater degree of expressiveness than FOMAML’s assumption that the gradient is just the identity.

iMAML frames the problem as:

What is the derivative of the optimal single-task parameters, ϕ*, with respect to θ?

If you use this new framing, a number of quantities become analytically simpler and easier to work with. But, at first glance, this should seem like a doomed effort: the minimum-loss point in your loss space just is wherever it is, regardless of where you initialize your parameters. So how would it be possible for your θ initialization parameters to influence the location of the minimum-loss point, in a way that would make it possible to calculate a derivative of that point with respect to them?

The authors of Implicit MAML introduce a modification to their objective to fix this: a regularization that penalizes solutions according to their distance from the initial values. As you move farther from the initialization values within the updates on a single task, the squared-distance penalty eventually overwhelms the benefit of lower loss, and you reach an equilibrium point of minimum loss. So, under this modified objective, different initial points will locate the center of that squared-distance penalty in different places, and thus result in different minimum loss points of a hybrid loss that incorporates both the true underlying loss and the distance penalty. This establishes a relationship between initial parameter values and the minimum-loss point of your space (under the hybrid loss), making it meaningful to talk about the derivative of the minimum-loss parameters with respect to θ.

At first, this modified objective seems like a purely arbitrary change to make the derivatives be well-defined (and that’s true to some extent), but it seems more reasonable when we remember that meta-learning is trying to solve for initializations that perform well in few-step learning settings on each new task. So, if the definition of your problem means you’re only taking a few learning steps on a new task, that means you won’t ever have a chance to move very far away from your initializations. In that light, a modified objective that prioritizes nearby solutions seems more like a sensible approximation, rather than something purely made up for mathematical expediency.

Normal loss regularized by the squared norm between theta and your current parameter value

Conveniently enough, it happens to be the case that the gradient of this objective’s optimal parameters with respect to θ has a closed form analytic solution, and that this closed form solution can be efficiently calculated using Conjugate Gradient, without needing to store the chain of mid-optimization parameter values needed to calculate a full derivative. I’m not going to go into the details of how Conjugate Gradient works, and why it can be more efficient, since that would be its own separate blog post, but I strongly recommend this well-written, accessible tutorial, if anyone wants to learn more.

The derivative of Alg* — the parameters corresponding with the optimal value of the regularized objective — with respect to the initial parameters. This is what we use as a parameter-to-parameter derivative, multiplying it by the loss at ϕ to get our MAML-style update

So, if we assume that the parameters we get after k steps are close enough to being optimal under our regularized objective, then a derivative of optimal parameter values with respect to θ will be a pretty good approximation of the derivative of our actually-reached parameters. With that approximated component in hand, we can combine it with the easy-to-calculate loss gradient at ϕ and plug them in to the update rule.

It’s true that, unless our regularization term is so powerful that we drown out the effect of actual loss after a few steps, we’re likely not actually reaching a minimal loss point: k is generally small, and so ϕ is just the result of a few steps of gradient descent in the (hopeful) direction of the minimum. But assumptions always require being a little bit wrong to make your computations nicer, and this one seems reasonably likely to hold well enough in practice.

It’s useful to situate this approach alongside both MAML and First-Order MAML. In original MAML, the gradient of ϕ with respect to θ is just literally that quantity, fully and properly calculated, painful memory requirements and all. In First Order MAML, we assume that a change in initial parameters results in a constant shift in final values, and basically give up on capturing how the dynamics of the k-step learning process might depend on your starting point. In Implicit MAML, we allow the actual value of θ to influence the inter-parameter gradient, but only by taking into account the (regularized) theoretical optimum induced by θ, rather than the empirical SGD path we took to get to ϕ from θ. So, this is somewhere in the middle between the two: a simplifying assumption, but one that still captures characteristics of the loss landscape around our initial parameters.

All Simplifications Are Simple In Their Own Way

The illustration below is an attempt at pulling together these different methods, and visually capturing how they differ. Single small arrows indicate the loss gradient of particular batches, on a particular task. Green indicates which gradient or combination of gradients is used to update initial parameters θ.

Joint training updates using the loss calculated at the initializations themselves, since joint training only uses single batches of each task. The various variants of MAML all start with the dotted-arrow gradient, which is the loss calculated at the parameters, ϕ, that you reach after k update steps. First Order MAML uses that loss directly as the update vector for the initializations. MAML takes that vector, and multiplies it by the (matrix) gradient of ϕ with respect to θ. iMAML multiplies it by the gradient of the modified-objective-optimal parameters with respect to θ. Reptile sums all of the individual k updates, then scales that vector down, and uses that to update θ.

In joint training, only one batch is run on each task before parameters are updated, so they are inherently updated using only the gradient from that first (and only) batch. In the variants of MAML, we take multiple update steps, and the gradient from the last of these batches is used to update θ, after being multiplied by the some version of the gradient of the final parameters with respect to the initial ones.

Conclusion

Along with brittleness, another widespread critique of modern machine learning is its tendency to tell itself just-so stories: coming up with a theoretical justification for why its methods should work, observing empirical performance gains, and assuming the latter validates the theoretical claims of the former.

In that spirit, it seems useful to balance out this mostly conceptual and theoretical post with some grounding empirical questions, to help think about which conceptual claims have the most support.

Do algorithms that take advantage of second-order information actually reliably outperform joint training equivalents like FOMAML?

As far as I can tell, yes, but that difference is more pronounced at smaller values of k, somewhat paradoxically suggesting that much of the value is in θ having low loss across tasks to start with, and that using multi-step information is better at getting you that low loss. When k is larger, you may be more able to more effectively “catch up” from having started in a worse place, optimization-wise. At larger values of k, MAML, iMAML and Reptile still outperform, on the tested tasks, but not as substantially.

Does the more principled iMAML outperform the more heuristic Reptile?
I am… not sure. The empirical results on this are sparse, and don’t seem conclusive. Looking at the ones reported in the iMAML paper, iMAML performs better than Reptile on Omniglot (where each task is classifying letters from a different alphabet), but is within the error bounds of it on mini ImageNet

Part of the problem here is that — because original MAML itself is so computationally expensive — we’re mostly testing on simple tasks where existing models already perform quite well, and squinting at differences between 97–99% accuracy. I’d love to see more evaluation tasks that are genuinely difficult to learn in the absence of meta-methods, but obviously that would come at a high computational cost.

Meta Learning has a lofty ideal — learning how to learn, how to adapt to an unseen task. And it genuinely does seem to be outperforming more simplistic joint-training style methods. But it’s important to not get too caught up in the sweeping frames that meta learning creates for itself, and to instead try to identify the actual mechanical distinctions between simpler and more complex approaches, and rigorously assess whether our reasoning for the value they add matches up with reality.

Some remaining open questions of mine :

  • Is there a k at which FOMAML and MAML converge to equivalent performance? The tested k values of 1 and 5 are much lower than you’d actually use in practice, even for very data-efficient fine-tuning, and it’s not clear whether meta-optimization methods would still provide value in those more realistic regimes
  • To what extent does meta-learning actually optimize for a more effective and efficient few-step update, as opposed to just finding a generically good starting point? I’d love to see more metrics (or be pointed to existing ones) that split this difference by actually characterizing update paths, and actually qualitatively examine the ones learned by MAML vs joint training approaches.

--

--

machine learning engineer; lover of cats, languages, and elegant systems; professional curious person.