How to learn imbalanced data arising from multiple domains

On Multi-Domain Long-Tailed Recognition, Imbalanced Domain Generalization and Beyond (ECCV 2022)

Yuzhe Yang
Towards Data Science

--

Let me introduce you to our new work, which has been accepted by ECCV 2022: On Multi-Domain Long-Tailed Recognition, Imbalanced Domain Generalization and Beyond. As the name suggests, the problem of this work is how to learn a robust model when there are data from multiple domains, and each of the domains can have (potentially different) data imbalance. Existing methods for dealing with imbalanced data/long-tailed distribution are only for single domain, that is, the data originates from the same domain; however, natural data can originate from distinct domains, where a minority class in one domain could have abundant instances from other domains. Effectively utilizing data from different domains is likely to improve the performance of long-tail learning over all domains. This paper promotes the paradigm of the traditional imbalanced classification problem and generalizes it from single domain to multiple domains.

We first propose the domain-class transferability graph, which quantifies the transferability between different domain-class pairs under data imbalance. In this graph, each node refers to a domain-class pair, and each edge refers to the distance between two domain-class pairs in the embedding space. We show that the transferability graph dictates the performance of imbalanced learning across domains. Inspired by this, we design BoDA, a loss function that theoretically tracks the upper-bound of transferability statistics to improve the model performance. We construct five new mulit-domain imbalanced datasets, and benchmark ~20 algorithms for comparisons. The code, data, and models have been open sourced on GitHub: https://github.com/YyzHarry/multi-domain-imbalance.

Background and Motivation

Real-world data often exhibit label imbalances — instead of a uniform label distribution over classes, in reality, data are by their nature imbalanced: a few classes contain a large number of instances, whereas many others have only a few instances. In order to deal with this phenomenon, many methods to solve the data imbalance have been proposed. A more detailed review can be found in my previous article.

However, the existing solutions for learning from imbalanced data mainly consider the case of single domain, that is, all samples come from the same data distribution. However, in reality, data for the same task can originate from different domains. For example, figure below shows Terra Incognita, a real-world collected datasets for wildlife recognition & classification. The pictures on the left show the camera traps established at different locations, as well as the samples of wild animals captured; the picture on the right shows (part of) the specific data distribution obtained from different camera locations and their shooting effects. We can clearly see that even for the same wildlife classification task, the parameters of different cameras, shooting backgrounds, light intensity, etc. are completely different; that is, there is a domain gap between different camera traps. And because some animals only appear in specific locations, this further leads to data imbalance for a camera (domain), and there is even no data for some categories (for example, location 100 has almost no data for categories 0 and 1). Yet, since the label distributions captured by different cameras are often quite different, this also implies that other domains are likely to have many samples in these categories — for example, location 46 has more category 1 data. This indicates that we can leverage multi-domain data to address the inherent data imbalance within each domain.

Terra Incognita dataset. Even for the same wildlife classification task, the parameters, shooting background, light intensity, etc. of different cameras can be completely different; data obtained by each same camera is also extremely imnbalanced; furthermore, the label distribution captured by different cameras is also very different and mismatched. But it also illustrates that we can leverage multi-domain data to address the inherent data imbalance within each domain. (Image by Author)

Likewise, similar situations occur in other practical applications. For example, in visual recognition problems, minority classes from “photo” images could be complemented with potentially abundant samples from “sketch” images. Similarly, in autonomous driving, the minority accident class in “real” life could be enriched with accidents generated in “simulation”. Also, in medical diagnosis, data from distinct populations could enhance each other, where minority samples from one institution could be enriched with instances from others. In the above examples, different data types act as distinct domains, and such multi-domain data could be leveraged to tackle the inherent data imbalance within each domain.

Therefore, in this work, we formulate the problem of Multi-Domain Long-Tailed Recognition (MDLT) as learning from multi-domain imbalanced data, with each domain having its own imbalanced label distribution, and generalizing to a test set that is balanced over all domain-class pairs. MDLT aims to learn from imbalanced data from multiple distinct domains, tackle label imbalance, domain shift, and divergent label distributions across domains, and generalize to the entire set of classes over all domains.

Multi-Domain Long-Tailed Recognition (MDLT) learns from multi-domain imbalanced data, addresses label imbalance, domain shift, and divergent label distributions across domains, and generalizes to all domain-class pairs. (Image by Author)

Challenges of multi-domain imbalanced learning

Yet, we note that, MDLT brings new challenges distinct from its single-domain counterpart.

