Hugging Face Transformers: Fine-tuning DistilBERT for Binary Classification Tasks

A Beginner’s Guide to NLP and Transfer Learning in TF 2.0

Ray William
Towards Data Science

--

Hugging Face and TensorFlow 2.0. Source

1.0) Introduction

Creating high-performing natural language models is as time-consuming as it is expensive. After all, it took the team behind Google Brain 3.5 days on 8 Tesla P100 GPUs to train all 340 million parameters of the famous BERT-large model, and ever since its inception in 2018, natural language models have only increased in complexity.¹

Size of recent natural language models in millions of parameters. Source

But what if a company doesn’t have the resources needed to train such large behemoths? Well, thanks to recent advances in transfer learning (the technique was previously well established in Computer Vision and only recently found applications in NLP), companies can more easily achieve state-of-the-art performance by simply adapting pre-trained models for their own natural language tasks.

In this article, I would like to share a practical example of how to do just that using Tensorflow 2.0 and the excellent Hugging Face Transformers library by walking you through how to fine-tune DistilBERT for sequence classification tasks on your own unique datasets.

And yes, I could have used the Hugging Face API to select a more powerful model such as BERT, RoBERTa, ELECTRA, MPNET, or ALBERT as my starting point. But I chose DistilBERT for this project due to its lighter memory footprint and its faster inference speed. Compared to its older cousin, DistilBERT’s 66 million parameters make it 40% smaller and 60% faster than BERT-base, all while retaining more than 95% of BERT’s performance.² This makes DistilBERT an ideal candidate for businesses looking to scale their models in production, even up to more than 1 billion daily requests! And as we will see, DistilBERT can perform quite admirably with the proper fine-tuning. With that out of the way, let’s get started!

2.0) The Data

For this project, I will be classifying whether a comment is toxic or non-toxic using personally modified versions of the Jigsaw Toxic Comment dataset found on Kaggle (I converted the dataset from a multi-label classification problem to a binary classification problem).

Data preview of the modified Jigsaw Toxic Comment dataset. Image by the author.

Following conversion, the dataset exhibits class imbalance with toxic comments making up 9.58% of all data. This is a problem because any naive model could simply “learn” the class distribution and predict the majority class every time and still get 90.42% accuracy. So while this would seem like our model is successful, it would actually be completely ineffective at predicting toxic comments (the minority class), which is not at all what we want!

Ways to handle class imbalance. Source

To deal with this, we will implement a combination of both undersampling and oversampling to balance out our class distribution.

(Note: Make sure to split your data beforehand and only oversample the training set to ensure your evaluation results remain as unbiased as possible!)

2.1) The ‘Unbalanced’ Dataset

It is important to note, however, that a fine balance must be met when undersampling the majority class. If we undersample too much, we risk hurting model performance by losing out on valuable training data. But if we undersample too little (or not at all), the model’s predictions might bias towards the majority class and be unable to predict the minority class.

Keeping this in mind, I attempted to find the right balance by undersampling the modified dataset until toxic comments made up ~20% of all training data. This dataset will henceforth be referred to as the unbalanced dataset , and it is the dataset on which I received the best empirical results for this specific problem.

Class distribution of the ‘unbalanced dataset’. Image by the author.

2.2) The ‘Balanced’ Dataset

Along with data science comes the beauty of experimentation, and I thought it could be fruitful to fine-tune DistilBERT on a fully balanced dataset as well. Not only would doing this completely eliminate the imbalanced classification problem, but I also hoped that adding synthetic data into our training set would allow our model to generalize to previously unseen data.

To do this, I took the unbalanced dataset’s training set and oversampled the minority class until both classes had approximately 48,000 representative texts, thus creating the balanced dataset.

Various implementations of text augmentation via word replacement. Source

For the oversampling, I performed data augmentation using the nlpaug library via word replacement using BERT contextual embeddings. Generating this data can be a slow process depending on which language model you choose (the library currently supports implementations for DistilBERT, BERT, RoBERTa, and XLNet), but the library offers additional methods for data augmentation that fall into three categories:

  1. Character-level augmentation (can simulate typos in words while taking keyboard distance into account)
  2. Word-level augmentation (can apply back-translation, random insertion or deletion, word splitting, or thesaurus-based synonym replacement)
  3. Sentence-level augmentation (can generate sentences using next-sentence-prediction or abstractive text summarization)
