The world’s leading publication for data science, AI, and ML professionals.

Understand and Implement Vision Transformer with TensorFlow 2.0

Self-Attention Mechanism and Goodbye Convolution!

Break Images into Patches & Experience the Magic (Source: Author)
Break Images into Patches & Experience the Magic (Source: Author)

When Transformer Network came out, initially it became the go to model for NLP tasks. ‘An Image is Worth 16X16 Words‘ which was presented in International Conference for Representation Learning (ICLR) 2021, by Alex Dosovitskiy et.al. showed for the first time how Transformer can be implemented for Computer Vision tasks and outperform CNN (e.g. ResNet) in image classification tasks. This post is a deep dive and step by step implementation of Vision Transformer (ViT) using TensorFlow 2.0. What you can expect to learn from this post –

  1. Detailed Explanation of Self-Attention Mechanism.
  2. ViT Structure Clearly Explained.
  3. Implement ViT from scratch with TensorFlow 2.0.
  4. An Example of ViT in action for CIFAR-10 classification.
  5. Different Implementations of ViT and Subtle Differences.

This post will be long and I also expect this post will be the best companion for the original ViT paper for understanding and implementing research ideas into codes. All the codes/images used here are available in my GitHub. So sit back, grab your coffee and we are ready to go!


Disclaimer: I would like to spend some time describing the attention mechanism because, for the implementation of the ViT, I will use Keras MultiHeadAttention layer. So what goes on there is worth understanding. If you want to jump into ViT implementation jump to section 2.


1. Transformer & Attention:

To understand Vision Transformer, first we need to focus on the basics of transformer and attention mechanism. For this part I will follow the paper Attention is All You Need. This paper itself is an excellent read and the description/concepts below are mostly taken from there & understanding them clearly, will only help us to proceed further.

The idea of transformer is to use attention without recurrence (read RNN). So transformer is still a sequence to sequence (Seq2Seq) model and follows encoder-decoder structure. To quote from the paper –

The encoder maps an input sequence of symbol representations (x1,…,xn) to a sequence of continuous representations z=(z1,…,zn). Given z, the decoder then generates an output sequence (y1,…,ym) of symbols one element at a time. At each step the model is auto-regressive, consuming the previously generated symbols as additional input when generating the next.

Let’s see the Transformer structure introduced in the paper –

Fig. 1: Transformer Architecture (Source: Attention is All You Need by A. Vaswani et.al.)
Fig. 1: Transformer Architecture (Source: Attention is All You Need by A. Vaswani et.al.)

Once we understand the encoder part of the above structure we can move to the vision transformer. Encoder layer contains 2 very important components,

  • Multi-head self-attention block.
  • Position wise, fully-connected feed-forward network.

Let’s focus on the multi-head self-attention part. The paper itself has a diagram of scaled dot-product attention and multi-head attention which consists of several attention layers running in parallel.

Fig. 2: Multi-Head Attention (Source: Attention is All You Need by A. Vaswani et.al.)
Fig. 2: Multi-Head Attention (Source: Attention is All You Need by A. Vaswani et.al.)

The 3 labels in the diagram Q, K, V denotes Query, Key and Value vectors. For now, we think of this as part of the information retrieval protocol when we search (query) and the search engine compares our query with a key and responds with a value (output).

In the original paper, 3 different usages of multi-head attention were described. Let’s quote directly from paper-

  1. In "encoder-decoder attention" layers, the queries come from the previous decoder layer, and the memory keys and values come from the output of the encoder. This allows every position in the decoder to attend over all positions in the input sequence. This mimics the typical encoder-decoder attention mechanisms in sequence-to-sequence models.
  2. The encoder contains self-attention layers. In a self-attention layer all of the keys, values and queries come from the same place, in this case, the output of the previous layer in the encoder. Each position in the encoder can attend to all positions in the previous layer of the encoder.
  3. Similarly, self-attention layers in the decoder allow each position in the decoder to attend to all positions in the decoder up to and including that position.

For our purpose (to understand vision transformer), most important point is 2, i.e. self-attention in the encoder part. Let’s deep dive!


1.1. Self Attention:

To understand self-attention, more than text, pictorial representations will help. I will use some images from my slides.

Let’s consider an input sequence (_x_1,_x_2,…,xm). Output of self-attention layer from this input sequence is a set of context vectors of same length (_C_1,_C_2,…,Cm) as the input sequence. The picture below will help us –

