When not blogging about Data Science, I work as a Senior Data Scientist at an e-commerce company called ShopRunner. Over the past year, our team has been building out large multi-task deep learning ensembles to predict relevant fashion attributes and characteristics of products within our product catalog using both images and text. Recently, our team open-sourced the main training pipelines and framework that we have been building internally to train our multi-task learners in a package called Tonks. Tonks is available to install on pypi, with the source code available on GitHub here.
As we went through the process of discussing open sourcing Tonks I realized a potential early use case could be upgrading a side project of mine where I am building a series of reinforcement learning (RL) agents to play the mobile phone game Fate Grand Order (FGO), which I’ve nicknamed Project Pendragon.
Project Pendragon has two main parts, the feature extraction pipeline which sends inputs to the RL agents that make decisions and send commands back to play the game. While my RL agents that play FGO have been repeatedly upgraded, my feature extraction pipelines are basically where they were at a year ago.
This post will cover how I replaced my original feature extraction pipeline which used three large convolutional neural networks (two ResNet 50s and a ResNet 34) and replaced these three models with a single Tonks multi-task ResNet 50 trained with multiple datasets.

Tonks
Our ShopRunner team has built out our PyTorch based Tonks library to help us build large multi-task network ensembles using both images and text. In most of those use cases, we care about being able to return relevant attributes about products based on the provided images, descriptions, title, and other information.
When you build multi-task learners, you are usually attacking a problem with the mindset that features learned on one task might be beneficial to another. For us in the e-commerce space, two tasks like this might be dress length and sleeve length where two single task models would likely learn to look for lines and length while not worrying about color, pattern, and background. When tasks meet these criteria it means we can combine the tasks into multi-task networks and in tanks, we do this by having a core model such as a ResNet for images or a Bert model for text and connect the outputs of these core models to our individual task heads
When your tasks are all within the same domain multi-task learning is useful because it lets you build and maintain a single model instead of using multiple single task learners. In my current pipeline, I train and use 3 large CNNs where I build tailored datasets for each task previously. Tonks is built to handle situations where you train using tailored datasets for individual tasks so here it helps me cut down on the number of models I need to maintain and lets me train with my previously used datasets.
Something to think about when building multi-task models is that when tasks are not within similar domains, multi-task models can suffer from destructive interference, __ where conflicting signals from different tasks pull the model in different directions. A discussion of how to handle this destructive interference could be a good follow- up post or talk, but is beyond the scope of this post. For my FGO use case, my gut feeling was that the problem was fairly doable since all the tasks are all using FGO screenshots, similar-ish text, color palettes, etc., so I would likely not have issues with destructive interference.
For more details about Tonks see our launch post.

Project Pendragon
In the Fall of 2018, I started making some basic bots to play FGO, but before I go into the details of the bots, I’ll give a quick rundown of what FGO even is. Fate Grand Order is a turn-based mobile phone game where you pick between 3 and 6 characters with various stats and abilities. You then use that team to fight through waves of enemies, until all enemies or the team is defeated.
My main motivation for building these bots is that as part of regular events in FGO, players are often required to farm (repeatedly play) levels dozens, if not hundreds, of times. For a recent event, I did 100 farming runs over a week-long period where each level takes between 3–5 minutes to play (so 5–8 hours of gameplay in total). So with that in mind, I thought building a bot to be able to do this repetitive task for me would be a great side project! This once-simple "side project" has since turned into an interesting on-again-off-again year-long rabbit hole with a lot of interesting additions like multiple reinforcement learning agents and various custom game environments.
Despite making many upgrades to the codebase around the bots and how the decisions get made, I really haven’t touched my feature extraction pipelines to get information to my RL game bots.
Feature Extractor One: Whose turn is it now?
FGO is a turn-based game, and in order for the bots to play the game, I needed to be able to detect when it was actually their turn. The way that I decided to detect when the bots’ turn was starting was to look for the "attack" button that appears on the main combat screen before command cards are picked.
Below is a sample main combat screen where the attack button is highlighted in the bottom right corner.

