The world’s leading publication for data science, AI, and ML professionals.

The Nuclear Mousetrap

Try small easy solutions before pulling out the big guns

Image by Marco Schroeder from Pixabay
Image by Marco Schroeder from Pixabay

Artificial Intelligence is Hard

Scientists are now building trillion parameter models. Huge systems like GPT-3 and even bigger models like Wu Dao 2.0 are rolling off the assembly line. These truly huge models, their smaller transformers-based little brothers, and even good old from-scratch neural networks are very popular in the media and in practical application. There are scenarios where these bazookas are very useful, but many other situations can be solved with extremely simple machine learning models.

The nuclear mousetrap is a thought experiment taught to me in my first year of engineering school. The idea is that a thermonuclear explosion is effective at killing a mouse, but is also a very expensive way to build a mousetrap. It’s overkill. For some reason, physicists also like to talk about nuclear mousetraps, but in this case, we are talking about the overengineering of a solution in comparison to simpler "acceptable" solutions. Artificial Intelligence is full of huge solutions to tiny problems. You need to be careful not to be fooled into adopting the "best" performing model in machine learning metrics (e.g., precision, recall, f1-score, etc.) and then realizing you will be paying a lot more money for hardware and the throughput of your inference will be much lower. In real life, a rightsize solution from the engineering perspective is more important than squeezing every last drop from your machine learning metrics.

Let’s work through an example or two to hammer this point home. The code for these examples can be found here.

Example 1: Identify the Handwritten Digit "0"

The MNIST dataset is a well-known collection of small grayscale images of handwritten digits. This dataset is an example used to measure and compare the effectiveness of machine vision solutions. Machine Learning models with the best MNIST performance can be found here. Unfortunately, the best-performing solutions are also expensive computationally and have an annoyingly long latency, as we will soon see. What can we do to get a fast and cheap solution for recognizing the digit 0?

First, let’s recall that dimensionality reduction has a compute cost at inference time and impacts our model performance. However, simply deleting and downsampling dimensions is free of these costs. Let’s use a bunch of techniques to delete parts of the MNIST image, and downsample it like crazy to minimize our costs.

The dataset is split into a training set of 60,000 images, and a testing set of 10,000 images. Each MNIST image is composed of 784 pixels, each holding an integer value between 0 (black) and 255 (white). Here are 3 examples of what these images look like.

To begin simplifying our data, we can simply flatten each image into a long list of numbers, completely ignoring that the data is a picture. Next, we convert the labels for the digits 0, 1, 2, … 9 into a label that is "1" when the digit is a zero, and "0" otherwise.

At this point, we should check how well some very simple model like a decision tree fits our dataset. The model fits the training data well, and the decision tree caught 94% of the zeros in the testing data. The following is part of the classification report for this "NO EFFORT" model:

NO EFFORT
            precision    recall  f1-score   support
Not 0       0.99         0.99    0.99       9020
Was 0       0.92         0.94    0.93       980
accuracy                         0.99       10000

This initial model used all 784 unsigned 8-bit numbers (bytes) in each picture. Next, let’s take every other column in the picture, and then chop off the first 10 pixels and the last 10 pixels because we don’t really need them to recognize zeros. We also crunch down every pixel into a one or a zero. Before, a pixel was grayscale and could be anywhere in the 0 to 255 range. Now a pixel is either black or white. After these changes, we end up with 372 pixels remaining per image. Next, we can examine the training dataset of images to kill array indices that have few samples in them (mostly 0s or mostly 1s regardless of the written digit). We can only peek at training data, and then when we decide what image pixels to delete, we simply apply the same deletions to the testing images.

Having downsampled and removed low diversity pixels, we now check what pixels have a low correlation with the label we created. Pixels that don’t correlate well with the answer we can simply delete. After this change, we end up with 281 pixels per image.

The next step is to identify pixels that highly correlate with each other, and remove one of the two correlating pixels from the image. After doing this, we end up with 255 binary pixels per image. Our whole image now fits into 32 bytes instead of the original 784. We didn’t only save storage space. We fit a decision tree on this crunched down image dataset and the following is part of the classification report for this smaller "MEDIUM EFFORT" model:

MEDIUM EFFORT
            precision  recall  f1-score   support         
Not 0       0.99       0.99    0.99       9020        
Was 0       0.92       0.95    0.94       980      
accuracy                       0.99       10000

Wait… what? We deleted 752 out of 784 bytes of our data for each image! How is the model still working at the exact same performance level? It’s actually a bit better if you look closely.

Well, this is why we try simple before going complicated… and we’re not done. We can kill off a further 127 features with basically no consequences. With only 128 of the 255 remaining pixels, we fit a decision tree model to the training data and get the following test results for this "TINY" model:

TINY
         precision    recall  f1-score   support         
Not 0    1.00         0.99    0.99       9020        
Was 0    0.92         0.97    0.94       980

Still… no… collapse!

Now let’s see what we got out of this effort:

We should sum up where we are. We found a faster more storage-efficient way to detect handwritten 0 characters. But does this trick work for other stuff too?

Example 2: Identify the Sandals

Having shown that our dumb approach works well on handwritten digits, does this same approach work on more complicated data? Scientists have been concerned for some time that MNIST is a bit too easy of a dataset, and so there is a more complicated dataset called Fashion MNIST that poses more of a challenge than handwritten digits. Although Fashion MNIST still has the same dimensions (784 unsigned 8-bit numbers representing a 28×28 grayscale picture), it has more complicated objects in it than MNIST, like shirts and purses. The objects in the dataset are: 'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'.

As an aside, I am very entertained by the Fashion MNIST image embedding:

Embedding of Fashion MNIST dataset from: https://github.com/zalandoresearch/fashion-mnist
Embedding of Fashion MNIST dataset from: https://github.com/zalandoresearch/fashion-mnist

Before we re-apply our approach from MNIST to Fashion MNIST, let’s look at a few of the images:

A dress and two sandals from the Fashion MNIST dataset
A dress and two sandals from the Fashion MNIST dataset

We can see these objects are indeed a lot more complicated than handwritten digits.

Well, how do our tricks perform on this dataset? Here are the results:

That looks great! And how about the quality metrics? Did the model "work" well enough?

NO EFFORT FASHION
           precision  recall  f1-score   support    
Not Sandal 0.99       0.99    0.99       9000       
Sandal     0.89       0.89    0.89       1000      
accuracy                      0.98       10000
MEDIUM EFFORT FASHION
           precision  recall  f1-score   support    
Not Sandal 0.99       0.98    0.99       9000       
Sandal     0.87       0.88    0.87       1000      
accuracy                      0.97       10000
TINY FASHION
           precision  recall  f1-score   support    
Not Sandal 0.98       0.99    0.98       9000       
Sandal     0.87       0.85    0.86       1000      
accuracy                      0.97       10000

As we can see from the results above, these tricks work pretty well even in more advanced circumstances. With no changes in our approach, we produced very simple decision trees that correctly caught 89%, 88%, and 85% of the sandals in the testing set. We didn’t exactly explore the design space. This whole thing was a pretty low-effort first look. In many use cases, we can debounce the output of a model so that a few predictions in a row of the same object increases our confidence in the observation, resulting in precision and recall above the single-frame levels reported here.

Conclusion

We saw in this article that simple low-effort and low-complexity approaches can achieve reasonable performance in terms of model quality, latency, and storage requirements. If you are trying to maximize the speed of your solution, minimize the number of model parameters (to fit things into small devices or embedded platforms), or minimize inference cost (e.g., your REST API with reserved GPU costs you way more than CPU-based reserved compute power). Giant pre-trained models may simply be overkill for the problem you care about. Remember to check if your "big" solution is a nuclear mousetrap.

The code for this article is available HERE.

If you liked this article, then have a look at some of my most read past articles, like "How to Price an AI Project" and "How to Hire an AI Consultant." And hey, join the newsletter!

Until next time!

-Daniel linkedin.com/in/dcshapiro [email protected]


Related Articles