Average performance gain over five text classification tasks for different training set sizes (N). The α parameter roughly means “percent of words in sentence changed by each augmentation.” SR: Synonym Replacement, RI: Random Insertion, RS: Random Swap, RD: Random Deletion. Source

Furthermore, not only does nlpaug allows you to control how you generate new text, but it also allows you to control how much, or what percent, of the provided text should be modified to give rise to the newly generated text.

Because of this, it might be a bit confusing to know where to start, but in the 2019 paper “EDA: Easy Data Augmentation Techniques for Boosting Performance on Text Classification Tasks,” the authors provide the above figure to be used as a reference for your data augmentation pipeline.³ For my purposes, I chose to generate new sentences by replacing approximately 10% of all words in a given string of text (α = 0.1), but the best choice of hyperparameters may be different for your specific dataset. Try things out, and see how it goes!

For more information on text augmentation, go ahead and give this article a read: Link.

(Note: Building up a Pandas DataFrame using the .append() method is actually very inefficient, as you will be recopying the entire DataFrame with each method call. Instead, iteratively build up a dictionary containing your data and call the .from_dict() method to construct your final DataFrame. You can see me do this in lines 47–54.) Source

Unfortunately, in my case, training on the balanced dataset actually resulted in poorer performance (likely due to the fact that text augmentation performs better on small datasets), so all experiments in upcoming sections of this article will be performed on the unbalanced dataset instead.

3.0) Transfer Learning with Hugging Face

Now that we have our datasets in order, it’s time to start building our model! To do so, we will take full advantage of the power of transfer learning by choosing a pre-trained model for our base and adding additional layers on top as it suits our classification task. This is effective because the pre-trained model’s weights contain information representing a high-level understanding of the English language, so we can build on that general knowledge by adding additional layers whose weights will come to represent task-specific understanding of what makes a comment toxic vs non-toxic.

As we will see, the Hugging Face Transformers library makes transfer learning very approachable, as our general workflow can be divided into four main stages:

  1. Tokenizing Text
  2. Defining a Model Architecture
  3. Training Classification Layer Weights
  4. Fine-tuning DistilBERT and Training All Weights

3.1) Tokenizing Text

Once we select a pre-trained model, it’s time to convert human-readable strings of text into a format our model can interpret. This process is known as tokenization, and the intuitive Hugging Face API makes it extremely easy to convert words and sentences → sequences of tokens → sequences of numbers that can be converted into a tensor and fed into our model.

BERT and DistilBERT tokenization process. The special [CLS] token stands for ‘classification’ and will contain an embedding for the sentence-level representation of the sequence. The special [SEP] token stands for ‘separation’ and is used to demarcate boundaries between sequences. Source

In general, different pre-trained models utilize different methods to tokenize textual inputs (in the figure above, see how DistilBERT’s tokenizer includes special tokens such as [CLS] and [SEP] in its tokenization scheme), so it will be necessary to instantiate a tokenizer object that is specific to our chosen model. To get the tokenizer used by distilbert-base-uncased, we pass our model’s name to the .from_pretrained() method of the DistilBertTokenizerFast class.

(Note: Hugging Face provides both ‘slow’ and ‘fast’ versions of its tokenizers. Whereas the ‘slow’ version is written in Python, the ‘fast’ version is written in Rust and provides significant speedups when performing batched tokenization. In this article, we use the ‘fast’ version to take advantage of these performance benefits.)

Once we instantiate our tokenizer object, we can then go about encoding our training, validation, and test sets in batches using the tokenizer’s .batch_encode_plus() method.