I structured this network to have 2 classes, "attack" and "not attack," basically meaning the network is trained to detect whether or not the attack button is present in that portion of the game screen. If it is, then it means that it is the bots’ turn and you can do useful things like take actions/use skills on the present screen or go ahead and bring up the command card screen where you can pick which of the 5 cards to play.
Picking command cards is the main combat mechanic of FGO. So once I was able to detect whose turn it was, it was time to build a classifier that would help me identify what command cards had been dealt that turn. My first bots used this information to play algorithmically while later bots were trained via reinforcement learning to choose which cards to play, but all of them needed to be able to identify which cards had been dealt as a key part of the feature extraction process.
Feature Extractor Two: What command cards were dealt?
The main combat mechanic in FGO is picking "command cards" on your turn. There are 3 types of cards: "Arts", "Buster", and "Quick," and each card type does slightly different things. In each turn, 5 cards are dealt and the player has to pick 3 of them to play that turn. Below is a sample of the 5 cards presented and 3 being picked.

While I could have built a detector to find the card location behind the scenes, I found an easier solution. These screens are relatively consistent and the cards are placed in the same spot, so what I opted to do is hard code the locations of the five command cards and crop them out of screenshots of the command card screen (see below for an example). Then, I passed each of the five cards through a PyTorch-trained CNN to determine the card "type".

These card type classifications can then be fed into various RL agents and used to make decisions about which cards to play on that given turn. For a long time, these two networks (attack button and card type detector) were the core feature extractors for my first FGO reinforcement bot, nicknamed Pendragon Alter. They let me make a bot that could do the main combat in FGO and also play through a full game in an automated way. From this point onward, I really just had to think about what other information I would need to play the game and how to extract that from the game.
Feature Extractor Three: What wave of enemies are we on?
The final network I added into my framework is actually a wave counter that I use as an input into several different versions of my bots. The reason I added this is that FGO levels are almost always structured as having between 1 and 3 rounds of enemies that you have to fight through, and the actions an agent may want to take may depend on the round. For instance, the first wave of enemies may be relatively weak, but the third wave can be quite strong, so saving abilities for the third wave is often a good tactic.
I highlighted the round counter in red near the top of the below screenshot.

While I could have used some sort of optical character recognition to get the number, all I really cared about is telling whether it is round 1, 2, or 3, so I trained a CNN to have three classes mapping to those numbers and every time the attack button is detected, I check to see what round number it currently is.
These three networks achieve the basic feature extraction that I need, but needing to run three large CNNs simultaneously adds quite a bit of GPU or CPU computational overhead and additional storage (2 Resnet 50s are ~220MB and 1 Resnet 34 is ~84Mb).
Tonks Training Pipeline
Our Tonks pipeline follows the general framework setup by Fastai where the pipeline is organized into data loaders, models, and learners. We end up managing the multiple tasks and multiple datasets using a variety of dictionaries which help with bookkeeping.
The following sections will show code snippets that appear in the linked training notebook below and discuss what is happening in them.
The notebook I used for training is located here
Tonks Dataset and Dataloaders
Our custom Tonks dataset (line 26), FGOImageDataset, follows fairly standard PyTorch dataset layouts where we need to provide a way to index appropriate values as part of the data generator (lines 58–59), apply transformations (line 61), and return the image data and label (line 65).
Lines 11–24 get applied depending on whether or not we are looking at the training for validation set. I like keeping the transforms in this format, but it’s just a personal preference.

Once we have created our custom Tonks dataset class, we can start creating our training and validation datasets for our three tasks. This part of the process is also fairly similar to a standard PyTorch training pipeline, where you have to place your training and validation splits into datasets and then eventually into dataloaders. The only difference here is that we have three datasets instead of the normal one.

