Exploring the Transformer Series (27) --- MQA & GQA
Exploring the Transformer Series (27) --- MQA & GQA
0x00 Overview
As mentioned in the previous section on “Optimizing KV Cache,” the main related work on “reducing the number of attention heads” currently includes MQA and GQA. MQA and GQA optimize the number of KV values to cache: intuitively, if the number of cached KV values is smaller, the GPU memory usage will be smaller, and the reduction in the capacity of large models can be compensated for by further training or increasing the scale of FFN/GLU.
Because MQA and GQA are improvements on MHA, we illustrate the differences between the three in the following diagram. As you can see, by reducing the number of attention heads, MQA/GQA reduces KV cache storage, allowing different attention heads or attention heads in the same group to share a single set of K and V values, since only one (or a few) copies of the query parameters are retained. Therefore, there is only one (or a few) copies of the K and V matrices, which significantly reduces memory usage and makes it more efficient. Furthermore, traditional MHA-based attention operators are too restrictive of memory access bandwidth; MQA, GQA, and even subsequent MLA can improve the computation-to-memory ratio, which also greatly enhances performance.

Note:
- The complete list of articles is here. It’s estimated to eventually have around 35 articles. This list will be updated after each subsequent article is published. (Cnblogs Exploring Transformer Series: Article List)
- This series is a study and interpretation of papers, blogs, and code, drawing on many articles from online friends, to whom I express my gratitude and will list them in the references. Because there are so many references in this series, there may be omissions in the citations. If the original authors find any omissions, please point them out, and I will add them to the references.
0x01 MHA
Since MQA and GQA are modifications based on MHA, it is necessary for us to review MHA first.
1.1 Concept
MHA (Multi-Head Attention) was proposed in 2017 along with the original Transformer paper “Attention Is All You Need”. Its main contribution is to break down the original attention computation into multiple smaller attention heads, that is, to split Q, K, and V into multiple parts, and each attention head uses independent Q, K, and V for computation. The multiple heads can be computed in parallel, yielding results separately, and then finally combined back into the original dimension.
Let’s look at the MHA process through the following diagram, where we set (d) to represent the dimension of word embedding, (n_h) indicates the number of attention heads, (d_h) represents the dimension of each head, (h_t \in R^d) represents the input of the (t)th token in an attention layer, and (W^O \in R^{d \times d_h n_h}) represents the output mapping matrix. MHA can then be divided into the following four steps:
- Through 3 parameter matrices (W^Q, W^K, W^V \in R^{d_h n_h \times d}), you can get (q_t, k_t, v_t \in R^{d_h n_h}).
- (q_t, k_t, v_t) will be divided into (n_h) vectors, (q_{t,i}, k_{t,i}, v_{t,i} \in R^{d_h}). Let Q, K, and V represent the (i)th vectors, respectively. These split vectors will be referred to as the Q head, K head, and V head.
- Each attention head will use its acquired Q, K, and V vectors to perform attention calculations.
- Use (W^O). The results of multi-head attention calculations are merged.

1.2 Implementation
1.2.1 Harvard
Let’s review the implementation of the MHA code in “The Annotated Transformer”.
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
'''
h: head number
'''
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
# We assume d_v always equals d
self.d = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
if mask is not None:
# Same mask applied to all h heads.
mask = mask.unsqueeze(1)
nbatches = query.size(0)
# 1) Do all the linear projections in batch from d_model => h x d
query, key, value = \
[l(x).view(nbatches, -1, self.h, self.d).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch.
x, self.attn = attention(query, key, value, mask=mask,
dropout=self.dropout)
# 3) "Concat" using a view and apply a final linear.
x = x.transpose(1, 2).contiguous() \
.view(nbatches, -1, self.h * self.d)
return self.linears[-1](x)
1.2.2 llm-foundry
In contrast, let’s look at products from the industrial sector.
class MultiheadAttention(nn.Module):
"""Multi-head self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
attn_impl: str = 'triton',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
low_precision_layernorm: bool = False,
verbose: int = 0,
device: Optional[str] = None,
):
super().__init__()
self.attn_impl = attn_impl
self.clip_qkv = clip_qkv
self.qk_ln = qk_ln
self.d_model = d_model
self.n_heads = n_heads
self.softmax_scale = softmax_scale
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
self.attn_dropout_p = attn_pdrop
self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
# for param init fn; enables shape based init of fused layers
fuse_splits = (d_model, 2 * d_model)
self.Wqkv._fused = (0, fuse_splits) # type: ignore
if self.qk_ln:
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
self.q_ln = layernorm_class(self.d_model, device=device)
self.k_ln = layernorm_class(self.d_model, device=device)
if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
elif self.attn_impl == 'triton':
self.attn_fn = triton_flash_attn_fn
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
self.out_proj._is_residual = True # type: ignore
def forward(
self,
x,
past_key_value=None,
attn_bias=None,
attention_mask=None,
is_causal=True,
needs_weights=False,
):
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
query, key, value = qkv.chunk(3, dim=2)
key_padding_mask = attention_mask
if self.qk_ln:
# Applying layernorm to qk
dtype = query.dtype
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)
context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
self.n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
)
return self.out_proj(context), attn_weights, past_key_value
The code for scaled_multihead_dot_product_attention() is as follows.
def scaled_multihead_dot_product_attention(
query,
key,
value,
n_heads,
past_key_value=None,
softmax_scale=None,
attn_bias=None,
key_padding_mask=None,
is_causal=False,
dropout_p=0.0,
training=False,
needs_weights=False,
multiquery=False,
):
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
kv_n_heads = 1 if multiquery else n_heads
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
if past_key_value is not None:
if len(past_key_value) != 0:
k = torch.cat([past_key_value[0], k], dim=3)
v = torch.cat([past_key_value[1], v], dim=2)
past_key_value = (k, v)
b, _, s_q, d = q.shape
s_k = k.size(-1)
if softmax_scale is None:
softmax_scale = 1 / math.sqrt(d)
attn_weight = q.matmul(k) * softmax_scale
if attn_bias is not None:
_s_q = max(0, attn_bias.size(2) - s_q)
_s_k = max(0, attn_bias.size(3) - s_k)
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
attn_weight = attn_weight + attn_bias
min_val = torch.finfo(q.dtype).min
if key_padding_mask is not None:
attn_weight = attn_weight.masked_fill(
~key_padding_mask.view((b, 1, 1, s_k)), min_val)
if is_causal and (not q.size(2) == 1):
s = max(s_q, s_k)
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
causal_mask = causal_mask.tril()
causal_mask = causal_mask.to(torch.bool)
causal_mask = ~causal_mask
causal_mask = causal_mask[-s_q:, -s_k:]
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k),
min_val)
attn_weight = torch.softmax(attn_weight, dim=-1)
if dropout_p:
attn_weight = torch.nn.functional.dropout(attn_weight,
p=dropout_p,
training=training,
inplace=True)
out = attn_weight.matmul(v)
out = rearrange(out, 'b h s d -> b s (h d)')
if needs_weights:
return out, attn_weight, past_key_value
return out, None, past_key_value
1.3 Resource Consumption
If the model structure is MHA, during inference, the KV Cache needs to cache the following parameters for each token: (2n_h d_h l) (where (l) represents the network layer number). As the number of model layers and heads increases, the computing power, I/O, and memory required for attention calculations increase rapidly. However, these resources are often underutilized.
In the diagram below, (d) represents the hidden size, (h) represents the number of Heads, and (l) represents the number of Tokens in the current input sequence.
- When the batch size is 1, all multiplications at the red, green, and blue dashed circles in the diagram are matrix-vector multiplications, which is a clear memory bound and the arithmetic strength is less than 1.
- When the batch size is greater than 1 (e.g., in continuous batching):
- The red and blue parts: linear layer calculation is weight multiplied by activation. Weights can be shared between different requests, so it is matrix multiplication. The larger the batch size, the greater the arithmetic strength, and the closer it is to computationally intensive (FFN layer is similar).
- The green section shows that attention calculations are activations multiplied by activations. Because there is no correlation between different requests, even in batching, the matrix multiplication vectors are batched. Furthermore, since the sequence lengths may vary, the matrix multiplication vectors between different requests are irregular. That is, the arithmetic strength is always less than 1, which is a clear memory bound.
- Therefore, the green part is difficult to optimize, and the longer the input sequence, the greater the bottleneck here.

To alleviate these resource consumption issues and make better use of resources, methods such as MQA (Multi-Query Attention) and GQA (Grouped-Query Attention) have emerged. These methods are all products developed around the theme of “how to reduce resource consumption while ensuring the best possible results”.
0x02 MQA
The basic assumption is that there is very high sparsity in the head dimension, allowing us to reduce the number of heads to a fairly small number. Among these attention heads, some are dedicated to retrieval and long context-dependent capabilities, so these retrieval heads should be retained while others are pruned. It’s important to note that head pruning typically occurs after pre-filling, meaning it only improves decoding, concurrency, and context switching, but not the pre-filling stage itself.
2.1 Concept
MQA (Multi-Query Attention) originates from the paper [2019] Fast Transformer Decoding: One Write-Head is All You Need. In MQA, the multi-header nature of the query is preserved; all query heads share the same single key and value header. This reduces the number of key and value matrices, thereby lowering computational and storage overhead. Essentially, the attention differences between different headers are all focused on the query, requiring the model to be able to focus on different aspects of the input hidden states solely from the different query headers.
The specific characteristics of MQA are as follows.
- Q retains the original number of heads; that is, after the linear transformation, Q is still segmented (like MHA), and each attention head retains its own Q vector.
- K and V have only one head; specifically, during linear transformation, the dimensions of K and V are directly reduced to (d_{head}) instead of cutting them into smaller pieces.
- All Q headers share the K and V headers, or you can think of it as the k and v matrix parameters being shared. In implementation, this involves modifying the linear transformation matrix and changing the processing of K and V from splitting to copying.
- All Q heads use the same K head to calculate their attention scores, and the outputs of all heads are calculated using the same V head (but with different attention scores).
- Finally, the results calculated for each head are concatenated.

2.2 Implementation
Let’s take LLM-foundry as an example for analysis.
1.2.1 Simplified Version
We’ll first provide a simplified comparison of MHA and MQA. Here, we assume x (tensor): (batch, hidden_state, d_model), for example, (1, 512, 768). As you can see, the main difference lies in:
- The dimensions of the W matrix are different.
- The QKV segmentation methods are different.

As can be seen from the code, for MQA, all heads share the same key and value parameters. But how can this single set of parameters be used by all eight heads simultaneously? The code in the scaled_multihead_dot_product_attention() function uses matrix multiplication matmul to broadcast, so that each head is multiplied by the same tensor, thereby achieving parameter sharing.

The overall process of MQA can be seen in the following diagram.

1.2.2 (Full Version)
We will now provide the complete version of the code.
class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
Using torch or triton attention implemetation enables user to also use
additive bias.
"""
def __init__(
self,
d_model: int,
n_heads: int,
attn_impl: str = 'triton',
clip_qkv: Optional[float] = None,
qk_ln: bool = False,
softmax_scale: Optional[float] = None,
attn_pdrop: float = 0.0,
low_precision_layernorm: bool = False,
verbose: int = 0,
device: Optional[str] = None,
):
super().__init__()
self.attn_impl = attn_impl
self.clip_qkv = clip_qkv
self.qk_ln = qk_ln
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.softmax_scale = softmax_scale
if self.softmax_scale is None:
self.softmax_scale = 1 / math.sqrt(self.head_dim)
self.attn_dropout_p = attn_pdrop
# NOTE: if we ever want to make attn TensorParallel, I'm pretty sure we'll
# want to split Wqkv into Wq and Wkv where Wq can be TensorParallel but
# Wkv shouldn't be TensorParallel
# - vchiley
self.Wqkv = nn.Linear(
d_model,
d_model + 2 * self.head_dim,
device=device,
)
# for param init fn; enables shape based init of fused layers
fuse_splits = (d_model, d_model + self.head_dim)
self.Wqkv._fused = (0, fuse_splits) # type: ignore
if self.qk_ln:
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
self.q_ln = layernorm_class(d_model, device=device)
self.k_ln = layernorm_class(self.head_dim, device=device)
if self.attn_impl == 'flash':
self.attn_fn = flash_attn_fn
elif self.attn_impl == 'triton':
self.attn_fn = triton_flash_attn_fn
elif self.attn_impl == 'torch':
self.attn_fn = scaled_multihead_dot_product_attention
else:
raise ValueError(f'{attn_impl=} is an invalid setting.')
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
self.out_proj._is_residual = True # type: ignore
def forward(
self,
x,
past_key_value=None,
attn_bias=None,
attention_mask=None,
is_causal=True,
needs_weights=False,
):
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
query, key, value = qkv.split(
[self.d_model, self.head_dim, self.head_dim], dim=2)
key_padding_mask = attention_mask
if self.qk_ln:
# Applying layernorm to qk
dtype = query.dtype
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)
context, attn_weights, past_key_value = self.attn_fn(
query,
key,
value,
self.n_heads,
past_key_value=past_key_value,
softmax_scale=self.softmax_scale,
attn_bias=attn_bias,
key_padding_mask=key_padding_mask,
is_causal=is_causal,
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
multiquery=True,
)
return self.out_proj(context), attn_weights, past_key_value
2.3 Effects
2.3.1 Memory
MQA requires caching key-value pairs (KV) from all headers to a single header, thus reducing the KV cache to (1/h) of its original size. The number of KV values that a single token in MHA needs to store is ((2 * l * n_h)), while MQA reduces it to ((2 \times l)), that is, each layer shares one (k) vector and one (v) vector.
2.3.2 Speed

