The world’s leading publication for data science, AI, and ML professionals.

Explaining the Attention Mechanism

Building a Transformer from scratch to build a simple generative model

If you are not a Medium subscriber, click here for a free version of this story.

The Transformer architecture has revolutionized the field of AI and forms the basis not only for ChatGPT, but has also led to unprecedented performance in image recognition, scene understanding, and robotics. Unfortunately, the transformer architecture in itself is quite complex, making it hard to spot what really matters, in particular if you are new to machine learning. The best way to understand Transformers is to think about a problem as simple as generating random names, character by character. In a previous article, I have explained all the tooling that you will need for such a model, including training models in Pytorch and Batch-Processing, by focussing on the simplest possible model: predicting the next character based on its frequency given the former character in a dataset of common names.

Gradient Descent and Batch-Processing for Generative Models in PyTorch

In this article, we build up on this baseline to introduce a state-of-the-art model, the Transformer. We will start by providing basic code to read and pre-process the data, then introduce the Attention architecture by focussing on its key aspect first – cosine similarity between all tokens in a sequence. We will then add query, key, and value to build a Transformer Encoder, and hence develop the Decoder, the core of all generative models. Finally, we will train the Transformer to generate random names that sound much more reasonable than those generated by a bigram model, providing you with a framework that lets you study the impact of the different components as well as replacing them with built-in routines.

The Data

The code snippet that follows contains all the relevant pieces from the previous article that you need to move forward with the examples here. I recommend opening a Google Colab notebook at the side, copy and paste the code there, and poke around with the data structures until you are sure you understand what they do. If you find this challenging, work through the previous article first.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

torch.manual_seed(42)

# Download names file
!wget https://raw.githubusercontent.com/hackerb9/ssa-baby-names/refs/heads/main/allnames.txt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Load names
with open("allnames.txt", "r") as f:
    names = f.read().splitlines()
    names = [name.lower() for name in names]

# Define alphabet and mappings
alphabet = [' '] + sorted(list(set(''.join(names)))) + ['.']
itoc = {i: c for i, c in enumerate(alphabet)}
ctoi = {c: i for i, c in enumerate(alphabet)}

encode = lambda name : [ctoi[c] for c in name]
decode = lambda tokens : ''.join([itoc[i] for i in tokens])

# Create training and validation set
n=int(0.9*len(names))
train_data, val_data = random_split(names, [n, len(names)-n])

class NameDataset(Dataset):
    def __init__(self, names):
        self.names = names
        self.ctoi = ctoi
        self.alphabet_size = len(alphabet)

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        x = [self.ctoi[c] for c in name]  # Convert characters to indices
        y = x[1:] + [self.ctoi[' ']]  # The next character to predict (shifted version of x)
        x = torch.tensor(x).to(device)
        y = torch.tensor(y).to(device)
        return x, y  

# Define a function to pad sequences
def pad_sequences(batch):
    max_len = max([len(x) for x, _ in batch])  # Find the max length in the batch
    padded_x = []
    padded_y = []

    for x, y in batch:
        padded_x.append(F.pad(x, (0, max_len - len(x)), "constant", ctoi[' ']))  # Pad x
        padded_y.append(F.pad(y, (0, max_len - len(x)), "constant", ctoi['.']))  # Pad y

    # Stack the padded sequences to create the batch
    return torch.stack(padded_x), torch.stack(padded_y)

The code downloads the data from the internet using wget, loads it into a list names, extracts all the characters into a list alphabet, and defines dictionaries to turn characters into indices and the other way round (ctoi and itoc). For this article, I have added helper functions that encode and decode a name into a list of indices, as well as a validation set that uses 10% of the data, which will allow us to compute validation loss on previously unseen data. Finally, we provide a DataLoader that provides a random list of words that are all padded to the same length. Here, ‘ ‘ (space) is used to pad the names, and ‘.’ is used to pad the targets. For this article, we don’t use one-hot encoding, but return a tensor of indices. You can test this framework using the following code:

torch.manual_seed(42)

train_dataset = NameDataset(train_data)
val_dataset = NameDataset(val_data)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=pad_sequences)

name = next(iter(train_loader)) # Tuple of (x, target)
print(decode(name[0].tolist()[0]))
print(name[0][0]) # grab the 0th name
print(name[1][0]) # grab the 0th target
wilona   
tensor([23,  9, 12, 15, 14,  1,  0,  0,  0], device='cuda:0')
tensor([ 9, 12, 15, 14,  1,  0, 27, 27, 27], device='cuda:0')

We are drawing a batch of 32 names (_batchsize=32) and print the first name, its encoding, and its target. The result should always be the same due to using _torch.manualseed(42). DataLoader is a Python iterable from which we can draw using next(). Note the need for name[0] as the data loader returns tuples of name and target. As data comes in batches, even if there is only a single entry in this case, we need a second [0] to access the first batch. As we have drawn 32 names, at least one of them is longer than "Wilona", resulting in three characters of padding.

Knowing how the kind of data we are dealing with looks like, we are ready to delve into self-attention.

Self-Attention

Self-attention is a mechanism where each element in a sequence computes a weighted representation of the entire sequence, based on the similarity of its features to others. This sounds promising to solve our problem that looking back to a single character does not provide enough information to randomly generate meaningful words. Instead, we would like a way to encode information about an entire sequence at once and use that instead to guess the next character.

"wilon" -> "a"

Self-Attention does this in a very efficient way by computing the cosine-similarity, also known as the dot-product, between each token and every other token. The dot-product measures the angle between two vectors a and b. If they are the same, i.e. the angle between them is zero, the dot product is 1. If they are maximally orthogonally, i.e. the angle between them is 90 degrees, the dot product is 0. To compute it, we simply multiply individual entries and add them up:

So far, we have been using One-Hot encoding to encode our sequences. Let’s do that for the name "laika", turning indices in our alphabet into vectors:

n_embd=len(alphabet)

x=torch.tensor(encode('laika')).unsqueeze(0) # to add batch dimension

xenc=F.one_hot(x, num_classes=n_embd).float()
print(xenc)
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

We can see that the second and the last row are the same as they both represent the first character. We can now look at what happens when taking the dot-product between this encoding and itself:

print(xenc @ xenc.transpose(-2,-1)) # (5x28) * (28x5) -> (5x5)
tensor([[[1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 1.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.],
         [0., 1., 0., 0., 1.]]])

The transpose(dim0,dim1) function expects two dimensions that it will swap. If you are wondering why a matrix multiplication corresponds to computing the dot-product between all combinations of rows of the input encoding and its transpose, it might make sense to compute a couple of entries by hand and compare the result with the equation above.

Once you are convinced that each entry of the above matrix is the dot product between all characters in "laika" with each other – a 5×5 matrix, let’s look at the result. First, we observe that the matrix has a diagonal of all ones. This makes sense as every character is most similar to itself. We also see two additional ones at (4,1) and (1,4) as the second and the fifth character in "laika" are identical.

Artistic rendering of the transformer architecture and Laika, the first dog in space, via Stable Diffusion with ChatGPT prompt. Image: own.
Artistic rendering of the transformer architecture and Laika, the first dog in space, via Stable Diffusion with ChatGPT prompt. Image: own.

Multiplying the resulting self-attention matrix with the input encoding once more results again into the original sequence, this time with an emphasis on the letter "a":

print((xenc @ xenc.transpose(-2,-1)) @ xenc) # (1x5x5) x (1x5x28) -> (1x5x28)
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]])

We can better see where this going, by adding the next ingredient of self-attention, which is turning the cosine similarity matrix into probabilities using SoftMax:

(xenc @ xenc.transpose(-2,-1)).softmax(dim=-1)
tensor([[[0.4046, 0.1488, 0.1488, 0.1488, 0.1488],
         [0.1185, 0.3222, 0.1185, 0.1185, 0.3222],
         [0.1488, 0.1488, 0.4046, 0.1488, 0.1488],
         [0.1488, 0.1488, 0.1488, 0.4046, 0.1488],
         [0.1185, 0.3222, 0.1185, 0.1185, 0.3222]]])

Due to the exponentiation and normalization that SoftMax does, we maintain the same information, but arrive at non-zero entries for all the other character combinations (as exp(0) is 1). Let’s turn this matrix back into a sequence by multiplying with xenc one more time:

attn = (xenc @ xenc.transpose(-2,-1)).softmax(dim=-1) @ xenc
print(attn)

tensor([[[0.0000, 0.2977, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.1488, 0.0000, 0.1488, 0.4046, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.6444, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.1185, 0.0000, 0.1185, 0.1185, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.2977, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.4046, 0.0000, 0.1488, 0.1488, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.2977, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.1488, 0.0000, 0.4046, 0.1488, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.6444, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.1185, 0.0000, 0.1185, 0.1185, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000]]])

We can see that this sequence maintains the original meaning. "l" is still the highest probability in the first character (the first row), and the second and last row are still the same. Indeed, applying .argmax(dim=-1) to select the highest probability and decoding yields again "laika".

