Use SHAP loss values to debug/monitor your model

How you should understand the SHAP loss value and use it

Chuangxin Lin
Towards Data Science

--

Responsible AI has been a very hot topic in recent years. Accountability and explainability now become the necessary components of your machine learning models, particularly when the models make decisions that will impact people’s life, such as medical diagnostic, financial service. This is a very large topic for machine learning and a lot of ongoing work has been dedicated to various aspects. You can check more resources on this topic[1]. In this post, I will focus on SHAP (SHapley Additive exPlanations), which is one of the most popular explainability packages, due to its versatile (local/global explainability; model-specific/agnostic) and the solid theoretical foundation from game theory. You can find many posts and tutorials to understand how SHAP can help you understand how your ML model works, i.e., how each of your features contributes to the model prediction. However, in this post, I will talk about SHAP loss values that many people may be less familiar with. I will walk through some key concepts by presenting an example. I will also share some of my thoughts.

To begin with, you may want to check the example provided by SHAP package. And there are two important notes:

  • The shap loss values will show you how each feature contributes to the logloss value from the expected value. Note, in this post, when I say loss value, it refers to logloss, since we will look at the classification problem)
Just like the SHAP values for prediction, the SHAP loss values represent how each feature contributes to the logloss. The expected value is the baseline value which depends on the label. The expected value is calculated by setting the labels of all instance to True when a data instance is True (setting to False when a data instance is False)
  • You should use “interventional” method for the calculation of SHAP loss values

Essentially, this means when integrating out the absent features, you should use the marginal distribution instead of the conditional distribution. And the way to achieve the marginal distribution is to assign the absent features with the values from the background dataset.

The use of “interventional” (i.e., marginal distribution) or “tree_path_dependent” (i.e., conditional distribution) is an important nuance (see docstring in SHAP package) and it’s worth further discussion. But I don’t want to confuse you in the very beginning. You just need to know that in the common practice, TreeShap calculates shap values very fast because it takes advantage of the conditional distribution from the tree structure of the model, but the use of conditional distribution can introduce the problem of causality[2].

Train an XGBoost Classifier

The example in this post is modified from the tutorial example in SHAP package and you can find the full code and notebook here. I first trained an XGBoost classifier. The dataset uses 12 features to predict if a person makes over 50K a year.

['Age', 'Workclass', 'Education-Num', 'Marital Status', 'Occupation', 'Relationship', 'Race', 'Sex', 'Capital Gain', 'Capital Loss', 'Hours per week', 'Country']

You can use the SHAP package to calculate the shap values. The force plot will give you the local explainability to understand how the features contribute to the model prediction for an instance of interest (Fig. 1). The summary plot will give the global explainability (Fig. 2). You can check Part 1 in the Jupyter Notebook. There is nothing new but just the common use of SHAP, so I will leave the details to you and jump to Part 2, shap values for the model loss.

Fig. 1. Force plot shows how each feature contributes to pushing the model output from the base value to the model output. Note the output is in log odds ratio space.
Fig. 2. SHAP summary plot gives the global explainability. A feature with a large magnitude means it has a more important impact on the prediction.

Explain the Log-Loss of the Model

Now the contribution to the model loss is more of interest, so we need to calculate shap loss values. In some sense, this is similar to residual analysis. The code snippet is as follows. Note that you need to

  • provide a background data since we use “interventional” approach. And the computational cost could be expensive. So you should provide a background data of a reasonable size (here I use 100).
  • Now the model_output is “log_loss”.
# subsample to provide the background data (stratified by the target variable)X_subsample = subsample_data(X, y)explainer_bg_100 = shap.TreeExplainer(model, X_subsample, 
feature_perturbation="interventional",
model_output="log_loss")
shap_values_logloss_all = explainer_bg_100.shap_values(X, y)

Force Plot

Now the fore plot for a data instance has a similar interpretation as that in Fig. 2, but in terms of log loss instead of prediction. A successful prediction (ground truth as True and prediction as True) is given in Fig. 3, while a wrong prediction (ground truth as True and prediction as False) in Fig. 4. You can see how the features with blue color try to reduce the logloss from the base value, and the reds increase the logloss. It’s noteworthy that the base values (expected values) of the model loss depend on the label (True/False) so it is a function instead of a single number. The calculation of expected values is by first setting all the data labels to True (or False), and then calculate the average log loss, for which you can check more details on the notebook. I am not sure if there is a particular reason for such a calculation of base values, but after all, the base values just serve as a reference value so I think it should not matter very much.

Fig. 3. Force plot for a data instance, ground truth is True and mode predicts True.
Fig. 4. Force plot for a data instance, ground truth is True and mode predicts False.

Summary Plot

Similarly, we have the summary plot for the model logloss (Fig. 5). This will tell you how the features contribute to the model logloss (the calculation is based on absolute mean). A feature with a large contribution means it contributes a lot to the model loss, could be increasing the logloss for some data instance or reducing the logloss for other data instances. Therefore, the summary plot here should show the consistency with the top features by shap values in Fig. 2. But we can see the ranking orders are a bit different. While “Relationship” remains the top one, the order of “Age”, “Education-Num”, “Capital Gain”, “Hours per week”, “Occupation” is different. And “Capital Gain” in Fig. 5 has a relatively large contribution than it does in Fig. 2. This suggests that “Capital Gain” plays an important role in reducing the log loss while relatively speaking it may not be that important for the model to make the prediction compared to “Relationship”. It’s noteworthy that the summary plot in Fig. 5 should be interpreted with cautions, since the bar plot in Fig. 5 is calculated based on absolute mean, which means both the effect of reducing logloss and increasing logloss are taken into account to rank the importance of a feature. In plain language, a large magnitude of (absolute) contribution may not necessarily mean a feature is a “good” feature.

Fig. 5. Similar to Fig. 2, but based on model logloss.

Of course, you can use the scatter summary plot instead of the bar summary plot to see the detailed distribution to dive deeper for your model debugging (i.e., improve your model performance). The other way I investigate it is to decompose the shap loss values into negative component (Fig. 6) and positive component (Fig. 7). And in terms of the model debugging, you want to achieve a more negative value and reduce the positive value for all the features since you wish all the features reduce the final model logloss.

Fig. 6. Sum of all the negative shap loss values for each feature.
Fig. 7. Sum of all the positive shap loss values for each feature.

Monitoring plot

Now we come to the most interesting part: use the shap loss value to monitor your model. Model drift and data drift are real-world problems that your model deteriorates and leads to unreliable/inaccurate predictions. But these usually happen silently, and it is very hard to identify the root cause. In a recent paper[3] by the SHAP author, they use the shap loss values to monitor the model health. The idea is very appealing and I wish to explore more on that. Note that the API is available but seems under ongoing development.

First we need to calculate the shap loss values for the training data and test data. In the context of monitoring, you need to calculate the shap loss values for dataset from different time-snapshot. You may recall that we have done this in the beginning of this section. But note that we use the background data sampled from the entire dataset. For the rationale of monitoring, it makes more sense to calculate the shap loss values for the training dataset and the test dataset separately, by using the background data from the training dataset and the test dataset. The code snippets are as follows:

# shap loss values for training data
X_train_subsample = subsample_data(X=X_train, y=y_train)
explainer_train_bg_100 = shap.TreeExplainer(model, X_train_subsample,
feature_perturbation="interventional", model_output="log_loss")
shap_values_logloss_train = explainer_train_bg_100.shap_values(X_train, y_train)# shap loss values for test data
X_test_subsample = subsample_data(X=X_test, y=y_test)
explainer_test_bg_100 = shap.TreeExplainer(model, X_test_subsample,
feature_perturbation="interventional", model_output="log_loss")
shap_values_logloss_test = explainer_test_bg_100.shap_values(X_test, y_test)

The monitoring plots for the top features are shown in Fig. 8. First all the data instances will be ordered by the index. And here we assume the index indicates the evolution of time (from left to right along the axis). In this toy example, we don’t have data from different time snapshot so we simply treat the training data as the current data and the test data as the future data we would like to monitor.

There are some important points to understand these monitoring plots, based on the current implementation in the SHAP pakcage. In order to see if the shap loss values are time-consistent, t-test will be repeatedly conducted to compare two data samples. The current implementation uses an increment of 50 data points to split the data. That means, the first t-test will compare data[0: 50] to data[50:]; and the second will compare data[0: 100] to data[100:], and so on. The t-test will fail if the p value is smaller than 0.05/n_features. In other words, it uses the confidence level of 95% and Bonferroni correction has been applied. Where the t-test fails, a vertical dash line will be plotted to indicate the location. A bit surprising, we see the monitoring plots show the inconsistency of shap loss values for [“Relationship”, “Education-Num”, “Capital Gain”], and that happens when we enter the time snapshot of test data (Fig. 8).

Fig. 8. Monitoring plots for top features. The training dataset and test dataset are concatenated to mimic data from different time snapshots. Note that the test data starts from index 26047.

The reason for the use of an increment of 50 data points is not very clear to me. And in this example, since we know [0:26048] is the training data, and [-6513:] is the test data. I modified the increment to 6500 and see if it will give a different result. But the monitoring plots still show the same inconsistency (i.e., the failure of t-test) when it comes to comparing the test data (Fig. 9).

