How to Accelerate your PyTorch GPU Training with XLA

The Power of PyTorch/XLA and how Amazon SageMaker Training Compiler Simplifies its use

Chaim Rand
Towards Data Science

--

Photo by Patrick Fore on Unsplash
by Author

In many of our past posts (e.g., here) we have emphasized the importance of managing the cost of training. We are in constant pursuit of ways to increase the runtime performance of our training through an iterative process of 1. profiling our workloads in order to identify performance bottlenecks and resource under-utilization, and 2. optimizing our workloads to remove bottlenecks and increase resource utilization.

In this post we explore the potential of optimizing our PyTorch training step by using XLA compilation. We will begin with a brief primer on XLA. We will follow with a demonstration of its use on AWS, and show how Amazon SageMaker Training Compiler simplifies its configuration. We will conclude with some of the (current) limitations that come along with XLA.

Disclaimers

Our intention in this post is to introduce the potential opportunity for accelerating training with PyTorch/XLA. Our intention is not to endorse its use, nor the use of any other training tools we will mention. As always, the best training configuration is highly dependent on the specific details of your model. The poem I opened with was included for entertainment purposes (I couldn’t resist). Despite its implications, there is no guarantee that PyTorch/XLA will speed up your training. While our focus (and demonstration) will be on the Amazon SageMaker training environment, most of what we will discuss pertains to other training environments as well.

JIT Compilation with XLA

First created for the purpose of accelerating TensorFlow workloads, XLA (which stands for Accelerated Linear Algebra) employs Just in Time (JIT) compilation techniques in order to optimize the computation graph for the underlying accelerator. More specifically, at runtime, the XLA compiler will analyze the full computation graph that is associated with the model, fuse together successive tensor operations, and output optimal machine code for the underlying accelerator. This is contrary to the standard method in which each tensor operation is optimized independently. There are several ways in which fusing graph components can boost performance, including:

  1. reducing the number of overall machine level operations (FLOPS),
  2. reducing the number of computation kernels that need to be loaded into the accelerator by the CPU, and
  3. reducing the memory footprint of the computation graph. The freed-up space can be used to increase the training batch size which could, potentially, lead to further performance gains.

Optimizing the full graph rather than individual operations offers the potential for reduced training time as demonstrated in the XLA overview page.

As a side note, another advantage of XLA is that it was specifically designed to make supporting new backend devices easy. See here and here for more details.

XLA for PyTorch

Contrary to TensorFlow, implementing XLA for PyTorch posed a unique challenge as we explain in the following subsection.

Eager Execution vs. Graph Execution

Deep learning frameworks can be classified according to the mode in which they represent and execute machine learning models. Some frameworks, most notably TensorFlow (by default in v1 and via tf.function in v2), support graph mode, in which the model is first represented as a computation graph (in Python) and then processed and executed by a separate execution engine (e.g., in C++). Other frameworks, such as PyTorch, execute their models eagerly, i.e., the model is not stored in a data structure but rather each line of code is directly executed. Eager mode is considered by many to be easier to debug and to enable greater programming expressivity. This is often seen as the reason behind the recent increase in the popularity of the PyTorch framework (as of the time of this writing).

Check out this post for more on the differences between the two execution modes.

Clearly, the JIT compilation mechanism of XLA aligns well with the execution methodology of graph mode given the separation between the graph definition and the graph execution. Supporting XLA together with eager execution mode requires more creativity. There are several approaches for solving this (see section 2 of this paper for a brief survey). As we will see, the method introduced and used in the PyTorch/XLA library is the Lazy Tensor system.

PyTorch/XLA

PyTorch/XLA is a Python library that was created with the primary intention of using XLA compilation to enable PyTorch based training on Google Cloud TPUs (e.g., see here). The approach underlying the PyTorch/XLA is the Lazy Tensor system. A Lazy Tensor is a custom tensor type referred to in PyTorch/XLA as an XLA Tensor. Contrary to a standard PyTorch tensor, operations are not immediately (or “eagerly”) executed, but rather collected into sequences of operations that form an intermediate representation (IR) graph. When prompted (e.g., at the end of each training step) the IR graph is passed to the XLA compiler where it undergoes optimization for the underlying accelerator. In this way, the Lazy Tensor system enables domain specific compilation (DSC) while maintaining the look and feel of eager execution, the hallmark of the PyTorch training framework. While the use of Lazy Tensors allows for XLA compilation without changing our model definition, it is important to be aware of the subtle differences between standard tensors and XLA tensors. For more details on Lazy Tensors and how to understand their behavior, be sure to check out this informative post.

PyTorch/XLA for GPU

While the API documentation describes PyTorch/XLA as supporting all XLA devices, including GPUs, you are likely to discover that, as of the time of this writing, support for Google Cloud TPUs has been highly prioritized. A naïve search for “PyTorch/XLA on GPU” will turn up several disclaimers regarding its support, and some unofficial instructions for creating a custom, GPU supporting, build (e.g., see this github issue).