Fig. 3: Steps towards Attention (Self). Turn Input sequences to context vectors. (Source: Author's Slides)
Fig. 3: Steps towards Attention (Self). Turn Input sequences to context vectors. (Source: Author’s Slides)

In the picture above we define the weights that will be trained as Wq,Wk,Wv for query, keys and values. How they are actually used? Let’s see the picture below –

Fig. 4: The weights for query, key and values that will be updated during training. (Source: Author's Slides).
Fig. 4: The weights for query, key and values that will be updated during training. (Source: Author’s Slides).

It is important to note that same weight is applied for all i’s. In the Attention is All You Need paper, the dimensions of query and key are taken as _dk and for values, it is assumed _dv. For example, if we have _a 5D xi (e.g. [0 1 1 2 3]) and the query has dimension 3 then Wq will have dimension 5×3. Same goes for keys and values and corresponding weights. So how do we finally calculate these context vectors? Let’s see below –

Fig. 5: Dot-Product in Attention. Dot product of all the keys in the sequence with ith query. (Source: Author's slides).
Fig. 5: Dot-Product in Attention. Dot product of all the keys in the sequence with ith query. (Source: Author’s slides).

As mentioned in the paper the dot product (K^T ⋅ Q_j) is a choice and some other forms can also be used like additive or multiplicative attention. Important point is _for calculating α_j (see figure above) at position q_j (same as x_j) we use information from q_j only but all other Keys (kj’s). We are left with the final step to calculate the output of the attention layer and that is to use the values as below –

Fig. 6: Obtain the Context Vectors (C's) as output of the attention block. (Source: Author's slides)
Fig. 6: Obtain the Context Vectors (C’s) as output of the attention block. (Source: Author’s slides)

I emphasize once again: To calculate the context vector at position ‘iwe need values from all the inputs. The inputs interact with each other (the term ‘self’) and figure out where to pay more attention. Finally, how to interpret the scaling factor (1/√_dk) in the scaled dot attention function?

Usually, we will initialize our layers with the intention of having equal variance throughout the model. But when we perform the dot product over these two vectors (Q, K) with a variance σ², this will result in a scaler having _dk times higher variance. Remember also that _dk is the dimension of both Q, K, while V has dimension _dv.

Normally distributed queries and keys. The total variance after dot product is d_k times higher. (Source: Author's Notebook).
Normally distributed queries and keys. The total variance after dot product is d_k times higher. (Source: Author’s Notebook).

If we do not scale down the variance back to σ², the softmax over the logits would have saturated to 1 for one random element and 0 for all others. The gradients through the softmax will be close to zero so we can’t learn the parameters appropriately.

Hopefully now, you can appreciate the scaled dot product attention diagram (figure: 2) that was introduced in the paper.


1.2. Multi-Head Self-Attention:

It’s a very simple extension of single-head self-attention.

  • In Multi-head self-attention we have h single head self-attentions (layers). See figure 2 above.
  • In a single-head self-attention, trainable parameters are weights Wq, Wk, Wv.
  • The h single head self-attention layers do not share parameters. So total 3h parameters.
  • Each single-head self-attention outputs a context vector.
  • These context vectors are concatenated.
  • If a single-head attention outputs a dimensional vector i.e. each _Ci‘s are d×1, then multi-head outputs are hd×1 dimensional vector, given h layers of single head self-attention layers.

To quote from the ‘Attention is All You Need’ paper on the importance of multi-head attention –

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different position. With a single attention head, averaging inhibits this.

We have gone through the self-attention mechanism and multi-head attention as an extension of it. This is the building block of the Transformer Encoder in Vision Transformer (ViT) paper and now we are ready to dive into ViT paper and implementation.


2. Vision Transformer:

First, take a look at the ViT architecture as shown in the original paper ‘An Image is Worth 16 X 16 Words‘ paper –

Fig. 7: Vision Transformer Architecture: (Source: An Image is Worth 16 X 16 Words by A. Dosovitskiy et. al. )
Fig. 7: Vision Transformer Architecture: (Source: An Image is Worth 16 X 16 Words by A. Dosovitskiy et. al. )

We have already discussed the main component of ViT which is the Transformer Encoder and Multi-Head Attention within it. The next part is to generate patches from images and add positional embedding. I will use CIFAR-10 data for this example implementation. Note that, it is mentioned in the paper that ViTs are data-hungry architectures and the performance of ViTs even using a relatively large dataset like ImageNet without strong regularization yields accuracy a few percentages below ResNet. But the scenario changes when Transformers are trained on larger datasets (14M-300M images), so CIFAR-10 used here is just for example implementation and not for performance comparison with other networks.

The original implementation is available in google github and a very similar version in tensorflow models. Below is a ViT TLDR:

  • Take an image (e.g: 256×256×3).
  • Turn images into smaller patches (ex:16×16×3, total 256 (N=256×256/16²) patches).
  • These patches then were linearly embedded. We can think of these now as tokens.
  • Use them as input for Transformer Encoder (contains multi-head self-attention).
  • Perform the classification.
  • Bye-Bye Convolution.

Let’s get started by loading the data and I will use tf.data format –

import tensorflow as tf
from tensorflow.keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
train_lab_categorical = tf.keras.utils.to_categorical(y_train, num_classes=10, dtype='uint8')
test_lab_categorical = tf.keras.utils.to_categorical(y_test, num_classes=10, dtype='uint8')
from sklearn.model_selection import train_test_split
train_im, valid_im, train_lab, valid_lab = train_test_split(x_train, train_lab_categorical, test_size=0.20, stratify=train_lab_categorical, random_state=40, shuffle = True)
training_data = tf.data.Dataset.from_tensor_slices((train_im, train_lab))
validation_data = tf.data.Dataset.from_tensor_slices((valid_im, valid_lab))
test_data = tf.data.Dataset.from_tensor_slices((x_test, 
test_lab_categorical))
autotune = tf.data.AUTOTUNE
train_data_batches = training_data.shuffle(buffer_size=40000).batch(128).prefetch(buffer_size=autotune)
valid_data_batches = validation_data.shuffle(buffer_size=10000).batch(32).prefetch(buffer_size=autotune)
test_data_batches = test_data.shuffle(buffer_size=10000).batch(32).prefetch(buffer_size=autotune)

2.1. Patch Generation:

Let’s discuss what’s presented in the original paper.

Consider an image, x∈R(H×W×C), and turn it into a sequence of patches _xp∈R(N×P×P×C), where (H, W) is the height and width of the original image, C is the number of channels, (P, P) is the resolution of each image patch, and N=HW/ is the resulting number of patches, which also serves as the effective input sequence length for the Transformer.

For the patch generation, I will follow what was done in original code but I will discuss another method too which was discussed in Keras Blog.

  • The example shown in Keras Blog uses [tf.image.extract_patches](https://www.tensorflow.org/api_docs/python/tf/image/extract_patches). Using this we can literally create patches from the images, the patches were then flattened. Then use a dense layer with learnable weights to project it with a hidden dimension (this will be more clear soon). In addition, it adds a learnable position embedding to the projected vector. Final shape of the output will be (batch_size, num_patches, hidden_dim). An example using an image tensor till patch creation is shown below.
  • In the original code instead of creating patches and then adding learnable weights via Dense layer to project it on a certain dimension, we directly use a convolutional layer (with learnable weights) with the number of filters equal to this hidden dimension. So the shape here is already (batch_size, num_patches, hidden_dim) and then a learnable position embedding layer of same shape was added to the input.

We will discuss both methods. But before that what’s this hidden dimension? This is the dimension of Query and Key (previously we wrote it as _dk) and we will use this in the encoder block when we need the MultiHeadAttention layer. So this projection is done in such a way that we can directly feed the embedded patches (flattened) to the transformer. Great! Slowly things are building up.

Patch Generation (Example From Keras Blog):

Below is the code block similar to that used in Keras Blog, which divides images into patches of a given patch size and then a helper function to visualize the patches –

Using this for an example image in CIFAR-10 data, we get the following result –

Fig. 8: Image broken into patches ('tokens'), using the code block above.
Fig. 8: Image broken into patches (‘tokens’), using the code block above.

If we do follow this implementation (Check Keras Example), we need to first project the patches (via a Dense layer) to a dimension that matches the query dimension in [MultiHeadAttention](https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention) layer and then corresponding position embedding is added and below is the code block –

### Positonal Encoding Layer
class PatchEncode_Embed(layers.Layer):
  '''
  2 steps happen here
  1. flatten the patches
  2. Map to dim D; patch embeddings
  '''
  def __init__(self, num_patches, projection_dim):
    super(PatchEncode_Embed, self).__init__()
    self.num_patches = num_patches
    self.projection = layers.Dense(units=projection_dim)
    self.position_embedding = layers.Embedding(
    input_dim=num_patches, output_dim=projection_dim)
  def call(self, patch):
    positions = tf.range(start=0, limit=self.num_patches, delta=1)
    encoded = self.projection(patch) +               self.position_embedding(positions)
    return encoded

2.2. Patch Generation & Positional Encoding:

I will follow the original implementation, where instead of generating patches and then adding learnable weights via the Dense layer, we directly add this learnability through the Conv2D layer and the number of filters matches the query dimension of [MultiHeadAttention](https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention) layer. With that, we also need to add positional embedding and will do that by randomly initializing weights via the custom layer by extending the tf.keras.Layer class. Let’s see below –

Some important points related to the positional encoding –

  1. Inductive Bias: Once an image is broken into patches, we lose the structures of the input and the positional embedding helps allow the model to learn about the structure of the input image. These positional embeddings are learnable and highlight how much image structure the model can learn on its own. It is also mentioned in the paper that different types of embedding techniques like relative embedding, 2D embedding don’t change the performances much. The main difference between this Transformer approach from the CNN is – In CNN, the kernels help us to learn/understand the 2D neighbourhood structure; but in transformers apart from the MLP layer, this local 2D structure isn’t used and the positional embeddings at the initialization time carry no information about the 2D position of the patches and all spatial relations between the patches are learned from scratch.
  2. Class Token: It is also mentioned in the paper, the 'cls' token was used to resemble the original Transformer structure as close as possible. The researchers also made attempt at using only image-patch embeddings, globally average-pooling (GAP) them, followed by a linear classifier. The initial poor performance was later found neither due to absence of the token, GAP but the learning rate was not optimum.

Following the last point, I will not make use of the [class] token in this implementation. If we go back to Figure 7, we can see that in the Transformer Encoder block we need to implement the Normalization and the MLP part apart from the Multi-Head Attention layer. Let’s move to that part!


3. Transformer Encoder Block:

3.1. MLP: The Multi-layer perceptron contains GELU non-linearity. I will not discuss GELU in detail but check out the Gaussian Error Linear Units (GELU) paper and I skip the discussion here. This activation is available within tf.nn. The MLP sizes are given in the paper and for this simplified implementation we will use much smaller (number of units) Dense layers. Since the encoder block repeats, we have to be careful about the number of units in the Dense layer because the output dimension has to be compatible with the input for the next MultiHeadAttention layer.

3.2. Norm: The Norm in the figure refers to LayerNormalization layer. An excellent visual representation is available in the original Group Normalization paper. In short, if we consider an input tensor with shape (N, C, H, W), It computes mean and variance (_μi, _σi) along the (C, H, W) axes. This ensures that computation for an input feature is entirely independent of other input features in a batch.

With these details, we are ready to move to code!! Don’t overlook the residual connections in the Transformer Block.


4. Putting All Together: Vision Transformer

We have built up all the small pieces required for ViT – Images as word tokens, Positional encodings, Transformer encoder block. We will put them together to build up the Vision Transformer. Let’s jump to the code block –

Only steps remaining are compiling the model & training it for some given epochs. Training for 120 epochs with Adam optimizer, I obtain the following training curves—

Fig. 9: Training Curves for ViT model for CIFAR-10 data. (Source: Author's Notebook)
Fig. 9: Training Curves for ViT model for CIFAR-10 data. (Source: Author’s Notebook)

The model starts to over-fit and several changes like ReduceLearningRate, EarlyStopping etc., can be added to overcome this. Beyond that, I think the size of data is too low for data-hungry ViT. Once again the ViT performance improves over a generic ResNet when the data-set is large (14M-300M images). We can obtain the confusion matrix below for the test set –

Fig. 10: CM for CIFAR-10 test set obtained with ViT Model described above. (Source: Author's Notebook).
Fig. 10: CM for CIFAR-10 test set obtained with ViT Model described above. (Source: Author’s Notebook).

We have come to the end of this post where we went through all the details of understanding and implementing ViT from scratch using TensorFlow 2. All the codes are available on my GitHub.

References:

  1. Codes/Images Used Here: My GitHub. Recommend you to Open it Notebook Viewer for viewing it properly.
  2. Attention is All You Need: A. Vaswani et.al.
  3. An Image is worth 16X16 Words: A. Dosovitskiy et.al.
  4. Transformers for Image Recognition at Scale: Google Blog.
  5. _Keras Vision Transformer Blog_.
  6. _Google Research; Transformer Model:_ GitHub

Stay strong & Cheers!!


If you’re interested in further fundamental Machine Learning concepts, you can consider joining Medium using My Link. You don’t pay anything extra but a I’ll get a tiny commission. Appreciate you all!!


Related Articles