Transformer Systems · Transformer Systems
Exploring the Transformer Series (20) --- KV Cache
Autoregressive inference redundancy, KV cache, prefill vs decode, implementation, and resource usage.
Exploring the Transformer Series (20) --- KV Cache
0x00 Overview
As the list of tokens input to the LLM grows, the Transformer’s self-attention phase can become a performance bottleneck. A longer token list means larger matrices to be multiplied. Each matrix multiplication consists of many small numerical operations, called floating-point operations, whose performance is limited by the GPU’s floating-point operations per second (FLOPS). Thus, inference latency and throughput become critical challenges in LLM deployment. These problems mainly stem from:
- To generate the sequential autoregressive properties of the inference, it is necessary to recalculate the key and value vectors for all previous labels.
- Since the attention mechanism grows quadratically with the size of the input sequence, it often incurs the greatest latency overhead during inference.
To address inference latency and throughput issues, the most commonly used optimization technique is KV Cache. KV Cache is a key performance optimization mechanism. By caching already computed key and value matrices, it avoids redundant computation during autoregressive generation, thus significantly improving inference efficiency (essentially trading space for time). This mechanism is similar to the short-term memory system in human thinking, enabling the model to efficiently utilize historical information. Reusing KV Cache achieves two main objectives:
- Improve Prefill efficiency. Because the number of tokens involved in Prefill is reduced, the computational load decreases, resulting in lower Prefill latency and directly improving TTFT performance. This is particularly suitable for optimizing performance in multi-turn dialogue scenarios.
- Saves GPU memory. The KV cache stores reusable intermediate data that is crucial to the generation inference process.
This article first introduces how to predict the next token step by step without using a KV Cache, and then introduces the KV Cache.
Note: The analysis and summary in this article may differ from the actual historical trajectory of the concept’s emergence. This summary is provided simply because the author believes it is easier to explain in this way.
0x01 Problems with Autoregressive Inference
Multi-turn dialogue is a fundamental feature of modern large-scale language models (LLMs). In such dialogues, a multi-turn dialogue session consists of a series of consecutive dialogues, denoted as D = [d1, d2, … dN]. In each dialogue dj, the user enters a new question or command qj and then waits for the LLM’s response aj.
LLM uses an autoregressive model. The inference process of an autoregressive model is unique: the generation of tokens is iterative. It predicts the next word/character using the preceding text, and the last word in the preceding text is represented by the decoder and mapped to the probability distribution of its next predictable word. Specifically, given an input text, the model outputs an answer (of length N). However, this process actually executes N inference steps. That is, each inference step outputs only one token. The token output in the current round is concatenated with the previous input tokens and used as the input tokens for the next round. This process repeats until a terminator is encountered or the number of tokens generated reaches a set threshold max_new_token.

1.1 Request Lifecycle
In practice, prompts in LLM are typically long sequences. Without considering key-value caches, the actual characteristics of prompts lead to two distinct processes in LLM inference: the prompt phase and the token-generation phase.
- The prompt phase: When the LM service receives a user request (Is tomato a fruit?), it generates the first output token (Yes) based on the input tokens (Is, tomato, a, fruit, ?).
- The token-generation phase begins after the first token is generated (Processing). The prompt and the already generated tokens are used as input to a new model. An autoregressive approach is used to generate one token at a time until a specific Stop Token is generated (or a user condition is met, such as exceeding a certain length). During this process, the inputs of consecutive rounds differ by only one token, resulting in duplicate calculations.
The prompt phase is counted as one reasoning phase, and each decode in the token-generation phase is counted as one reasoning phase. For example, the token-generation phase in the diagram below includes three reasoning steps.

We will conduct an in-depth analysis of the characteristics of the two stages.
The prompt phase (pre-filling phase), also called the initiation phase, has the following characteristics:
- Timing: Occurs during the computation of the first output token.
- Input: Input a prompt sequence.
- Function: To process all user input at once. LLMs summarize the context of the input sequence (i.e., input prompts) and generate a new token as the initial input for the decoding phase.
- Execution count: It can be completed with a single Forward.
- Computation type: There are a large number of GEMM (GEneral Matrix-Matrix multiply) operations, which belong to the Compute-bound type (computation-intensive) computation.
- Parallelism: The input tokens are processed in parallel, which is a highly parallelized matrix operation with relatively high execution efficiency.
The characteristics of the token-generation phase are as follows:
- Timing: After the first token is generated in the prompt phase, the token-generation phase begins. This occurs during the calculation of the second output token through the last token.
- Input: The newly generated token will be concatenated with the input tokens and used as the input for the next inference.
- Function: The newly generated tags are fed back to the decoding stage as input, thus creating an autoregressive process for tag generation.
- Number of executions: Assuming there are a total of N tokens in the output, the token-generation phase requires N-1 Forward executions.
- Computation type: There are a large number of GEMM (GEneral Matrix-Matrix multiply) operations, which belong to the Compute-bound type (computation-intensive) computation.
- Parallelism: Assuming there are a total of N tokens in the output, the Decoding stage requires N-1 Forwards. These N-1 Forwards can only be executed serially, so the efficiency is relatively low. In addition, during the generation process, the number of tokens that need to be paid attention to increases (the generation of each token requires attention to previous tokens), and the amount of computation will also increase accordingly.
The autoregressive generation pattern is the fundamental reason for the two-stage model, and the two-stage model is the external manifestation of the autoregressive generation pattern. KV cache is an optimization technique.
Note: In the SplitWise paper, these two phases are referred to as the prompt phase and the token-generation phase, respectively. In practice, the terms “pre-fill” and “initiation” are used interchangeably. For better clarity, we will now prefer to use the former.
1.2 Simplified Derivation
Let’s look at an example to see how an LLM-type model responds to a given text. For clarity, “prompt” here is just one word (which is not accurate). We can break down the response process into the following reasoning: given the input “new,” the model progressively predicts the words “year,” “big,” “auspicious,” and “[EOS].” The specific reasoning steps are as follows.
第一次推理: 输入=[BOS]新;输出=年
第二次推理: 输入=[BOS]新年;输出=大
第三次推理: 输入=[BOS]新年大;输出=吉
第四次推理: 输入=[BOS]新年大吉;输出=[EOS]
[BOS] and [EOS] are the start symbol and the end symbol, respectively.

