How to make Deep Learning Models Generalize Better

A new technique developed by Facebook’s AI research team.

Michael Berk
Towards Data Science

--

Invariant Risk Minimization (IRM) is an exciting new learning paradigm that helps predictive models generalize beyond the training data. It was developed by researchers at Facebook and outlined in a 2020 paper. The method can be added to virtually any modeling framework, however it’s suited best for black-box models that leverage lots of data i.e. neural networks and their many flavors.

Without further ado, let’s dive in.

0. Technical TLDR

At a high level, IRM is a learning paradigm that attempts to learn causal relationships instead of correlational ones. By developing training environments, structured samples of data, we can maximize accuracy while also guaranteeing predictor invariance. The predictors that both fit our data well and are invariant across environments are used as outputs in the final model.

Deep Learning, Machine Learning, Neural Networks, Data Science, Casual Inference
Figure 1: theoretical performance of 4-fold CV (top) vs. Invariant Risk Minimization (IRM) (bottom). These values are extrapolated from simulations in the paper. Image by the author.

Step 1: Develop your environment sets. Instead of reshuffling the data and assuming that they’re IID, we use knowledge about our data selection process to develop sampling environments. For instance, for a model that parses text in an image, our training environments could be grouped by the person who wrote the text.

Step 2: Minimize loss across environments. After we have our environments developed, we fit predictors that are approximately invariant and optimize our accuracy across environments. See section 2.1 for more.

Step 3: Generalize better! Risk Invariant Minimization methods exhibit higher out-of-distribution (OOD) accuracy than traditional learning paradigms.

1. But what’s actually going on?

Let’s slow down a bit and understand how Risk Invariant Minimization actually works.

1.1 What’s the purpose of predictive models?

Starting at square one, the purpose of a predictive model is to generalize i.e. perform well on unseen data. We call unseen data out-of-distribution (OOD).

To simulate new data, a variety of methods have been introduced such as cross-validation. Although this method is better than a simple training set, we’re still limited to the observed data. So, can you be certain that the model will generalize?

Well, usually you cannot.

For well-defined problems where you have a very good understanding of the data generating mechanism, we can be confident that our data sample is representative of the population. However, for most applications, we lack this understanding.

Take an example cited in the paper. We are looking to tackle the scintillating problem of deciding whether an image shows a cow or a camel.

Deep Learning, Machine Learning, Neural Networks, Data Science, Casual Inference

To do this, we train a binary classifier using cross-validation and observe a high accuracy on our testing data. Great!

However, after some more digging, it turns out that our classifier was simply using the color of the background to determine the label of cow vs. camel; when a cow was placed in a sand-colored background, the model always believed it was a camel and visa versa.

Now, can we assume that cows will always be observed in pastures and camels in deserts?

Probably not. And although this is a trivial example, we can see how this lesson could generalize to more complex and important models.

1.2 Why are current methods insufficient?

Before diving into the solution, let’s develop more understanding of why the popular train/test learning paradigm is insufficient.

The classic train/test paradigm is referred to as Empirical Risk Minimization (ERM) in the paper. In ERM, we pool our data into training/testing sets, train our model on all features, validate using our testing sets, and return the fitted model with the best testing (out-of-sample) accuracy. One example would be a 50/50 train test split.

Now, to understand why ERM doesn’t generalize well, let’s take a look at its three main assumptions then tackle them one at a time. Quickly, they are:

  1. Our data are independent and identically distributed (IID).
  2. As we gather more data, the ratio between our sample size n and the number of significant features should decrease.
  3. Perfect testing accuracy only occurs if there is a realizable (buildable) model with perfect training accuracy.

At first glance, all three of these assumptions might appear to hold. However, *spoiler alert* they often don’t. Here’s why.

Taking a look at our first assumption, our data are almost never truly IID. In practice, data must be gathered which almost always introduces relationships between data points. For example, all images of camels in deserts must be taken in certain parts of the world.

Now there are many cases where the data are “very” IID, but it’s important to think critically about whether and how your data collection introduces bias.

Assumption #1: if our data aren’t IID, the first assumption is invalidated and we can’t randomly shuffle our data. It’s important to consider if your data generating mechanism introduces bias.

For our second assumption, if we were modeling casual relationships, we would expect the number of significant features to remain fairly stable after a certain number of observations. In other words, as we gather more high-quality data, we’d be able to pick up true causal relationships and would eventually map them perfectly, so more data wouldn’t improve our accuracy.

However, with ERM, this is rarely the case. Since we are unable to determine if a relationship is causal, more data can often lead to more spurious correlations being fit. This phenomenon is known as the bias-variance tradeoff.

