Make The Most of Your Small NER Data Set by Fine-tuning Bert

Leverage Large Language models to improve your model’s performance

Youness Mansar
Towards Data Science

--

Image By Author

Training NER models on small data sets can be a pain. It can be difficult to build a model that generalizes well on a non-trivial NER task when it is trained on a few hundred samples only. This kind of problem we will try to solve in this post.

First, we will define the task. Then, we will describe the different variations of the model’s architecture, and then compare the results between a pre-trained model vs one that was trained from scratch.

The Problem

We will be tackling a NER (Named Entity Recognition) problem applied to a recipe data set [Linked here: Taste set]. The objective of this task is to extract structured information from the raw text in the form of labeled entities.

Raw Text
Tagged text

Our objective is to assign the correct tag to each of the entities of the raw text.

The data set we will use is tiny, with around 700 sentences. This is the primary challenge of this task, and also where pre-trained models can shine given that fine-tuning them can re-use representations learned in the pre-training step for the task at hand.

The Models

Architecture:

The basic architecture of our model is:

Model Architecture (Image by Author)

Tagging scheme:

The BIO labels referenced above are a way to encode tags on the sequence of tokens generated from the raw text, as explained here:

Source: https://x-wei.github.io/notes/xcs224n-lecture3.html

The start token of each {LABEL} entity is tagged as B-{LABEL} and each other token of this same entity is tagged as I-{LABEL}. All other tokens are tagged as O.

Implementation:

We will use the Hugging Face Transformers library for BERT. We can plug in the pre-trained model like we would any other torch layer since it is an instance of nn.Module.

In the __init__ of the model, we can do this ->

from transformers import BertConfig, BertModelclass BertNerModel(BaseModel):
def __init__(
self,
...
dropout=0.2,
bert_path=None,
...
):
super().__init__()

...
self.bert = BertModel(config=CONFIG)

...

if bert_path:
state_dict = torch.load(bert_path)
self.bert.load_state_dict(state_dict)

self.do = nn.Dropout(p=dropout)

self.out_linear = nn.Linear(CONFIG.hidden_size, n_classes)

Where CONFIG is an instance of BertConfig, initialized from a JSON file.

Then, in the forward method, we do:

def forward(self, x):

mask = (x != self.pad_idx).int()
x = self.bert(
x, attention_mask=mask, encoder_attention_mask=mask
).last_hidden_state
# [batch, Seq_len, CONFIG.hidden_size]

x = self.do(x)

out = self.out_linear(x)

return out

We first generate a mask for the padding tokens, then we feed the input to the BERT model. We extract the last hidden layer from BERT’s output and then feed it through the linear classifier layer to produce the scores for each class.

As a comparison, we will also train a model from scratch using nn.TransformerEncoder. The implementation follows the same logic as with BERT. You can find it all here.

The results

We track our loss metrics using Tensorboard:

Validation Loss (Orange: BERT, Blue: nn.TransformerEncoder)
Token level accuracy (Orange: BERT, Blue: nn.TransformerEncoder)

We can see that the model based on fine-tuning BERT generalizes better than the one trained from scratch. This is further confirmed by the F1 scores presented below:

F1 score on the validation set:

BERT Initialization: 92.9%

nn.TransformerEncoder: 76.3%

Inference speed on a 3080 RTX GPU (Batch size of 1):

BERT Initialization: 43 sentences/second

nn.TransformerEncoder: 303 sentences/second

Conclusion

We were able to gain an extra 16% F1 points by using a pre-trained model. This was done very easily thanks to Hugging Face’s transformers library and PyTorch but comes at a cost of a slower inference speed. However, we can still use weight quantization or model distillation to alleviate the inference latency problem while preserving most of the F1 performance.

You can find the code to reproduce all those results in this repository: https://github.com/CVxTz/ner_playground

Data set: https://github.com/taisti/TASTEset/tree/main/data (MIT License)

Thanks for reading!

--

--