The world’s leading publication for data science, AI, and ML professionals.

Writing custom CUDA kernels with Triton

Triton is a language compiler to write highly optimized CUDA kernels. Learn the basics of GPU programming and how to use Triton.

Exploring Triton’s Just-In-Time (JIT) compiler and code generation backend

With the success of Deep Learning and the explosion of research papers, it is common to find ourselves in a situation where we come up with a new idea just to find that it is not hardware accelerated. More specifically, when we come up with a new activation function, or a self-attention mechanism and have to rely on the capabilities provided by PyTorch/Tensorflow to handle the forward and backward pass through the Module.

PyTorch JIT is one option in these cases. But PyTorch JIT is a high-level compiler that can only optimize parts of the code but cannot be used to write custom CUDA kernels.

There is another problem with writing CUDA kernels. It is incredibly hard to do. Optimizing the computations for locality and parallelism is very time-consuming and error-prone and it often requires experts who have spent a lot of time learning how to write CUDA code. Also, GPU architectures are rapidly evolving, like the latest edition of tensor cores which means even a bigger challenge writing code that is utilizing the maximum performance out of the hardware.

This is where OpenAI Triton comes into the picture. Triton has three main components

Figure 1: Overview of main components of Triton. Source (with permission)
Figure 1: Overview of main components of Triton. Source (with permission)
  1. Triton-C: A C-like language mainly intended for programmers already familiar with CUDA.
  2. Triton-IR: An LLVM-based Intermediate Representation (IR). Triton-IR programs are directly constructed from Triton-C. In short, LLVM provided a lot of hardware-specific optimizations, which means we can directly use Nvidia’s CUDA Compiler (NVCC) to optimize our code specific to the particular hardware.
  3. Triton-JIT: A Just-In-Time (JIT) compiler and code generation backend for compiling Triton-IR programs into efficient LLVM bitcode. This also includes many machine-independent optimizations, which means less work for us.

Triton-JIT is the most exciting part of the Triton project for me. It allows programmers with almost no experience with CUDA programming to write highly optimized CUDA kernels in Python. Before discussing Triton, we need to understand how CUDA programs work on GPU.

Useful links

GPU programming basics

Starting with CPU (host). CPU has access to RAM, storage disks, and all the connected peripherals. GPU (device) on the other hand has no access to RAM or anything. A GPU has its own memory which is called VRAM and data must be copied from CPU->GPU for GPU to work on it and data must again be copied from GPU->CPU for CPU to store it in one of the storage devices or share it with the connected peripherals.

Note: This is the reason why you should minimize data movement between CPU and GPU as much as possible. To do this you have to brainstorm how you can load the data in chunks to process it in parallel or import the data in a way that it can be reused multiple times before importing next data item.

In CUDA, we launch many threads in groups of thread blocks that form a grid. All threads in a thread block can communicate with each other. You can launch 1024 threads per block and 2^32-1 blocks in a single launch. Figure 2, shows an example of this.

Figure 2: Architecture of CUDA programs. Source
Figure 2: Architecture of CUDA programs. Source

The idea behind using blocks is that you do not need to change your code if you get a new GPU in the future. So the new GPU can simply execute more blocks concurrently without change of any code.

CPU-program vs CUDA-program

Without going into technicalities let us consider a simple example of adding two arrays of length 3.

In C++, if we want to add these arrays then we will create a for-loop that will run three times (assuming a single-threaded program).

But in CUDA, we will launch 3 threads, each of which will do addition at a single index, and the for-loop is done in a single step. In reality, the following would take place

  1. Copy arr1, arr2 from CPU to GPU.
  2. Create a new array of size 3 (or store the result of the addition in arr1).
  3. Launch 3 threads to do the addition and store the result in the new array.
  4. Copy the result from GPU to CPU.

Because GPUs have1000s of cores, doing simple things like addition, matrix multiplication is much faster on the GPU than the CPU (provided the speedup is more than the time spent to transfer data between the CPU and GPU).

CUDA vs Triton

We saw the CUDA execution model above. Now let’s see how Triton differs from the above model.

In CUDA each kernel is associated with a thread block (i.e. a collection of threads). In Triton, each kernel is associated with a single thread. This execution paradigm solves the problem of memory synchronization between threads, inter-thread communication while allowing automatic parallelization.

Now instead of storing threads inside a thread block, a block consists of a range i.e. tiles of pointers to threads. The interesting thing about this is you can have as many ranges as you want. So if your input is 2D you can specify Range(10) for the x-axis and Range(5) on the y-axis for a total of 50 threads. Similarly, you can define ranges along as many dimensions as you want.

Figure 3: CUDA execution model vs Triton execution model. Source (with permission)
Figure 3: CUDA execution model vs Triton execution model. Source (with permission)

Adding two arrays using Triton

Now we have a basic understanding of how CUDA and Triton works, we can writing Triton programs. Use the following command to install Triton

pip install triton