Important arguments we may wish to set include:

  • max_length → Controls the maximum number of words to tokenize in a given text.
  • padding → If set to ‘longest,’ then pads to the longest sequence in the batch.
  • truncation → If True, then truncates text according to the value set by max_length.
  • return_attention_mask → If True, then returns the attention mask. This is optional, but attention masks tell your model what tokens to pay attention to and which to ignore (in the case of padding). Thus, including the attention mask as an input to your model may improve model performance.
  • return_token_type_ids → If True, then returns the token type IDs. This is required for some tasks that require multiple sequences as input (e.g. Question Answering requires a ‘question’ and an ‘answer’ sequence), as the token IDs inform the model on where one sequence of the input ends and the other sequence begins. For our purposes, however, this is optional because our classification task only requires one sequence as input (the potentially toxic comment).

As seen in the code above, the batch_encode function I create ends up returning:

  1. input_ids → Words of the text encoded as sequences of numbers.
  2. attention_mask → A binary sequence telling the model which numbers in input_ids to pay attention to and which to ignore (in the case of padding).

Both input_ids and attention_mask have been converted into Tensorflow tf.Tensor objects so they can be readily fed into our model as inputs.

3.2) Defining a Model Architecture

Now that we have encoded our training, validation, and test sets, it is time to define our model architecture. Since we will be using DistilBERT as our base model, we begin by importing distilbert-base-uncased from the Hugging Face library.

Initialize the Base Model

Importantly, we should note that the Hugging Face API gives us the option to tweak the base model architecture by changing several arguments in DistilBERT’s configuration class. Here, we instantiate a new config object by increasing dropout and attention_dropout from their defaults of 0.1 to their new values of 0.2, but there are many other options to choose from, all of which can be found in the configuration class’s documentation that is specific to your chosen model.

After (optionally) modifying DistilBERT’s configuration class, we can pass both the model name and configuration object to the .from_pretrained() method of the TFDistilBertModel class to instantiate the base DistilBERT model without any specific head on top (as opposed to other classes such as TFDistilBertForSequenceClassification that do have an added classification head). We do not want any task-specific head attached because we simply want the pre-trained weights of the base model to provide a general understanding of the English language, and it will be our job to add our own classification head during the fine-tuning process in order to help the model distinguish between toxic comments.

Because DistilBERT’s pre-trained weights will serve as the basis for our model, we wish to conserve and prevent them from updating during the initial stages of training when our model is beginning to learn reasonable weights for our added classification layers. To temporarily freeze DistilBERT’s pre-trained weights, set layer.trainable = False for each of DistilBERT’s layers, and we can later unfreeze them by setting layer.trainable = True once model performance converges.

Add a Classification Head

As we build up our model architecture, we will be adding a classification head on top of DistilBERT’s embedding layer that we get as model output in line 35 . In actuality, the model’s output is a tuple containing:

  1. last_hidden_state → Word-level embedding of shape (batch_size, sequence_length, hidden_size=768).
  2. hidden_states → [Optional] Tuple of tf.Tensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size=768). Returned when we set output_hidden_states=True in the config file.
  3. attentions → [Optional] Attention’s weights after the attention softmax, used to compute the weighted average in the self-attention heads. Returned when we set output_attentions=True in the config file.

Rather than accessing other layers’ hidden states, we select index 0 of this tuple to access last_hidden_state, as building off of this embedding usually leads to the best empirical results.⁴

Fine-tuning BERT with different layers on the IMDb movie reviews dataset. Source

Each hidden state is a tf.Tensor of shape (batch_size, sequence_length, hidden_size=768) and contains the word-level embedding output of one of DistilBERT’s 12 layers. Therefore, the last hidden state will be of the shape (64, 128, 768) in our case since we set BATCH_SIZE=64 andMAX_LENGTH=128 and DistilBERT has a hidden size of 768.

Shape of ‘last_hidden_state’ with [CLS] tokens highlighted in red. Source

I should emphasize that all 128 sequence tokens in the embedding provide a word-level understanding, and one may be able to extract a great deal of information from the 3D embedding perhaps with a bi-directional LSTM and a max-pooling layer as was performed in this article.

However, for our purposes, we will instead make use of DistilBERT’s sentence-level understanding of the sequence by only looking at the first of these 128 tokens: the [CLS] token. Standing for “classification,” the [CLS] token plays an important role, as it actually stores a sentence-level embedding that is useful for Next Sentence Prediction (NSP) during the pre-training phase. Thus, we can access this sentence-level embedding in line 41 by taking a slice of last_hidden_state such that we are left with a 2D tensor that represents the entire sequence of text.

