Implementing math in deep learning papers into efficient PyTorch code: SimCLR Contrastive Loss

Learning to implement advanced math formulas in deep learning papers into performant PyTorch code in 3 steps.

Moein Shariatnia
Towards Data Science

--

Introduction

One of the best ways to deepen your understanding of the math behind deep learning models and loss functions, and also a great way to improve your PyTorch skills is to get used to implementing deep learning papers all by yourself.

Books and blog posts could help you get started in coding and learning the basics in ML / DL, but after studying a few of them and getting good at the routine tasks in the field, you will soon realize that you are on your own in the learning journey and you’ll find most of the resources online as boring and too shallow. However, I believe that if you can study new deep learning papers as they get published and understand the required pieces of math in it (not necessarily all the mathematical proofs behind authors’ theories), and, you are a capable coder who can implement them into efficient code, nothing can stop you from staying up to date in the field and learning new ideas.

Contrastive Loss implementation

I’ll introduce my routine and the steps I follow to implement math in deep learning papers using a not trivial example: The Contrastive Loss in the SimCLR paper.

Here’s the mathematical formulation of the loss:

Contrastive (NT-Xent) loss from the SimCLR paper | from https://arxiv.org/pdf/2002.05709.pdf

I agree that the mere look of the formula could be daunting! and you might be thinking that there must be lot’s of ready PyTorch implementations on GitHub, so let’s use them :) and Yes, you’re right. There are dozens of implementations online. However, I think this is a good example for practicing this skill and could serve as a good starting point.

Steps to implement math in code

My routine in implementing the math in papers into efficient PyTorch code is as follows:

  1. Understand the math, explain it in simple terms
  2. Implement an initial version using simple Python “for” loops, no fancy matrix multiplications for now
  3. Convert your code into efficient matrix-friendly PyTorch code

OK, let’s get straight to the first step.

Step 1: Understanding the math and explaining it in simple terms

I’m assuming that you have a basic knowledge of linear algebra and are familiar with mathematical notations. If you’re not, you can use this tool to know what each of these symbols are and what they do in math, simply by drawing the symbol. You can also check this awesome Wikipedia page where most of the notations are described. These are the opportunities in where you learn new stuff, in searching and reading what is needed at the time you need it. I believe it’s a more efficient way of learning, instead of starting with a math textbook from scratch and putting it away after a few days :)

Back to our business. As the paragraph above the formula adds more context, in the SimCLR learning strategy you start with N images and transform each of them 2 times to get augmented views of those images (2*N images now). Then, you pass these 2 * N images through a model to get embedding vectors for each of them. Now, you want to make the embedding vectors of the 2 augmented views of the same image (a positive pair) closer in the embedding space (and do the same for all the other positive pairs). One way to measure how similar (close, in the same direction) two vectors are, is by using Cosine Similarity which is defined as sim(u, v) (look up the definition in the image above).

In simple terms, what the formula is describing is that for each item in our batch, which is the embedding of one of the augmented views of an image, (Remember: the batch contains all the embeddings of the augmented views of different images → if starting w/ N images, the batch has a size of 2*N), we first find the embedding of the other augmented view of that image to make a positive pair. Then, we calculate the cosine similarity of these two embeddings and exponentiate it (the numerator of the formula). Then, we calculate the exponentiate of the cosine similarity of all the other pairs we can build with our first embedding vector with which we started (except for the pair with itself, this is what that 1[k!=i] means in the formula), and we sum them up to build the denominator. We can now divide the numerator by denominator and take the natural Log of that and flip the sign! Now, we have the loss of the first item in our batch. We need to just repeat the same process for all the other items in the batch and then take the average to be able to call .backward() method of PyTorch to calculate the gradients.

Step 2: Implementing it using simple Python code, with naive “for” loops!

Simple Pythonic implementation, using slow “for” loops

