Fine-tuning pre-trained transformer models for sentence entailment

A PyTorch and Hugging Face implementation of fine-tuning BERT on the MultiNLI dataset

Dhruv Verma
Towards Data Science

--

Image from PNGWING.

In this article, I will be describing the process of fine-tuning pre-trained models such as BERT and ALBERT on the task of sentence entailment using the MultiNLI dataset (Bowman et al. A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference). The models will be loaded using the Hugging Face library and are fine-tuned using PyTorch.

What is entailment?

To understand entailment, let’s start with an example.
1. Jim rides a bike to school every morning.
2. Jim can ride a bike.

Entailment occurs if a proposed premise is true. In this example, if the sentence ‘Jim rides a bike to school every morning.’ is true, then the premise entails that Jim goes to school every morning and Jim also knows how to ride a bike. Hence, this would make the second sentence, or the hypothesis, true as well.

To define entailment in simple terms, a sentence Y is said to entail sentence X if X is true and Y can be logically derived from it. For the dataset I used, a pair of sentences can either entail each other, be neutral or contradict each other—more on the dataset in the next section.

The MultiNLI dataset

The Multi-Genre Natural Language Inference (MultiNLI) corpus is a dataset designed for use in the development and evaluation of machine learning models for sentence understanding. It has over 433,000 examples and is one of the largest datasets available for natural language inference (a.k.a recognizing textual entailment). The dataset is also designed so that existing machine learning models trained on the Stanford NLI corpus can also be evaluated using MultiNLI. You can read more about this dataset in the paper — A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference.

As part of the training process, 3 columns were considered in the dataset — ‘gold_label,’ ‘sentence1’ (premise), and ‘sentence2’ (hypothesis). The ‘gold_label’ is the column indicating the label given to the pair of sentences. There were three labels — ‘entailment,’ ‘neutral,’ and ‘contradiction.’
The training set had 392,702 samples, and the validation set had 10,000 samples left.

Training set. Image by the author.
Validation set. Image by the author.

The model — BERT

BERT (Bidirectional Encoder Representations from Transformers) is a language model by Google based on the encoder-decoder transformer model introduced in this paper. It uses transformers' attention mechanism to learn the contextual meaning of words and the relations between them. BERT, along with its modifications such as ALBERT, RoBERTa, etc. have been known to achieve state-of-the-art results on various natural language process tasks such as question-answering and natural language inference.

The Transformer encoder reads an entire sequence of words at once, unlike the directional Long Short-Term Memory network. This allows the model to learn the context of a word based on all of its surroundings. The encoder block of a transformer takes a sequence of tokens as an input. These are first embedded into vectors and fed through a feed-forward neural network. The output from this neural network is a sequence of vectors that each corresponds to an input sequence at a given index.

While I will not be elaborating much on BERT’s training process, you can read this article for a great detailed description of it’s working as well as the procedure to train it.

Fine-tuning BERT

Finally, coming to the process of fine-tuning a pre-trained BERT model using Hugging Face and PyTorch. For this case, I used the “bert-base” model. This was trained on 100,000 training examples sampled from the original training set due to compute limitations and training time on Google Colab.

The first step involved creating a DataLoader object to feed data to the model. BERT for sequence classification requires the data to be arranged in a certain format. Each sentence's start needs to have a [CLS] token present, and the end of the sentence needs a [SEP] token. So with our sequence consisting of two sentences, it will need to be formatted as [CLS] sentence1 [SEP] sentence2 [SEP]. Additionally, each sequence will need to have segment_ids associated with it. The first sentence in the sequence is marked by [0], while the second sentence is marked by [1]. Lastly, each sequence needs an attention mask to help the model determine which part of the input sequence is not part of the padding.

Creating the DataLoader object

Now that the DataLoader objects for the training and validation sets are created, the model can be loaded along with its optimizer. For this case, I will be using the BertForSequenceClassification pre-trained model. This model offers an additional argument to add an optional classification head with the required number of labels. For this case, there are three classes. Hence I set num_labels to three. This adds a classification head with three output units as the final layer.

Loading the pre-trained model

Now that the model is loaded, time to move to the training and validation loops. As part of the training process, the model was fine-tuned for 5 epochs.

Training and validation loops

With the training and validation loops defined, we can tune the model on the MultiNLI dataset to try and achieve the expected performance.

Epoch 1: train_loss: 0.5973 train_acc: 0.7530 | val_loss: 0.5398 val_acc: 0.7836 01:47:59.10 
Epoch 2: train_loss: 0.3623 train_acc: 0.8643 | val_loss: 0.5222 val_acc: 0.8072 01:48:18.70
Epoch 3: train_loss: 0.2096 train_acc: 0.9256 | val_loss: 0.6908 val_acc: 0.7939 01:48:11.29
Epoch 4: train_loss: 0.1295 train_acc: 0.9558 | val_loss: 0.7929 val_acc: 0.7891 01:47:59.77
Epoch 5: train_loss: 0.0916 train_acc: 0.9690 | val_loss: 0.8490 val_acc: 0.7906 01:47:52.39

As seen from the loss and accuracy values above, the model seems to be learning while overfitting a bit. This can be solved by training with more data instead of the sampled 100,000 samples.

Thank you for reading this article! The entire code for this project, along with other model benchmarks, can be found at https://github.com/dh1105/Sentence-Entailment.

--

--

A computer science graduate interested in the applications of deep learning in NLP and vision.