Beautiful visual model interpretation of classification strategies— Kannada MNIST Digits Recognition

subtitle here

Mirza Rahim Baig
Towards Data Science

--

Kannada handwritten digits

The Kannada MNIST dataset is a great recent work (details here), and I’m delighted that it is available to the public as well. I’m sure pretty soon the community here would be posting state of the art accuracy numbers on this dataset. Which is why, I’m doing something different.

Instead, we will try to visualize, try to see what the model sees, assess things pixel by pixel. Our goal would be interpretability. I’ll start with the ‘simplest’, easiest to interpret algorithm in this article. Hopefully, I’ll post results with other modeling techniques in later article.

To reiterate and clarify: I will not be focusing on getting best possible performance. Rather, I’ll focus on visualizing the output, making sense of the model, and understanding where it failed and why. Which is more interesting to assess when the model isn’t working extremely well. :)

Visualizing the digits data

Function to plot one random digit along with its label

def plot_random_digit():
random_index = np.random.randint(0,X_train.shape[0])
plt.imshow(X_train[random_index], cmap='BuPu_r')
plt.title(y_train[random_index])
plt.axis("Off")
plt.figure(figsize=[2,2])
plot_random_digit()
A random Kannada digit plotted as image

Looking at 50 samples at one go

plt.figure(figsize=[10,6])
for i in range(50):
plt.subplot(5, 10, i+1)
plt.axis('Off')
if i < 10:
plt.title(y_train[i])
plt.imshow(X_train[i], cmap='BuPu_r')

As someone who is not good at reading Kannada script, to me the symbols seem somewhat similar for -

  • 3 and 7
  • 6 and 9

At the onset, I would expect that the predictors could be somewhat confused between these pairs. Although this isn’t necessarily true — maybe our model can identify the digits better than I can.

Reshaping the datasets for predictive model building

The individual examples are 28 X 28. For most predictive modeling methods in scikit learn, we need to get flatten the examples to a 1D array.
We’ll use the reshape method of numpy arrays.

X_train_reshape = X_train.reshape(X_train.shape[0], 784)
X_test_reshape = X_test.reshape(X_test.shape[0], 784)

Building and understanding the Logistic regression model

Let’s build a Logistic regression model for our multiclass classification problem.

Note again that we’ll not be focusing on getting the best possible performance, but on how to understand what the model has learnt.

A logistic regression model will be easy and interesting to analyse the coefficients to understand what the model has learnt.
The formulation of a multi-class classification can be done in a couple of ways in SciKit-learn. They are -

  • One vs Rest
  • Multinomial

1. One vs Rest:

Also known as one-vs-all, this strategy consists in fitting one classifier per class. For each classifier, the class is fitted against all the other classes. One advantage of this approach is its interpretability.

Since each class is represented by one and one classifier only, it is possible to gain knowledge about the class by inspecting its corresponding classifier. This is the most commonly used strategy for multi-class classification and is a fair default choice.

For our case, it would mean building 10 different classifiers.

Read more about it here:
https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html

2. Multinomial:

In this strategy, we model the logarithm of the probability of seeing a given output using the linear predictor.
For multinomial the loss minimised is the multinomial loss fit across the entire probability distribution. The softmax function is used to find the predicted probability of each class.

Read more about this here:
https://en.wikipedia.org/wiki/Multinomial_logistic_regression#As_a_log-linear_model

Note: This distinction is important, and needs you to interpret the coefficients differently for the models.

First, let’s built our model using the One vs. Rest scheme

from sklearn.linear_model import LogisticRegression
lr1 = LogisticRegression(solver="liblinear", multi_class="ovr")

# Fitting on first 10000 records for faster training
lr1.fit(X_train_reshape[:10000], y_train[:10000])

Assessing performance on the train set

The predictions of the model for the training data

from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
y_train_pred = lr1.predict(X_train_reshape[:10000])
cm = confusion_matrix(y_train[:10000], y_train_pred[:10000])

plt.figure(figsize=[7,6])
sns.heatmap(cm, cmap="Reds", annot=True, fmt='.0f')
plt.show()

That’s VERY high training accuracy! Overfitting?

Also, looks like the model is NOT very confused between 3 and 7, 6 and 9, at least not on the train set.

Error Analysis: Checking out the mis-classified cases

We’ll covert to a Pandas series for ease of indexing, isolate the mis-classification cases, plot some examples.

11 cases were mis-classified

  • Studying some cases
  • Picking 9 random cases — we’ll plot the digits, along with the true and predicted labels