Sentence-level embedding of the [CLS] token. Source

To get a baseline for our model’s performance, we can start out by just adding a single, dense output layer with sigmoid activation function on top of the [CLS] token’s sentence-level embedding. Finally, we compile the model with adam optimizer’s learning rate set to 5e-5 (the authors of the original BERT paper recommend learning rates of 3e-4, 1e-4, 5e-5, and 3e-5 as good starting points) and with the loss function set to focal loss instead of binary cross-entropy in order to properly handle the class imbalance of our dataset.

(Note: tf.keras does NOT provide focal loss as a built-in function you can use. Instead, you will have to implement focal loss as your own custom function and pass it in as an argument. Please see here to understand how focal loss works and here for an implementation of the focal loss function I used. )

3.3) Training Classification Layer Weights

Ok, we’ve finally built up our model, so we can now begin to train the classification layer’s randomly initialized weights until model performance converges. In the case of a simple baseline model with just a single output layer, training all 768 available weights (since all of DistilBERT’s weights are frozen) over 6 epochs results in an accuracy of 85.7% and an AUC-ROC score of 0.926 on the test set. Not bad for a model trained with just a few lines of code!

However, we can definitely do better, and one thing we can think about at this stage is changing up our model architecture. After all, our model is pretty simple at this point with just a single output layer on top of DistilBERT, so it might be a good idea to add additional dense and/or dropout layers in between.

I do this for two dense layers by performing a grid-search using the Comet.ml API and find that the optimal model architecture for my specific classification problem looks like:

[DistilBERT CLS Embedding Layer] + [Dense 256] + [Dense 32] + [Single-node Output Layer]

with dropout of 0.2 between each layer.

Grid search for number of nodes in each dense layer. Image by the author.

As a result of this change, our new model scores an accuracy of 87.3% and an AUC-ROC of 0.930 on the test set by training only the added classification layers.

3.4) Fine-tuning DistilBERT and Training All Weights

Once we finish training the added classification layers, we can squeeze even more performance out of our model by unfreezing DistilBERT’s embedding layer and fine-tuning all weights with a lower learning rate (to prevent major updates to the pre-trained weights). Note that it is necessary to recompile our model after unfreezing layer weights, but aside from that, the training procedure looks the same as the previous step.

As a result of fine-tuning DistilBERT’s pre-trained weights, our model achieves a final accuracy of 92.18% and an AUC-ROC of 0.969 on the test set 🥳🔥🎉.

4.0) Conclusion

As Dr. Károly Zsolnai-Fehér from Two Minute Papers might say…

“What a time to be alive!”

As you can see, exciting times are upon us, as anyone with access to a computer can harness the power of state-of-the-art, multi-million dollar pre-trained models with relatively few lines of code. This ease of use is key to rapid development and model deployment, and the compact nature of smaller models such as DistilBERT make them scalable and able to produce real-time results while still maintaining high levels of performance.

As if DistilBERT wasn’t fast enough, inference speed can be further optimized using weight quantization and model serving using ONNX Runtime, but alas, I’m getting ahead of myself, as this is a topic for a future article.

In any case, I hope the code and explanations in this article are helpful, and I hope you are now as excited about NLP and the Hugging Face Transformers library as I am! If you enjoyed the content, feel free to connect with me on LinkedIn, find me on Kaggle, or check out the code on GitHub.

Good luck out there in your NLP pursuits, and happy coding!

5.0) References

[1] A. Vaswani et al., Attention Is All You Need (2017), arXiv:1706.03762

[2] V. Sanh et al., DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter (2019), arXiv:1910.01108

[3] J. Wei and K. Zou, EDA: Easy Data Augmentation Techniques for Boosting Performance on Text Classification Tasks (2019), arXiv:1901.11196

[4] C. Sun et al., How to Fine-Tune BERT for Text Classification? (2019), arXiv:1905.05583

--

--