The world’s leading publication for data science, AI, and ML professionals.

PolyViT: Co-training Vision Transformers on Images, Videos and Audio

Multi-Modality is the way forward for efficient and scalable solutions in the Transformer Era.

Transformers in Deep Learning

Photo by Hal Gatewood on Unsplash
Photo by Hal Gatewood on Unsplash

Introduction

The famed Transformer [1] architecture released in 2017 was simply outstanding and went on to break almost all SOTA benchmarks in Natural Language Processing. It was then adapted to other domains like Image & Video classification, Object Detection, Audio classification, and even in Generative networks without any major modifications. One of the popularly known architecture – Vision Transformer [2] achieved SOTA results on the ImageNet Classification task which then set the base in the Computer Vision domain. We now have mainstream research happening around Transformers for every domain under the Deep Learning umbrella. The architectures and results achieved are exceptional but throw a simple question,

Do we need a model for each modality?

Can we train a single Transformer model for multiple datasets? PolyViT [3] is set to solve that by adopting a co-training method that can be used to train a single model on image, video, and audio datasets. This would be massive in terms of reducing the efforts over multiple modalities.

Note: Modality can be viewed as "type of input processed"

Background

Multi-task learning is not new in Deep Learning. An abstract idea for Multi-task learning is to develop architectures that can generalize whilst share the parameters across multiple types of datasets like Image, Video, Speech, Point Clouds, etc. Important observation was the reduced performance when compared to a single-task model which in turn hindered more research in this particular hypothesis. And the performance further decreases when more tasks are combined which led to little or no encouragement for researchers to venture into this space further. To train a multi-task model is quite tricky too in terms of initializing and calibrating the parameters towards a stable training/fitting.

Isn’t this also an attempt to mimic a human brain which can effortlessly learn patterns across various modalities?

The Perceiver [4] architecture is a transformer that could process, tokenize and train different modalities using cross-attention processing of latent tokens. But this was not a parameter-shared network. We have architectures like UniT [5] and Joint Video and Image encoder [6] which are employing multi-modal strategies but do not outperform the single-task models.

Idea & Architecture

To set the context for Polyvit, we first need to know about 3 Transformer architectures that are best in their respective dataset modality,

  1. Vision Transformer (ViT) for Image
  2. VideoVision Transformer (ViViT)[7] for Video
  3. Audio Spectrogram Transformer (AST)[8] for Audio

Check out this blog on ViT for a comprehensive understanding of the architecture. ViViT is an extension of ViT with the only notable difference being 3D patches (Spatio-temporal) called tubelets than the 2D image patches observed in ViT. The AST is simply a ViT trained on spectrogram (time-frequency representation of speech/audio) images. All three models achieve their best performance when pre-trained on datasets like ImageNet-21K or JFT.

PolyViT architecture illustrating the co-training setup. Image Credits - PolyViT paper
PolyViT architecture illustrating the co-training setup. Image Credits – PolyViT paper

Coming to PolyViT, we can see in the above image how the encoder is shared among the modalities. The base architecture is a pre-trained ViT with added task-specific tokenizers and attention heads. We have separate input and positional embeddings and the class tokens for each of the three modalities image, video (sequence of frames), and audio. The encoding process is modality-specific to make sure it caters to different number of tokens/modality. During training, it performs one task from one modality in each pass (Image classification task in Image modality).

Co-training

This procedure involves using the same set of hyperparameters for different tasks but with mini-batches sourced from the same modality. The optimization is handled by Stochastic Gradient Descent (SGD). There are various ways of constructing the mini-batches – Task-by-Task, Alternating, Accumulating gradients, Uniform sampling, and Weighted sampling. We use the terminology U which denotes the total number of SGD steps during co-training.

Mini batch sampling methods. Image Credits - PolyViT paper
Mini batch sampling methods. Image Credits – PolyViT paper

As we can observe from the above methods, Task-by-task sampling follows an order wherein the SGD steps are applied for each task. The task order is random. Alternating sampling is (deterministic) repetitive with the SGD acting once per task. Uniform sampling involves constructing the batches from a uniform distribution with a probability of 1/T resulting in U/T steps per task. Weighted sampling is a schedule that has a certain weight assigned to each task according to its number of steps in a single-task model. Finally, the Accumulating gradients process is quite different from the rest – store the gradients after a single forward and back prop of each task and then apply the parameter for that entire batch (containing multiple tasks).

The intuition I’m having is we should have a task and apply parameter updates for that only; Having all sorts of tasks and updates within a batch is going to screw up the training quite badly.

