The world’s leading publication for data science, AI, and ML professionals.

Distribute your Pytorch model in less than 20 lines of code

A guide to make parallelization less painful

Photo by Nana Dua on Unsplash
Photo by Nana Dua on Unsplash

When you approach Deep Learning for the first time, you learn that you can speed up your training by moving the model and data to the GPU. That’s definitely a significant improvement compared to training on the CPU, as you can now train your models and see results in way less time.

Great, but let’s suppose that your model becomes too big (e.g. transformers) for a single GPU to handle a batch size greater than relatively small values such as 8. Or maybe you have so much data that an epoch takes too many iterations, and you want to increase your batch size.

So you buy another Gpu to have more capacity to store and process the data. Unfortunately, you can’t execute your script and expect it will distribute itself on the GPUs, but you have to set up your environment and code.

In this guide, I will show how to distribute a minimal training pipeline on more than one GPU. Note that I will only be treating data parallelism, which refers to splitting the dataset into different processes (each of which runs a copy of the whole model). This type of parallelism allows for computing on larger batches. Model parallelism enables each sub-process to run a different part of the model, but we won’t cover this case in this guide.

In Pytorch, there are two ways to enable data parallelism:

  • DataParallel (DP);
  • DistributedDataParallel (DDP).

DataParallel

Let’s start with DataParallel, even if I won’t use it in the example.

This module works only on a single machine with multiple GPUs but has some caveats that impair its usefulness:

  • The model is replicated on all GPUs, at each forward step, introducing a significant overhead and slowing performance;
  • It employs multithreading, which suffers from shared resource contention issues;
  • It can work on a single node.

Even if working on a single node, it is best to use DDP.

DistributedDataParallel

There are several reasons why DDP should be the preferred choice, despite the increase in setup complexity.

  • It uses multiprocessing, ensuring that each GPU does not have to share access to resources with the others, solving the problem of semaphore contention between threads;
  • It also works on multiple nodes, which allows the code to be easily ported from a single node machine to a multi-node cluster with little to no changes to the code;
  • DDP also supports different multiprocessing backends (e.g. nccl, gloo, mpi) and other fancy stuff we won’t cover in this practical guide.

Implementation

Now that we understand why we should use DDP, let’s implement it in our training pipeline.

For this guide, let’s suppose you already have a working script of a minimal CIFAR10 classifier, and that you want to parallelize it. I will show the comparison between the original code and its distributed version.

We need to implement the following steps:

  1. Distributed processes initialization
  2. Distributed data loading
  3. Model wrapping in DDP
  4. Training script
  5. Collect metrics
  6. Launch the script in distributed mode

1. Distributed process initialization

Before we can start working with distributed training, we have to tell the processes we will assign to each GPU that they belong to the same group.

Remember: when you launch a script in distributed mode, you are executing a copy of your code on each process.

First, we need to define the address and the port number (use any free port) of the master node of the distributed cluster where our processes will be running (as we are assuming a single node with multiple GPUs, we set localhost).

Then we use the init_process_group function from the torch.distributed package, to assign the processes to the same group. Here we set the backend (we will be using nccl) and the total number of processes _(world_size = n_processes * ngpus).

Multiprocessing is prone to race conditions. The barrier function solves this issue, enabling synchronization by blocking each process until all the other ones reach the same function call. Imagine a race where the processes are the participants: the barrier function is the starting pistol. The important thing is that the pistol does not fire if every participant is not in place (i.e. if not all processes have called the barrier function).

SLOC: +4

Photo by Braden Collum on Unsplash
Photo by Braden Collum on Unsplash

2. Distributed data loading

Data loading is the first step we take when training a model.

This step does not require much, we only need to set up a distributed sampler for the training data. Its role is to create N disjunct subsets of batch indices, where N is the number of processes, and assign each subset to a process.

Warning: with the batch_size value that we pass to the data loader, we are specifying the batch size on each GPU. When using DistributedSampler, the final batch size will be the size specified in the DataLoader, multiplied by the number of subprocesses, as each one calls a different DataLoader instance.

Before:

After:

SLOC: +1

3. Model wrapping in DDP

After distributing the data, we need to tell our model that it has to handle distributed gradient aggregation and parameters updates, so we wrap our model in the DistributedDataParallel module by specifying in the device_ids parameter the process ID where we want to send the model (an index between 0 and N_GPUS-1). For instance, the process running on the GPU 4 will use 3 as device_ids. For this reason, I let each process store in theargs.gpu variable its own ID (we will see this in detail shortly).

Before:

After:

SLOC: +1

WARNING: If you accidentally send your model to cuda:0 before the DDP wrapping (maybe due to old residual code), the script freezes with no debug error, and you will need to put its existence to an end with CTRL-C. So, be careful to move the model to the correct device, by specifying the integer associated with the device index (e.g. if you have 8 GPUs, call net.to(4)to send the model on the 5th GPU)

The best thing to do is to dynamically set the device id, taking it from a local environment variable, typically LOCAL_RANK, which assigns a different integer to each GPU within the node. I usually set two global variables: args.gpu, and args.distributed. The first one stores the local rank of the GPU, while the latter is a boolean determining whether the script is running in a distributed environment or not.

The initialization code will be extended to something like this

SLOC: +10

4. Training script

Now that we have all set up for our training, we can start iterating over our batches and run our data through our pipeline. We need to set one last thing before we train the model: we need to tell the distributed sampler the current epoch before creating the DataLoader iterator.

The sampler shuffles batches between epochs using a random seed and the epoch number. If we don’t set the current epoch number on the sampler, the batches will follow the same ordering at each epoch, introducing a bias that could harm the training.

Before:

After:

SLOC: +1

5. Collect metrics

OK, so now we trained our model at the speed of light, but how should we compute metrics across multiple processes?

Torchmetrics is a library that simplifies the computation of several metrics and covers the most common ones in Deep Learning. Each metric class accepts a parameter that allows aggregating the metrics computed by each process.

We create an object responsible for computing and collecting the desired metric from all processes. When instantiating the object, we set the _dist_sync_onstep flag to True, to specify that we want to synchronize metrics from all processes at each step (when we call the update method).

Before:

After:

SLOC: -2

6. Launch the script in distributed mode

One last thing: we also need to change the command we use to launch the script, defining the number of processes we want to spawn.

Before:

After:

BONUS: Print once

If your code contains any print function calls, when you run it in distributed mode, you will see your console output become chaotic. The reason for this is once again the fact that the same code runs multiple times, meaning that each process calls the print function. To avoid this behavior, we can check the rank of each process and call the print function only if this is equal to zero.

Conclusion

And 15 lines of code is all you need to distribute your pipeline (actually, I cheated a little bit not counting the imports and function definition lines, but you’ll forgive me).

Even if it is this easy, problems can come in different forms, usually as silent failures or the script freezes and does not show any debugging error.

Distributed processing is not an easy path overall, and many things happen under the hood of the few functions we added, but the speed-up of the training process can be worth the effort (even if there are exceptions).

Here you can find the link to the repository containing the serial and the parallel version of the script, as shown in the guide.

Recommended readings

[1] – Getting Started with Distributed Data Parallel – PyTorch Tutorials 1.11.0+cu102 documentation

[2] – Writing Distributed Applications with PyTorch – PyTorch Tutorials 1.11.0+cu102 documentation

[3] – Parallel and Distributed Deep Learning


Related Articles