The world’s leading publication for data science, AI, and ML professionals.

Interpreting the Prediction of BERT Model for Text Classification

How to Use Integrated Gradients to Interpret BERT Model's Prediction

Photo by Shane Aldendorff: https://www.pexels.com/photo/shallow-focus-photography-of-magnifying-glass-with-black-frame-924676/
Photo by Shane Aldendorff: https://www.pexels.com/photo/shallow-focus-photography-of-magnifying-glass-with-black-frame-924676/

Bidirectional Encoder Representation from Transformer or Bert is a language model that’s very popular within the NLP domain. BERT is literally the swiss army knife of NLP due to its versatility and how well it performed in many different NLP tasks, such as text classification, named entity recognition, question-answering, etc.

But there is a catch if we use BERT for a specific NLP task: its architecture consists of a deep stack of layers. For this reason, BERT has commonly been considered a black box model in the sense that it’s not easy to interpret its prediction result.

Let’s say that we have trained our BERT model to classify the sentiment of a movie review. Next, we want to use it to predict a random movie review.

Image by author
Image by author

The BERT model predicts that the ‘The movie is superb‘ review has a positive sentiment, which is correct. However, we might ask some of these questions afterward:

  • Why does the model predict our input review as a positive review instead of a negative review?
  • Which words that our model finds to be the most important in that input review such that it can be classified as a positive review?
  • How reliable is our model, exactly?

It’s not trivial to answer the questions above if we use a black box model like BERT. However, there is a way to interpret a deep learning model’s prediction so that we can answer all the questions above. And we’re going to use Integrated Gradient to do this.


What is Integrated Gradients?

Integrated gradients is a method to compute the attribution of each feature of a deep learning model based on the gradient of the model’s output (prediction) with respect to the input. This method applies to any deep learning model for Classification and regression tasks.

As an example, let’s say that we have a text classification model and we want to interpret its prediction. With integrated gradients, in the end, we will get the attribution score of each input word with respect to the final prediction. We can use this attribution score to find out which words play an important role in our model’s final prediction.

Image by author
Image by author

To implement integrated gradients, we need two sets of input: _the original input and the baseline input_.

The original input is pretty self-explanatory. It’s just our original input. Meanwhile, the baseline feature is an ’empty’ or ‘neutral’ input. The example of this depends on the use case that we have, for example:

  • If our input is an image: the baseline input could be a black image (all pixels are set to 0)
  • If our input is a text: the baseline input could be an all-zero embedding vector
Image by author
Image by author

The baseline input is important because in order to find out which features are influential for the model’s prediction, we need to compare the change in the model’s prediction if we use the original input vs if we use the baseline input.

We then gradually interpolate the baseline input to resemble our original input step-by-step, whilst calculating the gradient of the prediction with respect to the input features in each step. Thus, integrated gradients is calculated according to the following formula:

Image by author
Image by author
where:
i : feature iterator
x : original input
x': baseline input
k : scaled feature perturbation 
m : total number of approximation steps

You might notice that the equation above is just an approximation of the true integrated gradients. This is because in practice, computing the true integral is not always numerically possible.

The higher the total number of approximation steps (m), the closer the approximation result to the true integrated gradients. Also, as the scaled feature perturbation (k) is approaching m, **** the more the baseline input resembles the original input. In practice, we should define the value of **** m in advance.

To understand more about the concept of the above equation, let’s imagine that we have an image classification model because it’s easier to visualize. Here are the steps if we want to use integrated gradients to interpret its prediction.:

1. Interpolate Baseline Input

As mentioned earlier, to implement integrated gradients we need two different inputs: original input and baseline input. If our input is an image, then the original input will be our original image and the baseline input would be an all-black image.

Image by author
Image by author
Image by author
Image by author

The method then will linearly interpolate the baseline image by increasing the k value step-by-step based on the total number of steps (m) that we need to define in advance. As the k value approaches m, the more identical our baseline image with the input image will be.

Let’s say that we set the total number of steps to 50. As the k value gets closer to 50, the more identical our baseline image and our input image will be.

Image by author
Image by author

2. Calculate the Gradients

For each interpolated baseline image, the method will calculate the gradients of the model’s prediction with respect to each input feature.

These gradients measure the change in the model’s prediction with respect to the change in input features so that in the end we can estimate the importance of each input feature on the model’s prediction.

Image by author
Image by author

3. Accumulate the Gradients

In the end, we accumulate the gradients in each step with an approximation method called Riemann sums. We basically just sum the gradients in each step and then divide it by the total number of steps.

Image by author
Image by author

After this step, we get the attribution result of each feature on model’s prediction.


Integrated Gradients for BERT

In this article, we’re going to implement integrated gradients method to interpret the prediction of a BERT model that has been trained for a text classification use case.

If you’re new to BERT and want to get a basic intuition of what it does and its architecture, check out my other article here.

Text Classification with BERT in PyTorch

In a nutshell, BERT architecture consists of Transformer encoders stacked together and the number of encoders depends on which BERT model we use.

  • BERT base model has 12 layers of Transformer encoders
  • BERT large model has 24 layers of Transformer encoders
BERT base architecture illustration (Image by author)
BERT base architecture illustration (Image by author)

When we want to use a BERT model to predict a text, what we first normally do is tokenize the input text.

The tokenization process splits our input text into smaller chunks called tokens and each token consists of either a word or a subword. Next, special tokens for the BERT model such as [CLS], [SEP], and optionally [PAD] will be added to our initial tokens. Finally, each token will be transformed into its numerical representation that can be used by machine learning algorithms.

from transformers import BertTokenizer

# Instantiate tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

text = 'The movie is superb'

# Tokenize input text
text_ids = tokenizer.encode(text, add_special_tokens=True)

# Print the tokens
print(tokenizer.convert_ids_to_tokens(text_ids))
# Output: ['[CLS]', 'The', 'movie', 'is', 'superb', '[SEP]']

# Print the ids of the tokens
print(text_ids)
# Output: [101, 1109, 2523, 1110, 25876, 102]

However, this numerical representation after tokenization can’t be used for integrated gradients because they are discrete and wouldn’t work with the interpolation method described in the previous section. Thus, we need to transform the tokenization result into another form.

As you can see in the BERT architecture above, before each token is passed into a stack of encoders, it needs to be transformed into embedding via an embedding layer. This embedding of each token is the one that we’ll use as an input to calculate the attribution of each input to the model’s prediction.

Below is a minimal example of how we can get the embedding of tokens with BERT. We will see the detailed implementation in the next section.

from transformers import BertModel
import torch

# Instantiate BERT model
model = BertModel.from_pretrained('bert-base-cased')

embeddings = model.embeddings(torch.tensor([text_ids]))
print(embeddings.size())
# Output: torch.Size([1, 6, 768]), since there are 6 tokens in text_ids

Implementing Integrated Gradients for a BERT Model with Captum

Now that we know the concept behind integrated gradients, let’s implement it in action. As a first step, we need to instantiate the architecture of our BERT model:

from torch import nn

# Class of model architecture
class BertClassifier(nn.Module):

    def __init__(self, dropout=0.5):

        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 2)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask = None):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)

        return final_layer

This architecture corresponds to the architecture used to train the BERT model to classify a movie review’s sentiment. We use a pre-trained BERT-base-cased model and then add a linear layer at the end that has two outputs to classify whether a movie sentiment is negative or positive.

If you want to follow along in interpreting this model, you can download the trained model via this link:

https://github.com/marcellusruben/bert_captum/raw/main/bert_model.pt

Now let’s load the trained model’s parameters, and then set the trained model to evaluation mode.

model = BertClassifier()
model.load_state_dict(torch.load('bert_model.pt', map_location=torch.device('cpu')))
model.eval()

When we want to use integrated gradients to interpret a model’s prediction, we need to specify two things: the model’s output and the model’s input.

The model’s output is simply the prediction of a movie review’s sentiment. This means the value that comes out from the model’s last layer. To obtain the model’s output, we can simply do a forward pass.

Meanwhile, as mentioned in the previous section, the model’s input would be the embedding of tokens. This means the features after passing through the embedding layer. Thus, the values that come out from the embedding layer would be our model’s input.

Let’s define our model’s output and input.

# Define model output
def model_output(inputs):
  return model(inputs)[0]

# Define model input
model_input = model.bert.embeddings
Image by author
Image by author

Now it’s time for us to implement integrated gradients, and we are going to use Captum to do this. Captum is an open-source library to interpret machine learning prediction built in PyTorch. This means that this library works for any PyTorch model.

Currently there are three different methods supported by Captum to interpret the result of any PyTorch model:

  • Primary attribution: a method to evaluate the influence of each feature on the model’s prediction
  • Layer attribution: a method to evaluate the influence of neurons of a layer on the model’s prediction
  • Neuron attribution: a method to evaluate the influence of each input feature on the activation of a particular neuron

Since we’re going to interpret the influence of each token on the model’s prediction, then we need to implement the primary attribution, and integrated gradients is one of the algorithms for primary attribution.

We literally only need one line of code to initialize integrated gradients algorithm with Captum. Then, we provide the model’s output and input as arguments.

from captum.attr import LayerIntegratedGradients

lig = LayerIntegratedGradients(model_output, model_input)

If you remember from the previous section, the implementation of integrated gradients under the hood requires two things: one is our original input and another is the baseline input.

Since we are dealing with text input, then our original input would be the embedding of each token. Meanwhile, the baseline input would be the embedding of a bunch of padding tokens with a similar length as our original token. Let’s create a function to generate this baseline.

def construct_input_and_baseline(text):

    max_length = 510
    baseline_token_id = tokenizer.pad_token_id 
    sep_token_id = tokenizer.sep_token_id 
    cls_token_id = tokenizer.cls_token_id 

    text_ids = tokenizer.encode(text, max_length=max_length, truncation=True, add_special_tokens=False)

    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    token_list = tokenizer.convert_ids_to_tokens(input_ids)

    baseline_input_ids = [cls_token_id] + [baseline_token_id] * len(text_ids) + [sep_token_id]
    return torch.tensor([input_ids], device='cpu'), torch.tensor([baseline_input_ids], device='cpu'), token_list

text = 'This movie is superb'
input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)

print(f'original text: {input_ids}')
print(f'baseline text: {baseline_input_ids}')

# Output: original text: tensor([[  101,  1109,  2523,  1110, 25876,   102]])
# Output: baseline text: tensor([[101,   0,   0,   0,   0, 102]])

We get three variables after executing the function above: input_ids is the tokenized original input text, baseline_input_ids is the tokenized baseline text, and all_tokens is the list of tokens from the original input text. We will use all_tokens variable only for visualization purposes later on.

Next, we can start to interpret the model’s prediction. To do this, we can use attribute() method from LayerIntegratedGradients() class that we have initialized before. All we need to do is pass the tokenization result of our original input and baseline input as arguments. Optionally, you can specify the number of approximation steps by passing n_steps variable as an additional argument. Otherwise, it will be set to 50.

attributions, delta = lig.attribute(inputs= input_ids,
                                    baselines= baseline_input_ids,
                                    return_convergence_delta=True,
                                    internal_batch_size=1
                                    )
print(attributions.size())
# Output: torch.Size([1, 6, 768])

We get two variables after implementing integrated gradients: attributions and delta.

From the previous section, we know already that the integrated gradients method that we implement is only an approximation method because computing the true integral is not always possible numerically. The delta variable that we get here is the difference between the approximated and the true integrated gradients. We will use this variable later on only for visualization purposes.

Meanwhile, attributions variable is the attribution of each embedding element for each token. To get the final attribution of each token, then we need to calculate the attribution average of all of the embedding elements. The function below does exactly that.

def summarize_attributions(attributions):

    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)

    return attributions

