Exploring the Transformer Series (18) --- FlashAttention
Exploring the Transformer Series (18) --- FlashAttention
0x00 Overview
0.1 Problem
The core of the Transformer architecture is the powerful self-attention mechanism. However, self-attention is slow and memory-intensive, especially when dealing with long context lengths. For a Transformer model, assuming an input sequence length of , both its computational and space complexity are . In other words, the computational cost and storage space of the model increase quadratically with the sequence length . When the input sequence length is long, the Transformer’s computation process is slow and memory-intensive, which limits the maximum sequence length of large language models. This is why, in the early stages of development, large models often only supported 2K or 4K token inputs. Therefore, researchers sought to reduce the computational cost of the Transformer model, reduce the complexity, strive to approach the required complexity, or drop to .
0.2 Other Solutions
Before FlashAttention, many attempts had been made, basically following two paths: reducing the computational complexity of the attention mechanism and reducing its space complexity. Models improved by these methods are usually called Efficient Transformers.
Regarding computational complexity, some works have attempted to propose approximate attention mechanisms to reduce the theoretical computational complexity of attention. These can be mainly categorized into sparse estimation and low-rank estimation. The basic idea of sparse estimation is to approximate a complete, dense attention matrix using a sparse matrix. For example, Reformer performs Local Sensitive Hashing on Q and K, calculating attention only for Q and V within the same bucket, thus reducing the time complexity of attention from to . For example, the basic idea of low-rank approximation is to estimate the attention matrix using a low-rank matrix. For instance, linear transformers introduce a kernel function , and will formalize
as
This decouples Q and K in the softmax operation. After this operation, we can first calculate . The time complexity of this operation is . While reducing the computational complexity of attention mechanisms is theoretically very attractive, it still has some shortcomings in practical applications, such as the following two points:
- The performance is inferior to the original attention mechanism. Whether it’s sparse estimation, low-rank estimation, or others, these methods all employ some approximation algorithm to estimate the attention weight matrix, inevitably resulting in information loss. Currently, the mainstream approach remains the original attention mechanism.
- These methods cannot reduce the time consumed by memory reads. They can only reduce the computational complexity of the attention mechanism, but cannot control the space complexity or other factors in the operation of the attention mechanism, and cannot reduce the time loss caused by memory reads and writes.
Regarding space complexity, the basic idea behind this work is to reduce the memory requirements of the attention mechanism by minimizing data swapping between HBM and SRAM, thereby reducing the computation time consumed by the attention mechanism. A representative approach is kernel fusion, which simply combines operations that would otherwise be performed step-by-step by multiple CUDA kernels into one or a few CUDA kernels, thus reducing the number of data swaps between HBM and SRAM and saving computation time.
0.3 Flash Attention
The authors of FlashAttention discovered that while these Efficient Transformers effectively reduced the model’s FLOPS, their computational speed did not decrease significantly. The root cause of this phenomenon is that most Efficient Transformers typically only focus on FLOPS (Floating Point Operations Per Second), a commonly used metric for computationally intensive applications and deep learning model performance. However, the model’s computational speed is not only greatly affected by FLOPS but also by MAC (Memory Access Cost). Especially when the computation itself is already highly efficient, the MAC overhead cannot be ignored. MAC overhead mainly comes from two aspects: reading data from memory and writing data to memory. Similar to the CPU, in a GPU, when computation is needed, data must be read from GPU memory and the computation unit performs the operation. After the computation is complete, the data is written back to GPU memory.
The work done by Flash Attention is embodied in its paper titled “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness,” as detailed below:
- Fast (with IO-Awareness). Previous methods for accelerating Transformer computation focused on reducing computational FLOPs, for example, using sparse attention for approximation. However, the authors of Flash Attention discovered that the bottleneck to slow computation was IO read/write speed rather than computational power. Therefore, Flash Attention improves overall computational speed by reducing the number of times GPU memory (HBM) is accessed, and this is IO-awareness. Specifically, reducing the number of HBM accesses is achieved through tiling and kernel fusion techniques.
- Memory efficiency saves GPU memory. In standard Attention scenarios, memory is saved during forward propagation by storing the attention matrices and . During backpropagation, the attention matrix is read to calculate the gradient, which is why the memory complexity is . The reason is that Flash Attention introduces statistics to change the computation order of the attention mechanism, avoiding the instantiation of the attention matrix, thus reducing the storage pressure to .
- Exact Attention provides precise attention with identical computational results. While “Sparse Attention,” preceding Flash Attention, is an approximate calculation that reduces computational cost, its results differ from standard Attention. Flash Attention, on the other hand, yields results identical to standard Attention.
In simple terms, the attention formula is:
FlashAttention doesn’t require implementing intermediate matrices in global memory; instead, it integrates the entire computation in the above formula into a single CUDA kernel, thus reducing I/O. Furthermore, for classic algorithms like matrix multiplication, tiling is used to ensure that on-chip memory does not exceed hardware limitations.
0x01 Background Knowledge
Since large models are primarily trained and inferred on GPUs, we’ll first look at GPU-related knowledge and then examine the computational characteristics of Transformers.
1.1 GPU-related concepts
When learning and using CUDA, we often encounter many concepts such as SM, SP, and Grid, which can be confusing. We’ll provide a brief explanation below. These concepts generally fall into two categories:
- Hardware resources or concepts include: SP, SM, HBM, and SRAM;
- Software abstractions or concepts include: Thread, Warp, Block, and Grid;
Hardware concept
Let’s start by looking at some hardware concepts.
Operating unit
This mainly includes the concepts of SM (Streaming Multiprocessors) and SP (Streaming Processor). A GPU is composed of a series of SMs. An SM is the basic computing unit of a GPU, much like a core in a multi-core CPU chip. The difference is that a CPU core typically runs one thread, while an SM can run multiple lightweight threads. Each SM has a certain number of registers, on-chip memory, a control unit, and several SPs or other accelerated computing units. This on-chip memory and control unit are shared by all SPs. Furthermore, each SM is equipped with a hardware-based thread scheduler for executing threads.
Memory
We’ll use the A100-40GB as an example to illustrate the GPU’s memory usage. Below is a diagram of the A100-40GB’s memory hierarchy.

The top layer is a three-tiered pyramid. The bottom layer is the CPU’s memory, which is large in quantity but slow. The top two layers belong to the GPU. The GPU’s memory consists of multiple memory modules of different sizes and read/write speeds. It can be divided into on-chip memory and off-chip memory based on whether it is on the chip. The information on the two types of memory on the NVIDIA A100-40GB card is as follows.
| type | name | effect | size | Read and write speed | Features |
|---|---|---|---|---|---|
| On-chip memory | SRAM (Static Random-Access Memory) | Primarily used for caching and a small number of special storage units (such as textures). | Distributed across 108 streaming multiprocessors, each processor is 192KB in size. The total is KB = 20,736 KB = 20 MB. | 19TB/s | Small storage space, large bandwidth |
| Under-chip memory | HBM (High Bandwidth Memory) | Primarily used for global memory, also known as video memory. | 40~80GB | 1.5~2.0TB/s | Large storage space, small bandwidth |
It is important to emphasize again that SRAM is L1 Cache (a combination of shared memory and data cache).
As we can see, the bandwidth of video memory is much smaller than that of SRAM, making data reading time-consuming. However, SRAM storage is too small to hold much data. Therefore, we use the storage capacity of SRAM as the upper limit, trying to ensure that each data load fills the SRAM as much as possible, thus saving data reading time.
Software Concepts
Operating mode
A CUDA program can be divided into two parts (each with its own memory):
- The program running on the CPU is called the Host program, or you can think of the CPU as the Host.
- Programs running on the GPU are called Device programs, also known as Kernel functions. Alternatively, the GPU can be understood as a Device.
The typical way the GPU performs the operation involves the following steps:
- The CPU sends calculation instructions to the GPU;
- Data is copied from the CPU’s memory to the GPU’s memory, i.e., HBM;
- The GPU loads input data from the low-speed HBM into the high-speed SRAM;
- The GPU distributes computing tasks to various SMs for parallel processing.
- SM reads data from SRAM to perform computational operations;
- After the calculation is complete, the calculation results are written from SRAM to HBM;
- The calculation results are then copied from HBM to CPU memory;
Thread Model
On a GPU, multiple threads are needed to execute the kernel. For example, in vector addition, if we want to add 256-dimensional vectors, we can use 256 threads to process them in parallel, so that each thread can process one element of the vector. If the data is larger, there may not be enough threads available on the GPU, and we may need each thread to process multiple data points. Therefore, programmers need to carefully configure the threads based on the size of the data and the required level of parallelism.
To facilitate programmers in designing and organizing threads, CUDA programming abstracts software resources into a thread model, which includes concepts such as Grid, Block, Thread, and Warp. The software abstraction and hardware resources corresponding to each concept are as follows.
- Thread: The basic unit of parallel execution. A CUDA parallel program is executed by multiple threads, and the thread is the most basic unit of execution. Thread execution is handled by SPs. One SP can execute one thread.
- Block: A block consists of several threads. A block occupies one SM for execution.
- Grid (Threaded Grid): Multiple blocks will form a Grid. One Kernel function corresponds to one Grid. The Grid runs on the device.
- Warp: The scheduling unit during program execution. A warp consists of 32 or 16 threads. Threads within each warp can execute the same instructions simultaneously, thus achieving SIMT (Single Instruction Multithreading) parallelism. The warp is the smallest scheduling unit on an SM, and an SM can handle multiple warps concurrently.
Grid, Block, and Thread represent three levels of thread organization, a software architecture independent of hardware. Therefore, theoretically, we can arrange Grid, Block, and Thread in any dimension (one-dimensional, two-dimensional, three-dimensional). This software architecture corresponds to individual SMs or SPs in hardware. Hardware itself doesn’t have dimensions; it’s just an abstraction of dimensions in software. See the diagram below for details.