Fig. 9. Monitoring plots for top features. Similar to Fig. 8 but now we use the increment of 6500 data points. The purpose is to compare the test data to the last “time segment” of the training data directly.

Finally, I think it’s a good idea to check the t-test on the training data and test data directly. And this verifies the conclusion again, the shap loss values are inconsistent between the training dataset and the test dataset.

# t-test for top features (assume equal variance)
t-test for feature: Relationship , p value: 2.9102249320497517e-06
t-test for feature: Age , p value: 0.22246187841821208
t-test for feature: Education-Num , p value: 4.169244713493427e-06
t-test for feature: Capital Gain , p value: 1.0471308847541212e-27
# t-test for top features (unequal variance, i.e., Welch’s t-test,)
t-test for feature: Relationship , p value: 1.427849321056383e-05
t-test for feature: Age , p value: 0.2367209506867293
t-test for feature: Education-Num , p value: 3.3161498092593535e-06
t-test for feature: Capital Gain , p value: 1.697971581168647e-24

The inconsistency of shap loss values between training data and test data is actually very unexpected, and can be troublesome. Remember that we simply use training/test split from the entire dataset, so there is a good reason to believe that training dataset and test dataset should be consistent, in terms of data distribution or shap loss values contribution. By any means, this is just a simple experiment and more investigations should be performed to draw any firm conclusion. But I think there may be some reasons why the SHAP package indicates the monitoring functionality is just preliminary, for example:

  • the use of an increment of 50 data points looks arbitrary to me;
  • the t-test looks very sensitive and can give many false alarms.

Another interesting discussion point is the use of background data. Note that for the monitoring plots, the shap loss values on the training dataset and the test dataset are calculated using different background data (subsamples from training dataset/test dataset). Since the “interventional” approach to calculate shap loss values is very expensive, I only tried the subsamples data of a size of 100 data instances. That could yield a high-variance result of the shap loss values. Perhaps a background data of a large size will reduce the variance and give the consistency of shap loss values in the monitoring plots. And when I used the same background data (subsamples from the entire dataset), there will not be inconsistency in the monitoring plot. So how you choose the background data matters a lot!

Conclusions and Discussions

I hope this post can give you a useful introduction to the shap loss values. You can better debug your ML models by investigating the shap loss values. It can also be a useful approach to monitoring your ML models for model drift and data drift, which is still a very big challenge in the community. But note the limitation: in order to use the shap loss values for monitoring, you need to have the ground truth for the new coming data, which is usually only available after a certain period. Also, unfortunately this functionality is still under development, and the appropriateness of the use of t-test needs to be further justified.

Last but not least, calculating shap values (TreeShap) by marginal distribution or conditional distribution can give different results (see the equations). The use of conditional distribution will introduce the problem of causality, while marginal distribution will provide unlikely data points to the model[4]. There seems no consensus about which one to use, depending on what scenarios[2,5]. This paper[6] has some interesting comments on this topic which I would like to quote here:

In general, whether or not users should present their models with inputs that don’t belong to the original training distribution is a subject of ongoing debate.

….

This problem fits into a larger discussion about whether or not your attribution method should be “true to the model” or “true to the data” which has been discussed in several recent articles.

llustration of the use of marginal distribution and conditional distribution to integrate out missing values (i.e., absent features). Here X1 is presented feature and X2, X3 are absent features. Reproduced from [2].

Thank you for your time. And don’t hesitate to leave any comments and discussions!

All the plots in this post are created by the author by using the SHAP package. Please kindly let me know if you think any of your work is not properly cited.

[1] Introduction to Responsible Machine Learning

[2] Janzing, D., Minorics, L., & Blöbaum, P. (2019). Feature relevance quantification in explainable AI: A causality problem. https://arxiv.org/abs/1910.13413

[3] Lundberg, S.M., Erion, G., Chen, H., DeGrave, A., Prutkin, J.M., Nair, B., Katz, R., Himmelfarb, J., Bansal, N. and Lee, S.I. (2020). From local explanations to global understanding with explainable AI for trees. Nature machine intelligence, 2(1), 2522–5839.

[4] https://christophm.github.io/interpretable-ml-book/shap.html

[5] Sundararajan, M., & Najmi, A. (2019). The many Shapley values for model explanation. arXiv preprint arXiv:1908.08474.

[6] Sturmfels, P., Lundberg, S., & Lee, S. I. (2020). Visualizing the impact of feature attribution baselines. Distill, 5(1), e22.

--

--