The world’s leading publication for data science, AI, and ML professionals.

BERT – Intuitively and Exhaustively Explained

Baking General Understanding into Language Models

"Baking" by Daniel Warfield using MidJourney. All images by the author unless otherwise specified. Article originally made available on Intuitively and Exhaustively Explained.
“Baking” by Daniel Warfield using MidJourney. All images by the author unless otherwise specified. Article originally made available on Intuitively and Exhaustively Explained.

In this article we’ll discuss "Bidirectional Encoder Representations from Transformers" (BERT), a model designed to understand language. While BERT is similar to models like GPT, the focus of BERT is to understand text rather than generate it. This is useful in a variety of tasks like ranking how positive a review of a product is, or predicting if an answer to a question is correct.

Before we get into BERT we’ll briefly discuss the transformer architecture, which is the direct inspiration of BERT. Using that understanding we’ll dive into BERT and discuss how it’s built and trained to solve problems by leveraging a general understanding of language. Finally, we’ll create a BERT model ourselves from scratch and use it to predict if product reviews are positive or negative.

Who is this useful for? Anyone who wants to form a complete understanding of the state of the art of AI.

How advanced is this post? Early parts of this article are accessible to readers of all levels, while later sections concerning the from-scratch implementation are fairly advanced. Supplemental resources are provided as necessary.

Pre-requisites: I would highly recommend understanding fundamental ideas about PyTorch before reading the implementation section. You can learn more about PyTorch here:

AI for the Absolute Novice – Intuitively and Exhaustively Explained

An understanding of transformers and multi-headed-self-attention may be useful for later sections but is not required.

Transformers In a Nutshell

At this point I’ve covered the transformer, and transformer derivative architectures, a lot.


Transformers – Intuitively and Exhaustively Explained

Multi-Headed Self Attention – By Hand

GPT – Intuitively and Exhaustively Explained

Flamingo – Intuitively and Exhaustively Explained


Let’s go over the highlights.

At its most fundamental, the transformer is an "encoder/decoder" style model. When you put something into a transformer, the encoder summarizes the input into some meaning rich and abstract representation which the decoder uses to generate output.

The core idea of an encoder-decoder style model. The encoder converts the input into some abstract representation, which is passed to the decoder for generation. In this particular example, the encoder and decoder are working together to translate a phrase in English to French. From my article on transformers.
The core idea of an encoder-decoder style model. The encoder converts the input into some abstract representation, which is passed to the decoder for generation. In this particular example, the encoder and decoder are working together to translate a phrase in English to French. From my article on transformers.

The transformer uses a variety of AI building blocks to do this general process, as can be seen from its architecture diagram.

The transformer diagram. source
The transformer diagram. source

First, the input embedding converts words into vectors. This turns words, which are hard to do math on, into numbers which are easier to do math on.

Word to vector embedding. From my article on transformers.
Word to vector embedding. From my article on transformers.

Then, we create vectors which correspond to each input location and add those position vectors to the word vectors. Thus, each resulting vector contains information about both the word, and the position of the word.

Different vectors, with different values, are added to the word vectors depending on their location. From my article on transformers.
Different vectors, with different values, are added to the word vectors depending on their location. From my article on transformers.

A process called multi-headed self-attention is applied to the input phrase. This somewhat complex operation is, to a large extent, the defining characteristic of a transformer. We’ll cover this more later, but for now we’ll just say that the multi-headed self-attention mechanism makes every word in the input sequence interact with every other word in the input sequence. The output is an abstract and meaning rich representation of the input.

Multi headed self-attention allows the input to interact with itself, creating a complex matrix that represents the entire input. From my article on transformers.
Multi headed self-attention allows the input to interact with itself, creating a complex matrix that represents the entire input. From my article on transformers.

The multi-headed self-attention mechanism is a complex operation. To make training the model easier some of the older and simpler structure of the input is added to the output of the attention mechanism to preserve some of that simpler representation. This is called a skip connection.

A conceptual diagram of a skip connection. From my article on transformers.
A conceptual diagram of a skip connection. From my article on transformers.

Then the values, which might be all over the place, are squashed into a reasonable distribution via a process called "normalization". After the attention mechanism, a neural network is applied to the data.

This whole process results in an output which is similar in shape to the input but is much more complex and abstract.

The encoder's job is to represent the input optimally for the decoder. It's the decoder's job to generate the output. From my article on transformers.
The encoder’s job is to represent the input optimally for the decoder. It’s the decoder’s job to generate the output. From my article on transformers.

The "decoder" is essentially made of the same core components as the encoder, but it’s objective is different. While the encoder contextualizes the input into an encoding, the decoder uses the encoding to construct the output.

And that’s the transformer in a nutshell. We glossed over a lot, feel free to check out some of the links if you want to know more. For now, though, we can start digging into the core idea of Bert.

Encoder vs Decoder Style Models

The original transformer lit AI off like a firework. From its humble beginnings as an English to French translation model, it’s been expanded to be the bedrock of a hundred-billion-dollar industry. When the transformer was invented, it inherited many of the popular conceptualizations of machine learning of its day. Particularly, thinking of language modeling as a "sequence to vector to sequence" task.

Conceptual diagrams of a few applications of different types of sequence modeling. Sequence to Sequence might be predicting the next word for text complete. Sequence to vector might be scoring how satisfied a customer was with a review. Vector to sequence might be compressing an image into a vector and asking the model to describe that image as a sequence of text. Sequence to vector to sequence might be text translation, where you need to understand a sentence, compress it into some representation, then construct a translation of that compressed representation in a different language. From my article on transformers.
Conceptual diagrams of a few applications of different types of sequence modeling. Sequence to Sequence might be predicting the next word for text complete. Sequence to vector might be scoring how satisfied a customer was with a review. Vector to sequence might be compressing an image into a vector and asking the model to describe that image as a sequence of text. Sequence to vector to sequence might be text translation, where you need to understand a sentence, compress it into some representation, then construct a translation of that compressed representation in a different language. From my article on transformers.

This was a popular conceptualization in recurrent neural networks which was the modeling architecture popular before the transformer. The transformer was derivative of that cannon of thought, and thus was constructed to be a variant of "sequence to vector to sequence" modeling, which is what the encoder and decoder essentially are.

After the transformer exploded in popularity a new era of research emerged (and still persists) which encouraged widespread experimentation around the transformer architecture. One of the strongest ideas that emerged from this research was the idea of "encoder only" and "decoder only" style models.

GPT was one such "decoder only" style model. Essentially a decoder only style model is just the right half of the transformer. Instead of feeding the input into an encoder, GPT feeds the input into the decoder and then uses that decoder to generate the output.

A conceptual diagram of GPT, a decoder only model, generating output. From my article on GPT.
A conceptual diagram of GPT, a decoder only model, generating output. From my article on GPT.

This simpler architecture has some key advantages, particularly around training, which allows the model to learn more easily from large amounts of textual data.

While decoder only models have exploded in popularity, "encoder only" models are another important tool in building advanced AI systems. Recall that the encoder and the decoder are virtually identical, the only difference is in what their job is; the encoder summarizes an input into an abstract and meaning rich representation, and the decoder generates text.

The point of an "encoder only" transformer, then, is to summarize some input sequence into an abstract, dense, and meaning rich representation. Instead of creating that representation specifically for text generation, the point of an encoder only transformer is to create a representation that’s generally useful for a variety of tasks.

BERT is the most famous encoder only model and excels at tasks which require some level of language comprehension.

BERT – Bidirectional Encoder Representations from Transformers

Before the transformer if you wanted to predict if an answer answered a question, you might use a recurrent strategy like an LSTM.

The general idea of a recurrent model, which is a type of model that feeds into itself over successive inputs. From my article on transformers.
The general idea of a recurrent model, which is a type of model that feeds into itself over successive inputs. From my article on transformers.

One issue with this approach is information locality. LSTMs pass information along in a sequence in order to create a vector which represents the models prediction. As a result, information that’s further apart within a sequence has a harder time interacting. It’s kind of like a game of telephone, if you’ve ever played that as a kid. The longer the sequence is, the more easily the model can forget important but distant information as that information is forced to interact with other information through successive inputs.

AI Researchers did all sorts of stuff to try to alleviate the problem of information locality. They used recurrent networks that go in opposite directions to try to bake "bidirectional" understanding into the model, allowing the model to look both forward and backward.

A conceptual diagram of a bidirectional recurrent networks: a left to right and a right to left network
A conceptual diagram of a bidirectional recurrent networks: a left to right and a right to left network

There were also strategies like hierarchical recurrent networks which tried to use summarizations of previous layers to preserve more long-range information.

A conceptual diagram of a simultaneously bi-directional and hierarchical recurrent network
A conceptual diagram of a simultaneously bi-directional and hierarchical recurrent network

While these strategies put a Band-Aid over the issue, they don’t solve it. Ultimately, recurrent networks are spatially dependent.

One of the cool things about transformers is their ability to deal with information across large input sequences. If two words in a sequence are relevant to one another the transformer can manipulate those two words together regardless of how far apart they are. This is thanks to the self-attention mechanism.

