The world’s leading publication for data science, AI, and ML professionals.

How to Fine-Tune a Pretrained Vision Transformer on Satellite Data

A step-by-step tutorial in PyTorch Lightning

Image created by the author using Midjourney.
Image created by the author using Midjourney.

The Vision Transformer is a powerful AI model for image classification. Released in 2020, it brought the efficient transformer architecture to computer vision.

In pretraining, an AI model ingests large amounts of data and learns common patterns. The Vision Transformer was pretrained on ImageNet-21K, a dataset of 14 million images and 21,000 classes.

Satellite images are not covered in ImageNet-21K, and the Vision Transformer would perform poorly if applied out-of-the-box.

Here, I will show you how to fine-tune a pretrained Vision Transformer on 27,000 satellite images from the EuroSat dataset. We will predict land cover, such as forests, crops, and industrial areas.

Example images from the EuroSAT RGB dataset. Sentinel data is free and open to the public under EU law.
Example images from the EuroSAT RGB dataset. Sentinel data is free and open to the public under EU law.

We will work in PyTorch Lightning, a Deep Learning library that builds on PyTorch. Lightning reduces the amount of code one has to write, and lets us focus on modeling.

All code is available on GitHub.

Setting up the project

The pretrained Vision Transformer is available on [Huggingface](http://The pretrained models are available from Huggingface.). The model architecture and weights can be installed from GitHub. We will also need to install Pytorch Lightning. I used version 2.2.1 for this tutorial, but any version > 2.0 should work.

pip install -q git+https://github.com/huggingface/transformers
pip install lightning=2.2.1

We can split our project in four steps, which we will cover in detail:

  • Pretrained Vision Transformer: Lightning Module
  • EuroSAT dataset
  • Train the Vision Transformer on the EuroSAT dataset
  • Calculate the accuracy on the test set

Adapting the Vision Transformer to our dataset

The Vision Transformer from Huggingface is optimized for a subset of ImageNet with 1,000 classes.

Our dataset contains only 10 classes for different types of land cover. Therefore, we need to modify the output section of the Vision Transformer to a new classification head with the correct number of classes.

Vision Transformer architecture. Adapted by the author from the original paper. [arxiv]
Vision Transformer architecture. Adapted by the author from the original paper. [arxiv]

The code for instantiating a pretrained model from Huggingface makes this straightforward. We only need to specify the new number of classes by _numlabels, and tell the model to ignore the fact we changed the output size.

from transformers import ViTForImageClassification
ViTForImageClassification.from_pretrained("google/vit-base-patch16-224", 
                                          num_labels=10, 
                                          ignore_mismatched_sizes=True
                                          )

The model reminds us that we now need to re-train:

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16–224 and are newly initialized because the shapes did not match: … You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

We can choose different flavours of the Vision Transformer, here we stick with vit-base-patch16–224, the smallest model that uses 16 x 16 patches from images with a size of 224 x 224 pixels. This model has 85.8 million parameters and requires 340 MB of memory.

Vision Transformer as a Lightning Module

In PyTorch Lightning, a deep learning model is defined as a Lightning Module. We only need to specify

  • Setup of the model: Load the pretrained Vision Transformer
  • Forward step: Apply the model to a batch of data
  • Training, validation, and test step
  • The optimizer to be used in training

    The training step must return the loss, in this case the cross-entropy loss to quantify the mismatch between the predicted and the true classes.

Logging is convenient. With calls to self.log, we can log training and evaluation metrics directly to our preferred logger – in my case, TensorBoard. Here, we log the training loss and the validation accuracy.

Note that, in order to access the predictions made by Huggingface’s Vision Transformer, we need to retrieve them from the model output as predictions.logits.

Lightning DataModule for the EuroSAT dataset

You can download the EuroSAT dataset from Zenodo. Make sure to select the RGB version, which has already been converted from the original satellite image. We will define the dataset within a LightningDataModule.

The setup stage uses torchvision transform functions. In order to comply with the input that is expected by the Vision Transformer, we need to upscale the satellite images to 224 x 224 pixels, convert the images to torch datatypes and normalize them.

We split the dataset so that 70% remain for training (fine-tuning), 10% for validation, and 20% for testing. By stratifying on the class labels, we ensure an equal distribution of classes across all three subsets.

The functions _traindataloader etc. are convenient for setting up the dataloaders later in the run script.

Putting it all together: the run script

Now that we have the building blocks for the dataset and the model, we can write a run script that performs the fine tuning.

For clarity, I created separate modules for the dataset (_eurosatmodule) and the model (_visiontransformer) that need to be imported.


Train the model

The Lightning Trainer takes a model and dataloaders for training and validation data. It offers a variety of flags that can be customized to your needs – here we use only three of them:

  • devices, to use only one GPU for training
  • early stopping callback, to stop training if the validation loss does not improve for six epochs
  • logger, to log the training process to TensorBoard

The beauty of PyTorch Lightning is that training is now done in one line of code:

trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)

Under the hood, the trainer uses backward propagation and the Adam optimizer to update the model weights. It stops training when the validation accuracy has not improved for the specified number of epochs.

In fact, fine-tuning on EuroSAT is completed within few epochs. The panel shows the training loss, as logged by TensorBoard. After two epochs, the model has already reached 98.3% validation accuracy.

Snapshot from TensorBoard created by the author.
Snapshot from TensorBoard created by the author.

Evaluate on the test set

Our dataset is small and fine-tuning is fast, so we do not need to store the trained model separately and can apply it directly to the test set.

In one line of code, we compute the accuracy on the test set:

trainer.test(model=model, dataloaders=test_dataloader, verbose=True)

With only one epoch of fine-tuning, I achieved a test set accuracy of 98.4%, meaning that the land cover types were correctly classified for almost all satellite images.

Even with less samples, we can achieve great accuracy. The panel shows the test set accuracy for different numbers of training samples seen only once during fine-tuning. With only 320 satellite images, an average of 32 per class, the test set accuracy is already 80%.

Image created by the author.
Image created by the author.

Key takeaways

Pretrained models are a great way to reduce your training time. They are already good at a general task and just need to be adapted to your specific dataset.

In real-world applications, data is often scarce. The EuroSAT dataset consists of only 27,000 images, about 0.5% the magnitude of ImageNet-21K. Pretraining takes advantage of larger datasets, and we can efficiently use the application-specific dataset for fine-tuning.

Lightning is great for training deep learning models without having to worry about all the technical details. The LightningModule and the Trainer API offer convenient abstractions and are performant.

If you want to stay within the Huggingface ecosystem to fine-tune a Vision Transformer, I recommend this tutorial.

The complete code, including configuration files that allow you to add your own datasets, is available on GitHub:

GitHub – crlna16/pretrained-vision-transformer: Pretrained Vision Transformer with PyTorch…


References


Related Articles