Transformer Systems · Transformer Systems
Exploring the Transformer Series (19) --- FlashAttention V2 and its Upgrade
FlashAttention V2, Flash-Decoding, Flash-Mask, and FlashAttention-3.
Exploring the Transformer Series (19) --- FlashAttention V2 and its Upgrade
0x00 Overview
FlashAttention leverages the asymmetric hierarchy of GPU memory to reduce memory consumption to linear (rather than quadratic) levels, achieving a 2 to 4x speedup compared to the optimized baseline. However, this technique still falls short of the speed of optimized matrix multiplication (GEMM) operations. Forward propagation throughput reaches only 30-50% of the theoretical maximum floating-point operation rate (FLOPs/s), while backward propagation only reaches 25-35%. This inefficiency is due to poor load distribution between different thread blocks on the GPU, resulting in low occupancy or unnecessary shared memory reads/writes.
Therefore, the original authors upgraded FlashAttention, resulting in version V2. Other researchers also used their ingenuity to optimize and develop it further, building upon V1 and V2.
0x01 FlashAttention V2
1.1 Motivation
The authors found that inefficient worker partitioning across different thread blocks and warps on the GPU is a major cause of computational inefficiency. To address this issue, FlashAttention 2 designs a better worker partitioning scheme, fully utilizing parallelization and efficient work decomposition to improve computational utilization.
1.2 Scheme
The optimization points of FlashAttention 2 mainly include the following, of which the second and third points can be attributed to optimizations at the CUDA gemm level.
- Reduce redundant computations. Reduce non-matrix multiplication FLOPs and increase the computational proportion of Tensor Cores.
- Parallelism along the sequence length dimension. Parallelism is achieved at the single-head level across different thread blocks, with forward and backward propagation parallelized along the sequence length dimension. This method increases GPU utilization when the input sequence is very long (where the batch size is typically small). Even for a single head, parallel computation is performed across different thread blocks.
- Adjust the warp partitioning strategy to distribute the load and reduce communication. Within an attention computation block, distribute the work across different warps in a single thread block to reduce data exchange and shared memory read/write operations.
Reduce redundant calculations
Why reduce non-matrix multiplication operations? This is because matrix multiplication can be implemented efficiently on modern hardware.
In deep learning, matrix multiplication is commonly used for forward and backward propagation. To meet acceleration requirements, hardware manufacturers have customized dedicated computational units for GEMM (Generative Matrix Multiplication); conversely, with the advent of these dedicated units, software algorithm design and implementation have also converged in this direction, creating a mutual influence. However, not all operations can be represented in matrix multiplication form; operations such as addition, multiplication, and division fall outside of matrix multiplication. Although the FLOPs of these non-matrix multiplication operations are lower than those of matrix multiplication, their computational throughput is significantly lower due to the lack of targeted acceleration. Therefore, it is necessary to find ways to avoid non-matrix operations on GPUs, thereby reducing the FLOPs of non-matrix multiplication.
Reducing redundant calculations and changing the loop order is achieved by adjusting the algorithm structure, mainly by eliminating the previously frequent rescale operations.
Increase parallelism
FlashAttention V1 applies parallelism in the batch size and head dimensions, meaning each head is allocated a thread block, resulting in a total of batch_size * head_num thread blocks operating in parallel. However, due to memory constraints, when processing long sequence inputs, the batch size and number of heads are typically reduced, thus decreasing the degree of parallelism.
Therefore, FlashAttention V2 further parallelizes the sequence length dimension by modifying the Q-loop in V1 to operate in parallel using multiple thread blocks. This increases the total number of thread blocks, thereby improving GPU utilization. Specifically, V2 introduces the concept of num_m_block, further dividing the Q-matrix into multiple smaller blocks along the sequence length, each processed by a different block. Moreover, each block can independently compute its assigned output, reducing dependencies and communication overhead between different blocks.
The purpose of sequence parallelism is to better divide thread blocks.
Adjust Warp Partitioning Strategy
FlashAttention V1 uses a split-K strategy, in which all warps write intermediate results to shared memory for synchronization, and then add the intermediate results together. These shared memory reads slow down the forward propagation computation.
FlashAttention V2 uses a better warp partitioning strategy to distribute the workload between warps within each thread block, thereby reducing communication via shared memory.
Essentially, adjusting the warps workload strategy involves optimization within a thread block.
1.3 Algorithm
The main optimization of the FlashAttention V2 algorithm is the swapping of the outer and inner loops. The Q-loop is moved to the outermost layer, and the key-value loop is moved to the inner loop.

The details are as follows.
- Compared to V1, lines 3 and 6 of V2 swap the order of the outer and inner loops. The Q loop is moved to the outermost loop, and the KV loop is moved to the inner loop.
- The 8th row will calculate the block. .
- The 9th line will update three intermediate variables.
- Indicates the end point up to the current block. rowmax up to (including the current block);
- This indicates that the value of each current row is used to calculate the value before normalization. ;
- Indicates the end point up to the current block. The rowsum (including the current block);
- The 10th row will calculate O. Indicates the end point up to the current block. (Including the current block) The calculated O value. From lines 9 and 10, we know that when we fix the Q-loop KV, each block is calculated using the latest rowmax and rowsum. Similarly, the corresponding It’s also calculated using the latest rowmax and rowsum. This way, when we’ve iterated through all the key-value pairs, we get That is equivalent to the final global result.
- Line 12 A uniform normalization operation will be performed on O. Normalization is not performed in the inner loop, but is instead done in the outer loop, which reduces non-matrix operations.
- The 13th line will calculate intermediate variables. And it’s written back to HBM on line 15. Because it’s read from HBM , It would consume read and write operations, so we don’t want to store the corresponding data for each Q block anymore. and However, in backpropagation, we still need , Come and do and The recalculation (using the chain rule to calculate dQ, dK, dV requires this operation) is necessary. Therefore, in V2, we only store Then through To calculate This saves on HBM read/write operations. L is an abbreviation for log-sum-exp.
Reduce redundant calculations
The FlashAttention V2 algorithm reduces redundant computation by minimizing the number of intermediate scaling operations.
Original Softmax
To ensure numerical stability (because exponential growth can lead to excessively large values or even overflow), the original softmax subtracts the maximum value, which comes at the cost of iterating through the token three times.

FlashAttention V1
The operation of FlashAttention V1 to calculate O is shown below.

The image below illustrates how FlashAttention uses online softmax for block-based computation.

FlashAttention V2
FlashAttention V2 has been modified as follows.

By comparing V1 and V2 side by side, we can better see the differences.
- The V1 algorithm iteratively rescales the preceding values in the inner loop, meaning that a rescale operation needs to be performed for each iteration of each block, which involves division operations.
- The V2 algorithm moves the rescale operation from the inner loop to the outer loop. This rescale operation is deferred until the end of the loop, reducing the number of division operations by one per calculation. That is:
- In the inner loop, calculate Deleted at the time Operation, only for The molecules are corrected; in the calculation Deleted at the time operate.
- After the inner loop finishes, a rescale correction is performed uniformly in the outer loop to obtain the final value. This reduces one division (non-matrix multiplication) operation in each inner loop calculation. V2 only needs to ensure the numerator part and Scaled to the correct value, and the final denominator can be calculated. You can get the same effect as V1.

