The world’s leading publication for data science, AI, and ML professionals.

Addressing the Issue of “Black Box” in Machine Learning

4 must-know techniques to create more transparency and explainability in model predictions

Model Interpretability

Photo by Will Porada on Unsplash
Photo by Will Porada on Unsplash

There is no doubt that machine learning models have taken the world by storm in recent decades. Their ability to identify patterns and generate predictions that far exceed any other form of statistical technique is truly remarkable and hard to contend with.

However, despite all of its promising advantages, many still remain sceptical. Specifically, one of the main setbacks that machine learning models struggle with is the lack of transparency and interpretability.

In other words, although machine learning models are highly capable of generating predictions that are very robust and accurate, it often comes at the expense of complexity when one tries to inspect and understand the logic behind those predictions.

Our goal in this article is to unpack and address the issue of black-box models by answering two fundamental questions:

  • What features in the data did the model think are most important?
  • How does each feature affect the model’s predictions in a big picture sense as well as on a case by case basis?

To help us answer those questions, we will be exploring 4 unique techniques and discuss how each of them can be used to create more transparency in model predictions:

  • Feature importance
  • Permutation importance
  • Partial dependence plots
  • SHAP values

So, if you are ready to start peeling back and examining how exactly your model is using input data to make predictions, let’s begin!

The reference notebook to this article can be found here.


Data description and preparation

The medical cost personal dataset provides the insurance costs for 1,338 policyholders living in the United States along with their personal information such as:

  • Age
  • Sex
  • Body mass index (BMI)
  • Number of children
  • Smoking status
  • Residential region
First 5 rows of the data
First 5 rows of the data

Since insurance charges is a continuous variable, this is a regression problem. I have also decided to use a random forest regressor as our model of choice.

For the purpose of this article, let’s not worry too much about which model to use but instead direct our attention more towards learning how to interpret the results of our model predictions.

# Feature encoding 
data['sex'] = data['sex'].map({'female': 1, 'male': 0})
data['smoker'] = data['smoker'].map({'yes': 1, 'no': 0})
data = pd.get_dummies(data)
# Predictor and target variables
X = data.drop('charges', axis = 1)
y = data.charges
# Train test split
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size = 0.2, random_state = 42)
# Fit random forest model
rf = RandomForestRegressor(random_state = 42).fit(X_train, y_train)

Feature importance

I wrote an article a few months ago on the topic of feature selection dimensionality reduction. One of the techniques that was discussed in that article is called feature importance.

Essentially, when using any tree-based models e.g. decision tree, random forest or gradient boosting, each feature will be assigned a "feature importance" which highlights how important that feature is to the model when making predictions.

To understand the intuition behind feature importance, it is important that we first talk about how the decision tree algorithm actually works.

Image credits to Dan Becker, Kaggle
Image credits to Dan Becker, Kaggle

Recall that when a decision tree model is fitted to the training data, each node in the decision tree represents a single feature that is used by the model to split the data into groups. The goal here is to split the data in such a way that data with similar values in their target variable end up in the same group after each split.

In other words, we want data that exhibit the same characteristics to be grouped together so that when new unseen data comes in, our model will be able to predict its final value by traversing from the top of the tree down to the bottom of the tree.

How exactly does the model decide which feature to use for each node? Well, the decision tree algorithm is optimised to reduce the Gini impurity (classification) and variance (regression). Gini impurity informs the probability of misclassifying an observation whereas variance informs how far away a set of observations is from the mean. Therefore, at each layer of the tree, the model will determine the feature that gives the lowest impurity score or the lowest variance.

If you are curious to learn more about how tree-based models work, check out my article below:

Battle of the Ensemble – Random Forest vs Gradient Boosting

Now, going back to feature importance. Feature importance measures how useful a particular feature is at helping our model minimise the loss function. The better a feature is at separating data points into distinct categories, the more important that feature is to the model.

Here is how to compute and visualise feature importance in practice.

# Calculate feature importance
importances = rf.feature_importances_
std = np.std([tree.feature_importances_ for tree in rf.estimators_], axis = 0)
indices = np.argsort(importances)[::-1]
# Plot feature importance
plt.figure(figsize = (10, 5))
plt.title("Feature importances")
plt.bar(range(X_train.shape[1]), importances[indices], yerr = std[indices])
plt.xticks(range(X_train.shape[1]), X_train.columns[indices], rotation = 90)
plt.show()
Feature importance
Feature importance

As we can see from the graph, the 3 most important features are:

  • Smoker
  • BMI
  • Age

This result seems reasonable as smoking habits, BMI and age are all very common indicators of a person’s health condition and as a result, the amount they pay for their health insurance premium.


Permutation importance

Another technique similar to feature importance is called permutation importance.

Effectively, once our model has been trained, permutation importance involves shuffling a single feature in the validation data, leaving the target and other columns in place, and subsequently evaluate the accuracy of predictions in that now-shuffled data.

The underlying reason why this method works is because if we shuffle a feature that is not important to our model when making predictions, the accuracy will remain fairly consistent throughout the permutation process. On the flip side, if the accuracy deteriorates as a result of the reshuffling, that indicates that the feature is important to the model.

Let’s see how this works in practice.

