Explainable AI (XAI)

Model Complexity, Accuracy and Interpretability

Importance of model interpretability and how interpretability decreases as complexity increases

Ann Sajee
Towards Data Science
8 min readMay 7, 2020

--

The Balancing Act: Model Accuracy vs Interpretability | Credit: kasiastock/Shutterstock

Introduction

Complex Real-world challenges requires complex models to be build to give out predictions with utmost accuracy. However, they do not end up being highly interpretable. In this article, we will be looking into the relationship between complexity, accuracy and interpretability.

Motivation:

I am currently working as a Machine Learning Researcher and working on a project on Interpretable Machine Learning. This series is based of my implementation of the book by Christopher Molnar: “Interpretable Machine Learning”.

Series: Intepretable Machine Learning

Part 1: Model Complexity, Accuracy and Interpretability

We will be using real world dataset to showcast the relationship between complexity, accuracy and interpretability of a Machine Learning model. We will try out Machine learning models of increasing complexity and see how accuracy increases and interpretability decreases with it.

Part 2: Unboxing the “Black Box” Model

We will be using the most complex model as our final model as it gives very high accuracy and implementing Model-agnostic methods to interpret this model using the “Interpretable Machine Learning”, refer here:

https://medium.com/@sajee.a/unboxing-the-black-box-models-23b4808a3be5

Contents:

  1. Accuracy vs Interpretability
  2. Complexity and Accuracy
  3. Importance of Interpretability
  4. Implementation — How complexity increases interpretability decreases

Model Accuracy vs Interpretability

In real-world, while working on any problem its important to understand the trade-off between Model Accuracy and Model Interpretability. Business users want Data Scientists to build models with higher accuracy while Data Scientist face the issue to explain to them how these model makes predictions.

What is more important?? — Having a model that gives best accuracy on unseen data or understanding the predictions even when the accuracy is poor. Below we have a comparison of traditional models accuracy vs their ability to be interpretable.

Accuracy vs Interpretability

The graph shows some of the most used algorithms of Machine learning and how interpretable they are. The complexity increases in terms of how the Machine learning model works underneath. It can be parametric model (Linear Models) or non-parametric models (K-Nearest Neighbour), Simple Decision trees (CART) or Ensemble models (Bagging method — Random Forest or Boosting method— Gradient Boosting Trees). Complex models mostly give better accuracy in their predictions. However, interpreting them is more difficult.

Model Complexity and Accuracy

Typical accuracy-complexity trade-off

Goal of any supervised machine learning algorithm is to achieve low bias and low variance. However, its not possible in real life and we have a trade-off between Bias and Variance.

Linear Regression assumes linearity when in reality the relationship is quite complex. These simplifying assumptions give high Bias(train and test errors high) and the model tends to be underfit. High bias can be reduced by using a complex functions or adding more features. Thats when the Complexity increases and accuracy increases. At a certain point, the model will become too complex, and tend to overfit the training data i.e. low Bias but high Variance for test data. Complex models like Decision Trees tend to overfit.

There is usually a tendency to overfit a Machine learning model, hence, to overcome this we can use resampling technique (Cross Validation) to improve the performanceon unseen data.

Importance of Model Interpretability

In use cases when the impact of the prediction is high, understanding “Why” a certain prediction is made is really important. Knowing the ‘why’ can help you learn more about the problem, the data and the reason why a model might fail.

Reasons to learn about interpretability:

  1. Curiosity & Learning
  2. Safety Measure — Ensure learning is error-free
  3. Debugging to detect Bias in model training
  4. Interpretability increases social acceptance
  5. Debug and audit Machine learning models

Implementation:

Dataset — Bike Rental Prediction

Bike Rental Dataset can be found from UCI Machine Learning Respository: http://archive.ics.uci.edu/ml/datasets/Bike+Sharing+Dataset.

This dataset contains daily counts of rented bicycles from the bicycle rental company Capital-Bikeshare in Washington D.C., along with weather and seasonal information.

Goal: Predict how many bikes will be rented depending on the weather and the day.

Input Variables:

  1. Total_count (target): Count of total rental bikes including both casual and registered
  2. Yr: Year (0: 2011, 1:2012)
  3. Month: Month (1 to 12)
  4. Hr: hour (0 to 23)
  5. Temp: Normalized temperature in Celsius. The values are derived via (t-t_min)/(t_max-t_min), t_min=-8, t_max=+39 (only in hourly scale)
  6. Atemp: Normalized feeling temperature in Celsius. The values are derived via (t-t_min)/(t_max-t_min), t_min=-16, t_max=+50 (only in hourly scale)
  7. Humidity: Normalized humidity. The values are divided to 100 (max)
  8. Windspeed: Normalized wind speed. The values are divided to 67 (max)
  9. Holiday: Whether day is holiday or not
  10. Weekday: Day of the week
  11. Workingday: If day is neither weekend nor holiday is 1, otherwise is 0
  12. Season : Season (1:winter, 2:spring, 3:summer, 4:fall)
  13. Weather:
  • 1: Clear, Few clouds, Partly cloudy, Partly cloudy
  • 2: Mist + Cloudy, Mist + Broken clouds, Mist + Few clouds, Mist
  • 3: Light Snow, Light Rain + Thunderstorm + Scattered clouds, Light Rain + Scattered clouds
  • 4: Heavy Rain + Ice Pallets + Thunderstorm + Mist, Snow + Fog

Features:

Index(['month', 'hr', 'workingday', 'temp', 'atemp', 'humidity', 'windspeed','total_count', 'season_Fall', 'season_Spring', 'season_Summer','season_Winter', 'weather_1', 'weather_2', 'weather_3', 'weather_4','weekday_0', 'weekday_1', 'weekday_2', 'weekday_3', 'weekday_4','weekday_5', 'weekday_6', 'holiday', 'year'],dtype='object')

Exploratory Data Analysis:

Bike rides increases over a period of time

The number of bike rides increases over the period of 2 years from 2011 to 2012.

Correlation matrix

Windspeed and humidity have slightly negative correlation. Temp and atemp carry the same information and hence are highly positively correlated. So for building the model, we can use either temp or atemp.

Histogram of target: Most of the days bike rides have been around 20–30 rides/hr

Preprocessing:

Dropping features like causal, registered as they are same as total_count. Similarly, for features like atemp which is same as temp, dropping one to reduce multicollinearity. For categorical features, using OneHotEncoding method to transform into a format that works better with regression models.

Model Implementations:

We will be going through models with increasing complexity and see how the interpretability decreases.

  1. Multivariate Linear Regression (Linear, Monocity)
  2. Decision Tree Regressor
  3. Gradient Boosting Regressor

Multivariate Linear Regression:

Linear regression involving multiple variables is called “multiple linear regression” or “multivariate linear regression”.

Source

Goal of multiple linear regression (MLR) is to model the linear relationship between the explanatory (independent) variables and response (dependent) variable. In essence, multiple regression is the extension of ordinary least-squares (OLS) regression that involves more than one explanatory variable.

Regression comes with some assumptions that are not practical in real world datasets.

  1. Linearity
  2. Homoscedasticity (Constant variance)
  3. Independence
  4. Fixed features
  5. Absense of multicollinearity

Linear Regression implementation:

Linear Regression results:

Mean Squared Error: 19592.4703292543
R score: 0.40700134640548247
Mean Absolute Error: 103.67180228987019

Using Cross-validation:

Interpret Multiple Linear Regression:

To interpret Linear models is easier, we can look into the coefficients of each variable to understand its effect on the prediction and also the slope of intercept.

Intercept of the equation(Bo):

The intercept represents the value of y(target) when none of the features have any effect(x=0).

18.01100142944577

Coefficients corresponding to X.columns helps us understand the effect of each feature on the target outcome.

This means that increase in “temp” by a unit increases Bike rides by 211.05 units. Same applies for rest features

Decision Tree Regressor:

Decision trees work by iteratively splitting the data into distinct subsets in a greedy fashion. For regression trees, they are chosen to minimize either the MSE (mean squared error) or the MAE (mean absolute error) within all of the subsets.

CART — Classification and Regression Trees:

CART takes a feature and determines which cut-off point minimizes the variance of y for a regression task. The variance tells us how much the y values in a node are spread around their mean value. Splits are based on features that minimize the variance based on average of all subsets used in decision tree.

DecisionTreeRegressor(ccp_alpha=0.0, criterion='mse', max_depth=None,max_features=None, max_leaf_nodes=15,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=10,
min_weight_fraction_leaf=0.0, presort='deprecated',
random_state=None, splitter='best')

Decision tree results:

Decision Tree Regression results gives better fit to the data.

Mean Squared Error: 10880.635297455
R score: 0.6706795022162286
Mean Absolute Error: 73.76311613574498Decision tree split:

Decision tree has a better fit to the model than Linear Regression. The R square value is about 0.67.

Using Cross-Validation:

Decision tree graph:

Decision Tree Regressor output

Interpret Decision Trees:

Feature Importance:

Feature importance is based on the one that reduces the maximum variance for all the splits the feature was used. A feature might be used for more than one split or not at all. We can add the contributions for each of the p features and get an interpretation of how much each feature has contributed to a prediction.

We can see that features: hr, temp, year, workingday, season_Spring are the features that used to split the decision tree.

Decision Tree Regressor — Feature Importance Bar chart

Gradient Boosting Regressor:

Boosting is an ensemble technique in which the predictors are not made independently, but sequentially. Gradient Boosting uses Decision tree as weak models.

Boosting is a method of converting weak learners into strong learners by training many models in a gradual, additive and sequential manner and minimizing Loss function (i.e squared error for Regression problems) in the final model.

GBR has better accuracy than other Regression model because of its Boosting technique. It is the most used Regression algorithm for competitions.

GradientBoostingRegressor(alpha=0.9, ccp_alpha=0.0, criterion='friedman_mse',init=None, learning_rate=0.1, loss='ls', max_depth=6,max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=100, n_iter_no_change=None, presort='deprecated',random_state=None, subsample=1.0, tol=0.0001,validation_fraction=0.1, verbose=0, warm_start=False)

The result from GBR is as below:

Mean Squared Error: 1388.8979420780786
R score: 0.9579626971080454
Mean Absolute Error: 23.81293483364058

The Gradient Boosting Regressor gives us the best R2 square value of 0.957. However, to interpret this model its very difficult.

Interpret Ensemble Model:

Ensemble models definitely fall into the category of “Black Box” models since they are composed of many potentially complex individual models.

Each tree in sequentially fashion is trained on bagged data using random selection of features, so gaining a full understanding of the decision process by examining each individual tree is infeasible.

Part 2: Model Agnostic method to interpret Gradient Boosting Regressor model — https://medium.com/@sajee.a/unboxing-the-black-box-models-23b4808a3be5

References:

https://christophm.github.io/interpretable-ml-book/

--

--