The world’s leading publication for data science, AI, and ML professionals.

Explainable AI (XAI) with Class Maps

Introducing a novel visual tool for explaining the results of classification algorithms, with examples in R and Python.

Photo by Charles Deluvio on Unsplash
Photo by Charles Deluvio on Unsplash

The need for model explanations

Classification algorithms aim to identify to which groups a set of observations belong. A machine learning practitioner typically builds multiple models and selects a final classifier to be one that optimizes a set of accuracy metrics on a held-out test set. Sometimes, practitioners and stakeholders want more from the classification model than just predictions. They may wish to know the reasons behind a classifier’s decisions, especially when it is built for high-stakes applications. For instance, consider a medical setting, where a classifier determines a patient to be at high risk for developing an illness. If medical experts can learn the contributing factors to this prediction, they could use this information to help determine suitable treatments.

Some models, such as single decision trees, are transparent, meaning that they show the mechanism for how they make decisions. More complex models, however, tend to be the opposite – they are often referred to as "black boxes", as they provide no explanation for how they arrive at their decisions. Unfortunately, opting for transparent models over black boxes does not always solve the explainability problem. The relationship between a set of observations and its labels is often too complex for a simple model to suffice; transparency can come at the cost of accuracy [1].

The increasing use of black-box models in high-stakes applications, combined with the need for explanations, has lead to the development of Explainable Ai (XAI), **** a set of methods that help humans understand the outputs of machine learning models. Explainability is a crucial part of the responsible development and use of AI.

Visualizing the perspective of the classifier

When we cannot derive an explanation directly from a classifier, we can use post-hoc interpretations, a category of XAI methods that explain a model’s predictions without necessarily revealing how it works [3]. Visualization is a powerful way to explain complex decisions in a simple, human-friendly manner. Furthermore, visualization tools can be easily placed within a Machine Learning pipeline, allowing for automatic explanations of training sets, test sets, or special subsets (for example, visualizing a classifier’s performance across sensitive groups).

Raymaekers et al. (2021) introduces the class map, a visual tool that maps the probabilities output from a classifier to the distribution of the observations for which it makes predictions. Class maps show us observations from the perspective of a classifier, shedding light on the relationship between a classifier’s decision rules and the structure of the data. Practitioners can use class maps to determine whether a classification model makes errors only on edge cases, or if it misses observations that humans can easily identify. We can also get an idea if a classifier performs consistently across sub-groups of observations, indicating its level of fairness.

Essentially, class maps help us easily determine whether we can trust a model before deploying it.

The best way to understand a class map is with a simple example (Note: the code for this and all following examples are available at the end of this post.)

Example: Classifying the Iris Dataset

Consider the task of classifying the benchmark iris dataset, which contains four attributes of 150 observations of iris flowers, evenly divided into three species: setosa, versicolor and virginica. To get an idea of the structure of the data, we first show a Principal Component Analysis (PCA) view, colored by class label:

A Principal Component Analysis (PCA) view of the iris dataset, colored by ground-truth class label. The "setosa" irises appear as a distinct group, while "versicolor" and "virginica" have some overlap, making them harder to classify | Image by author
A Principal Component Analysis (PCA) view of the iris dataset, colored by ground-truth class label. The "setosa" irises appear as a distinct group, while "versicolor" and "virginica" have some overlap, making them harder to classify | Image by author

Observations from the setosa species exist as a distinct group, while versicolor and virginica flowers have more in common. A classifier will likely have difficulty classifying observations from the latter two classes.

As the data appears to be linearly separable, we opt to fit a Linear Discriminant Analysis classifier (note: we do not do a train/test split in this example, but fit the classifier to the entire dataset). We obtain a fitted classifier with 98% accuracy. We could stop there, or we could ask: Where does the classifier make mistakes, and why ?

Let’s look at a class map for the observations with ground-truth labels "versicolor" and "virginica":

Class maps of observations with ground-truth labels "versicolor" (left) and "virginica" (right). The class maps show the classifier's perspective of the data. Each observation is a point, colored by the classifier's predicted class. The y-axis shows the probability that an observation belongs to an alternative class. The x-axis is farness, a measure of how far an object lies from its class center. Here, the misclassified observations are far from their class, indicating that they are challenging examples to classify | Image by author
Class maps of observations with ground-truth labels "versicolor" (left) and "virginica" (right). The class maps show the classifier’s perspective of the data. Each observation is a point, colored by the classifier’s predicted class. The y-axis shows the probability that an observation belongs to an alternative class. The x-axis is farness, a measure of how far an object lies from its class center. Here, the misclassified observations are far from their class, indicating that they are challenging examples to classify | Image by author

The class map visualizes observations with ground-truth label belonging to a single class as a scatterplot. Each observation is colored by its predicted class. For each observation, the class map reflects the probability that the observation belongs to an alternative class (on the y-axis) against a measure of how far it is from its own class (on the x-axis). These quantities are called Probability of Alternative Classification (PAC) and farness, respectively, and are derived from the trained classifier. An interested reader can find the mathematical definitions of each quantity in Raymaekers et al. (2021).

Looking horizontally, the class map of observations belonging to class "versicolor" (left) shows that the classifier sees most observations as typical of their class. A few observations are deemed "far" from their class. A roughly similar picture is shown for the virginica flowers, but the distribution of observations of this class is more spread out – there are more observations with high farness. Looking vertically at each class map, we see the conditional probability that each observation belongs to its class, derived from the trained classifier. Most observations are seen as highly likely belonging to their ground-truth class. However, three observations in total are misclassified. Looking at the PCA view of the data, it is not surprising to see that the classifier finds confusion between versicolor and virginica flowers.

We can also see how the classifier views observations with ground-truth label "setosa":

Class map of class "setosa". The classifier is certain that all observations belong in the same class | Image by author
Class map of class "setosa". The classifier is certain that all observations belong in the same class | Image by author

Looking vertically, we see that the classifier is strongly certain that all observations of this class belong to their true label. This is true regardless of where an observation lies with respect to its class center.

While this example is much simpler than a typical classification problem, it shows the idea behind a class map. With this tool, we are able to see how the data are viewed by a specific classifier. We understand that the "setosa" class is easy to classify, while the other two classes provide more challenge. The classifier makes mistakes on observations that are atypical of their own class.

Both quantities plotted in a class map are derived from a trained classifier, making this method model-dependent. Currently, class maps are available for the following classifiers: Discriminant Analysis, K-nearest neighbors, Support Vector Machines, Decision Trees, Random Forest, and Neural networks.

But what if we want explanations of other classification models ?

The localized class map: a model-agnostic class map

A second kind of class map is the localized class map, a model-agnostic version of the class map, introduced in Kawa (2021). The y-axis, which shows the conditional probability than an observation belongs to its ground-truth class, remains unchanged. The x-axis of the localized class map displays localized farness, a measure which uses local qualities of the data to assess where an object lies in the data space. This shows a picture of how the classifier performs with respect to the local and global structure of the data.

Localized farness measures how far an observation lies from its class by looking within its local neighborhood. It is based on method presented in Napierala (2015) . For a _k-_neighborhood of an observation i, localized farness computes a weighted ratio of neighbors from the same ground-truth class. If i belongs to a neighborhood where nearly all neighbors belong to the same class, it has a low localized farness; essentially, it lives close to its class members. If i lives in a neighborhood where most of its neighbors are from a different class, then it is far from its own class. In the definition (which can be found in this blog post), more weight is given to closer neighbors. The user determines the appropriate k, or the size of the neighborhood, with the default set to k=10.

By replacing farness with localized farness, we can make class maps for any classifier (if the data are not numeric, we can use Gower’s distance to find its nearest neighbors). Furthermore, the resulting view of the distribution of the data does not depend on the classifier. So, we can use this version of the class map to get an idea of how difficult the data is to classify, and compare different classifiers on the same data.

As with the original class map, we explain the localized class map with a benchmark example. More detailed explanations and examples can be found here.

Example: Classifying Handwritten Digits

A benchmark problem in machine learning is the classification of the MNIST handwritten digits dataset. We train a Multi-layer Perceptron (MLP) on a training set of 60,000 handwritten digits, obtaining a 96.11% accuracy on a held-out test set of 10,000 digits. To get an idea of the performance of this classifier, we can look to a confusion matrix:

Figure 1: Confusion matrix of a MLP classifier reflecting test-set perforamance by class | Image by author
Figure 1: Confusion matrix of a MLP classifier reflecting test-set perforamance by class | Image by author

Viewing the image horizontally, we can see how the classifier performs for each label in the test set (digits 0–9). For the images with true label "3", we see that the classifier identifies most images to their true class. However, it incorrectly identifies 12 images to label "5". Looking at observations having true label "9", we see the classifier has most confusion between "9" and "4".

For some applications, a confusion matrix gives a good enough picture of the classification. However, sometimes we want to know more. For instance,

  • For which examples does a classifier make mistakes?
  • Does the classifier make errors that humans also make ? Or, does it make mistakes on the relatively easy examples?
  • Are some classes more difficult to classify than others ? In other words, do some classes contain many challenging examples and/or contamination?

As in the first example, we can use a class map to answer these questions from the perspective of the classifier. To get a model-agnostic view, we can plot a localized class map for each digit in the test set. Below is the localized class map for observations with ground-truth label "3":

Localized class map of the test set observations with ground truth label "3". The classifier performs better for this class than its overall performance. Most observations that are misclassified have a high localized farness, meaning that they live in neighborhoods that are not homogeneous | Image by author
Localized class map of the test set observations with ground truth label "3". The classifier performs better for this class than its overall performance. Most observations that are misclassified have a high localized farness, meaning that they live in neighborhoods that are not homogeneous | Image by author

Like before, each observation is visualized as point, colored by predicted label. The y-axis remains as Probability of Alternative Class (PAC), which reflects how likely the classifier would assign the observations to an alternative class. Here, we see that most observations have PAC = 0; the classifier is quite confident that they belong to their own class. However, quite a few points have a higher PAC value, reflecting the classifier’s uncertainty with their class membership.

On the x-axis we have localized farness, which measures how far an observation lies from its own class (here, we use the default value of k to compute localized farness) with respect to a local neighborhood. Most objects have a relatively low localized farness, meaning that they are typical examples of a digit "3". Some objects, however, deviate from their class. We have identified three examples in the localized class map and plotted them below:

Examples of observations with ground truth label "3" that have a high localized farness. They are atypical examples from their class | Image by author
Examples of observations with ground truth label "3" that have a high localized farness. They are atypical examples from their class | Image by author

Viewing from left to right, the first two image have a PAC close to 0 – the classifier is quite sure that they belong to digit "3", although they look different from the average digit. The last image has a PAC = 1: the classifier is strongly certain that this image is of digit "7". It is sensible that the classifier believes this; the image is a challenging example for a human to classify.

We can also plot a localized class map from the same classifier, this time for examples with ground-truth label "9", this time using k=40 nearest neighbors to compute localized farness:

Localized class map for test set observations with ground-truth label "9". The classifier performs poorly even on observations that lie in relatively homogeneous neighborhoods. It has the most confusion between classes "9" and "4" | Image by author
Localized class map for test set observations with ground-truth label "9". The classifier performs poorly even on observations that lie in relatively homogeneous neighborhoods. It has the most confusion between classes "9" and "4" | Image by author

For this class, the classifier has relatively poorer performance. Indeed, we see quite a few points misclassified to digit "4". The classifier seems to make mistakes irrespective of an observation’s localized farness. The classifier misses some easy examples, and correctly identifies some challenging examples. This indicates that the classifier does not make its decisions in line with observations’ local distances. Had we used, for instance, a k-nearest neighbor classifier, we would find instead a strong correlation between localized farness and PAC.

As we did previously, we can look at examples of observations with localized farness = 0 (they are typical examples of their class):

Examples of observations with ground truth label "9" that have zero localized farness. They are typical examples from their class | Image by author
Examples of observations with ground truth label "9" that have zero localized farness. They are typical examples from their class | Image by author

and those with relatively high localized farness, which lie in neighborhoods close to other digits:

Examples of observations with ground truth label "9" that have high localized farness. They are atypical examples from their class, and most are misclassified | Image by author
Examples of observations with ground truth label "9" that have high localized farness. They are atypical examples from their class, and most are misclassified | Image by author

We can see why each example lives in a heterogeneous neighborhood. In the top row, the examples resemble the number that they are predicted; the top-right example could truly be the number "4". The bottom row is easier for a human to identify, yet still shows examples that are atypical of the standard digit "9".

Conclusion

Machine learning practitioners and stakeholders often require explanations of the results of classification algorithms. We introduce the class map, a novel visualization tool that helps end-users understand how a classifier makes its decisions. We then introduce the localized class map, a model-agnostic extension of the class map. Using benchmark examples, we have seen that both the class map and the localized class map help us answer several questions about how a classifier works, such as:

  • Does the classifier make decisions in line with an object’s difficulty ? Or, does it make mistakes on typical examples ?
  • Is the data relatively homogeneous, or are there a lot of challenging examples ?
  • Where does a classifier have confusion ?
  • Why does a classifier make a mistake on a particular observation ?

While this post focuses on sharing the basic idea of the class map and its localized counterpart, the tool has uses beyond visualizing a single classification algorithm. It can also be used to compare different classifiers on the same data, assess the fairness of a classifier. To get a more detailed understanding of the class map, we refer to the original paper [5] and its associated R package. For the localized class map, we refer to [2] and its associated implementation.

Code Examples: Class Maps in R and Python

Both versions of the class map are open-source and easy to use. The class map is implemented in the R package classmap, available on CRAN.

The localized classmap is implemented in Python and in R, and are built with popular machine learning frameworks in mind. In Python, the localized class map can be plotted for most scikit-learn classifiers. In R, the localized class map is compatible with caret . Below, we present reproducible examples in each language.

R Example: Iris— class map

Python Example: MNIST – localized class map

In R, the localized class map is implemented to be compatible with caret, a machine learning framework that contains a large number of classification models. Below, we show part of an example with the Iris dataset (full example is here).

R Example: Iris – localized class map

Thanks for reading !

References

[1] Amina Adadi and Mohammed Berrada, Peeking Inside the Black-Box: A Survey on Explainable Artificial Intelligence(XAI) 2018, IEEE Xplore

[2] Kawa, N. Visualizing Classification Results, 2021, Github

[3] Lipton, Z. C., The mythos of model interpretability: In machine learning, the concept of interpretability is both important and slippery, 2018, arXiv

[4] Napierala, K. and Stefanowski, J., Types of minority class examples and their influence on learning classifiers from imbalanced data, 2015, Journal of Intelligent Information Systems

[5] Raymaekers, J., Rousseeuw, P. J., and Hubert, M., Class maps for visualizing classification results, 2021, Technometrics


Related Articles