The world’s leading publication for data science, AI, and ML professionals.

Peeking inside the black box

A simple approach to gain insight into the limitations of your CNN model

Photo by Marc-Olivier Jodoin on Unsplash
Photo by Marc-Olivier Jodoin on Unsplash

There is something magical about the way modern algorithms solve our current data-driven problems. Despite coming from a neuroscience background with a strong understanding of our sensory systems, I feel humbled when I write algorithms that mimic them.

One things that strikes me most, however, is the fact that both can be described as black boxes, a term usually reserved for deep learning models. For example, right now as you read this, what are you really seeing?

words…sentences…paragraphs? No way!

You’re not seeing any thing. Rather, light emitted at differing wavelengths, angles, and intensities enters your retina to produce a cascade of events transduced into an electrochemical signal. This signal is sent to a small region deep inside your brain where it is integrated and then projected to a vast region in the back of your brain… somehow allowing you to interpret the information as text in an article.

If you really want to blow your mind, consider the consequences of reading black text on a light background. If our visual system is activated by light, what we are really reading is more like the "shadow" of text (since black = no light, hence no activation of our retinal cells)!


Today we are going to find ways to peek inside the black box to help gain insight from the results of a convolutional neural network (Cnn). In a previous article, I provided step-by-step instructions to build your own CNN from scratch. I compared that model to several others using transfer learning. You can check out my Github repo for this project if you are interested!

Recently, I was asked if I knew where my model was making errors…What a great question! I fell down the rabbit hole, looking only at minimizing loss and maximizing accuracy. I got really happy when I developed a model that performed well. Meanwhile, I had no idea where or why the model was failing.

The purpose of this post will be to try and find out where my model was making errors. To do so, I will generate a Confusion Matrix using the code below (note: this example presents a confusion matrix for the top-performing CNN I built from scratch which resulted in a validation accuracy of 76%).

Refresher on the dataset

Just as a quick primer for those who haven’t seen the post, this was a multi-classification task for a condition known as diabetic retinopathy (DR). The distribution of the diagnoses was highly disproportionate. In the original dataset there were 1805 controls, 370 mild (class 1), 999 moderate (class 2), 193 severe (class 3), and 295 proliferative (class 4).

To interpret a confusion matrix, we aim to maximize the values along the diagonal line going from the top left to bottom right. We can see in figure 1 that n=361 were properly classified on the control group, while n=8 were classified as 1 and n=3 were classified as 2. This isn’t bad.

But looking at the other classes, the accuracy rate was nothing to write home about!

Figure 1 - Confusion matrix presenting the true vs predicted labels for this dataset. Recall, 0 was the control group, and 1–4 represents an increase in severity of diabetic retinopathy. Image by author
Figure 1 – Confusion matrix presenting the true vs predicted labels for this dataset. Recall, 0 was the control group, and 1–4 represents an increase in severity of diabetic retinopathy. Image by author

Going a little further, we can use the code below to plot images of the top errors, along with their true and predicted values.

Plotting images of the top 10 errors shows us that all of them were predicted to be control images.

Figure 2 - Images of the top 10 errors. Image by author
Figure 2 – Images of the top 10 errors. Image by author

What information do we gain by plotting the errors?

I think it’s fair to say that the model presented here was heavily influenced by the disproportionate group sizes. Of the 733 total validation images, 650 were predicted to be either control or class 2, in line with their over-representation in the dataset.

How to improve the model

Seeing as the imbalanced dataset significantly influenced the model, there are a few methods that may rectify this issue. Check out this article for more details.

  • Focal loss: Instead of an equal weight assigned to each class, a customized loss function will penalize classes that are easily classified. The result puts all classes on a more even playing field which may improve classification of the under-represented classes.
  • Over/Under sampling: By selecting a random proportion (under sampling) of the over-represented group(s) or creating copies (oversampling) of the under-represented group(s), we can synthetically balance group sizes.

I’ll try these different options out and let you know how they change the model accuracy. Thanks for reading!


Related Articles