Tidying up the framework of dataset shifts

Taking a step back about the causes of model degradation

Valeria Fonseca Diaz
Towards Data Science

--

In collaboration with Marco Dalla Vecchia as the image creator

We train models and use them to predict certain outcomes given a set of inputs. We all know that’s the game of ML. We know quite a lot about training them, so much so that now they have evolved into AI, the biggest level of intelligence that has ever existed. But when it comes to using them, we are not that far ahead, and we continue exploring and understanding every aspect that matters after models go into deployment.

So today, we will discuss the issue of model performance drift (or just model drift), also frequently known as model failure or model degradation. What we refer to is the issue of the quality of predictions that our ML model delivers. Be it a class or a number, we care about the gap between that prediction and what the real class or value would be. We talk about model performance drift when the quality of those predictions goes down with respect to the moment when we deployed the model. You may have found other terminology for this topic in the literature, but stay with me on model performance drift or simply model drift, at least for the purpose of our current conversation.

What we know

Several blogs, books, and many papers have explored and explained the core concepts of model drift so we’ll enter into this current picture first. We have organized the ideas mainly into the concepts of covariate shift, prior shift, and conditional shift. The latter is also known commonly known as concept drift. These shifts are known to be the main causes of model drift (remember, a drop in the quality of predictions). The summarized definitions go as follows:

  • Covariate shift: Changes in the distribution of P(X) without necessarily having changes in P(Y|X). This means that the distribution of the input features changes, and some of those shifts may cause the model to drift.
  • Prior shift: Changes in the distribution of P(Y). Here, the distribution of the labels or the numerical output variable changes. Most likely, if the probability distribution of the output variable shifts, the current model will have a large uncertainty on the given prediction so it may easily drift.
  • Conditional shift (aka concept drift): The conditional distribution P(Y|X) changes. This means that, for a given input, the probability of the output variable has changed. As far as we know until now, this shift usually leaves us with very little room to keep up the quality of predictions. Is it so really?

Many sources exist listing examples of the occurrence of these dataset shifts. One of the core opportunities for research is detecting these types of shifts without the need for new labels [1, 2, 3]. Interesting metrics have been recently released to monitor the prediction performance of the model in an unsupervised way [2, 3]. They are indeed motivated by the different concepts of dataset shifts and they reflect quite accurately the changes in the real probability distributions of the data. So we are going to dive into the theory of these shifts. Why? Because perhaps there’s some order we can put about these definitions. By tidying up, we might be able to move forward more easily or simply understand this entire framework more clearly.

To do that, let’s go back to the beginning and make a slow derivation of the story. Grab a coffee, read slowly, and stay with me. Or just, don’t drift!

The real and the estimated model

The ML models we train attempt to get us close to a real, yet unknown, relationship or function that maps a certain input X to an output Y. We naturally distinguish the real unknown relationship from the estimated one. Yet, the estimated model is bound to the behavior of the real unknown model. That is, if the real model changes and the estimated model is not robust against these changes, the estimated model’s predictions will be less accurate.

The performance we can monitor deals with the estimated function but the causes of model drift are found in the changes of the real model.

  • What is the real model? The real model is founded in the so-called conditional distribution P(Y|X). This is the probability distribution of an output given an input.
  • What is the estimated model? This one is a function ê(x) which specifically estimates the expected value of P(Y|X=x). This function is the one connected to our ML model.

Here’s a visual representation of these elements:

(Image by author)

Good, so now that we clarified these two elements, we’re ready to organize the ideas behind the so-called dataset shifts and how the concepts connect to each other.

The new arrangement

The global cause of model drift

Our main objective is to understand the causes of model drift for our estimated model. Because we already understood the connection between the estimated model and the conditional probability distribution, we can state here what we already knew before: The global cause for our estimated model to drift is the change in P(Y|X).

Basic and apparently easy, but more fundamental than we think. We assume our estimated model to be a good reflection of the real model. The real model is governed by P(Y|X). So, if P(Y|X) changes, our estimated model will likely drift. We need to mind the path we are following in that reasoning which we showed in the figure above.

