The world’s leading publication for data science, AI, and ML professionals.

A Practical Guide to Contrastive Learning

How to build your very first SimSiam model with FashionMNIST

Contrastive learning has many use cases these days. From NLP and computer vision to recommendation systems, contrastive learning can be used to learn underlying data representations without any explicit labels, which can then be used for downstream classification, detection, similarity search, etc.

There are many online resources to help the audience understand the basic ideas of Contrastive Learning so that I won’t add one more blog post repeating the information. Instead, I will show you how to convert your supervised learning problem into a contrastive learning problem in this article. Specifically, I will start with a basic classification model for the FashionMNIST (MIT licence). Then, I will proceed to an advanced problem with limited training labels (e.g., reducing the full training set of 60,000 labels to 1,000). I will introduce SimSiam, a state-of-the-art method for contrastive learning, and show step-by-step instructions on modifying the original linear layers in the SimSiam style. Ultimately, I’ll show the results – SimSiam could improve the F1 score by 15% with a very basic configuration.

Image source: https://pxhere.com/en/photo/395408
Image source: https://pxhere.com/en/photo/395408

Now, let’s start. First, we’ll load in the FashionMNIST dataset. A custom FashionMNIST class is used to obtain a subset of the training set named the finetune_dataset. The source code for the customer FashionMNIST class will be given at the end of this article.

import matplotlib.pyplot as plt

import torchvision.transforms as transforms

from FashionMNIST import FashionMNIST

train_dataset = FashionMNIST("./FashionMNIST", 
                             train=True, 
                             transform=transforms.ToTensor(), 
                             download=True,
                             )
test_dataset = FashionMNIST("./FashionMNIST", 
                            train=False, 
                            transform=transforms.ToTensor(), 
                            download=True,
                            )
finetune_dataset = FashionMNIST("./FashionMNIST", 
                                train=True, 
                                transform=transforms.ToTensor(), 
                                download=True, 
                                first_k=1000,
                                )

# Create a subplot with 4x4 grid
fig, axs = plt.subplots(4, 4, figsize=(8, 8))

# Loop through each subplot and plot an image
for i in range(4):
    for j in range(4):
        image, label = train_dataset[i * 4 + j]  # Get image and label
        image_numpy = image.numpy().squeeze()    # Convert image tensor to numpy array
        axs[i, j].imshow(image_numpy, cmap='gray')  # Plot the image
        axs[i, j].axis('off')  # Turn off axis
        axs[i, j].set_title(f"Label: {label}")  # Set title with label

plt.tight_layout()  # Adjust layout
plt.show()  # Show plot

The code will show a grid of images from the train_dataset

First 16 images from the FashionMNIST training set. Image by author.
First 16 images from the FashionMNIST training set. Image by author.

Next, we’ll define the supervised classification model. The architecture contains a backbone of convolutional layers and an MLP head of two linear layers. This will set a consistent baseline for the following experiments, as SimSiam will only replace the MLP head for contrastive learning purposes.

import torch.nn as nn

class supervised_classification(nn.Module):

    def __init__(self):
        super(supervised_classification, self).__init__()

        self.backbone = nn.Sequential(
                                nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
                                nn.ReLU(),
                                nn.BatchNorm2d(32),
                                nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
                                nn.ReLU(),
                                nn.BatchNorm2d(64),
                                nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
                                nn.ReLU(),
                                nn.BatchNorm2d(128),
        )

        self.fc = nn.Sequential(
                                nn.Linear(128*4*4, 32),
                                nn.ReLU(),
                                nn.Linear(32, 10),
        )

    def forward(self, x):
        x = self.backbone(x).view(-1, 128 * 4 * 4)

        return self.fc(x)

We’ll train the model for 10 epochs:

import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import wandb

wandb_config = {
    "learning_rate": 0.001,
    "architecture": "fashion mnist classification full training",
    "dataset": "FashionMNIST",
    "epochs": 10,
    "batch_size": 64,
    }

wandb.init(
    # set the wandb project where this run will be logged
    project="supervised_classification",
    # track hyperparameters and run metadata
    config=wandb_config,
)

# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

supervised = supervised_classification()

optimizer = optim.SGD(supervised.parameters(), 
                      lr=wandb_config["learning_rate"], 
                      momentum=0.9, 
                      weight_decay=1e-5,
                      )

train_dataloader = DataLoader(train_dataset, 
                              batch_size=wandb_config["batch_size"], 
                              shuffle=True,
                              )

