Distillation of Knowledge in Neural Networks

State-of-the-art performance on much Smaller & Faster models

Mukul Malik
Towards Data Science

--

Distillation of Knowledge (in machine learning) is an architecture agnostic approach for generalization of knowledge (consolidating the knowledge) within a neural network to train another neural network.

Importance

Currently, especially in NLP, very large scale models are being trained. A large portion of those can’t even fit on an average person’s hardware. Plus, due to the Law of diminishing returns, a great increase in the size of model barely maps to a small increase in the accuracy.

These models can barely run on commercial servers, let alone on a smartphone.

Using distillation, one could reduce the size of models like BERT by 87% and still retain 96% of its performance.

Basically, distillation enables one to:

  • get State-of-the-Art accuracies
  • with a fraction of size
  • within a fraction of response time
  • within a fraction of fine-tuning time
  • fit an ensemble on single GPU
  • could run models on CPUs
  • enables federation learning
  • etc

Problem with Normal Neural Networks

The goal of every learner is to optimize its performance on training data. This doesn’t exactly translate as a generalization of knowledge within the dataset.

Take an example of MNIST dataset. Let’s pick a sample picture of number 3.

In training data, the number 3 translates to a corresponding one-hot-vector:

0 0 0 1 0 0 0 0 0 0

This vector simply tells that the number in that image in 3 but

fails to explicitly mention anything about the shape of number 3. Like the shape of 3 is similar to 8.

Hence:

neural network is never explicitly being asked to learn the generalized undersdtanding of the training data. The degree of generalization is the implicit ability of the neural network.

As a result, in a normally trained neural network, information (detected feature) within each neuron is not equally significant with respect to the desired output.

Simply put normally trained neural networks carry a lot of dead weight in the form of neurons which never learned to generalize data and hence result in lowering the accuracy over the test data.

Distillation

Distillation enables us to train another neural network using a pre-trained network, without the dead weight of the original neural network.

Enabling us to compress the size of the network without much loss of accuracy.

Hence distilled models have higher accuracies than their normally trained counterparts.

Note: Distillation of knowledge can be done from any form of a learner (logistic regression, SVM, neural networks etc) to any other form of a learner.

Though for the simplicity of this blog, I’ll be taking references of only neural networks.

Generalization of Information

Let us take a step back and revise the goal of a neural network:

predict the output for samples that the network had never seen during training by generalizing the knowledge within the training data.

Taking an example of a discriminative neural network whose objective is to identify the relevant class for a given input.

Now the neural network returns distribution of probabilities across all classes, even the wrong ones.

This tells us a lot about the capability of the network to generalize over the concepts within the training data.

Measure of Generalization

For a decently trained neural network on MNIST following observations would be true:

  • even though the probability for number 3 is significantly greater than the probability for the number 8 and number 0
  • Probability of 8 and 0 are comparable
  • still higher probability of 8 and 0 is comparatively higher than other numbers

So, the neural network is able to identify that the shape of the number in that image is 3 but the neural network also suggests that the shape of 3 is quite similar to the shape of numbers 8 and 0 (all are quite curvy).

Process of Generalization of Information

No, you don’t actually need all this equipment. This is just funny and relevant.

Let’s start with the obvious, train an enormously gigantic neural network (that your hardware can support) in a normal manner. We’ll be referring to this network as the cumbersome network.

Note: This cumbersome model can very well be an ensemble of multiple normally trained models.

References:

  • soft targets: network’s probability/weight distribution across all classes
  • hard targets: one-hot vector representation within the original training data

When the cumbersome model is not a single model but an ensemble of multiple models, the arithmetic/geometric mean of their outputs as the soft-targets.

The new model can be trained on the same dataset as the original model or on a different dataset called the ‘transfer set’.

Transfer-Set: Pass the data through the cumbersome model and use its output (probability distribution) as the respective truth values. It can consist of the dataset used to train the original model, new dataset or both.

Temperature & Entropy

By adjusting the temperature of the soft-targets, it is possible to adjust the size of this transfer-set.

When soft-targets have high entropy, they give much more information per-training sample than hard-targets (let’s get back to that in a minute).

This leads to:

  • smaller loss, hence smaller correction gradient (backpropagation)
  • less variation between the gradients of different training examples