Assumption #2: When fitting with ERM, the number of significant features will probably increase as our sample size increases, thereby invalidating our second assumption.

Finally, our third assumption simply states that we have the capacity to build a “perfect” model. If we lack data or robust modeling techniques, this assumption will be invalidated. However, unless we know this to be false, we always assume it to be true.

Assumption #3: We assume the optimal model to be realizable for sufficiently large datasets, so assumption #3 holds.

Now there are non-ERM methods discussed in the paper, but they also fall short for a variety of reasons. You get the idea.

2. The Solution: Invariant Risk Minimization

The proposed solution, called Invariant Risk Minimization (IRM), overcomes all the problems listed above. IRM is a learning paradigm that estimates causal predictors from multiple training environments. And, because we’re learning from different data environments, we’re more likely to generalize to new OOD data.

How do we do this? We leverage the concept that causality relies on invariance.

Returning to our example, let’s say cows and camels are seen in their respective grassy and sandy habitats 95% of the time, so if we fit on the color of the background, we’d achieve 95% accuracy. At face value, that’s a pretty good fit.

However, borrowing a core concept from randomized control trials called counterfactuals, if we see a counter-example to a hypothesis, we have disproven it. So, if we see even one cow in a sandy environment, we can conclude that sandy backgrounds do not cause camels.

While a strict counterfactual is a bit harsh, we build this concept into our loss function by severely penalizing instances where our models mispredicts in a given environment.

For instance, consider a set of environments where each corresponds to a single country. Let’s say that in 9/10 of them, cows live in pastures and camels live in deserts, but in the 10th environment this pattern reverses. When we train on the 10th environment and observe many counterexamples, the model learns that background cannot cause the label cow or camel, so it lowers the significance of that predictor.

2.1 The Method

Now that understand IRM in English, let’s enter the world of math so we can understand how to implement it.

Deep Learning, Machine Learning, Neural Networks, Data Science, Casual Inference
Figure 2: minimization expression — source.

Figure 2, shows our optimization expression. As indicated by the summation, we are looking to minimize the summed value across all our training environments.

Breaking it down further, the “A” term represents our predictive accuracy on a given training environment, where phi (𝛷) represents a data transformation, such as the log or a kernel transformation to higher dimensions. R denotes the risk function of our model for a given environment e. Note that a risk function is simply the average of a loss function. A classic example is the mean squared error (MSE).

The “B” term is simply a positive number that is used to scale our invariance term. Remeber when we said that a strict counterfactual may be too harsh? This is where we can scale how harsh we want to be. If lambda (λ) is 0, we don’t care about invariance and simply optimize accuracy. If λ is large, we care a lot about invariance and penalize accordingly.

Finally, the “C” and “D” terms represent our model’s invariance across training environments. We don’t need to get too deep into the term, but in short, our “C” term is a gradient vector of our linear classifier w, which defaults to a value of 1. “D” is the risk of that linear classifier w multiplied by our data transformation (𝛷). And the entire term is the squared distance of the gradient vector.

The paper goes into a lot of detail on these terms, so check out section 3 if you're curious.

In summary, “A” is our model accuracy, “B” is a positive number that scales how much we care about invariance, and “C”/”D” is our model’s invariance. If we minimize this expression, we should find a model that only fits causal effects found in our training environments.

2.2 IRM Next Steps

Unfortunately, the IRM paradigm outlined here only works for linear cases. Transforming our data to high dimensional space can lead to effective linear models, however, some relationships are fundamentally non-linear. The authors leave the nonlinear case to future work.

If you want to follow the research, you can check out the work of the authors: Martin Arjovsky, León Buttou, Ishaan Gulrajani, and David Lopez-Paz.

And there’s our method. Not too bad, right?

3. Implementation Notes

  • Here’s a PyTorch package.
  • IRM is best suited for causal relationships that are not known. If there are known relationships, you should account for them in the structure of the model. A famous example is convolutions for convolutional neural nets (CNN).
  • IRM has lots of potential for unsupervised models and reinforcement learning. Model fairness is also an interesting application.
  • The optimization is quite complex because there are two minimization terms. The paper outlines a transformation that makes the optimization convex, but only in the linear case.
  • IRM is robust to mild model misspecification because it’s differentiable with regard to the covariances of the training environments. So, while a “perfect” model is ideal, the minimization expression is resilient to small human error.

Thanks for reading! I’ll be writing 47 more posts that bring “academic” research to the DS industry. Check out my comments for links/ideas on IRM methods.

--

--