Thankfully, several cloud service providers have created docker images specifically supporting PyTorch/XLA on GPU. See here for the available support in GCP, and here for the latest supporting image in AWS. In the next section we will provide an example of using PyTorch/XLA on GPU using the Amazon SageMaker training service.

PyTorch/XLA on AWS

In the code block below we show how to train a HuggingFace vision transformer model using torch_xla, the PyTorch/XLA Python module. In order to highlight the differences between the torch_xla flow and the standard flow, we have implemented both flows and added an “is_xla” flag to toggle between them. The code assumes that we will run on a single instance with 8 GPUs. We have highlighted some of the XLA specific lines of code.

import time
import torch
import os
import json
from torch.utils.data import Dataset
num_gpus = 8
is_xla = True
if is_xla:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
if os.environ.get('XRT_WORKERS') is None:
# Add ENVARS required for XLA
host=json.loads(os.environ["SM_HOSTS"])[0]
os.environ["XRT_WORKERS"] = f'localservice:0;{host}:43857'
os.environ['XRT_SHARD_WORLD_SIZE'] = '1'
os.environ['XRT_HOST_ORDINAL'] = '0'
os.environ['FI_EFA_USE_DEVICE_RDMA'] = '1'
os.environ['NCCL_PROTO'] = 'simple'
os.environ['XLA_FIX_DIV_FP64'] = '1'
os.environ['OFI_NCCL_NIC_DUP_CONNS'] = str(num_gpus)
os.environ["GPU_NUM_DEVICES"] = str(num_gpus)
else:
# DDP setup
import torch.distributed as dist
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR',
'localhost')
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT',
str(2222))
dist.init_process_group('nccl', rank=rank,
world_size=world_size)
# wrap the model with DDP
def wrap_model(model,local_rank):
from torch.nn.parallel import DistributedDataParallel as DDP
model.to(torch.cuda.current_device())
model = DDP(model,device_ids=[local_rank])
return model
# A fake dataset
class FakeDataset(Dataset):
def __len__(self):
return 10000000
def __getitem__(self, index):
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(data=[index % 1000], dtype=torch.int64)
return rand_image, label
def build_model():
from transformers import ViTForImageClassification, ViTConfig
return ViTForImageClassification(ViTConfig(num_labels=1000))
def main(rank, world_size=num_gpus):
dataset = FakeDataset()
model = build_model()
if is_xla:
device = xm.xla_device()
rank = xm.get_local_ordinal()
model = model.to(device)
else:
setup(rank, world_size)
torch.cuda.set_device(rank)
model = wrap_model(model,rank)
batch_size = 128
optimizer = torch.optim.Adam(model.parameters())
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=12)
if is_xla:
data_loader = pl.MpDeviceLoader(data_loader, device)
loss_function = torch.nn.CrossEntropyLoss()
t0 = time.perf_counter()
for idx, (inputs, targets) in enumerate(data_loader, start=1):
if not is_xla:
inputs = inputs.to(torch.cuda.current_device())
targets = targets.to(torch.cuda.current_device())
targets = torch.squeeze(targets,-1)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs['logits'], targets)
loss.backward()
if is_xla:
xm.optimizer_step(optimizer)
else:
optimizer.step()
if rank == 0 and idx%1000 == 0:
batch_time = time.perf_counter() - t0
print(f'step: {idx}: mean step time is {batch_time/1000}')
t0 = time.perf_counter()
if not is_xla:
dist.destroy_process_group()
def _mp_fn(index):
main(index)
if __name__ == "__main__":
if is_xla:
import torch_xla.distributed.xla_multiprocessing as mp
else:
import torch.multiprocessing as mp
mp.spawn(_mp_fn, nprocs=num_gpus, join=True)

To run the training script in Amazon SageMaker, we need to program our training job to use one of the docker images that includes a build of torch_xla that was specifically configured and tuned to run on Amazon SageMaker’s GPU training instances.

image_uri = "763104351884.dkr.ecr.us-east-1.amazonaws.com/" \
"huggingface-pytorch-trcomp-training:1.11.0-" \
"transformers4.21.1-gpu-py38-cu113-ubuntu20.04"
from sagemaker.pytorch import PyTorch
estimator = PyTorch(entry_point='train.py',
role=<role>,
instance_type='ml.p4d.24xlarge',
instance_count=1,
image_uri=image_uri)
estimator.fit()

The training script we shared above is, admittedly, a bit messy. The environment variable settings, in particular, required quite a bit of trial and error, (as well as some reverse engineering). In the next subsection we show how to use Amazon SageMaker Training Compiler for a cleaner and more general solution.

Amazon SageMaker Training Compiler

The Amazon SageMaker Training Compiler is a feature of Amazon SageMaker that is intended to accelerate the training of deep learning models on GPU instances managed by SageMaker. Under the hood, SageMaker Training Compiler achieves the training acceleration by using XLA compilation in the manner we described above. For its PyTorch support, SageMaker Training Compiler loads a docker image with a custom build of torch_xla, auto configures the training environment for its use, and starts up the XLA training job. In particular, SageMaker Training Compiler sets all of the obscure (undocumented) environment variables, thus reducing the action required by the user to a few simple steps.