The specific explanations of these software concepts and hardware resources are as follows, and we will introduce them in a top-down hierarchy.
Grid & Device
The role of the Grid is to control the number of threads and perform differentiated execution. CUDA allows individual Kernel functions in the Host program to execute on the device according to the concept of a Grid. One Kernel function corresponds to one Grid; when a Grid runs on a device, it may exclusively occupy one device, or multiple kernels may concurrently occupy one device.
Block & SM
A Block is a thread block; threads within the same block can synchronize or accelerate communication through shared memory. Each Grid undertakes a task from a kernel function. When executing a task, each Grid further divides the task into several Blocks (thread blocks) to run on the SM. The relationship between Grid and SM is:
- Different blocks within the same Grid may be dispatched to different SMs for execution. A block’s thread can only be scheduled on one SM; that is, a block cannot cross SMs.
- Multiple blocks can be executed concurrently on an SM, and these blocks do not necessarily originate from the same kernel function. Sometimes, even if the remaining resources on the SM are insufficient to accommodate a block from kernel A, it may still be able to accommodate a block from kernel B. Multiple blocks need to enter the SM in turn.
- Each thread occupies a certain number of registers and shared memory, so the number of blocks that can live on the SM at the same time should not exceed the limits of these hardware resources.
- A thread block can contain multiple warps. Threads within the same block can synchronize and communicate via shared memory. A thread block is the smallest unit of execution on the GPU. Threads within a warp must reside in the same block. If the number of threads in a block is not an integer multiple of the warp size, the extra warp will contain some inactive threads. In other words, even if a warp has insufficient threads, the hardware will compensate by adding more threads, albeit inactive ones, which will still consume shared memory resources.
Thread & SP
A CUDA program (i.e., a kernel task) is ultimately broken down into threads. Local variables in each thread are mapped to registers in the SM, and the execution of the thread is handled by the CUDA Core, or SP.
Thread & Warp
Because the size of a block is variable, we cannot practically allocate a CUDA core array of the same size for parallel computation of any block. To better manage and execute threads, GPUs employ the SIMT (Single Instruction Multiple Threads) architecture, introducing the concept of a warp. Let’s first look at the differences between SIMT and SIMD.
- The CPU uses SIMD to process vector data. However, using SIMD alone cannot execute conditionally jumping functions in parallel, and it is obvious that conditional jumps will behave differently in different threads depending on the input data.
- GPUs use SIMT to process data. Developers don’t need to painstakingly shape the data into a suitable vector length, and SIMT allows each thread to have different branches, enabling parallel operations across different branches.
A warp is the smallest scheduling/execution unit in a GPU programming architecture. Threads within the same warp execute the same instructions, i.e., SIMT. Blocks are divided into warps, each mapped to a CUDA core array for execution. Each warp can be understood as a container for a thread, ensuring a fixed number of threads and allocating uniform hardware resources. Each container carries only one type of “cargo,” signifying synchronous execution. Typically, 32 threads form a warp, executing the same instructions in parallel within the same clock cycle, achieving Single Instruction, Multithreading. Each thread can access its own registers. Different warps read the necessary data from SRAM during computation; that is, different warps load and store data from different addresses and follow different control flow paths.
Summarize
Now, we will combine the GPU’s computing core SM and the different levels of GPU storage structure to draw a simplified diagram.
- Registers: Each SM in a GPU has a large number of registers. These registers are shared between cores and dynamically allocated based on thread needs. During execution, each thread is allocated private registers that cannot be read or written by other threads.
- L1 cache/shared memory: Each SM has its own L1 cache for storing data within the SM, shared by all CUDA cores within the SM. SMs cannot access each other’s L1 cache. In the NV Volta architecture, L1 and shared memory were merged to further reduce latency. After the merge, users can still directly control shared memory in their code, and can also control how much storage is allocated from L1 to shared memory. In FlashAttention, SRAM refers to L1 cache/shared memory.
- L2 cache: All SMs share the L2 cache. The bandwidth of L1/L2 cache is greater than that of video memory, meaning faster read and write speeds, but their storage capacity is smaller.
- HBM: Video Memory.

1.2 Transformer Memory and Computation
From a computational science perspective, performance bottlenecks in operations fall into two categories: computationally limited (Compute-bound or math-bound) and memory-limited (Bandwidth-bound or Memory-bound). To reduce the computational and space complexity of the Transformer model, we need to identify whether the resource bottleneck of the Transformer’s core component, the attention mechanism, is computational power or GPU memory. This will allow us to determine which aspect to optimize.
Basic Concepts
We will begin our analysis by starting with the basic concepts.
- Calculate bandwidth (math bandwidth) . This concept can be understood as computing power, specifically referring to the number of mathematical calculations a processor can perform per second, usually measured in OPS (operations per second). If floating-point calculations are used, the unit is FLOPS (floating-point operations per second).
- Memory bandwidth . This concept refers to the amount of data a processor reads from memory per second, measured in bytes per second.
- Calculate the arithmetic intensity. . This concept refers to the memory bandwidth requirements of an algorithm, specifically how many floating-point operations (FLOPs) can be supported per unit of data read (IO) in this algorithm. It can be calculated by dividing the total number of FLOPs by the total number of bytes accessed (also known as MOPs or memory operations).
- Calculate the upper limit of strength . It describes the maximum number of computations that can be performed per unit of memory swap on this computing platform. The unit is
FLOPs/Byte. Dividing the computing bandwidth and memory bandwidth gives the upper limit of the computing platform’s computational intensity. - Theoretical performance of the model . The theoretical number of floating-point operations per second that the model can achieve on the computing platform. The unit is
FLOPSorFLOP/s.
Computational and memory-constrained
The program’s execution time is mainly spent in two areas: computation and data reading/writing. Therefore, we obtain the following two timeframes.
Generally, computation time and memory access time can overlap, meaning “computing while simultaneously reading/writing the next item,” therefore the total runtime is .
- Computationally bound. When computation time exceeds memory access time, meaning most of the time for a certain operation is spent on the GPU’s streaming multiprocessors, it indicates that computational bandwidth is the bottleneck of the algorithm. Fast reads, slow computation, this is computationally bound. In this case, the time spent accessing HBM is relatively low, and regardless of the model’s computational intensity, its theoretical performance can only be equal to the computing power of the computing platform. Examples include large matrix multiplications and convolution operations with a large number of channels.
- Memory-bound. When memory access time exceeds computation time, meaning part of the time spent completing an operation involves moving data from memory to the streaming multiprocessor rather than actually computing on the streaming multiprocessor, it indicates that memory bandwidth is the bottleneck of the algorithm. Fast computation, slow read speed, this is memory-bound. When the computational intensity of the model is less than the upper limit of the computing platform’s capacity, the theoretical performance of the model is entirely determined by the bandwidth limit of the computing platform and the computational intensity of the model itself. Pointwise operations are mostly memory-bound, such as activation functions, dropout, and masking; reduction operations are also memory-bound, such as sum, softmax, batch normalization, and layer normalization.
Computational intensity of attention mechanism
To evaluate the bottlenecks in Transformers, it’s necessary to model the number of floating-point operations (FLOPs) required to compute Transformer encoder-only and decoder-only models, as well as the arithmetic strength of these networks. The most crucial part of the attention mechanism’s computation is calculating the attention weights; let’s examine its computational intensity. Assume we have , calculate , and , where is the attention head dimension. Referring to the example below, the calculated strength of the attention weights is as follows:
Note: Some papers or blogs omit steps 3 and 4, so the calculation of MAC will differ from that in this article.

Whether matrix multiplication is computationally or memory-constrained depends on the formula and the computational intensity of the platform. A100-40GB SXM’s platform computational intensity is 201 flops/bytes. Therefore, if the computational intensity of matrix multiplication is greater than 201, the performance is limited by computational bandwidth; conversely, performance is limited by memory bandwidth. The GPU’s computation speed is “far faster” than its memory bandwidth. Therefore, for memory-intensive tasks like attention mechanisms, the generation speed is determined not by the GPU’s computational power, but by the memory bandwidth. Furthermore, some operations in attention mechanisms are memory-constrained pointwise operations, such as the mask operation on S, the softmax operation, and the dropout operation on P. The performance of these pointwise operations is also limited by memory bandwidth.
How to balance
Researchers have analyzed the arithmetic strength of BERT Base and BERT Large encoders and GPT-2 decoders at different sequence lengths.
- For short sequence lengths (e.g., 128-512), most computations are performed in the projection layer of the FFN module, while most of the MHA computations are performed in the projection layer.
- As the sequence length increases, matrix multiplication begins to dominate because they are all quadratically scaled by the sequence length. This results in an initial increase in arithmetic strength, as the larger matrix dimension allows for more computation to be performed for each loaded parameter.
- However, at higher sequence lengths, the arithmetic strength decreases. This is because, for long sequence lengths, matrix multiplication and softmax computations in the MHA module begin to dominate. These have relatively low arithmetic strength compared to the projection layers in the FFN module.
These observations confirm that decoder inference is a memory-constrained problem, not a computationally constrained one. So, to balance GPU computing power and memory bandwidth, what batch size is needed? The formula is: 2 bytes * number of parameters / number of cards / memory bandwidth = batch size * 2 * number of parameters / number of cards / computing power. The number of parameters and cards cancel each other out on both sides of the equation, ultimately resulting in batch size = computing power / memory bandwidth. This requires adjustment based on the parameters of different chips. Additionally, network latency and the overhead of the communication library itself must also be considered.
1.3 Tiling
Tiling is a technique that reduces memory consumption by recursively performing operations by splitting the input and maintaining some intermediate variables. This tiling method is effective because addition is associative, allowing the entire matrix multiplication to be decomposed into a sum of many tiling matrix multiplications.
For large matrices, operating directly on the entire matrix would consume enormous amounts of memory. We know that matrix multiplication has the characteristics of block multiplication and accumulation. Therefore, a large matrix multiplication can be decomposed into smaller submatrices using the Tiling technique. These smaller matrices are then loaded separately from the slow HBM into the fast SRAM, where calculations are performed. Finally, the results of each block matrix multiplication are accumulated to obtain the final correct result.
The diagram below briefly explains how matrix multiplication works. . The input and output matrices are partitioned. Each matrix is divided into fragmentation. For each output fragment, we scan the relevant fragments in A from left to right and the relevant fragments in B from top to bottom, loading the values from global memory into on-chip memory (colored in blue, the entire on-chip memory occupies an area of ). For position , we load and from on-chip memory for all within the slice, and then in on-chip memory aggregate to . After communication for one slice is complete, we write the on-chip C slice back to main memory and then continue processing the next slice.
Alternatively, we can load the data required for computation from HBM to SRAM in advance or asynchronously, and by combining this with pipeline orchestration, we can further hide the time required for data loading.

The pseudocode for this operation is as follows:
a = A_i
b = B_j
c = C_ij
for k in range(k):
c += a[k] * b[k]
final c done
1.4 Operator Fusion
In inference engine implementation, a common way to accelerate operations whose performance is limited by memory bandwidth is operator fusion. The basic idea is to merge multiple operations into one operation when SRAM storage is sufficient, thereby avoiding the repeated execution of “reading input data from HBM, performing calculations, and writing the calculation results back to HBM”.
We will analyze this through an example. Suppose we want to execute operators A and B consecutively, where the output of operator A is the input of operator B. The most basic execution order is as follows:
- Start operator A and copy the data required by A from HBM to SRAM.
- Run operator A.
- Write the result of operator A back to HBM.
- Start operator B and copy the data required by B from HBM to SRAM.
- Run operator B.
- Write the result of operator B back to HBM.
This sequence involves four read/write HBM operations and two start operator operations, which will increase the runtime.
Following the operator fusion approach, if we find that SRAM is fully capable of storing the output of operator A, we will merge operator A and operator B into a single operation. In this way, the output of A is temporarily stored in SRAM for B to read, thereby reducing the number of read/write operations to HBM and the actions required to initiate operators, thus effectively reducing the runtime of memory-constrained operations.
0x02 Optimize the attention mechanism
Because FlashAttention optimizes the memory access (HBM) process during attention computation, let’s first look at the memory access process of the standard attention mechanism.
2.1 Standard Attention Mechanism
Calculation formula
The formula for the Scaled Dot-Product Attention module is as follows:
In this formula, both Q and K have dimensions of . The dimension of V is , and is the length of the input sequence. are the feature dimensions. has dimension , and has output dimension .
For ease of description, Mask and Scale will be omitted in the following discussion. Since the computational logic of each head in a multi-head attention system is consistent, only the case of a single head will be described here. Therefore, assuming there are N tokens in total, and each token vector has a dimension of d, a simplified attention computation process is shown in the following diagram:

Implementation Algorithm
The implementation algorithm of the standard attention mechanism given in the FlashAttention paper is shown in the figure below. The algorithm consists of three steps (also known as the 3-pass algorithm):
- (Calculate attention score). The goal is to obtain the dot product of each query with respect to all keys. Intuitively, the larger the dot product, the better the correlation between columns of a given Q-row and a given . In practice, the attention mechanism loads data from HBM, Q and K matrices, performs the dot product calculation , obtains the similarity score , and then writes the results back to HBM.
- (Calculate attention weights). The purpose of the softmax operation is to normalize the attention score. Specifically, the operation involves reading from HBM and executing . The attention weights are calculated, and then is written back to HBM.
- (Calculate the final attention result). and are read from HBM and execute . The calculation is performed, finally the vector is written back to HBM.

Note: Masking and dropout operations are omitted in the algorithm. Q, K, V, and O are all 2D matrices with shape . N is the sequence length, and d is the dimension of the attention head.
We will illustrate the above algorithm with the following diagram.

Detailed disassembly
The diagram above does not show the interaction between SRAM and HBM. We have found a more detailed algorithm implementation from other papers, which is shown below.

The following diagram illustrates the interaction process between SRAM and HBM in the algorithm and the amount of data read and written. The numbers in the diagram are consistent with the numbers in the algorithm above.
Note: Some papers or blogs omit steps 3 and 4, so the calculation of MAC will differ from that in this article.

The problem
Standard attention algorithms suffer from two drawbacks in GPU memory hierarchical storage architectures: high memory consumption and numerous HBM read/write operations. The culprit behind these drawbacks is . This operation determines, on the one hand, the algorithmic complexity of the attention mechanism, . On the other hand, the two intermediate matrices S and P generated by it occupy too much memory and need to be moved between HBM and SRAM. However, the read/write bandwidth of HBM is much lower than that of SRAM, which slows down the running time (wall-clock time). We will analyze this in detail below.
- High memory usage. The 3-pass algorithm requires a significant amount of memory for its input and output variables Q, K, V, and O, which is . Steps one and two will generate two intermediate matrices S and P respectively, both with memory requirements of . Therefore, the total memory requirement is . When the sequence length N is very large (i.e., ), P and S require of memory, far bigger than the memory requirement of . This will exhaust the video memory, and at the same time, the memory access pressure on the GPU HBM will increase dramatically to .
- HBM involves numerous read/write operations. Because the intermediate matrix occupies too much memory to be accommodated in SRAM, it needs to be transferred from SRAM to HBM. However, due to computational needs, S and P are accessed immediately after being stored in HBM, resulting in multiple HBM read/write operations. The three steps of the 3-pass algorithm correspond to three kernels: gemm, softmax, and gemm. The three kernels are executed sequentially. Each kernel’s computation process involves the following operations: reading data from HBM; computation; and writing back to HBM. There are a total of eight HBM matrix read/write operations, resulting in a total HBM access complexity of . The eight specific operations are as follows:
- The first step involves three operations. The two read operations are to read the complete Q and K matrices (each of size ) from HBM. A single write operation is to assign a similarity score S (of size ) and write it back to HBM. A total of steps are required. This involves one HBM access, which includes reading a very large matrix S.
- The second step involves two operations. One read operation reads the complete S matrix from HBM, and the other write operation writes P (of size ) back to HBM. A total of steps are required. This involves two HBM accesses, and two reads and writes of a very large matrix.
- The third step involves three operations. The two read operations are to read the complete P and V matrices (of size ) from HBM. A single write operation involves converting the output vector O (of size ) and writing it back to HBM. A total of steps are required. This involves one HBM access, which includes reading a very large matrix P.
2.2 Solution
Now that you know is the culprit, let’s think about how to reduce the memory space required for intermediate results in the calculation process, so that intermediate results can be temporarily stored in SRAM, thereby reducing I/O read and write and optimizing I/O time.
Ideas
Our goal is to calculate O. Generally, we need to obtain all Q, K, and V, and then calculate in three steps. Alternatively, we can first obtain a small piece of Q, K, and V, calculate a portion of O in one step, and then find a way to combine the portion of O into all of O.
As mentioned earlier, attention mechanisms () have three main computational modules: calculating the attention score, normalization, and weighted summation based on the attention weights. These correspond to three kernels executed sequentially: gemm(query × key), point-wise softmax, and gemm(attn_score × value). If SRAM can store intermediate results, we can merge these three kernels, keeping the intermediate data in SRAM. This avoids repeatedly reading and writing intermediate global memory from HBM, thus accelerating point-wise operations. Note that we temporarily disregard the special characteristics of softmax calculation and assume that it can be merged.
Therefore, our overall solution is to use “fusion + partitioning” to avoid frequent reads and writes of large matrices from HBM. This means eliminating reads and writes to large matrices S and P. Fusion + partitioning are two sides of the same coin, intertwined, and require a unified analysis of the overall approach.
- To reduce I/O, operator fusion should be performed centered around two gemm kernels.
- The premise of fusion is to store all intermediate variables and not write them back to HBM;
- SRAM doesn’t have enough space to hold the intermediate matrices, so block partitioning is necessary during fusion. As long as the block matrices and intermediate attention results can be stored in SRAM, only SRAM needs to be accessed during computation.
Next, we’ll look at operator fusion and block computation.
Operator fusion
For attention calculation, our approach is to optimize the input and output of data by merging the two gemm and softmax operators into a single operator. processing is done in SRAM all at once, thus reducing read and write operations on S and P.
The standard attention algorithm is: computed on SRAM. , write matrix S into HBM, then read matrix S from HBM into SRAM, and calculate .
Under the operator fusion scheme, the above operations can be combined into one kernel. That is, after S is calculated in SRAM, P is calculated immediately through S, thus avoiding the exchange of S between HBM and SRAM.

Block computation
As mentioned earlier, operator fusion requires a sufficiently large SRAM storage capacity; in other words, operator fusion is only feasible if SRAM can hold intermediate results. This is because while operator fusion is effective, it cannot solve the problem of excessive memory overhead.
For example, in the diagram below, the SRAM can hold 10,000 data points, but Q and K each hold 5,000 data points. If the fusion operator is run all at once, , 10,000 data points need to be loaded into SRAM. However, this would not be able to accommodate intermediate calculation results and would cause an OutOfMemoryError (OOM). Therefore, calculations can only be performed iteratively, which still results in a large number of read and write operations to HBM.

Because SRAM has limited memory, it’s impossible to compute the entire attention layer at once. Since fully connected layers and weighted summation based on attention weights are implemented using matrix multiplication, tiling operations can be used for block computation. In block computation, only the necessary blocks of Q, K, and V involved in the computation are loaded into SRAM, ensuring the total memory usage does not exceed the SRAM size. Furthermore, after calculating S, S is directly used to calculate P. This improves overall read/write speed (reducing the number of HBM accesses). See the diagram below for details.
- Divide Q[100,50] into two matrices
- Divide K [100, 50] into two matrices
- at this time the operator can compute the attention operations of these small blocks in one go in SRAM.

Therefore, our overall approach is as follows: QK^T a temporary output of shape is generated (b, n, s, s), while we only need Softmax(QK^T)V the final result of shape (b, n, s, d). As long as s and d are relatively small, we can fuse the multiplication of these three matrices into a single CUDA kernel function to directly generate the result Softmax(QK^T)V.
limit
Having looked at Softmax(QK^T)V the general idea, let’s take a closer look at how to calculate O and the SRAM limitations. Here are a few points to note.
- generate It’s a cumulative update operation. We use Let’s take an example to illustrate.
- From It is continuously updated and accumulated until it is finally obtained. .
- For better analysis, we will It can be viewed as a row vector containing i elements, that is, each term of the addition is considered as an element, as shown in the figure below. and consist of two elements, that is is the second column of the first row.
- Each update Need to The first column must be loaded into SRAM before it can be processed. Perform the operation of adding a new column.
- The softmax operation is omitted, meaning the kernel function performs the calculation .
- SRAM can only hold small blocks of Q, K, V, and O at a time.

