The world’s leading publication for data science, AI, and ML professionals.

Image Classification with Vision Transformer

How to classify images with the help of Transformer-based model

Photo by drmakete lab on Unsplash
Photo by drmakete lab on Unsplash

Since its introduction in 2017, Transformer has been widely recognized as a powerful encoder-decoder model to solve pretty much any language modeling task.

BERT, RoBERTa, and XLM-RoBERTa are a few examples of state-of-the-art models in language processing that use a stack of Transformer encoders as the backbone in their architecture. ChatGPT and the GPT family also use the decoder part of Transformer to generate texts. It’s safe to say that almost any state-of-the-art model in natural language processing incorporate Transformer in its architecture.

Transformer performance is so good that it seems wasteful not to use it for tasks beyond natural language processing, like Computer Vision for example. However, the big question is: can we actually use it for computer vision tasks?

It turns out that Transformer also has a good potential to be applied to computer vision tasks. In 2020, Google Brain team introduced a Transformer-based model that can be used to solve an image classification task called Vision Transformer (ViT). Its performance is very competitive in comparison with conventional CNNs on several image classification benchmarks.

Therefore, in this article, we’re going to talk about this model. Specifically, we’re going to talk about how a ViT model works and how we can fine-tune it on our own custom dataset with the help of HuggingFace library for an image classification task.

So, as the first step, let’s get started with the dataset that we’re going to use in this article.


About the Dataset

We will use a snack dataset that you can easily access from dataset library from HuggingFace. This dataset is listed as having a CC-BY 2.0 license, which means that you are free to share and use it, as long as you cite the dataset source in your work.

Let’s take a sneak peek of this dataset:

Subset of images in the dataset
Subset of images in the dataset

We only need a few lines of code to load the dataset, as you can see below:

!pip install -q datasets

from datasets import load_dataset 

# Load dataset
dataset = load_dataset("Matthijs/snacks")
print(dataset)

# Output
  '''
  DatasetDict({
      train: Dataset({
          features: ['image', 'label'],
          num_rows: 4838
      })
      test: Dataset({
          features: ['image', 'label'],
          num_rows: 952
      })
      validation: Dataset({
          features: ['image', 'label'],
          num_rows: 955
      })
  })'''

The dataset is a dictionary object that consists of 4898 training images, 955 validation images, and 952 test images.

Each image comes with a label, which belongs to one of 20 snack classes. We can check these 20 different classes with the following code:

print(dataset["train"].features['label'].names)

# Output
'''
['apple','banana','cake','candy','carrot','cookie','doughnut','grape',
'hot dog', 'ice cream','juice','muffin','orange','pineapple','popcorn',
'pretzel','salad','strawberry','waffle','watermelon']'''

And let’s create a mapping between each label and its corresponding index.

# Mapping from label to index and vice versa
labels = dataset["train"].features["label"].names
num_labels = len(dataset["train"].features["label"].names)
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

print(label2id)
print(id2label)

# Output
'''
{'apple': 0, 'banana': 1, 'cake': 2, 'candy': 3, 'carrot': 4, 'cookie': 5, 'doughnut': 6, 'grape': 7, 'hot dog': 8, 'ice cream': 9, 'juice': 10, 'muffin': 11, 'orange': 12, 'pineapple': 13, 'popcorn': 14, 'pretzel': 15, 'salad': 16, 'strawberry': 17, 'waffle': 18, 'watermelon': 19}
{0: 'apple', 1: 'banana', 2: 'cake', 3: 'candy', 4: 'carrot', 5: 'cookie', 6: 'doughnut', 7: 'grape', 8: 'hot dog', 9: 'ice cream', 10: 'juice', 11: 'muffin', 12: 'orange', 13: 'pineapple', 14: 'popcorn', 15: 'pretzel', 16: 'salad', 17: 'strawberry', 18: 'waffle', 19: 'watermelon'}
'''

One important thing that we need to know before we move on is the fact that each image has varying dimension. Therefore, we need to perform some image preprocessing steps before feeding the images into the model for fine-tuning purposes.

Now that we know the dataset that we’re working with, let’s take a closer look at ViT architecture.


How ViT Works

Before the introduction of ViT, the fact that a Transformer model relies on self-attention mechanism raised a big challenge for us to use it for computer vision tasks.

The self-attention mechanism is the reason why Transformer-based models can differentiate the semantic meaning of a word used in different contexts. For example, a BERT model can distinguish the meaning of the word ‘park’ in sentences ‘They park their car in the basement’ and ‘She walks her dog in a park’ due to self-attention mechanism.

However, there is one problem with self-attention: it’s a computationally expensive operation as it requires each token to attend every other token in a sequence.