(I) First, the label distribution for each domain is likely different from other domains. For example, in the gif figure above, both “Photo” and “Cartoon” domains exhibit imbalanced label distributions; Yet, the “horse” class in “Cartoon” has many more samples than in “Photo”. This creates challenges with divergent label distributions across domains, in addition to in-domain data imbalance.

(II) Second, multi-domain data inherently involves domain shift. Simply treating different domains as a whole and applying traditional data-imbalance methods is unlikely to yield the best results, as the domain gap can be arbitrarily large.

(III) Third, MDLT naturally motivates zero-shot generalization within and across domains — i.e., to generalize to both in-domain missing classes (gif figure right part), as well as new domains with no training data, where the latter case is typically denoted as Domain Generalization (DG).

To summarize, we can see that MDLT has new difficulties and challenges compared with traditional single-domain imbalanced classification. So, how should we do multi-domain imbalanced learning? In the next two sections, we will analyze this problem step by step from the overall modeling, motivating examples, observed phenomena, theoretical derivation, to the design of the final loss function, and finally improve the model performance on the MDLT task.

Domain-Class Transferability Graph

Here we first propose a set of definitions to model the problem of MDLT. We argue that in contrast to single-domain imbalanced learning where the basic unit one cares about is a class (i.e., minority vs. majority classes), in MDLT, the basic unit naturally translates to a domain-class pair.

Then when we start with “domain-class pairs”, we can measure the transferability (similarity) between them, which is defined as the distance between different domain-class pairs in the embedding space:

(Image by Author)

Intuitively, the transferability between two domain-class pairs is the average distance between their learned representations, characterizing how close they are in the feature space. By default, d is chosen as the Euclidean distance, but it can also represent the higher order statistics of (d, c). For example, the Mahalanobis distance uses the covariance. Then, based on transferability, we can further define transferability graph:

(Image by Author)

In the transferability graph, each node refers to a domain-class pair, and each edge refers to the distance (i.e., transferability) between two domain-class pairs in the embedding space. Moreover, we can visualize this graph in a 2D space using multidimensional scaling (MDS).

Overall framework of transferability graph. (a) Distribution statistics is computed for all domain-class pairs, by which we generate a full transferability matrix. (b) MDS is used to project the graph into a 2D space for visualization. (c) We define (α, β, γ) transferability statistics to further describe the whole transferability graph. (Image by Author)

Specifically, as shown in Figures (a)(b) above, for each domain-class pair, we can calculate the feature statistics (mean, covariance, etc.) of all data belonging to this domain-class pair. Then for different domain-class pairs, we further calculate the transferability between pairs, from which we generate a complete transferability graph, represented by a matrix (Figure a). We can then visualize this similarity on a 2D plane using Multidimensional Scaling (MDS). In Figure b, we can see that different domains are marked with different colors, each point represents a domain-class pair, its size represents the amount of data contained, and the numbers represent specific categories; and the distance between them can be seen as transferability. Obviously, we want the domain-class pairs of the same number (that is, the same category) to be closer, while the domain-class pairs of different categories are far away from each other. This relationship can be more abstracted into three transferability statistics: different domains but same class (α), different classes but same domain (β), and different classes and different domains (γ):

(Image by Author)

So far, we have modeled and mathematically defined MDLT. Next we will further explore the relationship between transferability (statistics) and the final MDLT performance.

What Makes for Good Representations in MDLT?

Divergent Label Distributions Hamper Transferable Features

MDLT has to deal with differences between the label distributions across domains. To understand the implications of this issue we start with an example.

Motivating Example. We construct Digits-MLT, a two-domain toy MDLT dataset that combines two digit datasets: MNIST-M and SVHN. The task is 10-class digit classification. We manually vary the number of samples for each domain-class pair to simulate different label distributions, and train a plain ResNet-18 using empirical risk minimization (ERM) for each case. We keep all test sets balanced and identical. The results reveal interesting observations. When the per-domain label distributions are balanced and identical across domains, although a domain gap exists, it does not prohibit the model from learning discriminative features of high accuracy (90.5%), as shown in Fig. a. If the label distributions are imbalanced but identical, as in Fig. b, ERM is still able to align similar classes in the two domains, where majority classes (e.g., class 9) are closer in terms of transferability than minority classes (e.g., class 0). In contrast, when the labels are both imbalanced and mismatched across domains, as in Fig. c, the learned features are no longer transferable, resulting in a clear gap across domains and the worst accuracy. This is because divergent label distributions across domains produce an undesirable shortcut; the model can minimize the classification loss simply by separating the two domains.