Swap loop order
GPU Features
Before going into detail about the parallel strategy of FlashAttention v2, we need to briefly review the basic working principle of GPUs.
From a hardware perspective, GPUs are well-suited for parallel tasks because they typically contain a large number of computing units. While a single GPU computing unit is generally less powerful than a CPU, the sheer number of units allows them to perform parallel tasks simultaneously. Streaming multiprocessors (SMs) are the actual physical computing units within a GPU; the A100 contains 108 SMs. To improve computational throughput, it’s essential to ensure that as many SMs as possible are engaged in computation at any given time.
From a software perspective, GPUs rely on threads to perform computations. GPUs have a large number of threads, which are managed in the form of thread blocks. For example, each thread block consists of 128 threads, and these thread blocks are scheduled to perform computations on the SM (Software Module).
To facilitate better collaboration, each thread block is further divided into multiple warps. A warp is the basic unit of parallel computing on an NVIDIA GPU (the smallest unit of actual thread scheduling). A warp typically contains 32 threads that execute the same instructions simultaneously but operate on different data. When the GPU executes instructions, scheduling is usually done in units of warps, which fully utilizes the GPU’s parallel processing capabilities. All threads within the same warp can collaborate to complete matrix multiplication. However, if shared variables are not within the same thread block, it means that more intermediate results need to be written to shared memory.
FlashAttention V1
Let’s first look at some features of version V1 from the perspective of parallelization.
First, the prerequisite is: if we consider O as a matrix, then from a matrix perspective, the outer loop j in version V1 corresponds to the column of matrix O, and the inner loop i corresponds to the row of matrix O.
Secondly, the current configuration of the inner and outer loops necessitates placing the entire outer loop operation within a single thread block because:
- During forward propagation, we need to perform online softmax accumulation and update it column-wise within each row (outer loop direction). is required., , and , is calculated in the inner loop.
- The inner loop iterates in the row direction, which conflicts with the online softmax operation in the column direction on each row. Additional reduction logic is needed to complete the online softmax.
Ideally, V1 should place the entire outer loop operation within a single thread block to share intermediate results from the softmax calculation, thus speeding up the process. If the entire outer loop operation is not within the same thread block, this intermediate result information must be stored in shared memory, or additional communication operations are required, such as cross-thread block reduce.
Third, the current configuration of internal and external loops leads to dependencies between them. This is because updates Need to use In V1’s nested loops, K and V are loaded first in the outer loop, and then Q is loaded in the inner loop. This results in the inner loop only calculating [the first value] each time Part of it, and each iteration of the inner loop requires Perform global memory read and write operations.
In summary, V1 can only perform parallelism at the thread block level in the batch_size and headnum dimensions. When the sequence is long and the batch size is small, the efficiency of V1 drops significantly. See the diagram below for a specific example. FlashAttention v1 uses a single thread block to generate the result O shown in the diagram. Alternatively, it can be understood that the entire inner and outer loops combined constitute a single thread block.

FlashAttention V2
Analysis of V1 shows that the inner loop should not be placed in the softmax reduction dimension. Furthermore, in the Attention calculation, the Attention calculations for different queries are completely independent. The output O1 is only related to Q1 and has no logical dependency on Q2, Q3, or Q4, so it should be possible to parallelize it.
Therefore, FA2 adjusts the loop order for forward propagation, loading Q first, then K and V.
Let’s analyze the impact of adjusting the order.
- The outer loop can increase parallelism. After swapping the order of the Q loops to the outermost level, Parallelism is naturally possible along the “row” direction (seqlen), as there is no dependency between each iteration of the outer loop. This dimension of parallelism can be changed from serial iterations to parallel thread blocks; that is, attention calculations for different query blocks are sent to different thread blocks for parallel execution, and these thread blocks do not need to communicate with each other.
- Internal loops can reduce operations.
- Compared to FA1, the inner loop does not need to access O_i, ℓ_i, m_i to HBM every time, thus reducing IO operations and time consumption.
- Online softmax accumulates column-wise across each row, aligning with the iteration direction of the inner loop, thus eliminating the need for additional reduction logic.
Therefore, V2 can perform parallel splitting of the three nested loops (batch_size, num_heads, seq_len) at the thread block level. For seq_len, it can be understood that the outer loop is split into These are parallel blocks. There is no need for communication between these thread blocks, thus significantly increasing the GPU’s throughput.
As shown in the figure below, FlashAttention v1 uses a single thread block to generate the result O in the figure below; however, in FlashAttention v2, a single thread block is only responsible for generating a subset of the result O in the figure, which is each row (O1, O2…) in the figure below. Within a single thread block, tiling attention operations are iteratively performed on the data (Q1, K1, V1), (Q1, K2, V2), (Q1, K3, V3), and (Q1, K4, V4), accumulating the results into O1. The O1 values during the iterations are intermediate results, while O1 after the final iteration is the true result. This aligns with the semantic interpretation that attention is a weighted average sum; it can be understood that O1 is a weighted average sum representation of Q1 in a deeper semantic space.
In this way, multiple thread blocks can generate O2, O3, and O4 parts in parallel, thereby increasing the overall parallelism of the algorithm and improving GPU utilization.

Backpropagation follows the same principle, but does not place the inner loop in the softmax reduction dimension. Therefore, the backpropagation loop is still the same as V1: the outer loop loads K and V first, and the inner loop loads Q. However, it adds one dimension of parallelism in the seq length (column direction). The specific analysis is as follows.
The main goal in the BWD process is to seek , (In order to find them, we also need to find intermediate results) , Let’s summarize which directions AllReduce needs to follow for these gradients:
- : Perform AllReduce along the i direction, which means summing the results of each row.
- : Perform AllReduce along the i direction, which means summing the results of each row.
- Perform AllReduce along the j-axis, which means summing the results of each column.
- , : Only relevant to the current i,j.
If we maintain the Q-loop inner loop and the KV-loop outer loop, which is equivalent to fixing the rows and iterating through the columns, then among these gradients, only benefited from this. However, writing intermediate results from the KV gradient to HBM consumes a lot of GPU memory and involves many memory operations. Since the amount of data in KV is larger than that in Q, a trade-off must be made: sacrifice Q and let KV enter the inner loop (the calculation of S and P is not affected by the loop change).
The specific algorithm for backpropagation is as follows.