Now if we use self-attention mechanism on image data, then each pixel in an image would need to attend and be compared to every other pixel. The problem is, if we increase the pixel value by one, then the computational cost would increase quadratically. This is simply not feasible if we have an image with a reasonably large resolution.

Image by author
Image by author

In order to overcome this problem, ViT introduces the concept of splitting the input image into patches. Each patch has a dimension of 16 x 16 pixels. Let’s say that we have an image with the dimension of 48 x 48 pixels, then the patches of our image will look something like this:

Image by author
Image by author

In its application, there are two options for how ViT splits our image into patches:

  1. Reshape our input image that has a size of height x width x channel into a sequence of flattened 2D image patches with a size of no.of patches x (patch_size^2.channel) . Then, we project the flattened patches into a basic linear layer to get the embedding of each patch.
  2. Project our input image into a convolutional layer with the kernel size and stride equal to the patch size. Then, we flatten the output from that convolutional layer.

After testing the model performance on several datasets, it turns out that the second approach leads to the better performance. Therefore, in this article, we’re going to use the second approach.

Let’s use a toy example to demonstrate the splitting process of an input image into patches with a convolutional layer.

import torch
import torch.nn as nn

# Create toy image with dim (batch x channel x width x height)
toy_img = torch.rand(1, 3, 48, 48)

# Define conv layer parameters
num_channels = 3
hidden_size = 768 #or emb_dimension
patch_size = 16

# Conv 2D layer
projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, 
             stride=patch_size)

# Forward pass toy img
out_projection = projection(toy_img)

print(f'Original image size: {toy_img.size()}')
print(f'Size after projection: {out_projection.size()}')

# Output
'''
Original image size: torch.Size([1, 3, 48, 48])
Size after projection: torch.Size([1, 768, 3, 3])
'''

The next thing that the model will do is flatten the patches and put them sequentially as you can see in the image below:

Image by author
Image by author

We can do the flattening process with the following code:

# Flatten the output after projection with Conv2D layer

patch_embeddings = out_projection.flatten(2).transpose(1, 2)
print(f'Patch embedding size: {patch_embeddings.size()}')

# Output
'''
Patch embedding size: torch.Size([1, 9, 768]) #[batch, no. of patches, emb_dim]
'''

What we have after the flattening process is basically the vector embedding of each patch. This is similar to token embeddings in many Transformer-based language models.

Next, similar to BERT, ViT will add a special vector embedding for the [CLS] token in the first position of our patches’ sequence.

Image by author
Image by author
# Define [CLS] token embedding with the same emb dimension as the patches
batch_size = 1
cls_token = nn.Parameter(torch.randn(1, 1, hidden_size))
cls_tokens = cls_token.expand(batch_size, -1, -1)

# Prepend [CLS] token in the beginning of patch embedding
patch_embeddings = torch.cat((cls_tokens, patch_embeddings), dim=1)
print(f'Patch embedding size: {patch_embeddings.size()}')

# Output
'''
Patch embedding size: torch.Size([1, 10, 768]) #[batch, no. of patches+1, emb_dim]
'''

As you can see, by prepending the [CLS] token embedding in the beginning of our patch embedding, the length of the sequence increases by one. The final step after this would be adding the positional embedding into our sequence of patches. This step is important so that our ViT model can learn the sequence order of our patches.

This position embedding is a learnable parameter that will be updated by the model during the training process.

Image by author
Image by author
# Define position embedding with the same dimension as the patch embedding
position_embeddings = nn.Parameter(torch.randn(batch_size, 10, hidden_size))

# Add position embedding into patch embedding
input_embeddings = patch_embeddings + position_embeddings
print(f'Input embedding size: {input_embeddings.size()}')

# Output
'''
Input embedding size: torch.Size([1, 10, 768]) #[batch, no. of patches+1, emb_dim]
'''

Now, the position embedding plus vector embedding of each patch will be the input of a stack of Transformer encoders. The number of Transformer encoders depends on the type of ViT model that you use. Overall, there are three types of ViT model:

  • ViT-base: it has 12 layers, hidden size of 768, and the total of 86M parameters.
  • ViT-large: it has 24 layers, hidden size of 1024, and the total of 307M parameters.
  • ViT-huge: it has 32 layers, hidden size of 1280, and the total of 632M parameters.

In the following code snippet, let’s say that we want to use Vit-base. This means that we have 12 layers of Transformer encoders:

# Define parameters for ViT-base (example)
num_heads = 12
num_layers = 12

# Define Transformer encoders' stack
transformer_encoder_layer = nn.TransformerEncoderLayer(
           d_model=hidden_size, nhead=num_heads,
           dim_feedforward=int(hidden_size * 4),
           dropout=0.1)
transformer_encoder = nn.TransformerEncoder(
           encoder_layer=transformer_encoder_layer,
           num_layers=num_layers)