As a result:

  • a greater learning rate can be used to train the model
  • a smaller dataset can be used to train the model

More on Temperature

Something we learned in Physics that applies here ‘Entropy increases with temperature’.

Let us use an analogy:

Imagine a box of balls well stacked on top of each other. If we increase the entropy of the box by shaking it a bit, the balls don’t fall out of the box but rather spread a bit.

The resulting distribution has a shape similar to the original distribution but the magnitude of its peaks change.

However, at higher temperatures, a certain amount of heat added to the system causes a smaller change in entropy than the same amount of heat at a lower temperature.

Following is an example of probability distribution comparing a system with low temperature and high temperature.

Now the formula we use for temperature for transferring knowledge from one neural network to another is:

  • q ᵢ : resulting probability
  • z ᵢ : logit of a class
  • z ⱼ : other logits
  • T : temperature

Logits: vector of raw (non-normalized) predictions that a classification model generates, which is ordinarily then passed to a normalization function. If the model is solving a multi-class classification problem, logits typically become an input to the softmax function.

Well, imagine the following scenarios:

Imagine an MNIST data-point for number 3 (yes, I like the number 3) but let us concentrate on only 3 (yes yes) classes:

  • 0
  • 3
  • 8

Dataset Values: 0 , 1 , 0

Logits (NN output): 0.1 , 0.7 , 0.2

Temp 0.5: 0.1831659252 , 0.5939903214 , 0.2228437534

Temp 1: 0.254628528 , 0.463963428 , 0.281408044

Temp 2: 0.294019937 , 0.3968854015 , 0.3090946615

Temp 5: 0.3176924658 , 0.3581972549 , 0.3241102793

Higher temperature results in a softer probability distribution over classes.

Now, let’s visualize the smoothness of probability distribution as we increase the temperature:

Training the Distil Model

The simplest form of distillation is training a model using the soft targets generated by a cumbersome model with high temperature and are distilled into another model with the same temperature.

After the training, the temperature of the distilled model is set to 1.

By now we have established that the distilled network can be trained on a transfer set consisting of soft targets.

BUT.

We can also leverage the truth values or the hard targets which are known for all or some of the data.

One of the most efficient methods of doing this is by using 2 objective functions:

  • cross-entropy with soft targets using a high-temperature cumbersome model
  • cross-entropy with hard targets using the same cumbersome model but with the temperature set to 1

Note: The magnitude of soft-targets is scaled to i/T² times, whereas hard-targets undergo no such scaling. So we multiply the soft-targets with T² to normalize the impact of soft-targets and hard-targets.

Code

HuggingFace actually provides scripts to train your own DistilBERT & DistilRoBERTa, which are 40% smaller, 60% faster while retaining 99% accuracy of the original model.

To get the repository use:

git clone https://github.com/huggingface/transformers.git

then

cd transformers/examples/distillation

First, we will binarize the data, i.e. tokenize the data and convert each token in an index in our model’s vocabulary.

python scripts/binarized_data.py \
--file_path data/dump.txt \
--tokenizer_type bert \
--tokenizer_name bert-base-uncased \
--dump_file data/binarized_text

HuggingFace follows XLM’s one and smoothes the probability of masking with a factor that put more emphasis on rare words. Thus count the occurrences of each token in the data:

python scripts/token_counts.py \
--data_file data/binarized_text.bert-base-uncased.pickle \
--token_counts_dump data/token_counts.bert-base-uncased.pickle \
--vocab_size 30522

Training with distillation is really simple once you have pre-processed the data:

python train.py \
--student_type distilbert \
--student_config training_configs/distilbert-base-uncased.json \
--teacher_type bert \
--teacher_name bert-base-uncased \
--alpha_ce 5.0 --alpha_mlm 2.0 --alpha_cos 1.0 --alpha_clm 0.0 --mlm \
--freeze_pos_embs \
--dump_path serialization_dir/my_first_training \
--data_file data/binarized_text.bert-base-uncased.pickle \
--token_counts data/token_counts.bert-base-uncased.pickle \
--force # overwrites the `dump_path` if it already exists.

Conclusion

By distilling neural networks, we obtain a smaller model that bears a lot of similarities with the original model while being lighter, smaller and faster to run.

Distilled models are thus an interesting option to put the large-scaled neural networks into production.

--

--