Knowledge Distillation — A Survey Through Time

Through this blog you will review Knowledge Distillation (KD) and six follow-up papers.

Nishant Nikhil
Towards Data Science

--

History

In 2012, AlexNet outperformed all the existing models on the ImageNet data. Neural networks were about to see major adoption. By 2015, many state of the arts were broken. The trend was to use neural networks on any use case you could find. The success of VGG Net further affirmed the use of deeper-model or ensemble of models to get a performance boost.

(Ensemble of models is only a fancy term. It means averaging of outputs from multiple models. Like if there are three models and two models predict ‘A’ while one model predicts ‘B’, then take the final prediction as ‘A’ (two votes versus one vote))

But these deeper models and these ensemble of models are too costly to run during inference. (An ensemble of 3 models uses 3x the amount of computations of a single model).

Ideation

Geoffrey Hinton, Oriol Vinyals and Jeff Dean came up with a strategy to train shallow models guided by these pre-trained ensembles. They called this knowledge distillation because you distill knowledge from a pre-trained model to a new model. As this seems like a teacher guiding a student, so this is also called teacher-student learning. https://arxiv.org/abs/1503.02531

(Image from https://nervanasystems.github.io/distiller/knowledge_distillation.html)

In Knowledge Distillation they used the output probability of the pre-trained model as the labels for the new shallow model. Through this blog you would go through the improvements of this technique.

Fitnets

In 2015 came FitNets: Hints for Thin Deep Nets (published at ICLR’15)

FitNets add an additional term along with the KD loss. They take representation from the middle point of both the networks, and add a mean square loss between the feature representations at these points.

The trained-network is providing a learnt-intermediate-representation which the new-network is mimicking. These representations help the student to learn efficiently, and were called hints.

FitNet is able to compress the model while maintaining almost same performance

Looking back, this choice of using a single point for giving hints is sub-optimal. A lot of subsequent papers try to improve these hints.

Paying more attention to attention

Paying more attention to attention: Improving the performance of convolutional neural networks via Attention Transfer was published at ICLR 2017

Image from the paper

They have similar motivation as FitNets, but rather than the representations from a point in the network, they use the attention maps as the hints. (MSE over attention maps of student and teacher). They also use multiple points in the network for giving hints, rather than the one point hint in FitNets

A Gift from Knowledge Distillation

In the same year, A Gift from Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning was published at CVPR 2017.

This is also similar to FitNets and the attention transfer paper. But instead of the representation and the attention maps, they give hints using the Gram matrices.

They have an analogy for this in the paper:

“In the case of people, the teacher explains the solution process for a problem, and the student learns the flow of the solution procedure. The student DNN does not necessarily have to learn the intermediate output when the specific question is input but can learn the solution method when a specific type of question is encountered. In this manner, we believe that demonstrating the solution process for the problem provides better generalization than teaching the intermediate result.”

Image from the paper

To measure this “flow of solution procedure”, they use a gram matrix between the feature maps of two layers. So instead of the intermediate feature representation as the hints in FitNets, this uses Gram matrix between feature representations as the hints.

Paraphrasing Complex Network

Then in 2018 came Paraphrasing Complex Network: Network Compression via Factor Transfer published at NeurIPS 2018

Image from the paper

They add another module to the model which they call paraphraser. It is basically an auto-encoder which doesn’t reduce dimensions. From the last layer they fork out another layer which trains on the reconstruction loss.

The student has another module named translator. It embeds the outputs of the student’s last layer to the teacher-paraphraser’s dimensions. And they use this latent paraphrased representation from the teacher as hints.

tl;dr The student should be able to construct an auto-encoded representation of the input from the teacher network.

A Comprehensive Overhaul of Feature Distillation

In 2019, A Comprehensive Overhaul of Feature Distillation was published at ICCV 2019.

Image from the paper

They claim that the position from which we take the hints isn’t optimal. The outputs are refined through ReLU and some information is lost during that transformation. They propose a marginReLU activation (a shifted ReLU). “In our margin ReLU, the positive (beneficial) information is used without any transformation while the negative (adverse) information is suppressed. As a result, the proposed method can perform distillation without missing the beneficial information”

They employ a partial L2 distance function which is designed to skip the distillation of information on a negative region. (No loss if both the feature vector from student and from the teacher at that location is negative)

Contrastive Representation Distillation was published at ICLR 2020. Here also the student learns from the teacher’s intermediate representations, but instead of MSE loss they use a contrastive loss over them.

In total, these different models have employed different methods to

  1. Increase the amount of transferred information in distillation.
    (Feature representations, Gram Matrices, Attention Maps, Paraphrased representations, pre-ReLU features)
  2. Make the process of distillation efficient by tweaking with the loss function
    (Contrastive, partial L2 distance)

Another interesting way to look at these ideas is that new ideas are vector sum of old ideas.

  1. Gram Matrices for KD = Neural Style Transfer + KD
  2. Attention Maps for KD = Attention is all you need + KD
  3. Paraphrased representations for KD = Autoencoder + KD
  4. Contrastive Representation Distillation = InfoNCE + KD

What could be other vector sums?

  1. GANs for KD (that is change the Contrastive loss with a GAN loss between feature representations),
  2. Weak-supervision KD (Self-Training with Noisy Student Improves ImageNet classification)

This blog post is inspired from the tweet-storm on Knowledge Distillation (https://twitter.com/nishantiam/status/1295076936469762048)

--

--

Learner | Applied Scientist @amazonIN | prev @IITKGP @GSoC @eccvconf EMNLP’18