Next, we’ll delve into the internals of the Transformer to examine the reasoning process described above step by step. Note: The example diagram below only shows details related to the KV Cache.
The first step inputs “New” and outputs “Year”. The specific data flow for this step is shown in the diagram below.

The second step appends “年” to “新” as the new input. Therefore, the input for this inference is “新年” (New Year), and the prediction is “快” (fast). The specific data flow for this step is shown in the diagram below.

The third step appends “快” to the end of “新年” as the new input. That is, the input for this inference is “新年快” and the prediction is “乐”. The specific data flow for this step is shown in the figure below.

1.3 Redundancy Analysis
The three steps above are summarized in the diagram below. It can be observed that there is a large amount of redundant computation. Each time a token is generated, the key/value of all historical tokens needs to be recalculated, resulting in a complexity of . Memory and computation time increase dramatically with the sequence length, for example:
- There is redundant calculation in generating the embedding.
- KV generation involves redundant calculations.
- has redundant calculations.
- The softmax operation and multiplication with V involve redundant calculations.

Since the preceding operations in each step prepare for the calculation of attention, we will focus on analyzing the attention part. The attention calculations involved in each step are as follows ( below refers to the result after the softmax operation; for example, in the second step, might be 0.4, and might be 0.6).
- The first step involves the calculation: .
- The second step involves the calculations: , .
- There is a step that repeatedly calculates , which depends only on and has nothing to do with .
- is a new calculation. As can be seen from , it is only related to and is not related to .
- The calculations involved in the third step are: , , .
- There are two steps of repeated calculation, and the specific reasoning is similar to that of the second step.
- is a new calculation, which is only related to and not to or .
It appears that when predicting the i-th character, only the last step introduces new calculations, while the calculations from the first to the (i-1)-th step are completely repeated from the previous steps.
1.4 Root Causes of Redundancy
Now we explore the reasons for redundant calculations, namely why previous words do not need to be calculated repeatedly.
1.4.1 Examine the processing logic
To generate new tokens closely related to the context, LLMs need to compute the relationship between the last token and all previous tokens (including tokens in the input sequence) in the attention layer. A simple approach is to recompile the keys and values of all previous tokens in each iteration. Therefore, in each step, the output token of the current round is concatenated with the input tokens to form the input tokens for the next round. The input data of round i+1 is identical to that of round i, except that one token is added. However, this means that the inference in round i+1 inevitably includes some of the computation from round i, making it redundant to compute the previous words again. Moreover, the computational cost increases linearly with the number of previous tokens, meaning that the cost is greater for longer sequences.
For each token generation, the query is computed from the current token, while the key and value are derived from all tokens and remain unchanged for subsequent tokens. The vanilla Transformer implementation recomputes the key and value when generating each new token, thus unnecessarily increasing the computational cost per attention block on the GPU.
1.4.2 Observe the processing procedure
From a network structure perspective, the main modules of the Transformer determine that redundant computations are unnecessary:
- Attention module (corresponding to number 1 in the figure below).
- During inference, previously generated tokens cannot see subsequently generated tokens, so previously generated tokens do not require attention calculations with later tokens. Under the influence of “one-way attention,” the query vector at the i-th time step, , does not affect the previous-time-step vectors and . For example, at i=3 and at i=4 are completely identical. In each layer of the Transformer, the Key and Value are not calculated repeatedly.
- During training, due to the use of masking techniques, when generating the output representation of the current tokens, only information from previously generated tokens is used, not information from subsequently generated tokens. That is, , , and are masked and do not need to be computed. The main advantage of masking is that it changes the FLOPs requirement of the (self-)attention mechanism from a quadratic expansion with the total sequence length to a linear expansion. In each generation step, we can effectively avoid recompiling the keys and values of past tokens, and only need to compute the last generated token. Each time new keys and values are computed, we can indeed cache them in GPU memory for future reuse, thus saving the floating-point operations required to recompile them.
- FFN (corresponding to number 2 in the diagram below). In FFN calculation, the features corresponding to each word in the sequence do not interact or affect each other, and only the output feature of the last position is taken as the probability distribution of the next token. Therefore, after the FNN layer, the calculation of the i-th output is only related to the i-th input and is independent of other inputs. For example, below is only related to .
- Add & Norm (corresponding to number 3 in the diagram below). For LayerNorm, it calculates the mean and variance in the
d_modeldirection and then normalizes them, so its output is only related to the last line of the inputhidden_state. - Linear (corresponding to label 4 in the diagram below). This is a linear mapping that transforms
hidden_statefromd_modeltovocab_size. According to the properties of matrix multiplication, we know that the last row oflogitsis only related to the last row ofhidden_state. - Softmax (corresponding to number 5 in the diagram below). Softmax can perform calculations by storing previous calculation results and combining them with newly calculated results.

