Demystifying PyTorch’s WeightedRandomSampler by example

A straightforward approach to dealing with imbalanced datasets

Chris Hughes
Towards Data Science

--

Recently, I found myself in the familiar situation of working with a vastly imbalanced dataset, which was impacting the training of my CNN model on a computer vision task. Whilst there are various ways of approaching this, the findings of a study into handling class imbalance when training CNN models on different datasets concluded that, in almost all cases, the best strategy was oversampling the minority class(es); increasing the frequency that images from these classes are seen by the model during training.

However, whilst the idea seems simple enough, implementing this in PyTorch usually involves interacting with the somewhat enigmatic WeightedRandomSampler. The documentation for WeightedRandomSampler is quite scarce, both with regards to how it works, and how to set the parameters to ensure that it behaves as we would expect. Despite having used this many times in the past, when returning to use it after a long absence, I have often found myself trawling through various forum and StackOverflow posts to ensure that I am setting it correctly. Whilst there is a great blog post which provides a mathematically rigorous breakdown of how it is implemented behind the scenes, as we humans are often notoriously poor at understanding probability theory, it can be difficult to gain an intuition from this alone.

In this article, I shall take a pragmatic approach to understanding the behaviour of WeightedRandomSampler, with the aim of answering the following questions:

  • How do I calculate the weights that will be used to balance my dataset?
  • As this approach is based on probability, can I be certain that this will balance the dataset in the way that I want?
  • Will all of my dataset will be seen during training?
  • What if I don’t want to balance the dataset evenly, but achieve some other ratio?

We shall do this by interacting with the WeightedRandomSampler object in the context of dealing with a real-world dataset, before running a simple experiment to determine whether balancing the dataset results in any performance improvements for our simple problem.

Tl;dr: If you just want to see some working code that you can use directly, all of the code required to replicate this post is available as a GitHub gist here.

Creating an imbalanced dataset

First, let’s download some data to use as an example. Here, I am using the Oxford Pets dataset, which contains 37 different categories of cats and dogs. On Linux, we can download this with the following commands:

wget https://thor.robots.ox.ac.uk/~vgg/data/pets/images.tar.gz -P data/pets
ls data/pets
tar -xzf data/pets/images.tar.gz -C data/pets
rm data/pets/images.tar.gz

Now that we have some data, I often find that a good way of exploring this is by creating a pandas DataFrame. We can use pathlib from the standard library to quickly get a list of all image paths and, as the class names are contained in each file name, we can extract these as part of the same step.

We can now use these to easily create a DataFrame, which we can use to quickly inspect the distribution of our labels:

From this, we can see that the dataset is well balanced, with almost all classes having around 200 images. Whilst this would usually be a good thing, to make the dataset suitable for our purposes, we can define a function to extract an imbalanced subset of this:

Selecting two of the categories, we can use this to create our imbalanced dataset, as demonstrated below:

Let’s encode our class labels with integer values and create a lookup to use in our datasets.

Here, I chose the Siamese and Birman breeds of cat, as they have a passing similarity based on the preview images on the dataset website. We can confirm this by randomly inspecting some of the images.

Based on this, these classes seem to be a suitable choice for a simple, but non-trivial task.

Visualising batch distributions

Now that we have defined our dataset, let’s explore the effects of the class imbalance by inspecting the distribution of each batch that the model will see during a single training epoch.

As we do not need to load the images at this point, lets create a tensor dataset from the labels and the index of each image so that we can iterate through this quickly.

We can define a function which does this, by keeping track of the classes and indexes seen during each batch and plotting these, as seen below:

We can now use this to explore how our data looks.

From this, we can clearly observe the effects of the imbalance, as some batches don’t contain any images from our minority class at all! Additionally, we can see that every image in our dataset would been seen during training. When we take the average representation across all batches, we observe the same proportions as in our dataset, which is as we would expect.

Balancing our dataset with WeightedRandomSampler

Now, let’s look at how we can balance our dataset using WeightedRandomSampler.

The first thing that we need to do is to calculate the weights that will be used to sample each image; from the docs, we can see that we need a weight for each image in the dataset. In my opinion, the most confusing part about this is that these weights do not have to sum to 1. In reality, these weights represent the probability that an image will be selected, PyTorch simply scales these into [0, 1] range behind the scenes for convenience.

Now that we understand what we need, let’s look at how we can calculate these weights. First, we need to calculate how many images belong to each of our classes, using Pandas we can do this as demonstrated below:

Now, that we have our class counts, we can calculate the weight for each class by taking the reciprocal of the count. This will ensure that classes with a higher representation will have a smaller weight.

Now, we simply need to assign the appropriate weight to each sample based on its class. In practice, we can do this directly from the class counts, as demonstrated below:

Next, we can create our sampler and DataLoader:

Here, we can see that we have provided our calculated sample weights as an argument and set replacement as True; without this, we would not be able to oversample at all. For now, we have just set the number of samples as the length of our dataset, but we will discuss this more later.

Once again, we can visualise the distribution of our DataLoader batches, this time using WeightedRandomSampler .

From this, we can see that our batches are quite well balanced! Looking at the number of images seen, we can see that the sampler has done this by oversampling the minority class and undersampling the majority class.

To understand which images are being selected in more detail, we can create a DataFrame containing the number of times that each image was seen during this epoch.

To make this data easier to interpret, we can represent this as an Empirical cumulative distribution function plot, using the snippet below:

From this, we can see that to achieve our desired proportion, each image in our minority class was seen at least 5 times, with some as many as 13 times! In contrast, many of the images from our majority class were not seen at all!

