Explainable AI (XAI) with SHAP -Multi-Class Classification Problem

A practical guide for XAI analysis with SHAP for a Multi-class classification problem

Idit Cohen
Towards Data Science

--

Image by Pixabay Sager from Pixabay

Model explainability becomes a basic part of the machine learning pipeline. Keeping a machine learning model as a “black box” is not an option anymore. Luckily there are analytical tools such as (lime, ExplainerDashboard, Shapash, Dalex, and more) that are evolving rapidly and becoming more popular. In a previous post, we explained how to use SHAP for a regression problem. This guide provides a practical example of how to use and interpret the open-source python package, SHAP, for XAI analysis in Multi-class classification problems and use it to improve the model.

SHAP (Shapley Additive Explanations) by Lundberg and Lee (2016) is a method to explain individual predictions, based on the game theoretically optimal Shapley values[1]. Calculating Shapley value to get feature contributions is computationally expensive. There are two methods to approximate SHAP values to improve computation efficiency: KernelSHAP, TreeSHAP (for tree-based models only).

SHAP provides global and local interpretation methods based on aggregations of Shapley values. In this guide we will use the Internet Firewall Data Set example from Kaggle datasets [2], to demonstrate some of the SHAP output plots for a multiclass classification problem.

# load the csv file as a data frame
df = pd.read_csv('log2.csv')
y = df.Action.copy()
X = df.drop('Action',axis=1)

Create the model and fit like you always do.

X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, random_state=0)cls = RandomForestClassifier(max_depth=2, random_state=0)
cls.fit(X_train, y_train)

Now, just to get a basic impression of the model, I recommend viewing the feature importance and confusion matrix. Just to understand where we stand with the feature importance, I used scikit-learn which computes the impurity decrease within each tree [3].

֫importances = cls.feature_importances_
indices = np.argsort(importances)
features = df.columns
plt.title('Feature Importances')
plt.barh(range(len(indices)), importances[indices], color='g', align='center')
plt.yticks(range(len(indices)), [features[i] for i in indices])
plt.xlabel('Relative Importance')
plt.show()
Image by Author

Later we can compare these results to the feature importance calculated by the Shapley values.

Confusion matrix

A confusion matrix is a way to visualize the performance of a model. And more important we can easily see where the model fails exactly.

class_names = ['drop', 'allow', 'deny', 'reset-both']
disp = plot_confusion_matrix(cls, X_test, y_test, display_labels=class_names, cmap=plt.cm.Blues, xticks_rotation='vertical')
Image by Author

The model could not detect any instance from the class reset-both. The reason is an unbalanced dataset that gives a low number of examples of reset-both class to learn from.

y.value_counts()
allow 37640
deny 14987
drop 12851
reset-both 54
Name: Action, dtype: int64

SHAP Summary Plot

SHAP values of a model’s output explain how features impact the output of the model.

# compute SHAP values
explainer = shap.TreeExplainer(cls)
shap_values = explainer.shap_values(X)

Now we can plot relevant plots that will help us analyze the model.

shap.summary_plot(shap_values, X.values, plot_type="bar", class_names= class_names, feature_names = X.columns)

In this plot, the impact of a feature on the classes is stacked to create the feature importance plot. Thus, if you created features in order to differentiate a particular class from the rest, that is the plot where you can see it. In other words, the summary plot for multiclass classification can show you what the machine managed to learn from the features.

In the example below we can see that the class drop hardly uses the features pkts_sent, Source Port, and Bytes Sent. We can also see that the classes allow and deny uses the same features equally. That is the reason the confusion between them is relatively high. In order to separate better between the allow and deny classes, one needs to generate new features that uniquely be dedicated towards these classes.

You can also see the summary_plot of a specific class.

shap.summary_plot(shap_values[1], X.values, feature_names = X.columns)
Image by Author

