The world’s leading publication for data science, AI, and ML professionals.

Top 5 techniques for Explainable AI

Not 1, not 2, but 5 techniques for Explainable AI

Photo by Drew Dizzy Graham on Unsplash
Photo by Drew Dizzy Graham on Unsplash

Imagine you are a medical professional and you are using AI for stroke prediction. The AI has predicted stroke for one of your patients. When you will tell this to your patient, he is going to panic. There are two obvious questions, which he will ask.

  • First, can you explain why the stroke was predicted?
  • Second, what can be done to avoid it?

In order to answer these questions, you will need Explainable Ai. Such situations are applicable for various domains such as healthcare, credit risk, product recommendation, and many others.

So in this article, we will go through the stroke prediction example to understand explainable AI. I will demonstrate explainable AI using the techniques

  • Explaining with Data Visualisation
  • Explaining with Logistic Regression Machine Learning model
  • Explaining with Decision Tree machine learning model
  • Explaining with Neural Network machine learning model
  • Explaining with SHAP

This is not an exhaustive list as there are many other techniques. However the above are commonly used ones.

Patient Situation

Let us start by understanding the patient situation for whom the stroke is predicted. The patient is a male, 67 years of age. He has a medical history of heart diseases, an average glucose level of 228, Body mass index of 37. AI has made a prediction that this patient will suffer a stroke.

Now as a medical professional, you will have to be ready to explain the prediction and let us start with the different techniques which will help you to do so.

Explaining with Data Visualisation

Here we will not use any sophisticated ways, but a simple data visualization technique. The idea is simple.

  • First, take the data which was used for training AI. The training data will have patients which suffered a stroke as well as those who did not.
  • Using visualization techniques, analyze what is the difference between the patients who suffered a stroke and those who did not. This will help you understand the factors that cause a stroke.
  • You can then compare it with the patient situation. This will help you understand why the patient is at risk.
Explaining with Visualisation Approach (image by author)
Explaining with Visualisation Approach (image by author)

One of the best ways to analyze two different groups is using a radar plot. Shown below is a radar plot of the data which was used to train the AI for stroke prediction. The blue area corresponds to patients who have not suffered a stroke. The orange area corresponds to patients who have suffered a stroke. The difference between the two areas gives us insight into the reasons for stroke.

Radar plot on data which was used to train AI (image by author)
Radar plot on data which was used to train AI (image by author)

We observe that patients who had a stroke have a high glucose level, have heart disease, have hypertension, are relatively high in age, and have formerly smoked. This gives us a good understanding of the reasons for stroke.

Now we can compare this with our patient situation. We see that the patient is fulfilling most of the conditions which would lead to a stroke. This can help us to explain why the patient is at risk.

Patient Situation as compared to stroke-causing factors (image by author)
Patient Situation as compared to stroke-causing factors (image by author)

Though this is a very simple and efficient technique, we do not know which of the stroke-causing factors are more important than others.

Let us now investigate the second technique, which will help us solve this problem.

Explaining with a machine learning model

The AI which has predicted the stroke is based on a machine learning model. You can use these models, to explain the stroke prediction. There are various machine learning models, and I will demonstrate the following three models.

  • Logistic Regression
  • Decision Tree
  • Neural Network

Let us start with a logistic regression model

Logistic Regression Model

The result of a logistic regression model is shown here. On the Y-axis, you have different factors and on X-axis you see the importance of the factor for a stroke condition.

Explaining using Logistic regression (image by author)
Explaining using Logistic regression (image by author)

The blue bar indicates the factors which lead to a stroke and red shows the factors which help in avoiding a stroke. For example, age, working for a private company, high glucose levels, heart disease, and hypertension can result in a stroke. Being married, being self-employed, living in rural areas can reduce the probability of having a stroke.

The patient’s situation is shown as a grey bar. We see that patient’s age, his private work, glucose level, and his heart condition are factors that explain why he is at risk of stroke. His BMI (body-mass index) is low and not a reason for a potential stroke.

This approach is a more accurate way compared to the earlier radar plot technique. We are now able to precise which are the top factors leading to a potential stroke.

Logistic regression models are an excellent way to explain any predictions, as it helps us to identify which factors are more important than others. For example, here we know that patients’ glucose level is more important than the patient’s body-mass index.

However, what exactly is the glucose level that is leading the patient to a potential stroke? Let us answer this question using the decision tree model.

