Distilling BERT Using an Unlabeled Question-Answering Dataset

How to leverage unlabeled data for a question-answering task using knowledge distillation

Sergey Parakhin
Towards Data Science

--

Photo by Alfons Morales on Unsplash

The data labeling process is quite complicated, especially for tasks such as Machine Reading Comprehension (Question Answering). In this post, I want to describe one of the techniques we used to adapt the question-answering model to a specific domain using a limited amount of labeled data — Knowledge Distillation. It turned out that we can use it not only to “compress” the model but also to leverage in-domain unlabeled data.

Question Answering and SQuAD 📖

One of the simplest forms of Question Answering systems is Machine Reading Comprehension (MRC). There the task is to find a short answer to a question within the provided document. The most popular benchmark for MRC is the Stanford Question Answer Dataset (SQuAD) [1]. It contains 100,000 question-answer pairs and 53,775 unanswerable questions written for 23,215 paragraphs from popular Wikipedia articles. The answer to every question is a segment of text (a span) from the corresponding reading passage. For unanswerable questions, the system should determine when no answer is supported by the paragraph and abstain from answering.

Examples of answerable and unanswerable questions from SQuAD 2.0

There are also other question-answering datasets available, such as Natural Questions and MS MARCO, but SQuAD 2.0 is one of the most used and it was the starting point for our project.

Evaluating pre-trained models 👩🏻‍🔬

Recently, we worked on a question answering system for a photo & video cameras online store where we trained a machine reading comprehension model. In our project, we tested various pre-trained question answering models (thanks to 🤗 Hugging Face) on our small labeled dataset and found that ALBERT-xxlarge [2] trained on SQuAD 2.0 shows promising results on our domain:

Results on our test dataset using the models pre-trained on SQuAD 2.0 (image by author)

But the model is so slow that we cannot use it in production. One of the commonly used methods in such cases is the use of Knowledge Distillation.

Knowledge distillation ⚗️

Knowledge Distillation [3] is usually used as a model compression technique when we want to train a faster or smaller model. In this process, we train a smaller model (student) using output probabilities from our primary larger model (teacher), so a student starts to “imitate” its teacher’s behavior. The loss based on comparing output distributions is much richer than from only hard targets. The idea is that by using such soft labels, we can transfer some “dark knowledge” from the teacher model. Additionally, learning from soft labels prevents the model from being too sure about its prediction, which is similar to a label smoothing [4] technique.

Knowledge Distillation (image by author)

Knowledge distillation is a beautiful technique that works surprisingly well and is especially useful with Transformers — larger models often show better results, but it’s hard to put such big models into production.

Distillation using unlabeled data 🔮

The models distilled or trained on SQuAD don’t show competitive results on our dataset. The SQuAD dataset doesn’t overlap with our domain, so such distillation doesn’t work well. It also seems that larger models trained on SQuAD work much better on out-of-domain data. To make the distillation process work, we need to use datasets from our domain. Besides our small labeled dataset, we also had about 30,000 unlabeled (no highlighted answers) question-document pairs, and we thought about how we can use them.

What do we need for distillation? A teacher and a labeled dataset. ALBERT-xxlarge can be a teacher model. We don’t have labels for our 30K examples, but can we just remove part of the loss which uses labels? Sure, we will inherit more teacher mistakes during knowledge distillation without ground-through labels. But we don’t have models better than ALBERT-xxlarge at the moment, so even getting similar results with a smaller model would be useful for us. So, we tried to distill the knowledge from ALBERT-xxlarge to ALBERT-base using only unlabeled 30K examples.

Knowledge Distillation ground-through labels (image by author)
Results on our test dataset using the models pre-trained on SQuAD 2.0 (image by author)

As you can see, we got ALBERT-base with F1/EM close to its teacher, and we didn’t use any labeled data for training. Of course, it doesn’t mean we don’t need labeled data anymore, the score is still far from ideal, and we also inherited teacher mistakes, so adding labeled data may improve this training procedure.