The self-attention mechanism analyzes the input and then constructs a matrix which says which words are relevant to which other words in the input. It then uses that matrix to allow the vectors which represent those words to interact with one another.

Self-attention allows any input to arbitrarily interact with any other input. Typically an input most interacts with itself, and tokens which are nearby, but that does not have to be the case.
Self-attention allows any input to arbitrarily interact with any other input. Typically an input most interacts with itself, and tokens which are nearby, but that does not have to be the case.

This is one of the reasons BERT has the term "bidirectional" in the name. While recurrent networks can do a weak form of bidirectionality by combining right to left and left to right analysis together, transformers are capable of arbitrary bidirectionality. Really, I think it might be better to call BERT style transformers "omnidirectional", but "OERT" doesn’t really slip off the tongue.

This shift in thinking helped make BERT an overnight success. After it was released, it absolutely crushed several well-established benchmarks which were previously dominated by recurrent strategies. This increase in performance is likely due to BERTs ability to allow elements in the input sequence to arbitrarily interact with one another, rather than having a strong bias towards spatial proximity.

Another reason it was so successful was the way it was trained, which we’ll cover in the next section.

Training a BERT model

BERT style models employ a two-pronged approach to training, called "pre-training" and "fine tuning".


Excerpt from my article on LoRA, an approach to fine tuning.

As the state of the art of machine learning has evolved, expectations of model performance have increased; requiring more complex machine learning approaches to match the demand for heightened performance. In the earlier days of machine learning it was feasible to build a model and train it in a single pass.

Training, in its simplest sense. You take an untrained model, give it data, and get a performant model.
Training, in its simplest sense. You take an untrained model, give it data, and get a performant model.

This is still a popular strategy for simple problems, but for more complex problems it can be useful to think of training as two parts; "pre-training" then "fine tuning". The general idea is to do an initial training pass on a bulk dataset and to then refine the model on a tailored dataset.

Pre-Training and Fine Tuning, a refinement of the typical training strategy.
Pre-Training and Fine Tuning, a refinement of the typical training strategy.

This "pre-training" then "fine tuning" strategy can allow data scientists to leverage multiple forms of data and use large pre-trained models for specific tasks. As a result, pre-training then fine tuning is a common and incredibly powerful paradigm.


BERT uses a pre training step which is designed to encourage the model to understand language generally, then allows for fine tuning to allow the model to learn specific tasks. First let’s cover pre-training.

BERT Pre-Training

BERT is pre-trained on two objectives simultaneously: "masked language modeling", which is like fill in the blank, and "next sentence prediction" which is essentially asking the model to predict of two sentences make sense with one another.

For next sentence prediction, imagine we have some text which we can break into a list of sentences.

['I am sad.',
 'I ate a bagle, but I'm still hungry.',
 'I don't like being hungry.',
 'But, I know I'll be eating again soon!',
 'I think I'll go to Fudruckers.'
 'I'm not sure if Fudruckers has bagles, though.'
]

First, we can take two sentences where we know the second sentence follows the first sentence, and we can stick them together.

