Machine Learning — Multiclass Classification with Imbalanced Dataset

Challenges in classification and techniques to improve performance

source [Unsplash]

Classification problems having multiple classes with imbalanced dataset present a different challenge than a binary classification problem. The skewed distribution makes many conventional machine learning algorithms less effective, especially in predicting minority class examples. In order to do so, let us first understand the problem at hand and then discuss the ways to overcome those.

  1. Multiclass Classification: A classification task with more than two classes; e.g., classify a set of images of fruits which may be oranges, apples, or pears. Multi-class classification makes the assumption that each sample is assigned to one and only one label: a fruit can be either an apple or a pear but not both at the same time.
  2. Imbalanced Dataset: Imbalanced data typically refers to a problem with classification problems where the classes are not represented equally. For example, you may have a 3-class classification problem of set of fruits to classify as oranges, apples or pears with total 100 instances . A total of 80 instances are labeled with Class-1 (Oranges), 10 instances with Class-2 (Apples) and the remaining 10 instances are labeled with Class-3 (Pears). This is an imbalanced dataset and the ratio of 8:1:1. Most classification data sets do not have exactly equal number of instances in each class, but a small difference often does not matter. There are problems where a class imbalance is not just common, it is expected. For example, in datasets like those that characterize fraudulent transactions are imbalanced. The vast majority of the transactions will be in the “Not-Fraud” class and a very small minority will be in the “Fraud” class.

Dataset

The data set we will be using for this example is the famous “20 News groups” data set. The 20 Newsgroups data set is a collection of approximately 20,000 newsgroup documents, partitioned (nearly) evenly across 20 different newsgroups. The 20 newsgroups collection has become a popular data set for experiments in text applications of machine learning techniques, such as text classification and text clustering.

scikit-learn provides the tools to pre-process the dataset, refer here for more details. The number of articles for each news group given below is roughly uniform.

Removing some news articles from some groups to make the overall dataset imbalanced like below.

Now our imbalanced dataset with 20 classes is ready for further analysis.

Build Model

As this is a classification problem, we will use the similar approach as described in my previous article for sentiment analysis. The only difference is here we are dealing with multiclass classification problem.

The last layer in the model is Dense(num_labels, activation =’softmax'),with num_labels=20 classes, ‘softmax’ is used instead of ‘sigmoid’ . The other change in the model is about changing the loss function to loss = ‘categorical_crossentropy’, which is suited for multi-class problems.

Train Model

Training the model with 20% validation set validation_split=20 and usingverbose=2, we see validation accuracy after each epoch. Just after 10 epochs we reach validation accuracy of 90%.

Evaluate Model

This looks like a very good accuracy but is the model really doing well?

How to measure model performance? Let us consider that we train our model on imbalanced data of earlier example of fruits and since data is heavily biased towards Class-1 (Oranges), the model over-fits on the Class-1 label and predicts it in most of the cases and we achieve an accuracy of 80% which seems very good at first but looking closely, it may never be able to classify apples or pears correctly. Now the question is if the accuracy, in this case, is not the right metric to choose then what metrics to use to measure the performance of the model?

Confusion-Matrix

With imbalanced classes, it’s easy to get a high accuracy without actually making useful predictions. So, accuracy as an evaluation metrics makes sense only if the class labels are uniformly distributed. In case of imbalanced classes confusion-matrix is good technique to summarizing the performance of a classification algorithm.

Confusion Matrix is a performance measurement for a classification algorithm where output can be two or more classes.

x-axis=Predicted label, y-axis, True label

When we closely look at the confusion matrix, we see that the classes [alt.athiesm, talk.politics.misc, soc.religion.christian] which have very less samples [65,53, 86] respectively are indeed having very less scores [0.42, 0.56, 0.65] as compared to the classes with higher number of samples like [rec.sport.hockey, rec.motorcycles]. Thus looking at the confusion matrix one can clearly see how the model is performing on classifying various classes.

How to improve the performance?

There are various techniques involved in improving the performance of imbalanced datasets.

Re-sampling Dataset

To make our dataset balanced there are two ways to do so:

  1. Under-sampling: Remove samples from over-represented classes ; use this if you have huge dataset
  2. Over-sampling: Add more samples from under-represented classes; use this if you have small dataset

SMOTE (Synthetic Minority Over-sampling Technique)

SMOTE is an over-sampling method. It creates synthetic samples of the minority class. We use imblearn python package to over-sample the minority classes .

we have 4197 samples before and 4646 samples after applying SMOTE, looks like SMOTE has increased the samples of minority classes. We will check the performance of the model with the new dataset.

Improved validation accuracy from 90 to 94%. Let us test the model:

Little improvement in test accuracy than before (from 87 to 88%). Let us have a look at the confusion matrix now.

We see that the classes [alt.athiesm, talk.politics.misc, sci.electronics, soc.religion.christian] having improved scores [0.76, 0.58, 0.75, 0.72] than before.Thus the model is performing better than before while classifying the classes even though accuracy is similar.

Another Trick:

Since classes are imbalanced, what about providing some bias to minority classes ? We can estimate class weights in scikit_learn by using compute_class_weight and use the parameter ‘class_weight’, while training the model. This can help to provide some bias towards the minority classes while training the model and thus help in improving performance of the model while classifying various classes.

Precision-Recall Curves

Precision-Recall is a useful measure of success of prediction when the classes are very imbalanced. Precision is a measure of the ability of a classification model to identify only the relevant data points, while recall is a measure of the ability of a model to find all the relevant cases within a dataset.

The precision-recall curve shows the trade-off between precision and recall for different threshold. A high area under the curve represents both high recall and high precision, where high precision relates to a low false positive rate, and high recall relates to a low false negative rate.

High scores for both precision and recall show that the classifier is returning accurate results (precision), as well as returning a majority of all positive results (recall). An ideal system with high precision and high recall will return many results, with all results labeled correctly.

Below is a precision-recall plot for 20 News groups dataset using scikit-learn.

Precision-Recall Curve

We would like to have the area of P-R curve for each class to be close to 1. Except classes 0 , 3 & 18 rest of the classes are having area above .75. You can try with different classification models and hyper-parameter tuning techniques to improve the result further.

Conclusion

We discussed the problems associated with classification of multi classes in an imbalanced dataset. We also demonstrated how using the right tools and techniques help us in developing better classification models.

Thanks for reading. The code can be found on Github.

References

Towards Data Science

A Medium publication sharing concepts, ideas, and codes.

More From Medium

More from Towards Data Science

More from Towards Data Science

More from Towards Data Science

More from Towards Data Science

The Most Useful ML Tools 2020

761

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade