Trends in Model Pre-training for Natural Language Understanding

The Uncertain Future of Token Prediction

Julia Turc
Towards Data Science

--

Photo by Patrick Tomasso on Unsplash

Pre-training is now ubiquitous in natural language understanding (NLU). Regardless of the target application (e.g., sentiment analysis, question answering, or machine translation), models are first pre-trained on vast amounts of free-form text, often hundreds of gigabytes. The intention is to initialize models with general linguistic knowledge that can be later leveraged in multiple contexts. A pre-trained model that is linguistically well-versed can then be fine-tuned on a much smaller dataset to perform the target application.

While we’ve emphatically determined the usefulness of exposing a model to endless Internet blabber, it’s still not obvious how the model should interact with it. There are two requirements regarding this interaction. First, the data needs to be gamified into a task: during each training step, the model attempts to solve the task, receives feedback on its performance, then adjusts its parameters accordingly. Second, because of the data magnitude, the task needs to be unsupervised: the correct predictions should already be present in the raw data, without the need for human annotation.

Traditionally, pre-training tasks revolved around predicting tokens that were artificially removed from a text document. Despite their simplicity (or maybe because of it), these techniques have been dominating the field since the inception of pre-training, with truly remarkable results. Yet we are probably only scratching the surface. There must be a lot of unleashed potential in datasets that exceed by several orders of magnitude the number of tokens we are exposed to in our childhood. Innovative ideas have sprouted in recent research proposing more elaborate pre-training tasks like document retrieval and paraphrasing.

The past: unidirectional language modeling

A simple yet effective technique is next-token prediction: given a text document, a model is trained to traverse it left-to-right and predict each token along the way, based on what it has read so far. This task is also known as language modeling (LM). The vanilla unidirectional formulation of language modeling was adopted by the now famous GPT models from OpenAI, whose massive computational scale compensate for the simplicity of the training objective. The GPT-3 model [1], with 175 billion parameters trained on 400 billion tokens, records unprecedented few-shot performance: in order to solve real-world tasks, it requires no or very little fine-tuning after pre-training.

Before the pre-training + fine-tuning paradigm started dominating NLU, pseudo-bidirectional language models had their moment of glory; instead of a single pass, they would traverse the input text twice (left-to-right and right-to-left) to give the illusion of bidirectional processing. For instance, ELMo [2], which set the trend for the Muppet series, used this technique to produce continuous input representations that would be later fed into an end-task model (in other words, only the input embeddings were pre-trained rather the entire network stack). Despite their popularity at the time, pseudo-bidirectional LMs never resurged in the context of pre-training + fine-tuning.

The present: masked language modeling

In the past two years, the de facto building block for NLU has been Google’s BERT [3], pre-trained with two objectives: masked language modeling (MLM) and next-sentence prediction. During MLM training, the model is exposed to text documents in which 15% of the tokens were replaced with a special [MASK] token; its task is to recover these elisions. Access to context on both sides of the masked token encourages the model to process text in both directions. It is posited that MLM encourages models to emulate human reasoning more closely than unidirectional and pseudo-bidirectional LMs.

The main disadvantage of MLM over its next-word prediction predecessor is reduced sample efficiency, since only 15% of the tokens are predicted. Additionally, the [MASK] tokens introduce a discrepancy between the inputs observed during the pre-training and fine-tuning stages, since downstream tasks do not mask their inputs. XLNet [4] proposed a variation of MLM that addressed these issues, but its adoption has remained relatively limited compared to BERT.

The future: beyond token prediction

Despite their success, token-prediction objectives are not flawless. Their major criticism is that they focus exclusively on linguistic form: models learn the characteristics of coherent language without necessarily associating meaning to it. Generative models like GPT-2 [5] are known to hallucinate — that is, to produce convincing factual-looking text that is not anchored in reality. This is perhaps the reason why OpenAI is reluctant to open-sourcing their models.

Recent work has made progress towards grounding natural language into the reality of our world. Research projects such as REALM (Retrieval-Augmented Language Model Pre-training) [6] and MARGE (Multilingual Autoencoder that Retrieves and Generates) [7] introduce more elaborate pre-training techniques that go beyond simple token prediction.

REALM (Retrieval-Augmented Language Model Pre-training)

REALM focuses on the specific application of open-domain question answering (open-QA): given a question and a database of documents, the task is to extract the correct answer from one of the documents. Following standard practices, pre-training is performed on a large corpus of free-form text. The innovation is an adjustment to the classic MLM task: before predicting a masked token, the model is trained to first retrieve a document that helps filling in the gap.

REALM pre-training via masked language modeling and document retrieval [6]

This technique has two major advantages. First, it encourages evidence-based predictions rather than well-sounding guesses (it also helps that masks are applied over salient spans like “July 1696” rather than arbitrarily). Second, it conveniently sets the stage for end-to-end open-QA fine-tuning, as shown in the figure below. Note that the training data does not explicitly link the question-answer pairs to relevant documents. But since the model acquires some notion of document relevance during pre-training, the lack of this explicit signal is less damaging. The main disadvantage is the engineering complexity behind retrieving a document on each training step, while ensuring that this operation (over a potentially large set) remains differentiable.

REALM fine-tuning for open-domain question answering [6]

You might be tempted to think that including a retrieval step during pre-training reduces the generality of the pre-trained model (after all, REALM was only applied to open-QA). But MARGE shows that is not the case.

MARGE (Multilingual Autoencoder that Retrieves and Generates)

All the methods above propose some sort of reconstruction of the input after it was altered. Left-to-right LMs remove all the text to the right hand side of the token being predicted, and MLMs elide arbitrary tokens from the input text. MARGE pre-training takes this challenge to the next level and asks the model to do the seemingly impossible: reconstruct a “target” document that it has never seen, not even crippled with truncations or omissions. Instead, the model is shown other “evidence” documents related to the input (e.g. paraphrases or even translations of it to another language), and challenged to regenerate the original text. The figure below shows an example.

Target and evidence documents used for MARGE pre-training (adapted from [7])

One of the major differences between MARGE and REALM is that the former is a sequence-to-sequence model (consisting of an encoder and a decoder), while the latter is solely a decoder. This enables MARGE to be fine-tuned on a wide range of downstream tasks, including discriminative (e.g., classification or extractive question answering) and generative tasks (e.g. machine translation, summarization or paraphrasing). MARGE makes the interesting observation that the pre-training + fine-tuning paradigm holds even when retrieval is only performed during pre-training (remember that REALM used its retriever during both stages). A truly remarkable outcome is that MARGE can perform decent zero-shot machine translation — that is, without any fine-tuning on parallel data!

Conclusion

Increasing the amount of training data remains a surefire way to boost model quality, and this trend doesn’t seem to slow down even in the presence of hundreds of billions of tokens. But despite being exposed to more text than a human being will ever process in their lifetime, machines are still underperforming us, especially in tasks that are generative in nature or that require complex reasoning. Which is to say — the way models interact with the data is very inefficient. The research community has started moving away from pre-training tasks that solely rely on linguistic form and incorporate objectives that encourage anchoring language understanding in the real world.

References

  1. Brown et al., Language Models are Few-Shot Learners (2020)
  2. Peters et al., Deep contextualized word representations (2018)
  3. Devlin et al., BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (2018)
  4. Yang et al., XLNet: Generalized Autoregressive Pretraining for Language Understanding (2019)
  5. Radford et al., Language Models and Unsupervised Multitask Learners (2019)
  6. Guu et al., REALM: Retrieval-Augmented Language Model Pre-Training (2020)
  7. Lewis et al., Pre-training via Paraphrasing (2020)

--

--