The world’s leading publication for data science, AI, and ML professionals.

How to Build an Image-Captioning Model in Pytorch

A detailed step-by-step explanation of how to build an image-captioning model in Pytorch

Photo by Adam Dutton on Unsplash
Photo by Adam Dutton on Unsplash

In this article, I will explain how you can build an image captioning model architecture using the Pytorch deep learning library. In addition to explaining the intuition behind the model architectures, I will also provide the Pytorch code for the models.

Note that this article was written in June 2022, so earlier/future versions of Pytorch may be a little different and the code in this article may not necessarily work.

What is Image Captioning?

As the name implies, Image Captioning is the task of taking/inputting an image to an AI model, and receiving a text caption describing/summarizing the contents of the image as its output. For example, if I were to input the following picture into an image captioning model:

Photo by LyAn Voyages on Unsplash
Photo by LyAn Voyages on Unsplash

the model would return a text caption like "Dog running in water". Image captioning models consist of 2 main components: a CNN (Convolutional Neural Network) encoder and a Language Model/RNN (some sort of NLP model that can produce text) decoder. The CNN encoder stores the important information about the inputted image, and the decoder will use that information to produce a text caption.

To train image captioning models, the most commonly used datasets are the Flickr8k dataset and the MSCOCO dataset. You can find the download links to the Flickr dataset [here](https://cocodataset.org/#home) and the link to the MSCOCO dataset here. The Flickr8k dataset consists of 8000 images – each with 5 different captions that can describe the image – and the MSCOCO dataset consists of 328000 images. From an introductory perspective, using the Flickr dataset is recommended because it is not as large as MSCOCO, making it much easier to deal with. But if you are trying to build a model that can be deployed/used in production, then the MSCOCO is probably better.

Encoder-Decoder Model Architecture

As I mentioned before, an Encoder-Decoder architecture consists of 2 components: a Convolutional Neural Network to encode the image (i.e. transform it into a rich embedding representation), and a Recurrent Neural Network (or an LSTM) that will take in as input this image, and be trained to sequentially decode the caption using a mechanism called Teacher Forcing. A good diagram of what this model looks like is below:

Image by Author
Image by Author

The CNN: Typically, the CNN tends to be the least computationally intensive/complex part of this model architecture. Why? Because most image captioning models tend to use transfer learning to simply load pre-trained weights of already existing powerful CNN architectures. In this article, for example, I will be using the Inception V3 CNN network that will be loaded in Pytorch‘s torchvision library. However, there are many other CNN’s you can use besides Inception, like ResNet, VGG, or LeNet. The main reason why I am using transfer learning for the CNN instead of training it from scratch is because of simply how general and broad this task is. It doesn’t require the CNN to learn very specific things (the only purpose of the CNN is to create a rich embedding for the image). Any features that the model will have to learn that are beneficial for image captioning will be learned during the training process for the whole model, and its weights will be fine-tuned.

Here is the Pytorch model code for the CNN Encoder:

import torch
import torch.nn as nn
import torchvision.models as models
class CNNEncoder(nn.Module):
 def __init__(self, embed_size):
   super(CNNEncoder, self).__init__()
   self.inception = models.inception_v3(pretrained=True,
                                       aux_logits=False)
   self.inception.fc = nn.Linear(self.inception.fc.in_features,
                                                    embed_size)
   self.relu = nn.ReLU()
   self.dropout = nn.Dropout(0.5)
 def forward(self, input):
   features = self.inception(input)
   return self.dropout(self.relu(features)) 

As you can see, it is a relatively straightforward CNN architecture. The only difference is that we are taking the last fully connected layer of the Inception network, and manually changing it to map/connect to the embedding size we want our feature embeddings to be (and the size that the RNN decoder will take in as input).

The RNN Decoder: Unlike CNN’s, we typically do not use transfer learning in RNN’s/LSTM’s. If you are familiar with the LSTM architecture, you know that the training is done sequentially, with each consecutive cell being trained before the next cell. The inputs to the LSTM are usually just the hidden state from the previous cell, and the output of the previous LSTM cell. However, for this model, we will concatenate the feature embedding produced by the CNN with the previous output of the LSTM, and pass this concatenated tensor into the LSTM. And you can see this represented in the image of the whole encoder-decoder architecture:

Image by Author
Image by Author

Here is the Pytorch code for the LSTM:

class DecoderRNN(nn.Module):
  def __init__(self, embed_size, hidden_size, vocab_size):
    super(DecoderRNN, self).__init__()
    self.embed = nn.Embedding(vocab_size, embed_size)
    self.lstm = nn.LSTM(embed_size, hidden_size)
    self.linear = nn.Linear(hidden_size, vocab_size)
    self.dropout = nn.Dropout(0.5)
  def forward(self, features, captions):
    embeddings = self.dropout(self.embed(captions))
    embeddings = torch.cat((features.unsqueeze(0), embeddings), 
                                                        dim=0)
    hiddens, _ = self.lstm(embeddings)
    outputs = self.linear(hiddens)
    return outputs

Some of the hyperparameters in this model will have to be chosen by you (specifically the hidden_size and embed_size parameters; I just used 256 for both). The vocab size is a parameter you will have to calculate based on what dataset you used. In the model code above, like with many/most LSTM architectures, to actually get the word that each cell predicts, I use a linear layer that takes the hidden layer and maps it to the vocab. One more important thing to mention is that in this model, I generate the word embeddings from scratch. As you can see in the LSTM code, I use an nn.Embedding layer that will take the one-hot encodings of each word in the vocab and transform them into an embedding of embed_size. Nowadays, we typically don’t generate word embeddings from scratch (for instance many people just use transformer weights) because this prolongs the training process, but because we are using an LSTM as our decoder, we can’t really load pre-trained weights and use them.

class Encoder_Decoder(nn.Module):
  def __init__(self, embed_size, hidden_size, vocab_size):
    super(Encoder_Decoder, self).__init__()
    self.cnn = CNNEncoder(embed_size)
    self.decoderRNN = DecoderRNN(embed_size, hidden_size,
                                 vocab_size)
  def forward(self, images, captions):
    features = self.cnn(images)
    outputs = self.decoderRNN(features, captions)
    return outputs

We simply connect both models by feeding the output of the CNN as input to the Decoder LSTM, and return the final output of the LSTM. This final Encoder_Decoder model is the actual model we will train on our data (not the other 2).

Modifications to the Encoder-Decoder Model

While the model architecture described above is already a good way to construct image captioning models, there are a couple of ways to modify the model to make it more powerful.

Attention Mechanism: At a high level, as the Decoder generates each word of the caption, an attention mechanism allows a model to pay attention to – or essentially focus on – relevant parts of the image. For instance, taking the previous example of a dog running on water, when the decoder captions the word ‘dog’, an attention mechanism would then allow the model to focus on the spatial region of the image that contained the dog.

There are 2 types of attention mechanisms we could use:

Soft Attention: Soft attention involves constructing a word by considering/paying attention to multiple parts of an image – each to a different degree. Basically, just think about it as using – sometimes many – different parts of an image each with different strengths. Certain parts of an image will be considered more strongly than other parts of the image.

Hard Attention: Hard attention involves considering multiple parts of an image, but in contrast to soft attention, each part of the image that is considered has the same strength. In other words, each part of the image that is selected (i.e that the model is ‘paying attention to’) is completely considered, and the parts of the image not selected are completely disregarded while generating a specific word.

Before we feed the feature embedding produced by the CNN to the decoder, to add the attention mechanism, we add a separate linear/attention layer that we will apply to the feature embedding. Below is some pseudo-code/Pytorch code demonstrating soft attention:

self.attention = nn.Linear(embed_size, embed_size)
self.softmax = nn.Softmax()
attention_output = self.softmax(self.attention(feature_embedding))
feature_embedding = feature_embedding * attention_output

The softmax is used to create the probability distributions for the feature embedding. Multiplying these probabilities with the original feature embedding will produce the new embedding layer that we will feed to the decoder.

Using a Transformer as a Decoder: Using a powerful transformer like BERT is the next step towards building a better image captioner model. But while it may seem like an obvious choice, it is more challenging in reality. This is because the inherent architecture of the Transformer limits its input to words and text only. For example, BERT was pretrained with tasks like next sentence prediction and masked language modeling, making it difficult to incorporate things other than text into its input. To incorporate Transformers end-to-end in the training process would require a special/modified transformer. There are some papers that have been published in [this](https://arxiv.org/abs/2101.10804) area like this and this.


I hope you found this content easy to understand. If you think that I need to elaborate further or clarify anything, drop a comment below.

References

MSCOCO Dataset: https://cocodataset.org/#home

Flickr8k Dataset: https://www.kaggle.com/datasets/adityajn105/flickr8k


Related Articles