The world’s leading publication for data science, AI, and ML professionals.

Optimizing Multi-task Learning Models in Practice

What is multi-task learning models, and how to optimize them

Photo by Laura Rivera on Unsplash
Photo by Laura Rivera on Unsplash

Why Multi-task learning

Multi-task learning

Multi-task learning (MTL) [1] is a field in machine learning in which we utilize a single model to learn multiple tasks simultaneously.

Multi-task learning model (Image by the author)
Multi-task learning model (Image by the author)

In theory, the approach allows knowledge sharing between tasks and achieves better results than single-task training. Moreover, as the model tries to learn a representation to optimize multiple tasks, there is a lower chance of overfitting and, hence, better generalization.

Multitask Learning is an approach to inductive transfer that improves generalization by using the domain information contained in the training signals of related tasks as an inductive bias. It does this by learning tasks in parallel while using a shared representation; what is learned for each task can help other tasks be learned better. [2]

In practice, large recommendation and search systems often measure user satisfaction based on multiple metrics, such as stay time, click-through rate, and long-click rate. Each metric can be easily measured and optimized, but building many models is challenging and resource-intensive.

MTL is a good approach to solving these high-dimensional problems. Moreover, MTL utilizes resources better; it reduces the serving computation N times by combining N tasks.

Challenges

However, training and tuning MTL models requires effort as they are sensitive to task correlation. When 2 tasks are loosely correlated or even have conflicting goals, one task’s knowledge may harm the learning process of the other and vice versa. This phenomenon is called negative transfer.

For example, in e-commerce, it’s common to combine click-through rate (CTR) and conversion rate (CVR) predictions into one single model. At first glance, the tasks seem to have a strong correlation. However, many items were designed as clickbait, so they will have high CTR but low CVR. Therefore, training MTL models with such samples will result in low performance in the CVR prediction.

In this article, I will introduce some practical optimizations for MTL to learn multiple tasks and transfer knowledge effectively. Let’s begin!

Multi-gate Mixture-of-Experts (MMOE)

As we have mentioned above, MTL models are especially sensitive toward task relations. Therefore, it’s essential to understand the tradeoff between task objectives and task relation. The paper [3] proposed to explicitly learn the task correlation with Multi-gate Mixture-of-Experts (MMOE).

MMOE was inspired by the Mixture-of-Experts structure, which selects the subnets based on input data. This not only increases the modelling power but also introduces sparsity into gating networks, reducing computation costs. This is referred to as One-gate MOE (OMOE).

One-gate Mixture-of-Experts network [4]
One-gate Mixture-of-Experts network [4]

To further improve OMOE, the authors added separate gating networks for each specific task. Each gating network is simply a linear transformation combined with a softmax layer. We refer to it as multi-gate mixture of experts (MMOE).

Compared to OMOE, MMOE allows each task to utilize the experts differently. Therefore, when the correlation between tasks is low, the tasks can decide to use less input from the shared experts and use more from other experts.

Multi-gate Mixture-of-Experts network [4]
Multi-gate Mixture-of-Experts network [4]

To prove the role of gate networks, the authors built a synthetic dataset with tasks having different degrees of correlation. As we can observe from the image below, the performance of OMOE and MMOE is close when tasks are related or identical. However, when the correlation is low, MMOE is consistently superior.

The average performance of MMoE, OMoE, and Shared-Bottom on synthetic data with different correlations [4]
The average performance of MMoE, OMoE, and Shared-Bottom on synthetic data with different correlations [4]

GradNorm

While MMOE focuses on optimizing network structure, another important aspect of MTL is weighted losses. For MTL, the loss is the weighted sum of all tasks’ losses. These weights play an essential role in updating the model parameters, especially the shared parameters.

Multi-task weighted loss (Image by the author)
Multi-task weighted loss (Image by the author)

A naive approach is to set an equal weight for each task (w_i =1), but this is suboptimal as each task’s loss may have different scales. For example, our MTL model predicts a person’s age and their house’s price in USD. The Mean squared error loss of the house price prediction task is much bigger than that of the age prediction task, hence likely to dominate the gradient update on shared layers. This may cause a drop in the age prediction task.

GradNorm[4] introduced a new optimization, gradient loss, in which backpropagation is applied on w_i only. The more imbalanced the gradient update between tasks is, the bigger the gradient loss is. This ensures that each task contributes equally to updating the shared layers W.

Simple Gradient loss (Image by the author)
Simple Gradient loss (Image by the author)

The authors even went a step further and optimised gradient loss based on each task’s learning rate. The inverse learning rate is defined as follows.

Task inverse learning rate (Image by the author)
Task inverse learning rate (Image by the author)

Intuitively, if a task’s loss decreases faster, it also converges more quickly. If a task is converged faster than others, we should deprioritize it. Hence, it should contribute less to the gradient update of shared layers. Based on this intuition, we add the inverse learning rate to the original formula.

GradNorm loss (Image by the author)
GradNorm loss (Image by the author)

The authors used a synthetic dataset with different task scales to illustrate how the solutions work. With equal weighting, the big-scale tasks will dominate the gradient update on shared parameters. Using GradNorm, small-scale tasks are prioritized, hence, their weights are significantly bigger than large-scale ones. The optimization consistently outperformed the equal-weighting baseline.

Gradient norm on a toy 10-task system [4]
Gradient norm on a toy 10-task system [4]

Gradient-Blending

In this paper [5], the authors also posed concerns about the different convergence speeds of tasks when training a multimodal network. Namely, training a model using multimodal features (e.g. video, audio, etc.) should outperform unimodal approaches as we use more information as input. However, in reality, the best unimodal often beats the multimodal.

After investigation, the reasons can be listed as follows:

  • The multi-modal network has more parameters compared to the uni-modal and, hence, is more prone to overfitting.
  • Different modalities generalize and overfit at different rates, so naively training them together has suboptimal results.

The authors suggested using joint training of different modalities with weights tuned according to tasks’ generalization and overfitting rates. The rates can be defined as follows.

Measure a task' overfitting and generalization rate (Image by the author)
Measure a task’ overfitting and generalization rate (Image by the author)

At time step t, the weight loss of task i is defined as follows.

Task weight estimation (Image by the author)
Task weight estimation (Image by the author)

As we can see, when the task is overfitting (smaller _G_i, bigger O_i), w_i_ will be penalized and become smaller.

The weight updates can be computed once or multiple times after every n epoch. This solution is shown to outperform different optimizations on many multimodal benchmark datasets.

The solution (G-Blend) works on different multi-modal problems [5]
The solution (G-Blend) works on different multi-modal problems [5]

Reference

[1] https://en.wikipedia.org/wiki/Multi-task_learning

[2] Caruana, Rich. "Multitask learning." Machine Learning 28 (1997): 41–75.

[3] Ma, Jiaqi, et al. "Modeling task relationships in multi-task learning with multi-gate mixture-of-experts." Proceedings of the 24th ACM SIGKDD international conference on knowledge discovery & data mining. 2018.

[4] Chen, Zhao, et al. "Gradnorm: Gradient normalization for adaptive loss balancing in deep multitask networks." International conference on machine learning. PMLR, 2018.

[5] Wang, Weiyao, Du Tran, and Matt Feiszli. "What makes training multi-modal classification networks hard?." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2020.


Related Articles