# Training loop
loss_fun = nn.CrossEntropyLoss()
for epoch in range(wandb_config["epochs"]):
    supervised.train()

    train_loss = 0
    for batch_idx, (image, target) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):
        optimizer.zero_grad()

        prediction = supervised(image)

        loss = loss_fun(prediction, target)
        loss.backward()
        optimizer.step()

        wandb.log({"training loss": loss})

torch.save(supervised.state_dict(), "weights/fully_supervised.pt")

Using the classification_report from the scikit-learn package, we’ll get the following results:

from sklearn.metrics import classification_report

supervised = supervised_classification()

supervised.load_state_dict(torch.load("weights/fully_supervised.pt"))
supervised.eval()
supervised.to(device)

target_list = []
prediction_list = []
for batch_idx, (image, target) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))):
    with torch.no_grad():
        prediction = supervised(image.to(device))

    prediction_list.extend(torch.argmax(prediction, dim=1).detach().cpu().numpy())
    target_list.extend(target.detach().cpu().numpy())

print(classification_report(target_list, prediction_list))

# Create a subplot with 4x4 grid
fig, axs = plt.subplots(4, 4, figsize=(8, 8))

# Loop through each subplot and plot an image
for i in range(4):
    for j in range(4):
        image, label = test_dataset[i * 4 + j]  # Get image and label
        image_numpy = image.numpy().squeeze()    # Convert image tensor to numpy array
        prediction = supervised(torch.unsqueeze(image, dim=0).to(device))
        prediction = torch.argmax(prediction, dim=1).detach().cpu().numpy()
        axs[i, j].imshow(image_numpy, cmap='gray')  # Plot the image
        axs[i, j].axis('off')  # Turn off axis
        axs[i, j].set_title(f"Label: {label}, Pred: {prediction}")  # Set title with label

plt.tight_layout()  # Adjust layout
plt.show()  # Show plot
Classification results of the fully supervised model. Image by author.
Classification results of the fully supervised model. Image by author.

Now, let’s think about a new problem. What should we do if we’re given a limited subset of the training set labels, e.g., only 1000 images out of the total 60,000 images are annotated? The natural idea is to simply train the model on the limited annotated dataset. So without changing the backbone, we let the model train on the limited subset for 100 epochs (we increase the epochs to have a fair comparison to our SimSiam training):

import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import wandb

wandb_config = {
    "learning_rate": 0.001,
    "architecture": "fashion mnist classification full training on finetune set",
    "dataset": "FashionMNIST",
    "epochs": 100,
    "batch_size": 64,
    }

wandb.init(
    # set the wandb project where this run will be logged
    project="supervised_classification",
    # track hyperparameters and run metadata
    config=wandb_config,
)

# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

supervised = supervised_classification()

optimizer = optim.SGD(supervised.parameters(), 
                      lr=wandb_config["learning_rate"], 
                      momentum=0.9, 
                      weight_decay=1e-5,
                      )

finetune_dataloader = DataLoader(finetune_dataset, 
                                 batch_size=wandb_config["batch_size"], 
                                 shuffle=True,
                                 )

# Training loop
loss_fun = nn.CrossEntropyLoss()
for epoch in range(wandb_config["epochs"]):
    supervised.train()

    train_loss = 0
    for batch_idx, (image, target) in enumerate(tqdm.tqdm(finetune_dataloader, total=len(finetune_dataloader))):
        optimizer.zero_grad()

        prediction = supervised(image)

        loss = loss_fun(prediction, target)
        loss.backward()
        optimizer.step()

        wandb.log({"training loss": loss})

torch.save(supervised.state_dict(), "weights/fully_supervised_finetunedataset.pt")
Fully supervised training loss on the limited training set. Image by author.
Fully supervised training loss on the limited training set. Image by author.
Quantitative evaluation results on the testing set. Note the performance drops more than 25% by reducing the training size. Image by author.
Quantitative evaluation results on the testing set. Note the performance drops more than 25% by reducing the training size. Image by author.

Now it’s time for some contrastive learning. To mitigate the issue of insufficient annotation labels and fully utilize the large quantity of unlabelled data, contrastive learning could be used to effectively help the backbone learn the data representations without a specific task. The backbone could be frozen for a given downstream task and only train a shallow network on a limited annotated dataset to achieve satisfactory results.

The most commonly used contrastive learning approaches include SimCLR, SimSiam, and MOCO (see my previous article on MOCO). Here, we compare SimCLR and SimSiam.

SimCLR calculates over positive and negative pairs within the data batch, which requires hard negative mining, NT-Xent loss (which extends the cosine similarity loss over a batch) and a large batch size. SimCLR also requires the LARS optimizer to accommodate a large batch size.