# Permutation importance 
perm = PermutationImportance(rf, random_state = 42).fit(X_val, y_val)
eli5.show_weights(perm, feature_names = X_val.columns.tolist())
Permutation importance
Permutation importance

As we can see, the result we obtain from permutation importance is almost identical to that under feature importance. The values towards the top of the table are the most important features and those towards the bottom matter the least.

As a result of reshuffling, there will be some level of randomness. Hence, the number after the plus-minus sign accounts for this randomness as it measures how performance varied from one reshuffling to the next.


Partial dependence plots

So far, we have touched upon two techniques that help us determine what variables most affect model predictions.

What if we want to take this one step further and examine how they affect predictions? The next two techniques that we will look at can help us accomplish this task.

In order to plot a partial dependence plot, we repeatedly alter the value for the variable that we are interested in and make a series of predictions based on those values.

Let’s now look at two practical examples. Suppose we would like to know how BMI and age affect a policyholder’s insurance premium.

# Create data to plot 
pdp_bmi = pdp.pdp_isolate(model = rf, dataset = X_val, model_features = X_train.columns, feature = 'bmi')
# Plot data
pdp.pdp_plot(pdp_bmi, 'bmi')
plt.show()
Partial dependence plot for BMI feature
Partial dependence plot for BMI feature

As we can see, premium levels start to rise around BMI = 25, dips slightly before increasing again.

# Create data to plot
pdp_age = pdp.pdp_isolate(model = rf, dataset = X_val, model_features = X_train.columns, feature = 'age')

# Plot data
pdp.pdp_plot(pdp_age, 'age')
plt.show()
Partial dependence plot for age feature
Partial dependence plot for age feature

On the other hand, we see premium levels constantly increasing with age though at different rates for:

  • Age before 40
  • Age between 40–60
  • Age 60 and above

We can also perform a bivariate or 2-dimensional PDP which shows us the interaction between two features.

# 2D partial dependence plot 
features_to_plot = ['bmi', 'age']
inter = pdp.pdp_interact(model = rf, dataset = X_val, model_features = X_train.columns, features = features_to_plot)
pdp.pdp_interact_plot(pdp_interact_out = inter, feature_names = X_train.columns, plot_type = 'contour')
plt.show()
2D partial dependence plot
2D partial dependence plot

I hope by now you can see and appreciate the additional layer of complexity that a partial dependence plot offers.

Using this technique, not only do we now know which features are important to our model but more importantly, how those features actually affect model predictions.


SHAP values

SHAP is short for Shapley Additive Explanations. It breaks down an individual prediction to demonstrate the impact of each contributing feature. This enables us to study model predictions on a more case by case basis.

For example, this technique can be particularly useful for financial institutions such as banks and insurance companies. If a bank turns down a customer for a loan, it is legally required to explain the basis for the loan rejection. Similarly, if an insurance company chooses not to underwrite someone with an insurance cover, the company is obligated to give the person their reasons.

SHAP values interpret the impact of having a certain value for a feature in comparison to the prediction had that feature took some baseline value.

So, each feature in the dataset will have its own SHAP value and the sum of all SHAP values will explain the difference between a particular prediction and the base value, where the base value is the average of the model output over the training data. Furthermore, we can also visualise the decomposition of SHAP values using a graph.

Suppose we would like to examine the SHAP values for the first 3 rows of the validation data.

First 3 rows of the validation data
First 3 rows of the validation data

Firstly, we need to define a function to calculate and visualise the SHAP value decomposition.

# Define SHAP plot function 
shap.initjs()
def shap_plot(index):
    explainer = shap.TreeExplainer(rf)
    shap_values = explainer.shap_values(X_val)
    res = shap.force_plot(explainer.expected_value, shap_values[index], X_val.iloc[index])
    return res

Once our function is ready, we just need to simply pass in the index of the row that we want to investigate. Let’s start with the first row of the validation data.

# First row 
shap_plot(0)
First row of the validation data
First row of the validation data

So, how exactly do you interpret this?

Well, first of all, the model predicts the insurance charges to be 9,543.19 which is lower than the base value of 1.355e+4.

Feature values that cause increased predictions are in pink whereas feature values that cause decreased predictions are in blue. The length of each coloured bar shows the magnitude of the feature’s effect on the final prediction.

In this particular scenario, because the predicted value is lower than the base value, the length of all the blue bars combined is greater than that of the pink. Specifically, if you subtract the pink bars from the blue bars, it equals the distance from the base value to the predicted value.

Moreover, we can also observe that smoker = 0 (i.e. non-smoker) is the most significant factor in driving down the insurance charges for this particular individual.

Let’s now look at two more examples.

# Second row
shap_plot(1)
Second row of the validation data
Second row of the validation data
# Third row
shap_plot(2)
Third row of the validation data
Third row of the validation data

To summarise, in this article, we have discussed 4 different techniques to inspect and interpret model predictions:

  • Feature importance
  • Permutation importance
  • Partial dependence plots
  • SHAP values

Hopefully, with this new knowledge, you can now start to examine not only which features are important to your model but also how exactly they impact model predictions.

I guess machine learning is that much of a Black Box after all, are they?

Thank you so much for reading. Feel free to connect or follow me for more future content!


Follow me


Related Articles