Transformer Systems · Transformer Systems

Exploring the Transformer Series (11) --- Mask

Transformer masks: padding mask, sequence/causal mask, implementation details, data flow, and advanced sample-packing strategies.

Attention And Positional Informationadvanced1.7 hrReview deck
transformermaskpadding-maskcausal-maskself-attentionsample-packing

Exploring the Transformer Series (11) --- Mask

0x00 Overview

In machine learning, a mask is essentially a tensor of the same size as the target tensor (mostly binary, 0-1). The idea originated from the CBOW training mechanism of word2vec: predicting the center word based on context. A mask essentially conceals the center word. Different tasks and applications may require different types of mask operations. In self-attention models, two common mask operations are padding mask and sequence mask.

  • Padding mask: When processing variable-length sequences, special padding symbols (such as <pad>) are usually added to the end of the sequence to maintain a consistent length. The purpose of the padding mask is to set the attention score at the position corresponding to these padding symbols to a very small value (such as negative infinity), so that the model ignores these padding symbols when calculating the attention score, thus avoiding interference from the padding symbols in the calculation.
  • Sequence mask: In some tasks, to prevent the model from seeing future information when generating sequences, it’s necessary to mask the attention scores. The purpose of a sequence mask is to construct a lower triangular (or upper triangular) attention score matrix, setting the attention score for positions after the current position to a very small value. This forces the model to focus only on the attention relationship between the current token and previous tokens, ignoring its relationship with subsequent tokens. This ensures that the model only relies on already generated parts when generating sequences, unaffected by future information; it only “sees” the current and preceding tokens. Sequence masks are also sometimes called casual masks.

The self-attention mechanism that uses a mask is called the masked self-attention mechanism, also known as causal self-attention.

1101

0x01 Requirements

Let’s analyze in detail why we need a mask.

1.1 Avoiding Deviation

Actual situation

During the training of a neural network, a single batch may contain multiple text sequences, and these sequences may not necessarily have the same length. The neural network input requires a regular tensor. To conform to the model’s input method, during dataset generation, we need to align the input sequences to ensure that all sequences within the same batch have the same length. Specifically:

  • However, if the input sequence is too long, we will truncate the left side and discard the extra words.
  • Shorter sequences are padded with special symbols (e.g., fill in the blanks).

See the diagram below for details. The diagram illustrates the situation encountered when combining multiple sentences into a batch: the sentences have different lengths. We need to pad or trim all sentences according to a pre-defined maximum length to form multiple sentences of the same length before we can combine them into a batch (a three-dimensional tensor) and feed it into the model for training.

1102

The problem

The above approach encounters a problem when calculating attention: if information about the filling position is considered during the attention calculation process, it will introduce errors into the final result. Let’s analyze this in detail.

Suppose a certain row vector is . An element in a row vector is . The native softmax calculation formula is as follows:

The algorithm requires two loops. First, it iteratively calculates the sum of the denominators. Then, iteratively calculates the softmax value for each value in the vector, effectively scaling each element. Since the padding words are manually added and essentially meaningless, we generally don’t want the model to focus its attention on these irrelevant words when calculating the attention score, thus avoiding wasting computational resources. We also don’t want these positions to participate in the later backpropagation process, thus preventing them from affecting the model’s final performance. However, in practice, the padding value is generally 0. Since , the padding portion in the softmax function is included in the calculation. The inclusion of these invalid parts in the calculation poses a significant risk, causing deviations in the attention score and affecting the global probability value. Therefore, we need to perform some processing.

Solution

The intuitive solution is that the model should focus its attention on words that actually make sense. Therefore, we need to find all non-padd tokens and then only compute the loss function for these non-padd tokens. Alternatively, we can think of it the other way around: use a mask to exclude these invalid regions from the computation. This is what we call a padding mask.

1.2 Preventing Peeping

Actual situation

First, let’s recall the attention calculation formula as follows. We need to calculate the attention for the entire input sequence.

Secondly, encoders and decoders operate in different ways:

  • Because an encoder needs to encode the entire sentence, each word needs to consider its context. Therefore, each word can see all the words in the sentence during the calculation process.
  • However, the decoder is essentially a unidirectional self-attention structure, where each word can only see the state of the preceding word. The reason is as follows: the inference phase is autoregressive, with input word by word, and the decoder is unaware of the following information. Therefore, each time the decoder can only see the token and prompt it generated previously, and thus cannot calculate the attention between the current word and the words that haven’t yet appeared in the following sequence.

The decoder’s operating mode necessitates special handling during training. Because it employs an autoregressive model during training, the training speed can be too slow. As mentioned earlier, to accelerate training, Teacher Forging was adopted. This uses a matrix-parallel algorithm similar to that in encoders, predicting all target words in one step. This approach has two advantages: first, multi-sample parallel computation speeds up network training; second, directly feeding the decoder with correct results instead of the previous time step’s predictions (since the previous predictions might be incorrect during training) allows for faster convergence.

Let’s temporarily forget about Teacher Forcing and assume we’re focusing on parallel computation. The simplest training method would be to construct samples (each sample being a complete prediction sequence) based on a prediction sequence of length , and then input these samples into the model in parallel. For the first sample, the model would predict the first character, the model uses the first characters to predict the th character for the last sample. The label already provides the teacher token for the current time step. This allows the decoder to perform parallel computation, similar to the encoder. That is, it receives input from all time steps during decoding at once and then synchronously predicts the token at each position.

1103

The problem

