A hands-on tutorial

Visualize BERT sequence embeddings: An unseen way

Exploring an unseen way of visualizing sequence embeddings generated across BERT’s encoder layers. Python notebook included for complete code.

Tanmay Garg
Towards Data Science
6 min readJan 1, 2021

--

About

Transformer-encoder based language models like BERT [1] have taken the NLP community by storm, with both research and development strata utilising these architectures heavily to solve their tasks. They have become ubiquitous by displaying state-of-the-art results on a wide range of language tasks like text classification, next-sentence prediction, etc.

The BERT-base architecture stacks 12 encoder-layers, which brings it up to a whopping 100 million tuneable parameters! The BERT-large architecture takes it yet a notch higher with 24 encoder-layers and ~350 million parameters! 🤯

Illustration picked from http://jalammar.github.io/illustrated-bert/

Why so many layers?

In the forward pass for an input text sequence, the output from each of these encoder blocks can be seen as a sequence of contextualised embeddings. Each contextualised embedding sequence is then fed as an input to the next layer.

This repetitive application of encoder layers enables:

  • extraction of different features of the input as the input proceeds through them
  • each successive layer building upon the patterns highlighted by the previous layer
Illustration picked from http://jalammar.github.io/illustrated-bert/

It is a known fact that when using these deep architectures, it is very easy to fall into the pits of overfitting.

Looking at how each encoder layer offers its own embeddings for the input, an interesting question may arise in one’s head:
“As I train my model, how effectively do each layer’s embeddings generalize on unseen data?”
In simpler words, it is of interest to see to what extent each layer of BERT is able to find patterns in data that hold on unseen data.

In this tutorial we will be talking about a really cool way to visualize how effective each layer is in finding patterns for a classification task.

Hands on💪

With the motivation set, let’s look at what we will be doing.

Aim

Train a BERT model for multiple epochs, and visualize how well each layer separates out the data over these epochs. We will be training the BERT for a sequence classification task (using the BertForSequenceClassification class). The same exercise can be extended to other tasks with some tweaks in implementation details. For example, language modeling (using the BertForMaskedLM class). Re-train a language model on your own dataset, and inspect the characteristics of each cluster or the distribution of embeddings!

Resources/Libraries used

Training:

  1. 🤗Transformers: the BertForSequenceClassification model, but you can also plug in other transformer-encoder classifier architectures like RobertaForSequenceClassification, DistilBertForSequenceClassification, etc.
  2. PyTorch

Visualization:

  1. Seaborn
  2. Matplotlib

Dataset:

HatEval [2], a dataset with tweets labeled as Hateful/Neutral.

However, feel free to load your own dataset in the companion Jupyter notebook.

The companion Jupyter notebook

I am placing the complete code for data loading, model training and embedding visualization in this notebook. The code present in this tutorial is intended only for explanation purposes. Please refer the notebook for complete working code.

⚠️ What we won’t be covering

Since this article focuses only on the visualization of layer embeddings, we will be walking through only relevant parts of the code. Rest of the code lies outside the scope of this tutorial.
I assume a prior knowledge of the 🤗Transformers BERT basic workflow (data preparation, training/eval loops, etc).

Let’s get started

Extract Hidden States of each BERT encoder layer:

  1. The 🤗Transformers provides us with a BertForSequenceClassification model, which consists of:
    (1 x BertEmbeddings layer) → (12 x BertLayer layers) → (1 x BertPooler layer over the embedding for ‘[CLS]’ token) → (tanh activation) → (Dropout layer)
    Note that the classification head (starting from the pooler layer) is placed to facilitate training. We would be visualizing embeddings coming straight out of the 12 x BertLayer layers.
  2. sent_ids and masks are prepared in a 🤗Transformers BERT compatible form
  3. labels for the dataset are required to color code the visualizations
  4. Each of these 1 x BertEmbeddings layer and 12 x BertLayer layers can return their outputs (also known as hidden_states) when the output_hidden_states=True argument is given to the forward pass of the model. Hence, the dimension of model_out.hidden_states is (13, number_of_data_points, max_sequence_length, embeddings_dimension)
  5. Since we are only interested in the embeddings from 12 x BertLayer layers, we slice out the unwanted BertEmbeddings layer embeddings, leaving us with hidden states of dimension (12, number_of_data_points, max_sequence_length, embeddings_dimension)

Next, we define a function that can plot the layers’ embeddings for a split of our dataset (eg- train/val/test) after an epoch:

  1. dim_reducer: scikit-learn’s t-SNE dimension reduction implementation to reduce our embeddings from BERT’s default 768 dimension to 2 dimension. You can also use PCA depending on which suits better to your dataset.
  2. visualize_layerwise_embeddings: define a function that can plot the layers’ embeddings for a split of our dataset (train/val/test) after each epoch
  3. loop over each layer to compute:
    - layer_embeds: the embeddings output by the layer, a tensor of shape (number_of_data_points, max_sequence_length, embeddings_dimension)
    - layer_averaged_hidden_states: creates a single embedding for each data point by taking an average of the embeddings across all non-masked tokens of the sequence, resulting in a tensor of shape (number_of_data_points, embeddings_dimension)
    - layer_dim_reduced_vectors: the t-SNE dimension reduced embeddings, tensor of shape (number_of_data_points, embeddings_dimension)

These computed values are finally plotted on a new plot using the Seaborn library.

Finally, putting together what we’ve seen till now inside a training loop:

Here we call the visualize_layerwise_embeddings function once per epoch for every split of the dataset we want to visualize separately.

I choose to visualize embeddings from the first 4 and last 4 layers.

The Visualizations🔬👨‍⚕️

We have our visualizations ready!
I took a step further to make things more convenient by stitching the different images into a gif! Once again, code is present in the notebook.
Pretty aesthetic, right?🤩

Visualizations for train data across layers of a BERT model

Spend a moment on each layer’s output. Try to draw some interesting inferences out of them!

I’ll give some examples:
- can you comment on how each layer is performing with each successive epoch?
- the train accuracy of the classifier dropped from epoch 4 to epoch 5! Can you verify this fact from the above gif? **

Visualizations for validation data across layers of a BERT model

Finally, we are more interested in knowing if our embeddings are helping us generalize. We can judge that by the validation split visualizations above.

Some interesting questions from the top of my head are:
- Which layer are generalizing better that the others?
- How well is the last layer able to separate out the classes?
- Do you see any difference in the separability between the embeddings for train and validation splits?
- Does taking an average of the embeddings across all non-masked tokens of the sequence produce better results that taking embedding only for ‘[CLS]’ token? (You might have to tweak the notebook a little to answer this one😉)

Going further🙇‍♂️

Don’t stop just here, yet!
Go and play around in the notebook provided, and try to mix and match different layers’ output embeddings to see which combination helps you produce the best downstream performance!

** Answer: the embeddings for the 12th layer are more neatly clustered for epoch 4 as compared to epoch 5! It is a clear indicator of the classifier having hit and then over-shot a minima in the loss-function space.

References

[1]: BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, Delvin et al., 2019
[2]: SemEval-2019 Task 5: Multilingual Detection of Hate Speech Against Immigrants and Women in Twitter, Basile et al., 2019

--

--