A Framework For Contrastive Self-Supervised Learning And Designing A New Approach

In a new paper, we discuss the key ideas driving performance in self-supervised learning and show what matters.

William Falcon
Towards Data Science

--

Contrastive learning: Batch of inputs.

This is the partner blog matching our new paper: A Framework For Contrastive Self-Supervised Learning And Designing A New Approach (by William Falcon and Kyunghyun Cho).

In the last year, a stream of “novelself-supervised learning algorithms have set new state-of-the-art results in AI research: AMDIM, CPC, SimCLR, BYOL, Swav, etc…

In our recent paper, we formulate a conceptual framework for characterizing contrastive self-supervised learning approaches. We used our framework to analyze three examples of these leading approaches, SimCLR, CPC, AMDIM, and show that although these approaches seem different on the surface, they are all in fact slight tweaks of one another.

In this blog we will:

  • Review self-supervised learning.
  • Review contrastive learning.
  • Propose a framework for comparing recent approaches.
  • Compare CPC, AMDIM, MOCO, SimCLR, and BYOL using our framework.
  • Formulate a new approach — YADIM — , using our framework.
  • Describe some of our results.
  • Describe the computational requirements to achieve these results.

The majority of this work was conducted while at Facebook AI Research.

Implementations

You can find all the augmentations and approaches we described in this article implemented in PyTorch Lightning which will allow you to train on arbitrary hardware and makes the side-by-side comparison of each approach much easier.

AMDIM

BYOL

CPC V2 (only verified implementation outside of DeepMind to our knowledge).

Moco V2

SimCLR

Self-Supervised Learning

Recall that in supervised learning, a system is given input (x) and a label (y),

Supervised learning: Input on the left, label on the right.

In self-supervised learning, the system is only given (x). Instead of a (y), the system “learns to predict part of its input from other parts of its input” [reference].

In self-supervised learning, the input is used both as the source and target

In fact, this formulation is so generic that you can get creative about ways of “splitting” up the input. These strategies are called pretext tasks and researchers have tried all sorts of approaches. Here are three examples: (1) predicting relative locations of two patches, (2) solving a jigsaw puzzle, (3) colorizing an image.

Examples of pretext tasks

Although the approaches above are full of creativity, they don’t actually work well in practice. However, a more recent stream of approaches that use contrastive learning has actually started to dramatically close the gap between supervised learning on ImageNet.

The latest approach (Swav) is closing the gap with the supervised variation trained on ImageNet (credit: Swav authors)

Contrastive Learning

A fundamental idea behind most machine learning algorithms is that similar examples should be grouped together and far from other clusters of related examples.

This idea is what’s behind one of the earliest works on contrastive learning, Learning a Similarity Metric Discriminatively, with Application to Face Verification By Chopra et al in 2004.

The animation below illustrates this main idea:

Contrastive learning achieves this by using three key ingredients, a positive, anchor, and negative(s) representation. To create a positive pair, we need two examples that are similar, and for a negative pair, we use a third example that is not similar.

But in self-supervised learning, we don’t know the labels of the examples. So, there’s no way to know whether two images are similar or not.

However, if we assume that each image is its own class, then we can come up with all sorts of ways of forming these triplets (the positive and negative pair). This means that in a dataset of size N, we now have N labels!

Now that we know the labels (kind of) for each image, we can use data augmentations to generate these triplets.

Characteristic 1: Data Augmentation pipeline

The first way we can characterize a contrastive self-supervised learning approach is by defining a data augmentation pipeline.

A data augmentation pipeline A(x) applies a sequence of stochastic transformations to the same input.

A stochastic data augmentation pipeline applied to an input

In deep learning, a data augmentation aims to build representations that are invariant to noise in the raw input. For example, the network should recognize the above pig as a pig even if it’s rotated, or if the colors are gone or even if the pixels are “jittered” around.

In contrastive learning, the data augmentation pipeline has a secondary goal which is to generate the anchor, positive and negative examples that will be fed to the encoder and will be used for extracting representations.

CPC pipeline

CPC introduced a pipeline that applies transforms like color jitter, random greyscale, random flip, etc… but it also introduced a special transform that splits an image into overlaying sub patches.

The key CPC transform

Using this pipeline, CPC can generate many sets of positive and negative samples. In practice, this process is applied to a batch of examples where we can use the rest of the examples in the batch as the negative samples.

Generating positive, anchor, and negative pairs from a batch of images. (Batch size = 3).

AMDIM pipeline

AMDIM takes a slightly different approach. After it performs the standard transforms (jitter, flip, etc…), it generates two versions of an image by applying the data augmentation pipeline twice to the same image.

This idea was actually proposed in 2014 via this paper by Dosovitski et al. The idea is to use a “seed” image to generate many versions of the same image.

SimCLR, Moco, Swav, BYOL pipelines

The pipeline in AMDIM worked so well that every approach that has followed uses the same pipeline but makes slight tweaks to the transforms that happen beforehand (some add jitter, some add gaussian blur, etc…). However, most of these transforms are inconsequential compared with the main idea introduced in AMDIM.