Currently, each sample actually includes the entire sentence. However, when decoding at time , the decoder can only use the input from times to , and cannot use the input from time and beyond. In other words, the model should only make predictions based on a portion of the input. The problem lies in feeding the entire sentence (the complete target sequence) to the decoder all at once, because the model already knows the entire sentence content. Therefore, when predicting a word at a certain position, the decoder can use the target words before and after that word. This allows the decoder to “cheat” by using target words from future “time steps.” For example, predicting “I love China” based on “I love.” When outputting “love,” the model will use the information about “China” at the end.

As the saying goes, “Heavenly secrets should not be revealed.” If a model could predict its next output, it could easily learn to be lazy. It wouldn’t bother calculating the output; it could simply use the next element of the input sequence as the output, rendering training ineffective. Furthermore, because attention layers have multiple layers, in the first layer, the current token integrates the next token information; when the token is calculated in the next attention layer, will see that already includes . This uses one’s own information to predict oneself, which is clearly a form of information leakage.

Therefore, during training, the decoder should not know the following information in advance, and should not calculate the attention between the current word and the following words, but only the attention between the current word and all preceding words.

Solution

To ensure the model isn’t influenced by future words at a given time step, the decoder employs a sequence mask. Its function is to hide information after time step . This allows the decoder to see only a portion (the prefix) of the target sequence, preventing it from seeing future information. In other words, for a sequence, our decoding output should only depend on the output before time step , not on the output after . This is the essence of a sequence mask. You can think of this process as a timeline: when predicting a specific word, you cannot “predict” the words that follow it because, in reality, the subsequent parts haven’t occurred yet.

1104

In summary, the purpose of the padding mask is to avoid deviations caused by padding symbols. The purpose of the sequence mask is to shield future information, prevent peeking, and ensure that each position can only see the preceding tokens.

0x02 Padding Mask

Next, let’s see how Padding Mask is implemented.

2.1 Logic

The core logic is to ensure that the filler words should not have a corresponding output after the softmax operation.

Mask matrix

One method researchers found was to use a mask matrix during training. For sentences already padded to the same length in a batch, a mask matrix is used to cover the positions that need to be padded before applying the softmax function. There are different implementations of the mask matrix:

  • Each value in the matrix is a Boolean, and the values that are false are the places where we need to process the data.
  • In the mask matrix, a very large negative number (such as -1e9) is placed at the corresponding position of the filler word, otherwise 0 is placed.

After the masking matrix is processed, the attention scores obtained by these masked positions after passing through the softmax activation function will be zero or close to zero. In this way, the token representations at the corresponding positions will not participate in the weighted summation process mentioned above, that is, no attention is allocated to them, and they will no longer affect the prediction of the global probability.

Calculate attention steps

The specific steps for attention calculation after adding a mask are as follows:

  1. Create a mask matrix. If a position in the input sequence is a filler word, place a very large negative number (e.g., -1e9) at the corresponding position in the mask matrix; otherwise, place 0.
  2. The mask matrix is added to the attention score. Because the positions of the filler words in the mask matrix are very large negative numbers, adding them will also make the attention scores at these positions very large negative numbers.
  3. Apply the softmax function. Apply the softmax function to the masked attention score. Since the scores at the filler word positions are very large negative numbers, after applying the softmax function, the weights at these positions will be close to 0, while the weights at other positions will remain unchanged (because softmax is a normalization function).
  4. Calculate the weighted sum. Using the output of softmax as weights, calculate the weighted sum of the values.

In the image below, the top part shows the masking operation corresponding to the encoder input, and the bottom part shows the masking operation corresponding to the decoder input.

1105

2.2 Implementation

Let’s analyze the Harvard code. To better illustrate this, we’ll include the padding code as well.

Set fill symbol

Taking the target sentence as an example, the collate_batch() function adds a mask to the target sentence when loading data.

def collate_batch(
    batch, # 句子对的列表
    max_padding=128, # 句子最大长度
    pad_id=2,
):    
    	# 省略其它代码
        
        processed_tgt = torch.cat( # 获取目标句子
            [
                bs_id,
                torch.tensor(
                    tgt_vocab(tgt_pipeline(_tgt)),
                    dtype=torch.int64,
                    device=device,
                ),
                eos_id,
            ],
            0,
        )
        
        """
        调用torch.pad()函数对processed_src进行处理,如果processed_src的长度小于max_padding,则使用pad_id进行填充,如果大于max_padding,则截断。
        然后把处理后的processed_tgt加入到tgt_list。
        """
        tgt_list = []
        tgt_list.append(
            pad(
                processed_tgt,
                (0, max_padding - len(processed_tgt)),
                value=pad_id,
            )
        )

		# 省略其它代码

Create a mask

Here, we’ll extract the mask-related part of the Batch class for further analysis. The statement that generates src_mask is quite simple, consisting of only self.src_mask = (src != pad).unsqueeze(-2) this one line of code, which mainly serves two purposes:

  • Set the non-pad parts of the src element to True and the pad parts to False. For example, if a sentence contains the string [0, 3, 1, 2, 2], its corresponding mask would be [True, True, True, False, False]. "" and "" are considered sentence components, so they are not masked.
  • The unsqueeze() function is used to add a dimension because src_mask will be used for masking calculations with the attention score, which has three dimensions, so consistency is maintained here. The final src_mask returns a boolean matrix with the shape [batch size, 1, longest sentence length]. The i-th row and j-th column indicates whether the attention of the i-th word in the query to the j-th word in the key is meaningless. If meaningless, it’s True; if meaningful, it’s False (i.e., the padding position is True). Later, when processing the mask, positions marked False need to be masked, while positions marked True remain unchanged. After this processing, placeholders can no longer absorb the query’s attention.