Sequence Parallel
When writing CUDA code, we need to determine the total number of blocks required. For FlashAttention, attention calculations are performed within each block. Because the batch and head are data-independent during attention calculations, the block partitioning depends on whether the data dependencies between Q, K, and V support parallelism.
- Because of data dependencies, V1 divides the thread into blocks based on two dimensions: batch_size and num_heads. There are a total of
batch_size * num_headsblocks, each responsible for calculating a portion of the O matrix. An example of setting up the grid is as follows:dim3 grid(params.b, params.h). - Because Qi needs to be computed with the full set of K and V, V2 divides the thread into blocks based on three dimensions: batch_size, num_heads, and seq_len. There are a total of
batch_size * num_heads * num_m_blockblocks, each responsible for computing a portion of matrix O. num_m_block is the partition along the rows of the Q matrix, and each partition maintains several tokens. An example of setting up the grid is shown below.
if (params.num_splits == 1) {
dim3 grid(params.b, params.h, params.num_splits);
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
} else {
dim3 grid_dot(params.b, params.h, (params.seqlen_q + 128 - 1) / 128);
fmha_bwd_dot_do_o_kernel<Kernel_traits><<<grid_dot, Kernel_traits::THREADS, 0, stream>>>(params);
int num_splits = params.seqlen_k / blocksize_c; // seqlen_k is divisible by blocksize_c
dim3 grid(params.b, params.h, num_splits);
kernel_seqparallel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
}
The purpose of increasing sequence parallelism is to better utilize the sequence parallelism (SM) and maximize its utilization. When both batch_size and num_heads are large, there are also many blocks, resulting in high SM utilization. However, if our data seq_len is long, it often corresponds to a smaller batch_size and num_heads, leading to idle SMs. To address this issue, V2 introduces partitioning on the seq_len of the sequence number (Q).
FlashAttention V1
FlashAttention V1 parallelizes both batch and heads.
- For a single sequence, the parallel computation of FlashAttention v1 mainly occurs between attention heads. During a single forward computation, attention heads within the same self-attention computation can be computed in parallel.
- Data within the same batch is also processed in parallel.
Therefore, the parallelism of FlashAttention v1 actually occurs simultaneously in two dimensions: batch and attention heads. The number of thread blocks required is equal to the batch size × the number of heads. Each block is scheduled to run on a streaming multiprocessor (SM), and the A100 has a total of 108 streaming multiprocessors. When the number of blocks is large, more SMs will be computed in parallel, and the overall throughput will naturally be higher, making full use of GPU resources.
However, when processing long input sequences, memory limitations often lead to a reduction in batch size and the number of attention heads, thus decreasing parallelization. This is because setting the batch size and number of attention heads too large can cause OutOfMemoryError (OOM). Therefore, in scenarios with long contexts, the batch size or the number of attention heads is relatively small. The batch size on a single GPU typically becomes very small, meaning the actual number of attention heads that can be parallelized may be far less than the number of attention heads, resulting in lower overall system throughput.
The thread block distribution of V1 is shown in the figure below.
Assuming batch_size = 1 and num_heads = 3, we use different colors to represent different attention heads. We know that in Multihead Attention, each attention head can be computed independently, and the results are concatenated after computation. Therefore, we assign one attention head to one block, enabling parallel computation between blocks. Within each block, the “KV outer loop, Q inner loop” process in V1 is executed. This process is organized at the lower warp level of the block and computed by threads. Finally, each block simply writes its result to the corresponding position in its maintained O after computation.

FlashAttention V2
The parallel strategy of FlashAttention v1 results in a significantly smaller number of parallelizable thread blocks compared to the number of attention heads when the input sequence is long, due to the small batch size. Therefore, it’s necessary to consider other dimensions besides batch and attention head dimensions for parallelization. Thus, FlashAttention v2, building upon the parallel strategy of FlashAttention v1, adds parallel operations along the sequence length dimension. This is essentially a complementary improvement to the overall idea of inner and outer loop substitution.
Forward propagation partitioning
Now let’s continue assuming batch_size = 1 and num_heads = 3. Unlike V1, we’ve also split Q along the seq_len dimension, dividing it into two parts, i.e., num_m_block = 2. So now we have a total of 1 x 2 x 3 = 6 blocks running. The operations between these blocks are also independent because:
- The calculation of the head is independent, so blocks of different colors do not interfere with each other.
- When using Q as the outer loop and KV as the inner loop, the blocks between rows are independent, so the blocks of different rows do not interfere with each other.
Each block loads the corresponding slice from Q and the corresponding head slice from KV, calculates the part of O it maintains, and then writes it to the corresponding position of O.

Differences
Because the inner and outer loops of FWD and BWD are different in V2, the division of thread blocks will also be different.

The large box in the diagram represents the output matrix, and worker represents a thread block. Different thread blocks are represented by different colors, with white indicating that they are exempt from computation due to the mask operation.
- Forward propagation: Each row corresponds to a worker, which means that each row of the O matrix is calculated by a thread block (assuming num_heads = 1).
- Backpropagation: Each column corresponds to one worker. This is because in BWD, we use KV as the outer loop and Q as the inner loop. In this case, dK and dV are accumulated row-wise, while dQ is accumulated column-wise. Since the minority follows the majority, the thread_block here is calculated by The columns are divided.
Other possibilities
- Why doesn’t V1 perform sequence parallelism? Actually, both FA1 and FA2 could. Looking at the code, later versions of V1 did introduce sequence-level parallelism. While V1 also introduced sequence parallelism, its grid organization was (batch_size, num_heads, num_m_blocks), while V2’s organization is (num_m_blocks, batch_size, num_heads). What’s the significance of this order reversal? This reversal is to improve the L2 cache hit rate. For blocks in the same column, they read the same key-value portion. Therefore, when reading data from the same column block, there’s a high probability that the desired data can be directly read from the L2 cache (data previously retrieved from other blocks).
- Why only split the Q-sequence length (seq_len) and not the KV-sequence length (seq_len)? The answer is that, generally speaking, splitting into blocks in parallel along the Q-sequence length is sufficient for GPU occupancy. Unless you believe the softmax function (SM) is truly not fully utilized, avoid splitting along the KV dimension. This is because different blocks cannot be computed independently (for example, for a row of O, its various parts come from different blocks, and to obtain the global softmax result, the results of these blocks need to be aggregated and computed again), which will introduce additional communication overhead. In fact, the V2 cutlass implementation does provide a method for splitting the KV-sequence length (seq_len).
Furthermore, FlashAttention V2 exhibits high computational parallelism during training and inference prefill due to the large query_num, as well as the large head_num and batch_size. However, it is unsuitable for the inference/decoding stage because query_num is 1 at this point, and the value of simply batch_size * head_num is very small. Therefore, FlashAttention V2 is not used during inference.
Adjusting workload between warps
Having discussed the parallelism of thread blocks, let’s look at how warps within a block allocate work. This section optimizes the warp-level working mode within thread blocks, minimizing communication between warps and the number of times shared memory is accessed.
Matrix multiplication is inherently modular. Therefore, we can fully leverage the computational power of multiple warps to partition the matrix, thereby accelerating the overall computation speed. Each thread block is responsible for computing one attention head within a specific block. Within each thread block, threads are further organized into multiple warps, and the threads within each warp can collaboratively complete the matrix multiplication computation. Work Partitioning primarily focuses on optimizing the organization of warps. Regardless of whether it’s V1 or V2, in the Ampere architecture, each block is further divided into 4 warps, while in the Hopper architecture, it’s 8 warps.
The left figure represents V1, and the right figure represents V2.

FlashAttention V1
In the forward calculation of flash attention1, for each block, it is , Split into 4 different warps, but It remains visible to all four warps. The authors call this calculation method ‘split-K’.
Each warp reads the same Q-block and its own KV-block from shared memory. Each warp computes its own Then it is multiplied by the segmented V. For the same Q, all KV values need to be calculated to get the result, but each warp only calculates the column-wise results. These column-wise results must be aggregated to obtain the corresponding row-wise results of the final O matrix. Therefore, each warp needs to write its intermediate results to shared memory, and then another warp (e.g., warp1) performs unified integration. This is why communication between warps is necessary. The need to write intermediate results impacts computational efficiency. Furthermore, the dependency between inner and outer loops prevents V1 from performing parallel operations; the outer loop can only be executed as a single thread block, and operations within the warp are also serial.

