Language Models for Sentence Completion

A practical application of a language model that picks the most likely candidate word that extends an English sentence by a single word

Dhruv Matani
Towards Data Science

--

Photo by Brett Jordan on Unsplash

Co-authored with Naresh Singh.

Table of contents

Introduction
Problem Statement
Brainstorming a solution

An LSTM model

A Transformer model

Conclusion

Introduction

Language models such as GPT have become very popular recently and are being used for a variety of text generation tasks, such as in ChatGPT or other conversational AI systems. These language models are huge, often exceeding tens of billions of parameters, and need a lot of computing resources and money to run.

In the context of English language models, these massive models are over-parameterized since they use the model’s parameters to memorize and learn aspects of our world instead of just modeling the English language. We can likely use a much smaller model if we have an application that requires the model to understand just the language and its constructs.

The complete code for running inference on the trained model can be found in this notebook.

Problem Statement

Let’s assume we’re building a swipe keyboard system that tries to predict the word you type in next on your mobile phone. Based on the pattern traced by the swipe pattern, there are many possibilities for the user’s intended word. However, many of these possible words aren’t actual words in English and can be eliminated. Even after this initial pruning and elimination step, many candidates remain, and we need to pick one as a suggestion for the user.

To further prune this list of candidates, we can use a deep-learning-based language model that looks at the provided context and tells us which candidate is most likely to complete the sentence.

For example, if the user has typed the sentence “I’ve scheduled this” and then swipes a pattern as shown below

Then, some possible English language words that the user could have meant are:

  1. messing
  2. meeting

However, if we think about it, it’s probably more likely that the user meant “meeting” and not “messing” because of the word “scheduled” in the earlier part of the sentence.

Given everything we know so far, what options do we have for doing this pruning programmatically? Let’s brainstorm some solutions in the section below.

Brainstorming a solution

Algorithms and Data Structures

Using first principles, it seems reasonable to start with a corpus of data, find pairs of words that come together, and train a Markov model that predicts the probability of the pair occurring in a sentence. You’ll notice two significant issues with this approach.

  1. Space utilization: There are anywhere between 250k to 1 million words in the English language, which don’t include the numerous proper nouns that are constantly growing in volume. Hence, any traditional software solution modeling the probability of a pair of words occurring together must maintain a lookup table with 250k*250k = 62.5 billion word pairs, which is somewhat excessive. It seems likely that many pairs don’t occur very often and can be pruned. Even after pruning, there are a lot of pairs to worry about.
  2. Completeness: Encoding the probability of a pair of words doesn’t do justice to the problem at hand. For example, the earlier sentence context is completely lost when you’re looking at just the most recent pair of words. In the sentence “How is your day coming” if you want to check the word after “coming”, you’d have a lot of pairs starting with “coming”. This misses the entire sentence context before that word. One can imagine using word triplets, etc.… but this exacerbates the problem of space utilization mentioned above.

Let’s shift our focus to a solution that leverages the nature of the English language and see if that can help us here.

NLP (Natural Language Processing)

Historically, the area of NLP (natural language processing) involved understanding the parts of speech (POS) of a sentence and using that information to perform such pruning and prediction decisions. One can imagine using a POS tag associated with each word to determine if the following word in a sentence is valid.

However, the process of computing the parts of speech for a sentence is a complex process in itself, and requires specialized understanding of language as evidenced in this page on NLTK’s parts of speech tagging.

Next, let’s take a look at a deep-learning-based approach that requires a lot more tagged data, but not as much language expertise to build.

Deep Learning (Neural Networks)

The area of NLP has been upended by the advent of deep learning. With the invention of LSTM and Transformer based language models, the solution more often than not involves throwing some high-quality data at a model and training it to predict the next word.

In essence, this is what the GPT model is doing. GPT (Generative Pre-Trained Transformer) models are trained to predict the next word (token) given a prefix of a sentence.

Given the sentence prefix “It is such a wonderful”, it’s likely for the model to provide the following as high-probability predictions for the word following the sentence.

  1. day
  2. experience
  3. world
  4. life

It’s also likely that the following words will have a lower probability of completing the sentence prefix.

  1. red
  2. mouse
  3. line

The Transformer model architecture is at the heart of systems such as ChatGPT. However, for the more restricted use case of learning English language semantics, we can use a cheaper-to-run model architecture such as an LSTM (long short-term memory) model.

An LSTM model

Let’s build a simple LSTM model and train it to predict the next token given a prefix of tokens. Now, you might ask what a token is.

Tokenization

Typically for language models, a token can mean

  1. A single character (or a single byte)
  2. An entire word in the target language
  3. Something in between 1 and 2. This is usually called a sub-word

Mapping a single character (or byte) to a token is very restrictive since we’re overloading that token to hold a lot of context about where it occurs. This is because the character “c” for example, occurs in many different words, and to predict the next character after we see the character “c” requires us to really look hard at the leading context.

Mapping a single word to a token is also problematic since English itself has anywhere between 250k and 1 million words. In addition, what happens when a new word is added to the language? Do we need to go back and re-train the entire model to account for this new word?

Sub-word tokenization is considered the industry standard in the year 2023. It assigns substrings of bytes frequently occurring together to unique tokens. Typically, language models have anywhere from a few thousand (say 4,000) to tens of thousands (say 60,000) of unique tokens. The algorithm to determine what constitutes a token is determined by the BPE (Byte pair encoding) algorithm.

To choose the number of unique tokens in our vocabulary (called the vocabulary size), we need to be mindful of a few things:

  1. If we choose too few tokens, we’re back in the regime of a token per character, and it’s hard for the model to learn anything useful.
  2. If we choose too many tokens, we end up in a situation where the model’s embedding tables over-shadow the rest of the model’s weight and it becomes hard to deploy the model in a constrained environment. The size of the embedding table will depend on the number of dimensions we use for each token. It’s not uncommon to use a size of 256, 512, 786, etc… If we use a token embedding dimension of 512, and we have 100k tokens, we end up with an embedding table that uses 200MiB in memory.

Hence, we need to strike a balance when choosing the vocabulary size. In this example, we pick 6600 tokens and train our tokenizer with a vocabulary size of 6600. Next, let’s take a look at the model definition itself.

The PyTorch Model

The model itself is pretty straightforward. We have the following layers:

  1. Token Embedding (vocab size=6600, embedding dim=512), for a total size of about 15MiB (assuming 4 byte float32 as the embedding table’s data type)
  2. LSTM (num layers=1, hidden dimension=786) for a total size of about 16MiB
  3. Multi-Layer Perceptron (786 to 3144 to 6600 dimensions) for a total size of about 93MiB

The complete model has about 31M trainable parameters for a total size of about 120MiB.

Here’s the PyTorch code for the model.

class WordPredictionLSTMModel(nn.Module):
def __init__(self, num_embed, embed_dim, pad_idx, lstm_hidden_dim, lstm_num_layers, output_dim, dropout):
super().__init__()
self.vocab_size = num_embed
self.embed = nn.Embedding(num_embed, embed_dim, pad_idx)
self.lstm = nn.LSTM(embed_dim, lstm_hidden_dim, lstm_num_layers, batch_first=True, dropout=dropout)
self.fc = nn.Sequential(
nn.Linear(lstm_hidden_dim, lstm_hidden_dim * 4),
nn.LayerNorm(lstm_hidden_dim * 4),
nn.LeakyReLU(),
nn.Dropout(p=dropout),

nn.Linear(lstm_hidden_dim * 4, output_dim),
)
#

def forward(self, x):
x = self.embed(x)
x, _ = self.lstm(x)
x = self.fc(x)
x = x.permute(0, 2, 1)
return x
#
#

Here’s the model summary using torchinfo.

LSTM Model Summary

=================================================================
Layer (type:depth-idx) Param #
=================================================================
WordPredictionLSTMModel -
├─Embedding: 1–1 3,379,200
├─LSTM: 1–2 4,087,200
├─Sequential: 1–3 -
│ └─Linear: 2–1 2,474,328
│ └─LayerNorm: 2–2 6,288
│ └─LeakyReLU: 2–3 -
│ └─Dropout: 2–4 -
│ └─Linear: 2–5 20,757,000
=================================================================
Total params: 30,704,016
Trainable params: 30,704,016
Non-trainable params: 0
=================================================================

Interpreting the accuracy: After training this model on 12M English language sentences for about 8 hours on a P100 GPU, we achieved a loss of 4.03, a top-1 accuracy of 29% and a top-5 accuracy of 49%. This means that 29% of the time, the model was able to correctly predict the next token, and 49% of the time, the next token in the training set was one of the top 5 predictions by the model.

What should our success metric be? While the top-1 and top-5 accuracy numbers for our model aren’t impressive, they aren’t as important for our problem. Our candidate words are a small set of possible words that fit the swipe pattern. What we want from our model is to be able to select an ideal candidate to complete the sentence such that it is syntactically and semantically coherent. Since our model learns the nature of language through the training data, we expect it to assign a higher probability to coherent sentences. For example, if we have the sentence “The baseball player” and possible completion candidates (“ran”, “swam”, “hid”), then the word “ran” is a better follow-up word than the other two. So, if our model predicts the word ran with a higher probability than the rest, it works for us.

Interpreting the loss: A loss of 4.03 means that the negative log-likelihood of the prediction is 4.03, which means that the probability of predicting the next token correctly is e^-4.03 = 0.0178 or 1/56. A randomly initialized model typically has a loss of about 8.8 which is -log_e(1/6600), since the model randomly predicts 1/6600 tokens (6600 being the vocabulary size). While a loss of 4.03 may not seem great, it’s important to remember that the trained model is about 120x better than an untrained (or randomly initialized) model.

Next, let’s take a look at how we can use this model to improve suggestions from our swipe keyboard.

Using the model to prune invalid suggestions

Let’s take a look at a real example. Suppose we have a partial sentence “I think”, and the user makes the swipe pattern shown in blue below, starting at “o”, going between the letters “c” and “v”, and ending between the letters “e” and “v”.

Some possible words that could be represented by this swipe pattern are

  1. Over
  2. Oct (short for October)
  3. Ice
  4. I’ve (with the apostrophe implied)

Of these suggestions, the most likely one is probably going to be “I’ve”. Let’s feed these suggestions into our model and see what it spits out.

[I think] [I've] = 0.00087
[I think] [over] = 0.00051
[I think] [ice] = 0.00001
[I think] [Oct] = 0.00000

The value after the = sign is the probability of the word being a valid completion of the sentence prefix. In this case, we see that the word “I’ve” has been assigned the highest probability. Hence, it is the most likely word to follow the sentence prefix “I think”.

The next question you might have is how we can compute these next-word probabilities. Let’s take a look.

Computing the next word probability

To compute the probability that a word is a valid completion of a sentence prefix, we run the model in eval (inference) mode and feed in the tokenized sentence prefix. We also tokenize the word after adding a whitespace prefix to the word. This is done because the HuggingFace pre-tokenizer splits words with spaces at the beginning of the word, so we want to make sure that our inputs are consistent with the tokenization strategy used by HuggingFace Tokenizers.

Let’s assume that the candidate word is made up of 3 tokens T0, T1, and T2.

  1. We first run the model with the original tokenized sentence prefix. For the last token, we check the probability of predicting token T0. We add this to the “probs” list.
  2. Next, we run a prediction on the prefix+T0 and check the probability of token T1. We add this probability to the “probs” list.
  3. Next, we run a prediction on the prefix+T0+T1 and check the probability of token T2. We add this probability to the “probs” list.

The “probs” list contains the individual probabilities of generating the tokens T0, T1, and T2 in sequence. Since these tokens correspond to the tokenization of the candidate word, we can multiply these probabilities to get the combined probability of the candidate being a completion of the sentence prefix.

The code for computing the completion probabilities is shown below.

 def get_completion_probability(self, input, completion, tok):
self.model.eval()
ids = tok.encode(input).ids
ids = torch.tensor(ids, device=self.device).unsqueeze(0)
completion_ids = torch.tensor(tok.encode(completion).ids, device=self.device).unsqueeze(0)
probs = []
for i in range(completion_ids.size(1)):
y = self.model(ids)
y = y[0,:,-1].softmax(dim=0)
# prob is the probability of this completion.
prob = y[completion_ids[0,i]]
probs.append(prob)
ids = torch.cat([ids, completion_ids[:,i:i+1]], dim=1)
#
return torch.tensor(probs)
#

We can see some more examples below.

[That ice-cream looks] [really] = 0.00709
[That ice-cream looks] [delicious] = 0.00264
[That ice-cream looks] [absolutely] = 0.00122
[That ice-cream looks] [real] = 0.00031
[That ice-cream looks] [fish] = 0.00004
[That ice-cream looks] [paper] = 0.00001
[That ice-cream looks] [atrocious] = 0.00000

[Since we're heading] [toward] = 0.01052
[Since we're heading] [away] = 0.00344
[Since we're heading] [against] = 0.00035
[Since we're heading] [both] = 0.00009
[Since we're heading] [death] = 0.00000
[Since we're heading] [bubble] = 0.00000
[Since we're heading] [birth] = 0.00000

[Did I make] [a] = 0.22704
[Did I make] [the] = 0.06622
[Did I make] [good] = 0.00190
[Did I make] [food] = 0.00020
[Did I make] [color] = 0.00007
[Did I make] [house] = 0.00006
[Did I make] [colour] = 0.00002
[Did I make] [pencil] = 0.00001
[Did I make] [flower] = 0.00000

[We want a candidate] [with] = 0.03209
[We want a candidate] [that] = 0.02145
[We want a candidate] [experience] = 0.00097
[We want a candidate] [which] = 0.00094
[We want a candidate] [more] = 0.00010
[We want a candidate] [less] = 0.00007
[We want a candidate] [school] = 0.00003

[This is the definitive guide to the] [the] = 0.00089
[This is the definitive guide to the] [complete] = 0.00047
[This is the definitive guide to the] [sentence] = 0.00006
[This is the definitive guide to the] [rapper] = 0.00001
[This is the definitive guide to the] [illustrated] = 0.00001
[This is the definitive guide to the] [extravagant] = 0.00000
[This is the definitive guide to the] [wrapper] = 0.00000
[This is the definitive guide to the] [miniscule] = 0.00000

[Please can you] [check] = 0.00502
[Please can you] [confirm] = 0.00488
[Please can you] [cease] = 0.00002
[Please can you] [cradle] = 0.00000
[Please can you] [laptop] = 0.00000
[Please can you] [envelope] = 0.00000
[Please can you] [options] = 0.00000
[Please can you] [cordon] = 0.00000
[Please can you] [corolla] = 0.00000

[I think] [I've] = 0.00087
[I think] [over] = 0.00051
[I think] [ice] = 0.00001
[I think] [Oct] = 0.00000

[Please] [can] = 0.00428
[Please] [cab] = 0.00000

[I've scheduled this] [meeting] = 0.00077
[I've scheduled this] [messing] = 0.00000

These examples show the probability of the word completing the sentence before it. The candidates are sorted in decreasing order of probability.

Since Transformers are slowly replacing LSTM and RNN models for sequence-based tasks, let’s take a look at what a Transformer model for the same objective would look like.

A Transformer model

Transformer-based models are a very popular architecture for training language models to predict the next word in a sentence. The specific technique we’ll use is the causal attention mechanism. We’ll train just the transformer encoder layer in PyTorch using causal attention. Causal attention means we’ll allow every token in the sequence to only look at the tokens before it. This resembles the information that a unidirectional LSTM layer uses when trained only in the forward direction.

The Transformer model we’ll see here is based directly on the nn.TransformerEncoder and nn.TransformerEncoderLayer in PyTorch.

import math

def generate_src_mask(sz, device):
return torch.triu(torch.full((sz, sz), True, device=device), diagonal=1)
#

class PositionalEmbedding(nn.Module):
def __init__(self, sequence_length, embed_dim):
super().__init__()
self.sqrt_embed_dim = math.sqrt(embed_dim)
self.pos_embed = nn.Parameter(torch.empty((1, sequence_length, embed_dim)))
nn.init.uniform_(self.pos_embed, -1.0, 1.0)
#

def forward(self, x):
return x * self.sqrt_embed_dim + self.pos_embed[:,:x.size(1)]
#
#

class WordPredictionTransformerModel(nn.Module):
def __init__(self, sequence_length, num_embed, embed_dim, pad_idx, num_heads, num_layers, output_dim, dropout, norm_first, activation):
super().__init__()
self.vocab_size = num_embed
self.sequence_length = sequence_length
self.embed_dim = embed_dim
self.sqrt_embed_dim = math.sqrt(embed_dim)
self.embed = nn.Sequential(
nn.Embedding(num_embed, embed_dim, pad_idx),
PositionalEmbedding(sequence_length, embed_dim),
nn.LayerNorm(embed_dim),
nn.Dropout(p=0.1),
)

encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim, nhead=num_heads, dropout=dropout, batch_first=True, norm_first=norm_first, activation=activation,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.fc = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.LayerNorm(embed_dim * 4),
nn.LeakyReLU(),
nn.Dropout(p=dropout),

nn.Linear(embed_dim * 4, output_dim),
)
#

def forward(self, x):
src_attention_mask = generate_src_mask(x.size(1), x.device)
x = self.embed(x)
x = self.encoder(x, is_causal=True, mask=src_attention_mask)
x = self.fc(x)
x = x.permute(0, 2, 1)
return x
#
#

We can plug this model in place of the LSTM model that we used before since it’s API is compatible. This model takes longer to train for the same amount of training data and has comparable performance.

Transformer models are better for long sequences. In our case, we have sequences of length 256. Most of the context needed to perform next-word completion tends to be local, so we don’t really need the power of Transformers here.

Conclusion

We saw how we can solve very practical NLP problems using deep learning techniques based on LSTM (RNN) and Transformer models. Not every language task requires the use of models with billions of parameters. Specialized applications that require modeling language itself, and not memorizing large volumes of information can be handled using much smaller models that can be deployed easily and more efficiently than the massive language models that we are used to seeing these days.

All the image(s) except for the first one were created by the author(s).

--

--

Machine Learning, PyTorch, CNNs, Transformers, Vision, Speech, Text AI. On-Device AI, Model Optimization, ML and Data Infrastructure. My views are my own.