Fashion MNIST Classification with TensorFlow featuring Deepmind Sonnet

Suraj Narayanan Sasikumar
Towards Data Science
6 min readAug 11, 2018

--

In this post we’ll be looking at how to perform a simple classification task on the Fashion MNIST dataset using TensorFlow (TF) and Deepmind’s Sonnet library.

This post is also available as a Colaboratory notebook. Feel free to copy the notebook to your drive and mess around with the code.

Run in Colab

My aim for this post is two pronged:

  1. Show how to use the bells and whistles provided by TF in a simple Machine Learning (ML) task.
  2. Act as a simple getting started example for Deepmind’s Sonnet library.

Much of the explanations are in the form of comments within the code, so consider reading the code along with the post. The post is written with the assumption that the reader has a basic understanding of ML and the TF framework. That said, I have tried to provide external links to the technical terms used.

To kick things off let’s first install sonnet. A simple pip installation via the command line would do the trick, but make sure that TensorFlow is installed and has a version >= 1.5

$ pip install dm-sonnet

Assuming the other libraries are installed we’ll import the required python libraries.

[Out]
Tensorflow version: 1.10.0-rc1

The Fashion MNIST Dataset

The more traditional MNIST dataset has been overused to a point (99%+ accuracy) where it's no longer a worthy classification problem. Zalando Research came up with a new starting point for Machine Learning research, where rather than the 10 digits, 10 different clothing apparels are captured in 28x28 images. Myriad variations of these 10 apparels constitute the Fashion MNIST dataset.

A Sample from the Fashion MNIST dataset (Credit: Zalando, MIT License)

Using Keras (a high-level API for TensorFlow) we can directly download Fashion MNIST with a single function call. Since its relatively small (70K records), we’ll load it directly into memory.

Preprocess the Dataset

Since the dataset is hand-crafted for ML research we don’t need to perform data wrangling. The only pre-processing we require is mean centering and variance normalization. The resulting data distribution would have a spherical structure, resulting in lesser number of steps for gradient descent to converge. Refer Sec 5.3 of LeCun, Yann A., et al. “Efficient backprop.” for a precise explanation on why we need to do centering and normalization to achieve faster convergence of gradient descent.

[Out]
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 1s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step Training Data ::: Images Shape: (60000, 28, 28), Labels Shape: (60000,)
Test Data ::: Images Shape: (10000, 28, 28), Labels Shape: (10000,) Random 25 Images from the Training Data:

Building the model

Using Deepmind’s Sonnet library we’ll build two models, one a simple Multi-Layer Perceptron (MLP), another a Convolutional Network. We’ll then setup the training apparatus such that switching between the two models would be a simple configuration parameter.

As an aside, Keras is another High-level API, which from TF v1.9 onward is tightly integrated into TF. Rapid prototyping is instrumental for any ML research project and both Keras and Sonnet are extremely useful in that regard. Admittedly, Keras is a much more mature project and has the official backing of the TF team. Moreover, there is a multitude of Keras tutorials and projects on the interweb, adding one more doesn’t make any sense. On the other hand, Sonnet is rarely used outside of Deepmind, but is a must know for anyone following their research.

Deepmind Sonnet

Sonnet is a TensorFlow library from Deepmind which abstracts out the process of model building. The zen of sonnet is to encapsulate components of your model as python objects (modules) which can then be plugged into a TF graph as and when required, thus providing a seamless mechanism for code reuse. Such a design allows us to not bother about internal configurations like variable reuse, weight sharing etc. For a detailed guide refer their official documentation. Also, their source code is well documented and worth reading, especially when stuck trying to implement something.

In our example we’ll create two modules: FMNISTMLPClassifier and FMNISTConvClassifier. As the name suggests FMNISTMLPClassifier uses an MLP, and FMNISTConvClassifier uses a convolutional neural network. We’ll then setup the training apparatus in TensorFlow and then plug-in the model we want to train.