In our paper, we ran ablations on the impact of these transforms and found that the choice of transforms is critical to the performance of the approach. In fact, we believe that the success of these approaches is mostly driven by the particular choice of transforms.

These findings are in line with similar results posted by SimCLR and BYOL.

The video below illustrates the SimCLR pipeline in more detail.

Characteristic 2: Encoder

The second way we characterize these methods is by the choice of encoder. Most of the approaches above use ResNets of various widths and depths.

ResNet architecture (credit: Original ResNet authors)

When these methods began to come out, CPC and AMDIM actually designed custom encoders. Our ablations found that AMDIM did not generalize well while CPC suffered less from a change in the encoder.

Encoder robustness on CIFAR-10

Every approach since CPC has settled on a ResNet-50. And while there may be more optimal architectures that we’ve yet to invent, standardizing on the ResNet-50 means we can focus on improving the other characteristics to drive improvements as a result of better training methods and not better architectures.

One finding did hold true for every ablation, wider encoders perform much better in contrastive learning.

Characteristic 3: Representation extraction

The third way to characterize these methods is by the strategy they employ to extract representations. This is arguably where the “magic” happens in all of these methods and where they differ the most.

To understand why this is important, let’s first define what we mean by representations. A representation is the set of unique characteristics that allow a system (and humans) to understand what makes that object, that object, and not a different one.

This Quora post uses an example of trying to classify a shape. To successfully classify the shapes a good representation might be the number of corners detected in this shape.

In this collection of methods for contrastive learning, these representations are extracted in various ways.

CPC

CPC introduces the idea of learning representations by predicting the “future” in latent space. In practice this means two things:

1) Treat an image as a timeline with the past at the top left and the future at the bottom right.

CPC “future” prediction task

2) The predictions don’t happen at the pixel level, but instead, they use the outputs of the encoder (ie: the latent space)

From pixel space to latent space

Finally, the representation extraction happens by formulating a prediction task using the output of the encoder (H) as targets to the context vectors generated by a projection head (which the authors call a context encoder).

CPC representation extraction

In our paper, we find that this prediction task is unnecessary as long as the data augmentation pipeline is strong enough. And while there are a lot of hypotheses about what makes a good pipeline, we suggest that a strong pipeline creates positive pairs that share a similar global structure but have a different local structure.

AMDIM

AMDIM, on the other hand, uses the idea of comparing representations across views from feature maps extracted from intermediate layers of a convolutional neural network (CNN). Let’s unpack this into two parts, 1) multiple views of an image, 2) intermediate layers of a CNN.

1) Recall that the data augmentation pipeline of AMDIM generates two versions of the same image.

2) Each version is passed into the same encoder to extract feature maps for each image. AMDIM does not discard the intermediate feature maps generated by the encoder but instead uses them to make comparisons across spatial scales. Recall that as an input makes its way through the layers of a CNN, the receptive fields encode information for different scales of an input.

References [1], [2].

AMDIM leverages these ideas by making the comparisons across the intermediate outputs of a CNN. The following animation illustrates how these comparisons are made across the three feature maps generated by the encoder.

AMDIM representation extraction: AMDIM uses the same encoder to extract 3 sets of feature maps. Then it makes comparisons across feature maps.

The rest of these methods make slight tweaks to the idea proposed by AMDIM.

SimCLR

SimCLR uses the same idea as AMDIM but makes 2 tweaks.

A) Use only the last feature map

B) Run that feature map through a projection head and compare both vectors (similar to the CPC context projection).

Moco

As we mentioned earlier, contrastive learning needs negative samples to work. Normally this is done by comparing an image in a batch against the other images in a batch.

Moco does the same thing as AMDIM (with the last feature map only) but keeps a history of all the batches it has seen and increases the number of negative samples. The effect is that the number of negative samples used to provide a contrastive signal increases beyond a single batch size.

Credit: Original Moco Authors. (Source)

BYOL

Using the same main ideas as AMDIM (but with the last feature map only), but with two changes.

Credit: Deepmind (source)
  1. BYOL uses two encoders instead of one. The second encoder is actually an exact copy of the first encoder but instead of updating the weights in every pass, it updates them on a rolling average.
  2. BYOL does not use negative samples. But instead relies on the rolling weight updates as a way to give a contrastive signal to the training. However, a recent ablation discovered that this may not be necessary and that in fact adding batch-normalization is what keeps ensures the system does not generate trivial solutions.

Swav

Frames their representation extraction task as one of “online clustering” where they enforce “consistency between codes from different augmentations of the same image.” [reference]. So, it’s the same approach as AMDIM (using only the last feature map), but instead of comparing the vectors directly against each other, they compute the similarity against a set of K precomputed codes.

Credit: Swav Authors (source)

In practice, this means that Swav generates K clusters and for each encoded vector it compares against those clusters to learn new representations. This work can be viewed as mixing the ideas of AMDIM and Noise as Targets.

Characteristic 3, takeaways