The difficulty lies in calculating O; let’s look at two solutions next.
Option 1
The logical solution for Solution 1 is shown in the diagram below. The specific idea is as follows:
First, K and V were divided into Tc small blocks, and Q and O were divided into Tr small blocks.
Next, we begin the loop calculation. j is the outer loop, and i is the inner loop; or, K and V are the outer loop j, and Q and O are the inner loop i.
The outer loop logic is as follows:
- The outer loop retrieves the j-th block of the K and V matrices in the j-th iteration. , are loaded into SRAM.
- Each outer loop is correct for arrive update everything, but only update each part separately each time. arrive part of it.
- After all the j loops have finished, the latest complete O value is the expected result. arrive The entire update is completed all at once only after the outer loop ends.
The inner loop i of the j-th outer loop will be updated line by line. Each line The logic is as follows:
- Take the i-th block of the Q matrix and the i-th block of the O matrix (Right now The previous state of a row can be simply understood as The first column loaded into SRAM.
- use and S and P were calculated, and then summed. Multiplication yielded a new train .
- use and accumulation, update .
- Bundle write back to HBM.
- During the inner loop, a total of O were processed. This is the second update.
The pseudocode is as follows.
# ---------------------
# Tc: K和V的分块数
# Tr: Q和O的分块数量
# ---------------------
O_0 = 0
for 1 <= j <= Tc: # 对K和V进行外循环
load V_j, K_j
for 1 <= i <= Tr: # 对Q和O进行内循环
load O_i^{j-1}, Q_i
S = softmax(Q_i @ K_j) #计算得到了S
O_i^j = S @ V_j # 得到O_i行的第j列
O_i = O_i^j + O_i^{j-1} # 完成O_i的一次累加更新。这里需要从HBM中读取之前的O_i^{j-1}
store O_i # 回写
The corresponding diagram is shown below, with green markings indicating the outer circulation.

We then combine HBM and SRAM to obtain the step-by-step operation of the physical scheme.
- When j = 1 in the outer loop, we first iterate through all i once. During this stage, we produce arrive . They then write these, along with some other important data, back into HBM.
- Next, we perform the second outer loop, where j=2. In this stage, we need to process the previously produced data in the inner loop. arrive Load them one by one into SRAM, and then update them one by one in the inner loop. arrive .
Note that the diagram below simplifies the inner loop operations. The diagram assumes that Q is read into SRAM all at once, then the inner loop executes, and finally it’s written back to HBM all at once. In reality, each inner loop utilizes to calculate . Each inner loop requires a writeback of .
The outer loop with j = 1 is shown in the figure below.

The outer loop with j = 2 is shown in the figure below.

Option 2
Note that, in order to be consistent with the FlashAttention V2 paper, i here is the outer loop and j is the inner loop.
The problem with Solution 1 is the need for frequent read/write operations (O). Specifically, in the two nested loops above, the outer loop loads K and V first, and then the inner loop loads Q. This results in the inner loop only calculating as part of it, and each inner loop iteration requires frequent global memory reads and writes. Specifically:
- As shown in the diagram below, the update characteristic of O is that each row of O is bound to Q, that is updates and strict binding.
- In Scheme 1, the inner loop iterates over Q, which conflicts with the characteristics of O-updates, causing each iteration to . At that time, it will be to arrive load and update one by one. And it requires waiting for the outer loop to complete all iterations before loading everything at once. arrive all updates are complete.

In attention computation, the attention calculations for different queries are completely independent. That is, if the outer loop loads the query query (Q) first, then the attention for different query blocks can be assigned to different thread blocks for computation, and these thread blocks do not need to communicate with each other. in one , completing the process within a cycle reduces the number of read/write operations to HBM. So why don’t we use Q as the outer loop and KV as the inner loop for traversal? This avoids reading and writing intermediate results to HBM, allowing us to complete the process in each outer loop iteration. the calculation is performed for each row. Furthermore, the softmax operation is also performed in the row dimension, so a fixed Q and iterative KV approach is more naturally suited to the characteristics of softmax.
Therefore, our new algorithm swaps the loop positions in Scheme 1, loading Q and O first, then K and V, resulting in the logic diagram of Scheme 2 as shown below. After adjusting the loop order, the inner loop of Scheme 2 calculates first . Let Q1K1V1 be the initial value, then calculate Q2K2V2, and obtain the result based on Q1K1V1. This is represented by Q1K1V1 + Q2K2V2, and so on. This way, the inner loop doesn’t require reading and writing every time . By using HBM, IO-accesses are reduced, and the time consumed is also reduced.

We then combine HBM and SRAM to obtain the step-by-step operation of the physical solution. The inner loop operation is simplified here; let’s assume that K and V are read into SRAM all at once, then the inner loop is executed, and finally, they are written back to HBM all at once. In reality, the inner loop reads data each time, and , calculate , and finally write back in one go.
The outer loop with i = 1 is shown in the figure below.

The outer loop with i = 2 is shown in the figure below.

Summarize
The complexity of the matrix pair HBM and its repeated reads and writes is a major bottleneck in attention computation. To address this bottleneck, we devised two solutions. In fact, our solution 1 is the prototype of FlashAttention V1, and solution 2 is the prototype of FlashAttention V2.
We will demonstrate this using a method similar to painting a wall. The special requirement of this method is that when painting horizontally, subsequent brush strokes need to be based on the results of the previous strokes.
Solution 1 is as follows. Each outer loop updates every row of O, and each update is written back. If O is viewed as a matrix, the outer loop fixes one column at a time, and the inner loop updates the corresponding row in that column.
Furthermore, K and V are in the outer loop, while Q is in the inner loop. Therefore, a block of K and V corresponds to the full value of Q. Consequently, the O in the middle is also a full value, the same size as the full-size Q, which is quite large. Therefore, O cannot be stored on the chip but must be written to HBM.

We demonstrate Solution 2 using a method similar to painting a wall, as shown below. Each outer loop updates only one row of O, and after updating the entire row, it writes back. If we consider O as a matrix, then the outer loop fixes one row at a time, and the inner loop updates the corresponding column in that row.
At this point, Q is placed in the outer loop, and K and V become the inner loop. This creates a block-based . All blocks have been completed. , that is, to obtain the final result in blocks , and internal circulation the calculation results can be reused indefinitely. It can guarantee It’s always on the chip, even always in the registers.

The goal of Flash attention (its current prototype) is to optimize from the perspective of GPU’s underlying data storage to reduce operations on HBM and intermediate matrices. Specifically, as follows.
- First, operator fusion was performed. They are integrated into a single operator, allowing multi-step operations to be executed sequentially in SRAM, avoiding reads and writes to the intermediate matrices S and P.
- Secondly, the memory problem is solved by relying on tiling technology. Because SRAM is too small, the input QKV is divided into small blocks, and these blocks are loaded from HBM into SRAM. This ensures that SRAM can hold the intermediate results of matrix operations.
- Because matrix multiplication has the characteristics of block multiplication and accumulation, large matrices can be multiplied into multiple smaller matrix multiplications (the smaller the multiplication, the lower the SRAM requirement). Therefore, attention calculation is performed in blocks in SRAM, and the final correct result is obtained by accumulating the results of each block matrix multiplication.
- In addition, intermediate states also need to be stored during the model training and update process. FlashAttention uses Recomputation to exchange computation for storage, thereby reducing the size of intermediate states.
The computation process of the attention mechanism is “matrix multiplication —> scale —> mask —> softmax —> dropout —> matrix multiplication”. The block computation of matrix multiplication and pointwise operations (scale, mask, dropout) is easy to implement.
Because we’ve only focused on matrix multiplication, everything seems fine for now. However, there’s a hurdle in our path: softmax. The key factor limiting the performance of attention mechanisms is actually softmax. Let’s examine the problems with softmax and how to solve them.
0x03 Softmax Improvements
Let’s start by looking at the problems with the native softmax and how to improve it. Here’s a summary of the problems: matrices are additive, but softmax is not. That is, Self-Attention includes a softmax operator that isn’t directly related, making it difficult to simply tile Self-Attention.
3.1 Native Softmax
The Softmax function is a commonly used activation function in machine learning, especially in multi-class classification problems. Its purpose is to transform an arbitrary real-valued vector into a probability distribution, ensuring that the sum of the output probabilities is 1.
formula
Suppose a certain array is , and is a single element in the array. The native softmax calculation formula is as follows:
The specific algorithm is shown in the figure below. The algorithm process requires two loops, involving two read-from-memory operations and one write-back-to-memory operation:
- Calculate the normalization term, . In the Softmax function, the summation term in the denominator is called the normalization term. Its function is to process each element in the input vector . The values are reduced to smaller proportions, ensuring their sum equals 1, thus conforming to the definition of probability. Therefore, the algorithm first needs to traverse the array, performing exponentiation on each element to obtain the exponent sum d, which is then used as the denominator for subsequent calculations.
- Calculate the output value . Then, iteratively calculate the quotient of the exponent of each element in the array with d, that is, scale each element. Finally, the softmax operation on the entire array is completed.
Each vector element requires three memory accesses: two reads and one write.

accomplish
The simplified code implementation is as follows.
import torch
A = torch.tensor([1., 2., 3., 5., 4.])
def native_softmax(x):
A_exp = torch.exp(x) # 计算e的指数次幂
A_sum = torch.sum(A_exp) # 计算指数和
return A_exp / A_sum # scale操作
print(torch.softmax(A,dim=-1))
print(native_softmax(A))
limit
Returning to the previous optimization approach and solution, we hope to merge the three operators “gemm, point-wise softmax, and gemm”, which requires using tiling to perform block operations on the matrix.
However, softmax and tiling strategies are actually conflicting. This is because softmax lacks the associative property of addition; its calculation formula requires the summation of all elements in the global set of input data before each element can be calculated. Since block-based calculations only allow for the calculation of local sums, the block-based calculation of softmax becomes complex.
Specifically, regarding attention operations, the softmax operation is row-wise. For the complete resulting matrix, softmax needs to be normalized along the Inner Loop dimension using softmax. This means calculating the max/sum for the entire row of each individual row. A local calculation is then performed. The result cannot be immediately calculated with V; it must wait for the result from the next item in the same row. The computation can only begin after all calculations are complete, creating dependencies and impacting parallel computation. Therefore, fusing softmax imposes restrictions on the block partitioning of the first gem: either no partitioning in the row direction to reduce parallelism, or similar k-slicing to increase communication costs. Similarly, if we want to fuse these three into a single kernel without altering the softmax computation, we must always strike a balance between these two approaches.
The real challenge in chunked computation of attention lies in chunked computation of softmax. To overcome this limitation, we need to find a clever way to compute softmax without relying on the entire line of input. That is, to compute softmax (which is the attention score S in the attention mechanism) without accessing the entire input, ensuring that softmax remains associative. This way, softmax is integrated, but tiling is still possible.
The core problem that FlashAttention aims to solve is how to decouple the algorithm itself from this global dependency, thereby enabling fast on-chip computation using Tiling. The reason why FlashAttention can save GPU memory (GPU memory overhead increases linearly with Seq length) is because it decouples the row direction dependency of softmax and subsequent GEMM, and re-scales the auxiliary information stored in the auxiliary array to the correct values.
Therefore, our next step is to study how to correctly obtain the score after block division, how to correctly obtain the softmax, how to correctly obtain O, and how to optimize IO, so as to solve the memory-bound problem.
3.2 Process
The evolution from native softmax to FlashAttention involves several key steps, with two main research milestones:
- NVIDIA 2018’ Online normalizer calculation for softmax. This paper first proposed the Online Softmax technique, which uses equivalent transformations to eliminate row direction dependency in softmax, allowing for parallel computation via Tiling.
- Google Research Rabe, MarkusN., and Charles Staats. Self-Attention Does Not Need Memory. Furthermore, the algorithm is extended from Softmax Tiling to Fused Attention Tiling, and Fused Attention without memory is implemented on TPU & JAX.

Next, this article will start with online-softmax and gradually explain the FlashAttention algorithm.
3.3 3-Pass Safe Softmax
Current problem
In practical calculations, the instability of the exponential calculation exp can cause problems with softmax. For example, because the range of floating-point numbers is limited, for float32 and bfloat16, when , This will result in an inf value, causing a data overflow problem; another example is when each element in the array is a large negative value, each can underflow, causing the entire denominator to become 0, which in turn leads to an error in softmax.
Solution
In practice, people use the safe softmax algorithm, which is based on the “translation invariance” of softmax. It subtracts the maximum value of all elements from each element before performing the softmax operation. The specific formula is as follows:
This algorithm requires performing three loops on the array, namely:
- Find the maximum value for each row. First, we need to iterate through the array once to find the maximum value, max.
- Calculate the exponents and sum them to obtain the normalized term . Then iterate through the array again, subtract the maximum value from each element, and then calculate the exponent. This process is performed element by element.
- Calculate the softmax and output the result. Finally, iterate through the array once, calculate the exponent of each element after subtracting the maximum value, and then divide by the result . This means that the softmax value is obtained after scaling each element.
The pseudocode for the safe softmax algorithm is shown below. There are data dependencies between the three loops above. The second iteration depends on . The third iteration depends on and .

accomplish
import torch
A = torch.tensor([1., 2., 3., 5., 4.])
def safe_softmax(x):
m = torch.max(x) # 计算每行的最大值
A = x - m # 每行元素都需要减去对应的最大值,否则求exp(x)会溢出,导致inf情况
A_exp = torch.exp(A) # 计算e的指数次幂
A_sum = torch.sum(A_exp)
return A_exp / A_sum #广播
print(torch.softmax(A,dim=-1))
print(safe_softmax(A))
Defects exist
Compared to the original softmax, safe softmax adds one more loop. In the context of the Transformer’s attention mechanism, the input to softmax is . The pre-softmax logits are calculated. This means that we need to implement one of the following two schemes on the input:
- Pre-calculate the pre-softmax logits and store them in global video memory; the video memory requirement is .
- The algorithm performs online computation, loading a portion of Q and K into on-chip memory in each loop and calculating the pre-softmax logits.
The current goal of attention optimization is to avoid the first scenario and save GPU memory as much as possible, because there isn’t enough SRAM to store the pre-softmax logits. The result of attention optimization is the second scenario. However, if we consider the safe softmax case, we need to access Q and K three times and recalculate x in real time. The entire calculation process involves three reads and one store operation, which is very inefficient for memory I/O. Although this avoids storing the intermediate matrix pre-softmax logits and saves GPU memory, it doesn’t save computation and increases HBM I/O accesses (requiring continuous loading of Q and K).
3.4 Online softmax 2-pass
motivation
We’ll use the example from “From Online Softmax to FlashAttention” as a comparison. As you can see from the code below, the softmax function requires three loops: the first loop calculates the maximum value of the array, the second loop calculates the denominator of the softmax function, and the third loop calculates the softmax output.

The main problem with the 3-pass algorithm above is that it’s too slow, and we hope to speed it up. For example, could we reduce one memory access? Or reduce one loop? For example, could we fuse formulas (3), (7), and (10) in the above diagram into a single calculation? This would reduce the number of global memory accesses from three to one. Unfortunately, we cannot directly fuse formulas (7) and (10) because formula (10) depends on . The reason why this value can only be obtained after (7) this loop is completed, that is, the calculation must be completed through two rounds of traversal, lies in a redundant dependency: . Therefore, we need a way to merge the first two iterations. This is the 2-pass algorithm below.
algorithm
In safe softmax, softmax requires the max/sum of each row. Therefore, generally, we need to wait until all the data for a row is ready before performing the softmax operation. And the maximum value of the vector is calculated in a separate for loop, and the summation is done in another separate loop. Could we perform tiling accumulation? That is, instead of calculating the max/sum all at once, we could save the current max/sum each time, and then gradually accumulate and update it, eventually obtaining the same result. This way, we can find M and D simultaneously through iteration, reducing the number of loops from two to one.
In online safe softmax, the task is completed through two loops. The second loop is the same as in 3-pass, so it will not be elaborated on; the focus is on the first loop. In the first loop, the input vector is traversed once, and the maximum value m and the normalization term d are calculated simultaneously. Specifically, the algorithm targets the input . In the j-th step of the for loop, is the maximum value of subarray , and is the denominator when calculating softmax for subarray . After the for loop ends, is the maximum value in the entire array, and is the denominator when calculating the softmax of the entire array.
Online safe softmax Softmax reduces the number of memory accesses for function computation from 4 per vector element to 3.

The PyTorch + Python implementation of Algorithm 3 is shown below:
def online_softmax(x: torch.Tensor) -> torch.tensor:
"""Iterative calculation and 2.5x faster than native softmax """
row_cont, col_count = x.shape
assert x.ndim == 2, f"only accepts 2D tensor now"
output = torch.zeros_like(x)
for r in range(row_cont):
row_max = x[r][0]
normalizer = 0
for c in range(1, col_count):
pre_max = row_max
cur = x[r][c]
row_max = max(pre_max, cur)
# if cur > pre_max:
# print(f"Update row max now is {row_max}, row = {r}")
normalizer = normalizer * torch.exp(pre_max - row_max) + torch.exp(cur - row_max)
output[r, :] = torch.exp(x[r, :] - row_max) / normalizer
return output
analyze
The 2-pass algorithm essentially optimizes the calculation of the denominator in the softmax function, making it independent of the global maximum value . Instead, it depends on local maxima . This combines the first two steps into one. In other words, to remove the dependency on N, we create another sequence as a replacement for the original sequence. That is, we find a geometric sequence (recursive form) that removes the dependency on N. This recursive form only depends on and . We can calculate simultaneously in the same loop and .
Specifically, the 2-pass algorithm continuously updates the maximum value m and the normalization term d as it iterates through the input array. In each iteration, the algorithm is based on the new maximum value , updates the normalization term d, and then adds the new value to the normalization term. The update of the maximum value m in the 2-pass algorithm is the same as that in the 3-pass safe softmax algorithm, calculated as follows: when calculating , only and are used. No elements at position greater than k are used. Therefore, the idea of finding the maximum value by dividing the data into blocks is relatively easy to understand. However, the update of the denominator in softmax is slightly different.
- The 3-pass safe softmax algorithm in computation used at the time. This is the maximum value of all elements. This is because we have already obtained the entire array through the first loop . The maximum value of d is found, so d can be calculated directly in the second loop.
- In online safe softmax, the calculation is performed iteratively, therefore the result is m. It’s not the largest value in a vector, but rather a local maximum during the iteration process. That is, when we reach the j-th step of the for loop, we only have a subarray at hand . The maximum value. At this point, the calculated d is not equal to . To maintain the correct d, we need to synchronously adjust . The corresponding right softmax denominator calculation also needs one additional compensation term . The result obtained in this way is consistent with directly using safe softmax.
Therefore, the most important aspect of online safe softmax is how to generate the denominator using a recursive formula . That is to let and there exists an independent given the recursive relationship, we only need to use the elements up to position k to merge them into one traversal. Let’s see how to derive this.
- First, the calculation principle of d is based on the rules of exponentiation: the product of two exponents with the same base is equal to the sum of the two exponent powers, and the same applies to division.
- Secondly, the derivation procedure for d is as follows.
- represents the sum of the first j-1 exponents of the array x[1…n]. It is not calculated based on the global maximum value, but rather based on and . It is used for calculation.
- It is the maximum value of the first j-1 elements. This represents the maximum value among the first j elements. and The difference is that it may equal It could also be the latest j-th element .
- The blue box in the image below represents the sum of the first S-1 terms, . As can be seen from the formula, each element is subtracted from the maximum value of the first S-1 elements, . Therefore, according to the previously mentioned law of exponentiation, the product of two exponents with the same base is equal to the sum of the two exponents raised to their powers. By multiplying with the exponent on the left, the sum of the exponents of the first S-1 terms is automatically updated to the latest value , further update the current .
- In this way, a single traversal can obtain the sum of the exponents of the first n terms of the array x[1…n] and the maximum value max. The subsequent calculation steps are consistent with safe softmax.
See the diagram below for a detailed proof.

Furthermore, the iterative calculation operations of m and d satisfy both the commutative and associative laws. After calculating m and d separately in any block, the results of all sub-blocks are mathematically equivalent. That is, the effect of the max value in the sequence can be delayed until the last step to be corrected. In this way, the normalization constant can be calculated in blocks (out of order), which can take advantage of the multi-threading characteristics of the GPU.

accomplish
The simplified code implementation is as follows.
import torch
A = torch.tensor([1., 2., 3., 5., 4.])
def online_softmax(x):
m = torch.tensor(-1000.0)
d = 0
N = len(x)
a = torch.zeros(N)
for i in range(N):
m_pre = m
m = torch.max(m, x[i])
d = d * (m_pre - m).exp() + (x[i] - m).exp()
for i in range(N):
a[i] = (x[i] - m).exp() / d
return a
print(torch.softmax(A,dim=-1))
print(online_softmax(A))
Defects exist
So, what are the advantages of the 2-pass algorithm compared to the 3-pass algorithm?
Compared to the 3-pass algorithm, the 2-pass algorithm can update both the maximum value m and the denominator d of the softmax function simultaneously in the first loop, thus reducing one loop and one overall loading of Q and K. Furthermore, it can also reduce one online recompute, because in the first pass of the 2-pass, It is shared between the two computations. Therefore, we can ultimately use the GPU’s shared memory to store intermediate results, which only requires two communications with global memory: one to write data and one to read the result.
However, the computational cost of FLOPs in the 2-pass algorithm has not decreased; in fact, it has slightly increased because an additional scale needs to be calculated each time, namely . Therefore, we still need to continue optimizing it.
3.5 Multi-pass Self-Attention
Starting from this section, we will move on to the FlashAttention section.
motivation
Let’s continue our thought process regarding 2-pass online softmax. Since we can obtain a 2-pass algorithm, can we go further and derive a 1-pass online softmax algorithm? Unfortunately, we cannot. This is because the calculation in the second step still depends on the denominator calculated in the first step, . Therefore, it still requires two steps to calculate s.
However, the goal of attention is computation. . Instead of softmax, it uses a 1-pass online softmax algorithm, but not a 1-pass attention algorithm. . This step can only be compressed into two for loops, but if we take the calculation of O=PV into account, we can further compile it into . We can compress this into a single for loop, finding a recursive form with O(1). Let’s see how feasible it is.
Multi-pass Self-Attention algorithm
Let’s first look at the Multi-pass Self-Attention algorithm. This is a 2-pass FlashAttention algorithm based on 2-pass online softmax.
In the first loop of the algorithm, the formula derived from 2-pass online softmax is used. The first loop is actually the same as 2-pass online softmax, except that it adds a step to the calculation.
The second loop step of the algorithm is as follows:
- Attention weights were calculated and the result of the current iteration step . Because O depends on , therefore, it cannot be merged into the first loop; it must wait until the first loop ends to obtain the result .
- Compared to 2-pass, the second loop has more . That is, each time a block is traversed, O will only be updated once.
See the image below for details.

Introduced to FlashAttention
Now that updates once for each block traversed, and and 2-pass There seems to be a similar paradigm, so we’ll see if we can find a recursive relation like in a 2-pass algorithm. (+ Current latest result), until the last block is traversed, at this point The result will be exactly the same as the result in the standard scenario, so it can be merged.
Next, let’s see if we can find a way similar to 2-pass online softmax. and . They do not depend on each other . The recursive relationship is as follows. Specifically, it involves expanding label 2 in the above diagram using label 1. Then, it is converted into a recursive form, as shown in the diagram below.

As you can see, and the recursion between them only depends on , , , , and , not dependent on . Right now, the computation can simultaneously satisfy both the commutative and associative laws. After calculating M, D, and O separately in any block, the results of all sub-blocks are re-aggregated and are mathematically completely equivalent, thus achieving computation within a loop of for , , and . Therefore, we can merge the calculation of the second loop into the first loop to obtain the 1-pass FlashAttention algorithm.
3.6 1-pass FlashAttention
The final 1-pass FlashAttention algorithm is shown in the figure below. The Online Softmax implementation is computed within a for loop, calculating and . FlashAttention-v1 takes this idea a step further, implementing computation within a for loop for , , and attention output . In other words, all attention operations are implemented within a single kernel. The transition from the original 3-pass Self Attention to 1-pass Flash Attention saves GPU memory for the S and P matrices and reduces HBM IO Accesses for Q and K.
Let’s reiterate why Flash Attention can achieve one-pass computation: Flash Attention ensures that all Attention calculations conform to the associative law of addition. While a standalone softmax operation cannot achieve one-pass computation, in self-Attention, after the softmax operation is completed, each term’s value is multiplied by a vector in V and then accumulated. This accumulation is crucial; with this accumulation operation, all calculations again conform to the associative law. This allows for full utilization of the parallel advantages of the GPU.
The core of FlashAttention is to construct a recursion (geometric sequence) that allows partial results to be accumulated globally, thus avoiding the need to load all values at once and calculate them step by step.

3.7 Algorithm FlashAttention (Tiling)
In the tiling form of softmax described above, we update only one element at each step. However, in practical applications, Flash Attention divides the input into multiple blocks, each containing multiple elements. Therefore, by further tiling matrices Q and K, we can obtain a block-tiling version of Flash Attention. First, Q, K, and V are divided into blocks. Then, each small block is loaded from low-performance global memory (GPU) to high-speed SRAM. Attention calculation for the current block is performed in SRAM, and finally, it is written to HBM. Throughout this process, there is no need to save the intermediate matrices S and P, thus greatly reducing the number of HBM accesses (memory read/write operations). Flash Attention then calculates the attention output for each block separately. Finally, the outputs of each block are scaled according to the correct normalization factor and summed to obtain the accurate attention output.


3.8 Summary
Let’s take a look at the optimization process between the major versions through the following diagram.

0x04 FlashAttention V1
Having established the feasibility of improving softmax, let’s examine the optimization strategies for FlashAttention.
4.1 Overall Approach
As mentioned earlier, the standard attention algorithm has two drawbacks in the GPU’s tiered memory architecture: high memory usage and numerous HBM read/write operations. Specifically, the intermediate matrices S and P generated during the standard attention calculation process are too large, resulting in multiple HBM read/write operations.
The main optimization idea of FlashAttention is to reduce the swapping of large intermediate matrices between HBM and SRAM. This is achieved by utilizing faster upper-layer memory computation units to reduce access to slower lower-layer memory, thereby improving model training performance. In essence, FlashAttention optimizes memory access from the perspective of GPU block/thread parallelism, rather than saving FLOPs.
To address this issue, FlashAttention proposes two methods to solve the problem in a distributed manner: tiling and recomputation. Essentially, it’s a customized tiling + recomputation approach to the matrix multiplication problem.
- Tiling (input segmentation, used during forward and backward propagation) divides the input Q, K, and V into blocks, then loads these blocks from the HBM into SRAM. Attention is then computed block by block. Before summing the outputs of each block, they are scaled to the correct normalization factor to obtain the correct result. Finally, the result is written back to the HBM; the intermediate matrices S and P are not saved during this process.
- Recomputation (used only during backpropagation) involves recalculating the attention weight matrix during backpropagation, instead of retaining the entire attention weight matrix. Instead, only certain intermediate variables from the forward pass, such as normalization factors, are preserved. While recomputation increases FLOPS, the computational efficiency of GPUs is currently higher than that of memory access, and significantly reducing HBM accesses can make FlashAttention run faster.
Next, we will focus on the analysis of forward propagation.
4.2 Algorithm
The main challenge in saving GPU memory or performing block-based computation during attention calculation is that softmax is coupled with the columns of K and V. In attention calculation, softmax needs to be calculated by coupling all columns together. Therefore, softmax needs to be calculated without accessing the entire input. FlashAttention decouples softmax and the row direction dependency of subsequent GEMM, and re-scales the auxiliary information stored in the auxiliary array to the correct values.

The algorithm flowchart above is the forward propagation implementation of FlashAttention V1. First, some variables in the algorithm need to be explained:
- All values with ij as subscripts represent the calculation result of the current block.
- All values with “i” as their index represent the calculation results up to and including the previous block.
- All instances with “new” as a superscript indicate that the current block is used to update the result.
- All results without subscripts represent global results.
- S represents the value of the attention matrix before softmax. represents the matrix obtained by multiplying the i-th block of Q and the j-th block of K.
- Let the maximum value of each row of matrix S be denoted as rowmax, and the sum of each row be denoted as rowsum. and represent the results before and after normalization, respectively. That is, represents , and represents the result of .
Next, let’s take a look at the calculation process step by step. The labels below correspond one-to-one with the row numbers in the figure.
- First, calculate the appropriate block size based on the size of the SRAM;
- Will , , and in HBM be initialized as a matrix or vector with all zeros of the corresponding shape. and are two introduced statistics that allow for the decoupling of softmax and the implementation of block-based computation.
- Will , , and be divided into many blocks according to the size of the block, so the global softmax calculation can also be divided into multiple different blocks to calculate the local softmax value separately.
- , , and are also divided into a corresponding number of blocks;
- The outer loop is executed to perform block-level softmax calculations based on recursion.
- The outer loop will divide the data into blocks. and are loaded from HBM into SRAM;
- Execute the inner loop;
- Will , , , and be loaded from HBM into SRAM, then calculate the intermediate values from the above process in blocks, and in each inner loop , , and are written back to HBM, so there are relatively many I/O operations with HBM.
- Perform on-chip SRAM calculation.
- Because we will compute , , and in blocks, and is performed on the entire vector, safe online softmax is used in steps 10, 11, and 12 of the diagram above. Specifically, the following operations are performed: two statistics, m(x) and l(x), are updated iteratively; the numerator and denominator of the global softmax are updated using m(x), the numerator of the current block softmax, and l(x); when all calculations are completed, the global numerator and denominator of the softmax are also calculated, and the final output value O can be obtained.
4.3 Proof
We will now demonstrate the effectiveness of the algorithm.
definition
The relevant definitions are as follows:
- . S is a vector in a row of matrix S. Due to the partitioning process, it has been divided into two parts, .
- . In a standard scenario, this is the global maximum value for that row.
- . The global maximum value of block 1.
- . The global maximum value of block 2.
- . In standard scenarios, is the result.
- . In a segmented scenario, is the result, and it is the molecular part of softmax.
- . In a segmented scenario, is the result.
- . In standard scenarios, is the result.
- . In a segmented scenario, is the result, and it is the denominator of the softmax function.
- . In a block-based scenario, is the result. Here does not represent summation over i, but rather represents summation over all elements in the vector .
Both l and m are scalars, and a single change in the value of i in represents a change in a data block, which can be either a vector or a matrix.
Derivation
conventional softmax
First, consider how to perform block-based calculations for the standard softmax. That is, the non-safe version without subtracting the maximum value is the native version.
One-time input, one-time iteration
If input all at once,
If x contains only two elements, then calculating softmax is as simple as the standard one-time calculation process shown in the figure below.
online-softmax
If the input sequence is too long, the softmax calculation requires a large amount of memory, making it impossible to store everything in SRAM for computation; therefore, it must be performed in blocks. The difficulty of block division lies in the fact that the softmax denominator depends on every value in the input vector x. Therefore, let’s see how to use online-softmax for multi-block accumulation.
The paper calculates the softmax after slicing as follows. That is, for a vector , you can concatenate vectors . The softmax calculation is decomposed into two vectors, and , and the calculations are then pieced together.
Assume , and . Then we have the following:
and are two sub-vectors. is the sum of , and is the sum of . and have no relation. In other words, if we don’t know the complete x beforehand, but only have , then we can calculate first and , and when is ready, calculate and , and make corrections to the previously calculated to obtain the final result .
The derivation is as follows.

Safe Softmax
Next, let’s see how safe softmax handles this.
One-time input, one-time iteration
To ensure numerical stability, for , the calculation process of performing “subtract the maximum value” safe softmax is as follows, where both max and sum require a complete result in one row.

online-softmax
Next, let’s look at how to handle multiple accumulations, specifically using online-softmax. Here, we’ll also concatenate the vectors . The softmax calculation is decomposed into two vectors, and , and the calculations are then pieced together. As you can see, only a global m(x) needs to be maintained, and the remaining states can be converted based on the intermediate values of the locally calculated softmax.

We will now prove its validity.
Calculate the first subvector
If we do not know the complete x beforehand, but only have , then we can calculate first and .
The first block is calculated using the stable version of softmax. The result of , along with the maximum value of the first block, is recorded as . The result of the partial summation of the first block is .
If we calculate the softmax at this point, then
However, this is only a local softmax and needs to be updated or discarded after subsequent calculations are completed.
Update global values
Set variables to record the global maximum value reached during iteration, and set the variable to record the global EXP summation result at this point in the iteration; subsequent iterations will calculate different blocks and gradually update and .
After calculating over the first piece,
If updated later, the subvectors can be retained, , for subsequent calculations, but is likely to be very large, so it’s better to keep two scalars, and . More economical. Furthermore, it’s necessary to preserve the global scalar and its current maximum value , and the summation of the global EXP term . This is the denominator of the global softmax function.
Calculate the second subvector
When is ready, calculate and the result of the partial summation of the second block, .
Record the maximum value of the second block at the same time, , and the local summation result of the second block, .
If we then calculate the softmax of the second sub-vector, then
However, this is only a local softmax and needs to be updated or discarded after subsequent calculations are completed.
Update global values
Update and iterate to the current global maximum value.
The global summation result at this point in the update iteration is:
cannot be directly accumulated and updated, because different values cannot guarantee that the global maximum value has been subtracted.
Calculate global softmax
Because
Now has been updated, let’s take a look at how is updated.
Since the maximum value has changed, corresponding to the previous i blocks needs to be corrected:
- For local , previously subtracted . Therefore, we need to add it back and then subtract the new m(x), that is .
- For local , previously subtracted . Therefore, we need to add it back and then subtract the new m(x), that is .
Therefore, by merging, we get:
Thus, the final result is obtained.
Summarize
The above is actually an incremental calculation process:
- We first calculate the local softmax value of a block, process it, and then temporarily store it, recording the global maximum value and the global softmax denominator value.
- After the next block is processed, the softmax value of this local block is maintained; then the new global maximum value and the global softmax denominator value are obtained.
- Then, update the two existing local softmax values. For each local softmax value , when updating its local softmax_i, the variables needed are: the local maximum of , , the local EXP summation term , the local softmax value , the global maximum value , and the global EXP summation term .
- Assume it exists , then you can combine and into a single sequence and repeat the previous steps.
- This process repeats itself, ensuring that we always obtain a continuously updated global softmax value. After processing all blocks, the softmax value of all blocks is now “global”.
The specific details are shown in the image below.

Code demonstration
import numpy as np
import torch
def softmax(x):
m_x = np.max(x)
f_x = np.exp(x - m_x)
l_x = np.sum(f_x)
soft_x = f_x / l_x
return m_x, f_x, l_x, soft_x
m_x1, f_x1, l_x1, soft_x1 = softmax(np.array([1, 2]))
m_x2, f_x2, l_x2, soft_x2 = softmax(np.array([3, 4]))
m_x_new = np.max([m_x1, m_x2])
l_new_all = np.exp(m_x1 - m_x_new) * l_x1 + np.exp(m_x2 - m_x_new) * l_x2
soft_x1_new = soft_x1 * l_x1 * np.exp(m_x1 - m_x_new) / l_new_all
soft_x2_new = soft_x2 * l_x2 * np.exp(m_x2 - m_x_new) / l_new_all
soft = torch.nn.functional.softmax(torch.Tensor([1, 2, 3, 4]), dim=0)
# [0.0320586 0.08714432] [0.23688282 0.64391426]
print(soft_x1_new, soft_x2_new)
# [0.0320586 0.08714432 0.23688284 0.6439143 ]
print(soft.numpy())
Analyze in conjunction with O
Let’s analyze this further by considering the output value O. From a matrix perspective, the outer loop j corresponds to a column of the O matrix, and the inner loop i corresponds to a row of the O matrix. After the j-th iteration of the outer loop ends, HBM yields the following:

In the (j+1)-th loop, we will perform the following processing to finally obtain the output O.

Alternatively, it can be shown in the image below.

4.4 Segmentation
How to divide
We need to consider the SRAM size to see how the four matrices Q, K, V, and O are divided into blocks in the algorithm. All four matrices are divided into blocks by row.
- Divide the Q matrix into blocks. Each block has length . A segment after being divided is denoted as . Its dimensions are . It contains query information for each token.
- Divide the K matrix into blocks. Each block has length . A segment after being divided is denoted as . Its dimensions are . It contains the key information of each token.
- Divide the V matrix into blocks. Each block has length . A segment after being divided is denoted as . Its dimensions are . It contains the value information of each token.
- Divide O into blocks. Each block has length . Its dimensions are . Each segment after being divided is denoted as . For O, the blocks are divided in a row-by-row manner.
In addition, and will be divided into blocks. Each block has length . Its dimensions are . Each segment after being divided is denoted as and .
Block size
Each inner loop will read , , and . It is stored in SRAM and then calculated as . The purpose of setting the block size is to allow the SRAM to accommodate as large a sub-block as possible. Since four “matrix” blocks are imported into the on-chip SRAM, we need to divide by 4d during slicing (blocking). This gives us the “sequence length” of each block. Here, M is the upper limit of the SRAM available to the system, d is the dimension, and M is the size of the on-chip SRAM, for example, 20M.
- Q and O block sizes are
- The block sizes of K and V are
The purpose of this setup is to ensure that the SRAM can hold all the small blocks of Q, K, V, and O. 4d represents the product of 4 and d, where 4 represents the four blocks of Q, K, V, and O. Furthermore, besides the inputs QKV and the output O, only the maximum value vector needs to be stored, , and the exponential sum . The storage overhead they bring is , which is negligible compared to the four blocks.
Therefore, we obtain the following: For each block of Q, , , and the partitioning of K and V, and , the required shared memory is M. Adding and , the required storage can be almost entirely utilized by SRAM.

Of course, this is an analysis conclusion based on the algorithm’s pseudocode. Specific engineering implementations may have slight differences, but the overall approach remains largely the same.
limitation
Next, let’s look at the impact of dimension d on FlashAttention.
In FlashAttention, the amount of SRAM required is related to Br/Bc and head_dim(d). Br and Bc are constants, which can usually be set to 64 or 128, while d is variable.
- If head_dim=d is larger, then and become smaller, the block size becomes smaller, and the execution time becomes longer.
- Smaller blocks mean that for the same seqlen, more iterations are needed, which translates to more thread blocks. With the same occupancy, more schedules are required to complete the computation, increasing the time consumption.
- Since Br decreases, it means that the number of outer Q loops increases. For each Q loop, all K and V need to be loaded into SRAM in blocks. In other words, memory accesses will also increase, which will lead to increased time consumption.
- If d is smaller, the block size will increase. Increasing the block size usually means reducing SRAM I/O operations, but it increases the usage of registers and SRAM. Since the amount of data that can be placed in the SRAM of each thread block is limited, this limits the upper limit of active SMs in the system.
4.5 Process
We analyze the process according to the algorithm.
Prerequisites

The prerequisite is
- Q, K, and V are located in HBM.
- The size of the SRAM has been obtained, let’s assume it’s M.
first step

The first step is to set the block size. The data to be stored in the SRAM includes: sub-blocks of the Q matrix, sub-blocks of the K matrix, sub-blocks of the V matrix, and intermediate outputs O from the computation process. Therefore, the sub-block sizes of the Q and O matrices are set based on the SRAM size M and the input vector dimension d. is used for the Q and O matrix sub-blocks, and the dimensions of the K and V matrix sub-blocks are . The purpose of rounding up is to provide space redundancy and prevent data loss.
Step 2

The second step initializes O, l, and m. The specific operations are as follows:
- Initialize the output matrix O on HBM to all zeros.
- Setting the variable l on HBM to all zeros will store the cumulative denominator of the softmax function.
- Set the variable m on HBM to
-inf, where m is used to record the maximum value of each row.
Step 3

The third step is to divide the matrix into blocks. Q is divided into blocks, and K and V are divided into blocks. Each of the blocks is divided into rows.
- Divide the Q matrix into blocks. Each block has length . A segment after being divided is denoted as . Its dimensions are . The block stores query information for each token.
- Divide the K matrix into blocks. Each block has length . A segment after being divided is denoted as . Its dimensions are . The block stores key information for each token.
- Divide the V matrix into blocks. Each block has length . A segment after being divided is denoted as . Its dimensions are . The block stores value information for each token.
Step 4

The fourth step is to divide O, l, and m into blocks, that is, to cut O, l, and m into blocks. Each block has length . Its dimensions are . Each segment after being divided is denoted as , , and . For O, the blocks are divided in a row-by-row manner.
Loop calculation
Next, we begin the iterative calculation. j is the outer loop, and i is the inner loop. This means that for each j, we iterate through all i values to obtain the relevant result. In the paper, this is also referred to as K, V is the outer loop, and Q is the inner loop. The code would look like this:
# ---------------------
# Tc: K和V的分块数
# Tr: Q和O的分块数量
# ---------------------
for 1 <= j <= Tc:
for 1 <= i <= Tr:
do....
See the image below for details.

Step 5

External circulation. controls the traversal through K and V in a cross-column loop.
Step 6

The sixth step will read the data from the currently traversed HBM, the video memory. and are moved to on-chip SRAM storage. At this point, we still had 50% of the SRAM unused, dedicated to Q and O.
Step 7
The seventh step begins the inner loop that spans multiple lines. The inner loop consists of . The control will iterate through Q, O, l, and m.

Step 8

Step 8 will add the current loop and blocks, together with and , loaded into SRAM.
Loop calculation
Next, we will perform calculations within the loop. Let’s first outline the workings of the inner and outer loops.
- Outer loop: Loads the sub-blocks of K and V from HBM into SRAM.
- Inner loop: Load the sub-blocks of Q, O, l, and m from HBM into SRAM, and then perform the attention S calculation on SRAM.
- First, based on the sub-blocks calculated in the previous step, is used to calculate the maximum row value of the current block, , the current block , i.e. the molecule of softmax, and for the cumulative value.
- Secondly, calculate the maximum value between sub-blocks, , and multiple sub-blocks, cumulative value .
- Finally, the softmax algorithm mentioned above is used to calculate ; the last is assigned to , and is assigned to . And these variables are written back from SRAM to HBM.
The paper’s diagrams perfectly illustrate this iterative process. We have also annotated the algorithm in the figure below. It is important to note here that , , and may store intermediate results from the previous loop calculation.

Step 9

Step 9: For each pair of blocks, , we can calculate their dot product, which is the correlation score between the two blocks.
indicates the preceding tokens and the previous the original correlation scores between the tokens. The shape changes are
Step 10

Continue the operation using the score calculated in the previous step, that is, based on the current block , calculate the intermediate state , , and for each block.
- finds the maximum element in each of the above rows, i.e., the current block the local maximum value of each row. Corresponding to or . It is the global maximum value of block 1 or block 2 respectively. may not be the maximum value of the i-th row of S.
- involves taking the maximum value of each row and subtracting it from the row fraction, then performing the EXP operation. In a block-based scenario, this yields the P matrix before normalization for each block. Note that this division by the denominator in the softmax formula is not yet performed. This corresponds to or .
Step 11

After each block is computed, these intermediate states are updated in real time to ensure the global result is correct. Specifically:
- . If the current block is , then represents the local maximum value in the first j-1 blocks when i is fixed. When i is fixed, after traversing j, the result is the global maximum value.
- . When i is fixed, it means maintaining the local maximum value up to the current block.
- and are, after traversing the latest , the resulting rowmax and rowsum values. So after each iteration of , we execute line 13 of the pseudocode to perform an update.
- and . Similarly, after we have traversed all of j, we can obtain the global rowmax and global rowsum for i.
- is the row-by-row sum of matrix P, which is the result of rowsum in a block-based scenario. It is equivalent to or .
Note that the calculated value will not be used at this point, and , to update the old and .
Step Twelve

The pseudocode for step 12 is explained below:
- is for each line the vector formed by these vectors. P represents the inactive attention matrix within the current block.
- Label 1 implements the denominator of softmax. Label 1, as a whole, and its right-hand parenthesis are operated on together; this can be understood as dividing all the values within the right-hand parenthesis by . This cancels out the same constant that was divided in previous iterations, this constant is hidden in .
- involves operating on vector l to obtain a diagonal matrix, where each row of the matrix contains only one element at the diagonal position. This allows for element-wise multiplication of two vectors of the same length.
-1is used to find the inverse of the diagonal matrix, the reciprocal of the diagonal values, which is exactly the denominator of the softmax function. - Label 2 is the result of “cumulative” calculation so far, . Its function is to update the local softmax value of the previous block.
- Number 3 is the one that is used in this case according to and calculated as .
- Coefficients of labels 2 and 3, terms, are used to modify the matrix and . The specific method is to eliminate the previous iteration’s and update it with the latest estimates . This includes the maximum value for each row up to this point. Essentially, it undoes the original softmax values, multiplying each softmax value by the denominator of the original softmax value to achieve “granularity alignment” with the denominator of the new softmax value.
- The weighted sum of labels 2 and 3, multiplied by label 1, yields the new . Furthermore, due to the final there will only be one value; therefore, after all blocks have been computed, the output value of is accurate.
This step is actually the pseudocode implementation of the formula below.

It expands the standard O-form of computation into a recursive form.
Because we are performing block-based calculations on Q, K, and V, each only a partial output is obtained; all the output blocks need to be combined to get the final output O. This is for all the accumulation operation. After expanding into a recursive form, it updates once for each block traversed, that is, on the previous basis, the information of the current block is used for updating. + Current latest result. After traversing all the blocks, we can obtain the result that is completely consistent with the standard scenario, .
The derivation is as follows.

Step Thirteen

Use and to update and .
The latest accumulated statistics, and , are then written back to HBM. Note that their dimension is .
The operations of m and l can be summarized as follows.

Steps fourteen, fifteen, and sixteen
When the nested for loops end, O will contain the final result: a vector of attention weights for each input token.

Summarize
We summarize the algorithm using the diagram below. Looking at lines 5-13 of the pseudocode, you’ll find that only , , and are written back from the on-chip SRAM to the video memory, HBM. After iterating through all values of i, the read/write operations are only m, l, and O. In a standard scenario, we need to read and write S, P, and O. Therefore, the significance of block-based calculation of the safe softmax is to eliminate reads and writes to S and P, thereby reducing memory requirements.

0x05 Computational complexity and video memory usage
5.1 I/O Complexity
We first assume that N is the length of the sequence, d is the dimension of the attention head, and M is the size of the SRAM.
Standard attention
Based on the preceding space complexity analysis, the GPU memory required for attention operations increases quadratically with the sequence length n. Since the computation needs to be performed on the GPU’s SRAM, this process requires constant data exchange between HBM and SRAM, resulting in a significant amount of time being consumed in swapping data between SRAM and HBM.

Based on the size of each matrix involved in the calculation, its MAC count can be analyzed, based on accessing a single float value.
- Calculate S. The first line of the algorithm reads from HBM, . The number of MAC reads is 2Nd. Calculate . Then
Sis written back to HBM, the number of times the MAC is written is . The total number of I/O operations is . - Calculate P. The second line of the algorithm reads from HBM, . The number of MAC reads is , calculate , and write back to HBM. The number of IOs is .
- Calculate O. The third line of the algorithm reads from HBM, . The number of MAC reads is . Read . The number of MAC reads is Nd. Calculate . Then is written back to HBM, the number of MAC writes is Nd, and the number of IOs is .
The total MAC overhead for all of the above is . Ignoring the constant term, the complexity can be written as . Therefore, in summary, the memory access complexity of standard Attention computation is .
Calculate and . The calculated strength of the attention weights is obtained:
FlashAttention
From a specific code perspective,
- The outer layer loads K and V blocks. Each block is . Line 6 loads each block of K and V only once, and the MAC of HBM is calculated each time as . The outer loop is times. So in total it is .
- The inner loop loads Q and O blocks. Each block is . Line 8 loads each block of Q and O only once, and line 12 writes back the O block. In summary, the inner loop loads the entire Q and O set. The number of loop iterations is . In conjunction with the external circulation, the overall situation is .
Therefore, the total IO is
Because , FlashAttention total IO is . In general, d is 64 or 128, and M is usually around 100K, because is much less than 1. Therefore, the MAC of FlashAttention v1 is much smaller than that of the standard Transformer.

Backpropagation
The above details the entire forward pass of FlashAttention v1. The backward pass will not be elaborated upon, as it is essentially highly related to the forward pass and shares a similar approach. The only additional explanation is that, in the backward pass, besides employing methods similar to the forward pass to reduce MAC, it also uses some techniques to reduce the overall memory overhead, for example:
- Backpropagation requires matrices P and S to compute gradients with respect to Q, K, and V. However, this can be achieved by recompiling matrices P and S, saving memory overhead, by storing the output O, the maximum value m, and the EXP summation term l.
- By saving the state of the random number generator in the forward process, Dropout masks are generated in the reverse process, thus saving the overhead of storing all Dropout masks.
5.2 Computational Complexity
Standard attention
In terms of time complexity, attention requires multiplying the transposes of matrices Q and K to obtain the attention weight matrix. Ignoring the batch dimension, assuming both matrices Q and K have dimensions of (n, dim), the time complexity of multiplying a matrix (n, dim) and (dim, n) is on the order of the square of the sequence length, meaning the time complexity of attention is . When the sequence is long, that is, n is large, the computation of attention is very time-consuming.
Assuming the input sequence has a length of N and a dimension of d, and is divided into h heads, the calculation process of the corresponding Attention can be divided into the following steps:
- Linear Transformation: Perform a linear transformation on the input sequence to obtain three matrices: Q, K, and V. Assuming the embedding dimension of each token is k, the complexity of this step is .
- Calculate the similarity score: The similarity score is calculated using matrices Q and K to obtain the attention weight matrix. The attention weight matrix is N * N in size, and the time complexity of calculating this matrix is .
- Weighted summation: Multiply the attention weight matrix by the V matrix and sum them using weighted methods to obtain the final output. The complexity of this step is .
Therefore, the total computational complexity of Attention is approximately .
FlashAttention
The computational workload mainly comes from matrix multiplication.
- Line 9, FLOPs is .
- Line 12, FLOPs is .
- Total number of loops is . In general,
Consistent with standard attention calculations.

0xFF Reference
https://arxiv.org/abs/2410.01359
A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library. https://research.colfax-intl.com/wp-content/uploads/2023/12/colfax-flashattention.pdf
Andrew Kerr. Gtc 2020: developing cuda kernels to push tensor cores to the absolute limit on nvidia a100. May 2020.
FLASHDECODING++: FASTER LARGE LANGUAGE MODEL INFERENCE ON GPUS. https://arxiv.org/pdf/2311.01282.pdf
Flash-Decoding for long-context inference. https://crfm.stanford.edu/2023/10/12/flashdecoding.html
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. https://arxiv.org/abs/2307.08691
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. https://arxiv.org/pdf/2205.14135.pdf
FlashMask: Efficient and Rich Mask Extension of FlashAttention. https://arxiv.org/abs/2410.01359
FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention. https://pytorch.org/blog/flexattention/
From Online Softmax to FlashAttention. https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf
Maxim Milakov and Natalia Gimelshein. Online normalizer calculation for softmax. CoRR, abs/1805.02867, 2018.
Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. https://arxiv.org/pdf/1909.08053.pdf
Self-attention Does Not Need O(n^2) Memory https://arxiv.org/abs/2112.05682
The I/O Complexity of Attention, or How Optimal is Flash Attention? https://arxiv.org/pdf/2402.07443.pdf
(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)
Antinomi: FlashAttention Core Logic and V1/V2 Differences Summary
Decode Optimization - Lean Attention (Hand-grab Pancake Bear)
Learning from the official Triton example of Flash Attention V2 [forward] from the L77 Nebula
Flash Attention on Intel GPU (Minor Issues)
FlashAttention v2 paper review and advancement: Killua
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
FlashAttention: Accelerates computation, saves GPU memory, and provides precise I/O-aware attention rotation (Thomas X).
FlashAttentions Chenfan Blog
FlashAttention Explained (How to Speed Up Attention) - Austin
FlashAttention Core Logic and V1/V2 Differences Summary (Antinomi)
FlashAttention Algorithm Explained in DeepHub
A Summary of FlashAttention Calculation Process ( by Pangpangdahai)
From Online Softmax to FlashAttention by Zihao Ye
From Online Softmax to FlashAttention
GitHub: LLMForEverybody
LLM inference acceleration technology – operator fusion method of Flash Attention
NLP (17): From FlashAttention to PagedAttention, How to Further Optimize Attention Performance
Purple Qi Comes from the East
Online normalizer calculation for softmax (arxiv.org) Maxim Milakov and Natalia Gimelshein. Online normalizer calculation for softmax. CoRR, abs/1805.02867, 2018.
Online normalizer calculation for softmax
Performance optimization of Scaled Dot Product Attention (SDPA) on CPU by Mingfei
[Large Model Training] FlashAttention v1, v2 - The Clearest Formula Derivation & Algorithm Explanation ( Alan’s Brief Sharing)
[Attention Optimization][20,000 words] Principles & Diagrams: From Online-Softmax to FlashAttention V1/V2/V3 DefTruth
[Attention Optimization][10,000 words] TensorRT 9.2 MHA/Myelin Optimize vs FlashAttention-2 profile DefTruth
[FlashAttention][20,000 words] Principles & Diagrams: From Online-Softmax to FlashAttention-1/2/FlashDecoding/FlashDecoding++ DefTruth
Flash Attention paper and source code study - KIDGINBROOK
Flash Attention 1-2-3 Series Summary Zhang
https://github.com/Dao-AILab/flash-attention
Online Softmax Paper Interpretation by Zhang
ops(7): CUDA Implementation and Optimization of Self-Attention (Part 1) Purple Qi Comes from the East
ops(8): CUDA Implementation and Optimization of Self-Attention (Part 2) Purple Qi Comes from the East
[Tearing Down LLM-FlashAttention2] Just Because the For Loop Optimization Was Too Beautiful - Little Winter Melon AIGC
[Treating LLM-FlashAttention from Scratch] Starting with Softmax, a super long and easy-to-understand article!! (by Xiaodonggua AIGC)
[Tearing apart Online Softmax] Flash Attention basics, one question and they’re all silent!!! Little Winter Melon AIGC [Tearing apart LLM]
Multitasking Online Softmax TaurusMoon
A lengthy article explaining FlashAttention v1/v2 Civ
A lengthy article explaining FlashAttention v1/v2 Civ
Reproduce Flash Attention 66RING using Cutlass Cute
Thomas the Tank Engine X: FlashAttention: Accelerates computation, saves GPU memory, and provides precise attention that is aware of I/O.
Illustrated Guide to Accelerating Large Model Computation Series: Flash Attention V2, From Principles to Parallel Computation ( by Mengyuan)
Illustrated Guide to Accelerating Large Model Computation Series: FlashAttention V1, From Hardware to Computational Logic
Large Model Analysis: Flash Attention Snowball Effect
Accelerating Large Model Training with FlashAttention Series: The Product Perspective Behind Hit Projects (by Fang Jiarui)
Some thoughts and questions about learning Flash Attention and Flash Decoding.
My Transformer Acceleration Notes (Part 1): FlashAttention - delin
A Hand-Drawn Guide to Flash Attention: Principle Analysis and Code Implementation
A Hand-Drawn Guide to Flash Attention! Principle Analysis and Code Implementation. Goodnight, Tombrey!
Exploring Linear Attention: Does Attention Require a Softmax? By Su Jianlin
Learning FlashAttention 2 slowly and carefully - Example 1: The Lost Little Bookboy
Learning FlashAttention slowly and carefully - The Lost Little Bookboy
Detailed Derivation of the Flash Attention Monster
Tim on the Road: A Simple Explanation of the Acceleration Principle of FlashAttention
A Thorough Understanding of FlashAttention and FlashAttention2: One of the Techniques for Enabling Large Model Context Lengths to Exceed 32K v_JULY_v
Summary of methods to reduce Transformer complexity to O(N^2) (Part 1) Civ
Summary of methods to reduce Transformer complexity to O(N^2) (Part 2) Civ
https://tridao.me/publications/flash2/flash2.pdf
Derivation of the Forward Process of Ring Attention and FlashAttentionV2 from a Coding Perspective (Yang Pengcheng)
Let’s discuss the principles of the forward pass of FlashAttentionV3 with code examples. (Yang Pengcheng)