('I am sad.', 'I ate a bagle, but I'm still hungry')

We can also combine two sentences which do not follow one another.

('I am sad.', 'I'm not sure if Fudruckers has bagles, though.')

Then we can construct a dataset of sentences which do and do not follow one another. This would be done with a large corpus of text with many sentences.

Follows | Sentences
------------------
True    | ('I am sad.', 'I ate a bagle, but I'm still hungry')
False   | ('I am sad.', 'I'm not sure if Fudruckers has bagles, though.')
...

And thus we’ve made a dataset for next sentence prediction. If we apply an AI model to this task, it will have to understand language to a sufficient degree to understand what sentences do or do not make sense with one another.

We can also add masked language modeling by randomly replacing words within the dataset with a mask.

Follows | Sentences
------------------
True    | ('I [MASK] sad.', 'I ate a bagle, but I'm still hungry')
False   | ('I am sad.', 'I'm not sure if Fudruckers [MASK] bagles, though.')
...

So, for a given input, the model will have two jobs during pre-training:

  1. Predict if the second sentence follows the first sentence.
  2. Predict what the masked words should have been.

The idea is, if you do this with a whole bunch of text, the model is forced to have a solid general understanding of language. It can understand sentences enough to understand if sentences make sense when paired together and can use context clues to understand which words should exist within the masked sections.

To make these predictions BERT adds a few things onto the traditional encoder style transformer model. First of all, BERT has a few special tokens called "utility tokens" which are placed within a sequence to allow these two sentences to be represented in the input. First it adds a token called [CLS] to the beginning of the sequence which will later be used for classifying if the sequence is a positive next sentence pair. Then, the token [SEP] is added to the input to separate the two sentences from one another. The token [MASK] also exists, as previously discussed.

'[CLS] I [MASK] sad. [SEP] I ate a bagle, but I'm still hungry'

These tokens are represented as text for interpretabilities sake, but in reality they’re numbers. The input is broken down into a list of numbers, where each number represents either a word or a utility token within the sequence.

# a conceptual demonstration of tokenizing the input as a list of numbers.
[CLS] = 101
I     = 1023
[MASK]= 103
sad   = 39842
[SEP] = 102
I     = 1023
ate   = 8907
a     = 213
bagle = 208756
but   = 9867
I'm   = 2367
still = 7893
hungry= 55678

(In reality tokenization happens with sub-words, but we’ll cover that in the implementation)

So, in pre training we give our BERT style model these tokens, and we ask it to predict two things: are the sentences related, and what should be in the masked locations? We then update the model based on how wrong it was.

BERT Fine Tuning

What’s cool about the pre-training process of BERT is that you can expose a tone of text to the model, allowing the model to form a very robust understanding of language. The fine tuning step leverages that understanding and applies it to a specific problem.

The exact process of fine tuning depends on the type of data you’re trying to fine tune against. Let’s use sentiment analysis as an example.

Suppose we have the following dataset of product reviews which contains the title of the review, the body of the review, and weather the review was positive or negative.

Positive | Title                 | Body
------------------------------------------------------------------------
True     | Eureka!               | I finally found something that works!
False    | Not good              | Broke after first use
...

This dataset can be manipulated into a format that’s similar to our original dataset (you may choose a different mode of re-representation depending on the dataset you’re fine tuning on).

Positive | Sequence
------------------------------------------------------------------------
True     | [CLS] Eureka! [SEP] I finally found something that works!
False    | [CLS] Not good [SEP] Broke after first use
...

Then we can train our model again, just like we did in pre-training, but on this new and more specific dataset. Through training the model will learn to morph its understanding of next sentence prediction into an understanding of if the sequence is a positive or negative review.

Typically, when fine tuning on a different objective, it’s a good idea to replace the "prediction head". I have an article discussing that topic here.

Self-Supervised Learning Using Projection Heads

In a BERT style model, one uses a dense network to turn the output corresponding to the [CLS] token into a true or false prediction of if the sentences are related with one another. The idea of projection head replacement is to replace that learned component with a randomly initialized component. Instead of the model needing to unlearn to predict for next sentence prediction, you can just replace the neural network for classification with a new, randomly initialized neural network. This typically makes it easier for the model to pivot to learning the new domain.

A conceptual diagram of the idea of throwing out the task specific logic of the final predictive neural network. This is for an image classification task, but the theory remains similar. From my article on projection heads.
A conceptual diagram of the idea of throwing out the task specific logic of the final predictive neural network. This is for an image classification task, but the theory remains similar. From my article on projection heads.

Don’t worry if you don’t completely understand projection heads, we’ll cover it more in the implementation. Let’s explore the input and output of BERT a bit more closely, then get to coding.

The Input of BERT

Recall that we’re using tokenization to turn words into numbers.

# a conceptual demonstration of tokenizing the input as a list of numbers.
[CLS] = 101
I     = 1023
[MASK]= 103
sad   = 39842
[SEP] = 102
I     = 1023
ate   = 8907
a     = 213
bagle = 208756
but   = 9867
I'm   = 2367
still = 7893
hungry= 55678

Each of these numbers is then converted into a vector via a process called "embedding". Essentially, we create a big set of random vectors where we have one vector for every possible token, then we use the corresponding vector whenever that token is used in the sequence. The values of our big lookup table of vectors are randomly initialized and updated throughout the training process.

The process of converting a sequence of tokens to a sequence of vectors in BERT. This is a common approach in many transformer style architectures.
The process of converting a sequence of tokens to a sequence of vectors in BERT. This is a common approach in many transformer style architectures.

Like the traditional transformer, vectors are added to these word vectors to bake in information about the location of words, except unlike the original transformer, BERT uses a learned positional encoding. Basically, we make a random vector for each spot in the input sequence and add that to the word vectors. These positional vectors are also updated during training.

On top of the tokenization lookup table, we also create vectors for every location in the input sequence which are added to the word vectors. This allows the model to learn to represent both the meaning of words as well as the location of those words.
On top of the tokenization lookup table, we also create vectors for every location in the input sequence which are added to the word vectors. This allows the model to learn to represent both the meaning of words as well as the location of those words.

This is a common general strategy in modern transformers, but BERT also adds another vector to each word that represents which of the two sentences the word belongs to, which is not super typical.

Sentence level embedding is also done in BERT. Random vectors are created for each sentence and added to let the model know which sentence a word came from.
Sentence level embedding is also done in BERT. Random vectors are created for each sentence and added to let the model know which sentence a word came from.

Once words and positional information have been encoded the sequence is passed through traditional encoder style transformer blocks to result in an abstract and complex output. We’ll cover this more in the implementation.

The word vectors with positional and sentence embeddings are passed through one or more encoder blocks to create a dense and meaning rich representation of the input.
The word vectors with positional and sentence embeddings are passed through one or more encoder blocks to create a dense and meaning rich representation of the input.

Then, two things happen simultaneously after the input has been passed through encoder blocks in BERT:

  1. The vector in the output which corresponds to the [CLS] token is passed through a dense network which generates a prediction as to whether sentence two follows sentence one.
  2. All the masked tokens get passed through a neural network in order to predict what the original word of masking was.
A conceptual diagram of the inputs and the outputs of a BERT style model. At the end, the output corresponding to the [CLS] token will be used in next sentence prediction, and any masked tokens (in this case only one) will be passed through a dense network to generate a token prediction.
A conceptual diagram of the inputs and the outputs of a BERT style model. At the end, the output corresponding to the [CLS] token will be used in next sentence prediction, and any masked tokens (in this case only one) will be passed through a dense network to generate a token prediction.

Naturally, a brand-new model is bad at both of these things, but by updating the parameters of the model over many examples the model begins to understand language sufficiently to solve these two problems. More importantly, though, getting good at solving these problems naturally imbues the model with an understanding of language, which can be leveraged in further tasks.

Join IAEE
Join IAEE

Implementing BERT From Scratch

We’ve covered all the high-level ideas, let’s make a BERT model.

We’ll use PyTorch to build and pre-train our BERT model on data from Wikipedia articles ([license](https://huggingface.co/datasets/fancyzhx/amazon_polarity)). We’ll then fine tune our model on a sentiment analysis task (license).

Full code can be found here.

Setting up the Wikipedia Pre-Training Dataset

Alright, we’re going to be using the fabulous datasets library from Huggingface to download out data. We’re also going to use nltk (natural language tool kit) to divide our Wikipedia articles by sentence.

!pip install datasets
!pip install nltk

The Wikipedia dataset is pretty large, and I don’t want to sit around waiting for stuff while playing with this article, so I opted to load up the dataset in streaming mode so I could get a subset of the data.

from datasets import load_dataset
#the dataset is big, to make things easier we're going to be streaming a subset
dataset = load_dataset("wikipedia", "20220301.en", trust_remote_code=True, streaming=True)

I’m also going to install punkt through nltk , which is a sentence tokenizer we’ll be using to extract sentences from articles.

import nltk
nltk.download('punkt')

And now we can go ahead and download some data and extract some sentences.

"""Breaking wikipedia articles into sentences and paragraphs
"""

import itertools

num_articles = 10000
#getting n articles
articles = list(itertools.islice(dataset_iter, num_articles))

#getting paragraphs
paragraphs = []
for article in articles:
    paragraphs.extend(article['text'].splitlines())

#filtering paragraphs so they're hopefully actually paragraphs
paragraps = [p for p in paragraphs if len(p)>50]

#dividing paragraphs into sentences
divided_paragraphs = []
for p in paragraphs:
    divided_paragraphs.append(nltk.sent_tokenize(p))

#only using paragraphs with 3 or more sentences
divided_paragraphs = [pls for pls in divided_paragraphs if len(pls)>=3]
divided_paragraphs

You might notice I’m first dividing the article into paragraphs along newlines, I use nltk to turn those paragraphs into sentences, and I only used paragraphs that contain three or more sentences. This is just some back of the napkin data engineering I experimented with which seemed to, more often than not, get me decent data to play with. At the end of this block of code we have divided_paragraphs , which is a list of paragraphs which are themselves a list of sentences.

#the content of divided_paragraphs
[['Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy.',
  'Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.',
  'As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.'],
 ['Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires.',
  'With the rise of organised hierarchical bodies, scepticism toward authority also rose.',
  'Although traces of anarchist thought are found throughout history, modern anarchism emerged from the Enlightenment.',
  "During the latter half of the 19th and the first decades of the 20th century, the anarchist movement flourished in most parts of the world and had a significant role in workers' struggles for emancipation.",
  'Various anarchist schools of thought formed during this period.',
  'Anarchists have taken part in several revolutions, most notably in the Paris Commune, the Russian Civil War and the Spanish Civil War, whose end marked the end of the classical era of anarchism.',
  'In the last decades of the 20th and into the 21st century, the anarchist movement has been resurgent once more.'],
...
]

This is handy for a bunch of reasons. The data for the original article is kind of messy and has a lot of weird grammatically incorrect stuff which makes sense in a website but not in a textual format. By only preserving the paragraphs we can be pretty sure that the sentences within the paragraphs do truly follow one another in a way that makes sense.

Now that we have, in effect, batches of sentences which follow one another we can use this data to make positive and negative pairs of sentences, where half the sentence pairs belong together and half do not.

"""Using the paragraph data to construct paris of following sentences
and pairs of random sentences
"""

import random

positive_pairs = []
negative_pairs = []

num_paragraphs = len(divided_paragraphs)

for i, paragraph in enumerate(divided_paragraphs):
    for j in range(len(paragraph)-1):
        positive_pairs.append((paragraph[j], paragraph[j+1]))
        rand_par = i
        while rand_par == i:
            rand_par = random.randint(0, num_paragraphs-1)
        rand_sent = random.randint(0, len(divided_paragraphs[rand_par])-1)
        negative_pairs.append((paragraph[j], divided_paragraphs[rand_par][rand_sent]))

At the end of this block of code we end up with two lists of sentence pairs, with one list containing sentences which belong together and the other which does not.

# positive_pairs

[('Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy.',
  'Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.'),
 ('Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.',
  'As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.'),
 ('Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires.',
  'With the rise of organised hierarchical bodies, scepticism toward authority also rose.'),
 ('With the rise of organised hierarchical bodies, scepticism toward authority also rose.',
  'Although traces of anarchist thought are found throughout history, modern anarchism emerged from the Enlightenment.'),
 ('Although traces of anarchist thought are found throughout history, modern anarchism emerged from the Enlightenment.',
  "During the latter half of the 19th and the first decades of the 20th century, the anarchist movement flourished in most parts of the world and had a significant role in workers' struggles for emancipation."),
 ("During the latter half of the 19th and the first decades of the 20th century, the anarchist movement flourished in most parts of the world and had a significant role in workers' struggles for emancipation.",
  'Various anarchist schools of thought formed during this period.'),
 ('Various anarchist schools of thought formed during this period.',
  'Anarchists have taken part in several revolutions, most notably in the Paris Commune, the Russian Civil War and the Spanish Civil War, whose end marked the end of the classical era of anarchism.'),
...
]
# negative_pairs

[('Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy.',
  'Wycliffite teachings on the Eucharist were declared heresy at the Blackfriars Council of 1382.'),
 ('Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.',
  'While Erdoğan declared being against antisemitism, he has been accused of invoking antisemitic stereotypes in public statements.'),
 ('Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires.',
  "In 1939, DeMille's Union Pacific was successful through DeMille's collaboration with the Union Pacific Railroad."),
 ('With the rise of organised hierarchical bodies, scepticism toward authority also rose.',
  "that father and son each bore the same double name, or that Abiathar officiated during his father's lifetime and in his father's stead—have been supported by great names, but have not been fully accepted."),
 ('Although traces of anarchist thought are found throughout history, modern anarchism emerged from the Enlightenment.',
  'The most common example of a conventional double is the takeout double of a low-level suit bid, implying support for the unbid suits or the unbid major suits and asking partner to choose one of them.'),
 ("During the latter half of the 19th and the first decades of the 20th century, the anarchist movement flourished in most parts of the world and had a significant role in workers' struggles for emancipation.",
  'The 1930s and 1940s were marked by instability and emergence of populist politicians, such as five-time President José María Velasco Ibarra.'),
...
]

And with that we have the dataset for pre-training pretty much set up.

Tokenization

In order to feed data into our model we need to somehow turn our sentences into vectors. In doing this we’ll use a pre-trained tokenizer from Huggingface. Basically, whoever made this looked at a bunch of text and figured out the most frequent components of text that exist, and then defined those components of text as a "vocabulary". We can use that to break up text into individual tokens.

First we download the tokenizer,

from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased")

then can put a sample sentence through the tokenizer to see what it’s all about.

"""Playing around with the tokenizer
"""
sentence = "Here's a weird word: Withoutadoubticus."
print(f'original sentence: "{sentence}"')
demo_tokens = tokenizer([sentence])
print(f"token IDs: {demo_tokens['input_ids']}")
tokens = tokenizer.convert_ids_to_tokens(demo_tokens['input_ids'][0])
print(f'token values: {tokens}')
original sentence: "Here's a weird word: Withoutadoubticus."
token IDs: [[101, 2182, 1005, 1055, 1037, 6881, 2773, 1024, 2302, 9365, 12083, 29587, 1012, 102]]
token values: ['[CLS]', 'here', "'", 's', 'a', 'weird', 'word', ':', 'without', '##ado', '##ub', '##ticus', '.', '[SEP]']

As you can see, the tokenizer broke our sentence up into individual components which may have included dividing individual words into more than one component. This is called sub-word tokenization, meaning the tokenizer has both words and word components in its vocabulary. This is important because it allows the tokenizer to express complicated words as a series of tokens.

Tokenization isn’t the focus of this article, so we’ll take the rest for granted. At the end of the day, the tokenizer turns a sequence of text into a bunch of numbers.

Exploring Special Tokens

Because we’re using a pre-made tokenizer for BERT style models, our tokenizer has a few utility tokens for us to play with.

tokenizer
BertTokenizerFast(name_or_path='google-bert/bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
 0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
 103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
  • [PAD], 0 , is used to pad sentences which are too short, allowing us to fill the context length of our model.
  • [UNK], 100 , allows BERT to encode any unknown values with the unknown token. This might be non ascii characters, for instance.
  • [CLS], 101 , a special class token we’ll be putting in the beginning of our sequence.
  • [SEP], 102 , a special token that specifies a division between sentences in the input.
  • [MASK], 103 , the token we’ll use to mask input data, allowing BERT to learn via masked language modeling.

We’ll use these to properly construct the input to our model in the next section.

Defining Training Batches

Now that we can tokenize data, and we understand what special tokens we have access to, we can turn our positive and negative sentence pairs into batches of data which we can use to train the model.

Each batch will contain 128 individual sentence pair examples, 64 of which are positive pairs and 64 of which are negative pairs. To keep our model fairly small, to speed up training, we’ll make the context window for our model equal to 64 tokens. So, at the end of this process, we’ll get a tensor which is [number_of_batches x 128(batch_size) x 64(sequence_length)] .

Here’s the code I used to make that happen:

from tqdm import tqdm
import torch
from multiprocessing import Pool, cpu_count

#defining the device the data ends up living on
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#number of examples in the batch
batch_size = 128  # should be divisible by 2
#sequence length of model
max_input_length = 64

#defining parallelizable function for processign batches
def process_batch(batch_index):
    #establishing bounds of the batch
    start_index = batch_index * batch_size
    end_index = start_index + batch_size

    if end_index > len(positive_pairs):
        return None, None, None

    #getting the sentence pairs of the batch, and if they're pos or neg
    sentence_pairs = []
    is_positives = []

    # Creating positive pairs
    sentence_pairs.extend(positive_pairs[start_index:start_index + int(batch_size / 2)])
    is_positives.extend([1] * int(batch_size / 2))

    # Creating negative pairs
    sentence_pairs.extend(negative_pairs[start_index + int(batch_size / 2):end_index])
    is_positives.extend([0] * int(batch_size / 2))

    # Defining outputs
    # At the end of the day we need to know three things:
    #   - the tokens for the sequences in a batch
    #   - which sentence the tokens belong to, for positional encoding
    #   - if the examples in the batch are positive or negative
    # these keep track of the first two
    batch_sentence_location_tokens = []
    batch_sequence_tokens = []

    # Tokenizing pairs
    for sentence_pair in sentence_pairs:
        sentence1 = sentence_pair[0]
        sentence2 = sentence_pair[1]

        # Tokenizing both sentences
        tokens = tokenizer([sentence1, sentence2])
        sentence1_tokens = tokens['input_ids'][0]
        sentence2_tokens = tokens['input_ids'][1]

        # Trimming down tokens
        if len(sentence1_tokens) + len(sentence2_tokens) > max_input_length:
            sentence1_tokens = [101] + sentence1_tokens[-int(max_input_length / 2) + 1:]
            sentence2_tokens = sentence2_tokens[:int(max_input_length / 2) - 1] + [102]

        # Creating sentence tokens
        sentence_tokens = [0] * len(sentence1_tokens) + [1] * len(sentence2_tokens)

        # Combining and padding
        pad_num = max_input_length - (len(sentence1_tokens) + len(sentence2_tokens))
        sequence_tokens = sentence1_tokens + sentence2_tokens + [0] * pad_num
        sentence_location_tokens = sentence_tokens + [1] * pad_num

        # Adding to batch
        batch_sequence_tokens.append(sequence_tokens)
        batch_sentence_location_tokens.append(sentence_location_tokens)

    return torch.tensor(batch_sentence_location_tokens), torch.tensor(batch_sequence_tokens), torch.tensor(is_positives)

# Determine the number of batches
num_batches = len(positive_pairs) // batch_size

# Use a Pool of workers equal to the number of CPU cores
with Pool(processes=cpu_count()) as pool:
    results = list(tqdm(pool.imap(process_batch, range(num_batches)), total=num_batches))

# Filter out None results from the process_batch function
results = [result for result in results if result[0] is not None]

# Unpack results into batches
sentence_location_batches, sequence_tokens_batches, is_positives_batches = zip(*results)

# Stack tensors into final batches
sentence_location_batches = torch.stack(sentence_location_batches).to(device)
sequence_tokens_batches = torch.stack(sequence_tokens_batches).to(device)
is_positives_batches = torch.stack(is_positives_batches).to(device)

This is a lot of code, feel free to pick through it. It’s a lot of just moving data around, so I don’t think it’s necessary to describe every fiddly little detail. However, I do think it’s useful to discuss the section of code where I construct what will ultimately become the input to the model.

# Trimming down tokens
if len(sentence1_tokens) + len(sentence2_tokens) > max_input_length:
    sentence1_tokens = [101] + sentence1_tokens[-int(max_input_length / 2) + 1:]
    sentence2_tokens = sentence2_tokens[:int(max_input_length / 2) - 1] + [102]

Here, I have two tokenized sentences, and I’m getting both of them to fit within the sequence length of the model. If the sentences are too long I opted to preserve the end of the first sentence and the beginning of the second sentence. This should still allow long inputs to be reasonably interpretable by the model.

# Creating sentence tokens
sentence_tokens = [0] * len(sentence1_tokens) + [1] * len(sentence2_tokens)

here I’m creating a vector which has zeros in the length of the first sentence, and ones for the length of the second sentence. We’ll use this vector to help us with positional encoding when we go to build the model.

# Combining and padding
pad_num = max_input_length - (len(sentence1_tokens) + len(sentence2_tokens))
sequence_tokens = sentence1_tokens + sentence2_tokens + [0] * pad_num
sentence_location_tokens = sentence_tokens + [1] * pad_num

Then we construct our output. We combine our sentence tokens together, and if the combined length is less than the model length we add a bunch of pad tokens. If we do add pad tokens, we say the pad tokens belong to the second sentence for convenience sake. So, conceptually, the sequence of events looks something like this:

#example sentences
sentence1 = 'Hello World!'
sentence2 = 'This is an example!'

#breaking up into tokens
sentence1_tokens = ['[CLS]', 'Hello', 'World', '!']
sentence2_tokens = ['[CLS]', 'This', 'is', 'an', 'example', '!']

#those tokens have IDs
sentence1_token_ids = [101, 1340, 87345, 1332]
sentence2_token_ids = [101, 4589, 988, 874, 13598, 1332]

#we can combine the token ID's together, and add a pad in between them.
#also we dont need the CLS token for sentence two
sequence_token_ids = [101, 1340, 87345, 1332, 102, 4589, 988, 874, 13598, 1332]

#constructing sentence tokens for positional encoding
#note how they correspond to the sequence token ids
sent_location_tokens=[0,   0,    0,     0,    1,   1,    1,   1,   1,     1   ]

#pad, if necessary, with 0's for the tokenids and 1's for the sent locations

To be completely honest I didn’t spend a ton of time validating this code. In the process of writing, I realized there are a few quirks that will lead to inconsistent tokens for various sequences. I’ll leave that as an exercise to the reader. Realistically, if we’re building a model that can understand natural language, it should be smart enough to deal with some minor formatting quirks. If you want to build a super skookum BERT model yourself, you might want to spruce up this code.

Regardless, at the end of the day we get batches of sequences of token ids each of which correspond to two sentences, we get a vector that encodes if the tokens belong to the first or second sentence, and we also keep track of which of these examples are positive or negative pairs within a batch.

# A conceptual breakdown of the data we have

#shape: [num_batches x batch_size x seq_length]
batch_tokens = [
[101, 1100, 87345, 1332,  102, 4589, 988, 874, 13598, 1332, 0,    0,    0],
[101, 987,  1332,  87345, 873, 4589, 102, 874, 13598, 1332, 1399, 1324, 1246],
...
]

#shape: [num_batches x batch_size x seq_length]
batch_location = [
[0,0,0,0,1,1,1,1,1,1,1,1,1],
[0,0,0,0,0,0,1,1,1,1,1,1,1],
...
]

#shape: [num_batches x batch_size]
batch_is_postive_lables = [
1,
0
]

Creating a Masking Function

Recall that BERT gets trained on two modeling objectives simultaneously: next sentence prediction and masked language modeling. We have all the data necessary for the first objective, so now we need to build out the second.

I glossed over some of the nitty gritty details of masked language modeling; let’s go over those details now.

Detail 1

Basically, the idea is to take in some sequence, randomly mask out certain tokens, then get the model to guess what that token ought to have been based on the surrounding text.

The [MASK] brown fox jumped over the lazy [MASK]

So, we’ll build a function that takes in an input sequence and randomly masks out values within that sequence. Our input sequences are more complicated than a simple sentence though, we have a list of tokens that correspond to two sentences, and contains special tokens.

[CLS] Here's a famous sentence. [SEP] The quick brown fox jumped over the lazy dog. [pad] [pad] [pad]

As we build our masking function we don’t want to inadvertently mask out special tokens like [CLS] , [SEP] , and [PAD] . We only want to mask out tokens which correspond to the sentences themselves.

Detail 2

There’s another quirk worth addressing before we get into it. In the BERT paper they don’t actually replace every masked word with the [MASK] token.

After we train our model, the [MASK] token will never be seen when the model is actually being used and making inferences. If we only train our model on the [MASK] token, it might learn to disregard other words that might be important in understanding the sequence generally. So, when we decide a random token should be masked we usually replace it with the [MASK] token, but we also sometimes preserve the original token value, and sometimes replace the masked token with a completely random token.

The idea is, this should make the model think more critically about the input, and consider every token, not just the [MASK] token, to be important.

# a conceptual breakdown of masking without always using the mask token
orig_sequence   = 'The quick brown fox jumped over the lazy dog.'
masked_sequence = 'The [MASK] brown fox jumped over the lazy asparagus'
masked_tokens = ['[MASK]', 'fox', 'asparagus']
original_toks = ['quick',  'fox', 'dog']

I hope this example makes this concept clearer. Here we’re masking three words, "quick", "fox" and "dog", but by not always using the mask token the masked language modeling objective becomes much richer as the model also needs to confirm certain words make sense and other words don’t make sense within the context of the input. In the original BERT paper they decided to mask 15% of words within the input. Of that 15%, 80% are replaced with [MASK] , while 10% are replaced with a random word and 10% are not replaced at all. We’ll be using those probabilities in our implementation.

Implementing Masking

Ok, we covered the quirks, here’s the masking code:

#listing out vocab for random token masking
vocab = tokenizer.get_vocab()
valid_token_ids = list(vocab.values())

def mask_batch(batch_tokens, clone=True):
    if clone:
        batch_tokens = torch.clone(batch_tokens)

    # Define the percentage of tokens to potentially mask
    replace_percentage = 0.15

    # Define tokens that should not be replaced
    excluded_tokens = {0, 100, 101, 102, 103}

    # Create a mask to identify tokens that are eligible for replacement
    eligible_mask = ~torch.isin(batch_tokens, torch.tensor(list(excluded_tokens)).to(device))

    # Count the number of eligible tokens
    num_eligible_tokens = eligible_mask.sum().item()

    # Calculate the number of tokens to potentially mask
    num_tokens_to_mask = int(num_eligible_tokens * replace_percentage)

    # Create a random permutation of eligible token indices
    eligible_indices = eligible_mask.nonzero(as_tuple=True)
    random_indices = torch.randperm(num_eligible_tokens)[:num_tokens_to_mask]

    # Create a probability distribution for replacement
    replacement_probs = torch.tensor([0.8, 0.1, 0.1])  # Probabilities for [103, random token, leave unchanged]
    replacement_choices = torch.multinomial(replacement_probs, num_tokens_to_mask, replacement=True)

    # Vector to store if a token was masked (0: not masked, 1: masked)
    masked_indicator = torch.zeros_like(batch_tokens, dtype=torch.int32)

    # Apply replacements based on sampled choices
    for i, idx in enumerate(random_indices):
        row = eligible_indices[0][idx]
        col = eligible_indices[1][idx]

        #replacing with [MASK]
        if replacement_choices[i] == 0:
            batch_tokens[row, col] = 103
            masked_indicator[row, col] = 1
        #replacing with random token
        elif replacement_choices[i] == 1:
            batch_tokens[row, col] = random.choice(valid_token_ids)
            masked_indicator[row, col] = 1
        #not replacing at all
        elif replacement_choices[i] == 2:
            masked_indicator[row, col] = 1

    return batch_tokens, masked_indicator

batch_tokens, masked_indicator = mask_batch(sequence_tokens_batches[0])
batch_tokens

This function ends up outputting both the masked tokens, and the location of the masked tokens. Because a "masked" token might not actually be the [MASK] token, we need to keep track of where the masked tokens exist separately to the token values themselves.

Embedding

Okie dokie, we’ve done basically all the preliminary work required to set up our training dataset. We have batches of tokens, information about what sentences those tokens exist in (sentence 1 or sentence 2), if those pairs belong together or not, and a function that can take a batch of tokens and mask them out.

Now we can actually start building the model.

The first step is embedding. A BERT style model, in being a derivative of transformers, expects a high dimensional vector to represent each word. The model will use these vectors to reason about words, allowing it to (hopefully) create a strong understanding of the input text. So, we need to turn our tokens (which are just integers) into these high dimensional vectors.

Recall that, in a BERT style model, we combine vectors from words, locations, and sentences to construct a vector for each element in the input.
Recall that, in a BERT style model, we combine vectors from words, locations, and sentences to construct a vector for each element in the input.

The embedding portion of the model will take care of both the conversion of tokens into vectors and the addition of positional information by using a lookup table. We’ll define random vectors for every possible token, random vectors that correspond to each input position, and random vectors which correspond to the two sentence inputs. We’ll replace tokens and positions with these random vectors, and use them to represent a token and it’s position. Naturally, it will do a bad job at first as we’re using completely random data, but these random values will be learnable parameters of the model, so the model will learn how to make good vectors for both token and position encoding.

Here’s the PyTorch code to make that actually happen:

import torch.nn as nn
import torch

vocab_size = tokenizer.vocab_size
d_model = 256
n_segments = 2

class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(max_input_length, d_model)  # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long).to(device)
        pos = pos.unsqueeze(0).expand_as(x)  # (seq_len,) -> (batch_size, seq_len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

e = Embedding()
e.to(device)

Here we’re saying we’ll represent the words with vectors of length 256 with the parameter d_model=256 , and we’re saying we’re dealing with two sentences with n_segments=2 . If you wanted to get crazy I suppose you could play around with three or more sentence inputs, but we’re keeping it locked in at two for this example.

We can pass in a batch of data through this module and see what we get.

dummy_embedding = e(sequence_tokens_batches[0], sentence_location_batches[0])
print(dummy_embedding.shape)
print(dummy_embedding)
torch.Size([128, 64, 256])
tensor([[[-0.4109,  0.1544,  0.3778,  ..., -1.9995,  1.3578,  0.3117],
         [-0.5452, -0.7935, -0.6296,  ...,  1.0046, -0.1871, -0.3125],
         [-2.2820,  0.4665, -1.1026,  ..., -0.5876,  1.4205, -1.5876],
         ...,
         [ 1.2866,  0.9395,  0.7138,  ...,  0.4223,  0.3374,  0.6935],
         [-0.3787,  1.4489, -0.7226,  ...,  0.3139,  0.3640,  0.4926],
         [ 1.1291,  1.4248, -0.2899,  ...,  0.8080,  0.7977,  1.4257]],

        [[-0.4109,  0.1544,  0.3778,  ..., -1.9995,  1.3578,  0.3117],
         [-0.9470, -0.4977, -1.0789,  ...,  0.5366,  0.5290, -1.7874],
         [-1.5527, -0.2966, -0.3398,  ..., -0.5468,  1.3547, -0.6128],
         ...,
         [ 1.2866,  0.9395,  0.7138,  ...,  0.4223,  0.3374,  0.6935],
         [-0.3787,  1.4489, -0.7226,  ...,  0.3139,  0.3640,  0.4926],
         [ 1.1291,  1.4248, -0.2899,  ...,  0.8080,  0.7977,  1.4257]],

        [[-0.4109,  0.1544,  0.3778,  ..., -1.9995,  1.3578,  0.3117],
         [-0.9972,  0.2936, -0.3921,  ...,  0.1695, -0.2766, -1.4312],
         [-2.2029, -1.5211, -1.3297,  ..., -0.6648,  2.2392, -0.1643],
         ...,
         [ 1.2866,  0.9395,  0.7138,  ...,  0.4223,  0.3374,  0.6935],
         [-0.3787,  1.4489, -0.7226,  ...,  0.3139,  0.3640,  0.4926],
         [ 1.1291,  1.4248, -0.2899,  ...,  0.8080,  0.7977,  1.4257]],

        ...,

That’s looking good to me! Batch size of 128, sequence length of 64, but now each token is represented a vector of length 256. Keep in mind this output corresponds to just one batch.

Multi-Headed Self Attention

This section is a bit more in depth, and assumes a fair degree of knowledge about multi-headed self-attention. Feel free to skim through or skip if you’re a novice, it’s not fundamental in understanding BERT as a whole. If you’d like to understand this section better, check out my article on transformers and my article on multi-headed self-attention.

BERT is a transformer style model, so multi-headed self-attention is a critical component. Because it’s critical we’re going to be implementing it from scratch. Why not.

At this point I’ve covered multi-headed self-attention (MHSA) so many times the topic and I are like old friends that have accidentally seen each other naked enough times that we know the shape of each other’s butt cheeks. If you’re like "woah, this is too much information, just like that analogy" then use PyTorches MHSA implementation. It’s more efficient anyway.

For the brave, let’s get into it.

First of all we can implement a single attention head. We’ll assume the query, key, and value have already been created, so we can whip that up. This doesn’t have any learnable parameters, those will be in the multi headed self attention mechanism which will employ this as a sub-component.

import numpy as np

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V):
        #Q, K, V of size [batch x sequence_length x dim]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(Q.shape[1])
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn

#sanity checking
q = torch.tensor([[[1.1,1.3],[0.9,0.8]]]).to(device)
k = torch.tensor([[[0.9,1],[0.2,2.1]]]).to(device)
v = torch.tensor([[[1.1,1.3],[0.9,0.8]]]).to(device)
sample = ScaledDotProductAttention().to(device)
sample(q,k,v)
(tensor([[[0.9771, 0.9927],
          [0.9912, 1.0280]]], device='cuda:0'),
 tensor([[[0.3854, 0.6146],
          [0.4559, 0.5441]]], device='cuda:0'))

For those familiar with MHSA, notice that there’s no mask. If you use a pre-made MHSA implementation you’ll almost certainly be required to specify some form of mask because MHSA is almost always used in a context where masking attention is required. In BERT, though, we want every input token to attend to every other input token, so we don’t need a mask at all.

Actually, turning this into multi-headed self-attention is a bit of a bear, and mostly for boring data engineering reasons. We have a batch of examples which need to be turned into querys, keys, and values, then those need to be further divided into multiple heads. This means we effectively have two axis which we need to parallelize self attention across; the batch dimension and a new dimension for the heads.

To get this working I decided to squeeze both those dimensions into a single dimension, and treat the combination of the batch and head dimension as just the batch dimension. Because PyTorch automagically parallelizes across the 0th dimension by assuming it’s the batch dimension, we can effectively parallelize across the batch and heads by squeezing both dimensions into a single dimension.

Before I actually implemented MHSA, I played around with these shape transformations, experimented with a few examples, and came up with something that (I think) works.

#defining sample value matrix
#[batch_size x sequence_len x (query_key_dim * n_heads)]
#in this matrix, [0,1,2,3] represents the values for 2 heads across a single word vector
samp_val = torch.tensor([[[0,1,2,3],[4,5,6,7]],[[0,-1,-2,-3],[-4,-5,-6,-7]]])

#dividing into two heads
#[batch_size x sequence_len x query_key_dim x n_heads]
samp_val = samp_val.view(2,2,2,2)

#moving the head dimension next to the batch dimension
#[batch_size x n_heads x sequence_len x query_key_dim]
samp_val = samp_val.permute(0, 3, 1, 2)

#combining batch and head dimension
#[batch_size*n_heads x sequence_len x query_key_dim]
samp_val = samp_val.reshape(-1, 2, 2)

#that would be the input into mhsa, which would give back the same shape output
#now I want to unpack the mhsa back into the original shape
#[batch_size x sequence_len x (query_key_dim * n_heads)]
#if I do this right, the values should be exactly identical

#seperating heads
#[batch_size x n_heads x sequence_len x query_key_dim]
samp_val = samp_val.reshape(2,2,2,2)

#moving the head dimension to the end
#[batch_size x sequence_len x query_key_dim x n_heads]
samp_val = samp_val.permute(0, 2, 3, 1)

#combining the last dim to effectively concatonate the result of the heads
#[batch_size x sequence_len x query_key_dim*n_heads]
samp_val = samp_val.reshape(2, 2, -1)
samp_val
#note how it's the same as the input. That implies
#that the transformations went well
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 0, -1, -2, -3],
         [-4, -5, -6, -7]]])

Now that I have those transformations blocked out conceptually, I can use them to build MHSA

import torch
import torch.nn as nn

# Define constants
n_heads = 3
query_key_dim = 64
value_dim = 64

class MultiHeadSelfAttention(nn.Module):
    def __init__(self):
        super(MultiHeadSelfAttention, self).__init__()
        # Defining the linear layers that construct the query, key, and value
        self.W_Q = nn.Linear(d_model, query_key_dim * n_heads)   # Projects input to [batch x sequence x (q/k_dim*num_heads)]
        self.W_K = nn.Linear(d_model, query_key_dim * n_heads)   # Projects input to [batch x sequence x (q/k_dim*num_heads)]
        self.W_V = nn.Linear(d_model, value_dim * n_heads)       # Projects input to [batch x sequence x (v_dim*num_heads)]
        self.dot_prod_attn = ScaledDotProductAttention()         # Parameterless system that calculates attention
        self.proj_back = nn.Linear(value_dim * n_heads, d_model) # Projects final output of mhsa back into model dimension

    def forward(self, embedding):

        # passing embedding through dense networks
        qs = self.W_Q(embedding)  # [batch_size x sequence_len x (query_key_dim * n_heads)]
        ks = self.W_K(embedding)  # [batch_size x sequence_len x (query_key_dim * n_heads)]
        vs = self.W_V(embedding)  # [batch_size x sequence_len x (value_dim * n_heads)]

        #dividing out heads
        #[batch_size, sequence_len, q/k/v_dim, n_heads]
        qs = qs.view(batch_size, max_input_length, query_key_dim, n_heads)
        ks = ks.view(batch_size, max_input_length, query_key_dim, n_heads)
        vs = vs.view(batch_size, max_input_length, value_dim, n_heads)

        #moving the head dimension next to the batch dimension
        #[batch_size x n_heads x sequence_len x q/k/v_dim]
        qs = qs.permute(0, 3, 1, 2)
        ks = ks.permute(0, 3, 1, 2)
        vs = vs.permute(0, 3, 1, 2)

        #combining batch and head dimension
        #[batch_size*n_heads x sequence_len x q/k/v_dim]
        qs = qs.reshape(-1, max_input_length, query_key_dim)
        ks = ks.reshape(-1, max_input_length, query_key_dim)
        vs = vs.reshape(-1, max_input_length, value_dim)

        #passing batches/heads of self attention through attn
        #[batch_size*n_heads x sequence_len x q/k/v_dim]
        head_results, _ = self.dot_prod_attn(qs,ks,vs)

        #seperating heads
        #[batch_size x n_heads x sequence_len x v_dim]
        head_results = head_results.reshape(batch_size,n_heads,max_input_length,value_dim)

        #moving the head dimension to the end
        #[batch_size x sequence_len x query_key_dim x n_heads]
        head_results = head_results.permute(0, 2, 3, 1)

        #combining the last dim to effectively concatonate the result of the heads
        #[batch_size x sequence_len x query_key_dim*n_heads]
        head_results = head_results.reshape(batch_size, max_input_length, -1)

        #projecting result of head back into model dimension
        return self.proj_back(head_results)

# Example usage
sample_embeddings = torch.tensor([[[1.1] * d_model] * max_input_length] * batch_size).to(device)
print("Sample embeddings shape:", sample_embeddings.shape)

sample = MultiHeadSelfAttention().to(device)
output = sample(sample_embeddings)
print('Output shape of mhsa:', output.shape)
# the output should be the same size as the input

Sample embeddings shape: torch.Size([128, 64, 256])
Output shape of mhsa: torch.Size([128, 64, 256])

Here I’m defining a few constants that we’ll use through training. n_heads specifies how many attention heads exist per MHSA block, query_key_dim specifies how big the query and key vectors will be, and value_dim specifies how big the value vectors will be.

MHSA has four sets of parameters:

self.W_Q = nn.Linear(d_model, query_key_dim * n_heads)   # Projects input to [batch x sequence x (q/k_dim*num_heads)]
self.W_K = nn.Linear(d_model, query_key_dim * n_heads)   # Projects input to [batch x sequence x (q/k_dim*num_heads)]
self.W_V = nn.Linear(d_model, value_dim * n_heads)       # Projects input to [batch x sequence x (v_dim*num_heads)]
self.proj_back = nn.Linear(value_dim * n_heads, d_model) # Projects final output of mhsa back into model dimension

These are all dense linear networks. Three that turn the tensors of the model into inputs for MHSA, and one that turns the output of MHSA back into the shape needed for modeling.

I feel like this gets glossed over a lot, I’m guilty of it in a lot of my previous articles. These are "pointwise dense networks" which is the default setup in PyTorch. Basically, these apply to all the vectors in your space and assumes the last dimension is the vector dimension. So if you have, for instance, an input of shape [batch_size x sequence_length x input_dim] and you want to turn that into an output of shape [batch_size x sequence_length x output_dim] you can use nn.Linear(input_dim, output_dim) . This network has the parameters to turn one vector into another vector, so when you apply it to your input you’re essentially running the same model across all [batch_size x sequence_length] vectors. This means the word vectors don’t interact with each other, they only get re-represented into a different vector of a different length.

We use those networks to project our input into the query, key, and value

# passing embedding through dense networks
qs = self.W_Q(embedding)  # [batch_size x sequence_len x (query_key_dim * n_heads)]
ks = self.W_K(embedding)  # [batch_size x sequence_len x (query_key_dim * n_heads)]
vs = self.W_V(embedding)  # [batch_size x sequence_len x (value_dim * n_heads)]

then we pass those through our reshaping stuff, pass them through MHSA, then reshape and spit out the result. And thus, we’ve implemented multi-headed self-attention for our BERT model.

Pointwise Feed Forward

We already implemented pointwise feedforward in the construction of the query, key, and value in multi headed self-attention, but this process is also done to the model tokens themselves, as per the classic transformer architecture

Notice how, in the original transformer, feed forward is done after multi-headed self-attention in the encoder block (the block on the left). source
Notice how, in the original transformer, feed forward is done after multi-headed self-attention in the encoder block (the block on the left). source

Just like in multi-headed self-attention, this applies a neural network to each word vector individually, allowing the model to learn to manipulate individual vectors as necessary.

d_ff = 4*d_model

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(torch.nn.functional.gelu(self.fc1(x)))

In this particular implementation we’re expanding the vectors to four times their length with a neural network, applying a non-linear activation function, then compressing that data back into the original model dimension length (256, in this example).

I describe non-linear activation functions in this article if you’re unfamiliar with the topic. Basically, we’re allowing our model to stretch each word vector out into a bigger representation, allowing the model to represent each vector in a diverse number of ways, then we’re passing that larger representation through a function that manipulates the vector in complex ways. The model can learn to exploit that large number of complex representations to create better word vectors, which are then compressed back into the original modeling dimension.

The Encoder Block

Now that we have multi-headed self-attention and pointwise feed forward figured out, we can implement the entire encoder block.

class EncoderBlock(nn.Module):
    def __init__(self):
        super(EncoderBlock, self).__init__()
        self.mhsa = MultiHeadSelfAttention()
        self.pwff = PoswiseFeedForwardNet()

    def forward(self, x):
        mhsa_output = self.mhsa(x)
        skip1 = mhsa_output + x
        pwff_output = self.pwff(skip1)
        skip2 = skip1+pwff_output
        return skip2

Here we’re:

  1. Passing the input through multi-headed self-attention
  2. adding the original input to the output of MHSA, combining both, creating the first skip connection.
  3. Passing that through pointwise feed forward
  4. Adding the output of pointwise feed forward to the previous skip connection output, creating the second skip connection.

Skip connections have a lot of interesting theory behind them, but for our purposes we’ll say they help a model learn more easily by combining simple and more complex information together, allowing the model to use both to its advantage.

Come to think of it, I probably should have added layer normalization. If you were being proper, you would probably want to do that after the second skip connection via torch.nn.LayerNorm . Anyway, that’s a minor detail as far as this demo is concerned, so whatever. Let’s get onto the fun stuff, actually building BERT!

Building BERT

And we have arrived. We can put all the pieces together to build a BERT model.

n_layers = 1

class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        #for converting tokens into vector embeddings
        self.embedding = Embedding()
        #encoder blocks
        self.encoder_blocks = nn.ModuleList([EncoderBlock() for _ in range(n_layers)])
        #for decoding a word vector (or tensor of them) into token predictions
        self.decoder = nn.Linear(d_model, tokenizer.vocab_size, bias=False)
        #for converting the first output token into a binary classification
        self.classifier = nn.Linear(d_model, 1, bias=False)

    def forward(self, x, seg, masked_token_locations):

        #x of shape [batch x seq_len x model_dim]
        embeddings = self.embedding(x, seg)
        x = embeddings
        for block in self.encoder_blocks:
            x = block(x)

        #passing first token through classifier
        clsf_logits = self.classifier(x[:,0,:])

        #passing masked tokens through decoder
        masked_token_embeddings = embeddings[masked_token_locations.bool()]
        token_logits = self.decoder(masked_token_embeddings)

        return clsf_logits, token_logits

BERT is pretty straight forward if you understand the two major subcomponents we’ve already created:

  1. It has an embedding sub-module which turns the token_ids of the input into vectors
  2. It has a bunch of encoder blocks which manipulate the input to create a dense and meaning rich representation. Under the hood these consist of multi-headed self-attention and pointwise feed forward layers.

The only new things are the decoder and the classifier , these are used to turn certain vectors in the output of the last encoder layer into predictions. The classifier looks at the first input token, which is always [CLS] in the input (we set our data up that way) and makes a prediction as to whether or not the sentences in the input belong together or not.

#for every example in the batch, this takes the first vector and 
#passes it to the classifier linear netowrk for prediction.
clsf_logits = self.classifier(x[:,0,:])

Notice how the classifier is defined as a linear network of output size 1.

#for converting the first output token into a binary classification
self.classifier = nn.Linear(d_model, 1, bias=False)

This is because we’re making a binary classification (yes or no). Predicted values over 0.5 will be interpreted as true, while predicted values of less than 0.5 will be interpreted as false.

The decoder does something similar, except it looks at all masked tokens, and instead of making a true or false prediction, it has to predict what token should be there. Thus, the output is of length tokenizer.vocab_size , meaning we predict, out of all tokens, what token a particular masked word should be.

The reason the outputs are called "logits" is a fiddly little Data Science topic. Feel free to google the difference between logits and probabilities if you’re inclined. Also, I’m only using one decoder block to speed up training (n_layers). If you were building a BERT model yourself I would expect at least a few decoder blocks stacked on top of one another.

And, with that, we’ve built a BERT model. Now we need to pre-train it on our data.

Pre-Training BERT

If you’re already familiar with PyTorch (which I kind of assume you are if you’ve gotten this far. If not, good for you!) You’ll probably recognize this code. Now that our model and data are set up, the actual training process is pretty straight forward.

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

model = BERT().to(device)
token_criterion = nn.CrossEntropyLoss()  # Expect indices, not one-hot vectors
classification_criterion = nn.BCEWithLogitsLoss()  # For logits directly
optimizer = optim.Adam(model.parameters(), lr=0.001)

#keeping track of the losses across all epochs
losses = [[]]

#these epochs can take a while, keeping it at a fairly small number
for epoch in range(4):
    for sequence_batch, location_batch, classtarg_batch in tqdm(zip(sequence_tokens_batches, sentence_location_batches, is_positives_batches)):
        # Zeroing out gradients from last iteration
        optimizer.zero_grad()

        # Masking the tokens in the input sequence
        masked_tokens, masked_token_locations = mask_batch(sequence_batch)

        # Generating class and masked token predictions
        clsf_logits, token_logits = model(masked_tokens, location_batch, masked_token_locations)

        # Setting up target for masked token prediction
        masked_token_targets = sequence_batch[masked_token_locations.bool()]

        # Calculating loss for next sentence classification
        loss_clsf = classification_criterion(clsf_logits.squeeze(), classtarg_batch.float())

        # Calculating loss for masked language modeling
        loss_mlm = token_criterion(token_logits, masked_token_targets)

        # Combining losses
        loss = loss_mlm + loss_clsf

        #keeping track of loss across the current epoch
        losses[-1].append(float(loss))

        # Backpropagation
        loss.backward()
        optimizer.step()

    print(f'=======Epoch {epoch} Completed=======')
    print(f'average loss in epoch: {np.mean(losses[-1])}')
    losses.append([])

Working through some highlights:

  1. We define our model, and put it on the device (a GPU).
  2. We define "criteria", these are the functions which will calculate the loss (how wrong the model was) from masked language modeling and next sentence prediction.
  3. We define an optimizer, which will look at how large the loss was and update the model accordingly.
  4. We go over all the data over n epochs, in this case 4.
  5. We iterate over all batches of the data
  6. We mask our batch randomly with the masking function we defined previously
  7. We run the masked tokens, along with location information and where the masked tokens are, through the model. We get back predictions for next sentence prediction and predictions as to what the model thinks each masked token should be.
  8. We pass our predictions through each respective criterion, with what the outputs should have been, to calculate loss.
  9. We call loss.backward() to calculate how the model should change to be less wrong at this particular example
  10. We allow the optimizer to update the model based on the model’s performance on this batch

In the end, we can see that the model improves on the data as the loss value continually declines.

7318it [13:50,  8.81it/s]
=======Epoch 0 Completed=======
average loss in epoch: 7.652484150003233
7318it [13:50,  8.81it/s]
=======Epoch 1 Completed=======
average loss in epoch: 7.468071281339797
7318it [13:50,  8.81it/s]
=======Epoch 2 Completed=======
average loss in epoch: 7.4392927800674
7318it [13:49,  8.82it/s]
=======Epoch 3 Completed=======
average loss in epoch: 7.4234244145145505

Loss in this context is an abstract concept that generally says how wrong the model is, but there’s no ideal loss value. In this environment, as long as the number’s going down that means the model is becoming less wrong, implying the model is learning to understand text. If we had a bigger model, more data, and trained over a longer period of time, it’s likely that this number would keep going down over many epochs, allowing the model to get really, really good at understanding text.

Fine Tuning

So, we now have a BERT model that has some understanding of text. Let’s use it to do something.

The amazon_polarity dataset is an open dataset from amazon that contains information about whether a review is positive or negative. It consists of a big batch of review titles, review content, and labels saying if the review is positive or negative.

"""Loading the dataset and printing out an example
"""
fine_tune_ds = load_dataset("fancyzhx/amazon_polarity")
for elem in fine_tune_ds['train']:
    print(elem)
    break
{'label': 1, 'title': 'Stuning even for the non-gamer', 'content': 'This sound track was beautiful! It paints the senery in your mind so well I would recomend it even to people who hate vid. game music! I have played the game Chrono Cross but out of all of the games I have ever played it has the best music! It backs away from crude keyboarding and takes a fresher step with grate guitars and soulful orchestras. It would impress anyone who cares to listen! ^_^'}

We’re going to use this data to fine tune our BERT model to predict if the review is positive or negative.

First we need to turn this data into data that makes sense in a BERT model. The exact approach for this process can vary from task to task. luckily for us this dataset consists of pairs of sentences (the title and content) so we can format the fine-tuned data just like we formatted the pre-training data previously. See the previous section "Defining Training Batches" for a breakdown of this approach.

"""Turning the data from the amazon dataset into something compatible with BERT
"""
def preprocess_data(data, max_num = 100000):
    data_tokens = []
    data_positional = []
    data_targets = []

    #unpacking data
    for i, elem in enumerate(data):

        #tokenizing the title and content
        sentence1 = elem['title']
        sentence2 = elem['content']
        tokens = tokenizer([sentence1, sentence2])
        sentence1_tokens = tokens['input_ids'][0]
        sentence2_tokens = tokens['input_ids'][1]

        # Trimming down tokens
        if len(sentence1_tokens) + len(sentence2_tokens) > max_input_length:
            sentence1_tokens = [101] + sentence1_tokens[-int(max_input_length / 2) + 1:]
            sentence2_tokens = sentence2_tokens[:int(max_input_length / 2) - 1] + [102]

        # Creating sentence tokens
        sentence_tokens = [0] * len(sentence1_tokens) + [1] * len(sentence2_tokens)

        # Combining and padding
        pad_num = max_input_length - (len(sentence1_tokens) + len(sentence2_tokens))
        sequence_tokens = sentence1_tokens + sentence2_tokens + [0] * pad_num
        sentence_location_tokens = sentence_tokens + [1] * pad_num

        data_tokens.append(sequence_tokens)
        data_positional.append(sentence_location_tokens)
        data_targets.append(elem['label'])

        if i > max_num: break

    return torch.tensor(data_positional), torch.tensor(data_tokens), torch.tensor(data_targets)

#processing data into modeling data
train_pos, train_tok, train_targ = preprocess_data(fine_tune_ds['train'])
test_pos, test_tok, test_targ = preprocess_data(fine_tune_ds['test'])

#moving training to device
train_pos = train_pos.to(device)
train_tok = train_tok.to(device)
train_targ = train_targ.to(device)

#moving testing to device
test_pos = test_pos.to(device)
test_tok = test_tok.to(device)
test_targ = test_targ.to(device)

Before we fine tune let’s replace the classifier with a randomly initialized model. This allows us to preserve BERT’s general language understanding, but start fresh in terms of the part of the model that’s doing the classification, which is good to do because we’re classifying something completely different.

# Replacing classification head with a new head
# the new training objective is still binary classification,
# except these parameters will be used to decide if a
# review was positive or negative
model.classifier = nn.Linear(d_model, 1, bias=False).to(device)

# resetting the optimizer to have access to the parameters of the new head
optimizer = optim.Adam(model.parameters(), lr=0.001)

Then we can run the same training code as before, except on the fine tuned dataset and with the pre-trained model with a new classification head. Here we don’t care about the masked language modeling objective, so I’m passing the original tokens into the model rather than the masked ones. If you wanted to do this properly you would artificially create a mask of all zeros, but I didn’t feel like it.

ft_losses = [[]*1]

for epoch in range(5):
    for i in tqdm(range(0, train_pos.shape[0], batch_size)):

        if i+batch_size>=train_pos.shape[0]:
            break

        #getting batch
        train_pos_batch = train_pos[i:i+batch_size]
        train_tok_batch = train_tok[i:i+batch_size]
        train_targ_batch = train_targ[i:i+batch_size]

        # Zeroing out gradients from last iteration
        optimizer.zero_grad()

        # Masking the tokens in the input sequence
        masked_tokens, masked_token_locations = mask_batch(train_tok_batch)

        # Generating class and masked token predictions
        clsf_logits, token_logits = model(train_tok_batch, train_pos_batch, masked_token_locations)

        # Setting up target for masked token prediction
        masked_token_targets = sequence_batch[masked_token_locations.bool()]

        # Calculating loss for next sentence classification
        loss_clsf = classification_criterion(clsf_logits.squeeze(), train_targ_batch.float())

        # Combining losses
        loss = loss_clsf

        ft_losses[-1].append(float(loss))

        # Backpropagation
        loss.backward()
        optimizer.step()

    print(f'=======Epoch {epoch} Completed=======')
    print(f'average loss in epoch: {np.mean(ft_losses[-1])}')
    losses.append([])

Actually, when I ran this model I forgot to change any of this code, so I was still optimizing on masked language modeling as well as optimizing on the the classification of positive or negative reviews. I’m sure you could experiment with that strategy, there might be some merit to getting the model to better understand the type of text used in reviews specifically.

Anyway, the model got better

100%|█████████▉| 781/782 [01:13<00:00, 10.60it/s]
=======Epoch 0 Completed=======
average loss in epoch: 5.904687740433384
100%|█████████▉| 781/782 [01:13<00:00, 10.65it/s]
=======Epoch 1 Completed=======
average loss in epoch: 5.426271478894731
100%|█████████▉| 781/782 [01:13<00:00, 10.57it/s]
=======Epoch 2 Completed=======
average loss in epoch: 5.234847757687082
100%|█████████▉| 781/782 [01:13<00:00, 10.60it/s]
=======Epoch 3 Completed=======
average loss in epoch: 5.128661873245972
100%|█████████▉| 781/782 [01:13<00:00, 10.61it/s]
=======Epoch 4 Completed=======
average loss in epoch: 5.058861758614319

This dataset has a test set, so we can apply this fine tuned BERT model to see how good it is at classifying reviews it’s never seen before.

is_correct = []
predicted_class = []
original_class = []

for i in tqdm(range(0, test_pos.shape[0], batch_size)):

    if i+batch_size>=test_pos.shape[0]:
            break

    #getting batch
    test_pos_batch = test_pos[i:i+batch_size]
    test_tok_batch = test_tok[i:i+batch_size]
    test_targ_batch = test_targ[i:i+batch_size]

    #making prediction, not masking anything
    clsf_logits, _ = model(test_tok_batch, test_pos_batch, torch.zeros(test_pos_batch.shape))

    #converting logits to probabilities then rounding to classifications
    res = torch.sigmoid(clsf_logits).round().squeeze()

    #keeping track of the original class (positive or negative) and if the model was correct
    original_class.extend(np.array(test_targ_batch.to('cpu')))
    is_correct.extend(np.array((res == test_targ_batch).to('cpu')))
    predicted_class.extend(np.array(res.detach().to('cpu')))

#accuracy
sum(list(is_correct))/len(is_correct)
0.7686259603072984

So, we got a model that could classify if a review was positive or negative with a 77% accuracy. That might not sound that impressive, but the BERT model used in this example is virtually microscopic. If you used more encoder blocks, a larger model dimension, and played around with a few other model parameters I think you could easily pass 90%. In fact, if you’ve gotten this far I think you should give that a shot.

Conclusion

Another big article, I hope you enjoyed it!

In this article we briefly reviewed the predecessor to BERT, the transformer, then explored how the encoder component of a transformer can be used to understand text, and how BERT uses masked language modeling and next sentence prediction to encourage textual understanding in the pre-training process. We then discussed how that model can be fine-tuned, usually after swapping a projection head within the model, to apply that general textual understanding to specific problems.

After we covered the theory, we actually created a BERT style model. We explored tokenization, data processing, embedding, multi-headed self-attention, pointwise feed forward, pre-training, and fine tuning. By the end of that process we had created a BERT style model, trained it on Wiki articles to understand text, then fine-tuned it to classify if product reviews were positive or negative.

Join Intuitively and Exhaustively Explained

At IAEE you can find:

  • Long form content, like the article you just read
  • Thought pieces, based on my experience as a data scientist, engineering director, and entrepreneur
  • A discord community focused on learning AI
  • Regular Lectures and office hours
Join IAEE
Join IAEE

Related Articles