class Batch:
    def __init__(self, src, tgt=None, pad=2):  # 2 = <blank>
        self.src = src # 源语言句子列表,形状是[batch_size,Length]
        # 创建源语言的掩码,这样可以忽略填充部分,unsqueeze()的作用是增加一维度,因为后续要和注意力分数进行掩码计算,而注意力分数是三个维度,所以这里要保持一致。
        # (src != pad)返回一个等大的布尔张量,src元素等于pad的位置为False,否则为True
        # unsqueeze(1)作用是增加了一个维度,变成pad_attn_mask: [batch_size,1,seq_len]
        # 最终得到返回一个[batch_size, 1, seq_len]大小的布尔张量,False是需要mask掉的位置
        self.src_mask = (src != pad).unsqueeze(-2) 

Implement mask

The code that applies the mask matrix is located in the attention() function. Note that here, both the padding mask and the sequence mask are applied together.

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    
    # 先计算注意力分数
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 在query和key的转置相乘得出(len_q,len_k)这个注意力分数矩阵以后,使用mask来掩盖相乘结果矩阵,此处把创建掩码矩阵和应用掩码矩阵合二为一
    if mask is not None:
        # 如果发现mask是0,就用-1e9来替换它
        scores = scores.masked_fill(mask == 0, -1e9)
        
    # 然后才开始实施softmax操作    
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

0x03 Sequence mask

3.1 Logic

The core logic of sequence masking is to mask information after the current time step during decoding. Therefore, we need a way to hide the information after time step t. Sequence masking only applies to the training and inference phases of the autoregressive model prefill; the inference decode phase does not require masking. However, for ease of implementation, the same code is still used.

Mask matrix

We need to generate a mask matrix, which we will add when calculating attention. By designing a suitable mask, we can cut off the path for each element to obtain information from the future (by forcing the corresponding attention to zero), thus shielding or limiting the model’s attention to certain positions when calculating the attention score. The characteristics of this mask matrix are as follows:

  • The shape of this matrix is the same as that of the attention distribution matrix, with a size of [seq_len, seq_len].
  • From the matrix content, this is a lower triangular matrix. The content depends on the specific situation. If it’s a Boolean matrix, the values in the upper triangle can all be 0, the values in the lower triangle can all be 1, and the diagonal can also be 1. If it’s a floating-point matrix, the values in the upper triangle can be assigned negative infinity. This allows for individual adjustment of the attention intensity between each source element and each target element.
  • Before performing the softmax calculation, this matrix is applied to each sequence. That is, in a mask is applied to the dot product, and the masked elements are set to negative infinity (indicating that they are “infinitely dissimilar,” i.e., unrelated). This means the inner product of query(t) and the key at a future time step is set to negative infinity (-inf).
  • During the softmax operation, the model sets the weights corresponding to these negative infinity values to zero. When multiplying by V subsequently, the current position no longer has access to subsequent word information. Therefore, calculating the probability at time t only uses information from key-value pairs from time points t-1 onwards.
  • This operation allows us to calculate the loss of the entire Decoder output sequence at once, instead of calculating it for each token individually. This process is what we previously referred to as Teacher Forcing.

The Mask matrix example is shown below; it is a 10-dimensional lower triangular matrix. When decoding the first character, the correlation between the first character and the first character can only be calculated. When decoding the second character, the correlation between the second character and the first and second characters can only be calculated.