The paper’s authors conducted a series of tests, detailed in the table above (the values represent the average number of milliseconds required to generate each token). Several points to note are:
- The training speed remained almost unchanged.
- Both inference time and beam search time were significantly reduced.
- In terms of inference speed, the encoder’s inference speed remains basically unchanged, while the decoder’s inference speed is much faster.
Although MQA only has one key-value (KV) header, it actually reads this KV header and copies it to all Q headers. Therefore, theoretically, MQA should only reduce memory usage and not reduce computational load. So why is the speed increased so much? The main benefit is the reduction in computational load due to the reduced KV cache, as detailed below:
- The KV-Cache space usage is reduced. Because the number of headers is reduced, the number of tensors that need to be stored in GPU memory is also reduced (for example, if a KV Cache previously required storing 32 headers, now only 1 header needs to be stored). The saved space can be used to increase batch size, improve throughput, and thus increase efficiency (although the total latency of a single request will increase, the total throughput of the service will increase significantly).
- This reduces the time overhead of reading model weights from memory. Because the number of heads is reduced, the amount of data read from GPU memory is decreased, reducing the waiting time of computation units and shifting the model from memory-intensive to computationally intensive. Furthermore, different heads within the same request can be shared, which improves the arithmetic strength of the attention computation for Q, K, and V.
2.3.3 Characterization ability
Because there is currently only one shared key-value (KV) head, the attention differences previously achieved with multiple Q-KV heads must now be handled solely by multiple Q-heads. This limits the model’s representational capabilities. Therefore, while MQA effectively supports inference acceleration, its performance is slightly worse than MHA. To compensate for the reduced number of parameters due to shared KV, the size of the FFN/GLU is often increased to maintain the total number of model parameters and thus offset some of the performance loss.
Another point to note is that because MQA and GQA alter the structure of the attention mechanism, models typically need to support MQA or GQA from the start of training. Forcibly replacing the KV cache with these two methods after the model has already been trained will result in poor performance, requiring fine-tuning to compensate. Research suggests that approximately 5% of the original training data is sufficient to achieve good results.
2.3.3 Communication
In multi-GPU parallel scenarios, MQA reduces memory access but increases parallel communication overhead. Because the K and V tensors are shared across all heads, each GPU needs its own backup. Compared to the MHA parallel strategy in Figure (a), MQA requires all-to-all pairs for input/output activation tensor resharding, resulting in additional communication costs, as shown in Figure (b). Furthermore, the presence of backups on each card may cause the memory cost savings of MQA to be lost.

