The world’s leading publication for data science, AI, and ML professionals.

Class Imbalance

In this post, we discuss the problem of class imbalance and the techniques which might help to improve the performance of the model

(source)
(source)

INTRODUCTION

When I started learning about machine learning and its various subfields, I took various MOOCs, read various articles, books etc. I used various freely available datasets and built models which gave a very good performance. My mindset at that time was that everything relied on the model we use and data played a minor role. This mindset was mostly because the data was available in a clean, balanced form with minor pre-processing required. This belief suffered a major setback when I joined my first machine learning role in a research institute. The institute specialized in medical research and my work focused on ultrasound images. The data available was not clean and structured like the publicly available datasets. Previously I spent most of my time selecting the model and now most of my time was spent cleaning, structuring and sorting the data. Moreover, in publicly available datasets the distribution of classes was almost similar and if not similar each class had adequate representative samples. But in the real world, especially the medical domain there is an inherent class imbalance as the people with a particular disease/condition are far less as compared to healthy people.

Class imbalance if you don’t know is the problem when the number of examples available for one or more classes in a classification problem is far less than other classes. The classes which have a large number of samples are called the majority classes while the classes which have very few samples are called the minority classes. Before joining the role I read about the class imbalance problems but never got the opportunity to work on imbalanced datasets as most of the datasets available online are perfectly balanced with all classes having almost identical numbers of samples. However, one of the first problems I started working on had a very unbalanced distribution. It was a binary classification problem where I was supposed to detect whether given women would deliver term or pre-term. Pre-term means when a woman delivers in less than 37 weeks of gestation while term means when a woman delivers in more than 37 weeks of gestation. Babies born pre-term suffer from various health complications as their organs are not properly developed and require special care for their survival.

According to WHO, every year 15 million babies are born premature and around 1 million babies die due to complications resulting from premature birth. Hence, it is very essential to detect preterm pregnancies so that adequate care can be provided to the baby and his/her life can be saved.

The data available was highly imbalanced having just 300 participants delivering pre-term while the number of women delivering term was around 3000.

The problem was challenging and I started researching the topic of class imbalance and techniques to mitigate it. I discovered that most of the material on class imbalance available online mentioned just two techniques which were over-sampling and under-sampling. After digging deeper and reading some research papers, I discovered some other less known techniques which I would like to discuss in this article along with the known techniques.

Note: The techniques discussed in this post are from a Computer Vision point of view but can be translated to other domains since the core concepts remain the same.

Techniques for handling Class-Imbalance Problem

The techniques for handling Class-Imbalance can be further divided into two broad categories:

  1. Data Level Methods
  2. Classifier Level Methods

Data Level Methods:

Data Level methods are those where we make changes to the distribution of the training set while keeping the algorithm and its subparts such as loss function, optimizer constant. The data level methods aim to vary the dataset in a way to make standard algorithms work.

There are two famous data-level methods readily applied in the Machine Learning domain.

1. Oversampling:

Oversampling or to be more precise minority class over-sampling is a very simple and widely known technique used to solve the problem of Class Imbalance. In this technique, we try to make the distribution of all the classes equal in a mini-batch by sampling an equal number of samples from all the classes thereby sampling more examples from minority classes as compared to majority classes. Practically it is done by increasing the sampling probability of examples belonging to minority class thereby down-weighing the sampling probability of examples belonging to the majority class. The easiest way to decide the sampling probability for all the classes present is to take the inverse of the number of samples present in each class. This will increase the probability of samples belonging to minority classes while will decrease the probability of samples belonging to majority classes. Some papers also use the inverse of the square root of the number of samples to weigh the samples.

The performance of Oversampling depends on the number of representative samples present in the original dataset because in oversampling what we are doing is increasing samples in batch but the number of unique samples remains the same.

Code in Pytorch to implement Minority Class Oversampling

Figure 1: Code for Oversampling using Pytorch WeightedSampler
Figure 1: Code for Oversampling using Pytorch WeightedSampler
Figure 2: Batch Distribution after using Weighted Sampler for Oversampling
Figure 2: Batch Distribution after using Weighted Sampler for Oversampling

2. Undersampling:

Another popular technique used for solving the class imbalance problems is Undersampling which does the opposite of Oversampling. In this, we randomly remove samples from the majority class until all the classes have the same number of samples. This technique has a significant disadvantage in that it discards data which might lead to a reduction in the number of representative samples in the dataset. To fix this shortcoming various methods are used which carefully remove redundant samples thereby preserving the variability of the dataset.

One naïve way to implement undersampling is to sample random samples from the majority class equal to the number of samples in the minority class while keeping the distribution of the minority class constant. To implement undersampling in Pytorch, WeightedRandomSampler from Pytorch might be used where the weights of the majority class should be reduced such that almost the same number of examples are sampled from the majority and minority classes.

For example: Say you have a binary classification task where the number of samples in the majority class is 2000 while the number of samples in the minority class is 1000. To do undersampling using WeightedSampler we can give all the examples belonging to minority class the weight of 1 i.e each example has 100 percent chance to be sampled while for the examples belonging to majority class we can set the weight to 0.5(num_samples_minority/num_samples_majority) which implies that each example belonging to majority class has to the 50/50 chance to be sampled.

Code in Pytorch to implement Undersampling:

Figure 3: Code for Undersampling using Pytorch WeightedSampler
Figure 3: Code for Undersampling using Pytorch WeightedSampler
Figure 4. Batch Distribution after using WeightedSampler for Undersampling
Figure 4. Batch Distribution after using WeightedSampler for Undersampling

Algorithm Level Methods:

Algorithm level methods are the class of methods to handle class imbalance where the aim is to keep the dataset constant but altering the training or inference algorithms.

In this post, I will discuss the two methods which I readily use when working with imbalanced datasets.

1. Cost-Sensitive Learning:

In cost-sensitive learning, the basic idea is to assign different costs to classes according to their distribution. There are various ways of implementing cost-sensitive learning like using higher learning rate for examples belonging to majority class as compared to examples belonging to minority class or using class weighted loss functions which calculate loss by taking the class distribution into account and hence penalize the classifier more for misclassifying examples from minority class as compared to majority class. There are various class weighted loss functions but the two most widely used are WeightedCrossEntropy and Focal Loss.

Figure 5. Equation of Weighted Cross Entropy Loss
Figure 5. Equation of Weighted Cross Entropy Loss
Figure 6. Equation of Focal Loss
Figure 6. Equation of Focal Loss

WeightedCrossEntropy uses the classical CrossEntropy loss and incorporates a weight term for giving more weightage to a specific class which in class imbalance problems is the minority class. Hence, CrossEntropy loss penalizes more when the classifier misclassifies examples belonging to the minority classes.

Focal loss which was first introduced by FAIR in their paper Focal Loss for Dense Object Detection is designed in such a way that it performs two tasks for solving class imbalance. Firstly, it penalizes hard examples more as compared to easy examples and helps the algorithm to perform better. Hard examples are those examples where the model is not confident and predicts the ground truth with a low probability, whereas easy examples are those where the model is highly confident and predicts the ground truth with high probability. For example, Suppose the model predicts some class wrongly with a probability of 0.8, keeping γ = 2,α = 1 and log = log10 for simplicity, substituting values in focal loss equation given in figure 6, we get -(1–0.8)²log(0.8) = 0.00387. Now suppose, the model predicts some class with a probability of 0.2, keeping the same conditions as above we get, -(1–0.2)²log(0.2) = 0.4473 As you can see, the focal loss penalized the model more for the example it was less confident than an example for which the model was overconfident.

In the case of class imbalance, as the minority class is rare and the model sees few examples of it, the model is less confident when predicting it (hard-examples) as compared to the majority class which has abundant samples which in turn gives the model more opportunities to learn and hence the model is more confident predicting it(easy examples).

Hence, the focal loss provides more weightage to examples to minority class and less weightage to majority class which in turn makes the model focus more on minority class and improves the performance of the overall classifier.

Secondly, similar to weighted cross-entropy it has a weight term in its loss function. Setting the weight term appropriately can penalize the model more when it misclassifies the minority class than to majority class. Focal loss is extensively used to model tasks that suffer from the problem of class imbalance due to its above two properties.

WeightedCrossEntropy can be easily implemented in Pytorch by taking the CrossEntropy loss function available in Pytorch and specifying the weight term as shown in figure 7 below.

Figure 7. Weighted CrossEntropy loss code in Pytorch
Figure 7. Weighted CrossEntropy loss code in Pytorch

To implement focal loss, you can either implement it from scratch by referring to the paper or many packages implement it for you such as Catalyst, Kornia etc.

2. One-Class Classification:

One Class Classification as the name suggests is the technique of handling class imbalance by modelling the distribution of only the minority class and treating all other classes as out-of-distribution/anomaly classes. Using this technique, we aim to create a classifier that can detect examples belonging to the minority class rather than discriminating between minority and majority class. This is done in practice by training the model on only the instances belonging to the minority class and during test time using examples belonging to all the classes to test the ability of the classifier to correctly identify examples belonging to the minority class and at the same time flagging examples belonging to other classes.

One Class Classification technique is implemented in various ways. One widely used way in computer vision applications is using autoencoders where we train the autoencoder on examples belonging to the minority class and make it regenerate the input. Now at test time, we pass images belonging to all the classes and measure the reconstruction error of the model using loss functions such as RMSE, MSE etc. If an image belongs to the minority class, the reconstruction error will be low as the model is already familiar with its distribution and the reconstruction error would be high for examples belonging to classes other than the minority class.

Now you would be thinking "But how will we know which error is low and which is high? What is the threshold value which we will use to decide whether a given image belongs to minority class or is an anomaly?". In practice, we determine the threshold by plotting the error rate obtained by passing images belonging to all the categories and then choosing the value that gives a clear distinction between the minority and the other classes. People also take the mean ± std of reconstruction error obtained by passing the images belonging to the minority class. Any example having a reconstruction error more than mean ± std is classified as an anomaly.

Hybrid methods using both Neural Networks and a specialized one-class classifier is frequently used. This method comprises two stages. In the first stage, the image is passed through a feature extractor such as Resnet and features are extracted from one of the last layers of the model. Once the features are extracted they are flattened and passed to a specialized OCC algorithm such as One-Class SVM or Isolation forest. The classifiers learn from the features and use the learned knowledge to classify whether an example belongs to the training distribution or is an anomaly.

Sklearn supports a few one class classification classifiers such as OneClassSVM, IsolationForest etc. One should use cross-validation to tune the various hyperparameters of these models for better performance.

Conclusion:

Class Imbalance though widely prevalent in machine learning and its sub-domains is not given the required attention and the techniques commonly known are ancient and might not help.

Hence, in this post, I try to explain a few techniques I used to deal with inherent class imbalance. However, many techniques are not covered in this post but might be important for mitigating class imbalance.

If you find anything that you consider useful and think might help make this post better, I request you to contact me by leaving a message here or through my Linkedin, Twitter. I hope you liked the post and have a great day ahead :).

Inspiring Quote:

"Believe you can and you’re halfway there".


Related Articles