attn.argmax(dim=-1)
tensor([[12,  1,  9, 11,  1]])

We have learned something important, however: the first character also has a high likelihood of being an "a" (0.2977), and even an "i" or a "k" with 14.88% probability.

The self-attention mechanism is therefore turning a sequence into a representation that learns something about its overall statistics!

More excitingly, every single token in the sequence carries a version of this information. Specifically, every token now also carries probabilities that serve as the statistics for the entire sequence. While a token in this example summarizes an entire word, tokens that represent image patches will summarize the entire image. This is why the attention mechanism is so incredibly versatile and powerful. It might also dawn on you how we could possibly sample from such a model, e.g. to generate words consisting of "l", "a", "i", and "k", but the model does not learn anything yet. Let’s change that.

Query, Key and Value

So far, we have used xenc -the one-hot encoded sequence representing our word three times. Its dimension was BxTxC, where B=1 are the number of batches, T=5 the number of tokens, and C=28 the dimension of one-hot encoding the character alphabet (26) plus "." and " " (space).

In order to learn not only about the statistics of the individual words, but about arbitrary relationships that contribute to our loss function – penalizing random words over reasonable baby names – we can run xenc through three different linear projections:

B, T, C = xenc.shape
dk = C

query = nn.Linear(C, dk, bias=False)
key = nn.Linear(C, dk, bias=False) 
value = nn.Linear(C, dk, bias=False) 

Q = query(xenc) # B x T x dk
K = key(xenc) # B x T x dk
V = value(xenc) # B x T x dk

We have first defined three linear layers without bias and then run xenc through each of these layers. The results are customarily called key, query, and value. As there is no bias, the linear layers correspond to multiplication with a matrix of dimension (C x dk), which weights the model will be able to learn. The original paper provides the following equation for the Self-Attention mechanism:

Although you will find many explanation that try to add meaning to the terms query, key, and value, I find them neither very compelling nor helpful.

We have already implemented, the structure of which you have already seen in the examples above, but for normalizing with the square root of the dimension of the linear layer dk. This is done so that the output of the softmax remains Gaussian with zero mean and variance one. In code:

attn = ((Q @ K.transpose(-2,-1))/(dk**0.5)).softmax(dim=-1) @ V

It is meaningless to print out the result of this operation as the linear layers have been initialized with random values, but you can decode the output

decode(attn.argmax(dim=-1)[0].tolist())

and will retrieve a random string. We can also sample from the output distribution to retrieve random strings:

for _ in range(10):  
  attn_probs = attn.softmax(dim=-1)  # Apply softmax to get probabilities over the vocabulary
  sampled_indices = torch.multinomial(attn_probs.view(-1, attn_probs.size(-1)), 1)
  print(decode(sampled_indices.T[0].tolist()))

.rwfg
llkrz
shwhc
hmrm 
xhhmc
ihsap
hwzmt
piaf.
kmqqh
owlhc

It might be tempting to actually train such a simple model by training it with our 100k+ names that are all padded to 15 characters length, but self-attention is only one aspect of a Transformer Encoder, an architecture that works very well. So let’s dive right into it.

Transformer Encoder

The sketch below shows the complete architecture of a Transformer Encoder from the paper "Attention is All You Need". So far, we have only described the orange Attention block. As "Input Embedding", we have chosen One-Hot Encoding. Let’s go step by step through what is missing:

Transformer Encoder block. From: Vaswani, A. "Attention is all you need." Advances in Neural Information Processing Systems (2017).
Transformer Encoder block. From: Vaswani, A. "Attention is all you need." Advances in Neural Information Processing Systems (2017).

Multi-Head Attention

The non-linear projections in the attention head, query, key, and value, project the input dimension from C to dk. In our example above, dk has been equal to C (28 in our one-hot encoded example). In order to parallelize this task, it is possible to split dk into multiple equal parts such as two heads of dk=14 or four heads of dk=7. This would result in two or four (in this example) triplets of distinct key, query, and value projections that each can focus on different aspects of the data. When analyzing text, one head might learn syntactic relationships (e.g., subject-verb agreement), while another head might capture semantic relationships (e.g., synonyms or antonyms). By combining information from multiple perspectives, the model develops a more nuanced understanding of the context. Here is the complete code for a Multi-Head Attention class:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by the number of heads"
        self.num_heads = num_heads
        self.dk = embed_dim // num_heads

        # Linear layers for query, key, and value (in the case of cross-attention, separate inputs are used)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, q, k, v):
        B, T, C = q.shape  # Assuming q, k, v have the same shape (B: batch size, T: sequence length, C: embedding dim)

        # Project Q, K, V using their respective linear layers
        q = self.q_proj(q)  # Shape: (B, T, C)
        k = self.k_proj(k)  # Shape: (B, T, C)
        v = self.v_proj(v)  # Shape: (B, T, C)

        # Reshape into (B, num_heads, T, dk)
        q = q.view(B, T, self.num_heads, self.dk).transpose(1, 2)  # (B, heads, T, dk)
        k = k.view(B, T, self.num_heads, self.dk).transpose(1, 2)  # (B, heads, T, dk)
        v = v.view(B, T, self.num_heads, self.dk).transpose(1, 2)  # (B, heads, T, dk)

        # Scaled dot-product attention
        attn_weights = (q @ k.transpose(-2, -1)) / (self.dk ** 0.5)  # (B, heads, T, T)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = attn_weights @ v  # (B, heads, T, dk)

        # Combine heads back to (B, T, C)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)

        # Final linear projection
        return self.out_proj(attn_output)

