The power of constrained language models

Why and how to build constrained language models with a custom beam search algorithm. A guide with Hugging Face code.

Karel D'Oosterlinck
Towards Data Science

--

Picture of a key on the ground.
Photo by Michael Dziedzic on Unsplash

Overview

  1. Introduction
  2. Why care about constrained language models?
  3. What is the beam search algorithm?
  4. How to constrain any language model with a custom beam search algorithm?
  5. The implementation using Hugging Face
  6. The results
  7. Conclusion

Introduction

Pre-trained generative language models (such as OpenAI’s GPT2 and GPT3) or seq2seq models (such as T5 or the recently released T0) generate free-flowing natural language. This entails that their output sentences can have any shape. To get the most value out of these models, we sometimes want the outputs to follow a certain structure — this is called a constrained language model. In this post I will explain:

  • Why you would want a constrained language model.
  • What the beam search algorithm is used for.
  • How to easily constrain any language model with a custom beam search algorithm.
  • How to implement this using Hugging Face.

After reading this post, you will be able to gain more value out of your language model by controlling the shape of its outputs.

This post relates to my previous post, where I explain how to unlock the true potential of GPT3 by manipulating the shape of its inputs (referred to prompt engineering).

A notebook containing example code is available here.

Why care about constrained language models?

As previously mentioned, we sometimes want the output of our models to follow a certain structure. For example, if we want to further process the output, it would be nice if we have some guarantees about their structure and properties.

In a case study I describe here, the output of a model needs to be parsed by an additional component. By enforcing a proper structure, I can guarantee the output will be easy to parse.

We could make sure a given model never outputs a sentence where 2 specific words co-occur, or we could make sure these words always co-occur. The possibilities are endless.

Alternatively, we might want to generate sentences with only an even amount of characters per token, just for the fun of it.

What is the beam search algorithm?

We can achieve this effect by using a custom beam search algorithm.

Simply put, the beam search algorithm (sometimes called the Viterbi algorithm) is used to generate highly probable continuations of a sentence.

Generative language models are trained to predict the next most probable token given an input sequence. If we want to generate 10 new tokens, we can treat this as generating 1 new token 10 times over. We take the original sentence, generate the first token and use the resulting sentence to generate the second token, etc. This is called greedy decoding.

Greedy decoding does not always produce the most optimal continuation of a prompt when we consider multiple tokens.
(Image by Author) Greedy decoding — Taking the most optimal next token several times in a row does not lead to the most optimal continuation of the sentence.

If we want to generate the most probable continuation of 10 tokens, this is not equal to picking the most probably token 10 times in a row — greedy decoding is not optimal. Sometimes it makes sense to not pick the most probable next token at a given step so that the following tokens can compensate this by being more probable.

The beam search algorithm tries to mitigate the problem of greedy decoding by considering the K most probable next tokens at each step. By taking more tokens into account, we can find situations where picking a less probable token in a given step gives rise to more probable tokens in the subsequent steps.

The beam search algorithm leads to better continuations by considering multiple possible continuations of a prompt.
(Image by Author) Beam search — How the beam search algorithm leads to better continuations by considering multiple possible continuations of a prompt.

As an example, if we continue the sentence ‘My cute dog is a’ for 1 token, the most probable continuation becomes ‘My cute dog is a little’. However, if we use beam search to find the most probable continuation of 3 tokens, ‘My cute dog is a bit of a’ becomes the most probable sentence. While ‘bit’ is a less probable continuation compared to ‘little’ if we only consider 1 token, the end result of 3 tokens is more probable with ‘bit’.

This algorithm is a heuristic one, meaning that is is not perfect. If K becomes larger and we take more possible next tokens into account at each step, we will get a more probable final continuation in the end. However, we pay for this by also increasing the computational cost of the algorithm.

If you are interested in learning more about different generation techniques for language models, I suggest this blogpost.

How to constrain any language model with a custom beam search algorithm?

We can achieve outputs that adhere to a certain structure by manipulating this beam search algorithm. We can dynamically set the probability of some next tokens to zero at certain steps to make sure these tokens never get generated. By banning certain tokens at the right time, we can force the outputs to follow any pattern.

Imagine we want to generate an output where the first word starts with the letter ‘a’, the second with the letter ‘b’, the third with the letter ‘c’, the fourth again with the letter ‘a’, etc. During each step of the beam search algorithm, we need to check the first letter of the previous word and then set the probability of all next tokens with a wrong starting letter to zero. The model will now generate the most probable sentences that adhere to this format.

Constrained beam search — By removing tokens we do not want during the beam search algorithm, we can constrain the outputs of a language model to a predefined structure.
(Image by Author) Constrained beam search — By removing tokens we do not want during the beam search algorithm, we can constrain the outputs of a language model to a predefined structure.

While the previous example is rather trivial, we can use this technique to enforce arbitrary constraints on the generated sentence.

The implementation using Hugging Face

To save us a lot of work, we will leverage Hugging Face to do much of the heavy lifting. The documented code implementing the ‘a -> b -> c’ pattern described above (using GPT2) is available here.

The results

GPT2 constrained to the ‘a -> b -> c’ pattern continues the prompt ‘My cute dog is a’ in the following manner:

beam 0: My cute dog is a bit confused about being called a bitch
beam 1: My cute dog is a bit confused about being called a baby
beam 2: My cute dog is a big cat and big cats are big
beam 3: My cute dog is a bit confused about being called a bunny
beam 4: My cute dog is a bit confused about being called a big
beam 5: My cute dog is a bit confused about being called a boy
beam 6: My cute dog is a bit confused about being called a b
beam 7: My cute dog is a bit confused about being called a bad
beam 8: My cute dog is a bit confused about being confused about being
beam 9: My cute dog is a bit confused about being called a black

Notice how the output of the model adheres to this ‘a -> b -> c’ structure without us having to provide the model with any examples of this pattern in the input. I particularly like how the model still manages to produce some coherent sentences — it makes sense to be confused about being called a bunny if you are a dog — despite this (rather strict) constraint.

Conclusion

By writing our own version of the beam search algorithm, we are able to constrain the output of a pre-trained language model. This can be applied to generative models such as GPT2 and GPT3 and even seq2seq models such as T5 and T0. This is particularly useful when we want the output of our models to follow a certain pre-defined structure or adhere to a set of rules.

By using the Hugging Face library, we can easily implement this custom beam search algorithm ourselves. The code can be found here.

If you are interested in getting even more value out of your language model, check out this post about how prompt engineering can help you unlock the true value of your generative language models.

--

--

PhD student in NLP at Ghent University. Visiting Student Researcher at Stanford University. Hobbyist full stack web developer.