A summary of steps is given below:

  1. Define block. We know blocks are defined by specifying a range. So in addition, we only need to define range along one dimension. Let it be 512 and we define it as a global BLOCK_SIZE=512.
  2. The range of 512 actually means we are launching 512 threads to do the computation.
  3. Next, we get the index of the input data. Suppose the input array is of size 1000. Because we defined a block size of 512, we will process the input array in chunks of size 512. So the first chunk would be from 0:512 and the second chunk from 512:1024. This is done using the code shown below
# Addition is 1D, so we only need to get the index of axis=0
pid = triton.language.program_id(axis=0)
# Below offsets are a list of pointers
block_start = pid * BLOCK_SIZE
offsets = block_start + triton.language.arange(0, BLOCK_SIZE)
  1. Masking to guard memory operations. In the above example, the input array is of size N=1000 but the offset is from 512:1024. So we need to specify a mask that will protect us against out-of-bounds accesses. This mask needs to be specified for each axis.
mask = offsets < N
  1. Load the data. Now that we have defined the offsets and the mask, we can load the data from the RAM and mask out all the extra elements.
def add_kernel(arr1_ptr, arr2_ptr, output_ptr, ...):
    ...
    arr1 = triton.language.load(arr1_ptr + offsets, mask=mask)
    arr2 = triton.language.load(arr2_ptr + offsets, mask=mask)
  1. Do the relevant operation. In this case, we only need to do addition.
output = arr1 + arr2
  1. After doing the computation, store the result in RAM. GPU has no access to storage so we have to first move the data to RAM and then we can store the data to disk if we want.
triton.language.store(output_ptr + offsets, output, mask=mask)

The code for the entire kernel is shown below

import triton
import triton.language as tl
BLOCK_SIZE = 512
@triton.jit
def add_kernel(arr1_ptr, arr2_ptr, output_ptr, N):
    # Step 1: Get range of axis
    pid = tl.program_id(axis=0)

    # Step 2: Define the offsets and mask
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < N
    # Step 3: Load the data from RAM
    arr1 = tl.load(arr1_ptr + offsets, mask=mask)
    arr2 = tl.load(arr2_ptr + offsets, mask=mask)

    # Step 4: Do the computation
    output = arr1 + arr2
    # Step 5: Store the result in RAM
    tl.store(output_ptr + offsets, output, mask=mask)

To use the kernel, we can define a helper function as shown below

def add(arr1: torch.Tensor, arr2: torch.Tensor):
    output = torch.empty_like(arr1)
    N = output.numel()
    grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),)
    add_kernel[grid](arr1, arr2, output, N)
    return output

grid is basically specifying the space over which we will work. In our case, the grid is 1D and we specify how the data is split along the grid. So if the input arrays are of size then we want to have the grid as following [0:512], [512:1024]. So in this step, we are basically specifying how to split the input data and pass it to the kernel.

By default, grid takes one positional argument which we call meta. The purpose of meta is to provide information like BLOCK_SIZE , but in this example, we used a global variable to define that so using meta is redundant.

Now we can call the add function as a normal python function, as shown below (make sure the inputs passed to the function are already on the GPU)

arr_size = 100_000
arr1 = torch.rand(arr_size, device='cuda')
arr2 = torch.rand(arr_size, device='cuda')
pytorch_out = arr1 + arr2
triton_out = add(arr1, arr2)
print(torch.sum(torch.abs(pytorch_out - triton_out)))

Out

> python main.py
tensor(0., device='cuda:0')

Add for higher dim-tensors

We can also reuse the same kernel for N-dim tensors. This gives us the flexibility to avoid writing multiple kernels for a different number of input dimensions. The idea is simple we reshape the tensor to a 1D tensor in the helper function and then reshape the output tensor.

The reshaping operation is not a time-consuming operation as we are just modifying the stride values in the tensor class. The modified helper function is shown below.

def add(arr1: torch.Tensor, arr2: torch.Tensor):
    input_shape = arr1.shape
    arr1 = arr1.view(-1)
    arr2 = arr2.view(-1)
    output = torch.empty_like(arr1)
    N = output.numel()
    grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE),)
    add_kernel[grid](arr1, arr2, output, N)
    output = output.view(input_shape)
    return output

And then we call the function the same as before

arr_size = (100, 100, 100)
arr1 = torch.rand(arr_size, device='cuda')
arr2 = torch.rand(arr_size, device='cuda')
pytorch_out = arr1 + arr2
triton_out = add(arr1, arr2)
print(torch.sum(torch.abs(pytorch_out - triton_out)))

Out

> python main.py
tensor(0., device='cuda:0')

This was a simple example. But in the case of complex examples like matrix multiplication, how you split the data can have a huge effect on the performance. Triton docs provide a tutorial on how to write an efficient matrix multiplication kernel which goes into much more detail.

This was a short tutorial introducing GPU programming basics and getting you started with Triton. Check the Triton project on openai/github if you want to learn more.

twitter, linkedin, github


Originally published at https://kushajveersingh.github.io on September 12, 2021.


Related Articles