Understanding your Neural Network’s predictions

A concise guide to assessing your neural network’s feature importance.

Antoine Villatte
Towards Data Science

--

Image by author

Neural networks are extremely convenient. They are usable for both regression and classification, work on structured and unstructured data, handle temporal data very well, and can usually reach high performances if they are given a sufficient amount of data.

What is gained in convenience is, however, lost in interpretability and that can be a major setback when models are presented to a non-technical audience, such as clients or stakeholders.

For instance, last year, the Data Science team I am part of wanted to convince a client to go from a decision tree model to a neural network, and for good reasons : we had access to a large amount of data and most of it was temporal. The client was on board, but wanted to keep an understanding of what the model based its decisions on, which means evaluating its features’ importance.

Does it make sense ?

That is debatable. With a decision tree or a boosting model, the features’ importance can be directly retrieved with the fitted attribute feature_importances_ for most decision trees or the get_booster() and get_score()methods for XGBoost models.

For a neural network, these attributes and methods do not exist. Each neuron is trained to learn when to activate or not based on the signal it receives, so that each layer extracts some information — or concept — from the original input, up until the final prediction layer. Therefore, the usefulness of retrieving the features’ importance of a more “black-box” kind of model is questionable.

I’ve even heard deep learning experts say that it is best to let the data do the talking, and not to try to understand the model too much. Basically, is it useful to know whether a cat’s fur is more impactful for the neural networks than its eyes? Maybe not. But it is useful to know that, for the model, a cat on a table is no less a cat than one on the floor, and that’s what we’ll do here.

Permutate, pertubate, and evaluate

We’ll use the permutation importance method. For classic machine learning models, Scikit-Learn provides a function to do that, and even recommends it when dealing with high cardinality features. If you want to use this function on your model, this code snippet will compute and display its permutation importance :

The principle behind permutation importance

Let’s say you have several students, and you want to evaluate their likelihood of passing a math exam. To do so, you have access to 3 variables : the time they spent studying for the exam, their ease in math, and their hair color.

Student data for a math exam. Image by author

In this example, Paul studied a lot, and is moderately gifted in math. He is very likely to succeed to his math exam. Mike, on the other hand, studied much less and is not very gifted, he is unlikely to succeed. Bob didn’t study at all, but is extremely gifted, he therefore has his chances.

Let’s permutate the values of the “Study Time” feature :

Impact of shuffling the 1st column. Image by author

Paul went from studying a lot to not studying at all. His moderate ease in math will not be enough to compensate, and he is now unlikely to pass. Likewise, the other students had their likelihood of success highly impacted by this perturbation.

We can therefore infer that the study time is an important feature to predict the exam’s outcome.

We get the same result when we perturbate the ease in math feature :

Impact of shuffling the 2nd column. Image by author

Bob has now become ungifted in math, and hasn’t studied at all. It is extremely unlikely that he passes the exam.

With the same reasoning as before, this feature is also important.

Now, when we permutate the hair color feature :

Impact of shuffling the 3rd column. Image by author

Mike’s going from blond to dark haired will not improve his chances for the exam, nor will any hair color change will have any impact on any student. This feature therefore has no importance on our prediction.

Limits of this method

Let’s say that out of 100 students, we have one cheater that managed to get his hand on the test subject, which guarantees him a pass on the exam. If we permutate the “cheater” column, we’ll have only one student going from cheater to non-cheater, and one other student that goes from non-cheater to cheater. Out of 100 students, only two will be impacted, and we’ll wrongfully consider this feature as unimportant because of its low prevalence.

Therefore, this method will not work well on unbalanced binary features and on rare modalities of categorical features. For these cases, it is better to set the whole column to the rare value and see how it impacts the prediction (in our analogy, that would mean setting the “cheater” column to True for every student).

Implementation

The first step is to make an unperturbated inference on your testing set. Then, for each feature, you’ll shuffle it randomly and make what I’ll call a perturbated inference.

Once all the perturbated inferences are made, concatenate them in a single dataframe, and then calculate, for each observation, how far each of them has deviated compared to the original prediction.

From there, a good way to visualize the impact of each perturbation is to make a box-plot of all the observations’ deviations.

Let’s use, for instance, the Kaggle dataset for the Home Credit Default Risk competition.
After the pre-processing and training stages, I got two datasets,X_test which contains the static data for the testing set and X_test_batch which contains the temporal data for the training set.

The following snippet goes through every feature and creates a perturbated inference :

Then, this code snippet will compute the deviation from the original inference for each perturbation :

Finally, this code snippet will print the feature importance :

You should get a plot like this :

If you want to see the whole data science pipeline, I have made a public docker image that contains all of the steps from the raw data up to the feature importance plot here : https://hub.docker.com/r/villatteae/neuralnet_feat_importance/tags

Simply run the following docker commands :

docker pull villatteae/neuralnet_feat_importance:latest

docker pull villatteae/neuralnet_feat_importance
docker run -p 10000:10000 -d villatteae/neuralnet_feat_importance

The image will run on your localhost:10000 address. The username and password for the instance are admin and admin. Note that the image is quite heavy (~17 GB).

--

--