Putting Together the training apparatus.

The training apparatus contains the following components:

  1. The Input Pipeline through which the data is fed to the model
  2. An Optimization algorithm for performing Gradient Descent.
  3. A Loss function that is to be optimized by the Optimizer.
  4. The model that is being trained.

Input pipeline

In TensorFlow the preferred way to feed data to a model is using the tf.data module. It allows us to apply transformations on input data in a simple and reusable manner. The tf.data module allows us to design input pipelines like aggregating data from multiple sources, adding complex data manipulation tasks in a pluggable manner etc. In this example we showcase its basic functionalities, the reader is encouraged to go through the official guide.

We want our input pipeline to have the following three properties:

  1. Ability to switch between training and test datasets seamlessly, allowing us to perform evaluation after every epoch.
  2. Shuffle the dataset to avoid learning unintended correlation from the ordering of data on disk.
  3. Batch the dataset for Stochastic Gradient Descent.

In a single epoch the training loop would contain multiple mini-batch training passes covering the entire dataset, and then an accuracy evaluation over the test dataset.

Optimizer

The Adam (Adaptive Moment Estimation) Optimizer is a variant of Stochastic Gradient Descent. Among many other techniques, Adam uses adaptive learning rates for each parameter. This allows parameters that are associated with features that are uncommon to have aggressive learning rates and those with common features to have low learning rates. For a detailed exposition on different SGD optimizer read this wonderful post.

Loss Function

Evaluate the Cross Entropy Loss after performing softmax on the output from the model. More details here.

Model

Here we’ll use the models that we built using Sonnet. We’ll setup the training such that we can swap the two models (MLP and ConvNet) based on a configuration parameter value.

Let’s put together all the components.

[Out]
Epoch 1 ::: Training Time: 27.48s, Training Loss: 0.35541, Training Accuracy: 0.82533, Test Accuracy: 0.86060
Epoch 2 ::: Training Time: 26.22s, Training Loss: 0.27885, Training Accuracy: 0.88165, Test Accuracy: 0.88280
Epoch 3 ::: Training Time: 25.68s, Training Loss: 0.25212, Training Accuracy: 0.89918, Test Accuracy: 0.88710
Epoch 4 ::: Training Time: 25.82s, Training Loss: 0.21601, Training Accuracy: 0.91033, Test Accuracy: 0.89750
Epoch 5 ::: Training Time: 26.27s, Training Loss: 0.18370, Training Accuracy: 0.91778, Test Accuracy: 0.90500
Epoch 6 ::: Training Time: 25.84s, Training Loss: 0.19794, Training Accuracy: 0.92612, Test Accuracy: 0.89190
Epoch 7 ::: Training Time: 26.45s, Training Loss: 0.15230, Training Accuracy: 0.93163, Test Accuracy: 0.90500
Epoch 8 ::: Training Time: 25.84s, Training Loss: 0.15200, Training Accuracy: 0.93763, Test Accuracy: 0.90360
Epoch 9 ::: Training Time: 25.85s, Training Loss: 0.12375, Training Accuracy: 0.94403, Test Accuracy: 0.90550
Epoch 10 ::: Training Time: 26.06s, Training Loss: 0.11385, Training Accuracy: 0.95010, Test Accuracy: 0.91050

That’s it, we have trained a Convolutional neural network model to perform classification on the Fashion MNIST dataset with a Test accuracy of 91.050%. To train the MLP based model just change 'conv' to 'mlp' in the train function call.

The entire code for the post is available on GitHub:

View on GitHub

References

  1. LeCun, Yann A., et al. “Efficient backprop.”
  2. Deepmind Sonnet
  3. Fashion MNIST
  4. Basic Classification from TF Guide
  5. CS231n stanford course notes
  6. ML CheatSheet Loss Functions
  7. TF Metrics

--

--

Independent researcher in AI/ML. Currently my research focuses on developing directed exploration algorithms for Deep Reinforcement Learning.