You may note that it seems odd that some images from our majority class were seen multiple times whilst others were not seen at all, which does not seem ideal. Unfortunately, this is a trade off with a probabilistic based sampling method. Whilst we could set our WeightedRandomSampler to sample without replacement, this would also prevent us from oversampling; so is not useful to us here!

The next logical question to ask is can we ensure that every image is seen during a training run.

When will all of my images be seen during training?

Adjusting the number of samples per epoch

As seen above, 97 of the images from our majority class would not have been seen during the training epoch. The reason for this due to the num_samples argument that we defined when we created our WeightedRandomSampler instance. As we specified the number of samples to be the equal to the total number of images in our original, unbalanced, dataset, it makes sense that our sampler will have to ignore some images in order to oversample our minority class.

Adjusting this parameter to double the size of our original dataset, we can see that more of our images are seen over the course of an epoch.

However, this does introduce some confusion into what exactly an epoch represents. In Machine Learning, we define an epoch as a single pass through the entire dataset. When each sample is seen only once, this definition makes it very clear to understand exactly what the model has seen up to the current point in training; this becomes less clear when images are being sampled with varying frequencies.

As the notion of an epoch is largely to help us track progress during a training run and has no bearing on the model itself — which just sees a constant stream of images — I prefer to leave the num_samples set to the length of the dataset and trust that all images will be seen at some point as we train for more epochs.

How many epochs will it take before the model has seen all images?

Resetting our sampler back to its original parameters:

let’s explore how long it takes for all of our dataset to be seen during a training run.

To do this, we can set up a simple experiment, where we keep track of all images seen over the course of multiple epochs and plot the number of unique images that have been so far by the end of each epoch. We can do this using the following snippet:

Here, we can see that it took 10 epochs until all images were seen at least once. However, as this is dependent on probability, there is a high likelihood that this number will change!

To get a more robust estimate, let’s take inspiration from Monte Carlo methods and run this experiment multiple times and observe the distribution. To make this easier to interpret, we can represent this as a kernel density estimate plot.

From this, we can see that, most of the time, around 9–10 epochs are necessary to be confident that all of the data will be seen.

Of course, this estimate highly depends on the proportion of imbalance in the underlying dataset. It seems logical to suggest that as the ratio of imbalance increases, more epochs will be needed. We can confirm this by repeating our trial for different levels of imbalance, as demonstrated below:

From the plot, we can observe that our intuition was correct. However, it is interesting to note that when using WeightedRandomSampler on a balanced dataset, it takes around 5 epochs to see all of the data; which suggests that this is not ideal in this case!

Obtaining non-balanced dataset proportions

Now that we have explored how we can use WeightedRandomSampler to balance our training set, let’s briefly examine how we can adjust our class weights to achieve any proportion that we would like.

One example of where this may be useful is in object detection, where we would like most of our training to be focused on images containing the item that we wish to detect, but the available datasets often contain a high number of background images.

As an example, let's investigate how we can heavily imbalance the dataset the other way; such that our minority class becomes dominant.

Once again, first we need to calculate the total number of samples for each class:

Next, we can define the target proportion for each class:

Finally, to calculate our sample weights, we simply have to multiply each weight by our corresponding target proportion, as demonstrated below:

As before, lets pass these weights into our sampler and visualise our batches.

From this, we can see that we have obtained the distribution that we are looking for.

Does oversampling improve performance?

Hopefully, at this point, we have developed an intuition behind how WeightedRandomSampler works. However, you may be thinking, does simply showing the same image to the network more frequently really make a difference? Let’s set up a small experiment to investigate.

Let’s train an image classifier on our imbalanced dataset. Of course, the results will heavily depend on many factors — such as the model and dataset used — but this is designed as a simple example. Here, I have selected the following based on training recipes that have consistently worked well for me in the past:

  • Model: ResNet-RS50
  • Optimizer: AdamW
  • LR scheduler: Cosine decay with warmup
  • Images resized to 224

As our dataset is quite small, to simplify things further, lets only train the final linear layer used for classification in our architecture; as our images are very similar to the ImageNet images that the model has been pretrained on, it should be safe to assume that the features learned in the backbone should work well enough here.

To evaluate, we can the validation set we created earlier, which contains a balanced sample of images that were not seen during training.

All training was carried out using a single NVIDIA V100 GPU. To handle the training loop, I used the PyTorch-accelerated library. However, as PyTorch-accelerated handles all distributed training concerns, the same code could be used on multiple GPUs — without having to change WeightedRandomSampler to a distributed sampler — simply by defining a configuration file, as described here.

We can define a script to conduct this experiment, as seen below:

Packages used:

Selecting the best metrics after training for 10 epochs, by running the commands:

I obtained the following results:

Oh no, this looks like oversampling actually made things worse!

However, with the way that our experiment is set up, this is not hugely surprising, as we are restricting the model’s exposure to new images from the majority class in favour of repeatedly showing a small number of images from the minority class. If only there was a way to get more out of the small set of Birman images…

Adding Data Augmentation

To try and help the model learn more from our images, we can use data augmentation to generate slightly modified versions of each image during training.

In this case, I decided to use the predefined RandAugment policy from timm, as this requires minimal hyperparameter tuning; timm’s RandAugment implementation is described in detail here.

Applying RandAugment to an image from our training dataset

Let’s run the experiment again, this time with data augmentation:

The results are presented below:

This time, we can see that the combination of data augmentation and oversampling resulted in a significant increase in performance!

Conclusion

Hopefully that has provided a somewhat comprehensive overview of how to get started with WeightedRandomSampler, and helped to illustrate how it can be used.

All of the code required to replicate this post is available as a GitHub gist.

Chris Hughes is on LinkedIn.

References

Dataset Used

--

--

Principal Machine Learning Engineer/Scientist Manager at Microsoft. All opinions are my own.