Decision Tree Model

Shown here is the result of the decision tree model. This model has different decision nodes shown as black points. These decision nodes indicate a field and threshold values. The final node is whether a patient will have a stroke or not. The decision path for the patient is shown in green color.

Explaining using Decision tree (image by author)
Explaining using Decision tree (image by author)

With the decision tree model, we know the factors causing stroke as well as the threshold value. For example, we see that the patient’s decision path goes through the node marked as average glucose level and the threshold value of 104.47. So, we know now that a glucose level of more than 104.47 is going to cause a potential stroke for our patient.

This is more precise than the logistic regression model earlier shown. This can also help in advising the patient on glucose level and taking necessary actions to reduce it below 104.47

Thus we are not only explaining the reasons of potential stroke, but also advising the patient on steps to take to reduce the risk.

Neural Network model

Now let us move to neural network machine learning models. Shown below is a neural network for stroke prediction. It consists of input neurons, intermediate neurons, and output predictions. Each neuron sends a positive signal and a negative signal to other neurons. The positive signal is indicated by a grey line and a negative signal is indicated by the red line.

The green bordered circle indicate activated neurons during prediction for the patient.

You see neuron 1_4 is strongly activated shown here by a thick green circle. This node is sending a negative signal towards NO STROKE neuron. This means that patient has a high probability of stroke.

Now let us "peek" inside the neuron 1_4 to see what is "inside" this neuron. The "inside" of the neuron is shown as a radar plot. The radar plot is based on input neurons and signals received from them. Blue means a positive signal from the input neuron and red means a negative signal from the input neuron.

Inside the neuron 1_4 (image by author)
Inside the neuron 1_4 (image by author)

We see a very high signal from input neurons for glucose level, heart disease, and age. This means that for the patient, the glucose level, heart disease, and age are the conditions that are leading to a stroke prediction. These explanations are similar to what we have obtained using earlier models.

SHAP

Now all the three models above give similar explanations, however, the way in which models are interpreted is very different. So enter SHAP, which stands for SHapley Additive exPlanations. It is capable of explaining any model.

SHAP helps in having a single approach to understanding the predictions irrespective of the machine learning model used

The SHAP of our stroke prediction is shown below. The visualization will be the same irrespective of the machine learning model used.

SHAP for patient stroke prediction
SHAP for patient stroke prediction

You see two vertical dotted lines. The left grey dotted line corresponds to an expected probability that a stroke can happen to any patient in the data. We see that this is approximately 50%. This is due to the fact that in the data, which was used for machine learning, about 50% of patients has a stroke and 50% did not have. The right red dotted line corresponds to our patient. This line is at 74.8%, as our patient has a 74.8% probability of stroke. The horizontal bars represent different factors and how they add up from moving the patient from 50% to 74.8% probability of stroke.

The name SHAP, also comes from the nature of this analysis which shows how different factors add up leading to the predicted outcome. Shown below is the contribution of each factor towards stroke prediction.

Contribution of each factor towards stroke prediction (image by author)
Contribution of each factor towards stroke prediction (image by author)

We see that age is contributing 15% to the probability of stroke, the glucose level is contributing 10% to the probability of stroke and BMI (body-mass index) is contributing 4.4% to the probability of stroke.

Now because we know the contribution of each factor to the probability of stroke, we can also use this analysis for any advice to the patient. The age cannot be changed, but If the average glucose level and the body-mass index can be reduced, then the risk of stroke can go down from 74.8% to 60%.

Factors that can help reduce stroke risk (image by author)
Factors that can help reduce stroke risk (image by author)

So we are able to explain the predictions, as well as advice on what can be done to reduce the risk and by how much.

Summary

Here is a summary of the techniques described in this article

Summary of techniques (image by author)
Summary of techniques (image by author)

As you can see that all these explainable AI techniques are not "nice-to-have", but mandatory. Using these techniques will help you better communicate with the person impacted through AI decisions. In some cases, as seen in the stroke prediction example, understanding these techniques can help improve or save lives.

Additional Resources

Website

You can visit my website to make analytics with zero coding. https://experiencedatascience.com

Please subscribe to stay informed whenever I release a new story.

Get an email whenever Pranay Dave publishes.

You can also join Medium with my referral link.

Join Medium with my referral link – Pranay Dave

Youtube channel Here is link to my YouTube channel https://www.youtube.com/c/DataScienceDemonstrated


Related Articles