FlashAttention V2
The drawback of Flash Attention 1’s block-based approach is that, since the purpose of fwd is to compute the softmax along the row direction, and the row direction information needs to be aggregated at the end, intermediate results need to be written back to SRAM, and then the time-consuming synchronize call is called before the addition operation. This memory operation slows down the computation. To overcome this drawback, v2 uses a split-Q strategy, where computation is performed in each warp. Afterwards, the corresponding V fragment is sufficient to obtain the corresponding O fragment, without the need for communication between warps, thus reducing intermediate shared memory read and write operations.
Regarding why this modification reduces shared memory reads and writes to improve performance, the original paper states the following:

In the V2 implementation, the Q-dimensional space is partitioned by warps. Each warp reads the same KV blocks and its own Q-blocks from shared memory. The partitions along the Q-dimensional space are independent (the computations in the row direction are completely independent). For a given Q-token, all results in the corresponding sequence dimension K are within a single warp; that is, all computational elements of a local softmax are within a quarter warp. In other words, each warp only needs to be multiplied by the partitioned V to obtain the corresponding block output, and then its calculated result is written to the corresponding position in O. This allows for the softmax computation and subsequent All calculations are performed within a single warp. Because this reduces additional additions and their corresponding read/write operations, communication between warps is no longer required. Furthermore, HBM writes are no longer needed within the inner loop (replaced with less frequent outer loop writes, as the inner loop completes the calculation in one iteration without needing synchronization across the outer loop), reducing I/O overhead.
However, this parallel warp method has a drawback in the BWD process of V2: since dK and dV in bwd are AllReduce in the row direction, this splitting method will cause communication between warps.

1.4 Causal Mask Processing
V2 also includes a simple optimization for Causal Masking. When training an LLM autoregressive model, a mask is typically applied to the Attention Score matrix to ensure that each token does not attend to subsequent tokens.
FlashAttention is based on block-based computation. Therefore, if a block needs to be completely masked, it can be skipped without any computation. This means there’s a possibility of Early Exit during the computation process. That is, there are blocks with all masks set to 0, and blocks whose indices meet certain conditions, which can be returned directly without computation. Specifically, based on the size of the row and column indices, it can be divided into three types:
- If column_index < row_index, then the entire block needs to be calculated. No causal mask required.
- If column_index > row_index, then the entire block can be skipped without further calculation. No causal mask required.
column_index = row_index, which requires applying a causal mask to process the data within the block before calculation. This can avoid some calculations.
The following is an excerpt from the specific paper.

1.5 MQA/GQA
FlashAttention also supports MQA and GQA. For MQA and GQA, FlashAttention uses indexing instead of directly copying multiple key-value headers to GPU memory and then performing calculations. Instead, it passes the key-value pairs/key-value header indices to the kernel, calculates the memory addresses, and directly reads the key-value pairs from memory.

1.6 Summary
Compare
We first conduct a systematic comparison of V1 and V2.

computational load
The advantage of FlashAttention v2 lies in eliminating the multiplication and division operations in each step of the original process. The specific idea behind its operation reduction is as follows.
Suppose we have a vector x, and we divide it into two sub-vectors by dividing it into two parts.
After calculating both sub-vectors, in order to To update the softmax function globally, a denominator replacement is required: that is, the local EXP summation term is promoted to the global value. The replacement logic is to multiply by the original denominator Then divide by the new global EXP summation term. After this update is complete, you will get The final softmax. If we divide the vector x into two, instead of three, then The softmax property will be updated again after this update: when After processing. At this point, for We need to multiply the softmax by (The previous global EXP summation term), divided by the new global EXP summation term at this time.
Looking back, it becomes clear that there was no need to remove it. Because the next update will require multiplying by one To cancel out the denominator. Similarly, if There will be further partitioning later, so we don’t need to divide by this point. Because it will be multiplied by another number in the next update. To offset.
Therefore, we can actually avoid removing the summation term based on the current EXP after each block calculation; we can simply divide it directly by the final value at the end. That’s it. Essentially, in each iteration, we no longer divide by the sum of EXP. Because we no longer divide by EXP, there’s no need to update the sum of EXP. After processing the last block, we simply use the current global sum of EXP as the denominator.
IO
After adjusting the loop order, compared to FA1, the inner loop no longer needs to read and write every time. , , By using HBM, IO-accesses are reduced, and the time consumed is also reduced.
V2 Overall
Let’s summarize with another overall diagram of V2.

1.7 Problem
FlashAttention-2 uses an online softmax technique to divide the attention computation of a single query block into working blocks. Each working block consists of a key block and a corresponding value block, and these working blocks arrive sequentially to update the attention output for a given query block. FlashAttention-2 computes the online softmax for each incoming working block, readjusts the intermediate outputs obtained from the previous working block, and combines them with a portion of the current working block’s output to obtain the latest updated output. However, this precise method of attention computation is limited by its sequential nature, leading to slower computation speeds during the decoding phase, especially when a large number of key/value blocks need to be traversed.
1.8 Implementation
Here we will use the V2 implementation for learning.
Fusion Operator
Ultimately, FlashAttention can perform attention operations using a single kernel: load input data from HBM, perform all computational operations (matrix multiplication, masking, softmax, dropout, matrix multiplication) in SRAM, and then write the results back to HBM. By fusing multiple operations into a single operation through kernel fusion, multiple operations are eliminated from the need to retain the S and P matrices, thus avoiding repeated readings and writes from HBM.

