Neural Network Calibration using PyTorch

Make your model usable for safty-critical applications with a few lines of code.

Lukas Huber
Towards Data Science

--

Photo by Greg Shield on Unsplash

Imagine you are a radiologist working in this new high-tech hospital. Last week you got your first neural-network based model to assist you making diagnoses given your patients data and eventually improving your accuracy. But wait! Very much like us humans, synthetic models are never 100% accurate in their predictions. But how do we know if a model is absolutely certain or if it just barely surpasses the point of guessing? This knowledge is crucial for right interpretation and key for selecting appropriate treatment.

Assuming you’re more of an engineer: This scenario is also highly relevant for autonomous driving where a car constantly has to make decisions whether there is an obstacle in front of it or not. Ignoring uncertainties can get ugly real quick here.

If you are like 90% of the Deep Learning community (including past me) you just assumed that the predictions produced by the Softmax function
represent probabilities since they are neatly squashed into the domain [0,1]. This is a popular pitfall since these predictions generally tend to be overconfident. As we’ll see soon this behaviour is affected by a variety of architectural choices like the use of Batch Normalization or the number of layers.

You can find a interactive Google Colab notebook with all the code here.

Reliability Plots

As we know now, it is desirable to output calibrated confidences instead of their raw counterparts. To get an intuitive understanding of how well a specific architecture performs in this regard, Realiability Diagramms are often used.

Reliability Plot for a ResNet101 trained for 10 Epochs on CIFAR10 (Image by author)

Summarized in one sentence, Reliability Plots show how well the predicted confidence scores hold up against their actual accuracy. Hence, given 100 predictions each with confidence of 0.9, we expect 90% of them to be correct if the model is perfectly calibrated.

To fully understand what’s going on we need to dig a bit deeper. As we can see from the plot, all the confidence scores of the test set are binned into M=10 distinct bins [0, 0.1), [0.1, 0.2),…, [0.9, 1]. For each bin we can then calculate its accuracy

Accuracy formula

and confidence

Confidence formula

Both values are then visualized as a Bar plot with the identity line indicating perfect calibration.

Metrics

Diagrams and plots are just one side of the story. In order to score a model based on its Calibration Error we need to define metrics. Fortunately, both metrics most often used here are really intuitive.

The Expected Calibration Error (ECE) simply takes a weighted average over the absolute accuracy/confidence difference.

For safety critical applications, like described above, it may be useful to measure the maximum discrepancy between accuracy and confidence. This can be accomplished by using the Maximum Calibration Error (MCE).

Temperature Scaling

We now want to focus on how to tackle this issue. While many solutions like Histogram Binning, Isotonic Regression, Bayesian Binning into Quantiles (BBQ) and Platt Scaling exist (with their corresponding extensions for multiclass problems), we want to focus on Temperature Scaling. This is due to the fact that it is the easiest to implement while giving the best results out of the other algorithms named above.

To fully understand it we need to take a step back and look at the outputs of a neural network. Assuming a multi-class problem, the last layer of a network outputs the logits zᵢ ∈ ℝ. The predicted probability can then be obtained using the Softmax function σ.

Softmax function

Temperature scaling directly works on the logits z(Not the predicted probabilities!!) and scales them using a single parameter T>0 for all classes. The calibrated confidence can then be obtained by

Temperature Scaling formula

It is important to note that the parameter T is optimized with repect to the Negative-Log-Likelihood (NLL) loss on the validation set and the network’s parameters are fixed during this stage.

Results

Reliability Plot for a ResNet101 trained for 10 Epochs on CIFAR10 and calibrated using Temperature Scaling (Image by author)

As we can see from the figure, the bars are now way closer to the identity line, indicating almost perfect calibration. We can also see this looking at the metrics. The ECE dropped from 2.10% to 0.25% and the MCE from 27.27% to 3.86%, which is a drastic improvement.

Implementation in PyTorch

As promised, the implementation in PyTorch is rather straight forward.

First we define the T_scaling method returning the calibrated confidences given a specific temperature T together with the logits.

In the next step the parameter T has to be estimated using the LBGFS algorithm. This should only take a couple of seconds on a GPU.

The code is somewhat similar to the repository by gpleiss.

You are welcome to play around in the Google Colab Notebook I created here.

Conclusion

As shown in this article, network calibration can be accomplished in just a few lines of code with drastic improvements. If there is enough interest I’m happy to discuss other approaches for model calibration in another Medium article. If you are interested in a deeper dive into this topic I highly recommend the Paper “On calibration of Neural Networks” by Guo et al..

Cheers!

--

--