Read for free at alexdremov.me
Flash Attention is a revolutionary technique that dramatically accelerates the attention mechanism in transformer-based models, delivering processing speeds many times faster than naive methods. By cleverly tiling data and minimizing memory transfers, it tackles the notorious GPU memory bottleneck that large language models often struggle with.
In this post, we’ll dive into how Flash Attention leverages efficient I/O-awareness to reduce overhead, then take it a step further by crafting a block-sparse attention kernel in Triton.
馃挜 I will provide a simple explanation of how Flash Attention works. We will then implement the explained algorithm in Triton!
What is Attention?
The attention mechanism (or scaled dot-product attention) is a core element of transformer models, which is a leading architecture for solving the problem of language modeling. All popular models, like GPT, LLaMA, and BERT, rely on attention.
The formula is pretty simple:

The rest is history.
Even though the formula looks simple, its computation involves multiplications of large tensors and a lot of data movement. Considering that this is a core part of the transformer architecture, optimizing the algorithm greatly improves the performance of the model in general.
In the naive implementation, attention requires O(n虏) additional memory and O(n虏) compute time complexity, where n is the sequence length. That’s a lot!
Flash Attention
Core Idea
The main idea of Flash attention can be summarized in a simple quote from the original paper:
We argue that a missing principle is making attention algorithms IO-aware – accounting for reads and writes between levels of GPU memory.
That is, modern GPUs have several types of memory:
- SRAM – fast, on-chip, small
- HBM – slower than SRAM, large size. That’s what we usually address as GPU memory.
Check out the memory hierarchy in the image below to see the differences in bandwidth and sizes of different memory types.

馃挕 To conduct computation, data must be transferred from HBM to SRAM, and this transfer is not overhead-free!
The Flash Attention algorithm proposes a method of computing attention in tiles, without explicitly materializing the attention scores tensor:

馃挜 Not materializing a matrix means that at any given time, the matrix does not exist in its full shape in memory.
It’s easy to see that this matrix requires O(n虏) of memory to store. For large sequence lengths, that’s a lot of data! So, if we manage to avoid explicitly materializing this matrix, we can save lots of memory.
However, this matrix is necessary for transformer training as it is a part of backpropagation and gradient calculation. The authors propose that it’s better to recalculate this matrix during the backward pass (again without explicit materialization). Not only does this saves lots of memory, but it also provides huge speedups as we don’t need to transfer this enormous matrix between different GPU memory types.
Overall, such an approach did not only speed up calculations by taking GPU I/O specifics into account, but also allowed processing huge sequence lengths as memory complexity drops to O(n).
Tiled Attention Calculation
The last thing to understand is how to compute attention in tiles. Basically, this means that we will calculate attention over the full sequence by processing incoming tokens in small portions.
Well, it’s easy to calculate QK^T in tiles. Considering that attention dimension is not high, we can load full matrix rows and columns and conduct multiplication in tiles.
馃槨 Yes, if we want to have an enormous attention dimension, Flash Attention will not work without algorithm modifications.
As dimensions are usually quite small even for enormous models, this limitation is fair.

So, we have QK^T calculated in SRAM. All that’s left is to apply softmax, multiply by V, and that’s it!

That’s where the trick is.
The problem is that the softmax denominator requires aggregation over the sequence length to normalize scores, and we do not have access to the whole length as we load data in tiles.
To address it, we can implement a concatenated softmax algorithm. Using it, we can calculate softmax "in batch" mode: by adjusting computed values with the new incoming data.
Taking the algorithm from the original article, we can define rules to compute the softmax over data concatenation. Having two vectors x1 and x2, we need to calculate the softmax denominator l(x) over those vectors’ concatenation [x1, x2]. If the vector’s maximum is m(x), we can easily derive the softmax denominator of the concatenation:


The last equivalence can be easily verified as

So, now we have what we want – we can calculate softmax per-tile and then, by doing re-normalization from the formula above, compute the global softmax. The last thing to do is to incorporate the tile of the V tensor and keep doing the same re-normalization (as matrix multiplication is a linear operation).
And all of this without loading the full sequence into memory or materializing QK^T!
馃挜 Notice that we calculate Softmax(QK^T) in tiles only, without needing to have the whole matrix at any moment.
Also, in the actual algorithm for numerical stability, we will compute not Softmax(x) but Softmax(x – max(x)). We can do that as softmax is invariant to constant shifts.
Triton Implementation
Now, we can easily implement the outlined algorithm in Triton, which is a tool that allows us to write efficient GPU kernels with the ease of Python.
馃挕 To learn more about Triton, check out their official guides.
Outlining the Algorithm
The first step is to decide how we will assign jobs and what data each job will load. By the algorithm of tiled softmax, each job must have access to K, V over the whole sequence length. So, each job will iterate over K, V in tiles. We don’t have any algorithmic restriction on the number of Q tiles processed. Therefore, each job will load just one Q tile and work with it only – this way we will maximize job parallelism.

In summary, each job will load a single Q tile, iterate over all tiles in K and V, and store one tile of result corresponding to the Q tile.
The Kernel
What’s left is to write the actual code. Let’s focus on the core part first, and only then we’ll add Triton-specific boilerplates.
Below is a Triton pseudocode with every line explained.
See? Easy!
What’s important is that you can see how simple it is to write such a thing as soon as we understand the idea of tiled softmax. Apart from that, there’s nothing complicated from the algorithm perspective.
馃挜 This kernel can be made even faster by implementing triton optimizations. However, this is out of the scope of this article.
This pseudocode is pretty close to the actual code. You may find it in my GitHub by following the link. All that I added is just data management and Pytorch wrappers.
kernels/src/self_attention/kernel.py at main 路 alexdremov/kernels
鉂桪on’t hesitate to ask if something isn’t clear. I’m here in the comments 馃榿 .
The code above was extensively tested to match PyTorch’s scaled_dot_product_attention
. You can also check out the tests to see how to use the written kernel.
Benchmarking
While we wrote the kernel in Triton to improve the algorithm understanding, it’s interesting to compare the performance with a naive implementation and PyTorch’s scaled_dot_product_attention
.

As expected, the Flash Attention algorithm completely outperforms the naive implementation performance-wise. Also, I’ve marked with a dashed line the range of lengths for which the naive implementation causes a CUDA out-of-memory error.
We see that our Triton implementation is slightly worse than PyTorch SDPA. But the difference is not too large Considering the fact that PyTorch SDPA is a well-optimized CUDA kernel, that’s a nice result.
Benchmarking code is also available in the repository.
kernels/benchmark/benchmark_self_attention.py at main 路 alexdremov/kernels
This story was originally published on alexdremov.me Check it out! (at least, TEX looks better there)
Conclusions
In the post, I covered the motivation of the Flash Attention algorithm as well as its algorithm details. Finally, we were able to implement it from scratch in Triton, reproducing the speedups from the paper.
I hope this post improved your understanding of Flash Attention. Feel free to leave a comment below if you have any questions.
References
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness