In this article, I will give a hands-on example (with code) of how one can use the popular PyTorch framework to apply the Vision Transformer, which was suggested in the paper "An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale" (which I reviewed in another post), to a practical Computer Vision task.

To do that, we will look at the problem of handwritten digit recognition using the well-known MNIST dataset.

I would like to provide a caveat right away, just to make it clear. I chose the Mnist Dataset for this demonstration because it is simple enough so that a model can be trained on it from scratch and used for predictions without any specialized hardware within minutes, not hours or days, so literally anyone with a computer can do it and see how it works. I haven’t tried much to optimize the hyperparameters of the model, and I certainly didn’t have the goal of achieving state-of-the-art accuracy (currently around 99.8% for this dataset) with this approach.
In fact, while I will show that the Vision Transformer can attain a respectable 98%+ accuracy on MNIST, it can be argued that it is not the best tool for this job. Since each image in this dataset is small (just 28×28 pixels) and consists of a single object, applying global attention can only be of limited utility. I might write another post later to examine how this model can be used on a bigger dataset with larger images and a greater variety of classes. For now, I just want to show how it works.
In terms of the implementation, I will be relying on the code from this open-source repository by Phil Wang, particularly on the following Vision Transformer (ViT) class from the _vitpytorch.py file:
As any PyTorch neural network module class, it has the initialization (init) function where all of the trainable parameters and layers are defined, and the forward function, which establishes the way these layers are assembled into the overall network architecture.
For brevity, only the definition of the ViT class itself is given here, without the dependent classes. If you want to use this code on your computer, you will need to import the whole _vit_pytorch.py_ file (which is surprisingly small, only about a hundred lines of code; I am giving a link to my own forked version on GitHub just in case the original file changes in the future), as well as a recent version of PyTorch (I used 1.6.0) and the einops library used for tensor manipulations.
To start using the MNIST dataset, we need to load it first, which we can do as follows (from this point on in the post all of the code is mine, although a lot of it is quite standard):
The transform_mnist transformation in the code above is used to normalize the image data to have zero mean and a standard deviation of 1, which is known to facilitate neural network training. The train_loader and test_loader objects contain the MNIST images already randomly split into batches so that they can be conveniently fed into the training and validation procedures.
Each item in the dataset contains an image with a corresponding ground-truth label. The goal of our Transformer, once trained on the training portion of the dataset (60,000 handwritten digit images), will be to predict, based on the image, the correct label for each sample in the test portion (10,000 images).
We will be using the following function to train our model for each epoch:
The function loops over every batch in the data_loader object. For each batch, it calculates the output of the model (as a log_softmax) and the negative log-likelihood loss for this output, then calculates the gradients of this loss in regards to each trainable model parameter via loss.backward() and updates the parameters via optimizer.step(). Every 100th batch it provides a printed update on the progress of the training and appends the value of the current loss to the loss_history list.
After each epoch of training, we will be able to see how our current model is doing on the test set using the following function:
While this is similar to the training procedure above, now we do not calculate any gradients and instead just compare the output of the model to the ground truth labels to calculate accuracy and update the loss history.
Once all of the functions are defined, it’s time to initialize our model and run the training. We will use the following code:
Here we define our Vision Transformer model with a patch size of 7×7 (which for a 28×28 image would mean 4 x 4 = 16 patches per image), 10 possible target classes (0 to 9), and 1 color channel (since the images are grayscale).
In terms of the network parameters, we use an embedding dimension of 64 units, a depth of 6 Transformer blocks, 8 transformer heads, and 128 units in the hidden layer of the output MLP head. For the optimizer, we will use Adam (as in the paper) with a learning rate of 0.003. We will train our model for 25 epochs and look at the results.
There is no particular justification for using the hyperparameter values above. I just picked something that seemed reasonable. It is certainly possible that optimizing these would lead to higher accuracy and/or faster convergence.
Once the code was run for 25 epochs (in a regular free Google Colab notebook with a Tesla T4 GPU), it produced the following output:

Well, 98.36% accuracy is not too bad. It is better than what one could expect to get from a fully-connected network (where I get about 97.8–97.9% without any tricks), so there certainly seems to be a benefit from the attention layers. Of course, as I mentioned above, the Vision Transformer is not particularly suited to this task, and even a simple convolutional network with a few layers can achieve accuracy at or above 99%. Perhaps this Transformer model can do a bit better as well after optimizing the hyperparameters.
But the point is, it works, and, as was described in the paper, when applied to bigger, more complex problems, it can be competitive with the best convolutional models. Hopefully, this brief tutorial has shown how you, the reader, can use it in your own work.