
Introduction
A confusion matrix is a convenient way to present the types of mistakes a machine learning mode makes. It is an N by N grid with numbers, where the value in the [n, m] cell represents the number of examples annotated with the n-th class recognized as the m-th class. In this tutorial, I will focus on creating a confusion matrix and a Heatmap. The color palette will be used to display the sizes of different groups, making it easy to notice similarities or significant differences in group sizes. This kind of visualization is handy when you deal with numerous categories.
Here is a visual explanation of the elements of the confusion matrix.

Please remember that the data used to demonstrate confusion matrices is artificial and does not represent any real classification model.
Now, I will explain step by step how to generate such a confusion matrix using Python modules.
Python bare minimum
To create a confusion matrix with a heatmap, you need three modules:
pip install scikit-learn, seaborn, pandas
Assuming you have two lists of predictions and true labels, you need to do the following:
- Calculate the confusion matrix –
confusion_matrix
- Convert the variable into a data frame –
pd.DataFrame
- Create the heatmap Plot –
sn.heatmap
- And finally, save the plot to a file –
cfm_plot.figure.savefig
import pandas as pd
from sklearn.metrics import confusion_matrix
import seaborn as sn
if __name__ == '__main__':
predictions = ["None", "Dog", "Cat", ...]
true_labels = ["None", "Dog", "Dog", ...]
cm = confusion_matrix(true_labels, predictions)
df_cfm = pd.DataFrame(cm)
cfm_plot = sn.heatmap(df_cfm)
cfm_plot.figure.savefig("data/confusion_matrix_v1.png")
Here is the output:

The output is not impressive at all. By default, much useful information and customization are disabled or not suitable for our data. Let’s improve the plot.
Labels
We need to create a list of all labels to display the labels. We can do this using the information from our predictions and true labels (first line). To improve the readability and keep the same order between runs, we move the majority class (None) as the first element and sort the remaining labels (second line). Without this, the order of labels might be different every time we run the code.
Next, we add the labels=label_names
to the confusion_matrix
method and the index=label_names, columns=label_names
to the constructor of the data frame.
label_names = list(set([] + predictions + true_labels))
label_names = ["None"] + sorted([a for a in label_names if a != "None"])
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
cfm_plot = sn.heatmap(df_cfm)

We can see one issue with the labels. The labels on the Y axe are partially truncated. To fix this, we can increase the plot canvas using figsize
.
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cfm_plot = sn.heatmap(df_cfm)

Values
In the next step, we will display a value for each cell. It is a convenient way to observe the exact number of errors for each pair of categories. We will use the annot
parameters for the heatmap
method to display the values. It takes a data frame with the same dimensions as our confusion matrix. Thus, we can pass again the same data frame, that is df_cfm
.
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cfm_plot = sn.heatmap(df_cfm, annot=df_cfm)

We can see the exact values, but the plot has two issues. The first is a weird way of displaying the large number, and the other is the large number of 0-s, making the plot hard to read.
To fix the problem with displaying numbers, we will change the default string formatting code to use when adding annotations from .2g
to empty using the fmt
argument.
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cfm_plot = sn.heatmap(df_cfm, annot=df_cfm, fmt="")

To hide the 0-s, we will copy the data frame and replace each 0 with an empty string.
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="")
cfm_plot.figure.savefig("data/confusion_matrix_v6.png")

Scale and colors
The idea of a heatmap is to use colors to visually diverse the values. In our case, we can see only one distinct value, the number of true positives for the None class, that the other classes (Dog and Pig), while the remaining categories look the same. The issue is the range of our values, from 0 to 1413, where most values are close to 0. To make the values more distractible, we can change the scale of the colors from linear to logarithmic. It can be done using the norm
parameters and setting it to LogNorm()
.
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="", norm=LogNorm())

The plot looks much better with a logarithmic scale for the colors. With empty cells removed, it is easier to analyze the non-empty cells as there is less distraction. In some cases, it might be tricky to follow the rows and columns. To fix this, we can add vertical and horizontal lines using the linewidth
and linecolor
parameters.
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="", norm=LogNorm(),
linewidths=0.5, linecolor="grey")

The last step is to choose the preferred color palette. Seaborn has several ready-to-use palettes, which are presented here: https://seaborn.pydata.org/tutorial/color_palettes.html. To change the color palette, provide the name of the pallet to the cmap
parameter. Here is an example using the crest
palette.
cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="", norm=LogNorm(),
linewidths=0.5, linecolor="grey", cmap="crest")

Here are examples of some other palettes:



Conclusions
Playing with colors and formatting might seem a waste of time, as the numbers are most important. However, the right plot might significantly improve the readability and accessibility of our data, especially when we present it to a client who is not as familiar with it as we are. It is worth spending some extra minutes to figure out if there is a better way of presenting the raw data and insights from the analysis.
References
[1] https://seaborn.pydata.org/
[2] https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html