[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
 [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
 [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
 [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
 [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
 [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
 [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]

The specific formula is as follows.

1106

Mask self-attention

Next, let’s look at Masked Self-Attention. In the Decoder Block, the input sequence first encounters Masked Self-Attention (masked means to cover). The Q, K, and V values of Masked Self-Attention all come from the same part, satisfying max_len_q = max_len_k_v = max_len. The difference between masked multi-head self-attention and the multi-head attention calculation process described above is that the score matrix is fed into the softmax function to calculate the weight matrix before undergoing a masking operation. That is, each word in the sentence can only perform attention on all preceding words, including itself; this is essentially a unidirectional Transformer. This also shows us the design motivation of Masked Self-Attention: to prevent the model from seeing inputs at future time steps and to ensure that the decoder operates in the same way during training and prediction.

We use the first decoder layer to explain its operation sequence as follows.

  1. After input embedding and position encoding, the word embedding is obtained, .
  2. multiplies by the three weight matrices respectively, . After three linear transformations, we obtain matrix.
  3. matrix multiplication transpose of the matrix is obtained, . This refers to the distribution of attention scores.
  4. multiplying by a Mask matrix and multiplying by each element yields the distribution of the masked attention scores, . This process preserves what should be seen during the decoding process and hides what should not be seen or should not be seen. Specifically, it keeps the lower triangular portion of the score matrix unchanged while masking the entire upper triangular portion, setting it to negative infinity. After this processing, the i-th row of the score matrix, corresponding to the i-th time step of q, only retains the relationship score between q and k from the previous i time steps; the rest is masked.
  5. after the softmax operation, we get:

Obviously, the parts that are masked (set to -inf) become 0 (infinitely close to 0) after softmax processing. That is, in the i-th row of the weight matrix, the sum of the first i weights is 1, and the weights after that are all 0.

  1. multiply matrix is ultimately obtained, . The masked weights are multiplied by the V matrix. As discussed earlier, the i-th row of the matrix is the result of a weighted average of all rows in V based on the weights in the i-th row of the weight matrix. However, after masking, only the first i weight values remain in the i-th row of the weight matrix. In other words, the i-th row of the context matrix is actually the result of a weighted average of the first i rows of V.

Furthermore, the first word of each sentence in Y is the encoding of the start symbol, so the actual time step of the information in Y is shifted forward by one position. Therefore, in the masked multi-head self-attention structure, when calculating the context information of the i-th time step, only the information of the previous i-1 time steps is actually used.

The above describes the matrix obtained from a single attention head, . If it’s multi-headed attention, then multiple after splicing, a linear transformation is performed to obtain the final result matrix.

1107

Cross attention

Now consider this question: Does the cross-attention process following masked attention also need an attention mask? The answer is no.

The masked qkv vector input to the MultiHeadAttention module inside the decoder comes from the target word embedding or the output of the previous decoder; all three are the same. However, in subsequent MultiHeadAttention qkv vectors, kv comes from the input of the last encoder layer, while q comes from the output of the masked MultiHeadAttention module. Because the encoder can see the entire input sequence and has obtained all the information, the decoder’s Q vector can see all the K and V of the context vector. In other words, during training and prediction, we allow the decoder to see all the information of the target sequence input; this information does not need to be masked. However, in practice, a src_mask, which is the source language padding mask, is still needed.

In summary, for the decoder, the two masks are actually merged, and the minimum value is taken at each position. This means that if either mask needs to be masked in any case, then it should be masked. See the diagram below for details.

1108

3.2 Implementation

Generate mask

Here we will extract the mask-related part of the Batch class and analyze it further.

The statement that generates the src_mask is quite simple, consisting of only self.src_mask = (src != pad).unsqueeze(-2) this one line of code. The details are explained in the Padding mask implementation above, and will not be repeated here.

Generating the tgt_mask is more complex, and the specific logic is in the make_std_mask() function. The tgt_mask differs slightly from the src_mask; besides covering the pad portion, it also needs to cover the upper right diagonal. This involves combining the mask corresponding to the fill word with the mask related to future words. The logic of the make_std_mask() function is as follows:

  • First, generate the mask corresponding to the filler words. Suppose the content of a sentence is [0, 3, 1, 2, 2], then its corresponding mask is [True, True, True, False, False].
  • Then the subsequent_mask() function is called to generate a mask related to future words. This is a matrix where the diagonal and the area below it are all True. The specific mask is as follows.
[[
  [ True, False, False, False, False ],
  [ True, True, False, False, False ],
  [ True, True, True, False, False ],
  [ True, True, True, True, False ],
  [ True, True, True, True, True ],
]]
  • Finally, the mask corresponding to the fill word and the mask related to the future word are ANDed together to obtain the final mask as follows:
[[
  [ True, False, False, False, False ],
  [ True, True, False, False, False ],
  [ True, True, True, False, False ],
  [ True, True, True, False, False ],
  [ True, True, True, False, False ],
]]

The source code for the make_std_mask() function is as follows.

@staticmethod
def make_std_mask(tgt, pad):
    "Create a mask to hide padding and future words."
    
    # 生成填充词对应的掩码,用于忽略填充部分
    tgt_mask = (tgt != pad).unsqueeze(-2) # 创建目标语言的掩码,用于忽略填充部分
    
    """
    subsequent_mask()函数会生成未来词汇相关的掩码。然后再和tgt_mask进行与操作,得到最终掩码
    tgt.size(-1) 表示的是序列的长度
    """
    tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(
        tgt_mask.data
    )
    return tgt_mask

The source code for the subsequent_mask() function is as follows.

def subsequent_mask(size):
    """
    Mask out subsequent positions.
    该方法在会在构建tgt的mask时使用。
    """
        
    # 首先需要定义掩码张量的形状,具体会生成一个Shape为(1, size, size)的矩阵
    # 前面加个1是为了和tgt的维度保持一致,因为tgt的第一维是batch_size        
    attn_shape = (1, size, size)
       
    # 首先使用torch.triu()函数产生一个上三角阵,几个注意点是:
    # 1. diagonal=1意为不包含主对角线(从主对角线向上偏移1开始)   
    # 2. 使用np.ones方法向矩阵中添加1元素,形成上三角阵(左上角全为1)
    # 3. 为了节约空间, 使上三角阵的数据类型变为unit8     
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )

    # subsequent_mask == 0其实是做了一个三角阵的反转, subsequent_mask中的每个元素都会被1减,这样将 0全部变为True, 1变为False
    return subsequent_mask == 0

Let’s print the output and see. The result of print(subsequent_mask(5)) is as follows.

tensor([[[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]]])

It outputs a square matrix where the diagonal and bottom left edges are all True, and the top right edges are all False. The first row has only the first column as True, meaning that at time 1, you can only attend to input 1. The third row indicates that at time 3, you can attend to inputs 1, 2, and 3, but not inputs 4 and 5, because for the decoder, these are future information.

Apply mask

The padding mask is applied together with the previous one, which will not be elaborated on here.

3.3 Transformer

Let’s take a look at the Transformer code. It’s basically the same as the Harvard approach, except that a key-value cache has been added.

# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
    input_ids_shape: torch.Size,
    dtype: torch.dtype,
    device: torch.device,
    past_key_values_length: int = 0,
):
    """
    Create a causal mask for bi-directional self-attention.

    Args:
        input_ids_shape (torch.Size): The shape of input_ids tensor, typically (batch_size, tgt_len).
        dtype (torch.dtype): The data type of the mask.
        device (torch.device): The device on which the mask will be placed.
        past_key_values_length (int, optional): The length of past key values. Default is 0.

    Returns:
        torch.Tensor: The causal mask tensor.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat(
            [
                torch.zeros(
                    tgt_len, past_key_values_length, dtype=dtype, device=device
                ),
                mask,
            ],
            dim=-1,
        )
    return mask[None, None, :, :].expand(
        bsz, 1, tgt_len, tgt_len + past_key_values_length
    )

0x04 Data Stream

The Harvard code combines two types of masks using two variables, and the combination of encoder and decoder modules makes it difficult to understand. Let’s examine the data flow more closely. In general, the requirements for the two masks in the encoder and decoder modules are as follows:

  • For the Encoder, it should not be noted this part is not a sentence component. However, there is no need to prevent “peeping into future information.”
  • For a decoder, the preceding word should not be considered in relation to the following word, and at the same time, it should not be considered in relation to padding. The padding mask and sequence mask can coexist.

Here’s another table where you can see the characteristics of the two variables in the code.

Variable namemask typeEncoder Self-attentionDecoder masked self-attentionDecoder Cross-attention
src_maskPadding Maskuse(Padding functionality is handled in tgt_mask)use
tgt_maskPadding Mask + Sequence MaskDo not useuseDo not use

4.1 How to apply it to attention

Let’s first look at which type of attention in which module the two types of masks should logically be used for.

Padding mask. Padding masks are used wherever there is padding, so both the encoder and decoder have padding masks.

  • Because encoding doesn’t require masking information beyond the current time step, information from any position can be obtained by words at any position. Therefore, the encoder’s mask is simply a padding mask. This is used in self-attention.
  • For the decoder:
    • Padding masks are used in cross-attention.
    • Padding masks are used in self-attention.

Sequence Mask (Attention Mask)

  • The decoder’s cross-attention does not require a sequence mask because the encoder’s output, as K and V, already contains all the information about the sequence.
  • Sequence masks are used in the decoder’s self-attention. In the decoder’s self-attention, the mask’s role is to prevent the decoder from “peeking” at the remaining time steps of the target sentence during prediction at the current time step. Therefore, for the scaled dot-product attention used in the decoder’s self-attention, both a padding mask and a sequence mask are needed as the attn_mask. Specifically, the two masks are added together to form the attn_mask.

In fact, in cross-self-attention, if we want to restrict the decoder to only access a portion of the encoder information, i.e., memory bandwidth, we can also use a mask. PyTorch has a memory mask, but in general scenarios, we allow the decoder to access all encoder information, so memory masks are not commonly used.

4.2 Variable Description

In the code, there are two variables related to the mask: src_mask and tgt_mask. The encoder only looks at src_mask. The decoder looks at both src_mask and tgt_task. src_mask is the padding mask, while tgt_mask is a blended mask that includes both the padding mask and the sequence mask.

Setting the mask in the Batch class code involves two steps. After these two steps, tgt_mask becomes the fusion mask. These two steps are:

  1. Step 1: Set the padding mask;
  2. The second step is to set the sequence mask under the padding mask constraint;

The specific code is:

def make_std_mask(tgt, pad):
    "Create a mask to hide padding and future words."
    # 一定要注意,这里有两步
    tgt_mask = (tgt != pad).unsqueeze(-2) # 第一步,设定padding mask
    tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(
        tgt_mask.data
    ) # 第二步,设定padding mask限定之下的sequence mask
    return tgt_mask

The shape of src_mask is (batch size, 1, 1, seq_length), because:

  • The src vector needs to mask the filler words in the sentence, so we only need to mask the last dimension. In other words, a single vector is sufficient.
  • Since all heads have the same mask, the second dimension is 1, so broadcasting can be used when masked_fill.
  • This is a self-attention mask, so each time step can attend to all other time steps, hence the third dimension is also 1, and broadcasting is also used.

The shape of tgt_mask is (batch size, 1, seq_length, seq_length). tgt needs to be masked diagonally, so a square matrix is needed to represent several time steps.

Encoder data stream

Let’s take an example. For simplicity, assume batch=2, head=2, and the maximum allowed sequence length is 5. The first sequence has a length of 3, and the second has a length of 5. These are as follows:

  • [<bos>, 你, <eos>, <pad>, <pad>]
  • [<bos>, 你, 好, 吗, <eos>]

The mask in the encoder is simply a padding mask. This is because the information at the padding positions doesn’t need to carry weights to interfere with the embedding representation of words. The mask shape is (2, 1, 1, 5), which can be represented by two vectors:

  • The first vector is {1,1,1,0,0}. This means that the first three parts of the first sentence are words, and the last two parts are fillers. The “mask” part is applied to the last two parts. Therefore, at any point in this sequence, attention can be calculated by interacting with the previous three points.
  • The second vector is {1,1,1,1,1}. This means that any word in this sequence can interact with the input at all times.

In actual operation, because there are multiple heads, for the first sequence, the two heads will be broadcast first, resulting in the following.

Then a mask will be applied, resulting in…

For the second sequence, both headers are broadcast, and the contents of the sequences before and after the mask are the same.

Decoder data stream

The decoder’s self-attention mask requires a combination of both the padding mask and the sequence mask as the attn_mask. This is because the decoder module must consider not only the mask caused by padding but also the issue of peeking at subsequent words.

  • The answers are input together, but the actual deployment scenario is step prediction. Theoretically, the information of words after the current step cannot be seen at the current step.
  • The answers themselves will be uniformly padded for that batch, so it is necessary to add a padding mask to prevent padding words from affecting the representation of content words.

Note: The above information is only valid for training; however, to maintain code reuse, the same code is used for inference as well.

The specific implementation involves merging the two masks and taking the minimum value at each position. This means that if either mask needs to be masked in any way, then it should be masked. The shape of the Decoder’s src-attention mask is (2, 1, 5, 5).

The mask matrix for the first sequence is obtained by performing a bitwise AND operation on two masks, with the result being attn_mask. The first mask is the padding mask, and the second is the sequence mask. That is:

and

To get together

The mask matrix of the second sequence is obtained by adding the two masks together as attn_mask. Since there are 5 words, the padding mask is all 1s. The matrix of all 1s is then ANDed with the triangular matrix to obtain the following result.

In actual computation, for the first sequence:

After masking, we get

For the second sequence:

After masking, we get

4.3 Use

From a masking perspective, the biggest difference between training and inference lies in the input at each time step. During training, the input at each time step is the entire target sequence. During inference, the input at each time step is the entire output sequence produced up to the current time step.

To simulate the effect of actual inference during training, a mask is needed to hide information about subsequent words, ensuring that the decoder can only focus on words it has already generated and cannot see future words. This logic is specifically designed for training because it uses the Teacher Forcing pattern, which requires that preceding tokens cannot observe information about subsequent tokens. Although all inputs are known during inference and can see each other, so a mask is not needed, this part of the code and model structure is retained for consistency.

train

Next, let’s trace how the mask was obtained during training. The module we ultimately build is an instance of the EncoderDecoder class, with the encoder’s parameter being src_mask, and the decoder’s parameters being src_mask and tgt_mask.

class EncoderDecoder(nn.Module):
    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)

    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)

    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

Let’s delve deeper into the decoder and examine its parameters. In the forward() function of the DecoderLayer class, we can see:

  • The self-attention mechanism uses tgt_mask, which is used to apply a padding mask to the target language.
  • The cross-attention mechanism uses src_mask, which is used to apply a sequence mask to the target language.

In the decoding process of a multi-layer Transformer, each Decoder uses the same memory for cross-attention.

class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
    def forward(self, x, memory, src_mask, tgt_mask):
        "Follow Figure 1 (right) for connections."
        m = memory
        # 目标语言的自注意力, 这里 mask的作用就是用到上面所说的 softmax 之前的部分
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        # m 是encoder的输出,x是decoder第一部分的输出,因为上面一部分的输出中, 未被预测单词的 query 其实是 0(padding), 在这里可以直接使用 src_mask
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        # 最后是两个线形层, 
        return self.sublayer[2](x, self.feed_forward)

Finally, it enters the attention function attention(), which will not be elaborated on here.

reasoning

For inference, only prefill the prefetch stage requires this operation; the prefetch stage mask optimized with kv cache does not. During prefill, only the source language input batch is used; therefore, in the class batch, trg is None. As can be seen from the code below, the Attention Mask in the prediction process uses a padding mask. decode mask.

def example_simple_model():
    V = 11
    criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)
    model = make_model(V, V, N=2)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9
    )
    lr_scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=lambda step: rate(
            step, model_size=model.src_embed[0].d_model, factor=1.0, warmup=400
        ),
    )

    batch_size = 80
    for epoch in range(20):
        model.train()
        run_epoch(
            data_gen(V, batch_size, 20),
            model,
            SimpleLossCompute(model.generator, criterion),
            optimizer,
            lr_scheduler,
            mode="train",
        )
        model.eval()
        run_epoch(
            data_gen(V, batch_size, 5),
            model,
            SimpleLossCompute(model.generator, criterion),
            DummyOptimizer(),
            DummyScheduler(),
            mode="eval",
        )[0]

    # 在这里进行配置    
    model.eval()
    src = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
    max_len = src.shape[1]
    src_mask = torch.ones(1, 1, max_len) # padding mask
    
    # 这里调用到
    print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=0))

Let’s take a look directly at the implementation of the decoder in the prediction process.

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    # memory 是 encoder 的中间结果
    batch_size = src.shape[0]
    ys = torch.ones(batch_size, 1).fill_(start_symbol).type_as(src)
    # 预测句子的初始化
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, ys, transformer.subsequent_mask(ys.size(1)).type_as(src))
        # ys 的维度是 batch_size * times, 所以target_mask 矩阵必须是 times * times
        # 根据 decoder 的训练步骤, 这里的 out 输出就应该是 batch_size * (times+1) 的矩阵
        prob = model.generator(out[:, -1])
        # out[:, -1] 这里是最新的一个单词的 embedding 向量
        # generator 就是产生最后的 vocabulary 的概率, 是一个全连接层
        _, next_word = torch.max(prob, dim = 1)
        # 返回每一行的最大值, 并且会返回索引
        next_word = next_word.unsqueeze(1)
        ys = torch.cat([ys, next_word.type_as(src)], dim=1)
        # 将句子拼接起来
    return ys

This part of the code above transformer.subsequent_mask(ys.size(1)).type_as(src) explains very well how the target mask matrix is constructed.

Let’s take a look at the Decoder’s forward function. We find that it still enters attention(). But this time, the input x is tgt.

class Decoder(nn.Module):
    "Generic N layer decoder with masking."

    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask) 
        return self.norm(x)

4.4 PyTorch

If we look at the PyTorch Transformer documentation, we’ll find six types of mask matrices. We can divide these six types of mask matrices into two categories.

The first type is called an attention mask, which defines which parts of the input sequence are allowed to be focused on. It corresponds to the sequence mask in the Harvard code.

  • Source mask: The self-attention mask in the Encoder, with shape (source_len, source_len).
  • target mask: The causal self-attention mask in the Decoder, with shape (target_len, target_len).
  • Memory mask: A mask matrix used in cross-attention, with shape (target_len, source_len). This mask is used in the decoder for cross-attention, mainly to integrate the padding in the encoder and decoder. The Q in cross-attention comes from the decoder and needs to be correlated with the key-value sets in the encoder. Here, we don’t need to consider the issue of peeking into future information; we only need to consider the padding.

The second type is called a key-padding mask, which marks the position of the token in the source sequence, target sequence, and memory sequence (i.e., the output sequence of the encoder), thus making these unrecognizable. This corresponds to the padding mask in the Harvard code.

  • src_key_padding_mask: Shape (batch_size, source_len)
  • tgt_key_padding_mask: Shape (batch_size, target_len)
  • memory_key_padding_mask: Shape (batch_size, source_len)

As can be seen from the example below, attention mask and each key_padding mask has its own specific function.

# 生成一个下三角矩阵,即为 target mask
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# 或者等价地:
def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.full((sz, sz), float('-inf'), , device=DEVICE)), diagonal=1)

def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    # attention mask 部分
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)
	
    # key_padding mask 部分
    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

Essentially, it’s a more refined version of the masking mechanism in the Harvard code. We summarize their connections as follows.

1109

4.5 Summary

The flowchart below outlines the code logic. As you can see, the Encoder only looks at src_mask, while the Decoder looks at both src_mask and tgt_task.

1110

We will now present the interaction data flow diagram from the perspective of model architecture, as follows.

1111

0x05 Advanced

5.1 Sample Packing and Mask

As context length increases, batch alignment issues become apparent. Training with long texts, especially with batch sizes greater than one, can waste a significant amount of space due to pad tokens, as long texts often span multiple orders of magnitude in length distribution. The following diagram provides an example.

1112

For example, a 4K sample and a 64K sample might appear in the same batch. In this case, the 4K sample would be padded with padding tokens to the length of the longest sample in the batch. This means that a 4K sample might be padded to a length of 60K, resulting in significant waste.

definition

Fortunately, most current fine-tuning frameworks can solve this problem using sample packing. Sample packing essentially eliminates the concept of batch size. A batch containing 3 samples is now concatenated into a longer single sequence. The three samples are joined end-to-end to form a sequence, and the attention mask is changed accordingly to prevent different samples in the same sequence from interfering with each other. The advantage of this is that there are no more pad tokens: an input may contain 2 long samples or 100 short samples.

1113

However, in practice, the LongAlign paper mentions that having long samples and very short samples in the same batch may affect model convergence. To solve this problem, samples of similar length are usually placed in the same batch during training.

Attention mask

Taking Megatron-LM (DeepSpeed-Megatron) as an example, pre-training typically involves many different datasets, each containing numerous documents. To improve training efficiency, during actual training, a single sample (sequence) may contain multiple different documents (sample packing). For instance, with an 8K pre-training sequence length, one sample could contain eight 1K documents.

For a single document, the Decoder-Only GPT model exhibits Causal properties, meaning each token cannot see subsequent tokens. Therefore, an Attention Mask is needed during actual training. In this case, the Attention Mask is a standard lower triangular matrix (Causal Mask), where the green parts are 1s and the rest are 0s.

1114

If a sample contains multiple samples, the Attention Mask matrix needs to be transformed into a block diagonal matrix form as shown in the figure below. For example, if the Sequence Length is 16, and the lengths of the four Documents are 3, 4, 5, and 4 respectively, the corresponding Attention Mask matrix is shown in the figure below, where the four diagonal matrices (red boxes) are standard lower triangular matrices. This method ensures equivalence to training with the four Documents as individual samples.

1115

The paper “LongAlign: A Recipe for Long Context Alignment of Large Language Models” discusses some issues related to sample packing. As shown in the left figure below, the sequence lengths vary, ranging from 0 to 60K. Using a naive batting approach would lead to a significant bubble problem. To address the efficiency and effectiveness issues, the authors propose three solutions: Packing, Loss Weighting, and Sorted Batching.

The right side of the image below shows Sample Packing, which we introduced earlier: different samples are concatenated into a single sequence, ensuring it is as close as possible to the Max Sequence Length, with the tokens at the end being padded. Then, a Block Diagonal Attention Mask is used to distinguish between different samples to avoid cross-contamination between samples, which is Document Level Attention.

1116

Strategy

In the paper “Enhancing Training Efficiency Using Packing with Flash Attention”, the authors summarize the advantages of different packing strategies, masking methods, and their combination with FlashAttention.

As shown in the figure below, the author analyzes different packing schemes and their impact, specifically including the following methods:

  • RandomSampling + Padding: The most traditional method of random sampling followed by padding. It involves redundant computation, and this redundancy accounts for a significant portion of the computation.
  • GroupByLength+Padding: Sort the data first, then try to ensure that the sequence lengths in each batch are close. This can reduce the proportion of padding.
  • RandomSampling + PosID: Random sampling, but without padding; instead, it supports variable-length sequences via PosID. There is almost no redundant computation, but significant load imbalance (computational cost) may exist.
  • FixedLengthPacking: Random sampling and random packing, with the last sample potentially truncated, ensuring the maximum sequence length is filled. It does not differentiate between different samples (i.e., Causal Mask), has no redundant calculations, and provides a very balanced load.
  • FixedLengthPacking + PosID: Compared to FixedLengthPacking, it adds PosID, which can distinguish different samples and their corresponding Block Diagonal Masks. However, it still has the potential for end truncation and uneven load distribution.
  • MultiPack + PosID: This method makes the data in a sequence as close as possible to the maximum sequence length of the batch, reducing the imbalance in sequence length. See GitHub - imoneoi/multipack_sampler: Multipack distributed sampler for fast padding-free training of LLMs. The data needs to be sorted.
  • SortedPacking + PosID: By sorting, the computational complexity within the same batch is made as similar as possible. This can minimize the problem of uneven computational load.
  • RandomPacking + PosID: The main difference between RandomPacking + PosID and FixedLengthPacking + PosID is that the last sample is not truncated, which may result in some Bubble.

1117

5.2 Function

Studies have shown that pure self-attention mechanisms experience rank collapse with increasing depth, limiting the model’s expressive power and ability to further utilize model depth. However, most existing literature on rank collapse neglects other key components in the Transformer that may mitigate the problem. The paper “On the Role of Attention Masks and LayerNorm in Transformers” provides a comprehensive analysis of rank collapse under self-attention mechanisms, considering the effects of attention masks and layer normalization (LayerNorm). Specifically, the authors find that while pure masked attention still collapses exponentially to a rank-1 subspace, sparse or local masked attention can be shown to slow down the collapse rate. In the case of LayerNorm, the authors first show that for certain class value matrices, rank-1 subspace collapse still occurs exponentially. However, by constructing nontrivial counterexamples, the authors demonstrate that, with appropriate value matrix selection, a general sequence may not converge to a rank-1 subspace, and that self-attention dynamics with LayerNorm can simultaneously achieve a balance point with any rank from 1 to full rank. The authors’ results refute previous assumptions that LayerNorm does not function in self-attention rank collapse and show that self-attention with LayerNorm constitutes a more expressive and versatile nonlinear dynamical system than initially thought.

Innovation

  • Impact of Attention Masks on Rank Collapse: This paper is the first to systematically analyze the impact of attention masks on rank collapse in Transformers. By introducing graph theory methods, the paper proves that in quasi-strongly connected graphs, rank collapse of tokens still occurs, albeit at a slower rate, even when using sparse or local attention masks. This finding provides a theoretical basis for designing more efficient attention mechanisms.
  • The mitigation effect of LayerNorm on rank collapse: By constructing nontrivial counterexamples, the authors demonstrate that LayerNorm can effectively alleviate the rank collapse problem of tokens in certain situations. With appropriate choice of value matrix, the self-attention dynamics with LayerNorm can simultaneously achieve a balance point of any rank from 1 to full rank.

Masked attention

The authors first analyze the case without LayerNorm and then focus on the impact of attention masking.

1118

The above results demonstrate that under pure self-attention, as long as one token exists in the sequence, all other tokens can participate directly or indirectly within a fixed number of layers, leading to exponential rank collapse. In particular, this conclusion can be generalized to a more general category of attention patterns: attention patterns only need to be quasi-strongly connected, meaning that rank collapse will exist for various attention masks used in practice, including causal masks used in the GPT series, or sparse attention patterns deployed in many efficient Transformer models.

The author discusses the following interesting implications.

  • Local vs. Global Attention; Exponential Rate : The graph is monotonic on radius . This means that rank collapse should be slower for graphs with larger radii. This shows that using the local attention model not only makes attention computation more efficient, but also implicitly mitigates the rank collapse problem.
  • Focused vs. Uniform Attention: Furthermore, exponential rate in the upper part is monotonically decreasing in , which means the smaller the value, the slower the rank collapse. This can be interpreted as the degree of “focus” of attention among reachable tokens, because reaches maximum when it is evenly distributed across reachable tokens. Besides applying attention masks and limiting the number of reachable tokens, another way to control the degree of attention focus is through a temperature term. Larger values make attention distribution more even among reachable tokens, thus causing rank collapse to occur more quickly across layers.
  • Trade-off between rank collapse and universal approximation capability. Finally, for strongly connected graphs, the above results also reveal a trade-off between the universal function approximation capability and the rate of rank collapse. Studies have shown that a Transformer with a strongly connected graph mask is a universal approximator for sequence-to-sequence functions; however, for masks they need at least layers to achieve the full sequence-to-sequence function approximation property. This means that masks with smaller diameters are more efficient in terms of function approximation capabilities, but they are more prone to rank collapse.

Masked attention with LayerNorm

Next, let’s examine the properties of masked attention with LayerNorm.

The authors first present a negative result, showing that for certain categories of value matrices, if the cosine similarity of all token pairs is initially non-negative, then as long as even if the network is quasi-strongly connected, tokens can still collapse exponentially into a common vector, a phenomenon known as rank collapse. However, if is a causal graph, and the mask will have only one central node and a looser upper bound, which shows that causal masks have an advantage over full masks in mitigating the rate of rank collapse.

The authors then present a counterexample showing that for a general class of input sequences, when using only LayerNorm, the token converges to an equilibrium state where rank collapse does not occur. They further demonstrate a general result showing that, with LayerNorm and an appropriately chosen value matrix, the self-attention dynamics can have an equilibrium point with any rank from 1 to full rank.

0xFF Reference

  • LLM pre-training corpus, preprocessing and dataset indexing, loading summary AI chat
  • FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention by Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong
  • Sample Packing: Attention Issues and Optimizations in Long-Sequence LLM Training
  • https://blog.csdn.net/zhaohongfei_358/article/details/125858248
  • Transformer Series: Detailed Explanation of Decoder Principles (xiaogp)
  • LongAlign: A Recipe for Long Context Alignment of Large Language Models
  • NIPS 2024 | The Role of Attention Masks and LayerNorms in Transformers [CV Technical Guide]
  • On the Role of Attention Masks and LayerNorm in Transformers
  • The use of three types of masks in transformers (Early Summer)
  • [Deep Learning] A Deep Explanation of the Mask Mechanism in Transformer (Articoder)
  • What is the Mask in the Transformer Encoder/Decoder architecture? [AIGC Beginner’s Guide]
  • LongAlign