1.5 How to improve
Although we deduced redundant calculations, the vanilla Transformer doesn’t handle these during inference. It calculates all inputs regardless of whether you only need the last word’s output, resulting in many unnecessary calculations in the output, leading to waste. This is the problem. Therefore, we need to see how to improve it. Since this involves caching or discarding certain intermediate variables related to the preceding text, we need to carefully consider which to cache and which to discard.
1.5.1 From a network perspective
Let’s look at several options from the perspective of model architecture.
| choose | in conclusion | reason |
|---|---|---|
| Discard the preceding X (the input token). | no | The following is a detailed explanation. |
| Cache X | It’s possible, but not the optimal choice. | Even if X is cached, K and V still need to be calculated. |
| cache | no | The previous tokens were not used when the next token was actually calculated. |
| Discard previous queries | Can | The i-th output of the model depends only on the i-th token of query and is independent of other queries. New calculations are only related to the current , but not to the previous . There is no connection, so there is absolutely no need to cache previous queries. |
| Discard the previous KV | no | The following is a detailed explanation. |
| Cache the previous KV | Can | The following is a detailed explanation. |
Why can’t the previous input token be discarded?
We know that inference ultimately selects only the output feature of the last position as the probability distribution of the next token; that is, the next token is determined by the network output of the last token. However, this does not mean that inference can be performed using only the last token as input. Although the result layer is determined solely by the last token, the intermediate attention process relies on the Key-Value vectors provided earlier to carry information from those earlier elements, so the earlier information cannot be ignored.
Alternatively, X generates three branches: Q, K, and V. Since the preceding K and V cannot be discarded, the preceding X cannot be simply discarded. However, due to the usage characteristics of Q in the autoregressive Transformer model and the asymmetry in the computation process, caching Q does not improve inference efficiency. Therefore, Q is usually not cached during LLM inference.
Of course, since X derives from K and V, if K and V are cached, the input X can be discarded.
Why can’t we discard the previous KV?
We’ve already mentioned that key-value pairs (KVs) are indispensable. Let’s analyze that in more detail next.
In the attention mechanism, the i-th output (which can be extended to the output of every transformer block) is related to the complete K, V, and the current-time . Taking the second-step calculation as an example, the red circles show the elements involved in computing , and the blue circles show the elements involved in computing . It can be seen that the blue circle involves all K and V.

Let’s further refine the specific operations using higher-order vectors. As shown in the figure below, involves all QKV.

Feasibility of caching previous key-value pairs
Since the previous key-value pairs are necessary, let’s now look at the feasibility of caching.
- First, the historical values of K and V are only related to the historical O, and not to the current O. From this perspective, K and V can be cached.
- Secondly, the previous tokens remain unchanged in subsequent iterations, so the output representation for that specific token will be the same for all subsequent iterations. During inference, the model weights are fixed (the weights , , and are fixed). For the same word, if its token embedding and positional encoding are fixed, then the Q, K, and V obtained from , , and are fixed. Therefore, they only need to be calculated once.
Therefore, we can avoid repeatedly calculating historical key and value data by caching historical key and value data.
1.5.2 From a mathematical perspective
Suppose we want to multiply matrices A and B. If we split matrix A into [:s] two parts, [s], and multiply each part by matrix B, then the final result can be directly concatenated, and the result is the same as the result without splitting. Attention and FFN are both matrix multiplication operations, so we [:s] cache a portion to avoid redundant calculations caused by the whole input [:].

1.5.3 Conclusion
The above analysis proves that the result of reassembling the cached KV sequence is equivalent to the normal input full sequence calculation, but the amount of computation is greatly reduced. This is KV Cache.
0x02 Optimize using KV Cache
The idea behind KV Cache is straightforward: trade space for time by caching the K and V values from the previous round, thus avoiding recalculating the key and value vectors each time a token is generated. New tokens can be generated using pre-calculated key values and value values, thereby reducing computation and speeding up the process. The main functions of KV Cache are as follows.
- The KV Cache acts as a memory library for the autoregressive generative model, storing all previously tagged keys (K) and values (V) for future reuse, ensuring that the KV is complete.
- Each time a new key vector and value vector are computed in an iteration, the KV cache updates the keys and values of the generated tags.
- The model’s first input is the complete prompt, and subsequent inputs only contain the tokens generated in the previous inference, instead of the entire prompt sequence.
- When calculating
K+1the attention score for the i-th token, the model does not need to recalculate the keys and values of all the previous K tokens, but only needs to retrieve the keys and values of the previous K tokens from the cache and concatenate them into the current vector.
2.1 Terminology
Let’s first look at the structure and terminology of KV-cache. An LLM consists of multiple transformer block layers, each maintaining its own cache of keys and values. In this paper, we will collectively refer to the cache of all transformer blocks as KV-cache, using the terms K-cache or V-cache to represent keys and values respectively. In deep learning frameworks, the K-cache (or V-cache) of each layer is typically represented as a 4D tensor of shape [B, L, H, D], where B represents the batch size and L represents the maximum possible context length requested. We call the kernel implementation that computes attention scores on contiguously stored K and V a vanilla kernel. The following figure shows the mathematical representation of the KV cache.

2.2 Process
Next, let’s look at the autoregressive process after adding the KV Cache. Taking the following diagram as an example, our input prompt is “新年快” (Happy New Year), and we expect the output to be “乐” (Happy). At this point, the KV values for the three words “新年快” will be calculated and stored in the KV Cache.

Then input “乐” and expect the output to be “万”. The specific steps are as follows:
- Calculate the Q, K, and V values corresponding to “乐” (lè). (This corresponds to label 1 in the diagram below.)
- Extract the K and V corresponding to the three tokens “Happy New Year” from the KV Cache. Concatenate the historical K and V values to obtain the complete K and V. In other words, the Key-Value Cache mechanism caches the Key and Value of all previous time steps. This corresponds to label 2 in the diagram below.
- Store the key and value corresponding to “乐” in the KV Cache. This corresponds to label 3 in the diagram below.
- The attention mechanism is calculated, corresponding to label 4 in the diagram below. The input to this attention mechanism then becomes the final generated token (instead of the entire sequence) and the concatenation of the KV cache with the final token (, ).
At this time, , , and correspond to “joy”, and and correspond to “Happy New Year”.
- The new output “万” corresponds to logits, which is labeled 5 in the figure below.