The evolving pattern of transferability graph when varying label proportions of Digits-MLT. (a) Label distributions for two domains are balanced and identical. (b) Label distributions for two domains are imbalanced but identical. © Label distributions for two domains are imbalanced and divergent. (Image by Author)

Transferable Features are Desirable. As the results indicate, transferable features across (d, c) pairs are needed, especially when imbalance occurs. In particular, the transferability link between the same class across domains should be greater than that between different classes within or across domains. This can be captured via the (α, β, γ) transferability statistics, as we show next.

Transferability Statistics Characterize Generalization

Motivating Example. Again, we use Digits-MLT with varying label distributions. We consider three imbalance types to compose different label configurations: (1) Uniform (i.e., balanced labels), (2) Forward-LT, where the labels exhibit a long tail over class ids, and (3) Backward-LT, where labels are inversely long-tailed with respect to the class ids. For each configuration, we train 20 ERM models with varying hyperparameters. We then calculate the (α, β, γ) statistics for each model, and plot its classification accuracy against (β + γ) − α.

Correspondence between (β + γ) − α quantity and test accuracy across different label configurations of Digits-MLT. Each plot refers to specific label distributions for two domains (e.g., (a) employs “Uniform” for domain 1 and “Uniform” for domain 2). Each point corresponds to a model trained with ERM using different hyperparameters. (Image by Author)

It reveals multiple interesting findings: (1) The (α, β, γ) statistics characterize a model’s performance in MDLT. In particular, the (β +γ)−α quantity displays a very strong correlation with test performance across the entire range and every label configuration. (2) Data imbalance increases the risk of learning less transferable features. When the label distributions are similar across domains (Fig. a), the models are robust to varying parameters, clustering in the upper-right region. However, as the labels become imbalanced (Figs. b, c) and further divergent (Figs. d, e), chances that the model learns non-transferable features (i.e., lower (β + γ) − α) increase, leading to a large drop in performance.

A Loss that Bounds the Transferability Statistics

We use the above findings to design a new loss function particularly suitable for MDLT. We will first introduce the loss function then prove that it minimizes an upper bound of the (α, β, γ) statistics. We start from a simple loss inspired by the metric learning objective. We call this loss L_{DA} since it aims for Domain-Class Distribution Alignment, i.e., aligning the features of the same class across domains:

(Image by Author)

Intuitively, this loss tackles label divergence, as (d, c) pairs that share same class would be pulled closer, and vice versa. It is also related to (α, β, γ) statistics, as the numerator represents positive cross-domain pairs (α), and the denominator represents negative cross-class pairs (β, γ).

But, it does not address label imbalance. Note that (α, β, γ) is defined in a balanced way, independent of the number of samples of each (d, c). However, given an imbalanced dataset, most samples will come from majority domain-class pairs, which would dominate the loss and cause minority pairs to be overlooked.

Balanced Domain-Class Distribution Alignment (BoDA). To tackle data imbalance across (d, c) pairs, we modify the loss to the BoDA loss:

(Image by Author)

BoDA scales the original d by a factor of 1/N_{d,c}, which is the number of samples in the domain-class pair (d, c). As such, it counters the effect of imbalanced domain-class pairs by introducing a balanced distance measure. We have the following theorem for BoDA:

(Image by Author)

Please refer to our paper for the proof details. Theorem 1 has the following interesting implications:

  1. BoDA upper-bounds (α, β, γ) statistics in a desired form that naturally translates to better performance. By minimizing BoDA, we ensure a low α (attract same classes) and high β, γ (separate different classes), which are essential conditions for generalization in MDLT.
  2. The constant factors correspond to how much each component contributes to the transferability graph. Zooming on the arguments of exp(·), we observe that the objective is proportional to α − ( β * 1/|D| + γ * (|D|−1)/|D|). According to Definition 3, we note that α summarizes data similarity for the same class, while the latter expression summarizes data similarity across different classes, using the weighted average of β and γ, where their weights are proportional to the number of associated domains.

Calibration for Data Imbalance Leads to Better Transfer

BoDA works by encouraging feature transfer for similar classes across domains, i.e., if (d, c) and (d′, c) refer to the same class in different domains, then we want to transfer their features to each other. But, minority domain-class pairs naturally have worse feature estimates due to data scarcity, and forcing other pairs to transfer to them hurts learning. Thus, when bringing two domain-class pairs closer in the embedding space, we want the minority (d, c) to transfer to majority ones, not the inverse.

