The world’s leading publication for data science, AI, and ML professionals.

Speculative Sampling – Intuitively and Exhaustively Explained

A drop-in strategy for making language models 3x faster

Machine Learning | Natural Language Processing | Data Science

Exploring the drop-in strategy that’s speeding up language models by 3x

"Speculators" by Daniel Warfield using MidJourney and Affinity Design 2. All images by the author unless otherwise specified.
"Speculators" by Daniel Warfield using MidJourney and Affinity Design 2. All images by the author unless otherwise specified.

In this article we’ll discuss "Speculative Sampling", a strategy that makes text generation faster and more affordable without compromising on performance. In doing so, we’ll take a thorough look at some of the more subtle aspects of language models.

Empirical results of using speculative sampling on a variety of text generation tasks. Notice how, in all cases, generation time is significantly faster. Source
Empirical results of using speculative sampling on a variety of text generation tasks. Notice how, in all cases, generation time is significantly faster. Source

First we’ll discuss a major problem that’s slowing down modern language models, then we’ll build an intuitive understanding of how speculative sampling elegantly speeds them up, then we’ll implement speculative sampling from scratch in Python.

Who is this useful for? Anyone interested in natural language processing (NLP), or cutting edge AI advancements.

How advanced is this post? The concepts in this article are accessible to Machine Learning enthusiasts, and are cutting edge enough to interest seasoned data scientists. The code at the end may be useful to developers.

Pre-requisites: It might be useful to have a cursory understanding of Transformers, OpenAI’s GPT models, or both. If you find yourself confused, you can refer to either of these articles:

GPT – Intuitively and Exhaustively Explained

Transformers – Intuitively and Exhaustively Explained

Language Models Are Getting Too Big

Over the last four years OpenAI’s GPT models have grown from 117 million parameters in 2018 to an estimated 1.8 Trillion parameters in 2023. This rapid growth can largely be attributed to the fact that, in language modeling, bigger is better.

A graph of model size vs performance, showing that bigger is better. From my article on GPT. Original source
A graph of model size vs performance, showing that bigger is better. From my article on GPT. Original source

As a result, the last few years have been an arms race. Numerous companies have been dropping billions of dollars on fancy graphics cards to the schagrin of Fortnight players everywhere.

The issue is, the models are simply getting too big. Language models, like the ones used in ChatGPT, have to generate their responses one word at a time through a process called "autoregressive generation". The bigger the model gets, the more money and time it takes to generate output word by word.

A decoder only style model, like GPT, iteratively constructing an output. The model takes an input "Translate to French: I am a manager", and generates the response word by word by using the previous outputs as part of the input. This style of text generation is called "autoregressive generation". From my article on GPT
A decoder only style model, like GPT, iteratively constructing an output. The model takes an input "Translate to French: I am a manager", and generates the response word by word by using the previous outputs as part of the input. This style of text generation is called "autoregressive generation". From my article on GPT

OpenAI’s GPT-4, based on a leak by some guy on twitter, uses a bunch of technologies to get around this problem. One of them, the topic of this article, is speculative sampling.

Speculative Sampling in a Nutshell

Speculative sampling (also known as "Speculative Decoding" or "Speculative Generation") was simultaneously proposed in two papers, both of which suggest a speedup in text generation by around 3x:

Despite being published independently, both approaches are functionally identical, so we’ll treat them synonymously.

The fundamental idea of speculative sampling is that bigger language models are better because some examples of text generation are difficult, but not all examples. For instance, suppose you ask a language model about the geological composition of the moon. To formulate a coherent response the model has to understand fancy sciency stuff, and also has to put words like "a", "and", and "of" in the right spots. Knowing the moon consists of something called "Breccias" is more difficult than knowing the word "are" might come after the word "which".

