
Artificial Intelligence is Hard
Scientists are now building trillion parameter models. Huge systems like GPT-3 and even bigger models like Wu Dao 2.0 are rolling off the assembly line. These truly huge models, their smaller transformers-based little brothers, and even good old from-scratch neural networks are very popular in the media and in practical application. There are scenarios where these bazookas are very useful, but many other situations can be solved with extremely simple machine learning models.
The nuclear mousetrap is a thought experiment taught to me in my first year of engineering school. The idea is that a thermonuclear explosion is effective at killing a mouse, but is also a very expensive way to build a mousetrap. It’s overkill. For some reason, physicists also like to talk about nuclear mousetraps, but in this case, we are talking about the overengineering of a solution in comparison to simpler "acceptable" solutions. Artificial Intelligence is full of huge solutions to tiny problems. You need to be careful not to be fooled into adopting the "best" performing model in machine learning metrics (e.g., precision, recall, f1-score, etc.) and then realizing you will be paying a lot more money for hardware and the throughput of your inference will be much lower. In real life, a rightsize solution from the engineering perspective is more important than squeezing every last drop from your machine learning metrics.
Let’s work through an example or two to hammer this point home. The code for these examples can be found here.
Example 1: Identify the Handwritten Digit "0"
The MNIST dataset is a well-known collection of small grayscale images of handwritten digits. This dataset is an example used to measure and compare the effectiveness of machine vision solutions. Machine Learning models with the best MNIST performance can be found here. Unfortunately, the best-performing solutions are also expensive computationally and have an annoyingly long latency, as we will soon see. What can we do to get a fast and cheap solution for recognizing the digit 0?
First, let’s recall that dimensionality reduction has a compute cost at inference time and impacts our model performance. However, simply deleting and downsampling dimensions is free of these costs. Let’s use a bunch of techniques to delete parts of the MNIST image, and downsample it like crazy to minimize our costs.
The dataset is split into a training set of 60,000 images, and a testing set of 10,000 images. Each MNIST image is composed of 784 pixels, each holding an integer value between 0 (black) and 255 (white). Here are 3 examples of what these images look like.

To begin simplifying our data, we can simply flatten each image into a long list of numbers, completely ignoring that the data is a picture. Next, we convert the labels for the digits 0, 1, 2, … 9 into a label that is "1" when the digit is a zero, and "0" otherwise.
At this point, we should check how well some very simple model like a decision tree fits our dataset. The model fits the training data well, and the decision tree caught 94% of the zeros in the testing data. The following is part of the classification report for this "NO EFFORT" model:
NO EFFORT
precision recall f1-score support
Not 0 0.99 0.99 0.99 9020
Was 0 0.92 0.94 0.93 980
accuracy 0.99 10000
This initial model used all 784 unsigned 8-bit numbers (bytes) in each picture. Next, let’s take every other column in the picture, and then chop off the first 10 pixels and the last 10 pixels because we don’t really need them to recognize zeros. We also crunch down every pixel into a one or a zero. Before, a pixel was grayscale and could be anywhere in the 0 to 255 range. Now a pixel is either black or white. After these changes, we end up with 372 pixels remaining per image. Next, we can examine the training dataset of images to kill array indices that have few samples in them (mostly 0s or mostly 1s regardless of the written digit). We can only peek at training data, and then when we decide what image pixels to delete, we simply apply the same deletions to the testing images.
Having downsampled and removed low diversity pixels, we now check what pixels have a low correlation with the label we created. Pixels that don’t correlate well with the answer we can simply delete. After this change, we end up with 281 pixels per image.
The next step is to identify pixels that highly correlate with each other, and remove one of the two correlating pixels from the image. After doing this, we end up with 255 binary pixels per image. Our whole image now fits into 32 bytes instead of the original 784. We didn’t only save storage space. We fit a decision tree on this crunched down image dataset and the following is part of the classification report for this smaller "MEDIUM EFFORT" model:
MEDIUM EFFORT
precision recall f1-score support
Not 0 0.99 0.99 0.99 9020
Was 0 0.92 0.95 0.94 980
accuracy 0.99 10000
Wait… what? We deleted 752 out of 784 bytes of our data for each image! How is the model still working at the exact same performance level? It’s actually a bit better if you look closely.
Well, this is why we try simple before going complicated… and we’re not done. We can kill off a further 127 features with basically no consequences. With only 128 of the 255 remaining pixels, we fit a decision tree model to the training data and get the following test results for this "TINY" model:
TINY
precision recall f1-score support
Not 0 1.00 0.99 0.99 9020
Was 0 0.92 0.97 0.94 980
Still… no… collapse!
Now let’s see what we got out of this effort:


We should sum up where we are. We found a faster more storage-efficient way to detect handwritten 0 characters. But does this trick work for other stuff too?
Example 2: Identify the Sandals
Having shown that our dumb approach works well on handwritten digits, does this same approach work on more complicated data? Scientists have been concerned for some time that MNIST is a bit too easy of a dataset, and so there is a more complicated dataset called Fashion MNIST that poses more of a challenge than handwritten digits. Although Fashion MNIST still has the same dimensions (784 unsigned 8-bit numbers representing a 28×28 grayscale picture), it has more complicated objects in it than MNIST, like shirts and purses. The objects in the dataset are: 'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
.
As an aside, I am very entertained by the Fashion MNIST image embedding:

Before we re-apply our approach from MNIST to Fashion MNIST, let’s look at a few of the images:

We can see these objects are indeed a lot more complicated than handwritten digits.
Well, how do our tricks perform on this dataset? Here are the results:


That looks great! And how about the quality metrics? Did the model "work" well enough?
NO EFFORT FASHION
precision recall f1-score support
Not Sandal 0.99 0.99 0.99 9000
Sandal 0.89 0.89 0.89 1000
accuracy 0.98 10000
MEDIUM EFFORT FASHION
precision recall f1-score support
Not Sandal 0.99 0.98 0.99 9000
Sandal 0.87 0.88 0.87 1000
accuracy 0.97 10000
TINY FASHION
precision recall f1-score support
Not Sandal 0.98 0.99 0.98 9000
Sandal 0.87 0.85 0.86 1000
accuracy 0.97 10000
As we can see from the results above, these tricks work pretty well even in more advanced circumstances. With no changes in our approach, we produced very simple decision trees that correctly caught 89%, 88%, and 85% of the sandals in the testing set. We didn’t exactly explore the design space. This whole thing was a pretty low-effort first look. In many use cases, we can debounce the output of a model so that a few predictions in a row of the same object increases our confidence in the observation, resulting in precision and recall above the single-frame levels reported here.
Conclusion
We saw in this article that simple low-effort and low-complexity approaches can achieve reasonable performance in terms of model quality, latency, and storage requirements. If you are trying to maximize the speed of your solution, minimize the number of model parameters (to fit things into small devices or embedded platforms), or minimize inference cost (e.g., your REST API with reserved GPU costs you way more than CPU-based reserved compute power). Giant pre-trained models may simply be overkill for the problem you care about. Remember to check if your "big" solution is a nuclear mousetrap.
The code for this article is available HERE.
If you liked this article, then have a look at some of my most read past articles, like "How to Price an AI Project" and "How to Hire an AI Consultant." And hey, join the newsletter!
Until next time!