Let’s go over the code. Let’s say we have two images: A and B. The variable aug_views_1 holds the embeddings (each with size 3) of one augmented view of these two images (A1 and B1), same as aug_views_2 (A2 and B2); so, the first item in both matrixes are related to image A and the second of item of the both is related to image B. We concatenate the two matrixes into the projections matrix (which has 4 vectors in it: A1, B1, A2, B2).

To keep the relation of the vectors in the projections matrix, we define pos_pairs dictionary to store which two items are related in the concatenated matrix. (soon I’ll explain the F.normalize() thing!)

As you see in the next lines of code, I’m going over the items in the projections matrix in a for loop, I find the related vector of that using our dictionary and then I calculate the cosine similarity. You might wonder why you do not divide by the size of the vectors, as the cosine similarity formula suggests. The point is that before starting the loop, using the F.normalize function, I’m normalizing all the vectors in our projection matrix to have the size of 1. So, there’s no need to divide by the size in the line where we’re calculating the cosine similarity.

After building our numerator, I’m finding all the other indexes of vectors in the batch (except for the same index i), to calculate the cosine similarities consisting the denominator. Finally, I’m calculating the loss by dividing the numerator by denominator and applying the log function and flipping the sign. Make sure to play with the code to understand what is happening in each line.

Step 3: Converting it into efficient matrix-friendly PyTorch code

The problem with the previous python implementation is that it’s too slow to be used in our training pipeline; we need to get rid of the slow “for” loops and convert it into matrix multiplications and array manipulations in order to leverage the parallelization power.

PyTorch implementation

Let’s see what’s happening in this code snippet. This time, I’ve introduced the labels_1 and labels_2 tensors to encode the arbitrary classes to which these images belong, as we need a way to encode the relationship of A1, A2 and B1, B2 images. It does not matter if you choose labels 0 and 1 (as I did) or say 5 and 8.

After concatenating both the embeddings and labels, we start by creating a sim_matrix containing the cosine similarity of all the possible pairs.

How the sim_matrix looks like: the green cells contain our positive pairs, the orange cells are the pairs which need to be ignored in the denominator | Visualization by the author

The visualization above is all you need :) to understand how the code’s working and why we’re doing the steps in there. Considering the first row of the sim_matrix, we can calculate the loss for the first item in the batch (A1) as follows: we need to divide A1A2 (exponentiated) by the sum of A1B1, A1A2, and A1B2 (each exponentiated first) and keep the result in the first item of a tensor storing all the losses. So, we need to first make a mask to find the green cells in the visualization above. The two lines in the code defining the variable mask do exactly this. The numerator is calculated by multiplying our sim_matrix by the mask we just created, and then summing the items of each row (after masking, there will be only one non-zero item in each row; i.e. the green cells). For calculating the denominator, we need to sum over each row, ignoring the orange cells on the diagonal. To do so, we’ll use the .diag() method of PyTorch tensors. The rest is self-explanatory!

Bonus: Using AI assistants (ChatGPT, Copilot, …) to implement the formula

We’ve great tools in our disposal to help us understand and implement the math in deep learning papers. For example, you can ask ChatGPT (or other similar tools) to implement the code in PyTorch after giving it the formula from the paper. In my experience, ChatGPT can be most helpful and provide the best final answers in less trail and errors if you can get yourself somehow to the pythonic-for-loop implementation step. Give that naive implementation to ChatGPT and ask it to convert it into efficient PyTorch code which uses matrix multiplications and tensor manipulations only; you’ll be surprised by the answer :)

Further Reading

I encourage you to check out the following two great implementations of the same idea to learn how you can extend this implementation into considering more nuanced situations, like in supervised contrastive learning setting.

  1. Supervised Contrastive Loss, by Guillaume Erhard
  2. SupContrast, by Yonglong Tian

About Me

I’m Moein Shariatnia, a machine learning developer and medical student, focusing on using deep learning solutions for medical imaging applications. My research is mostly on investigating the generalizability of deep models under various circumstances. Feel free to reach out to me, via Email, Twitter, or LinkedIn.

--

--