The world’s leading publication for data science, AI, and ML professionals.

On Learning in the Presence of Underrepresented Groups

Change is Hard: A Closer Look at Subpopulation Shift (ICML 2023)

Let me introduce to you our latest work, which has been accepted by ICML 2023: Change is Hard: A Closer Look at Subpopulation Shift. Machine learning models have shown great potential in many applications, but they often perform poorly on subgroups that are underrepresented in the training data. Understanding the variation in mechanisms that cause such subpopulation shifts, and how algorithms generalize across diverse shifts at scale remains a challenge. In this work, we aim to fill this gap by providing a fine-grained analysis of subpopulation shifts and their impact on machine learning algorithms.

We first present a unified framework that dissects and explains common shifts in subgroups. Further, we introduce a comprehensive benchmark consisting of 20 state-of-the-art algorithms, which we evaluate on 12 real-world datasets spanning the domains of vision, language, and healthcare. Through our analysis and benchmarking, we provide intriguing observations and understanding of subpopulation shifts and how machine learning algorithms generalize under such real-world shifts. The code, data, and models have been open sourced on GitHub: https://github.com/YyzHarry/SubpopBench.


Background and Motivation

Machine learning models frequently exhibit drops in performance under the presence of distribution shifts. Such shifts occur when the underlying data distribution changes (e.g., training distribution is different from testing), leading to performance drops when deploying the models. Constructing machine learning models that are robust to these shifts is critical to the safe deployment of such models in the real-world. One ubiquitous type of distribution shift is subpopulation shift, which is characterized by changes in the proportion of some subpopulations between training and deployment. In such settings, models may have high overall performance but still perform poorly in rare subgroups.

For example, in the task of cow and camel classification, cows are often found in areas with green grass, and camels are often found in areas with yellow sand backgrounds. However, such correlation is spurious because the presence of cows or camels is unrelated to the background color. As a result, the trained model performs well on the aforementioned images, but cannot generalize to animals with different background colors that are rare in the training data, such as cows on sand, or camels on grass.

In addition, when it comes to medical diagnosis, studies have found that machine learning models often perform worse on underrepresented age or ethnicity groups, raising important Fairness concerns.

All these shifts have been generally referred to as subpopulation shift, but, little is understood on the variation in mechanisms that cause subpopulation shifts, and how algorithms generalize across such diverse shifts at scale. So, how to model subpopulation shift?

A Unified Framework of Subpopulation Shift

We first provide a unified framework for subpopulation shift modelling. In classic classification setup, we have training data from multiple classes (where we use different color densities to represent different number of samples in each class). However, when it comes to subpopulation shift, there exist attributes in addition to the class – such as the background colors in the cow-camel problem. In this case, we could define the discrete subpopulations based on both the attribute and the label, and here the number of samples for different attributes within the same class could also vary (see figure below). And naturally, to test the model, similar to the classification setting that we evaluate performance across all classes, in subpopulation shift we test the model over all subgroups, to ensure the worst performance over all subpopulations is good enough, or ensure equally good performance across all groups.

Specifically, to provide a generic mathematical formulation, we first rewrite the classification model using Bayes’ theorem. We further view each input x as being fully described or generated from a set of underlying core features (_Xcore), and a list of attributes (a). Here, _Xcore denotes the underlying invariant components that are label-specific and support robust classification, whereas attributes a may have inconsistent distributions and are not label-specific. As such, we can integrate this modelling back to the equation, and further decompose it into three terms, as shown below:

Specifically, the first term represents the pointwise mutual information (PMI) between _Xcore and y, which is the robust indicator that related to the underlying class labels. The second and third terms correspond to the potential bias arising in the attribute distribution and the label distribution, respectively. Such modelling explains how the attribute and class influence the outcomes under subpopulation shift. Therefore, given invariant _Xcore between training and testing distributions, we can ignore changes in first term and focus on how the attribute and class influence the outcomes under subpopulation shift.

Based on this framework, we formally define and characterize four basic types of subpopulation shift: spurious correlations, attribute imbalance, class imbalance, and attribute generalization. Each type constitutes a basic shift component that potentially arises in subpopulation shift.

First, when certain attribute is spuriously correlated with label y in training but not in test data, it implies spurious correlations. Moreover, when certain attributes are sampled with a much smaller probability than others, it induces attribute imbalance. Similarly, class labels can exhibit imbalanced distributions, causing lower preference for minority labels. This will lead to class imbalance. And finally, certain attributes can be totally missing in training, but present in testing for certain classes, which motivates the need for attribute generalization. The source of attribute / class biases for each of these shifts, as well as the impact on the classification model are summarized in the table below:

These four cases constitute the basic shift components, and are important elements to explain complex subgroup shifts in real data. And in practice, datasets often consist of multiple types of shift simultaneously, instead of one.

SubpopBench: Benchmarking Subpopulation Shift

Now after setting up the formulation, we propose SubpopBench, a comprehensive benchmark including state-of-the-art algorithms evaluated on 12 real-world datasets. In particular, these datasets are originated from a variety of modalities and tasks, including vision, language, and healthcare applications, with data modalities rage from natural images, text, clinical text, to chest X-rays. They also exhibit different shift components.

For details about this benchmark, please refer to our paper. With the established benchmark and over 10K trained models using 20 state-of-the-art algorithms, we reveal intriguing observations for future research in this field.

A Fine-Grained Analysis on Subpopulation Shift

SOTA Algorithms Only Improve Certain Types of Shift

First, we observe that SOTA algorithms only improve subgroup robustness on certain types of shift, but not others.

We plot here the worst group accuracy improvement over ERM for various SOTA algorithms. For spurious correlations and class imbalance, existing algorithms can provide consistent worst-group gains over ERM, indicating that progress has been made for tackling these two specific shifts.

Interestingly however, when it comes to attribute imbalance, little improvement is observed across datasets. In addition, the performance becomes even worse for attribute generalization.

These findings stress that current advances are only made for specific shifts, while no progress has been made for the more challenging shifts such as AG.

The Role of Representations and Classifiers

Further, we are motivated to explore the role of representation and classifier in subpopulation shift. In particular, we separate the whole network into two parts: the feature extractor f, and the classifier g, where f extracts the latent features from the input, and g outputs the final prediction. We ask the question, how do representation and classifier affect subgroup performance?

First, given a base ERM model, when just optimizing the classifier learning with fixed the representation, it can substantially improve the performance for Spurious correlations & Class imbalance, indicating that representations learned by ERM are already good enough. Interestingly however, improving representation learning instead of classifier can bring notable gains for attribute imbalance, indicating that we may need more powerful features for certain shifts. Finally, no stratified learning manners lead to performance gains under attribute generalization. This highlights that one needs to consider the model pipeline design when facing different types of shift in reality.

On Model Selection & Attribute Availability

Furthermore, we observe that model selection and attribute availability considerably affect subpopulation shift evaluation.

Specifically, when gradually removing the attribute annotations in training and/or validation data, all algorithms experienced notable performance drops, especially when there’s no attribute available in both training and validation data.

This indicates that the access to attributes still plays a significant role in getting reasonable performance in subpopulation shift, and future algorithms should consider more realistic scenarios for model selection and attribute availability.

Metrics Beyond Worst-Group Accuracy

Finally, we reveal the fundamental tradeoff between evaluation metrics. Worst-group accuracy, or WGA, is considered as gold-standard in subpopulation evaluation. However, does improving WGA always improve other meaningful metrics?

We first show that improving WGA could lead to improved performance for certain metrics, such as the adjusted accuracy shown here. However, if we further consider worst-case precision, it surprisingly shows very strong negative linear correlation with WGA. This reveals the fundamental limitation of using WGA as the only metric to assess model performance in subpopulation shift: A well performed model with high WGA can however have low worst-class precision, which is especially alarming in critical applications such as medical diagnosis.

Our observations emphasize the need for more realistic & broader set of evaluation metrics in subpopulation shift. We also show many other metrics that exhibit inverse correlation with WGA in our paper.

Closing Remarks

To conclude this article, we systematically investigate subpopulation shift problem, formalize a unified framework to define and quantify different types of subpopulation shift, and further set up a comprehensive benchmark for realistic evaluation in real-world data. Our benchmark includes 20 SOTA methods and 12 real-world datasets across different domains. Based on over 10K trained models, we reveal intriguing properties in subpopulation shift that have implications for future research. We hope our benchmark and findings will promote realistic and rigorous evaluations and inspire new advances in subpopulation shift. At the end, I attach several relevant links of our paper; thanks for reading!

Code: https://github.com/YyzHarry/SubpopBench

Project Page: https://subpopbench.csail.mit.edu/

Talk: https://www.youtube.com/watch?v=WiSrCWAAUNI


Related Articles