Understanding Machine Learning Models Better with Explainable AI

Building an interactive dashboard in few lines of code with ExplainerDashboard

Devashree Madhugiri
Towards Data Science

--

Image Source: Author (Designed in Canva)

It is interesting to decipher the working of Machine Learning through a web-based dashboard. Imagine gaining access to the interactive plots displaying information on model performance, feature importance as well as What-if analysis. What is exciting is that one does not need any web development expertise to build such an informative dashboard but simple few lines of python code are sufficient to generate a stunningly interactive Machine Learning Dashboard. This is possible by using a library called ‘Explainer Dashboard’.

The ExplainerDashboard is a python package which generates interactive dashboards which allow users to understand as well as explain how the model works and how it is deciding the outcome. Without such a tool, a machine learning model is a “Black Box model” . Hence, it becomes difficult to explain the reason behind the decision made by the model and which factors are impacting its decision making. Using this package in Python, we can easily deploy a dashboard as a web app in Jupyter or Colab notebooks. This web app provides several interactive plots to explain the workings of the machine learning model to which it has been applied. These models could be based on Scikit-learn, XGBoost, CatBoost, LightGBM and some others. The dashboard provides insights on the model performance through various interactive
plots like the SHAP value plots to understand the feature dependence,
SHAP Interaction plots for feature interactions, Partial Dependence Plots (PDP) for selected feature impact as well as visualizing the decision path through Decision Trees, etc. It is also possible to evaluate the effect of changing a feature value on model performance through a ‘What-if’ analysis. Further, it is also possible for advanced users to customize these dashboards with some additional code. For this tutorial, we will just explore the functionality of the package and I will walk you through a step-by-step approach to create your own machine learning model dashboard in Python with a few lines of code. Towards the end of the tutorial, we’ll see what insights we can gather from this machine learning dashboard.

ExplainerDashboard Library

This Python package builds up a web app or inline notebook based explainable dashboard for a machine learning model. The default components of the dashboard are self-explanatory and no additional defining functions are required. These plots or the interactive components of the dashboard are based on another library called Dash, which is well-known for web-apps and for the plots, a library known as ‘ Plotly’ is being used. Finally, the entire dashboard is actually a web-app which runs on your local machine using flask server. You can find all the official documentation for the ExplainerDashboard here.

Let us begin the tutorial by a pip installation of the ExplainerDashboard library

pip install explainerdashboard

You can use this command in Colab or a Kaggle notebook directly. However, if using a Jupyter notebook running on a local machine, using a virtual environment might be a better choice to avoid any conflict due to the package dependencies.

Approach:

  • Import the libraries and the sample dataset
  • Create DataFrame from the dataset (No data preprocessing required for this demo tutorial)
  • Split the data and train the model
  • Deploy the dashboard on a local port

Regression Example Machine Learning Dashboard

For this tutorial, let us use the sample toy dataset ‘Diabetes’ dataset from the sklearn library to build the Machine Learning dashboard for a Regression problem.

Importing the Libraries

#Importing Libraries & Packagesimport pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from explainerdashboard import RegressionExplainer, ExplainerDashboard

We require the Pandas library for the DataFrame while the ExplainerDashboard and Dash Bootstrap libraries are for building the dashboard. The sklearn library would be used to get the toy dataset, split it and import the RandomForestRegressor to train the model for this regression example.

Importing the dataset

#Import the Diabetes Dataset
from sklearn.datasets import load_diabetes
data= load_diabetes()#print the dataset
data

Loading the dataset

We need to load the dataset onto the X and y variables to create a Pandas DataFrame. X will hold the features and y will hold the target values.

#create a DataFrame from the dataset
X=pd.DataFrame(data.data,columns=data.feature_names)
#Printing first five rows of the DataFrame
X.head()
#Load target values in y
y=pd.DataFrame(data.target,columns=[“target”])
y.head()

Now our data is ready and we can train the model using RandomForestRegressor.

Splitting the dataset

Let us split the dataset in 80–20 ratio using the train-test split function from sklearn.

#Splitting the Dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
print(X_train.shape,y_train.shape,X_test.shape,y_test.shape)

Training the model

We can now train the model with RandomForestRegressor with randomly selected values for estimators. You can also try with a different value or use XGBoost instead to train the model and compare.

#Training the model
model = RandomForestRegressor(n_estimators=50, max_depth=5)
model.fit(X_train, y_train.values.ravel())

Note: We are using the recommended command ‘ravel()’ convert ‘y_train’ to a 1d array in this step.This reshaping the column-vector y will avoid the DataConversionWarning generated by the RandomForestRegressor.

Setting up the Dashboard instance using the trained model

explainer = RegressionExplainer(model, X_test, y_test)#Start the Dashboard
db = ExplainerDashboard(explainer,title=”Diabetes Prediction”,whatif=False)
#Running the app on a local port 3050
db.run(port=3050)

This Dashboard will run on a local server with port 3050( or might be different in your case if you choose another port number, say 8080 or 3000)

Clicking on the link will open the dashboard in a separate tab in your web browser.

Your fully interactive Machine Learning Dashboard is ready!

You can find the complete code for this notebook on my GitHub repository.

Dashboard GIF by Author

Insights from the Dashboard

With this Dashboard, we can get some insights like-

  • Shap Values which indicate how each individual feature affects the prediction
  • Permutation importances which allow us to dig deeper to visualize how the model performance deteriorates with shuffling of a feature
  • In the case of a Regression model using XGBoost or RandomForestRegressor similar to this tutorial, we can visualize the individual decision trees whereas in case of Classifier models, we can get confusion matrix, ROC-AUC curves etc. to understand the models decisions better.
  • What-If (in case turned on while starting the dashboard) to help understand the changes in the model behavior if we modify the features or parts of the data. It also allows us to compare different models.

However, it is also helpful to have some basic understanding of the above plots and the parameters they include to make sense of the insights from such a machine learning dashboard. For anyone looking for detailed information on the theory for this tutorial topic, I would recommend reading the book ‘Interpretable Machine Learning’ by Christoph Molnar.

That’s it for this tutorial. Hope you as a reader learnt something new and interesting.

Until the next article, Happy Reading!

--

--

Data Enthusiast focusing on applications of Machine Learning and Deep Learning in different domains.