Triton implementation
Phil Tillet first proposed and implemented ideas such as swapping the loop order (outer loop on row blocks and inner loop on column blocks, instead of the reverse order in the original FlashAttention paper) and parallelization in the sequence length dimension in his Triton implementation.
Note: The FlashAttention V1 algorithm performs an outer loop on the key-value (kv) dimension and an inner loop on the value (q) dimension. However, the Triton implementation uses an outer loop on the value (q) dimension and an inner loop on the key-value (kv) dimension.
V2 reverses the loop order, making each iteration of the outer loop independent and allowing it to be sent to different thread blocks for parallel execution. This means the batch, head, and sequence loops can be partitioned in parallel at the thread block level, significantly increasing GPU throughput. The reverse follows the same principle: the inner loop should not be placed in the softmax reduction dimension; therefore, the loop order is different in both forward and reverse directions.
Basic idea
The computation process of FlashAttention V2 is as follows: Q is calculated separately with K and V in the inner loop order to obtain partial sums. Finally, the partial sums are accumulated to obtain an output with the same shape as Q. The pseudocode description is as follows.
flash_attention_2():
# outter loop
parallel do q[NUM_BLOCK_M]:
# inner loop
for i in range(NUM_BLOCK_N):
qk = q @ k[i].T
score = online_softmax(qk)
out += score @ v[i]
rescale(out)
In code terms, the basic idea is _attention to implement parallelism and emit operators. _att_fwd The current thread identifies the data it should access and _attn_fwd_inner is responsible for the actual attention calculation.
Thread Model
Single-threaded attention calculation performs the following operations: q[seqlen, headdim] @ k[seqlen, headdim].T @ v[seqlen, headdim]
Multilinear attention computation requires splitting along the q-axis, with each thread responsible for single-head attention computation of Block_M tokens [Block_M, headdim]. That is, if the input shape is [bs, head, seqlen, headdim], then the total number of threads is bs x head x seqlen/Block_M. Parallelism is achieved on both the bs x head dimension and the seqlen dimension.
class _attention
The _attention function utilizes torch.autograd.Functiona custom operator to implement Flash Attention.
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
# q k v 的 shape 是 [B, H, S, D],因此数组-1是最后一个维度,就是D_HEAD,头的维度。
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
# 初始化输出
o = torch.empty_like(q)
# 设置q在S维度上的切分,即Q分块的粒度。每个块需要处理q块的形状为 [1, 1, BLOCK_M, D]
BLOCK_M = 128 # BLOCK SIZE of Q、O Matrix
# 设置关于内循环时,K、V块在S维度上的长度,即,KV的分块计算的粒度
BLOCK_N = 64 if Lk <= 64 else 32 # TILE SIZE of K、V Matrix
# num_stages 是关于 A100 中新的异步数据拷贝特性的设置,可以粗略地理解为 prefetch 的深度,缓存多少份数据在buffer里
num_stages = 4 if Lk <= 64 else 3
# 每个kernel所需要的 warp数量是4,线程数是 4 x 32
num_warps = 4
stage = 3 if causal else 1
# Tuning for H100
if torch.cuda.get_device_capability()[0] == 9:
num_warps = 8
num_stages = 7 if Lk >= 64 else 3
# 划分二维网格,共有 triton.cdiv(q.shape[2], BLOCK_M)*q.shape[0]*q.shape[1]个块
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
# 存下S矩阵每行的最大值,用于用于反向传播使用
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
_attn_fwd[grid](
q, k, v, sm_scale, M, o, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
k.stride(0), k.stride(1), k.stride(2), k.stride(3), #
v.stride(0), v.stride(1), v.stride(2), v.stride(3), #
o.stride(0), o.stride(1), o.stride(2), o.stride(3), #
q.shape[0], q.shape[1], #
N_CTX=q.shape[2], #
BLOCK_M=BLOCK_M, #
BLOCK_N=BLOCK_N, #
BLOCK_DMODEL=Lk, # head size
STAGE=stage, #
num_warps=num_warps, # _attn_fwd函数被分成了4个warp
num_stages=num_stages #
)
ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, M = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 5
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
PRE_BLOCK = 128
assert N_CTX % PRE_BLOCK == 0
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
delta = torch.empty_like(M)
_attn_bwd_preprocess[pre_grid](
o, do, #
delta, #
BATCH, N_HEAD, N_CTX, #
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL #
)
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
M, delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
N_HEAD, N_CTX, #
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
BLOCK_DMODEL=ctx.BLOCK_DMODEL, #
num_warps=NUM_WARPS, #
num_stages=NUM_STAGES #
)
return dq, dk, dv, None, None
The _attention() class can be called like this. Z, H, N_CTX, and D_head are batch, head, sequence length, and head dimension, respectively. It seems that batch, head, and sequence length have been integrated into q, k, and v.
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_()
_attn_fwd
_attn_fwd is a kernel function in Triton that multiplies a batch of input Q, K, V matrices by the weight matrix and then performs a softmax operation. This kernel function implements self-attention by calculating a weighted sum at each position and storing it in the output matrix. During computation, each thread block processes a portion of one row of the input matrix and stores it in shared memory so that the data can be reused when processing other rows. The logic of this code is as follows:
- Calculate the starting pointer of the current row in the input matrix based on the current program index and the row span of the input matrix (i.e., the number of bytes occupied by each row).
- Based on the block size (i.e., the number of columns each program processes), create an offset array representing the index of the input element that each program needs to access. Note that the block size is a power of 2 greater than or equal to the number of columns, so it can be guaranteed that each line can be completely processed by one block.
- Based on the offset and a mask (used to filter out offsets exceeding the column number), the elements of the current row are loaded from the input pointer into the register, and the maximum value of the current row is subtracted to improve numerical stability.
- The elements after subtracting the maximum value are exponentially calculated, and the sum is obtained on a given axis to get the denominator. Then the numerator is divided by the denominator to get the softmax output.
- Based on the offset and a mask (used to filter out offsets exceeding the column count), the softmax output is stored from the register into the output pointer.
This allows each program to process a portion of the input matrix in parallel and write the result to the output matrix. This approach improves the efficiency and parallelism of memory access and computation.
The specific code is as follows.
"""
# Another trick we can use is to ask the compiler to use more threads per row by
# increasing the number of warps (`num_warps`) over which each row is distributed.
# You will see in the next tutorial how to auto-tune this value in a more natural
# way so you don't have to come up with manual heuristics yourself.
"""
@triton.jit
def _attn_fwd(Q, K, V, sm_scale, M, Out, #
stride_qz, stride_qh, stride_qm, stride_qk, # stride_qz就是batch,使用它就能在batch上并行
stride_kz, stride_kh, stride_kn, stride_kk, # k和n与v相反
stride_vz, stride_vh, stride_vk, stride_vn, # k和n与k相反
stride_oz, stride_oh, stride_om, stride_on, #
Z, H, #
N_CTX: tl.constexpr, #
BLOCK_M: tl.constexpr, #
BLOCK_DMODEL: tl.constexpr, #
BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr #
):
# 目的是知道本线程块应该操作什么数据
# program_id是外层循环中线程块的id,线程块包括warp组线程。start_m就是线程块的grid第一维度坐标,借此可以获取本线程块在 q 的 S 维度上的指针位置 start_m * BLOCK_M。
start_m = tl.program_id(0) # 对应论文算法的外层循环,即Q矩阵的第几个块
# 获取本线程块的grid的第二维度坐标。第二维度的数量等于 Z * H,因此使用它可以确定在第几个 batch 的第几个 head。此处用Z表示B维度
# 下面三行依据内层循环对应的线程索引知道本线程在qkv上应该在的offset
off_hz = tl.program_id(1)
off_z = off_hz // H # batch 的 offset
off_h = off_hz % H # head 的 offset
# 获取当前 head 的 shape 为 [S, D] tensor 的 offset
# 使用 stride_qz来对batch并行,使用stride_qh在head上并行,就是对batch, head在线程角度进行并行
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
# 根据当前程序的索引和输入矩阵的行跨度(即每行占用的字节数),计算出输入矩阵中当前行的起始指针
# 创建一个 block 指针指向对应 [S, D] tensor 里的 [start_m * BLOCK_M:(start_m + 1) * BLOCK_M, D] BLOCK_DMODEL=D,即第 start_m 个 block 加载 Q 的一个子 tensor [BLOCK_M, BLOCK_DMODEL]
# 以行的方式访问则使用 order=(1, 0)
Q_block_ptr = tl.make_block_ptr( # 构建一个指针
base=Q + qvk_offset, # 找到在输入矩阵中的起始位置
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0), # Q在外层,和算法一致
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=v_order,
)
# k 需要进行一个转置
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1), # 转置
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0), # 外层循环,利用start_m(外层循环对应的线程索引)知道本线程在q上的offset
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# initialize offsets
# tl.arange函数,用于创建一个从0到指定值的连续整数序列,类似于Python中的range函数。
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # 初始化为负无穷
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) # 向量o
# load scales
qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/log(2)
# load q: it will stay in SRAM throughout
# 对于每个 block 需要整个 q 的子 tensor [BLOCK_M, BLOCK_DMODEL] 全程参与
q = tl.load(Q_block_ptr)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M, BLOCK_DMODEL, BLOCK_N, #
4 - STAGE, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
)
# stage 2: on-band
if STAGE & 2:
# barrier makes it easier for compielr to schedule the
# two loops independently
tl.debug_barrier()
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M, BLOCK_DMODEL, BLOCK_N, #
2, offs_m, offs_n, N_CTX, V.dtype.element_ty == tl.float8e5 #
)
# 后处理
# 算法流程第13步
m_i += tl.math.log2(l_i)
# 算法流程第12步
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
# 将结果写回
# 算法流程第15步
tl.store(m_ptrs, m_i)
# 算法流程第14步
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
_attn_fwd_inner
The _attn_fwd_inner() function is where the attention operation is actually performed. First, the start_m-th block loads a sub-tensor [BLOCK_M, BLOCK_DMODEL] of Q, and multiplies it sequentially with the N_k sub-tensors [BLOCK_DMODEL, BLOCK_N] of K, where N_k x BLOCK_N = start_m x BLOCK_M. The result [BLOCK_M, BLOCK_N] obtained by multiplying with the sub-tensors of K is then multiplied with the corresponding sub-tensor [BLOCK_N, BLOCK_DMODEL] of V to obtain the sub-tensor [BLOCK_M, BLOCK_DMODEL] of O. Since this loop is repeated N_k times, the final result of O is the sum of N_k superpositions. It can be seen that the actual shape of the start_m block after concatenating the results of multiplying all the subtensors of Q and K is [BLOCK_M, start_m x BLOCK_M].
The specific code is as follows, annotated according to the V2 process.
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, #
K_block_ptr, V_block_ptr, #
start_m, qk_scale, #
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
N_CTX: tl.constexpr):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX
# 调整 block 指针的起始 offsets
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
# loop over k, v and update accumulator
# 第一阶段从 0, start_m * BLOCK_M
# 算法流程第6步,执行内循环
for start_n in range(lo, hi, BLOCK_N): # 对应的内层循环
start_n = tl.multiple_of(start_n, BLOCK_N)
#实际执行QK^T @ V
# -- compute score=QK^T ----
# k [BLOCK_DMODEL, BLOCK_N]
# 算法流程第7步,load Kj, Vj到SRAM
k = tl.load(K_block_ptr)
# qk [BLOCK_M, BLOCK_N]
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# 算法流程第8步
qk += tl.dot(q, k)
# 算法流程第9步
if STAGE == 2:
# 第二阶段去除小三角形对结果的影响
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
m_ij = tl.maximum(m_i, tl.max(qk, 1)) # 最大的m, 最后一个维度(行向量)的最大值构成的向量
qk -= m_ij[:, None]
else:
# 统计当前的 m_ij
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) # 最大的m
qk = qk * qk_scale - m_ij[:, None]
p = tl.math.exp2(qk) # 计算exp
# 统计当前的 l_ij
l_ij = tl.sum(p, 1) # 最后一个维度的求和
# -- update m_i and l_i
# 计算当前的修正因子 alpha
alpha = tl.math.exp2(m_i - m_ij)
# 修正当前的 l_i
l_i = l_i * alpha + l_ij
# 算法流程第10步
# -- update output accumulator --
# 对 O 子 tensor 的累加结果进行修正
acc = acc * alpha[:, None]
# update acc
# 算法流程第7步,load Kj, Vj到SRAM
v = tl.load(V_block_ptr)
# score @V
acc += tl.dot(p.to(tl.float16), v)
# update m_i
m_i = m_ij
# 调整 K 和 V 的指针
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
return acc, l_i, m_i
0x02 Flash-Decoding
While FlashAttention-2 achieves a 2x speedup compared to FlashAttention, it only works during the pre-padding stage of decoding because it ignores the different behaviors of the attention mechanism between the decoding and non-decoding phases. This results in significant GPU core wastage during the decoding phase. Furthermore, due to the lack of support for tensor parallelism, Vanilla FlashAttention-2 is also unsuitable for multi-GPU scenarios.
Modern large-scale language models require an attention mechanism that can scale well in multi-GPU scenarios to effectively support increasingly longer context lengths. To improve the computational speed of attention during the inference phase, the authors of FlashAttention proposed FlashDecoding, whose blog post can be found at: [ https://crfm.stanford.edu/2023/10/12/flashdecoding.html ]. FlashDecoding primarily accelerates LLM inference. For cases where the sequence length of Q is 1, it implements block parallelism in the key-value (K/V) direction to improve GPU utilization and thus achieve acceleration. FlashDecoding shows significant speedup effects with small batch sizes and large sequence lengths, and its performance is not sensitive to increases in sequence length.
2.1 Current Situation
The inference process in LLM essentially involves two distinct computational phases.
- The first stage is the prompt computation stage (sometimes called the pre-filling stage). In this stage, all tokens from the input prompts are forward-propagated through the model to generate the first output token. This stage is computationally intensive and requires high FLOPS/s.
- The second phase is the decoding phase (sometimes called the token generation phase). This phase begins in an autoregressive manner, where each subsequent token is generated based on the forward propagation result of the previous token and the previous KV-Cache in the sequence. As the context length increases, this cached context can become very long. Processing such long contexts sequentially slows down the decoding phase and is limited by memory bandwidth and capacity.
The diagram below summarizes the three operations involved in self-attention, as well as the corresponding dimensions involved in the decoding and pre-filling stages.

Although researchers have proposed mechanisms such as KV-Cache and FlashAttention to meet the low-latency requirements of LLM, these techniques cannot handle the different computational properties at different stages of the inference process.
FlashAttention V2’s forward propagation operates in parallel along the seqlen and batch_size dimensions of Q. As shown in the diagram, for the current Q-block Queries, the forward pass iterates through all K and V blocks within the thread block, calculating the local attention output for each block. Each local attention output is scaled based on the value of the current iteration during the iteration within the thread block, until the iteration along K and V is complete, yielding the final correct output.

This approach is effective for forward propagation during training because the seqlen or bs values are relatively large during training, allowing for efficient utilization of GPU resources. However, the inference generation phase generates tokens per token, with only one query token per inference iteration, making parallelization via queries impossible. This is especially problematic if bs is small, leading to inefficient use of GPU resources. Specifically, if the batch size is less than the number of stream processors (SMs) on the GPU (108 SMs on an A100 GPU), the attention operation can only utilize a small fraction of the GPU, particularly when using longer contexts.
2.2 Scheme
To address this issue, the authors of FlashAttention developed FlashDecoding to optimize the forward operation during the inference phase. The basic idea is quite intuitive: since query_num = 1 and a potentially small batch size can lead to insufficient blocks during the decoding phase in inference scenarios, could we instead increase blocks along the key and value dimensions instead of focusing on increasing blocks for the query?
Following this approach, Flash-Decoding adds a new parallelization dimension to FlashAttention V2’s parallelism of batch size and query length: the sequence length of keys/values. This new concurrency reduces latency while increasing hardware utilization, but requires additional final reduction costs.
Flash Decoding mainly includes the following three steps:
- Dividing the key/value pairs into smaller chunks allows for subsequent concurrency. Since physical separation isn’t required, this data chunking doesn’t involve GPU operations. The key/value chunks remain views of the complete key/value tensor.
- These key-value blocks are started in parallel. Standard FlashAttention is used to compute the attention between the query and each block in parallel on these key-value blocks. For each row of each block (since a row represents a feature dimension), Flash Decoding records an additional scalar: the log-sum-exp of the attention value.
- Finally, by utilizing the commutativity of addition in the inner product, the final result is calculated by reducing the calculation results of all split blocks and adjusting the contribution of each block using log-sum-exp.
We only need to execute separate kernels for steps 2 and 3. Although the final reduction operation introduces some additional computation, overall, Flash-Decoding achieves greater efficiency by increasing parallelism.

We use a diagram to compare Flash-Decoding and FlashAttention V2. The diagram assumes 2 heads, 1 batch, and 5 SMs (Signal Analyzers). Each block can only perform the same task; for example, it can only compute head1 or head0 individually, not simultaneously. When the batch size is 1, FlashAttention2 can only allocate 2 blocks, while FlashDecoding can allocate 4 blocks.

2.3 Discussion
FlashAttention accelerates batch size and query length through parallelization, while Flash-Decoding adds a new dimension of parallelization: the sequence length of keys/values. Even with a small batch size, it can fully utilize the GPU as long as the context is long enough. Similar to FlashAttention, Flash-Decoding requires almost no additional storage of large amounts of data in global memory, thus reducing memory overhead.
FlashDecoding has two potential inefficiencies.
- The kernel needs to be started twice. The first kernel calculates the partial attention results of the query and some keys and some values for each block. The second kernel mainly corrects and reduces the partial attention results from the first kernel.
- During the first computation, the parallelism of the sequence dimension is fixed, and the number of blocks used for long sequences and short sequences is the same. This results in long sequences being computed slowly and short sequences being computed quickly.
FlashDecoding++ (not by Tri Dao) modifies FlashDecoding by eliminating synchronization costs through approximation of the global maximum in softmax, thus avoiding eventual rescaling. FlashDecoding++ avoids calculating intermediate local softmaxes within FlashDecoding’s inner loop; it calculates the final global softmax once the algorithm can determine all partial exponential sums. Furthermore, FlashDecoding++ uses double buffering to hide memory access latency.
Despite these improvements, FlashDecoding and FlashDecoding++ remain suboptimal load balancing strategies. They require launching additional reduce cores and are therefore subject to kernel startup overhead, as well as reduction or correction overhead that increases with problem size.
0x03 Flash-Mask
With the rapid development of artificial intelligence technology, large-scale models, represented by Transformer, have demonstrated extraordinary capabilities in natural language processing, computer vision, and multimodal applications. In these large-scale models, the attention mechanism is a crucial component. To determine which query-key tokens require effective attention computation during large-scale model training tasks, the industry typically uses attention masks. However, current attention masks are usually represented using two-dimensional dense matrices, which leads to several problems. On the one hand, this representation introduces a large amount of redundant computation because the attention of many invalid tokens still needs to be calculated; on the other hand, its huge storage consumption makes it difficult to achieve efficient training for long sequence scenarios, hindering efficient training.
While existing methods like FlashAttention offer computational acceleration for specific attention masks, their support for attention mask patterns is limited, making it difficult to meet the demands of large model training tasks for flexible attention masks. To address this issue, PaddlePaddle pioneered the FlashMask technology, proposing a columnar sparse attention mask representation method that supports a wide variety of flexible attention mask patterns. This reduces storage complexity, and on this basis, an efficient operator kernel is implemented with a linear memory access complexity of O(N). This significantly accelerates the training efficiency of large models, especially in long sequence scenarios.
3.1 Motivation
FLASHMASK can be understood as an extension of FA (Fast Attention). FA aims to address the quadratic increase in computational and memory requirements faced by traditional attention mechanisms when processing long sentences. This increase is a significant challenge for Transformer models on any hardware, especially in LLM training with long sentences. Specifically, FA reduces attention latency through I/O-aware memory optimization and eliminates the need for Memory dependency. However, in the above training scenario, FA has two shortcomings:
- Native support for certain attention mask types is limited, and it doesn’t naturally adapt to more complex mask requirements. As shown in the pink area at the top of the diagram, FlashAttention only supports a few fixed mask types, such as causal masks, sliding window masks, causal document masks, and document masks. However, the attention mask types used in actual training tasks are often diverse, and current technology struggles to meet the flexibility requirements of large models for different training tasks.
- Previous methods used dense mask matrices, which led to The increased memory access leads to inefficiency and limits the maximum supported context length.

3.2 Approach
FlashMask’s core discovery is that in the attention masking patterns common in large models, the query-key token masking pattern exhibits a certain continuity. Specifically, for each key token, the query tokens used for ineffective attention computation are arranged adjacently. That is, in the two-dimensional mask matrix shown above, when query tokens and key tokens interact, they are continuously distributed along the column direction. Based on this insight, FlashMask cleverly transforms the two-dimensional dense mask matrix into a one-dimensional row index range, thereby achieving a more compact representation and significantly reducing storage requirements. This can be formalized as:
Where N is the sequence length of the Key, is the j-th column of a two-dimensional dense mask matrix, and is a continuous row index range, indicating that these continuous query tokens are masked and set as invalid attention calculations.
To efficiently handle complex mask patterns in causal and bidirectional attention scenarios, FlashMask proposes a novel columnar sparse representation method. Using the diagonal as a delimiter, it uses four one-dimensional vectors to represent the mask:
- Lower Triangular Start (LTS) index.
- Lower Triangular End (LTE) index.
- Upper Triangular Start (UTS) index.
- Upper Triangular End (UTE) index.
The row index range that is masked in the lower triangle is represented by [𝐿𝑇𝑆, 𝐿𝑇𝐸), and the row index range that is masked in the upper triangle is represented by [𝑈𝑇𝑆, 𝑈𝑇𝐸).
Those familiar with sparse matrices know that sparse matrices can usually be represented by a few one-dimensional arrays or vectors, without the need for two-dimensional tensors. This is also an important source of sparsity benefits. Similarly, FlashMask uses the same idea, using four vectors to represent which q tokens corresponding to each token in the k matrix are masked in the lower left and upper right corners. FlashMask divides the mask into two regions, one in the lower left corner and one in the upper right corner. LT describes the masking situation in the lower left corner, and UT represents the masking situation in the upper right corner. Taking (6) as an example, q has 10 tokens and k also has 10 tokens. For each k-dimensional token, we calculate the masking situation of the corresponding q-dimensional token. For example, for token number 5, the gray part has the red circle in the figure below, so [LTS,LTE)=[7,10), [UTS,UTE)=[2,4).

3.3 Algorithm
FlashMask integrates columnar mask representation into the FlashAttention-2 algorithm, enhancing its support for attention masks. Building upon the block-based computation of the FlashAttention Kernel, FlashMask utilizes mask vectors such as LTS to determine the mask type of the current block.
- Fully masked blocks: All elements of this type of block are masked and can be skipped during calculation.
- Partially masked blocks: These blocks have only some elements masked, so element-by-element masking is required.
- Unmasked blocks: All elements in this type of block are not masked, which simplifies the calculation process and eliminates the need for additional masking operations.
Through this classification process, FlashMask significantly improves computational efficiency, as shown in the figure below.

The algorithm in the figure below describes in detail the forward computation process of FlashMask extended FlashAttention-2, where the light blue shaded area represents the new computation steps added by FlashMask.

0x04 FlashAttention-3
The creators of FlashAttention have released V3, which features:
- More efficient GPU utilization. The WGMMA (Warped Matrix Multiplication and Accumulation) function is introduced for the H100 GPU, offering up to 3x the throughput of the A100. The TMA (Tensor Memory Accelerator) function for the H100 GPU accelerates data transfer between global and shared memory, handling all index calculations and out-of-bounds predictions. This frees up registers, increasing valuable resources for tile size and efficiency.
- Achieving better performance with lower precision. FlashAttention-3 can handle lower precision numbers (FP8) while maintaining accuracy. Specifically, FlashAttention-3 utilizes QuIP: 2-Bit Quantization of Large Language Models With Guarantees, which reduces quantization error through incoherent processing, namely multiplying the query and key with a random orthogonal matrix to “spread” outliers and reduce quantization error.
- It enables the use of longer contexts in LLMs. By accelerating the attention mechanism, FlashAttention-3 allows AI models to process longer text segments more efficiently. This allows applications to understand and generate longer, more complex content without slowing down the process.
Since it is mainly related to hardware, we will not go into detail. Interested readers can study it on their own.
0xFF Reference
- (Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)
- [Large Model Training] FlashAttention v1, v2 - The Clearest Formula Derivation & Algorithm Explanation ( Alan’s Brief Sharing)
- [ 1805.02867] Online normalizer calculation for softmax (arxiv.org) Maxim Milakov and Natalia Gimelshein. Online normalizer calculation for softmax. CoRR, abs/1805.02867, 2018.
- [Attention Optimization][20,000 words] 🔥Principles & Diagrams: From Online-Softmax to FlashAttention V1/V2/V3 DefTruth
- [Attention Optimization][10,000 words] 🔥TensorRT 9.2 MHA/Myelin Optimize vs FlashAttention-2 profile DefTruth
- [FlashAttention][20,000 words] 🔥Principles & Diagrams: From Online-Softmax to FlashAttention-1/2/FlashDecoding/FlashDecoding++ DefTruth
- Antinomi: FlashAttention Core Logic and V1/V2 Differences Summary
- Decode Optimization - Lean Attention (Hand-grab Pancake Bear)
- Flash Attention on Intel GPU (Minor Issues)
- Learning from the official Triton example of Flash Attention V2 [forward] from the L77 Nebula
- Flash Attention paper and source code study - KIDGINBROOK
- FlashAttention v2 paper review and advancement: Killua
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
- FlashAttention: Accelerates computation, saves GPU memory, and provides precise I/O-aware attention rotation (Thomas X).
- FlashAttention Explained (How to Speed Up Attention) - Austin
- FlashAttention Core Logic and V1/V2 Differences Summary (Antinomi)
- FlashAttention Algorithm Explained in DeepHub
- A Summary of FlashAttention Calculation Process ( by Pangpangdahai)
- From Online Softmax to FlashAttention by Zihao Ye
- From Online Softmax to FlashAttention
- LLM inference acceleration technology – operator fusion method of Flash Attention
- NLP (17): From FlashAttention to PagedAttention, How to Further Optimize Attention Performance
- ops(7): CUDA Implementation and Optimization of Self-Attention (Part 1) Purple Qi Comes from the East
- ops(8): CUDA Implementation and Optimization of Self-Attention (Part 2) Purple Qi Comes from the East
- Performance optimization of Scaled Dot Product Attention (SDPA) on CPU by Mingfei
- [Tearing Down LLM-FlashAttention2] Just Because the For Loop Optimization Was Too Beautiful - Little Winter Melon AIGC
- [Treating LLM-FlashAttention from Scratch] Starting with Softmax, a super long and easy-to-understand article!! (by Xiaodonggua AIGC)
- Multitasking Online Softmax TaurusMoon
- A lengthy article explaining FlashAttention v1/v2 Civ
- A lengthy article explaining FlashAttention v1/v2 Civ
- Reproduce Flash Attention 66RING using Cutlass Cute
- Thomas the Tank Engine X: FlashAttention: Accelerates computation, saves GPU memory, and provides precise attention that is aware of I/O.
- Illustrated Guide to Accelerating Large Model Computation Series: Flash Attention V2, From Principles to Parallel Computation ( by Mengyuan)
- Illustrated Guide to Accelerating Large Model Computation Series: FlashAttention V1, From Hardware to Computational Logic
- Large Model Analysis: Flash Attention Snowball Effect
- Accelerating Large Model Training with FlashAttention Series: The Product Perspective Behind Hit Projects (by Fang Jiarui)
- Some thoughts and questions about learning Flash Attention and Flash Decoding .
- Sequence Parallelism DeepSpeed-FPDT Hand-Pulled Pancake Bear [Large Model New Vision](javascript:void(0)😉
- My Transformer Acceleration Notes (Part 1): FlashAttention - delin
- A Hand-Drawn Guide to Flash Attention! Principle Analysis and Code Implementation. Goodnight, Tombrey!
- Exploring Linear Attention: Does Attention Require a Softmax? By Su Jianlin
- Learning FlashAttention 2 slowly and carefully - Example 1: The Lost Little Bookboy
- Learning FlashAttention slowly and carefully - The Lost Little Bookboy
- Detailed Derivation of the Flash Attention Monster
- A Thorough Understanding of FlashAttention and FlashAttention2: One of the Techniques for Enabling Large Model Context Lengths to Exceed 32K v_JULY_v
- Summary of methods to reduce Transformer complexity to O(N^2) (Part 1) Civ
- Summary of methods to reduce Transformer complexity to O(N^2) (Part 2) Civ
- A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library[ 5]
- Andrew Kerr. Gtc 2020: developing cuda kernels to push tensor cores to the absolute limit on nvidia a100. May 2020.
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. https://arxiv.org/abs/2307.08691
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness[ 2]
- FlashMask: Efficient and Rich Mask Extension of FlashAttention. https://arxiv.org/abs/2410.01359
- FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention. https://pytorch.org/blog/flexattention/
- From Online Softmax to FlashAttention(@ http://cs.washington.edu )
- From Online Softmax to FlashAttention. https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf
- Maxim Milakov and Natalia Gimelshein. Online normalizer calculation for softmax. CoRR, abs/1805.02867, 2018.
- Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.[ 6]
- Self-attention Does Not Need O(n^2) Memory. https://arxiv.org/abs/2112.05682
- The I/O Complexity of Attention, or How Optimal is Flash Attention?[ 4]
- Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: fast and memory-efficient exact attention with io-awareness. CoRR, abs/2205.14135, 2022.
- Goodnight, Tom Britt. ( https://www.zhihu.com/people/Rancho2508 )
- Derivation of the Forward Process of Ring Attention and FlashAttentionV2 from a Coding Perspective (Yang Pengcheng)
- Let’s discuss the principles of the forward pass of FlashAttentionV3 with code examples . (Yang Pengcheng )
- A Discussion on PagedAttention (V1/V2) in Thread Partitioning and Data Segmentation in CUDA Programming (by Yang Pengcheng )
- [DefTruth: Attention Optimization] 📚FFPA (Split-D): FA2 infinite HeadDim expansion, 2x↑ 🎉 vs SDPA EA