The summary plot combines feature importance with feature effects. Each point on the summary plot is a Shapley value for a feature and an instance. The position on the y-axis is determined by the feature and on the x-axis by the Shapley value. You can see that the feature pkts_sent, being the least important feature, has low Shapley values. The color represents the value of the feature from low to high. Overlapping points are jittered in the y-axis direction, so we get a sense of the distribution of the Shapley values per feature. The features are ordered according to their importance.

In the summary plot, we see the first indications of the relationship between the value of a feature and the impact on the prediction. But to see the exact form of the relationship, we have to look at SHAP dependence plots.

SHAP Dependence Plot

The partial dependence plot (short PDP or PD plot) shows the marginal effect one or two features have on the predicted outcome of a machine learning model (J. H. Friedman 2001 [3]). A partial dependence plot can show whether the relationship between the target and a feature is linear, monotonic, or more complex.

The partial dependence plot is a global method: The method considers all instances and gives a statement about the global relationship of a feature with the predicted outcome. The PDP assumes that the first feature is not correlated with the second feature. If this assumption is violated, the averages calculated for the partial dependence plot will include data points that are very unlikely or even impossible.

A dependence plot is a scatter plot that shows the effect a single feature has on the predictions made by the model. In this example, the property value increases significantly when the average number of rooms per dwelling is higher than 6.

  • Each dot is a single prediction (row) from the dataset.
  • The x-axis is the actual value from the dataset.
  • The y-axis is the SHAP value for that feature, which represents how much knowing that feature’s value changes the output of the model for that sample’s prediction.

The color corresponds to a second feature that may have an interaction effect with the feature we are plotting (by default this second feature is chosen automatically). If an interaction effect is present between this other feature and the feature we are plotting it will show up as a distinct vertical pattern of coloring.

֫# If we pass a numpy array instead of a data frame then we
# need pass the feature names in separately
shap.dependence_plot(0, shap_values[0], X.values, feature_names=X.columns)
Image by Author

In the example above we can see a clear vertical pattern of coloring for the interaction between the features, Source Port and NAT Source Port.

SHAP Force plot

Force plot gives us the explainability of a single model prediction. In this plot, we can see how features contributed to the model’s prediction for a specific observation. It is very convenient to use for error analysis or a deep understanding of a particular case.

i=8
shap.force_plot(explainer.expected_value[0], shap_values[0][i], X.values[i], feature_names = X.columns)

From the plot we can see:

  1. The model predict_proba value: 0.79
  2. The base value: this is the value that would be predicted if we didn’t know any features for the current instance. The base value is the average of the model output over the training dataset (explainer.expected_value in the code). In this example base value = 0.5749
  3. The numbers on the plot arrows are the value of the feature for this instance. Elapsed Time (sec)=5 and Packets = 1
  4. Red represents features that pushed the model score higher, and blue representing features that pushed the score lower.
  5. The bigger the arrow, the bigger the impact of the feature on the output. The amount of decrease or increase in the impact can be seen on the x-axis.
  6. Elapsed Time of 5 seconds increases the property that the class is, allow. Packets of 1, reduce the property value.

SHAP waterfall plot

The waterfall plot is another local analysis plot of a single instance prediction. Let’s take instance number 8 as an example:

row = 8
shap.waterfall_plot(shap.Explanation(values=shap_values[0][row],
base_values=explainer.expected_value[0], data=X_test.iloc[row],
feature_names=X_test.columns.tolist()))
Image by Author
  1. f(x) is the model predict_proba value: 0.79.
  2. E[f(x)] is the base value = 0.5749.
  3. On the left are the features value and on the arrows the feature contribution to the prediction.
  4. Each row shows how the positive (red) or negative (blue) contribution of each feature moves the value from the expected model output over the background dataset to the model output for this prediction [2].

Summary

The SHAP framework has proved to be an important advancement in the field of machine learning model interpretation. SHAP combines several existing methods to create an intuitive, theoretically sound approach to explain predictions for any model. SHAP values quantify the magnitude and direction (positive or negative) of a feature’s effect on a prediction[6]. I believe XAI analysis with SHAP and other tools should be an integral part of the machine learning pipeline. The code in this post can be found here.

--

--