Distillation as a pre-training 🏋🏻‍♂️

We can also think about distillation as an additional pre-training step to achieve better sample-efficiency when using labeled data. Below, you can see that distillation helps for further fine-tuning on labeled data in our case.

Results on our test dataset using the models pre-trained on SQuAD 2.0 and fine-tuned on in-domain labeled data (~1000 labeled examples) (image by author)

Self-distillation 👯‍♀️

The better teacher you have, the better student you can probably train. So, one of the main directions can be improving the teacher model. Besides using our labeled dataset, there is another exciting technique — self-distillation, when we use the same model architecture for both student and teacher. Self-distillation allows us to train a student model that can perform better than its teacher. That may sound strange, but because a student updates its weights learning from the data that the teacher didn’t see, this can lead to a slightly better performance of the student (of comparable size) on the data from that distribution. Our experiments also reproduced this behavior when we applied self-distillation for ALBERT-xxlarge and then used it as one of our teachers for further distillation to a smaller model.

Ensembles 👯‍♀️️️👯‍♂️

And of course, the distillation training procedure allows us to effectively use an ensemble of teacher models and distill them to a single smaller model. Combining all those approaches (see below) and a domain-adaptive language pre-training, we were able to achive good results using a limited number of labeled examples.

All steps of our training pipeline (image by author)

Wait, is our data really unlabeled? 🤔

Worth to say that even though we didn’t label any data for distillation, we still had questions, which is not always the case. Compared to NLP tasks like text classification and NER, the question-answering model’s input consists of a question and a document. And even though the question is an input, it can be considered as a part of the data labeling process (you have to write a question, not just highlight an answer). In this way, a truly unlabeled QA dataset is when we only have documents without questions.

But even though we had questions, we didn’t manually prepare each question to a specific document. We collected these 30,000 QA pairs matching the questions with a set of independent documents using the pre-trained USE-QA model. In such a way, we can start with a pre-trained model and approach the further model improvements in production. After we start collecting real questions asked by users interacting with our system, we can find candidate documents for these questions in the same way and use this dataset for knowledge distillation without labeling many examples.

Using questions collected from a deployed system (image by author)

Experimental Setup 🛠

We used the examples from 🤗 Hugging Face Transformers for knowledge distillation, and 1x NVIDIA 2080TI for all of our experiments. Gradient accumulation allowed us to work with larger models such as ALBERT-xxlarge and RoBERTa-large, and mixed precision (fp16) to train the models faster. To do self-distillation for ALBERT-xxlarge on one 2080TI, we first pre-computed the teacher’s predictions (soft labels).

Conclusion

Knowledge distillation is a handy technique that allows us to “compress” huge Transformer models to use them in production. We can also use it in the form of pre-training on unlabeled data, which can help improve sample-efficiency when fine-tuning on labeled examples. That can be especially useful for tasks such as machine reading comprehension, where the process of labeling data is quite complicated. Distillation using unlabeled data is not a novel idea, and it already showed good results in both computer vision (SimCLRv2, Noisy Students) and NLP (Well-Read Students Learn Better).

If you’d like to learn more about our journey of building a question answering system, check out our more detailed post, where we describe how we collected and labeled data, fine-tuned the model, and applied various techniques such as knowledge distillation and pruning:

[1] Pranav Rajpurkar, Robin Jia, Percy Liang, Know What You Don’t Know: Unanswerable Questions for SQuAD (2018), ACL 2018

[2] Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut, ALBERT: A Lite BERT for Self-supervised Learning of Language Representations (2019)

[3] Geoffrey Hinton, Oriol Vinyals, Jeff Dean, Distilling the Knowledge in a Neural Network (2015), NIPS 2014 Deep Learning Workshop

[4] Rafael Müller, Simon Kornblith, Geoffrey Hinton, When Does Label Smoothing Help? (2019), NeurIPS 2019

--

--