Multi-Class classification using Focal Loss and LightGBM

There are several approaches for incorporating Focal Loss in a multi-class classifier. Here’s one of them.

Luca Carniato
Towards Data Science

--

Motivation

Many real-world classification problems have an imbalanced distribution of classes. When data is heavily imbalanced, classification algorithms tend to predict the majority class. There are several approaches to mitigate class imbalance.

One approach is to assign sample weights inversely proportional to the class frequency. Another approach is to use oversampling/undersampling techniques. Popular techniques to generate artificial samples for the minority classes are the Synthetic Minority Oversampling TEchnique (SMOTE) and the ADAptive SYNthetic (ADASYN), both included in imblearn Python library.

Recently, the use of the Focal Loss objective function was proposed. The technique was used for binary classification by Tsung-Yi Lin et al. [1].

In this post, I will demonstrate how to incorporate Focal Loss into a LightGBM classifier for multi-class classification. The code is available on GitHub.

Binary classification

For a binary classification problem (labels 0/1) the Focal Loss function is defined as follows:

Eq.1 Focal Loss function

Where pₜ is a function of the true labels. For binary classification, this function is defined as:

Eq.2 Class probabilities

Where pₜ is obtained by applying the sigmoid function to the raw margins z:

Eq.3 Sigmoid function for converting raw margins z to class probabilities p

Focal Loss can be interpreted as a binary cross-entropy function multiplied by a modulating factor (1- pₜ)^γ which reduces the contribution of easy-to-classify samples. The weighting factor aₜ balances the modulating factor. Quoting from the authors: “with γ = 2, an example classified with pt = 0.9 would have 100 × lower loss compared with CE and with pt ≈ 0.968 it would have 1000 × lower loss”. Reducing the loss of easy to classify examples allows the training to focus more on hard-to-classify ones”.

An excellent post on incorporating Focal Loss in a binary LigthGBM classifier can be found in Max Halford's blog [2].

Multiclass classification

There are several approaches for incorporating Focal Loss in a multi-class classifier. Formally the modulating and the weighting factor should be applied to categorical cross-entropy. This approach requires providing the first-order and second-order derivatives of the multi-class loss for the raw margins z.

Another approach is using a One-vs-the-rest (OvR), in which a binary classifier is trained for each class C. The data from class C is treated as positive, and all other data as negative. In this post, the OvR approach is used, employing the binary classifier developed by Halford.

The class OneVsRestLightGBMWithCustomizedLoss shown below encapsulates the approach:

This class reimplements the OneVsRestClassifier class of the sklearn.multiclass namespace. The motivation for reimplementing the original OneVsRestClassifier class is being able to forward additional parameters to the fit method. This can be useful to pass an evaluation set (eval_set) for early stopping, thus reducing the computation time and avoiding overfitting.

Moreover, this class uses the generic LightGBM training API, which is required to obtain meaningful results when dealing with raw margins z and customized loss functions (see [2] for more details). Without these constraints, it would have been possible to implement the class more generically, not only accepting any loss function but also any model implementing the Scikit Learn model interface.

The other methods of the class are part of the Scikit Learn model interface: fit, predict, and predict_proba. In predict and predict_proba methods, the base estimator returns the raw margins z. Note that LightGBM returns the raw margins z when a customized loss function is used. Class probabilities are computed from the margins using the sigmoid function, as shown in Eq. 3.

An example

Let’s start by creating an artificial imbalanced dataset with 3 classes, where 1% of the samples belong to the first class, 1% to the second, and 98% to the third. As usual, the dataset is divided into a train set and a test set:

To keep the experiment simple, early stopping is discarded. The resulting confusion matrix after training is shown below:

Fig.1 Confusion matrix on the test set using the standard LightGBM classifier

For this first experiment, an accuracy of 0.990 and a recall value of 0.676 were obtained on the test set. The same experiment was repeated using the OneVsRestLightGBMWithCustomizedLoss classifier with the Focal Loss.

As it can be seen from the code above, the loss function is configurable outside the classifier and can be injected into the class constructor. Early stopping can be turned on by providing to the fit method a dictionary containing the eval_set, as shown in the commented lines above. For the second experiment the resulting confusion matrix is shown below:

Fig.2 Confusion matrix on the test set using LightGBM and the customized multi-class Focal Loss class (OneVsRestLightGBMWithCustomizedLoss)

In this case, an accuracy of 0.995 and a recall value is 0.838 were obtained, improving on the first experiment using the default logarithmic loss. This result is also evident from the confusion matrix, where the false positives for Class 0 and false negatives for Class 1 are significantly reduced.

Conclusions

In this post, I demonstrated an approach for incorporating Focal Loss in a multi-class classifier, by using the One-vs-the-rest (OvR) approach.

Using the Focal Loss objective function, sample weight balancing, or artificial addition of new samples to reduce the imbalance is not required. On an artificially generated multi-class imbalanced dataset, the use of Focal loss increased the recall value and eliminated some false positives and negatives in the minority classes.

The validity of the approach must be confirmed by exploring real-world datasets where noise and non-informative features are expected to influence the classification results.

[1] Lin, T. Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017). Focal loss for dense object detection. In Proceedings of the IEEE international conference on computer vision (pp. 2980–2988).

[2] Max Halford (2020). Focal loss implementation for LightGBM. https://maxhalford.github.io/blog/lightgbm-focal-loss/#first-order-derivative

--

--