We knew this already before, so what’s new about it? The new thing is that we now baptize the changes in P(Y|X) here as the global cause, not just a cause. This will impose a hierarchy with respect to the other causes. This hierarchy will help us nicely position the concepts about the other causes.

The specific causes: Elements of the global cause

Knowing that the global cause lies in the changes in P(Y|X), it becomes natural to dig into what elements constitute this latter probability. Once we have identified those elements, we will continue talking about the causes of model drift. So what are those elements?

We have known it always. The conditional probability is theoretically defined as P(Y|X) = P(Y, X) / P(X), that is, the joint probability divided by the marginal probability of X. But we can open up the joint probability once more and we obtain the magical formula we’ve known from centuries ago:

(Image by author)

Do you already see where we’re going? The conditional probability is something that is fully defined by three elements:

  • P(X|Y): The inverse conditional probability
  • P(Y): The prior probability
  • P(X): The covariates' marginal probability

Because these are the three elements that define the conditional probability P(Y|X), we are ready to give a second statement: If P(Y|X) changes, those changes come from at least one of the three elements that define it. Put differently, the changes in P(Y|X) are defined by any change in P(X|Y), P(Y), or P(X).

That said, we have positioned the other elements from our current knowledge as specific causes of model drift rather than just parallel causes to P(Y|X).

Going back to the beginning of this post, we listed covariate shift and prior shift. We note, then, that there’s yet another specific cause: the changes in the inverse conditional distribution P(X|Y). We usually find some mention of this distribution when talking about the changes in P(Y) as if in general we were considering the inverse relationship from Y to X [1,4].

The new hierarchy of concepts

(Image by author)

We can have now a clear comparison between the current thinking about these concepts and the proposed hierarchy. Until now, we have been talking about the causes of model drift by identifying different probability distributions. The three main distributions, P(X), P(Y), and P(Y|X) are known to be the main causes of drift in the quality of predictions returned by our ML model.

The twist I propose here imposes a hierarchy on the concepts. In it, the global cause of drift of a model that estimates the relationship X -> Y is the changes in the conditional probability P(Y|X). Those changes in P(Y|X) can come from changes in P(X), P(Y), or P(X|Y).

Let’s list some of the implications of this hierarchy:

  • We could have cases where P(X) changes, but if P(Y) and P(X|Y) also change accordingly, then P(Y|X) remains the same.
  • We can also have cases where P(X) changes, but if P(Y) or P(X|Y) doesn’t change accordingly, P(Y|X) will change. If you have given some thought to this topic before, you have probably seen that in some cases we could see X changing and those changes do not seem entirely independent of Y|X, so in the end, Y|X also changes. Here, P(X) is the specific cause of the changes in P(Y|X), which in turn is the global cause of our model drifting.
  • The previous two statements are true also for P(Y).

Because the three specific causes may or may not change independently, overall, the changes in P(Y|X) can be explained by the changes in these specific elements altogether. It can be because P(X) moved a bit here, and P(Y) moved a bit over there, then those two also make P(X|Y) change, which in the end altogether causes P(Y|X) to change.

P(X) and P(Y|X) are not to be thought of independently, P(X) is a cause of P(Y|X)

Where is the estimated ML model in all this?

Ok, now we know that the so-called covariate and prior shifts are causes of conditional shift rather than parallel to it. Conditional shifts encompass the set of specific causes for prediction performance degradation of the estimated model. But the estimated model is rather a decision boundary or function, not really a direct estimation of the probabilities at play. So what do the causes mean for the real and estimated decision boundaries?

Let’s gather all the pieces and draw the complete path connecting all the elements:

(Image by author)

Note that our ML model can come about analytically or numerically. Moreover, it can come as a parametric or non-parametric representation. So, in the end, our ML models are an estimation of the decision boundary or regression function which we can derive from the expected conditional value.