The code block below shows how to start up our PyTorch training job with Amazon SageMaker Training Compiler.

from sagemaker.huggingface import HuggingFace, TrainingCompilerConfigdistribution={'pytorchxla': {'enabled': True}}estimator=HuggingFace(entry_point='train.py',
role=<role>,
instance_type='ml.p4d.24xlarge',
instance_count=1,
transformers_version='4.21.1',
pytorch_version='1.11.0',
compiler_config=TrainingCompilerConfig(),
distribution=distribution)
estimator.fit()

Note that if your training script uses a HuggingFace transformer model, you can further simplify the use of SageMaker Training Compiler (and torch_xla) by using the high level Huggingface Trainer API, as described in the feature documentation.

Be sure to check out the feature documentation for more details. See also the API documentation for examples of the different modes of use.

Results

In the table below we compare the runtime performance, measured in samples per second of our Vision Transformer with and without XLA compilation.

Impact of XLA compilation on samples per second — higher is better (by Author)

Applying XLA compilation boosted the performance of our model by ~15%. Once again, we emphasize that the potential benefits of XLA compilation are very dependent on the specifics of the model. Running the same experiment on a different ML model may produce very different comparative results.

Note that it took around 10,000 steps, or close to an hour, for the step time of the XLA run to stabilize. This is a known symptom of XLA compilation.

XLA Limitations

Before using the XLA compiler, it is important to be aware of some of its limitations.

Model Dependence

Throughout this post we have emphasized the fact that XLA offers the potential for performance enhancements, but not a guarantee of performance enhancements. Whether you will benefit from using XLA, and, if so, how significant the performance will be, are highly dependent on the specifics of your model. In some cases, such as when your model includes tensors of dynamic shape, you may find that using XLA reduces your performance. While it is often difficult to predict whether your model is “XLA-friendly”, there are several resources (e.g., here) that should provide you with some general pointers.

As we saw in our example above, XLA compilation is a process that occurs over the course of multiple training steps. Depending on your model, it may take several minutes or more for the training step time to converge to its XLA-optimal value. In particular, if your overall training duration is relatively short, you might see no benefit to using XLA.

Limitations to Debugging and Experimentation

One of the reasons for the popularity of the PyTorch training framework, and its eager execution policy, is that it simplifies debugging and experimentation. In PyTorch, one can evaluate the value of any intermediate tensor without needing to jump through hoops to do so (as in graph mode execution). By adopting XLA compilation we are essentially giving that up. One can mitigate this limitation by supporting both XLA and non-XLA flows (as in our script above), however, there is always the possibility that a behavior we are seeing is unique to the XLA flow.

An additional factor that complicates debugging and experimentation is the reliance on custom torch_xla builds for our training. As of the time of this writing, there is no official torch_xla Python package. One option you have for evaluating the XLA behavior in a controlled environment is to pull the XLA-specific docker image from the CSP’s image repository (or to do this using SageMaker’s local mode). The PyTorch/XLA documentation includes instructions for configuring torch_xla to run on CPU. However, you may find running the docker image on a local GPU to be more challenging. Regardless, debugging and experimenting in this manner is clearly limiting and far from ideal.

Code Adaptations

Using torch_xla requires a number of code adaptations. These can be grouped into two types, adaptations to the training flow that are required in order to use torch_xla, and adaptations that are recommended as best practices when using torch_xla. We demonstrated some of the most basic required adaptations in our code example above. In practice, additional adaptations are often required for saving and loading models, integrating automatic mixed precision, and more. The required adaptations are documented here. There are several resources, such as here and here, describing recommended adaptations for getting the most out of torch_xla. These are often targeted at minimizing the number of times that XLA-compilation is required. Models will only benefit from using XLA if they can be XLA-compiled once and reused for all subsequent steps. Models that require frequent XLA-compilations (e.g., models with tensors that have dynamic shapes ) will not benefit from using torch_xla. To quote this recommended blog post that includes best practices for using torch_xla, “compile once and execute often.

The code adaptations, regardless of whether they are required or recommended, complicate the use of torch_xla. This is especially true if you choose to maintain both XLA and non-XLA flows.

Summary

XLA compilation offers the potential for significant training acceleration and, by extension, training cost savings. Unfortunately, at the time of this writing, GPU support of XLA compilation in PyTorch remains of secondary status to TPU. Thankfully, CSPs such as AWS have created custom builds that allow us to reap the benefits of XLA compilation. We can only hope that the future will bring improvements to the availability and usability of torch_xla, in general, and torch_xla on GPU, in particular. Such improvements would include:

  1. an official, multi-purpose, torch_xla build,
  2. alignment of the torch_xla and standard torch API flows,
  3. merging the XLA support into the official PyTorch package (as in TensorFlow), and
  4. improved usage documentation, particularly on GPU.

In the meantime, “give it a try, and see your code fly” :).

Please feel free to contact me with any comments, questions, or corrections.

--

--

I am a Machine Learning Algorithm Developer working on Autonomous Vehicle technologies at Mobileye. The views expressed in my posts are my own.