Whichever sampling technique used, results in parameter sharing over the defined n tasks which is one of the primal objectives of PolyViT. To enable a large number of modalities, the authors introduced a special modality-specific layer, L(adapt) which is applied right after tokenization.

Results & Benchmarks

Now we shall observe the results of PolyViT for each of the task-sampling methods for a variety of datasets like ImageNet-1K, CIFAR-10 & 100, Kinetics400, Moments in Time (MIT), AudioSet, and VGGSound.

PolyViT trained on various task sampling methods for a 9-task training. L(adapt) = 0 which means no modality-specific layers used. Image Credits - PolyViT paper
PolyViT trained on various task sampling methods for a 9-task training. L(adapt) = 0 which means no modality-specific layers used. Image Credits – PolyViT paper

The results reiterate the fact that the gradient updates should be task-specific for robust results and if are randomly sampled in Task-by-Task sampling it performs very poorly. The weighted schedule was found to be the best and hence used in all subsequent experiments. One thing to note here is that the number of steps per task is derived from its single task baseline which results in the same computational resources as 9-tasks trained separately. Fabulous isn’t it?

PolyViT with varying L(adapt) values compared to single-task baselines. Image Credits - PolyViT paper
PolyViT with varying L(adapt) values compared to single-task baselines. Image Credits – PolyViT paper

After observing the above results, the PolyViT trained for a single modality (3 models with 263M) is outperforming the SOTA benchmark in 7 out of 9 tasks. Using co-training for all 9 tasks, it is within 0.2–0.3 percent from the SOTA results with a 3x decrease in parameters(93M). PolyViT with L(adapt) = 0 is almost within 1% of the single-task baselines with an 8x decrease in parameters (773M to 93M). Also, we can conclude that the single modality co-training is improving the accuracy on the smaller datasets within each modality, confirming the effect of regularization and ability to fit on smaller datasets which were not seen in standard ViT models.

Similarly, experiments were performed on trained PolyViT and evaluated on downstream tasks on a linear head. The multi-modal PolyViT shines here too with very good performance on 11 tasks. Further, buoyed by the improved performance within a modality, authors co-trained PolyViT’s for a single modality which resulted in surpassing the SOTA benchmarks with a considerable reduction in parameters. Kindly refer to the PolyViT paper for more details on the above-mentioned experiments and training-setup details (Appendix section).

My Thoughts

In my opinion, it is an excellent attempt that has delivered SOTA-level results. We have a lot of positives like parameter efficiency, single model as opposed to n task-specific models, simple to implement and maintain. The inference time is also constant throughout the modalities and tasks.

PolyViT will now enable the deployment of relatively light-weighted multi-modal transformer models on Edge devices as opposed to heavy and task-specific models. Think about the usability of a single model for many tasks on a device. The avenues can surely be expanded and scaled such that we can hopefully have a "One-Model for a majority of ML tasks". Also, the fact of regularizing effect on smaller datasets is a welcome addition making it easier for individual practitioners/researchers to prevent the overfitting problem.

The only small caveat I observed is the opposition to the initial hypothesis of a fully parameter shared model with the addition of L(adapt) layers. Since, the results from the model with L(adapt) = 0 settings are very good, we can give this a pass and consider PolyViT as a fully parameter shared model 😊

Conclusion

PolyViT with co-training and single modality setup surpassed the previous SOTA results across 5 standard video and audio classification tasks. Change it to a multi-modal setting and it is efficient with very little drop in performance (0.5–1%). The fact that it can learn and generalize for multiple tasks, multiple modalities (domains) with near SOTA performance is itself a huge win. The second biggest thing would be the parameter efficiency observed in multi-modal setup when compared to their single-task baselines. And due to the robust representation, fine-tuning becomes easier without any hyperparameter tuning. There is still scope for experimentation in the form of using the larger datasets like the ImageNet-21K dataset and expanding to other domains.

More research paper reviews by me

Non-Deep Networks

Expire-Span: Not All Memories are Created Equal explained

MLP-Mixer: An all-MLP Architecture for Vision

References

[1] Transformer : https://arxiv.org/pdf/1706.03762.pdf

[2] Vision Transformer: https://arxiv.org/pdf/1409.1556.pdf

[3] PolyViT: https://arxiv.org/pdf/2111.12993.pdf

[4] Perceiver: https://arxiv.org/pdf/2103.03206

[5] UniT: https://arxiv.org/pdf/2102.10772.pdf

[6] Joint Video and Image Encoder: https://arxiv.org/pdf/2104.00650.pdf

[7] Video Vision Transformer: https://arxiv.org/pdf/2103.15691

[8] Audio Spectrogram Transformer: https://arxiv.org/pdf/2104.01778


Related Articles