# Forward pass
output_embeddings = transformer_encoder(input_embeddings)
print(f' Output embedding size: {output_embeddings.size()}')

# Output
'''
Output embedding size: torch.Size([1, 10, 768])
'''

Finally, the stack of Transformer encoders will output the final vector representation of each image patch. The dimensionality of the final vector corresponds to the hidden size of the ViT model that we use.

Image by author
Image by author

And that’s basically it.

We can certainly build and train our own ViT model from scratch. However, as with other Transformer-based models, ViT requires training on a large amount of image data (14M-300M of images) in order for them to generalize well on unseen data.

If we want to use ViT on a custom dataset, the most common approach is to fine-tune a pretrained model. The easiest way to do this is by utilizing HuggingFace library. All we have to do is call ViTModel.from_pretrained() method and put the path to our pretrained model as an argument. The VitModel()class from HuggingFace will also act as a wrapper of all of steps that we’ve discussed above.

!pip install Transformers

from transformers import ViTModel

# Load pretrained model
model_checkpoint = 'google/vit-base-patch16-224-in21k'
model = ViTModel.from_pretrained(model_checkpoint, add_pooling_layer=False)

# Example input image
input_img = torch.rand(batch_size, num_channels, 224, 224)

# Forward pass input image
output_embedding = model(input_img)
print(output_embedding)
print(f"Ouput embedding size: {output_embedding['last_hidden_state'].size()}")

# Output
'''
BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 0.0985, -0.2080,  0.0727,  ...,  0.2035,  0.0443, -0.3266],
         [ 0.1899, -0.0641,  0.0996,  ..., -0.0209,  0.1514, -0.3397],
         [ 0.0646, -0.3392,  0.0881,  ..., -0.0044,  0.2018, -0.3038],
         ...,
         [-0.0708, -0.2932, -0.1839,  ...,  0.1035,  0.0922, -0.3241],
         [ 0.0070, -0.3093, -0.0217,  ...,  0.0666,  0.1672, -0.4103],
         [ 0.1723, -0.1037,  0.0317,  ..., -0.0571,  0.0746, -0.2483]]],
       grad_fn=<NativeLayerNormBackward0>), pooler_output=None, hidden_states=None, attentions=None)

Output embedding size: torch.Size([1, 197, 768])
'''

The output of the complete ViT model is a vector embedding representing each image patch plus the [CLS] token. It has the dimension of [batch_size, image_patches+1, hidden_size].

To perform an image classification task, we follow the same approach as with the BERT model. We extract the output vector embedding of the [CLS] token and pass it through the final linear layer to determine the class of the image.

Image by author
Image by author
num_labels = 20

# Define linear classifier layer
classifier = nn.Linear(hidden_size, num_labels) 

# Forward pass on the output embedding of [CLS] token
output_classification = classifier(output_embedding['last_hidden_state'][:, 0, :])
print(f"Output embedding size: {output_classification.size()}")

# Output
'''
Output embedding size: torch.Size([1, 20]) #[batch, no. of labels]
'''

Fine-Tuning Implementation

In this section, we will fine-tune a ViT-base model that was pre-trained on the ImageNet-21K dataset, which consists of approximately 14 million images and 21,843 classes. Each image in the dataset has a dimension of 224 x 224 pixels.

To begin, we need to define the checkpoint path for the pre-trained model and load the necessary libraries.

import numpy as np
import torch
import cv2
import torch.nn as nn
from transformers import ViTModel, ViTConfig
from torchvision import transforms
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

#Pretrained model checkpoint
model_checkpoint = 'google/vit-base-patch16-224-in21k'

Image Dataloader

As previously mentioned, the ViT-base model has been pretrained on a dataset consisting of images with the dimension of 224 x 224 pixels. The images have also been normalized according to a particular mean and standard deviation in each of their color channels.

As a result, before we can feed our own dataset into the ViT model for fine-tuning, we must first preprocess our images. This involves transforming each image into a tensor, resizing it to the appropriate dimensions, and then normalizing it using the same mean and standard deviation values as the dataset on which the model was pretrained.

class ImageDataset(torch.utils.data.Dataset):

  def __init__(self, input_data):

      self.input_data = input_data
      # Transform input data
      self.transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224), antialias=True),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], 
                             std=[0.5, 0.5, 0.5])
        ])

  def __len__(self):
      return len(self.input_data)

  def get_images(self, idx):
      return self.transform(self.input_data[idx]['image'])

  def get_labels(self, idx):
      return self.input_data[idx]['label']

  def __getitem__(self, idx):
      # Get input data in a batch
      train_images = self.get_images(idx)
      train_labels = self.get_labels(idx)

      return train_images, train_labels

