The world’s leading publication for data science, AI, and ML professionals.

Basics of few-shot learning with optimization-based meta-learning

Overview of the mechanics behind MAML, FOMAML, and Reptile methods in optimization-based meta-learning

Photo by Kelly Sikkema on Unsplash
Photo by Kelly Sikkema on Unsplash

Meta-learning approaches can be broadly classified into metric-based, Optimization-based, and model-based approaches. In this post, we will mostly be focusing on the mathematics behind optimization-based meta-learning approaches.

Terminologies. Meta-learning models are trained with a meta-training dataset (with a set of tasks τ = {_τ_₁, _τ_₂, _τ_₃, …}) and tested with a meta-testing dataset (tasks τₜₛ). Each task τᵢ consists of task training set (i.e. support set) _Dᵢ_ᵗʳ and task test set (i.e. query set) _D_ᵢᵗˢ. One type of meta-learning problems is N-way k-shot learning, in which we choose between N classes and learn with k examples per class.

Illustration of the meta-training, meta-testing, support and query datasets for a 2-way 1-shot example. Image by author.
Illustration of the meta-training, meta-testing, support and query datasets for a 2-way 1-shot example. Image by author.

Transfer learning (fine-tuning)

Before going on to discuss meta-learning, we will briefly mention another commonly used approach – Transfer Learning via fine-tuning, to transfer knowledge from a base model (e.g. built by identifying many different objects) to a novel task (e.g. identifying specifically dogs). Here the idea is to build models pre-trained on general tasks, and fine-tune the model on a new specific task (either by only updating limited set of layers in a neural network and/or with a slower learning rate). We will go over the mathematical terminologies in this section, so we can compare and contrast with meta-learning to be discussed later.

In a fine-tuning setting, we would first derive an optimized set of parameters _θ_ᵖʳᵉ-ᵗʳ pre-trained on _D_ᵖʳᵉ-ᵗʳ,

During fine-tuning, we would then tune the parameters that minimize the loss to training set _D_ᵗʳ,

The equation illustrates one gradient step, but in practice this is optimized via multiple gradient steps. As an illustration, below shows the paths in the parameter space going from the pre-trained parameter values _θ_ᵖʳᵉ-ᵗʳ toward the fine-tuned parameter values θ.

Fine-tuning. Image by author.
Fine-tuning. Image by author.

In transfer learning via fine-tuning, the hope is that the base model have learned the basic patterns (such as shapes, contrasts, objects in images) that fine-tuning can more quickly and easily adopt to a new task. However, the approach is not specifically designed explicitly around learning to learn. The novel task may not overlap with the base tasks and result in poor performance for the transfer of knowledge. Meta-learning, on the other hand, is designed explicitly around constructing tasks and algorithms for generalizable learning.

MAML

Model agnostic meta-learning (MAML) was proposed by Finn et al. in 2017. This is an optimization-based meta-learning approach. The idea is that instead of finding parameters that are good for a given training dataset or on a fine-tuned training set, we want to find optimal parameters that with fine-tuning are generalizable to other test sets.

For one task. Given a task, we will first use a support training dataset _D_ᵗʳ in a fine-tuning step. The optimal parameter ϕ for _D_ᵗʳ is,

Unlike fine-tuning (which we would have stopped here), we want to calculate how well this optimal parameter ϕ do on a query test dataset _D_ᵗˢ, with the loss function as L(ϕ, _D_ᵗˢ). The objective is optimize the initial parameter θ such that it would perform well on the query test set given fine-tuning. In other words, we update θ in a meta-training step as,

Here we need to calculate ∇_θ L(ϕ, _D_ᵗˢ), which is the derivative of the loss function with respect to θ.

We can illustrate the paths in the parameter space as follows,

MAML for one task. Image by author.
MAML for one task. Image by author.

Note that instead of directly updating θ at the finetuning step, we get a sense on the direction toward the optimal parameters based on the support train and test datasets (paths in gray), and update θ in the meta-training step.

For task sets. Instead of just one task, for generalizability across a variety of tasks, we can perform this meta-learning at each step by averaging across a set of tasks τ = {_τ_₁, _τ_₂, _τ_₃, …}. Hence the optimal parameter ϕᵢ for task τᵢ of support set is,

The meta-training step is,

The term ∇_θ L(_ϕ_ᵢ, _Dᵢ_ᵗˢ)$ can be further expanded. Below we will omit the subscript i, but the discussion is applicable as on a per-task basis. With chain rule the term can be expressed as,

We can expand on the earlier path visuals to include multiple tasks,

MAML for multiple tasks. Image by author.
MAML for multiple tasks. Image by author.

Here we get a sense on the directionality toward the optimal parameters for each task (in different colors), and update θ based on the average across the tasks (path in black).

First order MAML

In the MAML meta-learning step, we need to calculate the Hessian matrix. As an alternative, in first-order MAML (FOMAML), a first-order approximation can be used by regarding ∇_θ L(θ, _D_ᵗʳ) as a constant and hence ignoring the second derivative terms. This means we treat the term ∇_θ ϕ as identity matrix I, resulting in,

This can be illustrated visually as follows,

First-order MAML. Image by author.
First-order MAML. Image by author.

Note we are not performing a meta-gradient computation by unrolling all the way up the computation graph, but instead we are using the first-order approximation ∇_ϕ L(ϕ, _D_ᵗˢ) as gradient for updating θ.

Reptile

Reptile (by OpenAI) is an alternative approach with performance on-par with MAML, but more computationally and memory efficient than MAML as there is no explicit calculations of the second derivatives.

First we’ll introduce an update function Uᵏ, which is just a reformulation (and generalization) of the fine-tuning step in MAML,

where k is the number of times ϕ is updated.

With Reptile, at each iteration, 1) a task τᵢ is sampled, 2) the optimal parameter ϕᵢ for τᵢ is calculated after k updates, and 3) the model parameter θ is updated as,

Instead of one task per iteration, multiple tasks can be evaluated, leading to a batch version as follows,

where

The parameters path can be schematically visualized as,

Reptile. Image by author.
Reptile. Image by author.

The key distinction that differentiates Reptile from it being just a regular stochastic gradient descent averaged across different tasks is the estimation of ϕᵢ over k>1 steps and using ϕᵢθ as the gradient for updating θ. In the vanilla stochastic gradient descent, the parameters are updated after each gradient step (, where k=1). The authors Nichol et al. have showed that when k>1, this allows the algorithm to pick up on the higher-order derivatives, and the consequent behavior is similar to MAML and distinctly different from when k=1.

Resources

References

  1. Finn et al. Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks. ICML 2017. arXiv
  2. Nicol et al. On First-Order Meta-Learning Algorithms. arXiv 2018. arXiv

_Originally published at https://boyangzhao.github.io on August 7, 2021._


Related Articles