m = MultiHeadAttention(28,4)
attn = m(xenc,xenc,xenc)
attn.shape

You will notice a final projection (_outproj) at the end, which allows the model to weigh the output from the different heads. This addition has already been described in the original "Attention is all you need" paper, but is not explicitly shown in the sketch.

The last three lines instantiate the model and demonstrate how to use it on xenc, demonstrating that the shape remains as expected (torch.Size([1, 5, 28]). Notice, that query, key and value parameters are all receiving the same value xenc. Keeping things separate is customary in most implementations as it will allow to replace the query with an input from elsewhere. This is known as cross-attention and will become relevant later, when we discuss the transformer decoder. Right now, we are only learning about the encoder.

Add and Norm

This layer does two things:

  1. It passes through the original encoding and adds it to the self-attention mechanism. This is particularly helpful once multiple attention heads are stacked on top of each other (notice the "Nx", the number of stacked heads in the figure above), as it allows the gradients to pass through directly.
  2. It normalizes the layer output to have zero mean and variance one. This is also particularly important when stacking multiple heads on top of each other or when feeding the encoder output to another part of the neural network that expects input to be zero mean and variance one.

Implementation is straightforward:

# Add and Norm
residual = xenc
attention_output = self_attention + residual
attention_output = F.layer_norm(attention_output, normalized_shape=attention_output.shape[-1])

Here _selfattention is the output from the multi-head attention block. In addition to normalizing the values to have zero mean and unit variance, _layernorm applies learnable scaling and shifting parameters. Layer norm is also available from Torch’s nn class, which is the preferred choice when implementing layer norm in a neural network, which we will do below.

Feed Forward

The final step in the transformer encoder that follows self-attention is a feed-forward layer. It usually combines two fully connected neural network layers with ReLU activation (rectified linear unit, max(x,0)). In practice, self-attention and feed forward are combined in a TransformerEncoderLayer class:

import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=4*28, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()

        # Multi-Head Attention
        self.self_attention = MultiHeadAttention(d_model, nhead)

        # Feedforward layer
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),  # First fully connected layer
            nn.ReLU(),                          # Non-linearity
            nn.Linear(dim_feedforward, d_model)  # Second fully connected layer
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout layer
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # Self-attention block
        attn_output = self.self_attention(src, src, src)
        src = self.norm1(src + attn_output)  # Add & Norm

        # Feedforward block
        ff_output = self.feedforward(src)
        src = self.norm2(src + self.dropout(ff_output))  # Add & Norm

        return src

encoder_layer = TransformerEncoderLayer(28, 4)
output = encoder_layer(xenc)
output.shape

The last three lines again instantiate this new class and test it with the encoding of "laika". Notice that we are now using nn.LayerNorm() instead of _F.layernorm(). The TransformerEncoderLayer class now implements everything that is shown in the Encoder module further above. In practice, dropout is also added to the feed-forward and self-attention blocks, but has been omitted here in the interest of a minimalist implementation.

As you can see, there are quite a lot of moving pieces and code to wire up the Transformer. This does not change that most of the action is in the self-attention mechanism, which computes the cosine similarity between all token pairings.

Positional Encoding

So far, self-attention is comparing every token in a sequence with any other token, but there is no notion of position. Two characters at either end of a sequence are treated the same way as characters next to each other. Looking at the matrix representation might give you a different impression as the matrix seems to maintain spatiality, but this is not the case and a sequence like "stressed" and "dessert" gets encoded in the exact same way. In that sense, the self-attention block is even worse than the bigram model, which at least has a notion of one character explicitly following each other. This is addressed by positional encoding, which simply adds a position-dependent function such as sin() and cos() to the input encoding:

import torch
import math

class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, max_len=16):
        super(PositionalEncoding, self).__init__()

        # Create a long enough "position" tensor
        position = torch.arange(0, max_len).unsqueeze(1)  # (max_len, 1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * -(math.log(10000.0) / embed_dim))  # (embed_dim / 2)

        # Apply the sine and cosine functions
        pe = torch.zeros(max_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)  # Apply sine to even indices
        pe[:, 1::2] = torch.cos(position * div_term)  # Apply cosine to odd indices

        # Register the positional encoding as a buffer (no gradient updates)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: Tensor of shape (batch_size, seq_len, embed_dim)
        return x + self.pe[:x.size(1)]  # Add the positional encoding to the input tensor

m = PositionalEncoding(28)
m.forward(xenc).shape

Here, _divterm is a scaling factor applied to the positions, controlling how fast the sine and cosine functions oscillate. Notice, that the positional encoding is added across the entire dimension of the embedding (28 in our running example), and its frequency changes based on the index i in the embedding dimension as can be seen in the equations that have been used in the "Attention is all you need" paper:

The positional embedding also uses two functions, sine and cosine, for odd and even indices respectively.

Albeit simple, positional encoding makes a big difference. It is best understand when considering word instead of character tokens. Consider the sentence:

Only John saw the movie 1__2__34_5____

For simplicity, we assume the vocabulary size to be only five, leading to a straightforward encoding. Let’s now change the meaning of this sentence quite substantially by rearranging the words:

John saw the movie only 2__345____1

Let’s now add some positional encoding 1, 2, 3, 4, 5. By adding the positional encoding to the text encoding we obtain

Only John saw the movie -> 2,4,6,8,10 John saw the movie only -> 3,5,7,9,6

In this case, the embeddings for most words are quite similar, e.g. John is encoded as 3 in one and 4 in the other, but for the word "only", which is once encoded as 2, once as 6. The transformer is therefore able to treat the word only very different based on its position in the sequence, while the embedding for words that appear anywhere in the sentence averages out with enough data.

In practice, using sine and cosine functions as above might be overkill, and simply adding a learnable embedding will lead to better learning and more compact code:

    wpe = nn.Embedding(len(alphabet),d_model)

    pos = torch.arange(0, T, dtype=torch.long).unsqueeze(0) # shape (1, T)
    wpe(pos)

Both embeddings are simply added to the sequence token by token.

Transformer Decoder

The transformer encoder is perfect to summarize a sequence of tokens. Due to the self-attention mechanism in which all tokens are exchanging information with each other, a single token contains all pertinent information of a sequence.

This information can then be further processed by the final feed-forward network. By choosing the output dimension of the feedforward layer, this information can be used for sentiment analysis (a simple true-false value), create a summary token for further processing, or result into a new sequence. For actually generating new data, such as random name, we need a variation known as a decoder.

Here is a schematic of the Transformer Decoder block from the "Attention is all you need paper":

The Transformer Decoder block. From: Vaswani, A. "Attention is all you need." Advances in Neural Information Processing Systems (2017).
The Transformer Decoder block. From: Vaswani, A. "Attention is all you need." Advances in Neural Information Processing Systems (2017).

It looks very similar to the encoder, but is using two multi-head attention blocks.

The key idea is that the Transformer Decoder can generate an endless stream of tokens based on the history of previous tokens up to a maximum length often referred to as context or block size.

Here the notion "Outputs (shifted right)" in the schematic might be a little confusing. Let’s assume we want to train the model the name "laika". As when training the bigram model, we are using data x and target y:

x: ‘l’,’a’,’i’,’k’,’a’ y: ‘a’,’i’,’k’,’a’,’ ‘

When training, we will present the sequence x ("laika") and compute loss with respect to the probability of the transformer generating y ("aika "), that is the probability for every token to generate the next one. During inference, we will feed the output back to the input of the transformer and use the output for the last token. That is when feeding "laik", the logits associated with the last token should predict "a". Then, when feeding "laika", the next prediction should be a " ". I’m sure something here can be interpreted as "shifted right", but I don’t find it very helpful.

In the encoder, we implemented what is known as full self-attention. Every token is compared with any other token. When generating new tokens, we are only interested in looking backward, however. This can be accomplished by simply masking all forward looking tokens. These are the tokens on the upper right triangle of the cosine similarity matrix, leading to masked multi-head attention:

B, T, C = xenc.shape

print(f"Unmasked attention:n {(xenc @ xenc.transpose(-2,-1)).softmax(dim=-1)}")

