The world’s leading publication for data science, AI, and ML professionals.

Heatmap for Confusion Matrix in Python

One image can be worth of thousands words.

Image by author
Image by author

Introduction

A confusion matrix is a convenient way to present the types of mistakes a machine learning mode makes. It is an N by N grid with numbers, where the value in the [n, m] cell represents the number of examples annotated with the n-th class recognized as the m-th class. In this tutorial, I will focus on creating a confusion matrix and a Heatmap. The color palette will be used to display the sizes of different groups, making it easy to notice similarities or significant differences in group sizes. This kind of visualization is handy when you deal with numerous categories.

Here is a visual explanation of the elements of the confusion matrix.

Image by author
Image by author

Please remember that the data used to demonstrate confusion matrices is artificial and does not represent any real classification model.

Now, I will explain step by step how to generate such a confusion matrix using Python modules.


Python bare minimum

To create a confusion matrix with a heatmap, you need three modules:

pip install scikit-learn, seaborn, pandas

Assuming you have two lists of predictions and true labels, you need to do the following:

  1. Calculate the confusion matrix – confusion_matrix
  2. Convert the variable into a data frame – pd.DataFrame
  3. Create the heatmap Plotsn.heatmap
  4. And finally, save the plot to a file – cfm_plot.figure.savefig
import pandas as pd
from sklearn.metrics import confusion_matrix
import seaborn as sn

if __name__ == '__main__':
    predictions = ["None", "Dog", "Cat", ...]
    true_labels = ["None", "Dog", "Dog", ...]

    cm = confusion_matrix(true_labels, predictions)
    df_cfm = pd.DataFrame(cm)
    cfm_plot = sn.heatmap(df_cfm)
    cfm_plot.figure.savefig("data/confusion_matrix_v1.png")

Here is the output:

Image by author
Image by author

The output is not impressive at all. By default, much useful information and customization are disabled or not suitable for our data. Let’s improve the plot.


Labels

We need to create a list of all labels to display the labels. We can do this using the information from our predictions and true labels (first line). To improve the readability and keep the same order between runs, we move the majority class (None) as the first element and sort the remaining labels (second line). Without this, the order of labels might be different every time we run the code.

Next, we add the labels=label_names to the confusion_matrix method and the index=label_names, columns=label_names to the constructor of the data frame.

label_names = list(set([] + predictions + true_labels))
label_names = ["None"] + sorted([a for a in label_names if a != "None"])

cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
cfm_plot = sn.heatmap(df_cfm)
Image by author
Image by author

We can see one issue with the labels. The labels on the Y axe are partially truncated. To fix this, we can increase the plot canvas using figsize.

cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cfm_plot = sn.heatmap(df_cfm)
Image by author
Image by author

Values

In the next step, we will display a value for each cell. It is a convenient way to observe the exact number of errors for each pair of categories. We will use the annot parameters for the heatmap method to display the values. It takes a data frame with the same dimensions as our confusion matrix. Thus, we can pass again the same data frame, that is df_cfm.

cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cfm_plot = sn.heatmap(df_cfm, annot=df_cfm)
Image by author
Image by author

We can see the exact values, but the plot has two issues. The first is a weird way of displaying the large number, and the other is the large number of 0-s, making the plot hard to read.

To fix the problem with displaying numbers, we will change the default string formatting code to use when adding annotations from .2g to empty using the fmt argument.

cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cfm_plot = sn.heatmap(df_cfm, annot=df_cfm, fmt="")
Image by author
Image by author

To hide the 0-s, we will copy the data frame and replace each 0 with an empty string.

cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="")
cfm_plot.figure.savefig("data/confusion_matrix_v6.png")
Image by author
Image by author

Scale and colors

The idea of a heatmap is to use colors to visually diverse the values. In our case, we can see only one distinct value, the number of true positives for the None class, that the other classes (Dog and Pig), while the remaining categories look the same. The issue is the range of our values, from 0 to 1413, where most values are close to 0. To make the values more distractible, we can change the scale of the colors from linear to logarithmic. It can be done using the norm parameters and setting it to LogNorm().

cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="", norm=LogNorm())
Image by author
Image by author

The plot looks much better with a logarithmic scale for the colors. With empty cells removed, it is easier to analyze the non-empty cells as there is less distraction. In some cases, it might be tricky to follow the rows and columns. To fix this, we can add vertical and horizontal lines using the linewidth and linecolor parameters.

cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="", norm=LogNorm(), 
                      linewidths=0.5, linecolor="grey")
Image by author
Image by author

The last step is to choose the preferred color palette. Seaborn has several ready-to-use palettes, which are presented here: https://seaborn.pydata.org/tutorial/color_palettes.html. To change the color palette, provide the name of the pallet to the cmap parameter. Here is an example using the crest palette.

cm = confusion_matrix(true_labels, predictions, labels=label_names)
df_cfm = pd.DataFrame(cm, index=label_names, columns=label_names)
plt.figure(figsize=(10, 7))
cell_value = df_cfm.applymap(lambda v: v if v else "")
cfm_plot = sn.heatmap(df_cfm, annot=cell_value, fmt="", norm=LogNorm(),
                      linewidths=0.5, linecolor="grey", cmap="crest")
Image by author: Palette crest
Image by author: Palette crest

Here are examples of some other palettes:

Image by author: Palette viridis
Image by author: Palette viridis
Image by author: Palette magma
Image by author: Palette magma
Image by author: Palette rocket_r
Image by author: Palette rocket_r

Conclusions

Playing with colors and formatting might seem a waste of time, as the numbers are most important. However, the right plot might significantly improve the readability and accessibility of our data, especially when we present it to a client who is not as familiar with it as we are. It is worth spending some extra minutes to figure out if there is a better way of presenting the raw data and insights from the analysis.


References

[1] https://seaborn.pydata.org/

[2] https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html


Related Articles