The representation extraction strategies is where these approaches all differ. However, the changes are very subtle and without rigorous ablations, it’s hard to tell what actually drives results or not.

From our experiments, we found that the CPC and AMDIM strategies have a negligible effect on the results but instead add complexity. The primary driver that makes these approaches work is the data augmentation pipeline.

Characteristic 4: Similarity measure

The fourth characteristic we can use to compare these approaches is on the similarity measure that they use. All of the approaches above use a dot product or cosine similarity. Although our paper does not list these ablations, our experiments show that the choice of similarity is largely inconsequential.

Characteristic 5: Loss function

The fifth characteristic we use to compare these approaches is by the choice of the loss function. All of these approaches (except BYOL) have converged on using an NCE loss. The NCE loss has two parts, a numerator, and denominator. The numerator encourages similar vectors close together and the denominator pushes all other vectors far apart.

Without the denominator, the loss can trivially become a constant, and thus the representations learned will not be useful.

BYOL however, drops the need for the denominator and instead relies on the weighted updates to the second encoder to provide the contrastive signal. However, as mentioned earlier, recent ablations show that this in fact may not actually be the driver of the contrastive signal.

In this video, I give a full explanation on the NCE loss using SimCLR as an example.

Yet Another DIM (YADIM)

We wanted to show the usefulness of our framework by generating a new approach to self-supervised learning without pretext motivations or involved representation extraction strategies. We call this new approach Yet Another DIM (YADIM).

YADIM can be characterized as follows:

Characteristic 1: Data augmentation pipeline

For YADIM we merge the pipelines of CPC and AMDIM.

Characteristic 2: Encoder

We use the encoder from AMDIM, although any encoder such as a ResNet-50 can also work

Characteristic 3: Representation extraction

The YADIM strategy is simple. Encode the multiple versions of an image and use the last feature map to make a comparison. There is no projection head or other complicated comparison strategy

Characteristic 4: Similarity metric

We stick to dot product for YADIM

Characteristic 5: Loss function

We also use the NCE loss.

YADIM Results

Even though our only meaningful choice was to merge the pipelines of AMDIM and CPC YADIM still manages to do really well compared with other approaches.

Unlike all the related approaches, we generate the results above by actually implementing each approach ourselves. In fact, our implementation of CPC V2 is, to our knowledge, the first public implementation outside of DeepMind.

More importantly, we use PyTorch Lightning to standardize all implementations so we can objectively distill the main drivers of the above results.

Computational efficiency

The methods above are trained using huge amounts of computing resources. The prohibitive costs mean that we did not conduct a rigorous hyperparameter search but simply used the hyperparameters from STL-10 to train on ImageNet.

Using PyTorch Lightning to efficiently distribute the computations we were able to get an epoch through ImageNet down to about 3 minutes per epoch using 16-bit precision.

These are the compute resources we used for each approach

Based on the 23dn.24xlarge instance at $31.212 per hour
Based on the 23dn.24xlarge instance at $31.212 per hour
Based on the 23dn.24xlarge instance at $31.212 per hour
Based on the 23dn.24xlarge instance at $31.212 per hour

Key takeaways

  1. We introduced a conceptual framework to compare and more easily design contrastive learning approaches.
  2. AMDIM, CPC, SimCLR, Moco, BYOL, and Swav differ from each other in subtle ways. The main differences are found in how they extract representations.
  3. AMDIM and CPC introduced the main key ideas used by other approaches. SimCLR, Moco, BYOL, and Swav can be viewed as variants of AMDIM.
  4. The choice of the encoder does not matter as long as it is wide.
  5. The representation extraction strategy does not matter as long as the data augmentation pipeline generates good positive and negative inputs.
  6. Using our framework we can formulate new CSL approaches. We designed YADIM (Yet Another DIM), as an example that performs on par with competing approaches.
  7. The cost of training these approaches means that only a handful of research groups in the world can continue to make progress. Although, our release of all these algorithms in a standardized way at least alleviates the issue of implementing these algorithms and verifying those implementations.
  8. Since most of the results are driven by wider networks and specific data augmentation pipelines, we suspect the current line of research may have limited room to improve.

Acknowledgments

As noted in our paper, I’d like to thank some of the authors of CPC, AMDIM, and BYOL for helpful discussions.

Most of this work was conducted while at Facebook AI Research. The ablations and long training times would not have been possible without the FAIR compute resources.

I’d also like to thank colleagues at FAIR and NYU CILVR for helpful discussions, Stephen Roller, Margaret Li, Tullie Murrell, Cinjon Resnick, Ethan Perez, Shubho Sengupta and Soumith Chintala.

PyTorch Lightning

In addition, this happens to have been one of the main reasons for creating PyTorch Lightning, rapid iteration of ideas using massive computing resources without getting caught up in all the engineering details required to train models at this scale.

Finally, I’d like to thank my advisors Kyunghyun Cho and Yann LeCun for patience while working on this research while building PyTorch Lightning in parallel.

--

--

⚡️PyTorch Lightning Creator • PhD Student, AI (NYU, Facebook AI research).