The next steps are:
- Input a new token “万”, calculate only its key/value, and merge it with the 4 cached key/values (“新年快乐”) to generate “事”.
- Input a new token “事”, calculate only its key/value, and merge it with the 5 cached key/value pairs (“新年快乐万”) to generate “如”.
- Input a new token “如”, calculate only its key/value, and merge it with the 6 cached key/value pairs (“新年快乐万事”) to generate “意”.
2.3 Redefining Phase
With the introduction of KV Cache, we redefined the two phases of the inference process and renamed them based on their characteristics. Specifically, the prompt phase was renamed the prefill phase (generating the first token), and the token generation phase was renamed the decoding phase (generating the remaining tokens). This, in turn, affects subsequent optimization methods. Dividing inference into Prefill and Decode processes is due to the significant differences in the computational patterns when generating the first token and the remaining tokens; implementing them separately facilitates targeted optimization.
2.3.1 Definition
Note: Only the parts that differ from the previous definition are given here.
The prefill phase, also known as the initiation phase, has the following characteristics:
- Function: The logical function remains the same as described above (summarize the input sequence and generate a new tag as the initial input for the decoding stage), but at this time, the Prompt of one request will also be converted into KV Cache at once (this is done for each Transformer layer), so it is usually called the pre-filling stage.
- Caching usage: It is not actually affected by the KV caching strategy because no steps have been performed previously.
The characteristics of the Decoding phase are as follows:
- Input: Instead of using the entire sequence as input, we will input one token at a time and output one token at a time.
- Computation Type: The computation type has changed, now resembling a matrix-vector operation; GEMM becomes GEMV (Generative Matrix-Vector multiply). Because FLOPs are reduced, the computational demands at this stage are not as high. Although the GPU’s computing power is not fully utilized compared to the prompt stage, it is still a computational optimization, degenerating matrix Q into the current time-step vector q, and reducing the QK operation between two matrices to a qK operation between a vector and a matrix. Due to the need to transfer weights and KV cache values from the memory system to the computation unit, this stage is limited by memory bandwidth and is a memory-bound computation (memory-intensive). This memory bottleneck is particularly pronounced in applications with long contexts and extensive text generation.
- Cache usage: At this point, the KV Cache already contains historical key-value results, so each round of inference only needs to read the Cache, and then calculate the next token by combining the KV of the input token. At the same time, the new Key and Value calculated in the current round are appended to the Cache.
- Speed: The inference speed is faster than the previous token generation phase without KV Cache because many redundant calculations are omitted.
The corresponding image has also been updated as follows.

The diagram below illustrates how the KV Cache is used in these two stages, using the model structure as an example.
- Prefill is the process of converting the prompt of a single request into a KV cache and generating the first token. Only the last Logit is decoded to obtain the first generated token; the K and V values calculated in the intermediate processes are retained in the GPU memory.
- Decode is the stage where new tokens are generated. At this stage, the prefill cache and the cache generated by the stage itself are used for settlement. The K and V values calculated in the intermediate process are appended to the KV Cache.

The following diagram illustrates the specific algorithm.

2.3.2 Analysis
Researchers have also conducted in-depth analysis of the prefill and decode stages. Understanding these characteristics helps us to make better targeted optimizations. Let’s take a look.
- Different inference services may have drastically different prompt and decode distributions.
- For most requests, the majority of end-to-end (E2E, total user request time) time is spent in the decoding phase.
- The prefill stage is compute-bound, so it can fully utilize computing power, making computing power the bottleneck. The decode stage is memory-bound, so memory is the bottleneck, and it cannot fully utilize computing power.
- Prefill can effectively utilize the GPU and is suitable for high-performance GPUs; the Decode stage can use GPUs with less computing power but higher memory access bandwidth.
- Prefill optimization focuses on operator merging and simplification to reduce model computation. Decoding optimization primarily addresses key-value cache access optimization, such as tile calculation and cache quantization.
- The computation time of the Prefill stage typically increases superlinearly with the increase of input length. The batch size of the Prefill stage should be limited to avoid affecting performance. Conversely, the batch size of the Decode stage should be increased to obtain higher computational intensity and throughput.

As we can see, the characteristics of these two stages are completely different. Even with excellent batching techniques, it is impossible to solve the problems caused by these two obviously different stages. For example, due to insufficient utilization of hardware resources, providing services to users will incur higher costs.
2.4 Reflection
Next, let’s look at some features related to KV Cache.
2.4.1 Historical Context
Let’s broaden our focus to sequence generation. For sequence models, a simple, stateless inference process recalculates all keys and values in the entire sequence in each iteration, including the input tokens provided by the client and the output tokens generated so far. To avoid this recalculation, historical context is typically cached, recording the internal state that needs to be maintained across multiple iterations and reused in subsequent iterations. The figure below illustrates how sequence models are modeled, and provides three models as examples. TTT (Time-to-Time) compresses the context into the model’s weights; this “hidden-state model” maintains a fixed size over time while significantly enhancing expressive power. Since this is not the focus of this article, we will omit it.

The diagram below illustrates the state usage patterns of Transformer and LSTM. LSTM compresses the historical context (such as information containing all past tokens) into a low-dimensional vector called hidden state. In LSTM, the size of the internal memory (c) and the layer’s input/output (h) remains constant.
In the Transformer, since the Attention operation requires all previously labeled keys and values, these K and V values are stored. The Transformer doesn’t compress the state; instead, it uses a cache. Each processed token has its own hidden vector, and all processed hidden vectors together constitute the hidden state. New tokens can interact with past hidden states. This is the KV cache. The KV cache grows continuously over time. This state doesn’t compress any historical context, but the cost increases with the context length.
Let’s examine in detail how the size of the Attention key (k) and value (v) increases with iteration. When processing a label with index t, the Attention operation requires all previous Attention keys and values and , as well as the current key and value and . Therefore, the Attention operation is computed on tensors of different shapes depending on the number of labels processed.