This fact has an important implication for the causes we have been discussing. While most of the changes happening in P(X), P(Y), and P(X|Y) will imply changes in P(Y|X) and so in E(Y|X), not all of them necessarily imply a change in the real decision boundary or function. In that case, the estimated decision boundary or function will remain valid if this one has originally been an accurate estimate. Look at this example below:

(Image by author)
  • See that P(Y) and P(X) changed. The density and location of the points account for a different probability distribution
  • These changes make P(Y|X) change
  • However, the decision boundary remained valid

Here’s one important bit. Imagine we are looking at the changes in P(X) only without information about the real labels. We would like to know how good the predictions are. If P(X) shifts towards areas where the estimated decision boundary has a large uncertainty, the predictions are likely inaccurate. So in the case of a covariate shift towards uncertain regions of the decision boundary, most likely a conditional shift is also happening. But we would not know if the decision boundary is changing or not. In that case, we can quantify a change occurring at P(X), which can indicate a change in P(Y|X), but we would not know what is happening to the decision boundary or regression function. Here’s a representation of this problem:

So now that we have said all this, it’s time for yet one more statement. We talk about conditional shift when we refer to the changes in P(Y|X). It’s possible that what we have been calling concept drift refers specifically to the changes in the real decision boundary or regression function. See here below a typical example of a conditional shift with a change in the decision boundary but without a covariate or prior shift. In fact, the change came from the change in the inverse conditional probability P(X|Y).

(Image by author)

Implications for our current monitoring methods

We care about understanding these causes so we can develop methods to monitor the performance of our ML models as accurately as possible. None of the proposed ideas is bad news for the available practical solutions. Quite the opposite, with this new hierarchy of concepts, we might be able to push further our attempts to detect the causes of model performance degradation. We have methods and metrics that have been proposed to monitor the prediction performance of our models, mainly proposed in light of the different concepts we have listed here. However, it is possible that we have mixed the concepts in the assumptions of metrics [2]. For example, we might have been referring to an assumption as “no conditional shift”, when in reality it may be specifically “no change in the decision boundary” or “no change in the regression function”. We need to keep thinking about this.

More about prediction performance degradation

Zooming in and zooming out. We have dived into the framework to think about the causes of prediction performance degradation. But we have another dimension to discuss this topic which comes about the types of prediction performance shifts. Our models suffer because of the listed causes, and those causes are reflected as different shapes of prediction misalignment. We find mainly four types: Bias, Slope, Variance, and Non-linear shifts. Check out this post to find out more about this other side of the coin.

Summary

We studied in this post the causes of model performance degradation and proposed a framework based on the theoretical connections of the concepts we already knew before. Here are the main points:

  • The probability P(Y|X) governs the real decision boundary or function.
  • The estimated decision boundary or function is assumed to be the best approximation to the real one.
  • The estimated decision boundary or function is the ML model.
  • The ML model can experience prediction performance degradation.
  • That degradation is caused by changes in P(Y|X).
  • P(Y|X) changes because there are changes in at least one of these elements: P(X), P(Y), or P(X|Y).
  • There can be changes in P(X) and P(Y) without having changes in the decision boundary or regression function.

The general statement is: if the ML model is drifting, then P(Y|X) is changing. The reverse is not necessarily true.

This framework of concepts is hopefully nothing but a seed of the crucial topic of ML prediction performance degradation. While the theoretical discussion is simply a delight, I trust that this connection will help us push further the aim of measuring these changes in practice while optimizing for the required resources (samples and labels). Please join the discussion if you have other contributions to your knowledge.

What’s causing your model to drift in prediction performance?

Have a happy thinking!

References

[1] https://huyenchip.com/2022/02/07/data-distribution-shifts-and-monitoring.html

[2] https://www.sciencedirect.com/science/article/pii/S016974392300134X

[3]https://nannyml.readthedocs.io/en/stable/how_it_works/performance_estimation.html#performance-estimation-deep-dive

[4] https://medium.com/towards-data-science/understanding-dataset-shift-f2a5a262a766

--

--

Data science researcher. Co-creator of MV Learn. Enthusiastic writer, technology ethics reader