0x03 GQA
For larger models, completely stripping away all heads is too radical. For example, reducing the number of heads from 64 to 1 is a much larger reduction in the model’s representational capacity than reducing it from 32 to 1. Moreover, according to experiments in the GQA paper, while MQA “drastically” improves inference performance in the decoder, it leads to a significant decrease in generation quality and causes training instability. Therefore, GQA was developed to accelerate the model while sacrificing some performance.

The graph above shows the evolution of self-attention mechanisms from 2022 to 2024. It can be seen that MHA is gradually being phased out and replaced by GQA.
3.1 Concept
Grouped Query Attention (GQA), proposed in the paper “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints,” improves the efficiency and effectiveness of information processing by grouping queries. The core improvement of GQA lies in allowing multiple queries to share a small number of keys and values, reducing computational overhead, and performing more efficient computation through a grouping mechanism.
GQA is a generalization between MHA and MQA, or a compromise between the two. MHA has H query, key, and value headers. MQA shares a single key and value header across all query headers. GQA, however, does not allow all query headers to share the same unique KV header. Instead, it divides all Q headers into (g) groups, with Q headers within the same group sharing a key header and a value header.
In the diagram below, the four Query Headers are divided into two groups, each containing two Query Headers. Each group also contains one K Header and one V Header. The group labeled 1 is the first group, and the group labeled 2 is the second.

