How to Answer Questions with Machine Learning

A look into the SQuAD dataset, its top NLP models, and whether they are overfitting.

Michael Berk
Towards Data Science

--

Have you ever wanted to build an algorithm that does your homework? While the tech isn’t there yet, we’re getting close.

NLP data science natural language processing MNIST question answering SQuAD Stanford Question and Answer Dataset modeling machine learning XLNet BERT NLP adaptive overfitting
Figure 1: models with top exact match (EM) scores on the SQuAD 2.0 dataset —src. Image by author.

In 2016, researchers at Stanford released a question answering dataset to train NLP models. Since then, there have been hundreds of models submitted, each of which is meticulously fine tuned to the dataset. While they boast impressive accuracy, we’d expect there to be some overfitting after so much reuse of the same dataset.

In this post we leverage research from a team at UC Berkley to determine whether these models A) are overfitting the dataset and B) able to generalize. We’re going to stay high-level.

Without further ado, let’s dive in.

0 — Technical TLDR

Question answering (QA) models are NLP-based models that look to answer questions based on text passages. The Stanford Question and Answer Dataset (SQuAD) was introduced in 2016 to facilitate the training of QA models. The repository also stores open-source NLP model submissions, some of which have outperformed the human baseline.

As with other popular datasets, we’d expect to observe adaptive overfitting so researchers at UC Berkley developed new datasets to estimate model generalizability and the impact of adaptive overfitting. Surprisingly, the models did not generalize well but also did not exhibit overfitting.

1 — But, what’s actually going on?

Ok, let’s slow down a bit and understand the SQuAD dataset, some of the best NLP models, and the methodology used to estimate adaptive overfitting.

1.1 — What is the SQuAD Dataset?

The Stanford Question and Answer Dataset (SQuAD) was introduced in hopes of furthering the field of question answering (QA) modeling. It is a reading comprehension dataset comprised of passages, questions, and answers. The passages were scraped from Wikipedia and questions/answers were crowdsourced using Amazon’s Mechanical Turk.

Note that there are two versions. SQuAD v1.1. does not contain unanswerable questions but SQuAD v2.0 does.

1.2 — Let’s see and Example

In figure 2 below, we can see a SQuAD passage as well as predictions from Microsoft Asia Research’s algorithm, r-net+.

NLP data science natural language processing MNIST question answering SQuAD Stanford Question and Answer Dataset modeling machine learning XLNet BERT NLP adaptive overfitting
Figure 2: SQuAD data source and modeling example — src. Image by author.

The text on the left in purple is information about our topic, in this case Super Bowl 50, which was taken (legally) from Wikipedia. On the right we have questions about text, followed by acceptable answers and some of the model’s predictions.

How is it possible that these algorithms can outperform humans? Well, let’s see…

2 — What are Type of Models Perform Best?

Much of the credit goes to two models: BERT and XLNet.

2.1 — BERT

Bidirectional Encoder Representations from Transformers (BERT) was introduced in 2018 and took the NLP world by storm. It popularized the concept of bi-directional learning, which is the concept of leveraging information on either side of our token of interest.

Within that overall framework, BERT leverages two training strategies.

In figure 3, we can see the first of these strategies called Masked LM (MLM).

NLP data science natural language processing MNIST question answering SQuAD Stanford Question and Answer Dataset modeling machine learning XLNet BERT NLP adaptive overfitting
Figure 3: diagram of BERT with the second token masked. Image by author.

In one sentence, MLM masks (removes) a known token in the training data and tries to use context around the missing token to predict its value. This approach hinges on bi-directional learning and results in strong performance.

The second is called Next Sequence Prediction (NSP), which simply tries to determine if two sentences are sequentially related or not. The training data are evenly split between true sequences, two sentences that are observed occurring one after the other, and random sequences, two probably unrelated sentences.

With this baseline structure, engineers can append additional layers to make the model suitable for their tasks. One example would be adding a final classification layer to perform sentiment analysis.

For brevity, we will move on, but check out the comments for more resources.

2.2 — XLNet

In 2020 XLNet was released which solved some of BERT’s shortcomings. According to a tweet by Quoc Le, a lead researcher on the project, XLNet outperforms BERT in over 20 tasks.

Let’s see how XLNet achieves this level of performance with a simplified example.

import itertools
perms = itertools.permutations(['A','M','C'])
print(list(perms))
# Output:
# [('A', 'M', 'C'), ('A', 'C', 'M'),
# ('M', 'A', 'C'), ('M', 'C', 'A'),
# ('C', 'A', 'M'), ('C', 'M', 'A')]

In the above code, we use python’s itertools library to show all permutations of 3 tokens. A and C represent words, and M represents the unknown token we are trying to predict.

One of BERT’s downfalls is that it’s bi-directional training is slow and complex. XLNet, on the other had, is strictly an autoregressive model that leverages permutations to gather contextual information.

NLP data science natural language processing MNIST question answering SQuAD Stanford Question and Answer Dataset modeling machine learning XLNet BERT NLP adaptive overfitting
Figure 4: example of XLNet training structure. Here tokens may vary depending upon the permutation order. Image by author.

One of the main reasons XLNet often outperforms BERT in text generation tasks is that BERT assumes independence of masked tokens, conditioned on unmasked tokens. However, english can exhibit lots of sequential interdependence so this assumption is often unsupported. XLNet also implements some fancy technical tricks, such as a segment recurrence mechanism and relative encoding scheme of Transformer-XL. But (unfortunately) we’re not going to get into that.

3 — Are Current Models Overfitting the Dataset?

Now that we have an understanding of the dataset and some of the top modeling techniques, let’s move on to the final topic of the post.

3.1 — Evaluation Methodology

The authors of the paper were looking to determine if these models exhibit adaptive overfitting, a phenomenon that occurs when many models are fit and tuned on the same dataset.

To do this, we outline potential bias in the following equation…

NLP data science natural language processing MNIST question answering SQuAD Stanford Question and Answer Dataset modeling machine learning XLNet BERT NLP adaptive overfitting
Figure 5: loss breakdown equation — src. Image by author.

In figure 5, we have defined each of the key terms, however let’s briefly talk about what each means…

  1. Adaptivity Gap (green): the difference between training and out-of-sample (OOS) testing loss for SQuAD data. With proper methodology for selecting our test set, this value should be very small.
  2. Distribution Gap (blue): the difference between SQuAD and new (New York Times, Reddit, Amazon) data training loss. We would expect models to perform differently when natural distribution shifts occurs, such as the data source. This value may be large.
  3. Generalization Gap (red): the difference between training and OOS testing loss for SQuAD data. As with the first bullet, we’d expect this value to be small.

All of these add up to the loss that is due to a models’ ability to generalize.

3.2 — Results

First, there was little evidence of adaptive overfitting in any of the models. OOS model performance on the SQuAD dataset was linearly related with the OOS performance of models tested on new datasets. Mathematically, this means that our adaptivity and generalization gaps had similar values.

This result is quite surprising because famous datasets like MNIST and ImageNet both exhibited adaptive overfitting. As more researchers develop models optimized for those datasets, the models tended to lose generalizability. However, with SQuAD we do not see the same phenomenon.

Second, there was a dramatic performance drop when fitting on the new datasets. In other words, the distribution gap was large. More specifically, the NYT, Amazon, and Reddit datasets exhibited average performance drops of 3.8, 14.0, and 17.4 F1 points, respectively.

While we might intuitively expect this drop in performance due to worsened grammar and sentence structure, the authors explored both sets of drops and were unable to find conclusive explanations.

4— Summary and Next Steps

While the lack of a conclusive explanation may be unsatisfying, there are some opportunities for next steps. But, before we get into that, let’s do a quick recap…

QA models look to answer questions based on text. SQuAD is a QA dataset that has catalyzed NLP work on QA applications. Top QA models include BERT and XLNet, however they still do not generalize well to other datasets. Finally, surprisingly there is little evidence for adaptive overfitting on the SQuAD dataset.

Now that we’re crystal clear on the takeaways, let’s look at some potential next steps…

  1. Construct metrics for comparing datasets that can explain performance differences. Using some standard diagnostics, the authors were unable to determine why models performed so poorly on Reddit and Amazon text relative to Wikipedia text.
  2. Look at the relationship between the bias-variance tradeoff and training data size. With a highly specific large dataset, we’d expect to see overfitting, but the relationship between data size and bias is not well known.
  3. Overall improve QA models. While some of models are extremely effective in specific domains, they still have a long way to go before being able to generalize as well as a human.

Thanks for reading! I’ll be writing 21 more posts that bring academic research to the DS industry. Check out my comment for links to the main source for this post and some useful resources.

--

--