In this above section of code, we are creating training and validation datasets for each class using the custom dataset class I showed previously. This involves specifying the x and y inputs (file paths to images and their labels) as well as what transforms we would like to apply. In this pipeline, I just have randomized crops in the training set with ImageNet normalization and only normalization in the validation sets. Once the datasets are created we make a dictionary of base PyTorch dataloaders where the keys are the names of the tasks and the values are the dataloaders associated with those tasks. The idea here is that we can keep track of which dataset we should generate batches for as part of our multi-dataset training pipeline.
The below snippet shows how we place our two dictionaries of PyTorch dataloaders into two Tonks MultiDatasetLoaders. These Tonks dataloaders are what we use to integrate with our multi-task multi-dataset training pipeline.

Example Tonks Network Architecture
The next major piece of the pipeline is the model. While we provide some sample image and text model architectures, it might make sense to customize the architectures to your needs. For me, I just made a simple ResNet50-based architecture and connected that to the individual task layers. In Tonks, we handle this part with two PyTorch ModuleDicts called the pretrained_classifiers
and new_classifiers
. The idea we had here is that the first time you train a network we send tasks to the new_classifiers
dictionary and when we save the trained network these get moved to the pretrained_classifier
dictionary for saving. Then on subsequent runs, we can load pretrained task heads into the pretrained_classifier
dictionary. This helps us with keeping track of where to apply learning rates (since you might want to have different learning rates depending on whether or not a task head has previously been fine tuned or not)

Loading Tonks Model
The main inputs when we are loading an instance of a Tonks model is in the task_dictionary
. There are two potential ones: the first is a new_task_dict
, which is the one you should use the first time you are training a model, while the second is a pretrained_task_dict
, which is where you would place tasks where there are already existing Tonks pre-trained weights for those tasks. For us at ShopRunner, this is useful because we can now potentially add on new tasks to existing models with ease.
For this FGO example, I have three tasks that are all new. I will place a dictionary with the task name and the number of categories in each task into a dictionary called new_task_dict
and feed that into the model class when I initialize it. This is what tells the Tonks model to create three task heads with a certain number of nodes in each.

Training
Once we have our model initialized, all we need to do before we kick off our training run is define a loss function for the various tasks, specify an optimizer, assign learning rates, create our learner, and call the fit function.
For this pipeline, each of the tasks is a multi-class problem, so we use a cross entropy loss. When we assign learning rates, we may assign different learning rates to different sections of the model. For the main ResNet section, we assign a low 1e-4
learning rate, but for the new sections, we assign a more aggressive 1e-2
learning rate.The main idea behind this is that we don’t really want to drastically change the ImageNet weights in the ResNet50 core model, but since the new classifier layers are randomly initialized, we can let them be adjusted more aggressively. Then we define a scheduler to decrease the learning rate every 2 epochs.
Once that is done, we can define our learner using the Tonks MultiTaskLearner
class. This class contains all the functionality we need to train our models and takes in the model architecture we loaded previously, the training and validation Tonks dataloaders, and the task dictionary, which contains a mapping of all our tasks which is used to retrieve batches from our dataloaders.
Finally, we can call fit()
on our learner. For details on the different arguments you can check our read the docs.

Results
Once I finished training the three-task Tonks network, I just had to replace the three networks in my Project Pendragon repos with the new multi-task network. Since Tonks is built on PyTorch, in order to use a model, all you need to keep track of is the model architecture and the weights file, so you don’t necessarily need to install Tonks and all of its dependencies to use the trained models in new projects.
The Tonks model performs strongly across all tasks and so far I have not had any issues in deployment in my FGO pendragon game interface. I have been using the Tonks model for my other recent developments as I continue to improve my RL agents. The most recent installment is getting the agents to play using coordinated strategies.

So while this example is a cute one based on a mobile phone game, the reasons to use a framework like Tonks are the same here as they are when we look at industrial scale problems. It removed my need to maintain multiple single task learner networks and I was able to train quickly and easily because I was able to use three existing datasets I had previously built.
Tonks is a library that our ShopRunner Data Science team has been using to build industrial scale multi-task Deep Learning ensembles using both images and text trained with multiple datasets. For us, it has made it relatively straightforward for our team to create multi-task models to meet new and varying needs. Since building multi-task learners is a very real world need, but one that is not currently supported we open-sourced our work to help give back to the data science community.