Multitask learning in TensorFlow with the Head API
An introduction to custom estimators for multitask learning
A fundamental characteristic of human learning is that we learn many things simultaneously. The equivalent idea in machine learning is called multi-task learning (MTL), and it has become increasingly useful in practice, particularly for reinforcement learning and natural language processing. In fact, even in standard single-task situations, additional auxiliary tasks can be devised and included in the optimization process to help learning.
This post provides an introduction to the field by showing how to solve a simple multi-task problem in an image classification benchmark. The focus is on an experimental component of TensorFlow, the Head API, that helps in designing custom estimators for MTL, by decoupling the shared component of the neural network from the task-specific ones. Along the route we will also have the opportunity to discuss additional features from the TensorFlow core, including tf.data, tf.image, and the custom estimators.
The code for the tutorial is available as a fully contained Colab notebook, feel free to test and experiment!
Content at a glance
To make the tutorial more interesting, we consider a realistic use case, by re-implementing (a part of) a 2014 paper: Facial Landmark Detection by Deep Multi-task Learning. The problem is simple: we are given a facial image, and we need to localize a series of landmarks, i.e., points of interest on the image (nose, left eye, mouth, …), and tags, including the age and gender of the person. Each landmark/tag constitutes a separate task on the image, and the tasks are clearly correlated (i.e., think of predicting the position of the left eye knowing first where the right one is).
We split our implementation in three parts: (i) loading the images (using tf.data and tf.image); (ii) implementing the convolutional network from the paper (using the custom estimators of TF); (iii) adding the MTL logic with the Head API. Lot of things to see, no time to loose!
Step 0 — Loading the dataset
After downloading the dataset (link), a quick inspection reveals that the images are split among three different folders (AFLW, lfw_5590, and net_7876). Train and test splits are provided through different text files, each row of which corresponding to the path of an image and labels:
For simplicity we will use Pandas to load the text files and adjust the path URLs to the Unix standard, e.g. for the training part:
Since the text file is not huge, using Pandas in this case is slightly easier and provides a bit of flexibility. For larger files, however, a better choice is to use directly the tf.data object TextLineDataset.
Step 1 — Working with tf.data and the Dataset object
Now that we have our data, we can load it using tf.data to make it estimator-ready! In the simplest case, we can just slice through the Pandas’ DataFrame to obtain our batches of data:
Previously, a major problem of using tf.data with Estimators was that debugging the dataset was rather complex, having to go through tf.Session objects. From the latest versions, however, it is possible to debug the datasets with eager execution enabled, even when working with estimators. As an example, we can use the Dataset to build batches of 8 elements, take the first batch, and print everything on screen:
It is now time for loading the images starting from the paths! Note that in general this is not trivial, because images can come in many different extensions, sizes, some in black-and-white, and so on. Luckily for us, we can take inspiration from a TF tutorial to build a simple function to encapsulate all this logic, leveraging the tools in the tf.image module:
The function takes charge of most parsing problems:
- The ‘channels’ parameter allows to load both colour and b/w images in a single line;
- we resize all images to our desired format (40x40, in accordance with the original paper);
- on line 8, we also normalize our landmark labels to denote a relative location between 0 and 1, instead of an absolute one (since we resized all images, and images might come with different shapes).
We can apply the parsing function to each element of a Dataset using its internal ‘map’ function: putting this together with some additional logic for training/testing, we obtain our final loading function:
Step 2 — Building a convolutional network with the custom estimators
As next step, we want to replicate the convolutional neural network (CNN) taken from the original paper:
The logic of the CNN is composed of two parts: the first one is a generic feature extractor for the entire image (which is shared across all tasks), while for each task we have a separate, smaller model acting on the final feature embedding of the image. For reasons apparent below, we will refer to each of these simpler models as an ‘head’. All heads are trained simultaneously through gradient descent.
Let us start from the feature extraction part. For this, we leverage the tf.layers objects to build our main network:
For the moment, we will focus on a single head/task, i.e., estimation of the nose position in the image. One way to do this is using custom estimators, allowing to combine our own model implementation with all the functionalities of a standard Estimator object.
One drawback of custom estimators is that their code tends to be quite ‘verbose’, because we need to encapsulate the entire logic of the estimator (training, evaluation, and prediction), into a single function:
Roughly speaking, the model function receives a mode parameter, that we can use to distinguish what kind of operation (e.g., training), we are expected to do. In turn, the model function exchanges all the information with the main Estimator object via another custom object, an EstimatorSpec:
Not only does this make the code harder to read, but most of the code above tends to be ‘boilerplate’ code, that only depends on the specific task we are facing, e.g., using the mean squared error for a regression problem. The Head API is an experimental feature designed to simplify the writing of the code in this type of situations, and it is our next topic.
Step 3a — Rewriting our custom estimator with the Head API
The idea of the Head API is that the main prediction component (our model function above) can be generated automatically once a few key items are specified: the feature extraction part, the loss, and our optimization algorithm:
In a sense, this is a similar idea as the high-level interface of Keras, but it still leaves enough flexibility to define a series of more interesting heads, as we will see shortly.
For the moment, let us rewrite the previous code, this time by using a “regression head”:
For all intents and purposes, the two models are equivalent, but the latter is more readable and less prone to errors, as most of the estimator-specific logic is now encapsulated inside the head. We can train either of the two models using the ‘train’ interface of the estimator, and start getting our predictions:
Please do not confuse the Head API (which is in tf.contrib) with tf.contrib.learn.head, which is deprecated.
Step 3b — Multi-task learning with the multihead
We finally get to the more interesting part of this tutorial: the MTL logic. Remember that, in the simplest case, doing MTL is equivalent to having ‘multiple heads’ on top of the same feature extraction part, as shown schematically here:
Mathematically, we can optimize all tasks jointly by minimizing the sum of the task-specific losses. For example, assume we have loss L1 for the regression part (mean squared error over each landmark), and L2 for the classification part (different tags), we can minimize L = L1 + L2 through gradient descent.
After this (quite lengthy) introduction, you might not be surprised that the Head API has a specific head for this situation, called the multi-head. As per our previous description, it allows linearly combining multiple losses deriving from separate heads. At this point, I will let the code speaks for itself:
For simplicity I am only considering two tasks: the prediction of the nose position, and of the face ‘pose’ (left profile, left, frontal, right, right profile). We just need to define two separate heads (a regression one, and a classification one), and combine them with the multi_head object. Adding more heads now is only a matter of a few lines of code!
There is a slight modification to the input function that we omitted here for brevity: you can find it on the Colab notebook.
The estimator at this point can be trained with standard methods, and we can get both predictions simultaneously:
Concluding…
I hope you appreciated this small introduction to the MTL problem: if you are interested, I highly suggest this quite informative post by Sebastian Ruder to learn something more about the field. More in general, talking about MTL was the perfect excuse to introduce some interesting concepts from the TF framework, most notably the Head APIs. Don’t forget to play around with the full notebook on Google Colab!
This article appeared originally in Italian on the blog of the Italian Association for Machine Learning: https://iaml.it/blog/multitask-learning-tensorflow.