SimSiam, however, uses a Siamese architecture, which avoids using negative pairs and further avoids the need for large batch sizes. The differences between SimSiam and SimCLR are given in the table below.

Comparison between SimCLR and SimSiam. Image by author.
Comparison between SimCLR and SimSiam. Image by author.
The SimSiam architecture. Image source: https://arxiv.org/pdf/2011.10566
The SimSiam architecture. Image source: https://arxiv.org/pdf/2011.10566

We can see from the figure above that the SimSiam architecture only contains two parts: the encoder/backbone and the predictor. During training time, the gradient propagation of the Siamese part is stopped, and the cosine similarity is calculated between the outputs of the predictors and the backbone.

So, how do we implement this architecture in reality? Continuing on the supervised classification design, we keep the backbone the same and only modify the MLP layer. In the supervised learning architecture, the MLP outputs a 10-element vector indicating the probabilities of the 10 classes. But for SimSiam, the purpose is not to perform "classification" but to learn the "representation," so we need the output to be of the same dimension as the backbone output for loss calculation. And the negative_cosine_similarity is given below:

import torch.nn as nn
import matplotlib.pyplot as plt

class SimSiam(nn.Module):

    def __init__(self):

        super(SimSiam, self).__init__()

        self.backbone = nn.Sequential(
                                nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
                                nn.ReLU(),
                                nn.BatchNorm2d(32),
                                nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
                                nn.ReLU(),
                                nn.BatchNorm2d(64),
                                nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
                                nn.ReLU(),
                                nn.BatchNorm2d(128),
        )

        self.prediction_mlp = nn.Sequential(nn.Linear(128*4*4, 64),
                               nn.BatchNorm1d(64),
                               nn.ReLU(),
                               nn.Linear(64, 128*4*4),
        )

    def forward(self, x):
        x = self.backbone(x)

        x = x.view(-1, 128 * 4 * 4)
        pred_output = self.prediction_mlp(x)
        return x, pred_output

cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def negative_cosine_similarity_stopgradient(pred, proj):
    return -cos(pred, proj.detach()).mean()

The pseudo-code for training the SimSiam is given in the original paper below:

Training pseudo-code for SimSiam. Source: https://arxiv.org/pdf/2011.10566
Training pseudo-code for SimSiam. Source: https://arxiv.org/pdf/2011.10566

And we convert it into real training code:

import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import RandAugment

import wandb

wandb_config = {
    "learning_rate": 0.0001,
    "architecture": "simsiam",
    "dataset": "FashionMNIST",
    "epochs": 100,
    "batch_size": 256,
    }

wandb.init(
    # set the wandb project where this run will be logged
    project="simsiam",
    # track hyperparameters and run metadata
    config=wandb_config,
)

# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

simsiam = SimSiam()

random_augmenter = RandAugment(num_ops=5)

optimizer = optim.SGD(simsiam.parameters(), 
                      lr=wandb_config["learning_rate"], 
                      momentum=0.9, 
                      weight_decay=1e-5,
                      )

train_dataloader = DataLoader(train_dataset, batch_size=wandb_config["batch_size"], shuffle=True)

# Training loop
for epoch in range(wandb_config["epochs"]):
    simsiam.train()

    print(f"Epoch {epoch}")
    train_loss = 0
    for batch_idx, (image, _) in enumerate(tqdm.tqdm(train_dataloader, total=len(train_dataloader))):
        optimizer.zero_grad()

        aug1, aug2 = random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0, 
                        random_augmenter((image*255).to(dtype=torch.uint8)).to(dtype=torch.float32) / 255.0

        proj1, pred1 = simsiam(aug1)
        proj2, pred2 = simsiam(aug2)

        loss = negative_cosine_similarity_stopgradient(pred1, proj2) / 2 + negative_cosine_similarity_stopgradient(pred2, proj1) / 2
        loss.backward()
        optimizer.step()

        wandb.log({"training loss": loss})

    if (epoch+1) % 10 == 0:
        torch.save(simsiam.state_dict(), f"weights/simsiam_epoch{epoch+1}.pt")

We trained for 100 epochs as a fair comparison to the limited supervised training; the training loss is shown below. Note: Due to its Siamese design, SimSiam could be very sensitive to hyperparameters like learning rate and MLP hidden layers. The original SimSiam paper provides a detailed configuration for the ResNet50 backbone. For the ViT-based backbone, we recommend reading the MOCO v3 paper, which adopts the SimSiam model in a momentum update scheme.