wei = xenc @ xenc.transpose(-2,-1) 
wei = wei.masked_fill(torch.tril(torch.ones(T,T)) == 0, float('-inf')) 
wei = F.softmax(wei, dim=-1) 

print(f"Masked attention:n {wei}")
Unmasked attention:
 tensor([[[0.4046, 0.1488, 0.1488, 0.1488, 0.1488],
         [0.1185, 0.3222, 0.1185, 0.1185, 0.3222],
         [0.1488, 0.1488, 0.4046, 0.1488, 0.1488],
         [0.1488, 0.1488, 0.1488, 0.4046, 0.1488],
         [0.1185, 0.3222, 0.1185, 0.1185, 0.3222]]])
Masked attention:
 tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2689, 0.7311, 0.0000, 0.0000, 0.0000],
         [0.2119, 0.2119, 0.5761, 0.0000, 0.0000],
         [0.1749, 0.1749, 0.1749, 0.4754, 0.0000],
         [0.1185, 0.3222, 0.1185, 0.1185, 0.3222]]])

All the action happens in a single line of code that uses _maskedfill(), to fill the cosine similarity matrix with float(‘-inf’) where a lower-triangular matrix (tril) of ones is zero. Conveniently, the SoftMax operation that follows will turn these entries to zero, removing their information from the attention model.

The output compares unmasked attention (as above) with masked attention, showing all zeros in the upper triangle of the output matrix. Where comparison was done both ways when doing full attention, the first "a" with the second and the other way round, the attention matrix is now only looking backwards. The first character only looks at itself, the second character only looks at the first, the third character only looks at the first and the second, and so on.

The masked multi-head attention block is followed by a cross-attention block. This block is also labeled as "multi-head attention" in the drawing, but you can see that the key and value encodings are coming from elsewhere, while the query value has been computed in the masked self-attention block. This block is optional and allows to alter the stream of generated tokens. Typical examples are translation problems. For example, the Transformer Decoder might have learned to generate arbitrary English text from training on internet-scale data, whereas an Encoder encodes French text to trigger the Decoder to spit out the appropriate translation. In this case, one would first train the Decoder on parroting random english sentences, and later train Encoder and Decoder together. Another application is question answering in ChatGPT. The decoder can generate arbitrary text, whereas the encoder encodes the user prompt to trigger the correct response.

This is how the two pieces work together:

Transformer Encoder (left) and Decoder (right) working togheter. From: Vaswani, A. "Attention is all you need." Advances in Neural Information Processing Systems (2017).
Transformer Encoder (left) and Decoder (right) working togheter. From: Vaswani, A. "Attention is all you need." Advances in Neural Information Processing Systems (2017).

In this article, we will only focus on the decoder part and train a model that outperforms a simple bigram model for name generation.

Putting things together

In order to train a transformer encoder to generate random names, we already have all the components, but need to add the mask to the Multi-Head Attention block:

class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by the number of heads"
        self.num_heads = num_heads
        self.dk = embed_dim // num_heads

        # Linear layers for query, key, and value (in the case of cross-attention, separate inputs are used)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, q, k, v):
        B, T, C = q.shape  # Assuming q, k, v have the same shape (B: batch size, T: sequence length, C: embedding dim)

        # Project Q, K, V using their respective linear layers
        q = self.q_proj(q)  # Shape: (B, T, C)
        k = self.k_proj(k)  # Shape: (B, T, C)
        v = self.v_proj(v)  # Shape: (B, T, C)

        # Reshape into (B, num_heads, T, dk)
        q = q.view(B, T, self.num_heads, self.dk).transpose(1, 2)  # (B, heads, T, dk)
        k = k.view(B, T, self.num_heads, self.dk).transpose(1, 2)  # (B, heads, T, dk)
        v = v.view(B, T, self.num_heads, self.dk).transpose(1, 2)  # (B, heads, T, dk)

        # Scaled dot-product attention with mask
        attn_weights = (q @ k.transpose(-2, -1)) / (self.dk ** 0.5)  # (B, heads, T, T)
        attn_weights = attn_weights.masked_fill(torch.tril(torch.ones(T,T).to(q.device) == 0, float('-inf')) 
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = attn_weights @ v  # (B, heads, T, dk)

        # Combine heads back to (B, T, C)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)

        # Final linear projection
        return self.out_proj(attn_output)

m = MaskedMultiHeadAttention(28,4)
attn = m(xenc,xenc,xenc)
attn.shape

In addition to changing the class name, we added only a single line in between the existing steps that compute the scaled dot-product attention. We need to also slightly change the TransformerEncoderLayer to create a TransformerDecoderLayer that uses the masked self attention.

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super(TransformerDecoderLayer, self).__init__()

        # Masked Multi-Head Attention
        self.self_attention = MaskedMultiHeadAttention(d_model, nhead)

        # Feedforward layer
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),  # First fully connected layer
            nn.ReLU(),                          # Non-linearity
            nn.Linear(dim_feedforward, d_model)  # Second fully connected layer
        )

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Dropout layer
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # Self-attention block
        attn_output = self.self_attention(src, src, src)
        src = self.norm1(src + attn_output)  # Add & Norm

        # Feedforward block
        ff_output = self.feedforward(src)
        src = self.norm2(src + self.dropout(ff_output))  # Add & Norm

        return src

encoder_layer = TransformerDecoderLayer(28, 4)
output = encoder_layer(xenc)
output.shape

For simplicity, this decoder does not contain the cross-attention block. It’s only now that we can put everything together into a model that we can train to generate random names:

class RandomNameGenerator(nn.Module):
  def __init__(self, d_model, nhead, nlayers, max_length):
    super().__init__()

    self.d_model = d_model
    self.nhead = nhead
    self.embed = nn.Embedding(len(alphabet), d_model)
#    self.pe = PositionalEncoding(d_model)
    self.wpe = nn.Embedding(max_length,d_model)
    self.decoder = nn.ModuleList([TransformerDecoderLayer(d_model, nhead) for _ in range(nlayers)])

    self.linear = nn.Linear(d_model, len(alphabet))
    self.softmax = nn.Softmax(dim=-1)

  def forward(self, x):
    B, T = x.size()

    x = self.embed(x)

    #x = self.pe(x)
    pos = torch.arange(0, T, dtype=torch.long, device=x.device).unsqueeze(0) # shape (1, t)
    x = x + self.wpe(pos)

    for layer in self.decoder:
      x = layer(x)
    x = self.linear(x)
    return x

  @torch.no_grad()
  def generate(self, x, max_new_tokens):
    for _ in range(max_new_tokens):
      logits = self(x)
      logits = logits[:, -1, :]
      probs = self.softmax(logits)
      next_token = torch.multinomial(probs, num_samples=1)
      if next_token == ctoi[' ']:
        break
      x = torch.cat((x, next_token), dim=1)
    return x[:,1:] # drop the first seed character

torch.manual_seed(42)
m = RandomNameGenerator(32, 4,2,16).to(device)

print(decode(m.generate(torch.tensor([0]).unsqueeze(0).to(device),8).tolist()[0]))
print(f"Model Parameters: {sum(p.numel() for p in m.parameters())}")
vrijajek
Model Parameters: 27484

One-hot encoding is realized via an embedding layer (embed) that projects from the dimension 28 (length of alphabet with space and padding characters) to the model dimension _dmodel.

We provide two implementations of positional encoding. The original one from the paper "Attention is all you need" (commented out), and a learned embedding wpe, that provides one row per token in the sequence up to the maximum length _maxlength. This embedding is also used in the forward routine.

We also now implement multiple decoder blocks in series, determined by the parameter nlayers, corresponding to the "Nx" in the drawing. We organize these in the PyTorch ModuleList datastructure, and use a for-loop to pass the input through all of them in the forward function.

We are adding a generate function that implements multi-nomial sampling on the last token and adds the resulting token to the result. How to use this generator is shown in the second-to-last line: we seed it with a torch.tensor of the " " character.

In this example, we instantiate the model with an internal dimension of 32, 4 heads, 2 layers, and a maximum token length of 16 for the positional embedding. We then print the total number of parameters, which will allow you to experiment with the impact of parameters like model dimension or numbers of layers yourself. In particular, you will be able to see that the architecture here scales very well (in fact it is very similar to that used for ChatGPT2), and it is easy to come up with parameters that create very large models.

Let’s train this model:

optimizer = torch.optim.AdamW(m.parameters(), lr=5e-4, weight_decay=0.01, betas=(0.9, 0.99), eps=1e-8)

for epoch in range(10):
  for xenc_batch, y_batch in train_loader:
    optimizer.zero_grad()

    logits = m(xenc_batch)
    logits = logits.view(-1, logits.size(-1))  # Shape: [batch_size * max_seq_len, vocab_size]
    y_batch = y_batch.view(-1)  # Shape: [batch_size * max_seq_len]

    # Compute the loss using CrossEntropyLoss
    loss = F.cross_entropy(logits, y_batch, ignore_index=ctoi['.'])

    # Backward pass
    m.zero_grad(set_to_none=True) # make sure ALL the gradients are set to zero
    loss.backward()

    optimizer.step()

  print(f"Epoch {epoch}, Loss: {loss}")
Epoch 0, Loss: 2.013244152069092
Epoch 1, Loss: 1.9868907928466797
Epoch 2, Loss: 2.0008857250213623
Epoch 3, Loss: 1.9368605613708496
Epoch 4, Loss: 1.9350165128707886
Epoch 5, Loss: 1.9508048295974731
Epoch 6, Loss: 1.8971132040023804
Epoch 7, Loss: 2.1034555435180664
Epoch 8, Loss: 2.100424289703369
Epoch 9, Loss: 2.0112760066986084

We use the Adam optimizer with weight decay from the paper Decoupled Weight Decay Regularization. As always, we have to remove the batch dimension by stacking all inputs and targets for _crossentropy() to work. Optimizers are covered in more detail here:

A primer on using PyTorch Optimizers

This is already a pretty good loss, and does not seem to be improving after the first epoch. (One epoch consists of presenting the entire training set once.) The reason for this is that we print only the loss of the last batch within each epoch. This makes it difficult to see the difference of small tweaks that we make to the architecture and whether the model is still learning at all.

A better way is randomly sample across the entire dataset to compute the average loss after every epoch. We can do this with a function like that:

@torch.inference_mode()
def evaluate(model, dataset, batch_size=50, max_batches=None):
    model.eval()
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=0, collate_fn=pad_sequences)
    losses = []
    for i, batch in enumerate(loader):
        X, Y = batch
        logits = model(X)
        logits = logits.view(-1, logits.size(-1))
        Y = Y.view(-1)  # Shape: [batch_size * max_seq_len]

        # Compute the loss using CrossEntropyLoss
        loss = F.cross_entropy(logits, Y, ignore_index=ctoi['.'])

        losses.append(loss.item())
        if max_batches is not None and i >= max_batches:
            break
    mean_loss = torch.tensor(losses).mean().item()
    model.train() # reset model back to training mode
    return mean_loss

As we make dataset a parameter, we can compute the loss against both training or validation datasets.

We can then add the following three lines to the training code above to compute the loss across 10*100 names from our train and validation dataset.

train_loss = evaluate(m, train_dataset, batch_size=100, max_batches=10)
test_loss  = evaluate(m, val_dataset,  batch_size=100, max_batches=10)

print(f"Epoch {epoch}, Train-Loss: {train_loss} Val-Loss: {test_loss}")

Running training again gets a much smoother output, clearly showing that the model is still learning as the training loss keeps decreasing.

Epoch 0, Train-Loss: 2.0593550205230713 Val-Loss: 2.0487308502197266
Epoch 1, Train-Loss: 2.005232572555542 Val-Loss: 2.0240142345428467
Epoch 2, Train-Loss: 1.9855084419250488 Val-Loss: 1.9795506000518799
Epoch 3, Train-Loss: 1.938616156578064 Val-Loss: 1.9944255352020264
Epoch 4, Train-Loss: 1.9284851551055908 Val-Loss: 1.9430066347122192
Epoch 5, Train-Loss: 1.9315165281295776 Val-Loss: 1.945677638053894
Epoch 6, Train-Loss: 1.923316240310669 Val-Loss: 1.9368290901184082
Epoch 7, Train-Loss: 1.9058040380477905 Val-Loss: 1.9276282787322998
Epoch 8, Train-Loss: 1.9039726257324219 Val-Loss: 1.9241658449172974
Epoch 9, Train-Loss: 1.8952724933624268 Val-Loss: 1.9211432933807373

We also notice that the validation loss, that is predicting words that were not in the training set using our model starts to stall, suggesting that we begin to overfit the data.

Here is what we get:

olonyuwu
airimar
un
emilly
estoni
irmelly
alace
urt
anthorie
alur

The name "Emilly" is identical to a sample of the training set, but the others are original creations.

Summary

The transformer is a neural network that has a self-attention mechanism at its core. Self-attention works by comparing each token in a sequence with each other token by expanding cosine similarity by non-linear, learned projections. As self-attention does not account for relative positions between tokens, positional encoding that is proportional to relative distances between tokens is added to a sequence before being processed in the self-attention mechanism. After self-attention, each tokens contains statistics about the entire sequence, which can then be fed to a fully-connected network to make predictions. As encoder blocks can be stacked vertically (layers) and horizontally (heads), the transformer architecture can create very large parameter models that are able to to learn patterns of unprecedented complexity. In particular, the architecture presented here is very close to that of the decoder in ChatGPT2.


Related Articles