attributions_sum = summarize_attributions(attributions)
print(attributions_sum.size())
# Output: torch.Size([6])

And that’s it! The output from the function above is the attribution of each token. Since we have six tokens after the tokenization process, then we also get six attribution values, one value for each token.

To make it easier for us to inspect the result, we can visualize it with VisualizationDataRecord() class and visualize_text() method from Captum.

from captum.attr import visualization as viz

score_vis = viz.VisualizationDataRecord(
                        word_attributions = attributions_sum,
                        pred_prob = torch.max(model(input_ids)[0]),
                        pred_class = torch.argmax(model(input_ids)[0]).numpy(),
                        true_class = 1,
                        attr_class = text,
                        attr_score = attributions_sum.sum(),       
                        raw_input_ids = all_tokens,
                        convergence_score = delta)

viz.visualize_text([score_vis])

There are a lot of arguments that we need to supply when we call VisualizationDataRecord() class above, so let’s dissect them one by one.

  • word_attributions : integrated gradients’ result of each token
  • pred_prob : model’s prediction
  • pred_class : the class of model’s prediction
  • true_class : ground-truth label of the input
  • attr_class : the input
  • attr_score : the sum of IG accross the whole tokens
  • raw_input_ids : the list of tokens
  • convergence_score : the difference between approximated and true integrated gradients

Once we run the code snippet above, we will get the visualization result as follows:

Image by author
Image by author

From the visualization above, we can see that the predicted review is 1 (positive), which is correct. In the ‘Word Importance‘ section, we can see the attribution of each token to the model’s prediction.

The tokens that have a positive contribution to the model’s prediction are highlighted in green. Meanwhile, the tokens that have negative contributions to the model’s prediction are highlighted in red.

It seems that the word ‘superb‘ has the most influence which makes our BERT model classifies the review as positive, and this makes sense. Let’s encapsulate what we have done so far into a single function and then interpret another review.

def interpret_text(text, true_class):

    input_ids, baseline_input_ids, all_tokens = construct_input_and_baseline(text)
    attributions, delta = lig.attribute(inputs= input_ids,
                                    baselines= baseline_input_ids,
                                    return_convergence_delta=True,
                                    internal_batch_size=1
                                    )
    attributions_sum = summarize_attributions(attributions)

    score_vis = viz.VisualizationDataRecord(
                        word_attributions = attributions_sum,
                        pred_prob = torch.max(model(input_ids)[0]),
                        pred_class = torch.argmax(model(input_ids)[0]).numpy(),
                        true_class = true_class,
                        attr_class = text,
                        attr_score = attributions_sum.sum(),       
                        raw_input_ids = all_tokens,
                        convergence_score = delta)

    viz.visualize_text([score_vis])

And that’s it. Now if we want to interpret a movie review, what we can do is type the review and then call interpret_text() function above.

text = "It's a heartfelt film about love, loss, and legacy"
true_class = 1
interpret_text(text, true_class)
Image by author
Image by author

Our BERT model predicted the review sentiment correctly, i.e it’s predicted as 1 (positive). From the attribution of each token, we can see that words like ‘heart‘, ‘love‘, and ‘legacy‘ contribute positively to the model’s prediction, whilst on the other hand, the word ‘loss‘ contributes negatively.

Let’s now supply our model with a negative review.

text = "A noisy, hideous, and viciously cumbersome movie"
true_class = 0
interpret_text(text, true_class)
Image by author
Image by author

Once again, our BERT model predicted the movie sentiment correctly. From the token attribution, we can see that words like ‘noisy‘, ‘hideous‘, and ‘vicious‘ contribute positively to the model’s prediction, which also makes sense.


Conclusion

In this article, we have implemented integrated gradients to interpret the prediction of a BERT model for text classification. With this method, we are able to find out which inputs are the most important for our Deep Learning model to make a prediction. Thus, it will also give us information about the reliability of our model.

I hope that this article helps you to get started with integrated gradients. You can find all of the code implemented in this article in this notebook.


Related Articles