The mis-classified cases

Can you see why the model was confused?
Let’s see how the model fares on the test set.

Confusion matrix on the test set

Making predictions on the test data, and plotting the confusion matrix.

Confusion Matrix on test set — smells like over-fitting

Looking at the confusion matrix and the classification report -

Recall is least for 3, 7 — model is confused between them significantly. Similarly, there is confusion between 4 and 5. Also, many 0s are mistaken for 1 and 3.

Okay! So it looks like the performance has fallen sharply on the test set. There’s a very good chance we’re over-fitting on the train set.

We acknowledge that the model could be improved.

But, let’s not worry about that for now. Let’s focus on the way to understand what the model learnt.

Model interpretation

Understanding the contribution of each pixel

The coefficients we learnt right now for each pixel, are based on the One vs Rest scheme.

Let’s go ahead and analyze the coefficients for our OVR model.

The shape for the coefficients ‘lr1.coef_.shape’ is (10. 784) i.e. we have 784 coefficients for each label — coefficients for each pixel for each digit!

A positive coefficient means a high value on that pixel increases the chances of this label, compared to ALL other classes. The coefficients therefore tell us how this pixel differentiates this label from all the other labels together.

Extracting the pixel coefficients and plotting on a heat-map for the label 0

plt.figure(figsize=[3,3])
coefs = lr1.coef_[0].reshape(28,28)
plt.imshow(coefs,cmap="RdYlGn",vmin=-np.max(coefs),vmax=np.max(coefs)) #setting mid point to 0
plt.show()
Heatmap for 0 — OVR

I’ve used a divergent colour scheme to differentiate between the positive and negative signs.

In the image above, pixels with green colour are positive value pixels. The image tells us that values in certain pixels help classify the digit as 0. As expected, the red colour in the center indicates that presence of values in that range means lower chances of the digit being a zero. Yellow is close to 0 — meaning the pixel doesn’t help differentiate in any way.

Making such pixel heatmaps for all the digits

Heatmap — all digits — OVR scheme

Have a good look at these heatmaps. This will reveal what the model has learnt. Be mindful that we have ‘One vs. Rest’ formulation, especially when comparing with heatmaps of other digits.

Now, let’s build a model using the multinomial scheme.

  • We need to specify the multi_class parameter as "multinomial"
  • The ‘liblinear’ solver doesn’t support this, so we choose the “sag” solver.
lr2 = LogisticRegression(random_state=42, multi_class="multinomial", solver="sag")
lr2.fit(X_train_reshape[:10000], y_train[:10000])

Assessing performance on the test set

Plotting the confusion matrix

Understanding the contribution of each pixel

We have 784 coefficients for each label — coefficients for each pixel.

Now, a positive coefficient would mean what makes this label what it is! But, if 3 labels have similar presence in particular pixel, the coefficients for all 3 may have similar values.

Extracting the pixel coefficients and plotting on a heatmap for the label 0

How different/similar is this to the heatmap from the OVR model?
Let’s make the heatmaps for all pixels.

Making such pixel heatmaps for all the digits

How does these heatmaps compare to the mean images for each label?

Plotting the average image for each digit.

plt.figure(figsize=(10, 4))
for i in range(10):
plt.subplot(2,5,i+1), plt.title(i)
plt.imshow(np.mean(X_train[y_train==i],axis=0),cmap='gray')
plt.suptitle('Mean images for each digit')

Plotting them all together — have a good look.

Mean images vs. OVR vs. Mutinomial

Exercise -

You’ve seen the heatmaps for the OVR method, as well as the multinomial method. And you also have the average image for each label.

  • Compare and contrast the heatmaps with the mean images.
  • What do think is going on? Can you try to understand what the models have learnt for each digit?
  • Why are the models not performing so well on certain digits? Can the heatmap help understand?

Possible next steps for those interesting in trying out more things -

I suggest you try the following -

  1. Use Logistic regression with regularization (ridge, lasso, elasticnet) and hyper-parameter optimization using cross validation to reduce overfitting.
  2. Use SVD/PCA to denoise and reconstruct the original data; follow it up with a tuned Logistic regression model.

Well, that’s all for our little demo here! I’ll soon share more demos with different modeling techniques, ways to interpret them and more experiments with the same dataset with supervised and unsupervised techniques.

Found this interesting? Stay tuned for more such demos.

Do share your remarks/comments/suggestions!

--

--

Seasoned Data Scientist, author of two books, teacher, corporate trainer. Solving complex problems at Zalando, Europe’s largest online shopping platform.