There are many details here, so I will skip it directly. A detailed motivating example and interpretation are given in our paper. The conclusion is that the degree of transfer can be controlled by the relative sample size of the two domain-class pairs by adding a calibration term to BoDA:

(Image by Author)

Benchmarking MDLT datasets & Experiments

To support practical evaluation of multi-domain imbalanced learning methods, and to facilitate future research, we curate five MDLT benchmarks from existing multi-domain datasets:

(Image by Author)

In addition, we selected ~20 algorithms covering various categories such as multi-domain learning, distributionally robust optimization, invariant feature learning, meta-learning, imbalanced learning, etc. as baseline comparisons, and optimized hyperparameters for each algorithm. Such a process ensures that the comparison is best-vs-best and that the hyperparameters are optimized for all algorithms.

During evaluation, we report average accuracy across domains; we also report the worst accuracy over domains, and further divide all domain-class pairs into many-shot (pairs with over 100 training samples), medium-shot (pairs with 20∼100 training samples), few-shot (pairs with under 20 training samples), and zero-shot (pairs with no training data), and report the results for these subsets.

Experiments: Due to there are many experiments, here we only show the representative results here (please refer to the paper for all the results). First, BoDA consistently achieves the best average accuracy across all datasets. It also achieves the best worst-case accuracy most of the time. Moreover, on certain datasets (e.g., OfficeHome-MLT), MDL methods perform better (e.g., CORAL), while on others (e.g., TerraInc-MLT), imbalanced methods achieve higher gains (e.g., CRT); Nevertheless, regardless of dataset, BoDA outperforms all methods, highlighting its effectiveness for the MDLT task. Finally, compared to ERM, BoDA slightly improves the average and many-shot performance, while substantially boosting the performance for the medium-shot, few-shot, and zero-shot pairs.

(Image by Author)

Analysis on understanding BoDA: We conduct further analysis on BoDA. We plot the transferability graph learned via BoDA and compare it with ERM under different cross-domain label distributions. As can be seen from the figure below, BoDA learns a more balanced feature space to separate different categories. When the label distribution is balanced and identical, both ERM and BoDA can learn good features; when the labels start to be imbalanced (b, c), or even mismatched across domains (d, e), there is an obvious domain gap in the transferability graph learned by ERM; in contrast, BoDA can always learn a balanced and aligned feature space. As a result, better learned features translate into better accuracy (9.5% absolute accuracy gain).

(Image by Author)

Beyond MDLT: Imbalanced Domain Generalization

Domain Generalization (DG) refers to learning from multiple domains and generalizing to unseen domains. Since naturally the learning domains differ in their label distributions and may even have class imbalance within each domain, we study whether BoDA can improve performance for DG. Note that all datasets we adapted for MDLT are standard benchmarks for DG, which confirms that data imbalance is an intrinsic problem in DG, but has been overlooked by past works.

(Image by Author)

We test BoDA follow the standard DG evaluation. Table above reveals the following findings: First, BoDA alone can improve upon the current SOTA on four out of the five datasets, and achieves notable average performance gains. Moreover, combined with the current SOTA, BoDA further boosts the result by a notable margin across all datasets, suggesting that label imbalance is orthogonal to existing DG-specific algorithms. Finally, similar to MDLT, the gains depend on how severe the imbalance is within a dataset — TerraInc exhibits the most severe label imbalance across domains, on which BoDA achieves the highest gains. These intriguing results shed light on how label imbalance can affect out-of-distribution generalization, and highlight the importance of integrating label imbalance for practical DG algorithm design.

Closing remarks

To conclude this article, we proposed (1) a new task termed multi-domain long-tailed recognition (MDLT), and (2) a new theoretically guaranteed loss function BoDA to model and improve MDLT , and (3) five new benchmarks to facilitate future research on multi-domain imbalanced data. Furthermore, we find that label imbalance affects out-of-distribution generalization, and practical and robust DG algorithm design also needs to incorporate the importance of label imbalance. At the end, I attach several relevant links of our paper; thanks for reading!

Code: https://github.com/YyzHarry/multi-domain-imbalance

Project Page: http://mdlt.csail.mit.edu/

--

--

Ph.D. student in EECS @MIT. Interested in robust & generalizable machine learning, and AI for health & medicine.