From the image dataloader above, we will then get a batch of preprocessed images with their corresponding label. We can use the ouput of image dataloader above as an input for our model during the fine-tuning process.

Model Definition

The architecture of our ViT model is straightforward. Since we’ll be fine-tuning a pretrained model, we can use the VitModel.from_pretrained() method and provide the checkpoint of the model as an argument.

We also need to add a linear layer at the end, which will act as the final classifier. The output of this layer should be equal to the number of distinct labels in our dataset.

class ViT(nn.Module):

  def __init__(self, config=ViTConfig(), num_labels=20, 
               model_checkpoint='google/vit-base-patch16-224-in21k'):

        super(ViT, self).__init__()

        self.vit = ViTModel.from_pretrained(model_checkpoint, add_pooling_layer=False)
        self.classifier = (
            nn.Linear(config.hidden_size, num_labels) 
        )

  def forward(self, x):

    x = self.vit(x)['last_hidden_state']
    # Use the embedding of [CLS] token
    output = self.classifier(x[:, 0, :])

    return output

The above ViT model generates final vector embeddings for each image patch plus the [CLS] token. To classify images, as you can see above, we extract the final vector embedding of the [CLS] token and pass it to the final linear layer to obtain the final class prediction.

Model Fine-Tuning

Now that we have defined the model architecture and prepared the input images for batching process, we can start to fine-tune our ViT model. The training script is a standard Pytorch training script, as you can see below:

def model_train(dataset, epochs, learning_rate, bs):

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Load nodel, loss function, and optimizer
    model = ViT().to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)

    # Load batch image
    train_dataset = ImageDataset(dataset)
    train_dataloader = DataLoader(train_dataset, num_workers=1, batch_size=bs, shuffle=True)

    # Fine tuning loop
    for i in range(epochs):
        total_acc_train = 0
        total_loss_train = 0.0

        for train_image, train_label in tqdm(train_dataloader):
            output = model(train_image.to(device))
            loss = criterion(output, train_label.to(device))
            acc = (output.argmax(dim=1) == train_label.to(device)).sum().item()
            total_acc_train += acc
            total_loss_train += loss.item()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        print(f'Epochs: {i + 1} | Loss: {total_loss_train / len(train_dataset): .3f} | Accuracy: {total_acc_train / len(train_dataset): .3f}')

    return model

# Hyperparameters
EPOCHS = 10
LEARNING_RATE = 1e-4
BATCH_SIZE = 8

# Train the model
trained_model = model_train(dataset['train'], EPOCHS, LEARNING_RATE, BATCH_SIZE)

Since our snack dataset has 20 distinct classes, then we’re dealing with a multiclass classification problem. Therefore, CrossEntropyLoss() would be the appropriate loss function. In the example above, we train our model for 10 epochs, learning rate is set to be 1e-4, with the batch size of 8. You can play around with these hyperparameters to tune the performance of the model.

After you trained the model, you will get an output that looks similar as the one below:

Image by author
Image by author

Model Prediction

Since we have fine-tuned our model, naturally we want to use it for prediction on the test data. To do so, first let’s create a function that encapsulate all of the necessary image preprocessing steps and the model inference process.

def predict(img):

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], 
                             std=[0.5, 0.5, 0.5])
        ])

    img = transform(img)
    output = trained_model(img.unsqueeze(0).to(device))
    prediction = output.argmax(dim=1).item()

    return id2label[prediction]

As you can see above, the image preprocessing step during inference is exactly the same as the step that we did on the training data. Then, we use the transformed image as the input to our trained model, and finally we map its prediction to the corresponding label.

If we want to predict a specific image on the test data, we can just call the function above and we’ll get the prediction afterwards. Let’s try it out.

print(predict(dataset['test'][900]['image']))
# Output: waffle
Example of test data from the dataset
Example of test data from the dataset

Our model predicted our test image correctly. Let’s try another one.

print(predict(dataset['test'][250]['image']))
# Output: cookie
Example of test data from the dataset
Example of test data from the dataset

And our model predicted the test data correctly again. By fine-tuning a ViT model, we can get a good performance on a custom dataset. You can also do the same process for any custom dataset in an image classification task.


Conclusion

In this article, we have seen how Transformer can be used not only for language modeling tasks, but also for computer vision tasks, which in this case is image classification.

To do so, first the input image is decomposed into patches with a size of 16 x 16 pixels. Then, the Vision Transformer model utilizes a stack of Transformer encoders to learn the vector representation of each image patch. Finally, we can use the final vector representation of the [CLS] token prepended at the beginning of image patch sequence to predict the label of our input image.

I hope this article is useful for you to get started with Vision Transformer model. As always, you can find the code implementation presented in this article in this notebook.


Dataset Reference

https://huggingface.co/datasets/Matthijs/snacks


Related Articles