2.4.2 Q is also cached.
Although we cached K and V, Q was actually cached to some extent as well.
First, for self-attention, Q, K, and V are all derived from X and are inherently interconnected. Second, because the Transformer has a multi-layered structure, in a single layer, the information of Q interacts with K and V, and the information of Q is actually contained within K and V to some extent. During multi-layer computation, some information of Q is also passed to the KV Cache of the next layer. This means that in multi-layered attention computation, in addition to the Q value of the current token, there will also be some Q value information from past tokens involved.
2.4.3 Each layer has an independent KV cache
The key-value cache exists in all layers of the Transformer, not just the first layer. This is because:
- The key-value cache is different for each layer.
- In all layers, the key and value vectors of each token depend only on previous tokens. When a new token is added in a subsequent iteration, the key and value vectors of existing tokens remain unchanged.
The KV cache is different for each layer.
Each decode layer needs to cache K and V separately because attention the operations at each layer are independent; that is, the operations of the Lth layer and and are independent and different from those of other layers. If K and V are not cached at each layer, the model would need to recalculate all previous K and V when generating the next token, which would lead to a lot of redundant computation. Caching avoids the repeated calculation of K and V, thereby speeding up the generation process.
Each layer relies solely on the previous token.
For the first layer, the key vector of a token is determined by multiplying the token’s fixed embedding vector by a fixed wk parameter matrix. Therefore, it remains unchanged in subsequent iterations regardless of how many new tokens are introduced. The same logic applies to the value vector. For the second and subsequent layers, to understand why, we can consider the output of the KQV matrix in the first layer’s self-attention stage. Each row of the KQV matrix is a weighted sum that depends on:
- The vector of values for the token mentioned earlier.
- The score is calculated from the key vector of the preceding token.
Therefore, each row in the KQV matrix depends only on the previous token. After some row-based operations, this matrix serves as the input to the second layer. This means that, apart from the newly added rows, the input to the second layer will remain unchanged in future iterations. This logic can be extended to the remaining layers through induction.
2.4.4 Computer Architecture
From a computing architecture perspective, and can be understood as memory storing instructions. The attention mechanism is equivalent to a controller, the token sequence is equivalent to a register, and the KV cache is equivalent to an instruction cache.
2.4.5 Applicable Prerequisites
KVCache is a technique that trades larger video memory space for faster inference speed. But is it unconditionally applicable to all LLMs? Actually, no.
- First, only LLMs that satisfy “causality” are suitable for KV caching. That is, the output of each token depends only on itself and its preceding inputs, and is independent of subsequent inputs. In transformer-type models, BERT-type encoder models do not satisfy this property, while GPT-type decoder models, due to the use of causal masks, do.
- Furthermore, KV Cache has certain requirements for positional encoding. The positional encoding must also satisfy causality, meaning that adding more tokens should not affect existing tokens. Techniques like ReRope adjust the positional embedding of the entire sequence when adding a new token. If the token’s embedding differs between the previous and current embeddings, the KV Cache condition no longer holds. Once the input preprocessing layer no longer meets the KV Cache condition, the input to subsequent transformer layers (i.e., the preprocessing layer’s output) changes and will no longer be suitable for KV Cache.
Another important point is that, due to the model’s positional encoding, the token’s key-value cache is position-dependent. This means that tokens that appear repeatedly in the text cannot share the same key-value cache.
0x03 Implementation
The source code of GPT2, Baichuan2, and LLaMA shows that the core code of the KV Cache is implemented in just a few lines and is not complicated.
3.1 Overall Approach
The basic idea of KV Cache is as follows:
KV-Cache continuously calls and updates past_key_values during model inference. When the model infers for the first time, past_key_values is empty and needs to be initialized. For the first inference, all text needs to be input at once, and all keys and values from the intermediate process need to be added to past_key_values.
Starting from the second inference, only the current last token needs to be input. The Q, K, V mapping is performed on the token separately. All the K and V values in the previous text in past_key_values are concatenated with the K and V values of the token to obtain the completed Key and Value vector. Finally, attention is calculated with the Query of the token. The concatenated Key and Value are also updated to past_key_values.

The flowchart of the KV-Cache code implementation is as follows. As you can see, the content of the KV-Cache comes from two sources:
- Type prompt;
- The generated token.

Furthermore, because the KV cache involves high-frequency reads and writes and has a very large scale, it requires efficient management, such as using multi-level memory pools. Moreover, the actual business logic of the KV cache varies, including MHA, GQA, MLA, and DoubleSparse, requiring proper isolation between these business logics. For example, a primary memory pool records high-level information, isolated from specific business logic, and tracks the token location used by each request. The specific KV caches (MHA, MLA, DoubleSparse) reside in a secondary memory pool.
3.2 Storage Structure
3.2.1 llama3
Let’s take llama3 as an example to look at the storage structure of KV Cache.
Below are the member variables of the Attention class. Since each TransformerBlock has Attention, these are single-layer member variables.
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
.cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
.cuda()
3.2.2 Transformer Library
We will then use the Transformer library for comparative learning.
Within each layer, the key and value vectors for each head are stored in memory. In HuggingFace’s code implementation, this is stored using the past_key_values variable, which is a matrix with dimensions [n, 2, b, h, s, d], similar to a six-dimensional matrix. The meaning of each dimension is as follows:
- The first dimension, num_layers, is based on each stacked block. For example, if there are 12 stacked blocks, there are a total of 12 sets of key and value information.
- The second dimension 2 represents the two information objects: Key and Value. Index 0 is the Key vector, and index 1 is the Value vector.
- The third dimension, batch_size, represents the batch size, which is equal to the number of texts that need to be reasoned about. If the input is a single text, then b=1.
- The fourth dimension, num_heads, represents the number of attention heads. For example, if there are 12 heads per layer, then h=12.
- The fifth dimension, seq_len, represents the text length up to the current token, including the key and value information of the token at each level and header in the historical token position.
- The sixth dimension d represents the mapping dimension of the Key and Value vectors. If the total mapping dimension of the token is 768 and the number of attention heads is 12, then d = 768/12 = 64.

