The world’s leading publication for data science, AI, and ML professionals.

Self-Supervised Learning (SSL) Overview

Why It Matters, What It Is & Different Types of Self-Supervised Learning

A robot is "self-supervising". Photo by Brett Jordan on Unsplash.
A robot is “self-supervising”. Photo by Brett Jordan on Unsplash.

The term "self-supervised learning" came from the quote of Yann Lecun on April 30, 2019 (tweet and post):

I now call it "self-supervised learning", because "unsupervised" is both a loaded and confusing term.

In this post, I will explain what it is, why it is important, how it can be used and different categories of self-supervised learning in multiple domains including text, image, speech/audio and graph.

What is Self-Supervised Learning?

Self-supervised learning is a subcategory under unsupervised learning because it leverages the unlabeled data. The key idea is to allow the model to learn the data representation without manual labels. Once the model learned how to represent data, and then it can be used for downstream tasks with a smaller amount of labeled data to achieve similar or better performance than the models without self-supervised learning.

It has three steps:

  1. Generate the input data and labels from the unlabeled data programmatically based on the understanding of the data
  2. Pre-training: train the model with data/labels from the previous step
  3. Fine-tune: use the pre-trained model as the initial weights to train for tasks of interests

If we use the data with manual labels instead of automatically generated labels in the second step, it’d be supervised pre-training, known as a step in transfer learning.

Why is Self-Supervised Learning important?

Self-Supervised Learning has been successful in multiple fields i.e., text, image/video, speech, and graph. Essentially, self-supervised learning mines the unlabeled data and boosts the performance. Just like the metaphor of Yann Lecun’s cake (video, slide), this self-supervised learning (the cake génoise) can take millions of bites per sample while supervised learning (the icing) can only take 10 to 10,000 bites. That is to say, self-supervised learning can gain more useful information from each sample than supervised learning.

Human-generated labels usually focus on a specific view of the data. For example, we can describe an image of a horse on the grass (like the image shown below) with only one term "horse" for image recognition, and provide the pixel coordinates for semantic segmentation. However, there is way more information in the data, e.g., the horse’s head and tail are on the opposite side of the body or the horse is usually on top of grass (not underneath). The models can potentially learn better and more complex representations from the data directly instead of manual labels. Not to mention the manual labels can be wrong sometimes, which is harmful to the models. One experiment has shown cleaning up the PASCAL dataset can improve MAP by 13. Even if it’s not compared to the state-of-the-art, we can still see that false labels may result in worse performance.

Photo by David Dibert on Unsplash.
Photo by David Dibert on Unsplash.

Data labeling is costly, time-consuming, and labor-intensive. In addition, the supervised learning approaches would need different labels for new data/labels and new tasks. More importantly, it has been shown that self-supervised pre-training even outperformed supervised pre-training for image-based tasks (i.e., image recognition, object detection, semantic segmentation) [ref]. In other words, extracting information directly from data is more helpful than manual labels. Then maybe we don’t need many costly labels with more advanced self-supervised learning now or in the near future depending on the tasks.

The superiority of self-supervised learning has been validated in image-based tasks thanks to the large-scale labeled datasets in the image domain which has a longer history than other domains in the recent deep learning trend. I believe that similar superiority will be proven in other domains in the future as well. Hence, self-supervised learning is essential for advancing the machine learning field.

How can it be used?

Usually, when a self-supervised model was released, we can download the pre-trained model. Then, we can fine-tune the pre-trained model and use the fine-tuned model for a specific downstream task. For example, the most well-known example of self-supervised learning is probably BERT (ref). BERT was pre-trained on 3.3 billion words in the self-supervised learning fashion. We can fine-tune BERT for a text-related task, such as sentence classification, with much less effort and data than training a model from scratch. Based on a fine-tuned BERT model, I’ve created an app for predicting whether a tweet message is from Elon Musk on Hugging Face (link). I’ll write a separate post about how I created it. Feel free to play with it and have fun!

What are the categories of self-supervised learning?

Let me describe each category with a few words and dive into each of them later.

  1. Generative approaches: recover the original information a. Non-autoregressive: masking a token/pixel and predicting the masked token/pixel (e.g., masked language modeling (MLM)) b. Autoregressive: predicting the next token/pixel

  2. Predictive tasks: design labels based on the understanding, the clustering, or the augmentation of data a: Predict the context (e.g., predict the relative location of the image patches, predicting whether the next fragment is the next sentence) b: Predict the cluster id of each sample c: Predict the image rotation angle

  3. Contrastive Learning (aka, contrastive instance discrimination): set up a binary classification problem based on positive and negative sample pairs created by augmentation
  4. Bootstrapping approaches: use two similar but different networks to learn the same representation from the augmented pairs of the same sample
  5. Regularization: add loss and regularization terms based on the assumptions/intuitions: a: the positive pairs should be similar b: the outputs from different samples in the same batch should be different

Generative approaches

Image by the author.
Image by the author.