The following diagram shows the formulas and process for GQA.

Su Shen pointed out that GQA is actually a (x_i) low-rank projection.

3.2 Architecture Comparison
GQA cleverly combines elements of MHA and MQA to create a more efficient attention mechanism. GQA interpolates between MHA and MQA, reducing the number of key-value heads from (n_heads) to (1 < g < n_heads) instead of changing the head number from (n_heads) to 1 KV header. This new parameter (g) can be expressed as follows:
Introducing the parameter (g) provides a unified perspective for GQA. From this perspective, MHA and MQA are special cases of GQA (corresponding to (g=1) and (g=n_heads) respectively).
- (g = 1): This is equivalent to MQA, which uses a shared key-value projection across all N headers.
- (g =) Number of attention heads: equivalent to MHA.
GQA can more smoothly balance model accuracy/KV cache size (related to latency and throughput) and the two extreme use cases of MHA and MQA. In other words, GQA is a small MQA within each group, while the inter-group interactions are traditional MHA.
MHA for large models replicates a single key and value header to a number of model partitions. MQA represents a significant reduction in memory bandwidth and capacity, while GQA allows us to maintain a proportional decrease in bandwidth and capacity as model size increases, providing a particularly good trade-off for larger models. GQA eliminates the waste associated with this fragmentation. Therefore, we expect GQA to provide a particularly good trade-off for larger models.
The diagram below illustrates the differences in their architectures.

3.3 Implementation
Most mainstream training and inference frameworks and algorithms now support MQA/GQA, such as FlashAttention, which also supports both. For MQA and GQA, FlashAttention uses indexing instead of directly copying multiple key-value (KV) headers to GPU memory for computation. Indexing involves passing the KV/KV header index to the kernel, calculating the memory address, and then directly reading the KV from memory.