A conceptual demonstration of generative difficulty. When a model predicts a response sequence, word by word, some of the words are difficult to predict because they require in depth knowledge, but some of them are easy because they can be inferred by simple grammar or context clues. In this example the text in red might be more difficult to predict than the text in blue.
A conceptual demonstration of generative difficulty. When a model predicts a response sequence, word by word, some of the words are difficult to predict because they require in depth knowledge, but some of them are easy because they can be inferred by simple grammar or context clues. In this example the text in red might be more difficult to predict than the text in blue.

Speculative sampling exploits the idea of varying degrees of difficulty by using two language models; a target model and a draft model:

  • The target model is the super big, super smart model we’re trying to speed up.
  • The draft model is a smaller, dumber, and faster model.

The idea is to use the draft model to predict numerous words in the sequence, then ask the target model to confirm that all the generated words are good. We can throw away all disagreements, resulting in an output which is identical to what the target model would output if it was working alone.

An example of speculative generation in action. In the first row, the draft model output "japan's benchmark bond", but the target model disagreed with "bond" in favor of "n". The word "bond" was replaced with "n", and anything the draft model might have predicted after the word "bond" is thrown out. In effect, this allows a speculative generation system to output multiple words for each pass of the target model. source
An example of speculative generation in action. In the first row, the draft model output "japan’s benchmark bond", but the target model disagreed with "bond" in favor of "n". The word "bond" was replaced with "n", and anything the draft model might have predicted after the word "bond" is thrown out. In effect, this allows a speculative generation system to output multiple words for each pass of the target model. source

A Natural Question

If you’re anything like me, you might be a bit confused. The common intuition, and the intuition that I communicated in both my article on transformers and my article on GPT, is that language models predict output word by word. Under that intuition it’s not exactly obvious how a target model might efficiently "double check" the output of the draft model; if the target model has to check predictions one by one, then what’s the point of going through the trouble of using the draft model in the first place?

If a model like GPT outputs text word by word, wouldn't it have to check the output of the draft model word by word, thus taking the same amount of time? No, and we'll talk about why in the next section.
If a model like GPT outputs text word by word, wouldn’t it have to check the output of the draft model word by word, thus taking the same amount of time? No, and we’ll talk about why in the next section.

The idea of speculative sampling requires a thorough understanding of the exact output of Transformers. There are some quirks which normally aren’t relevant, but are very relevant for speculative sampling.

The Secret Outputs of Transformers, and How Speculative Sampling Uses Them

As I discussed in my article on the original transformer architecture, the thing that made transformers so special was their ability to parallelize training. Before transformers, models like LSTMs had to be trained word by word, which was a slow and expensive process.

When a model like GPT is trained, an entire input sequence is provided at input, and the model is asked to predict that same sequence, just shifted by one word. The model is then trained to minimize the flaws of its predictions.

An example of the training process for a language model, like GPT. The model is given an input sequence, shifted to the right by a token which designates the start of a sequence, and the model is asked to predict that same sequence, in a single pass. Any errors are used to train the model. Essentially, the model is trained to predict all next words simultaneously.
An example of the training process for a language model, like GPT. The model is given an input sequence, shifted to the right by a token which designates the start of a sequence, and the model is asked to predict that same sequence, in a single pass. Any errors are used to train the model. Essentially, the model is trained to predict all next words simultaneously.

So, if the model has access to the entire input sequence, wouldn’t it cheat by just moving each word over by one space? No, and that’s because of masking.

Transformers use "masked" self attention, which essentially blocks out information about future words from reaching the information for a given word. I’ll probably cover masking in it’s own dedicated article, it’s definitely worthy of a deeper dive, but the intuition is this: By setting certain values in the self attention mechanism to zero, the prediction of a given word is not influenced by future words.

A conceptual diagram of training with masking. In effect, with masking, a language model is asked to predict all next words simultaneously.
A conceptual diagram of training with masking. In effect, with masking, a language model is asked to predict all next words simultaneously.

Typically, when using a transformer, we only care about a prediction of the next word in a sequence; that’s how we get text to generate and cause venture capitalists to empty their pockets. However, technically, the model has outputs for the entire sequence as if the next words in the sequence did not exist, because of the way the model is trained.

The true output of a transformer based language model like GPT. While we usually only care about the final last word prediction, technically it predicts all next words in the sequence.
The true output of a transformer based language model like GPT. While we usually only care about the final last word prediction, technically it predicts all next words in the sequence.

And that’s how the target model can quickly check numerous predictions from the draft model. If we give the draft models output to the target model as input, and ask the target model to predict the next word, we can compare the predicted values for every word in the sequence. If there’s a discrepancy we can stop there and use the target model’s output.

Suppose the text in blue was generated by the draft model, and the target model disagreed with some word in the sequence as highlighted in red and underlined. All draft generated before the disagreement can be accepted, and all drafted text after the disagreement must be rejected. At the point of first disagreement we use the output from the target model. In effect, we just generated "primarily made up of rocks and regolith" with a single pass of the target model.
Suppose the text in blue was generated by the draft model, and the target model disagreed with some word in the sequence as highlighted in red and underlined. All draft generated before the disagreement can be accepted, and all drafted text after the disagreement must be rejected. At the point of first disagreement we use the output from the target model. In effect, we just generated "primarily made up of rocks and regolith" with a single pass of the target model.

A cool note about this process generally. Every single time we run the target model, it predicts the next word in the sequence. The target model might confirm all of the predictions of the draft model, or it disagree with all of them. Regardless, the target model will always predict a new word. As a result, in a worst case scenario where the draft model consistently outputs incorrect information, the entire system is as fast as if we were only using the target model. In other words, speculative sampling can’t slow down generation, it can only make generation faster (at least, when it’s implemented correctly).

Sequences, Tokens, TokenIds, Logits, and Probabilities

That was the theory. Before we dive into the code we should discuss some technical details about how transformers function.

Text, from a language modeling perspective, is conceptualized as a sequence; **** a list of "things" that come one after another. Typically these "things" can be conceptualized as words, but in reality they’re a bit more abstract than that.

A machine learning model first breaks the input sequence into tokens, which are the "things" that make up a sequence. This can be done using one of many algorithms, but the end result is that the input sequence is divided into atomic chunks. These might be individual words, portions of words, multiple words, punctuation, numbers, or spaces.

An example of tokenization, using a tokenizer called "sentencepiece"
An example of tokenization, using a tokenizer called "sentencepiece"

Each of the tokens extracted from a tokenizer has a unique number, called the TokenId. Typically, a transformer style model learns a representative vector for each TokenId, which then becomes the input to the model. There’s one vector associated with each TokenId, which the model optimizes throughout training.

The same tokens, with their associated Ids
The same tokens, with their associated Ids

After the data goes through numerous rounds of self attention within the model, the data becomes an abstract sequence of vectors, one for each output. This is sometimes referred to as the "final hidden state".

The input, which has vectors which cleanly correspond with each word, get passed through numerous layers of self attention. This process creates highly abstract representations. From my article on transformers.
The input, which has vectors which cleanly correspond with each word, get passed through numerous layers of self attention. This process creates highly abstract representations. From my article on transformers.

This is passed through a language modeling head, which converts the model’s abstract representation into a representation that corresponds directly to the tokenizer. There’s a set number of TokenIds for a given tokenizer, and the language modeling head converts the output of the model into vectors which contain the same number of values.

After the transformer does its thing, the final hidden state of the model is passed through a language modeling head, which re-structures the data into a format which directly corresponds to whatever tokenizer the model is being trained with.
After the transformer does its thing, the final hidden state of the model is passed through a language modeling head, which re-structures the data into a format which directly corresponds to whatever tokenizer the model is being trained with.

These outputs are called logits. Typically, the term "logit" is used to refer to the unfiltered, unprocessed, true output of the model. This is the thing that usually gets optimized. logits are typically compared to each other using a softmax function, which converts the logits into probabilities. Big logit values become big probabilities, small logit values become small probabilities.

A conceptual diagram of logits getting converted to probabilities
A conceptual diagram of logits getting converted to probabilities

These probabilities can then be converted into tokens, which then can be used to construct the output sequence. There are a few ways to go about doing that though.

  • You can simply always choose to use the highest probability token
  • You could randomly select an output in a manner which is weighted by probability
  • You could do a more complex strategy like "top K sampling".

Regardless, the probabilities become a tokenId, that tokenId becomes the token itself, and from the tokens, the output can be constructed.

So, to recap:

  • Sequence: This is typically used in reference to the input and output text, but can also be conceptualized as a sequence of tokens, sequence of tokenIds, sequence of logits, sequence of probabilities, whatever. "The sequence" can mean a few things, depending on the context of the discussion
  • Token: Text can be divided into atomic tokens with a tokenizer. These are used to break text up into atomic, predefined chunks. Sometimes these cleanly correspond to words, and sometimes they don’t.
  • TokenId: Each token has a corresponding TokenId, which is simply a number. The model uses this number to retrieve a learned vector for that token, thus constructing the input to the model
  • Logits and Probabilities: After the model does its thing, it outputs a series of values. These are typically softmaxed, and thus turned into probabilities. The probabilities are used to select output tokens.

Speculative Sampling in PyTorch

Now that we understand logits, probabilities, and tokens, we can start diving into a practical example of Speculative Sampling.

Let’s keep it simple: We’ll use the maximum logit to decide which token gets generated on each step. If both the draft and target models output the same max value, we’ll say they agree.

Full code can be found here:

MLWritingAndResearch/SpeculativeSampling.ipynb at main · DanielWarfield1/MLWritingAndResearch

Loading the Models

First, we need a draft and a target model. I’m using T5 in this example, which stands for "Text to Text Transfer Transformer". It’s an encoder-decoder style transformer ([like the one I talk about in this article](https://medium.com/p/c70c38e87491)), which differs from a decoder only model (like the one I talk about in this article). Regardless, it has a decoder, so it will work for our purposes. Also, conveniently, T5 comes in a variety of sizes, pre-traind, and easily available on huggingface.

"""Loading the draft model
"""

from transformers import T5Tokenizer, T5ForConditionalGeneration

#loading the draft model
draft = "google/flan-t5-large"
draft_tokenizer = T5Tokenizer.from_pretrained(draft)
draft_model = T5ForConditionalGeneration.from_pretrained(draft)
"""Loading the target model
"""

#loading the target model
target = "google/flan-t5-xl"
target_tokenizer = T5Tokenizer.from_pretrained(target)
target_model = T5ForConditionalGeneration.from_pretrained(target)

The whole idea of speculative decoding relies on the draft and target model having the same tokens. So, just to double check, I confirmed that the tokenizers for both models behaved similarly.

"""Ensuring the tokenizers are identical
in order for speculative sampling to work, tokenization for both the draft
and target model must be identical. This is a sanity check to make sure they are.
"""

#tokenizing a test sequence
tokenizer_test = "this, is, some [text] for 1234comparing, tokenizers adoihayyuz"
ex1 = target_tokenizer(prompt, return_tensors="pt").input_ids
ex2 = draft_tokenizer(prompt, return_tensors="pt").input_ids

#zero means all tokenized values are the same, so the tokenizers are
#more than likely identical
print((ex1-ex2).abs().max())
The "0", in this case, means both tokenizers behave similarly
The "0", in this case, means both tokenizers behave similarly

Building Speculative Sampling

Once you have the models, you just, kinda.. do speculative sampling. Of course, as previously mentioned, to do speculative sampling productively you need a whole architecture that can handle parallelized cues of information. In this example I’m simply doing drafting and checking within the same loop on a single machine. It’s not a very complicated process, but there are some loops and logic that need to happen to get it all working. Here’s the code:

"""Performing Speculative Sampling
"""

#initializing an empty input to feed to the decoder.
#this is updated each loop with valid generations
decoder_ids = draft_model._shift_right(draft_tokenizer("", return_tensors="pt").input_ids)

#defining input. T5 is an encoder-decoder model, so input and output are handled seperatly
input_ids = draft_tokenizer("Translate to German n Battle not with monsters, lest ye become a monster, and if you gaze into the abyss, the abyss gazes also into you.", return_tensors="pt").input_ids

#defining the number of draft generations
k = 5

#keeps track of generation information, for later printouts
generated = []

#Generating Text
iter = 0
for _ in range(15):
    print('========== Speculative Sampling Iteration {} =========='.format(iter))
    iter+=1

    #creating a holding place for the generated draft
    decoder_ids_draft = decoder_ids.clone()

    before_text = draft_tokenizer.decode(decoder_ids_draft[0])
    initial_length = decoder_ids.shape[1]

    #generating draft
    for i in range(k):

        #predicting the next token with the draft model
        with torch.no_grad():
            logits = draft_model(input_ids=input_ids, decoder_input_ids=decoder_ids_draft).logits
            genid = torch.argmax(logits, dim=2)[0][-1]

        #appending the generated id to the draft
        genid = genid.expand(1,1)
        decoder_ids_draft = torch.cat((decoder_ids_draft,genid),1)

    print('=== Draft Generation')
    current_draft = draft_tokenizer.decode(decoder_ids_draft[0])
    print('generated draft tokens: {}'.format(decoder_ids_draft))
    print('generated draft text: {}'.format(current_draft))

    #Generating all next token predictions with the target
    logits = target_model(input_ids=input_ids, decoder_input_ids=decoder_ids_draft).logits
    genids = torch.argmax(logits, dim=2)[0]
    print('=== Target Generation')
    current_target = draft_tokenizer.decode(genids)
    print('generated target tokens: {}'.format(genids))
    print('generated target text: {}'.format(current_target))

    #checking draft against target
    for i, (dv, tv) in enumerate(zip(decoder_ids_draft[0,1:],genids[:-1])):
        #target does not agree with the draft
        if dv != tv:
            #genids is next word, so this is done to preserve the first token
            first_token = decoder_ids[0][:1]
            decoder_ids = genids[:i+1]
            decoder_ids = torch.cat((first_token,decoder_ids),0)
            break
    else:
        #no disagreements
        decoder_ids = genids

    print('=== Validated Generation')
    current_target = draft_tokenizer.decode(decoder_ids)
    print('generated target tokens: {}'.format(decoder_ids))
    print('generated target text: {}'.format(current_target))

    #expanding dimensions so that the shape of the tensor is the same
    decoder_ids = decoder_ids.expand(1,len(decoder_ids))

    #logging
    numgen = decoder_ids.shape[1] - initial_length
    generated.append({'tokens generated': numgen, 'text before': before_text, 'text after': current_target})

Once concluded, we can observe how many tokens were generated in each loop. In this example we’re asking the model to translate a famous quote from English to German:

Every iteration of speculative sampling.
Every iteration of speculative sampling.

As you can see, with the chosen task and models, most iterations did not have useful draft output. In some examples however, like 8 and 11, the draft model allowed the system to effectively generate five tokens in one run of the target model. The models used in this example are fairly small. I imagine, when dealing with larger models, the draft model would be more useful more often.

Conclusion

And that’s it. Speculative sampling is an incredibly elegant way to drastically speed up text generation. We use a small language model to quickly generate output, then (by using a quirk of masked attention during training) we can use a large language model to double check that work essentially for free. We only keep generated text that the larger model agrees with, so at the end we get the same output, only faster.

Follow For More!

I describe papers and concepts in the ML space, with an emphasis on practical and intuitive explanations.

Get an email whenever Daniel Warfield publishes

Never expected, always appreciated. By donating you allow me to allocate more time and resources towards more frequent and higher quality articles. Learn More
Never expected, always appreciated. By donating you allow me to allocate more time and resources towards more frequent and higher quality articles. Learn More

Attribution: All of the images in this document were created by Daniel Warfield, unless a source is otherwise provided. You can use any images in this post for your own non-commercial purposes, so long as you reference this article, https://danielwarfield.dev, or both.


Related Articles