Predicting masked inputs by surrounding data is the earliest self-supervised method category. The idea actually can be traced back to the quote, "You shall know a word by the company it keeps." – John Rupert Firth (1957), a linguist. This series of algorithms started from word2vec ([[ref](https://arxiv.org/abs/1810.04805)](https://arxiv.org/abs/1802.05365)) in the text field in 2013. The concept of the continuous bag of words (CBOW) of word2vec is predicting a central word by its neighbors, which is very similar to ELMo (ref) and the masked language modeling (MLM) of BERT (ref). These models were all categorized as non-autoregressive generative approaches. The main differences were that later models used more advanced structures like bidirectional LSTM (for ELMo) and transformer (for BERT), and the recent models generated contextual embeddings.

In the speech field, Mockingjay ([[[ref](https://arxiv.org/abs/2006.15437)](https://openai.com/blog/image-gpt/)](https://arxiv.org/abs/2007.06028)) masked all the dimensions of consecutive features, and TERA (ref) masked the specific subset of dimensions of features. In the image field, OpenAI applied the regime of BERT (ref). In the graph field, GPT-GNN also masked attributes and edges (ref). These methods all masked partial input data and tried to predict them back.

On the other hand, another generative approach is to predict the next token/pixel/acoustic feature. In the text field, GPT series models ([[[ref](https://openai.com/blog/image-gpt/)](https://arxiv.org/abs/1910.12607)](https://arxiv.org/abs/2005.14165) & ref) are the pioneers in this category. APC (ref) and ImageGPT (ref) applied the same idea in the speech and image fields respectively. Interestingly, because adjacent acoustic features are so easy to predict, the model usually is requested to predict the token in the later sequence (at least 3 tokens away).

The great successes of self-supervised learning (especially BERT/GPT) motivated researchers to apply similar generative approaches to other fields like image and speech. However, for image and speech data, it’s harder to generate the masked inputs since choosing a limited amount of text tokens is easier than choosing an unlimited amount of image pixels /acoustic features. The performance improvements were not as good as the text field. Therefore, the researchers also developed many other non-generative approaches in the following sessions.

Predictive tasks

Image by the author.
Image by the author.

The main idea is to design more simplified goals/targets to avoid data generation. The most crucial and challenging point is that the task needs to be at an appropriate difficulty level for the model to learn.

For example in predicting context in the text field, both BERT and ALBERT predicted whether the next fragment is the next sentence. BERT provided negative training samples by randomly swapping the next fragment with another fragment (next-sentence prediction; NSP) while ALBERT provided negative training samples by swapping the previous and the next fragment (sentence-order prediction; SOP). SOP has been shown to outperform NSP (ref). One explanation is that it is so easy to distinguish random sentence pairs by the topic prediction that the model didn’t learn much from the NSP task; whereas SOP allows the model to learn the coherence relationship. As a result, it needs domain knowledge to design good tasks and experiments to validate the task efficiency.

The idea of predicting context like SOP was also applied to the image field (as predicting the relative location of the image patches ([ref](https://ieeexplore.ieee.org/document/9060816))) and the speech field (as predicting the time interval between two acoustic feature groups (ref)).

Another approach is to generate the labels by clustering. In the image field, DeepCluster applied k-means clustering ([[ref](https://arxiv.org/abs/2202.01855)](https://arxiv.org/abs/2106.07447)). In the speech field, HuBERT applied k-means clustering (ref) and BEST-RQ employed a random-projection quantizer (ref).

Other tasks in the image field are: predicting the gray-scale channel by the color channels of images (and vice versa; [[[[ref](https://arxiv.org/abs/1603.09246)](https://arxiv.org/abs/1803.07728)](https://arxiv.org/abs/1609.04802)](https://arxiv.org/abs/1604.07379)), reconstructing the random cropping patch of the images (i.e., inpainting; ref), reconstructing the images of original resolution (ref), predicting the rotation angle of the images (ref), predicting the colors of images (ref1, ref2, ref3) and solving the jigsaw puzzle (ref).

Contrastive Learning (Contrastive Instance Discrimination)

Image by the author.
Image by the author.

The key concept of contrastive learning is to generate the positive and negative training sample pairs based on the understanding of the data. The model needs to learn a function so that the two positive samples have high similarity scores and two negative samples have low similarity scores. As a result, the appropriate sample generation is essential to ensure the model learn the underlying features/structures of the data.

The contrastive learning in the image field applies two different data augmentation from the same original image to generate positive sample pairs and use two different images as the negative sample pairs. Two most critical and challenging parts are the strength of augmentation and the selection of negative sample pairs. If the augmentation is so strong that there is no relationship between the two augmented samples from the same sample, the model cannot learn. Likewise, if the augmentation is so little that the model can easily solve the problem, then the model also cannot learn useful information for downstream tasks. As for selecting negative sample pairs, if we randomly assign two images as a negative pair, they can be the same class (e.g., two images of cats), which introduces conflicting noise to the model. If the negative pairs are very easy to distinguish, then the model cannot learn the underlying features/structures of data. The most famous examples of contrastive learning are SimCLR ([v1](https://arxiv.org/abs/1911.05722), [v2](https://arxiv.org/abs/2003.04297)) and MoCo (v1, v2).

As for speech field, one approach is applying augmentation like SimCLR (Speech SimCLR). Another approach is to use adjacent features as the positive pairs and features from different samples as the negative pairs (e.g, CPC, Wav2vec (v1, v2.0), VQ-wav2vec and Discret BERT). In the graph field, DGI maximized the mutual information between patch representations and global representations of the graphs, and minimized the mutual information between the patch representation of corrupted graphs and the global representation of original graph.

One interesting realization is that classification from the self-supervised learning in text field is actually similar to contrastive learning conceptually. Classification maximizes the output of positive class and minimizes the outputs of negative classes. Likewise, contrastive learning also maximizes the output of positive pairs and minimizes the outputs of negative pairs. The key distinction is that classification has finite amount of negative classes (in case of text tokens) whereas contrastive learning has infinite amount of negative classes (in case of images and acoustic features). In theory, we can design a classifier for images/speeches by given small amount of classes. One class is one original image and the input is the augmented images. However, this would not be practical since it would be only applicable to limited amount of images/classes.

Bootstrapping approaches

Image by the author.
Image by the author.

Researchers further developed bootstrapping approaches to avoid using negative examples since it’s computationally intensive for training and not easy to select good negative examples. The key ideas of bootstrapping approaches are 1) to generate a positive pair of samples from two augmentations of the same original sample (just like contrastive learning); 2) to set up one network as the target network (also called teacher network) and another network as the online network (also called student network), which is the same architecture as the target network plus an additional feedforward layer (called the predictor); 3) to fix the weights of the target/teacher network and only update the online/student network; 4) to update the weights of the target/teacher network based on the weights of the online/student network.

The most important designs are 1) the online network needs to have the predictor (an additional layer); 2) only the weights of the online network can be updated; Otherwise the networks would collapse (i.e., outputting the same values regardless of the inputs).

In the image field, BYOL updated the weights of the target/teacher network by taking the exponential moving average (EMA) of the weights of the online/student network ([ref](https://arxiv.org/abs/2011.10566)); whereas SimSiam just simply copied the weights over (ref).

Data2vec from Meta is a unified framework for image, speech, and text fields (ref). It also takes EMA to update target/teacher network but it uses masking prediction task. It feeds the target/teacher network the original data and the online/student network the masked data. One important design is that its objective is to predict the averaged embedding of masked input regions/tokens of top few layers in the target/teacher network.

Regularization

Image by the author.
Image by the author.

This is another approach only needs positive pairs without negative examples. Surprisingly, these methods can use the identical architectures for the two networks and they also don’t need the ‘stop gradient’ mechanism to only update one of the networks during training. By adding extra regularization terms, the model also does not collapse. The objective function terms include:

  1. Invariance: the loss term keeps the two embeddings from the same positive pair as similar as possible. Barlow Twins‘s and DeLoRes‘s invariance terms seek to equate the diagonal elements of the cross-correlation matrix to 1 in the image field and the audio field respectively; In the image field, VICReg minimizes the mean-squared euclidean distance between two embeddings (ref).
  2. Variance: the regularization term keeps the samples in the same batch different enough since they are not the same sample. Barlow Twins‘s and DeLoRes‘s redundancy reduction term tries to equate the off-diagonal elements of the cross-correlation matrix to 0 in the image field and audio field respectively. In the image field, VICReg’s variance term uses a hinge loss to keep the standard deviation of the embedding outputs across the samples in the same batch above a threshold ([ref](https://arxiv.org/abs/2105.04906)). VICReg’s covariance term minimizes the magnitude of the off-diagonal terms in the covariance matrix to decorrelate each pair of embeddings. This term can greatly boost the performance and maximizes the efficiency of using all the dimensions of embedding vectors. However, it is not required for preventing informational collapse (ref).

The VICReg paper has shown that VICReg is more robust to different network architecture comparing to other self-supervised frameworks (Barlow Twins and SimCLR). Therefore, it can enable the multi-modal application in the future.

Summary

This article provides an overview of the history and progress of self-supervised learning (SSL). SSL evolves from masked prediction, next-token prediction, contrastive learning to bootstrapping and regularization in multiple modalities of text, image, audio/speech and graph. Initially, the model recovers the partial data to learn so that it doesn’t need manual labels. Then the model can learn by the tasks designed based on the data understanding. With data augmentation, positive and negative pair examples enable contrastive learning. Most surprisingly, with the bootstrapping techniques or regularization terms, the model can even learn without any negative examples. With a better understand of SSL in the foreseeable future, I believe we can develop more robust models with less data, time and efforts. I cannot wait to see more advanced progress in this exciting field!

I am Jack Lin, a senior data scientist at C3.ai, and I’m passionate about Deep Learning and machine learning. You can check out my other articles on Medium!

Photo by Andrea De Santis on Unsplash
Photo by Andrea De Santis on Unsplash

Related Articles