Incidentally, GQA should not be used in the encoder self-attention layer, as the encoder representation is computed in parallel, so memory bandwidth is usually not the main bottleneck.
We will use llama3’s code for analysis. First, we will provide a simplified version for learning purposes, followed by the complete version.
3.3.1 Simplified Version
For better analysis, we provide a simplified version of the code below.
In MHA, the size of the Query, Key, and Value matrix is (batch_size, n_head, seq_length, hidden_size). In GQA, the Query size remains unchanged, but the Key and Value matrix size becomes (batch_size, n_head / group_size, seq_length, hidden_size). That is, in GQA, both the key and value are smaller than the query by a multiple of the group size. To facilitate subsequent matrix multiplication, there are generally two approaches:
- The broadcast mechanism can be used to adjust the shape of the QKV data, i.e., Query: (batch_size, n_head / group_size, group_size, seq_length, hidden_size), Key: (batch_size, n_head / group_size, 1, seq_length, hidden_size), Value: (batch_size, n_head / group_size, 1, seq_length, hidden_size). However, this requires broadcasting and final merging processing, necessitating several modifications to the MHA code.
- Extend GQA to MHA before computation, that is, first expand the tensor of
keyandvalueheadto the same dimension asqueryand then perform the computation.
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads # 设定组数目
self.head_dim = args.dim // args.n_heads
# 用self.n_kv_heads * self.head_dim初始化,当n_kv_heads小于n_heads时,参数量变少
self.wq = ColumnParallelLinear(args.dim, args.n_heads * self.head_dim,)
self.wk = ColumnParallelLinear(args.dim, self.n_kv_heads * self.head_dim,)
self.wv = ColumnParallelLinear(args.dim, self.n_kv_heads * self.head_dim,)
self.wo = RowParallelLinear(args.n_heads * self.head_dim, args.dim,)
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len,
self.n_local_kv_heads, self.head_dim,)).cuda()
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len,
self.n_local_kv_heads, self.head_dim,)).cuda()
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
'''
self.n_rep = q_heads // kv_heads
query头数大于KV的头数,一对KV对应多个query,需要把每个KV复制n_rep份,这样第2个维度就和q一样了
即,num_key_value_heads就是q_heads // kv_heads
repeat_kv方法将hidden states从(batch, num_key_value_heads, seqlen, head_dim) 变成 (batch, num_attention_heads, seqlen, head_dim),相当于是复制了self.num_key_value_groups份
'''
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
The code for the repeat_kv() function is as follows. Why use expand() followed by reshape() instead of directly using the repeat() function built into the tensor? Because using the expand() function can save a lot of GPU memory during computation.
expandThis method expands a tensor without actually allocating new memory. The tensor it returns shares the same data as the original tensor.repeatThis method expands the tensor by actually copying the data. The new tensor it returns does not share data with the original tensor, and the expanded tensor uses more memory.
# 定义输入x, n_rep是需要重复的次数,在这里一般是组数
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
# 第4维进行扩维,扩展成5维
x[:, :, :, None, :]
# first we expand x to (bs, seq_len, head, group, head_dim),即第4维从1扩展为n_rep
.expand(bs, slen, n_kv_heads, n_rep, head_dim) # 进行广播,k,v向量共享
# reshape make head -> head * group,缩成4维,即把第3维从n_kv_heads扩展n_rep份
# 这样第3个维度就和q一样了
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
3.3.2 Full Version
The complete code is as follows.
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(
keys, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(
values, self.n_rep
) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(
1, 2
) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
Additionally, a common optimization technique for the decoding phase of MQA and GQA is to merge all QO headers sharing a common key-value header with the query row count (because they need to perform attention calculations with the same key-value cache). This increases the effective row count and operator density. Although the query length is 1 in the autoregressive decoding phase, after Head Group fusion, the effective row count increases significantly to (H_{QO}/H_{KV}).

3.4 Effects
3.4.1 Memory
GQA can significantly reduce the size of the KV cache during the inference phase, providing room for a larger batch size and further improving throughput.
Under MHA, for each token in all input batches and sequences, the total size of the KV cache can be expressed by the following formula:
- B represents batch size.
- L represents the total sequence length, which is the input sequence plus the output sequence, or the hint plus the completed part.
- H Representative number of head,
- D represents the size of the head, the dimension of each head.
- N represents the number of layers.
Under MQA, each token corresponds to:
Under GQA, each token corresponds to:
For a detailed comparison, please refer to the image below, where (g) is the number of KV header groups ((n_h/g) Each Header shares a single Key-Value pair (h is the number of Headers in the query). (d_k) is the head dimension, (l) is the number of layers, (s) is the sequence length, and (b) is the batch size.

The benefits of GQA and MQA implementations on GPUs primarily stem from the reduction in KV cache, allowing for the storage of more tokens. However, the performance of GQA and MQA is susceptible to the impact of parallelization strategies. If the GQA kernel performs parallelization at the Q head level (where a Q head is a block), blocks sharing the same KV head will be scheduled across different SMs, with each SM repeatedly loading the same KV head. This significantly reduces the memory reduction gains. Furthermore, KV loading is a bottleneck for MHA and GQA. Therefore, it is necessary to reduce the parallelism of the Q head.
3.4.2 Speed
GQA does not reduce the computational cost (FLOPs) of Attention because the Key-Value mapping matrix is expanded as broadcast variables to be the same as in MHA, thus keeping the computational cost unchanged; only the Key-Value parameters are shared. However, because GQA divides the query matrix Q into multiple groups, calculating the attention score and weighted sum for each group separately, each attention head only needs to calculate the attention score for a portion of the query, thereby reducing computational complexity, especially when dealing with long sequences. Therefore, although the computational cost of QKV in GQA is not reduced, the speed is greatly improved, for the same reasons as in MQA.
3.4.3 Characterization ability
GQA retains some of the expressive power of multi-head attention while accelerating inference speed by reducing memory access pressure.
The paper “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” investigates the accuracy and inference efficiency of a model. The authors used the T5 model as their research subject, employing both T5-Large and T5-XXL versions. In the figure below, the horizontal axis represents the average inference time per sample, with a larger value indicating greater latency. The vertical axis represents the evaluation score across numerous datasets, with a larger value indicating a higher score.
The figure below shows that MQA slightly sacrifices model accuracy but significantly reduces inference overhead, while GQA, with an appropriate number of groups, achieves both. GQA’s representational ability is significantly higher than MQA, almost identical to MHA (GQA may still result in some accuracy loss), and its inference speed is not much different from MQA, still showing a significant improvement over MHA. The number of groups in GQA is a hyperparameter; a larger number of groups brings it closer to MHA, resulting in greater inference latency but also higher model accuracy. Additionally, increasing model depth can mitigate the decline in model performance.

3.5 Conversion
Although most of the latest models use GQA by default during the pre-training stage, we can also consider how to convert a pre-trained MHA model into MQA or GQA.
3.5.1 Average Pooling
If we are continuing to train a multi-query model from an existing multi-head model (uptraining), we can group the MHA heads and construct the key-value headers for each group by performing mean pooling on all the original headers in that group, and then continue pre-training. Experiments have shown that mean pooling performs better than selecting the first head or arbitrary initialization. This training process is called uptraining.

The specific reference code is as follows.
import torch.nn as nn
n_heads=4
n_kv_heads=2
hidden_size=3
group = n_heads // n_kv_heads
k_proj = nn.Linear(hidden_size, n_heads)
# mean pool操作
k_proj_4d = k_proj.weight.data.unsqueeze(dim=0).unsqueeze(dim=0)
pool=nn.AvgPool2d(kernel_size=(group,1))
pool_out = pool(k_proj_4d).squeeze(dim=0).squeeze(dim=0)
k_proj_gaq = nn.Linear(hidden_size, n_kv_heads)
k_proj_gaq.weight.data = pool_out
3.5.2 Mask-based
The paper “Align Attention Heads Before Merging Them: An Effective Way for Converting MHA to GQA” proposes a low-cost method to prune MHA models into GQA models with arbitrary KV head compression ratios. This method is based on (L_0). Redundant parameters are gradually removed using masks. Furthermore, without altering the model, an orthogonal transformation is applied to the attention heads to increase the similarity between them before training, thereby further optimizing model performance.
The specific plan consists of the following steps: network transformation; grouping; and pruning training.
Network conversion
This step involves transforming the model before pruning training. The specific process is roughly as follows:
- Using a portion of the C4 training set to collect the corresponding KV Cache allows for more effective analysis of the KV Cache.
- Calculate the optimal orthogonal matrix based on cosine similarity or Euclidean distance.

- The calculated orthogonal matrices are fused into the corresponding Q, K, and V projection matrices to ensure computational invariance. Due to RoPE, orthogonal transformations are applied to the projection matrices of Q and K in their respective subspaces.

Orthogonal transformations can make different Attention Heads within the same group closer in the feature space, making it easier to find suitable parameter sharing methods during subsequent pruning training and improving the model’s compression effect and performance.
Find a better grouping method
After obtaining the similarity scores between each pair of Attention Heads, the Attention Heads can be regrouped based on these scores. The similarity score of a single group is the sum of the similarity scores between each pair of Attention Heads within that group, while the total similarity score for each grouping result is the sum of the similarity scores of all groups. The goal of the algorithm is to find the grouping method with the highest score.

A reasonable grouping method can make the Attention Heads in the same group more similar in the feature space, making it easier to find a suitable parameter sharing method during pruning, thereby improving the compression effect and performance of the model.
Pruning training
This step involves pruning the training process, gradually transferring the original key-value head to a new key-value head while maintaining model performance. As shown in the diagram below, the specific process includes:
- Add a new projection matrix: Initialize a new projection matrix within each group using Mean Pooling.
- application (L_0) Mask: Introduction (L_0) A mask is used to control the conversion between the original KV Head and the new KV Head. Initially, the mask value is 1, indicating that the original KV Head is used; during the pruning process, the mask value is gradually constrained to 0 (indicating that the new KV Head is used).
- Knowledge distillation: Using KL loss and BiLD loss, we encourage student models to align their outputs with those of teacher models, thereby preserving model performance.

3.6 Optimization
The paper “A Survey on Large Language Model Acceleration based on KV Cache Management” summarizes MQA, GQA, and their improvement schemes, as shown in the figure below.

Several improvement plans are as follows.
- Weighted GQA introduces additional trainable weights for each key and value header, which can be seamlessly integrated into existing GQA models. By adjusting the weights during training, it can improve model performance without increasing inference overhead.

- AsymGQA extends GQA by proposing an activation-informed merging strategy. Instead of grouping heads using uniform clustering, AsymGQA dynamically determines grouping based on activation similarity during training and constructs asymmetric groups, thereby achieving better optimization and generalization.

- QCQA uses an evolutionary algorithm to identify the optimal query head grouping for GQA. This algorithm is guided by a computationally efficient fitness function that uses weight-sharing error and key-value caching to evaluate text generation quality and memory capacity.

- KDGQA argues that many variants of GQA employ fixed grouping strategies, thus lacking dynamic adaptability to the evolution of key-value interactions during training. Their Dynamic Key-Driven GQA addresses these issues by adaptively grouping query headers using key head norms during training, resulting in a flexible strategy for grouping query headers and improving performance.

- GQKVA proposes a grouping strategy and a general mechanism for grouping queries, keys, and values. It first introduces MKVA and GKVA, where keys and values are grouped to share the same query. Building on this, the paper proposes using GQKVA to separately group queries and key-value pairs. Typically, queries are divided into (g_q) Groups, key values are divided into (g_{kv}) Each combination of groups, queries, and key-value pairs interacts using dot product attention. This results in (g_q \times g_{kv}) Different outputs are produced. GQKVA generalizes different grouping strategies for queries, keys, and values while maintaining good computational efficiency and performance comparable to MHA. The figure below illustrates the various strategies for grouping queries, keys, and values in attention mechanisms, including Vanilla MHA, MQA, GQA, MKVA, GKVA, and GQKVA.

0xFF Reference
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
[LLM Acceleration Techniques] Multi Query Attention and Attention with Linear Bias (with Source Code) He Zhi
https://github.com/meta-llama/llama3
A 20,000-word article! A comprehensive guide to Attention, from MHA to DeepSeek MLA, with numerous diagrams and detailed explanations! ShuYini [AINLPer](javascript:void(0)😉
From MHA, MQA, GQA to MLA Su Jianlin
Alibaba’s first-round coding question: “Implement GQA” ( See the image for more information) [See the image for more information](javascript:void(0)😉
MHA -> GQA: Improving LLM Inference Efficiency AI Chat [AI Chat](javascript:void(0)😉
Align Attention Heads Before Merging Them: An Effective Way for Converting MHA to GQA
FLASHINFER: EFFICIENT AND CUSTOMIZABLE ATTENTION ENGINE FOR LLM INFERENCE SERVING
Kernel design of DeepSeek MLA in FlashInfer yzh119
The Taizu Long Fist of Parallel Inference in Large Models: Deciphering Jeff Dean’s MLSys 23 Outstanding Paper (by Fang Jiarui)
Code transfer for intuitive analysis of MHA, GQA, and MQA on GPUs, triggered by GQA performance data anomalies.
The evolution path from MHA to MQA to GQA to MLA : If I were given an AI
Y. Chen, C. Zhang, X. Gao, R. D. Mullins, G. A. Constantinides, and Y. Zhao, “Optimised Grouped-Query Attention Mechanism for Transformers,” in Workshop on Efficient Systems for Foundation Models II @ ICML2024, Jul. 2024. [Online]. Available: https://openreview.net/forum?id=13MMghY6Kh
S. S. Chinnakonduru and A. Mohapatra, “Weighted Grouped Query Attention in Transformers,” Jul. 2024. [Online]. Available: http://arxiv.org/abs/2407.10855
V. Joshi, P. Laddha, S. Sinha, O. J. Omer, and S. Subramoney, “QCQA: Quality and Capacity-aware grouped Query Attention,” Jun. 2024. [Online]. Available: http://arxiv.org/abs/2406.10247
Z. Khan, M. Khaquan, O. Tafveez, B. Samiwala, and A. A. Raza, “Beyond Uniform Query Distribution: Key-Driven Grouped Query Attention,” Aug. 2024. [Online]. Available: http://arxiv.org/abs/2408.08454
F. Javadi, W. Ahmed, H. Hajimolahoseini, F. Ataiefard, M. Hassanpour, S. Asani, A. Wen, O. M. Awad, K. Liu, and Y. Liu, “GQKVA: Efficient Pre-training of Transformers by Grouping Queries, Keys, and Values,” Dec. 2023. [Online]. Available: http://arxiv.org/abs/2311.03426