The structure of past_key_values is shown in the diagram above. As the model’s inference step size increases, past_key_values is updated synchronously at each step. The difference between the previous past_key_values and the next past_key_values only occurs in the seq_len dimension. Specifically, the size of the seq_len dimension increases by 1. This is caused by concatenating the Key and Value corresponding to the newly inferred token into the seq_len dimension of the previous past_key_values. If we exclude this increment, the vectors of the previous past_key_values and the next past_key_values are completely identical in the seq_len dimension.
The Huggingface Transformer library abstracts the concept of cache and implements various caches. Examples of the main caches are as follows:
- DynamicCache: A cache that grows dynamically as more tokens are generated. It stores the key-value state as a list of tensors, one tensor per level. The expected shape of each tensor is [batch_size, num_heads, seq_len, head_dim].
- StaticCache: A static cache class used in conjunction with torch.compile(model).
- SinkCache: Implements the cache described in the Attention Sinks paper. It allows the model to generate tokens beyond the length of its context window without losing session fluency. Because it discards past tokens, the model will lose the ability to generate tokens that depend on the discarded context. It stores key-value states as a list of tensors, one tensor per layer. The expected shape of each tensor is [batch_size, num_heads, seq_len, head_dim].
Let’s take StaticCache as an example to see its specific data structure.
past_key_values = StaticCache(
model.config,
batch_size=batch_size,
device=device,
dtype=torch.float16,
max_cache_len=seq_length + num_tokens_to_generate,
)
As you can see, the shape of each KV cache is cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim). Each KV cache is enclosed in a new_layer_key_cache, meaning there are a total of num_hidden_layers cache_shape layers. Each layer contains two KV caches.
class StaticCache(Cache):
"""
Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
Parameters:
config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
StaticCache()
```
"""
def __init__(
self,
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
self.batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
for idx in range(config.num_hidden_layers):
if layer_device_map is not None:
layer_device = layer_device_map[idx]
else:
layer_device = device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
if not is_torchdynamo_compiling():
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how where to write in the cache.
Return:
A tuple containing the updated key and value states.
"""
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
if cache_position is None:
k_out.copy_(key_states)
v_out.copy_(value_states)
else:
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
# operation, that avoids copies and uses less memory.
try:
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out
3.3 How to use
We’ll use LLaMA3 as an example to illustrate how to use the KV Cache. After enabling the KV Cache, the forward method returns a list of tensor pairs (one key tensor pair, one value tensor pair). The number of these tensor pairs is the same as the number of decoder blocks in the model (often called decoder layers, denoted as n_layers). For each token in each sequence of the batch, each attention head has a key/value vector of dimension d_head, so the shape of each key/value tensor is (batch_size, seq_length, n_heads, d_head).
The caching works as follows:
- During the initial iteration, the key and value vectors of all tokens are computed and saved to the KV cache.
- In subsequent iterations, only the key-value vectors of the latest token need to be computed. The cached key-value vectors are concatenated with the key-value vectors of the new token to form the K and V matrices. This avoids recompiling the key-value vectors of all previous tokens, thus greatly improving efficiency.
- In subsequent iterations, only the key vector of the latest token is calculated; the others are retrieved from the cache and combined with the newly calculated key vector to form the K matrix. The newly calculated key vector is also saved to the cache. The same process applies to value vectors.
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
# 初始化KV Cache
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
# 将当前 Token 的 kv 值更新到 KV Cache,并返回新的 KV
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
0x04 Resource Usage
4.1 Dimensional Changes
The figure below illustrates the Transformer architecture, the inputs, outputs, and the shapes of the weight tensors for various operations. Assume the input is a tensor X of shape [B, L, H], where B represents the batch size, L represents the sequence length for each request (i.e., the number of input tokens in a given query), and H is the model’s embedding size.

Considering only one header, the dimensionality changes during the Transformer’s prefill phase are as follows:
- Preprocessing stage: Primarily the preproj module. After X is transformed by weight matrices , , and with shapes [H, H], it generates Q, K, and V, with each output tensor having a shape of [B, L, H]. The characteristics of this stage are: preproj computation requires reading model weights from GPU memory, and the computation is independent of the input sequence length (it only performs a linear transformation along the hidden_size dimension).
- The attention computation phase mainly consists of the self attention module and the postproj module.
- Self-attention: The process of calculating attention scores using Q, K, and V. The output of this stage is a tensor Y of shape [B, L, H]. The key features of this stage are: score calculation does not require reading model weights from GPU memory; you only need to use the pre-calculated QKV; and the calculation relies on a mask matrix, which is different for different sequences.
- postproj: Uses to map the attention-processed sequence Y, returning a tensor Z of shape [B, L, H]. Its properties are the same as preproj.
- The FFN phase. The FFN module performs two batch matrix multiplications. In
ffn_ln1, Z is multiplied by a weight tensor of shape [H, H2], producing a weight tensor of shape [B, L, H2]. This is then multiplied by the weight tensor of shape [H2, H] inffn_ln2, outputting a weight tensor of shape [B, L, H]. Here, H2 refers to the second hidden dimension of the model. The characteristics offfn_ln1are consistent with those of preproj.
The decoding phase performs the same operations as prefill, but only for the single token generated in the previous autoregressive iteration. Therefore, the shape of the input tensor in the decoding phase is [B, 1, H] (the opposite of [B, L, H] in prefill).
- Preprocessing stage: The resulting Q, K, and V are all [B, 1, H]. The shape of the K and V tensors for each token is [1, H].
- Attention calculation phase: The K and V tensors obtained from the KV Cache have the shape [B, prev_kv_seq_len, H]. After concatenation with the current K and V, the tensor shape becomes [B, prev_kv_seq_len + 1, H]. is [B, 1, H] x [B, H, prev_kv_seq_len + 1] -> [B, 1, prev_kv_seq_len + 1]; is [B, 1, prev_kv_seq_len + 1] x [B, prev_kv_seq_len + 1, H] -> [B, 1, H].
- FFN phase. Output is [B, 1, H].
From the above analysis, we can easily find that the memory access overhead of the attention operator mainly depends on the sequence length of the KV, while the computational overhead mainly depends on the sequence length of the Q. In the prefill stage, the Q sequence is generally long, and the attention operator is computationally intensive; while in the decode stage, the Q sequence length is 1, and the attention operator is memory access intensive.
4.2 Storage Capacity
4.2.1 Single layer
The size of each token in all input batch sequences is related to the model configuration and is fixed. Based on this, the total size of the KV cache can be expressed by the following formula:
in:
- 2 represents two vectors, Key and Value, which need to be stored at each level.
- B represents batch size.
- L represents the total sequence length, which is the input sequence plus the output sequence, or the hint plus the completed part.
- H stands for number of head.
- D represents the size of the head, the dimension of each head.
- P represents the number of bits required to store the key-value (kV) data format, that is, the number of bytes required to store one kV cache data. For example, fp16 requires 2 bytes.
4.2.2 Multi-story
If N represents the number of blocks, i.e., the model depth, then the total KV cache storage space required for one model is:
4.2.3 Actual Example
Assuming 100K contexts, 60 levels, 8 headers, and 128 embedding dimensions, using bf16 storage, the KV cache size is:
Taking LLaMA-7B as an example, the model loading occupies 14GB of GPU memory, has a vector dimension of 4096, a stacking of 32 layers, and a maximum inference step size of 4096. If inferring a sentence with a batch size of 2 and a length of 4096, the storage space occupied by the KV-Cache is 2×2×32×4096×2×4096=21474836480 bytes, approximately 4GB. As the inference batch size increases and the inference length increases, the storage space occupied by the KV-Cache may exceed that of the model itself. For example, if the batch size = 4, in LLaMA 2 70B, assuming the number of input and output tokens reaches the model’s limit of 4096, the 80-layer KV-Cache would require a total of 2 (K, V) * 80 * 8192 * 4096 * 8 * 2B = 80 GB. If the batch size is larger, the space occupied by the KV-Cache will exceed the 140 GB occupied by the parameters themselves.
4.2.4 Storage Implementation
KVCache is proportional to the current number of tokens, the vector dimension, and the number of layers. Among these, the current number of tokens is the most troublesome, as it continuously increases during inference. Storing variable-length data is always tedious, and there are essentially three methods to solve this:
- Allocating a buffer with a maximum capacity requires knowing the maximum number of tokens in advance. However, allocating according to the maximum capacity is very wasteful.
- Dynamically allocating the buffer size, similar to the classic vector append approach, doubles the size when the capacity is exceeded. This is also a feasible solution, but (on GPU devices) the overhead of frequently allocating and releasing memory is significant, resulting in low efficiency.
- Break down the data and store it in the smallest units, using metadata to record the location of each piece of data.
The last approach, which is currently the most widely used, is also called PageAttention. During initialization, the program allocates a single block of GPU memory (e.g., 4GB), divides it into smaller blocks according to the size of the KVCache, and records which block each token will use during inference. The allocation, release, and management of these smaller GPU memory blocks are similar to the virtualization process of physical memory in an operating system. This is the well-known idea behind vLLM (see the paper Efficient Memory Management for Large Language Model Serving with PagedAttention for details).
4.3 Computational complexity
The figure below illustrates the computation, data transfer, and arithmetic strength of the pre-filling stage. We use the asymptotic notation O to denote the complexity of the data transfer volume, where the constant factor of complexity depends on the specific implementation method.

The diagram below illustrates the computation, data transmission, and arithmetic strength during the decoding phase.

In the prefill stage, we need to calculate Attn(Q, K, V) and also fill the KV Cache, so the computational cost is not reduced. Therefore, we need to look at the computational cost in the decoding stage. The KV Cache mainly saves the following two parts.
- The calculations of K and V in the previous n-1 times are cached and do not need to be recalculated.
- FFN: Because it only outputs the logits of a single token, this part of the computation is also reduced.
Let’s take a look at the specific execution process.
4.3.1 Table lookup
Although the table lookup phase does not consume much computation, using a key-value cache can eliminate the computation required for querying the first t+N-1 tokens.
4.3.2 Calculation of , ,
Calculating the key or value vector for a specific token is simply a matter of multiplying its embedding vector of size d_model with a weight matrix of shape (d_model, d_head).
Single reasoning
- In the standard model, the computational cost for this part is .
- In kv cache mode, the query is modified to a single token, and the required computation is .
4.3.3 Attention
During the decoding phase, we need to add an output (token) to the original sequence. Since the previous kv results can be reused, we only need to compute Decode: Attn(q, K, V). Here, the length of q is 1, while the lengths of the sequences K=[k_cache, k] and V=[v_cache, v] are greater than 1. That is, after using the KV Cache, all matrix multiplication operations in Multi-Head Attention are reduced to matrix-vector multiplication.
Single reasoning
- In standard mode, the attention computation is .
- In kv cache mode, the query is modified to a single token, and the attention computation is .
4.3.4 MLP
In FFN, tokens do not cross-merge, meaning any token can be computed independently. Therefore, previous results are not cached during the decoding phase. However, matrix multiplication can still be downgraded to matrix-vector multiplication. A single inference step would then look like this:
- The computational cost in the standard mode is .
- In kv cache mode, the query is modified to a single token, and the computational cost is .
4.3.5 Comparison
Without KV cache
The computational cost of each transformer layer is approximately . Details are as follows.
| Module | operate | Output | Output shape | computational load |
|---|---|---|---|---|
| Embedding | Lookup table | X | [b, s, h] | - |
| Attention | Calculate Q, K, V | Q, K, V | [b, s, h] | |
| Attention | QK^T | Attention Score | [b, head_num, s, s] | |
| Attention | Multiply by V | Attention weight | [b, head_num, s, head_dim] | |
| Attention | post-attention linear projection | Attention weight | [b,s,h] | |
| FFN | First linear layer | intermediate state | [b,s,4h] | |
| FFN | Second linear layer | Z | [b,s,h] |
KV Cache
When a KV cache exists, the computation cost of each transformer layer is approximately , as detailed below.
| Module | operate | Output | Output shape | computational load |
|---|---|---|---|---|
| Embedding | Lookup table | X | [b, 1, h] | - |
| Attention | Calculate Q, K, V | Q, K, V | [b, 1, h] | |
| Attention | QK^T | Attention Score | [b, head_num, 1, prev_kv_seq_len + 1], approximately equal to [b, head_num, 1, s] | |
| Attention | Multiply by V | Attention weight | [b, head_num, 1, head_dim] | |
| Attention | post-attention linear projection | Attention weight | [b,1,h] | |
| FFN | First linear layer | intermediate state | [b,1,4h] | |
| FFN | Second linear layer | Z | [b,1,h] |
As can be seen, for a single operation, the computational cost is reduced by a factor of s. If we consider the sequence length, the reduction becomes quadratic.
summary
Suppose we have a batch of input sequences, numbering b. Each sequence consists of N generated tokens and t input tokens (total length N+t).
Choosing a KV cache will save approximately the following number of FLOPs in the first N generation steps:
Actually, you can remove the number of tokens and just look at how much computation is saved per token.
In other words, the amount of computation saved through KV caching is directly proportional to the number of tokens. The longer the text, the more significant the reduction in computation.
Taking LLaMa-7B as an example, when reasoning about a sentence with a batch size of 2 and a length of 4096, the computational cost of calculating key-value pairs alone is reduced by 2 × 2 × 32 × 4096 × 4096 × 4096 × 2 = 17592186044416 FLOPs. Furthermore, the KV Cache not only eliminates the need for mapping the key and value of all tokens mentioned earlier, but also eliminates the need for subsequent calculations of attention weights, attention MLP layers, and FFN feedforward layers. Essentially, the computational complexity during the reasoning phase is always equal to performing a complete forward reasoning on only one token, thus significantly reducing the computational load.
The figure below, from the paper “A Survey on Large Language Model Acceleration based on KV Cache Management,” illustrates the computational savings achieved by the KV Cache. For each token, the computational savings come from avoiding redundant calculations of keys and values in equation (1), the self-attention calculation results in equation (2), and the linear transformation in equation (3). The paper omits operations in the Transformer that do not affect the accelerated understanding of the KV cache, such as layer norm and positional encoding.

4.4 Summary
We will begin with a core comparison.
| Dimension | No KV Cache | KV Cache |
|---|---|---|
| computational complexity | Growth with the square of sequence length | Only need to calculate the new token |
| Video memory usage | Storing intermediate results of the complete sequence requires high GPU memory. | Caching key/value pairs allows for controllable GPU memory requirements. |
| Generation speed | Slow (recalculation of historical tokens) | Fast (only calculates new tokens, reuses cache) |
| Applicable Scenarios | Short sequence generation (<100 tokens) | Long sequence generation (such as API input, video generation) |
Specifically, the advantages of KV Cache are mainly reflected in the following dimensions:
- Reduce redundant computation. In self-attention mechanisms, without a KV cache, the model needs to recalculate the key and value vectors of the entire historical sequence and participate in attention calculations every time a new token is generated, leading to a large amount of redundant computation. The KV cache effectively eliminates the overhead of redundant computation by caching the key and value representations of processed tokens, significantly reducing the computational complexity of inference.
- Improve inference speed. KV Cache caches key and value vectors, allowing the model to calculate the query vector of the current token and perform attention calculations with the cached key and value when generating a new token. Compared to fully computing , degenerating to significantly reduces FLOPs and substantially improves inference speed;
- Reduce computational complexity. The computational complexity of the self-attention mechanism is O(n^2⋅d), where n is the sequence length and d is the vector dimension. Using a KV cache reduces this complexity to O(n⋅d). Compared to full computation , degenerating to significantly reduces FLOPs, thus substantially decreasing the computational load.
- The growth curve of maximum memory consumption as the sequence length increases changes from quadratic to linear, thus achieving effective control.
- In terms of context processing capabilities, KV Cache ensures the model’s accurate understanding of the context by maintaining complete long sequence representations. This mechanism enhances the effectiveness of the attention mechanism, enabling the model to accurately retrieve historical information, thereby guaranteeing semantic coherence and quality stability during long text generation.
- In terms of dynamic characteristics, KV Cache exhibits excellent adaptability. The system can dynamically adjust the cache size according to the length of the input sequence, flexibly responding to the needs of different scenarios, and is especially suitable for dynamic application scenarios such as real-time interactive dialogues.
- Cross-request reuse. In some scenarios, the prompts of multiple requests may share the same prefix. In these cases, the key-value cache calculation results of the prefix are the same for many requests, and can be cached for reuse in the next request.
In summary, KV Cache effectively reduces redundant calculations, lowers computational complexity, and improves inference speed by caching key and value vectors in LLM inference. It also optimizes the use of GPU memory resources, thereby improving the model’s inference efficiency and throughput.
0xFF Reference
- Notion – The all-in-one workspace for your notes, tasks, wikis, and databases.
- ZHANG Mingxing: Mooncake (1): Making mooncakes on the dark side of the moon, Kimi’s KVCache-centered split reasoning architecture
- The Taizu Long Fist of Parallel Inference in Large Models: Deciphering Jeff Dean’s Outstanding MLSys 23 Paper by Fang Jiarui
- Illustrated Explanation of Mixtral 8 * 7b Inference Optimization Principles and Source Code Implementation by Mengyuan
- Illustrated Guide to Accelerating Large Model Computation Series: Separate Inference Architecture 2, Chunked-Prefills for Fuzzy Separation and Merging Boundaries ( Ape)
- Random Thoughts on Mooncake by Xu Xinran
- DeepSpeed inference code understanding
- A Brief Analysis of Llama.cpp Code (Part 1): Parallel Mechanism and KVCache .
- As DeepSeek open-sources FlashMLA, this article provides a detailed explanation of MLA from its principles to its code. ( Du Lingxiao [Tanzhixuan])
- kv-cache Principles and Optimization Overview Zhang
- Discussing the evolution of large-scale model architectures, The Art of memory [zartbot]
- Illustrated KV Cache: The Key to Unlocking LLM Inference Efficiency (To Great ChallengeHub)
- Designing SGLang’s KV Cache from Scratch - Wang Yan
- https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/tree/main/sglang/kvcache-code-walk-through
- A Survey on Large Language Model Acceleration based on KV Cache Management
- Abridged version of “A Review of LLM Acceleration Research Based on KV Cache Management” by Chang Hua and Andy