Training loss for SimSiam. Image by author.
Training loss for SimSiam. Image by author.

Then, we run the trained SimSiam on the testing set and visualize the representations using UMAP reduction:

import tqdm
import numpy as np

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

simsiam = SimSiam()                      

test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
simsiam.load_state_dict(torch.load("weights/simsiam_epoch100.pt"))

simsiam.eval()
simsiam.to(device)

features = []
labels = []
for batch_idx, (image, target) in enumerate(tqdm.tqdm(test_dataloader, total=len(test_dataloader))):

    with torch.no_grad():

        proj, pred = simsiam(image.to(device))

    features.extend(np.squeeze(pred.detach().cpu().numpy()).tolist())
    labels.extend(target.detach().cpu().numpy().tolist())

import plotly.express as px
import umap.umap_ as umap

reducer = umap.UMAP(n_components=3, n_neighbors=10, metric="cosine")
projections = reducer.fit_transform(np.array(features))

px.scatter(projections, x=0, y=1,
    color=labels, labels={'color': 'Fashion MNIST Labels'}
)
The UMAP of the SimSiam representation over the testing set. Image by author.
The UMAP of the SimSiam representation over the testing set. Image by author.

It’s interesting to see that there are two small islands in the reduced-dimension map above: class 5, 7, 8, and some 9. If we pull out the FashionMNIST class list, we know that these classes correspond to footwear such as "Sandal," "Sneaker," "Bag," and "Ankle boot." The big purple cluster corresponds to clothing classes like "T-shirt/top," "Trousers," "Pullover," "Dress," "Coat," and "Shirt." The SimSiam demonstrates learning a meaningful representation in the vision domain.


Now that we have the correct representations, how can they benefit our classification problem? We simply load the trained SimSiam backbone into our classification model. However, instead of fine-tuning the whole architecture in the limited training set, we fine-tuned the linear layers and froze the backbone because we didn’t want to corrupt the representation already learned.

import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import wandb

wandb_config = {
    "learning_rate": 0.001,
    "architecture": "supervised learning with simsiam backbone",
    "dataset": "FashionMNIST",
    "epochs": 100,
    "batch_size": 64,
    }
wandb.init(
    # set the wandb project where this run will be logged
    project="simsiam-finetune",
    # track hyperparameters and run metadata
    config=wandb_config,
)

# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

supervised = supervised_classification()

model_dict = supervised.state_dict()
simsiam_dict = {k: v for k, v in model_dict.items() if k in torch.load("simsiam.pt")}
supervised.load_state_dict(simsiam_dict, strict=False)

finetune_dataloader = DataLoader(finetune_dataset, batch_size=32, shuffle=True)

for param in supervised.backbone.parameters():
    param.requires_grad = False
parameters = [para for para in supervised.parameters() if para.requires_grad]
optimizer = optim.SGD(parameters, 
                      lr=wandb_config["learning_rate"], 
                      momentum=0.9, 
                      weight_decay=1e-5,
                      )

# Training loop
for epoch in range(wandb_config["epochs"]):
    supervised.train()

    train_loss = 0
    for batch_idx, (image, target) in enumerate(tqdm.tqdm(finetune_dataloader)):
        optimizer.zero_grad()

        prediction = supervised(image)

        loss = nn.CrossEntropyLoss()(prediction, target)
        loss.backward()
        optimizer.step()

        wandb.log({"training loss": loss})

torch.save(supervised.state_dict(), "weights/supervised_with_simsiam.pt")

Here is the evaluation result of the SimSiam-pre-trained classification model. The average F1 score is increased by 15% compared to the supervised-only method.

The classification scores of the SimSiam model fine-tune on the limited set. Image by author.
The classification scores of the SimSiam model fine-tune on the limited set. Image by author.

Summary. We showcase a simple but intuitive example, using FashionMNIST for contrastive learning. By using SimSiam for backbone pre-training and only fine-tuning the linear layers on the limited training set (which contains only 2% of the labels of the full training set), we increased the average F1 score by 15% over the fully supervised learning method. The trained weights, the notebook, and the customized FashionMNIST dataset class are all included in this GitHub repository.

Give it a try!


References:

  • Chen et al., Exploring simple siamese representation learning. CVPR 2021.
  • Chen et al., A simple framework for contrastive learning of visual representations. ICML 2020.
  • Chen et al., An Empirical Study of Training Self-Supervised Vision Transformers. ICCV 2021.
  • Xiao et al., Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. arXiv preprint 2017. Github: https://github.com/zalandoresearch/fashion-mnist

Related Articles