Transformer Systems · Transformer Systems

Exploring the Transformer Series (31) --- Medusa

Medusa: multi-decoding heads, tree attention, typical acceptance, sparse tree construction, training strategies, and decoding flow.

Advanced Decoding And DeepSeek Systemsexpert2 hrReading
transformermedusaspeculative-decodinginferencetree-attentiondecodellm

Exploring the Transformer Series (31) --- Medusa

0x00 Overview

Medusa is one of the earlier works in the field of self-speculation and has greatly inspired subsequent work. Its main idea is multi-decoding head + tree attention + typical acceptance (threshold). Medusa does not use a separate draft model, but adds multiple decoding heads (MEDUSA heads) to the original model to predict multiple subsequent tokens in parallel.

A typical LLM has only one head for predicting the token at time t. Medusa retains the original LM Head after the last Transformer layer of the LLM and adds multiple (let’s say k) trainable Medusa Heads (decoding heads), each responsible for predicting multiple tokens at different positions at times t+1, t+2, …, and t+k. Medusa allows each head to generate multiple candidate tokens, instead of generating only one candidate as in speculative decoding. All candidate results are then assembled into multiple candidate sequences, which in turn form a tree. These candidate sequences are then validated in parallel using a tree attention mechanism.


Note: The complete list of articles is here. It’s estimated to eventually have around 35 articles. This list will be updated after each subsequent article is published.
Cnblogs Exploring Transformer Series: Article List


0x01 Principle

1.1 Motivation

3101

The core idea of speculative sampling is shown in the diagram below. First, it quickly generates multiple candidate tokens in a low-cost manner (generally using a small model). Then, it quickly verifies these tokens through a single parallel verification stage, thereby reducing the decoding step of the large model and achieving acceleration. However, using a separate “speculation” model also has drawbacks, as follows:

  • It is difficult to find a small but powerful model to generate tokens that are relatively simple for the original model.
    • It is difficult to align the draft model with the large model, resulting in distribution shift.
    • Not all LLMs have readily available small models. Retraining a small model requires a significant additional investment.
  • Maintaining two different models in a system increases the computational complexity of the inference process, leads to architectural complexity, and makes deployment in a distributed system more difficult.
  • Using speculative sampling introduces additional decoding overhead, especially when using a relatively high sampling temperature value.

1.2 Reference

Medua primarily drew upon two works: BPD and SpecInfer.

  • The main model itself has an LM head used to map the hidden layer outputs to the probability distribution of the vocabulary to achieve decoding of a single token. To generate multiple tokens, the paper “Blockwise Parallel Decoding for Deep Autoregressive Models” uses multiple decoding heads on the backbone model to accelerate inference. By training an auxiliary model, the model can predict the output at future positions, and then use these predictions to skip some greedy decoding steps, thereby accelerating the decoding process.
  • The paper “SpecInfer: Accelerating Generative Large Language Model Serving with Speculative Inference and Token Tree Verification” proposes the following approach: since small models can guess the output of large models with high efficiency, multiple small models can be used to guess multiple token sequences, thus providing more candidates and increasing the chance of correct guessing. To improve the verification efficiency of these multiple token sequences, the authors propose a Token Tree Attention mechanism. First, the multiple token sequences generated by multiple small models are combined into a token tree, and then the tree is expanded and input into the model, which allows the verification of the entire token tree to be completed in a single decoding step.

1.3 Approach

Based on these two ideas, Medusa decided to let the target LLM make its own predictions. This involved introducing multiple additional prediction heads above the last decoder layer of the target LLM, allowing the model to generate multiple tokens in parallel at each decoding step as “speculation” results. We will analyze this in detail.

1.3.1 Single Model & Multi-Head

To discard the independent Draft Model and retain only one model while preserving the Draft-then-Verify paradigm, Medusa adds several Medusa Heads after the final hidden layer of the backbone model. Each decoding head is a single-layer feedforward network with residual connections. These Medusa Heads are an upgrade to the multi-head approach in BPD, changing from one head generating one token to one head generating multiple candidate tokens. Because these Heads have the ability to predict the token at the corresponding position and can be executed in parallel, it is possible to obtain multiple draft tokens in a single forward pass. See the diagram below for details.

Some readers might wonder if the accuracy of the later heads, which involve cross-word prediction, is difficult to guarantee. That’s true, but if we select the top 3 at each prediction time step, the probability of a successful prediction increases significantly. Furthermore, the Medusa authors observed that while the accuracy of the top 1 might only be 60% when predicting the next next token, selecting the top 5 could potentially exceed 80%. Also, because the Medusa decoding head shares the hidden layer states with the original model, the distribution difference is relatively small.

3102

1.3.2 Tree Validation

Because greedy decoding doesn’t achieve high enough accuracy and its speedup effect isn’t significant enough, Medusa decodes the top-k candidates for each head, forming a tree structure with candidate sets from different heads. To more efficiently verify these draft tokens, Medusa constructs multiple token sequences by generating Cartesian products of tokens from these heads. Then, using the Tree Attention method, which only allows tokens within the same continuum to see each other (attention mask) during attention computation, combined with positional encoding, multiple candidates can be processed in parallel without increasing the batch size.

The tree and attention mask matrix in Medusa are shown in the figure below. In each hop, we see that Medusa retains multiple possible tokens, that is, the tokens with the highest probabilities. This forms the so-called tree structure. Intuitively, each token in each hop can be combined with all the tokens in the next hop to form a sentence, or the sentence can terminate at this hop. For example, in the figure, a total of 2 heads generate 2 hops of tokens, so this tree contains 6 possible sentences: Head 1 generates 2 possible tokens (It and I) in the next position, and Head 2 generates 3 possible tokens (is, ’, and the) in the position after that. Thus, the next position and the position after that have 2 x 3 = 6 possible candidate sequences, as shown on the left side of the figure below.

The corresponding Attention Mask matrix is shown on the right. Slightly different from the original speculative decoding, the tree contains multiple decoding paths, and these paths cannot access each other. For example, (1) “It is” and (2) “I is” are two paths. When calculating the probability distribution of (1).is, only (1).it can be seen, not “I” in (2). Therefore, Medusa created a new attention mask needed for parallel computation of the probability distributions of multiple paths, called “Tree attention.” Essentially, it follows the rules of causal masks within the same path, and different paths cannot access each other.

The authors of Medusa state that in SpecInfer, each speculator generates a sequence of different lengths, so the mask is dynamically changing. In contrast, Medusa’s Tree Attention Mask remains statically invariant during inference, which further improves the efficiency of preprocessing the tree attention mask.

3103

1.3.3 Summary

The table below shows the differences between BPD, SpecInfer, and Medusa.

field Blockwise Parallel Decoding SpecInfer Medusa
Multiple models Instead of actually constructing k-1 auxiliary models, the original model was slightly modified to enable it to predict the next k tokens. A batch of small speculative models (SSMs) are used to predict multiple candidate SSMs in parallel. These can be distilled, quantized, and pruned versions of the original LLM.
long Adding k project layers, the output of these k project layers is the logits of tokens at k different positions. After the last Transformer Layer of the LLM, the original LM Head is retained, and then multiple Medusa Heads are added to obtain multiple candidate token sequences.
Tree Multiple candidates predicted by SSMs are merged into a new token tree, and parallel validation is performed using the original LLM. In SpecInfer, each speculator generates a sequence of different lengths, so the Mask is dynamically changing. Medusa's Tree Attention Mask is statically invariant during the Infringement process, which further improves the efficiency of preprocessing the Tree Attention Mask.
train Retrain the original model Training small models Instead of retraining the entire large model, the large model is frozen and only the decoder head is trained.

0x02 Design Core Points

2.1 Process

MEDUSA’s general approach is similar to speculative decoding, where each decoding step mainly consists of three sub-steps:

  • Candidate generation. MEDUSA obtains candidate tokens at multiple locations by attaching multiple Medusa decoders to the original model.
  • Candidate processing. MEDUSA processes the candidate tokens at each position, selecting some candidate sequences. Then, it uses tree attention for validation. Since the MEDUSA head is on top of the original model, the logits calculated here can be used in the next decoding step.
  • Accept candidates. Select the final output result through typical acceptance.

Medusa’s greater advantage lies in the fact that, apart from the initial Prefill, subsequent steps can achieve the effect of verifying and generating simultaneously. In other words, Medusa’s inference process can be understood as: Prefill + Verify + Verify + …

2.2 Model Structure

The code below shows the Medusa model structure. Medusa retains the original LM Head after the last Transformer Layer of the LLM, and then adds multiple Medusa Heads, which are multiple different branch outputs. This allows for the prediction of multiple candidate token sequences.

The input to the Medusa head is the hidden layer output of the large model. This is another important difference from speculative decoding using an external small model. The input to the external small model is the token embedding obtained by looking up a table, which is much weaker than the last hidden layer of the large model here, and therefore depends heavily on the performance of the small model. Precisely because of the hidden layer output of the large model, the structure of the Medusa head here is very simple.

class MedusaLlamaModel(KVLlamaForCausalLM):
    """The Medusa Language Model Head.
    This module creates a series of prediction heads (based on the 'medusa' parameter)
    on top of a given base model. Each head is composed of a sequence of residual blocks
    followed by a linear layer.
    """

    def __init__(
        self,
        config,
    ):
        # Load the base model
        super().__init__(config)
        # For compatibility with the old APIs

        medusa_num_heads = config.medusa_num_heads
        medusa_num_layers = config.medusa_num_layers
        base_model_name_or_path = config._name_or_path
        self.hidden_size = config.hidden_size
        self.vocab_size = config.vocab_size
        self.medusa = medusa_num_heads
        self.medusa_num_layers = medusa_num_layers
        self.base_model_name_or_path = base_model_name_or_path
        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
        # Create a list of Medusa heads
        self.medusa_head = nn.ModuleList(
            [
                nn.Sequential(
                    *([ResBlock(self.hidden_size)] * medusa_num_layers),
                    nn.Linear(self.hidden_size, self.vocab_size, bias=False),
                )
                for _ in range(medusa_num_heads)
            ]
        )
        

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        past_key_values=None,
        output_orig=False,
        position_ids=None,
        medusa_forward=False,
        **kwargs,
    ):
        """Forward pass of the MedusaModel.

        Args:
            input_ids (torch.Tensor, optional): Input token IDs.
            attention_mask (torch.Tensor, optional): Attention mask.
            labels (torch.Tensor, optional): Ground truth labels for loss computation.
            past_key_values (tuple, optional): Tuple containing past key and value states for attention.
            output_orig (bool, optional): Whether to also output predictions from the original LM head.
            position_ids (torch.Tensor, optional): Position IDs.

        Returns:
            torch.Tensor: A tensor containing predictions from all Medusa heads.
            (Optional) Original predictions from the base model's LM head.
        """
        if not medusa_forward:
            return super().forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                position_ids=position_ids,
                **kwargs,
            )
        with torch.inference_mode():
            # Pass input through the base model
            outputs = self.base_model.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                past_key_values=past_key_values,
                position_ids=position_ids,
                **kwargs,
            )
            if output_orig:
                # 原始模型输出
                orig = self.base_model.lm_head(outputs[0])
        # Clone the output hidden states
        hidden_states = outputs[0].clone()
        medusa_logits = []
        # TODO: Consider parallelizing this loop for efficiency?
        for i in range(self.medusa):
            # 美杜莎头输出
            medusa_logits.append(self.medusa_head[i](hidden_states))
        if output_orig:
            return torch.stack(medusa_logits, dim=0), outputs, orig
        return torch.stack(medusa_logits, dim=0)

2.3 Bullish

2.3.1 Head Structure

Medusa adds an additional medusa_num_heads Medusa Heads. Each Medusa Head is a single-layer feedforward network with residual connections. The linear dimension of each layer is the same as the default lm_head dimension of the model, which can predict subsequent tokens.

self.medusa_head = nn.ModuleList(
    [
        nn.Sequential(
            *([ResBlock(self.hidden_size)] * medusa_num_layers),
            nn.Linear(self.hidden_size, self.vocab_size, bias=False),
        )
        for _ in range(medusa_num_heads)
    ]
)

The code below shows the actual content that will be printed.

ModuleList(
  (0-3): 4 x Sequential(
    (0): ResBlock(
      (linear): Linear(in_features=4096, out_features=4096, bias=True)
      (act): SiLU()
    )
    (1): Linear(in_features=4096, out_features=32000, bias=False)
  )
)

Let the output distribution of the k-th decoder head on the vocabulary be denoted as . The calculation method is as follows. d is the output dimension of the hidden state, V is the vocabulary size, and the prediction of the original model is represented as .

3104

Below is a diagram illustrating the combination of code and model structure.

3105

2.3.2 Location

Each head in Medusa predicts a different offset; the k-th head is used to predict the output token at position (where k ranges from 1 to K). The original model’s decoding head still predicts the output at position , which is equivalent to k=0. Specifically, the last hidden state of the original model at position t is . Connect to K decoders, for an input token sequence . The original head predicts based on the input . The first head added by Medusa predicts based on the input , the token, that is, the skip token , to predict the next future token. Each head can specify top k results. The predictions from these heads form multiple candidate word sequences, which are then processed simultaneously using a tree-structured attention mechanism. At each decoding step, the longest accepted candidate sequence is selected as the final prediction. In this way, multiple words can be predicted per step, thereby reducing the total number of decoding steps and improving inference speed.

As shown in the figure below, Medusa adds three additional heads to the original model, which can predict the candidates for the last four tokens in parallel.

3106

2.4 Disadvantages

The disadvantages of Medusa are as follows:

  • The newly added lm_head in Medusa and the last Transformer Block only have one MLP between them, which may limit their expressive power.
  • Medusa increases the number of model parameters, which increases memory usage.
  • Each head in Medusa is executed independently, meaning that the prediction of the “next next token” does not depend on the result of the previous “next token”. This leads to poor generation performance, low acceptance rate, and may even result in negative optimization when the batch size is large.
  • The lack of sequence dependencies can also lead to inefficient tree pruning algorithms.
  • The draft quality is still not high, the speedup effect is limited, and the output distribution cannot be guaranteed to be consistent with the target LLM under non-greedy decoding.

Therefore, subsequent research has improved upon this. For example, Clover focuses on providing sequence dependencies and adding modules with stronger representation capabilities than a single MLP. Hydra increases the correlation between predictions in the draft head. Hydra++ uses the output prediction probability of the base model as the output of the teacher model for knowledge distillation to train the draft head. And similar to EAGLE, Hydra++ adds an independent decoder layer, and each Hydra head, in addition to the previous token itself, also adds the representation of the previous token in this decoder layer as input.

0x03 Tree Verification

Each Medusa Head generates top-k predicted tokens, and then forms candidate sequences by computing the Cartesian product of these predictions. We could validate each candidate sequence by running the model through it, but this is too time-consuming. Therefore, the Medusa authors designed a tree attention mechanism, performing a masking operation within the candidate tree to restrict the attention of a token to preceding tokens. Simultaneously, the correct position indices are set for the position embeddings accordingly. Because of the tree attention mechanism, Medusa can construct, maintain, and validate multiple candidate sequences in parallel.

3.1 Decoding Path

In Medusa, the basic version uses a greedy method to extract the top-1 tokens during decoding. After adding an extra decoding head, Medusa uses Top-K Sampling, where each head outputs k tokens. The prediction results are inconsistent between different Medusa heads. and are formally conditionally independent, but in reality relies on . It cannot directly take the largest token from and as input for the verification stage, because sentences formed from this may be logically inconsistent. Therefore, Medusa introduces sampling top-k combinations as candidate sequences to alleviate this problem. Finally, LM_head the output of is used as the root node to construct a tree structure, and the depth of the tree traversed from top to bottom is called the decoding path (called the candidate path in the paper). Each candidate sequence can represent all nodes on a path in the constructed tree (not just leaf nodes, because tree attention will verify all tokens on the path).

Since there are K heads, each head selects . If there are 1 predicted output, then all possible paths can be combined to form the sum of all nodes in the tree, i.e. . When constructing a tree structure, the simplest method is to obtain all possible candidate sequences composed of multiple decoders using a Cartesian product. The example below uses the Cartesian product to process the results of two decoders to obtain all candidate sequences. Specifically, the top-k words of each head are treated as nodes, and each head is treated as a layer of the tree. There are a total of 6 decoding paths in the graph. This means that Head 1 generates 2 possible tokens (It and I) at the next position, and Head 2 generates 3 possible tokens (is, ’, and the) at the next-next position. Thus, the next and next-next positions have 2 x 3 = 6 possible candidate sequences. To distinguish different prefixes, Medusa uses redundancy; for example, the three predicted tokens of Head 2 appear twice, corresponding to the different prefixes It and I. Each token, under the tree mask, can only see its own prefix.

3107

3.2 Optimal Construction Method

The above diagram uses the top-3 algorithm, with a total of 6 candidate paths across two heads. If the number of decoding heads is large, each head provides a large number of candidate tokens. The number of decoding paths increases dramatically with the top-k algorithm and the number of heads, resulting in a massive number of candidate paths and a huge search space. While increasing the number of candidate sequences improves the final token acceptance hit rate, validating more candidate sequences also incurs additional computational costs. The new problem then becomes:

  • How can we reduce the number of candidate decoding paths?
  • How can we obtain the optimal decoding path from the candidate decoding paths?

Intuitively, candidate results composed of top-k predictions from different heads may have different accuracies. More accurate predictions should be prioritized to build a more efficient tree, rather than using all possible permutations. Medusa calculates the accuracy of the top-k predictions for each decoding head on a calibration dataset and then greedily selects nodes that maximize the expected accepted length and adds them to the tree. This achieves a higher speedup with the same total number of nodes. Essentially, this method speeds up the process by pruning, removing nodes from each head several.

Specifically, we can use a calibration dataset (such as the Alpaca-eval dataset) to obtain the accuracy of each token given by different decoders: let the accuracy of the i-th token given by the k-th decoder be denoted as . Assuming the accuracy of each token is independent, then a token composed of The accuracy of the constructed candidate sequences can be written as . Let I denote the set of candidate sequences. Then, the expectation of acceptance length of the candidate sequences in the set is expressed as:

When constructing the tree, Medusa uses a greedy algorithm to prioritize adding candidate sequences with the highest accuracy until the number of nodes in the tree reaches the upper limit of the expected acceptance length. This maximizes the expectation of acceptance length, and thus the acceleration rate. This is a hand-designed sparse tree structure where earlier nodes have more child paths. That is, all possible combinations are exhaustively explored, and the top n combinations are selected as fixed possibilities, with the rest pruned.

The following figure shows an example of a sparse tree for the MEDUSA-2 Vicuna-7B model. This tree structure extends four levels, indicating that four MEDUSA heads are involved in the computation. The tree is initially generated using a Cartesian product method and then pruned based on the statistical expectation of the top k predictions of each MEDUSA head measured on the Alpaca-eval dataset. The leftward tilt of the tree visually represents the algorithm’s tendency to use tokens with higher accuracy. Each node represents one of the top-k predictions from the MEDUSA head, and the edges show the connections between them. Red lines highlight the paths that correctly predict future tokens. This optimizes a tree with 1000 paths to only 42 paths, and these paths can terminate early, without necessarily traversing to the last level.

3108

3.3 Implementation

3.3.1 Key Variables

Let’s first look at the key variables involved in the attention tree.

demo_tensor

demo_tensor is the input tensor, as shown in the following example:

[2, 3, 0, 0, 0, 0, 0, 0 ...] # 1st depth we choose top 2
[4, 5, 6, 0, 0, 0, 0, 0 ...] # 2nd depth we choose top 3

See the image below.

3109

medusa_choices

medusa_choices is a nested list representing the medusa tree structure and determining the decoding path. The outer list corresponds to nodes in the tree, and each inner list gives the node’s ancestors in the tree and their positions. Based on this, Medusa choieswe can construct all the data members of the sparse tree, as shown in the example in the source code below.

vicuna_7b_stage2 = [(0,), (0, 0), (1,), (0, 1), (0, 0, 0), (1, 0), (2,), (0, 2), (0, 0, 1), (0, 3), (3,), (0, 1, 0), (2, 0), (4,), (0, 0, 2), (0, 4), (1, 1), (1, 0, 0), (0, 0, 0, 0), (5,), (0, 0, 3), (0, 5), (0, 2, 0), (3, 0), (0, 1, 1), (0, 6), (6,), (0, 7), (0, 0, 4), (4, 0), (1, 2), (0, 8), (7,), (0, 3, 0), (0, 0, 0, 1), (0, 0, 5), (2, 1), (0, 0, 6), (1, 0, 1), (0, 0, 1, 0), (2, 0, 0), (5, 0), (0, 9), (0, 1, 2), (8,), (0, 4, 0), (0, 2, 1), (1, 3), (0, 0, 7), (0, 0, 0, 2), (0, 0, 8), (1, 1, 0), (0, 1, 0, 0), (6, 0), (9,), (0, 1, 3), (0, 0, 0, 3), (1, 0, 2), (0, 5, 0), (3, 1), (0, 0, 2, 0), (7, 0), (1, 4)]
vicuna_7b_stage1_ablation = [(0,), (0, 0), (1,), (0, 0, 0), (0, 1), (1, 0), (2,), (0, 2), (0, 0, 1), (3,), (0, 3), (0, 1, 0), (2, 0), (0, 0, 2), (0, 4), (4,), (0, 0, 0, 0), (1, 0, 0), (1, 1), (0, 0, 3), (0, 2, 0), (0, 5), (5,), (3, 0), (0, 1, 1), (0, 6), (6,), (0, 0, 4), (1, 2), (0, 0, 0, 1), (4, 0), (0, 0, 5), (0, 7), (0, 8), (0, 3, 0), (0, 0, 1, 0), (1, 0, 1), (7,), (2, 0, 0), (0, 0, 6), (2, 1), (0, 1, 2), (5, 0), (0, 2, 1), (0, 9), (0, 0, 0, 2), (0, 4, 0), (8,), (1, 3), (0, 0, 7), (0, 1, 0, 0), (1, 1, 0), (6, 0), (9,), (0, 0, 8), (0, 0, 9), (0, 5, 0), (0, 0, 2, 0), (1, 0, 2), (0, 1, 3), (0, 0, 0, 3), (3, 0, 0), (3, 1)]
vicuna_7b_stage1 = [(0,), (0, 0), (1,), (2,), (0, 1), (1, 0), (3,), (0, 2), (4,), (0, 0, 0), (0, 3), (5,), (2, 0), (0, 4), (6,), (0, 5), (1, 1), (0, 0, 1), (7,), (3, 0), (0, 6), (8,), (9,), (0, 1, 0), (0, 7), (0, 8), (4, 0), (0, 0, 2), (1, 2), (0, 9), (2, 1), (5, 0), (1, 0, 0), (0, 0, 3), (1, 3), (0, 2, 0), (0, 1, 1), (0, 0, 4), (6, 0), (1, 4), (0, 0, 5), (2, 2), (0, 3, 0), (3, 1), (0, 0, 6), (7, 0), (1, 5), (1, 0, 1), (2, 0, 0), (0, 0, 7), (8, 0), (0, 0, 0, 0), (4, 1), (0, 1, 2), (0, 4, 0), (9, 0), (0, 2, 1), (2, 3), (1, 6), (0, 0, 8), (0, 5, 0), (3, 2), (5, 1)]

In our example here [[0], [0, 0], [0, 1], [0, 2], [1], [1, 0], [1, 1], [1, 2]], [1]is the root node, and the visualization is as follows.

[1]
[2, 3]
[4, 5, 6]

medusa_buffers

The data structure information for medusa_buffers is as follows.

medusa_buffers = generate_medusa_buffers(medusa_choices, device='cpu')

medusa_buffers = {
    "medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0),
    "tree_indices": medusa_tree_indices,
    "medusa_position_ids": medusa_position_ids,
    "retrieve_indices": retrieve_indices,
    }

The member variables serve the following purposes:

  • medusa_attn_mask: This is the mask used for tree attention.
  • tree_indices: The position of an element in the tree within demo_tensor, which is used in the generate_candidates() function.
  • medusa_position_ids: Ensures that nodes at the same depth have the same position ID. This ID is added to the position encoding, and incorporating this information during training results in a better medusa header. It is used in the tree_decoding() function.
  • retrieve_indices: Maps the tree to Cartesian products, representing the position of each Cartesian product in logits. Based on this information, the corresponding logits for each Cartesian product can be extracted from logits. It is used in the tree_decoding() and generate_candidates() functions.

tree_indices

tree_indices represents the position of an element in the tree within the demo_tensor. For a given input tensor, the corresponding tree_indices are as follows.

[0, 1, 2, 3, 4, 5, 3, 4, 5]

The grown tree looks like this.

1
|-- 2
|  |-- 4
|  |-- 5
|  |-- 6
|-- 3
|  |-- 4
|  |-- 5
|  |-- 6

The flattened tree nodes obtained from demo_tensor are as follows.

[1, 2, 3, 4, 5, 6, 4, 5, 6]

See the image below.

3110

medusa_position_ids

medusa_position_ids: Ensures that nodes at the same depth have the same position ID. With this information, the position encoding for each token is: position in the sequence + depth in the tree. This allows the depth information to be known during subsequent training of the medusa head, resulting in a better medusa head. This variable is used in the tree_decoding() function.

The position ID corresponding to the input tensor is as follows.

[0, 1, 1, 2, 2, 2, 2, 2, 2] # Medusa position IDs
 |  |  |  |  |  |  |  |  |
[1, 2, 3, 4, 5, 6, 4, 5, 6] # Flatten tree representation of the tensor

The visualization is as follows.

3111

retrieve_indices

retrieve_indices is a tree mapping to Cartesian products, representing the position of each Cartesian product in logits. Based on this information, the corresponding logits for each Cartesian product can be extracted from logits.

The retrieve_indices for this example are as follows.

[0, 2, 8]
[0, 2, 7]
[0, 2, 6]
[0, 1, 5]
[0, 1, 4]
[0, 1, 3]

The tree is mapped to the Cartesian product as follows.

[1, 3, 6]
[1, 3, 5]
[1, 3, 4]
[1, 2, 6]
[1, 2, 5]
[1, 2, 4]

The specific visualization is as follows.

3112

medusa_attn_mask

Because the final tree is constructed by using the top-k words of each head as nodes, each head as a layer of the tree, and each path leading to a leaf node constitutes a set of predictions to be validated. Within this tree, the Attention Mask requires a new design that restricts attention from one token to the preceding tokens. Simultaneously, the correct position indices must be set for the corresponding position embeddings. Details of the mask matrix are as follows:

  • MaskEach row of the matrix can represent a token prediction task.
  • In Tree Maskthe matrix, positional encoding needs to be misaligned.

Examples from the paper are as follows.

3113

The mask for this example is as follows.

3114

3.3.2 Example Code

Example code is as follows

demo_tensor = torch.zeros(2,10).long()
demo_tensor[0,0] = 2
demo_tensor[0,1] = 3
demo_tensor[1,0] = 4
demo_tensor[1,1] = 5
demo_tensor[1,2] = 6
print('Demo tensor: \n', demo_tensor)
demo_tensor = demo_tensor.flatten()
demo_tensor = torch.cat([torch.ones(1).long(), demo_tensor])
print('='*50)
medusa_choices = [[0], [0, 0], [0, 1], [0, 2], [1], [1, 0], [1, 1], [1, 2]]
medusa_buffers = generate_medusa_buffers(medusa_choices, device='cpu')
tree_indices = medusa_buffers['tree_indices']
medusa_position_ids = medusa_buffers['medusa_position_ids']
retrieve_indices = medusa_buffers['retrieve_indices']
print('Tree indices: \n', tree_indices.tolist())
print('Tree reprentation of the tensor: \n', demo_tensor[tree_indices].tolist())
print('='*50)
print('Medusa position ids: \n', medusa_position_ids.tolist())
print('='*50)
print('Retrieve indices: \n', retrieve_indices.tolist())
demo_tensor_tree = demo_tensor[tree_indices]
demo_tensor_tree_ext = torch.cat([demo_tensor_tree, torch.ones(1).long().mul(-1)])
print('Retrieve reprentation of the tensor: \n', demo_tensor_tree_ext[retrieve_indices].tolist())
print('='*50)
demo_tensor_tree_ext[retrieve_indices].tolist()
print('='*50)
print(medusa_buffers['medusa_attn_mask'][0,0,:,:].int())
print('='*50)
print(medusa_buffers['medusa_attn_mask'][0,0,:,:].int())

Print result:

Demo tensor: 
 tensor([[2, 3, 0, 0, 0, 0, 0, 0, 0, 0],
        [4, 5, 6, 0, 0, 0, 0, 0, 0, 0]])
==================================================
Tree indices: 
 [0, 1, 2, 11, 12, 13, 11, 12, 13]
Tree reprentation of the tensor: 
 [1, 2, 3, 4, 5, 6, 4, 5, 6]
==================================================
Medusa position ids: 
 [0, 1, 1, 2, 2, 2, 2, 2, 2]
==================================================
Retrieve indices: 
 [[0, 2, 8], [0, 2, 7], [0, 2, 6], [0, 1, 5], [0, 1, 4], [0, 1, 3]]
Retrieve reprentation of the tensor: 
 [[1, 3, 6], [1, 3, 5], [1, 3, 4], [1, 2, 6], [1, 2, 5], [1, 2, 4]]
==================================================
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 1, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 1, 0, 0, 0],
        [1, 0, 1, 0, 0, 0, 1, 0, 0],
        [1, 0, 1, 0, 0, 0, 0, 1, 0],
        [1, 0, 1, 0, 0, 0, 0, 0, 1]], dtype=torch.int32)

3.3.3 Overall Visualization

See the image below for a detailed visualization.

3115

3.3.4 Use

Call

The complete calling code is as follows. The basic logic is:

  • A sparse tree structure representation is obtained based on the set medusa choices, specifically involving the generate_medusa_buffers() function.
  • Initialize the key and value.
  • Construct a tree attention mask, perform predictions based on the input prompt, and output logits and medusa_logits. This involves the initialize_medusa() function. logits corresponds to the output of lm_head, and medusa_logits corresponds to the output of medusa_head.
  • Extract the top-k predictions from the tree using Medusa heads. These predictions form the candidate paths. This is specifically addressed by the generate_candidates() function.
  • The candidate paths are validated using tree attention to obtain the optimal path. This involves the tree_decoding() and evaluate_posterior() functions. tree_decoding() performs tree-attention-based inference. evaluate_posterior() validates the tree.
  • The corresponding logits (medusa_logits) are selected based on the candidate token sequence, and the inputs, key-value cache, etc., are updated. This specifically involves the update_inference_inputs() function.
def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, top_p=0.8, sampling = 'typical', fast = True, max_steps = 512):

    # Avoid modifying the input_ids in-place
    input_ids = input_ids.clone()

    # Cache medusa buffers (the fixed patterns for tree attention)
    if hasattr(model, "medusa_choices") and model.medusa_choices == medusa_choices:
        # Load the cached medusa buffer
        medusa_buffers = model.medusa_buffers
    else:
        # Initialize the medusa buffer
        # 1. 根据设定的medusa choices得到稀疏的树结构表达
        medusa_buffers = generate_medusa_buffers(
            medusa_choices, device=model.base_model.device
        )
    model.medusa_buffers = medusa_buffers
    model.medusa_choices = medusa_choices

    # Initialize the past key and value states
    if hasattr(model, "past_key_values"):
        past_key_values = model.past_key_values
        past_key_values_data = model.past_key_values_data
        current_length_data = model.current_length_data
        # Reset the past key and value states
        current_length_data.zero_()
    else:
        (
            past_key_values,
            past_key_values_data,
            current_length_data,
        ) = initialize_past_key_values(model.base_model)
        model.past_key_values = past_key_values
        model.past_key_values_data = past_key_values_data
        model.current_length_data = current_length_data

    input_len = input_ids.shape[1]
    reset_medusa_mode(model)
    
    # Initialize tree attention mask and process prefill tokens
    medusa_logits, logits = initialize_medusa(
            input_ids, model, medusa_buffers["medusa_attn_mask"], past_key_values
    )
    new_token = 0
    
    for idx in range(max_steps): 
        # Generate candidates with topk predictions from Medusa heads
        # 用美杜莎头得到的topk预测来生成候选路径。candidates是多个候选 Token 序列。tree_candidates是Token 树
        candidates, tree_candidates = generate_candidates(
                medusa_logits,
                logits,
                medusa_buffers["tree_indices"],
                medusa_buffers["retrieve_indices"],
                temperature, posterior_threshold, posterior_alpha, top_p, sampling, fast
            )
        # Use tree attention to verify the candidates and get predictions
        # 用树注意力验证候选路径。使用 Tree Attention 机制对 tree_candidates 进行验证推理,获得新的 logits 和 medusa_logits 输出。
        medusa_logits, logits, outputs = tree_decoding(
                model,
                tree_candidates,
                past_key_values,
                medusa_buffers["medusa_position_ids"],
                input_ids,
                medusa_buffers["retrieve_indices"],
            )
        # 评估每条路径合理性,得到最佳路径。如果所有序列都没有通过,则只使用第一个 Token,对应 accept_length 为 0,如果某个序列通过,则使用该序列中的已接受的 Token
        best_candidate, accept_length = evaluate_posterior(
                logits, candidates, temperature, posterior_threshold, posterior_alpha , top_p, sampling, fast
            )
        # 根据候选 Token 序列选出对应的 logits,medusa_logits,并更新输入,key、value cache 等
        input_ids, logits, medusa_logits, new_token = update_inference_inputs(
                input_ids,
                candidates,
                best_candidate,
                accept_length,
                medusa_buffers["retrieve_indices"],
                outputs,
                logits,
                medusa_logits,
                new_token,
                past_key_values_data,
                current_length_data,
            )
        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break
        if new_token > 1024:
            break
    return input_ids, new_token, idx

initialization

The initialize_medusa() function performs initialization operations, obtaining logits and mask.

def initialize_medusa(input_ids, model, medusa_attn_mask, past_key_values):
    """
    Initializes the Medusa structure for a given model.

    This function performs the following operations:
    1. Forward pass through the model to obtain the Medusa logits, original model outputs, and logits.
    2. Sets the Medusa attention mask within the base model.

    Args:
    - input_ids (torch.Tensor): The input tensor containing token ids.
    - model (MedusaLMHead): The model containing the Medusa layers and base model.
    - medusa_attn_mask (torch.Tensor): The attention mask designed specifically for the Medusa structure.
    - past_key_values (list of torch.Tensor): Contains past hidden states and past attention values.

    Returns:
    - medusa_logits (torch.Tensor): Logits from the Medusa heads.
    - logits (torch.Tensor): Original logits from the base model.
    """
    medusa_logits, outputs, logits = model(
        input_ids, past_key_values=past_key_values, output_orig=True, medusa_forward=True
    )
    model.base_model.model.medusa_mask = medusa_attn_mask
    return medusa_logits, logits

In the specific model, the medusa_mask and the causal mask are combined to form a new mask. Ultimately, this final combined mask is passed during forward propagation.

class LlamaModel(LlamaPreTrainedModel):
    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
    def _prepare_decoder_attention_mask(
        self, attention_mask, input_shape, inputs_embeds, past_key_values_length
    ):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape,
                # inputs_embeds.dtype,
                torch.float32,  # [MODIFIED] force to cast to float32
                device=inputs_embeds.device,
                past_key_values_length=past_key_values_length,
            )

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(
                attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
            ).to(inputs_embeds.device)
            combined_attention_mask = (
                expanded_attn_mask
                if combined_attention_mask is None
                else expanded_attn_mask + combined_attention_mask
            )

        # [MODIFIED] add medusa mask
        if hasattr(self, "medusa_mask") and self.medusa_mask is not None:
            medusa_mask = self.medusa_mask
            medusa_len = medusa_mask.size(-1)
            combined_attention_mask[:, :, -medusa_len:, -medusa_len:][
                medusa_mask == 0
            ] = combined_attention_mask.min()
            if hasattr(self, "medusa_mode"):
                # debug mode
                if self.medusa_mode == "debug":
                    torch.save(combined_attention_mask, "medusa_mask.pt")

        return combined_attention_mask

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values=None,  # [MODIFIED] past_key_value is KVCache class
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:

        # ......

        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_length_with_past),
                dtype=torch.bool,
                device=inputs_embeds.device,
            )
        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask,
            (batch_size, seq_length),
            inputs_embeds,
            past_key_values_length,
        )
        
        # ......

        # decoder layers
        for idx, decoder_layer in enumerate(self.layers):
            if self.gradient_checkpointing and self.training:
				# ......
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]

        hidden_states = self.norm(hidden_states)
      
        # ......

Generate candidate paths

The details of the generate_candidates() function are as follows: it mainly predicts the top k tokens of each head and assembles them into a candidate sequence that can be parsed into a tree using a Cartesian product.

def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices, temperature = 0, posterior_threshold=0.3, posterior_alpha = 0.09, top_p=0.8, sampling = 'typical', fast = False):
    """
    Generate candidates based on provided logits and indices.
    
    Parameters:
    - medusa_logits (torch.Tensor): Logits from a specialized Medusa structure, aiding in candidate selection.
    - logits (torch.Tensor): Standard logits from a language model.
    - tree_indices (list or torch.Tensor): Indices representing a tree structure, used for mapping candidates.
    - retrieve_indices (list or torch.Tensor): Indices for extracting specific candidate tokens.
    - temperature (float, optional): Controls the diversity of the sampling process. Defaults to 0.
    - posterior_threshold (float, optional): Threshold for typical sampling. Defaults to 0.3.
    - posterior_alpha (float, optional): Scaling factor for the entropy-based threshold in typical sampling. Defaults to 0.09.
    - top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8.
    - sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'.
    - fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False.

    Returns:
    - tuple (torch.Tensor, torch.Tensor): A tuple containing two sets of candidates:
        1. Cartesian candidates derived from the combined original and Medusa logits.
        2. Tree candidates mapped from the Cartesian candidates using tree indices.
    """
    # Greedy decoding: Select the most probable candidate from the original logits.
    if temperature == 0 or fast:
        candidates_logit = torch.argmax(logits[:, -1]).unsqueeze(0)
    else:
        if sampling == 'typical':
            candidates_logit = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0)
        elif sampling == 'nucleus':
            candidates_logit = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0)
        else:
            raise NotImplementedError
    # Extract the TOPK candidates from the medusa logits.
    candidates_medusa_logits = torch.topk(medusa_logits[:, 0, -1], TOPK, dim = -1).indices

    # Combine the selected candidate from the original logits with the topk medusa logits.
    # 把lm head和medusa heads的logits拼接在一起
    candidates = torch.cat([candidates_logit, candidates_medusa_logits.view(-1)], dim=-1)

    # Map the combined candidates to the tree indices to get tree candidates.
    # 从candidates中拿到树对应的节点
    tree_candidates = candidates[tree_indices]

    # Extend the tree candidates by appending a zero.
    tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device)], dim=0)

    # 从树节点中拿到笛卡尔积
    # Retrieve the cartesian candidates using the retrieve indices.
    cart_candidates = tree_candidates_ext[retrieve_indices]

    # Unsqueeze the tree candidates for dimension consistency.
    tree_candidates = tree_candidates.unsqueeze(0)
    return cart_candidates, tree_candidates

Validate candidate paths

The details of the tree_decoding() function are as follows: For the flattened sequence obtained above, the probability of each path is predicted using a basic LLM model. Finally, by restoring the original Cartesian product path using retrieve_indices, the probability of each position on the path can be obtained.

def tree_decoding(
    model,
    tree_candidates,
    past_key_values,
    medusa_position_ids,
    input_ids,
    retrieve_indices,
):
    """
    Decode the tree candidates using the provided model and reorganize the logits.
    
    Parameters:
    - model (nn.Module): Model to be used for decoding the tree candidates.
    - tree_candidates (torch.Tensor): Input candidates based on a tree structure.
    - past_key_values (torch.Tensor): Past states, such as key and value pairs, used in attention layers.
    - medusa_position_ids (torch.Tensor): Positional IDs associated with the Medusa structure.
    - input_ids (torch.Tensor): Input sequence IDs.
    - retrieve_indices (list or torch.Tensor): Indices for reordering the logits.
    
    Returns:
    - tuple: Returns medusa logits, regular logits, and other outputs from the model.
    """

    # Compute new position IDs by adding the Medusa position IDs to the length of the input sequence.
    position_ids = medusa_position_ids + input_ids.shape[1]

    # Use the model to decode the tree candidates. 
    # The model is expected to return logits for the Medusa structure, original logits, and possibly other outputs.
    tree_medusa_logits, outputs, tree_logits = model(
        tree_candidates,
        output_orig=True,
        past_key_values=past_key_values,
        position_ids=position_ids,
        medusa_forward=True,
    )
    
    # Reorder the obtained logits based on the retrieve_indices to ensure consistency with some reference ordering.
    logits = tree_logits[0, retrieve_indices] # 从logits里面根据retrieve_indices获取笛卡尔积
    medusa_logits = tree_medusa_logits[:, 0, retrieve_indices]
    return medusa_logits, logits, outputs

Calculate the optimal path

The evaluate_posterior() function calculates the optimal path.

def evaluate_posterior(
    logits, candidates, temperature, posterior_threshold=0.3, posterior_alpha = 0.09, top_p=0.8, sampling = 'typical', fast = True
):
    """
    Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate.

    Depending on the temperature value, the function either uses greedy decoding or evaluates posterior
    probabilities to select the best candidate.

    Args:
    - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size).
    - candidates (torch.Tensor): Candidate token sequences.
    - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding.
    - posterior_threshold (float): Threshold for posterior probability.
    - posterior_alpha (float): Scaling factor for the threshold.
    - top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8.
    - sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'.
    - fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False.
    Returns:
    - best_candidate (torch.Tensor): Index of the chosen best candidate.
    - accept_length (int): Length of the accepted candidate sequence.
    """
    # Greedy decoding based on temperature value
    if temperature == 0:
        # Find the tokens that match the maximum logits for each position in the sequence
        posterior_mask = (
            candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
        ).int()
        candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
        accept_length = candidates_accept_length.max()
        # Choose the best candidate
        if accept_length == 0:
            # Default to the first candidate if none are accepted
            best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
        else:
            best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
        return best_candidate, accept_length
        
    if sampling == 'typical':
        if fast:
            posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
            candidates_prob = torch.gather(
                posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
            ).squeeze(-1)
            posterior_entropy = -torch.sum(
                posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
            )  # torch.sum(torch.log(*)) is faster than torch.prod
            threshold = torch.minimum(
                torch.ones_like(posterior_entropy) * posterior_threshold,
                torch.exp(-posterior_entropy) * posterior_alpha,
            )
            posterior_mask = candidates_prob > threshold
            candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)

            # Choose the best candidate based on the evaluated posterior probabilities
            accept_length = candidates_accept_length.max()
            if accept_length == 0:
                # If no candidates are accepted, just choose the first one
                best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
            else:
                best_candidates = torch.where(candidates_accept_length == accept_length)[0]
                # Accept the best one according to likelihood
                likelihood = torch.sum(
                    torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
                )
                best_candidate = best_candidates[torch.argmax(likelihood)]
            return best_candidate, accept_length
        # Calculate posterior probabilities and thresholds for candidate selection
        posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha, fast)
        candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
        # Choose the best candidate based on the evaluated posterior probabilities
        accept_length = candidates_accept_length.max()
        
        if accept_length == 0:
            # If no candidates are accepted, just choose the first one
            best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
        else:
            best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
            # Accept the best one according to likelihood
        return best_candidate, accept_length
    
    if sampling == 'nucleus':
        assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1"
        posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p)
        candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
        accept_length = candidates_accept_length.max()
        # Choose the best candidate
        if accept_length == 0:
            # Default to the first candidate if none are accepted
            best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
        else:
            best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
        return best_candidate, accept_length
    else:
        raise NotImplementedError

3.4 Typical Acceptance

In speculative decoding, rejection sampling involves randomly sampling a token sequence from the output of a draft model and then using the original model to verify acceptance. If verification fails, resampling continues until a suitable token sequence is found. In practical applications, it’s often unnecessary to perfectly match the distribution of the original model; ensuring output quality and diversity is sufficient. This allows for obtaining more reasonable candidate tokens and speeds up the decoding process. Therefore, Medusa uses a typical acceptance scheme. This scheme uses a temperature threshold based on the probability predicted by the original model to determine whether to accept a candidate token. If the probability of a candidate token exceeds the threshold, it is considered “typical” and should be accepted.

3.4.1 Common Application Methods

The output of an LLM model is a probability distribution over a vocabulary, and the sampling strategy directly determines the output effect. Sometimes we want completely deterministic results, and sometimes we want richer and more interesting results.

The output of deterministic sampling is deterministic; it is essentially a search process. Two typical methods are as follows.

  • Greedy Search. It outputs the token with the highest probability each time.
  • Beam Search. Maintain a beam of size k. Expand all paths in the current beam using the next token, select the k paths with the highest cumulative probability as the new beam, and so on.

Probabilistic sampling is based on probability distributions, and there are three common types.

  • Multinomial sampling. Directly sampling based on probability distributions with pure randomness can easily result in sampling words with extremely low probability.
  • Top-k sampling. Random sampling is performed on the top k candidates by probability ranking. Note that re-normalization is performed before sampling.
  • Top-p sampling, also known as Nucleus sampling, first sorts the output probabilities from largest to smallest, and then performs random sampling on these candidate sets whose cumulative probabilities reach p. This process also requires re-normalization.

Sampling-based methods often include a temperature parameter; the higher the temperature, the greater the diversity of sampling, making them suitable for creative generation scenarios, such as essay writing.

3.4.2 Approach

In speculative decoding, the authors employed rejection sampling to generate outputs with a distribution consistent with the original model. However, subsequent research revealed that this sampling strategy leads to decreased efficiency as sampling temperature increases. For example, the draft model might be as good as the target model, with perfectly aligned distributions. In this state, we should accept all outputs from the draft model. However, because the draft model samples independently from the original model, higher temperatures generally correspond to stronger creativity, increasing the diversity of candidate tokens selected by the draft model. This reduces the probability of accepting tokens that hit the original model, resulting in a short parallel decoding length. In contrast, greedy decoding, which accepts all outputs from the draft model, maximizes efficiency.

However, this characteristic is unreasonable. In real-world scenarios, language model sampling is typically used to generate diverse responses, while the temperature parameter is only used to modulate the “creativity” of the response. Therefore, a higher temperature should lead to the original model having a greater chance of accepting the output of the draft model, but not necessarily matching the distribution of the original model. So why not just focus on accepting candidate tokens that seem plausible?

3.4.3 Typical Acceptance

MEDUSA argues that since sampling aims for creativity, the distribution of candidate sequences doesn’t need to perfectly match the original model’s distribution. Our goal should be to select typical candidates; that is, any candidate sequence that isn’t an extremely unlikely outcome is acceptable. Intuitively, this means that in the LLM decoding process, we don’t need overly definite words, nor can we have too many unexpected words. This ensures we obtain a rich vocabulary while avoiding repetitive generation.

Therefore, Medusa drew inspiration from truncation sampling to expand the pool of candidates that the original model might accept. Medusa sets a threshold based on the predicted probability of the original model. If a candidate token exceeds this threshold, the token and its prefix are accepted, and Greedy sampling is performed on these tokens to select the top-k. This threshold is related to the predicted probability of the original model.

Specifically, the authors use the minimum of the hard threshold and the entropy-dependent threshold to decide whether to accept a candidate token, as in truncation sampling. This ensures the selection of meaningful tokens and reasonable continuation during decoding. The authors always use Greedy Decoding to accept the first token, ensuring that at least one token is generated at each step. Finally, the longest accepted candidate sequence is selected as the final result. The advantage of this method is its adaptability: if you set the sampling temperature to zero, it simply reverts to the most efficient form, Greedy Search. As you increase the temperature, this method becomes more efficient, allowing for longer accepted sequences.

  • When there are individual tokens with a high probability in the probability distribution, the entropy is small, is large, and the conditions for accepting the token are more stringent.
  • When the probability of each token in the probability distribution is relatively average, the entropy is large, is small, and the conditions for accepting tokens are more lenient.

3116

The specific implementation is located in the evaluate_posterior() function, which will not be described in detail here.

0x04 Training

MEDUSA’s classification heads need to be trained to achieve good prediction results. Different training methods can be chosen depending on the conditions:

  • MEDUSA-1: Freezes the original model’s backbone (including the original model’s decoder head) and trains only the added decoder head. This approach is suitable for situations with limited computational resources or where you don’t want to affect the performance of the original model. QLoRA can also be used to train the decoder head, further saving memory and computational resources.
  • MEDUSA-2: The original model and the MEDUSA decoder heads are trained together. While training methods like MEDUSA-1 save resources, they don’t maximize the speedup effect of multiple decoder heads. MEDUSA-2, however, further leverages the speedup capabilities of the MEDUSA decoder heads. Furthermore, because the base model is trained alongside the MEDUSA heads, the distribution of MEDUSA heads is ensured to remain consistent with the original model, mitigating distribution drift and significantly improving head accuracy. MEDUSA-2 is suitable for scenarios with ample computational resources or for performing SFTs from a base model.

Additionally, if the original model’s SFT dataset is available, training can proceed directly. If the original model’s SFT data is unavailable, or if the original model was trained using RLHF, training data for the MEDUSA head can be obtained through self-distillation.

4.1 MEDUSA-1

MEDUSA-1 freezes the parameters of the original model and trains only the newly added decoder heads. Training the heads using MEDUSA-1 primarily involves calculating the cross-entropy loss between the predictions of the MEDUSA Heads and the Ground Truth. Specifically, this is calculated for the Ground Truth at position t+k+1, . Then the training loss for the k-th head can be written as:

And when k is large, This will also increase because as k increases, the predictions for later heads become more uncertain. To balance the magnitude of the loss across the various heads, therefore Add exponentially decaying weight parameters To balance the losses of different heads, the final Medusa loss is calculated as follows:

Here This is the scaling factor for each decoder, a series of hyperparameters. Because a larger k value corresponds to a higher prediction difficulty for the decoder, resulting in a larger loss, a scaling factor is used to adjust this to prevent later decoders from overly dominating training. In practical applications, .

4.2 MEDUSA-2

To further improve the accuracy of Medusa Heads, MEDUSA-2 trains the original model together with multiple decoder heads, resulting in higher accuracy and acceleration rates for each decoder head. However, maintaining the output quality of the original model requires some special training techniques. Medusa-2 uses the following three strategies to achieve this goal.

Combined loss

To maintain the next token prediction capability of the backbone model, the cross-entropy loss of the backbone model needs to be reduced. Add it to the Medusa loss, that is, add the loss of the original model’s decoder as well. A weighting factor also needs to be added. This is to balance the losses between the backbone and Medusa Heads. Specifically, the formula is as follows:

In practical use, during direct training . When using self-distillation .

Differential learning rates

The original model is already trained, while MEDUSA heads require more training. Therefore, using the same learning rate for both the original model and the newly added decoder is not appropriate. We can allow the new decoder to use a larger learning rate, while the original model parameters use a relatively smaller learning rate, to achieve faster convergence of MEDUSA heads while preserving the capabilities of the backbone model. In practice, the learning rate difference is set to 4 times, for example, using 2e-3 and 5e-4 respectively.

Heads warmup

The newly added decoder will initially experience a large loss during training, leading to larger gradients and potentially impairing the original model’s capabilities. To address this issue, a two-stage training process can be used. In the first stage, only the decoder is trained under the MEDUSA-1 strategy. In the second stage, MEDUSA-2 training is then performed. This is essentially equivalent to… It gradually increases during training.

4.3 Code

Let’s take a look at how a pre-trained LLM can be adapted to MEDUSA, which involves the following steps:

  • Add decoding headers: Add several MEDUSA decoding headers after the last hidden layer of the LLM.
  • Initialize the decoder head: You can use random initialization or initialize it using the parameters of the original model’s decoder head, which can speed up the training process.
  • Choose a training strategy: Select either MEDUSA-1 or MEDUSA-2 strategy based on the actual situation.
  • Preparing training data: You can reuse the original model’s training data or use a self-distillation method to generate training data.
  • Training: Train the MEDUSA decoder head or simultaneously fine-tune the LLM based on the selected strategy and data.

The specific training code is as follows. First, several new heads need to be trained. Different heads predict different label offsets, so the top k values of each head can be assembled as candidates.

# Customized for training Medusa heads
class CustomizedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        Compute the training loss for the model.

        Args:
            model (torch.nn.Module): The model for which to compute the loss.
            inputs (dict): The input data, including input IDs, attention mask, and labels.
            return_outputs (bool): Whether to return model outputs along with the loss.

        Returns:
            Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs.
        """
        # DDP will give us model.module
        if hasattr(model, "module"):
            medusa = model.module.medusa
        else:
            medusa = model.medusa

        logits = model(
            input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
        )
        labels = inputs["labels"]
        # Shift so that tokens < n predict n
        loss = 0
        loss_fct = CrossEntropyLoss()
        log = {}
        for i in range(medusa):
            medusa_logits = logits[i, :, : -(2 + i)].contiguous()
            # 常规的标签需要偏移1个位置, 由于不训练LM Head,所以偏移2个位置.
            medusa_labels = labels[..., 2 + i :].contiguous()
            medusa_logits = medusa_logits.view(-1, logits.shape[-1])
            medusa_labels = medusa_labels.view(-1)
            medusa_labels = medusa_labels.to(medusa_logits.device)
            loss_i = loss_fct(medusa_logits, medusa_labels)
            loss += loss_i
            not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID)
            medusa_labels = medusa_labels[not_ignore]

            # Add top-k accuracy
            for k in range(1, 2):
                _, topk = medusa_logits.topk(k, dim=-1)
                topk = topk[not_ignore]
                correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)

        return (loss, logits) if return_outputs else loss

0x05 Decoding

5.1 Example

The official GitHub source code provides the forward propagation code as follows.

@contextmanager
def timed(wall_times, key):
    start = time.time()
    torch.cuda.synchronize()
    yield
    torch.cuda.synchronize()
    end = time.time()
    elapsed_time = end - start
    wall_times[key].append(elapsed_time)

def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):
    wall_times = {'medusa': [], 'tree': [], 'posterior': [], 'update': [], 'init': []}
    
    with timed(wall_times, 'init'):
        if hasattr(model, "medusa_choices") and model.medusa_choices == medusa_choices:
            # Load the cached medusa buffer
            medusa_buffers = model.medusa_buffers
        else:
            # Initialize the medusa buffer
            medusa_buffers = generate_medusa_buffers(
                medusa_choices, device=model.base_model.device
            )
        model.medusa_buffers = medusa_buffers
        model.medusa_choices = medusa_choices

        # Initialize the past key and value states
        if hasattr(model, "past_key_values"):
            past_key_values = model.past_key_values
            past_key_values_data = model.past_key_values_data
            current_length_data = model.current_length_data
            # Reset the past key and value states
            current_length_data.zero_()
        else:
            (
                past_key_values,
                past_key_values_data,
                current_length_data,
            ) = initialize_past_key_values(model.base_model)
            model.past_key_values = past_key_values
            model.past_key_values_data = past_key_values_data
            model.current_length_data = current_length_data

        input_len = input_ids.shape[1]
        reset_medusa_mode(model)
        medusa_logits, logits = initialize_medusa(
                input_ids, model, medusa_buffers["medusa_attn_mask"], past_key_values
        )
    new_token = 0

    for idx in range(max_steps): 
        with timed(wall_times, 'medusa'):
            candidates, tree_candidates = generate_candidates(
                    medusa_logits,
                    logits,
                    medusa_buffers["tree_indices"],
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'tree'):
            medusa_logits, logits, outputs = tree_decoding(
                    model,
                    tree_candidates,
                    past_key_values,
                    medusa_buffers["medusa_position_ids"],
                    input_ids,
                    medusa_buffers["retrieve_indices"],
                )

        with timed(wall_times, 'posterior'):
            best_candidate, accept_length = evaluate_posterior(
                    logits, candidates, temperature, posterior_threshold, posterior_alpha
                )
        
        with timed(wall_times, 'update'):
            input_ids, logits, medusa_logits, new_token = update_inference_inputs(
                    input_ids,
                    candidates,
                    best_candidate,
                    accept_length,
                    medusa_buffers["retrieve_indices"],
                    outputs,
                    logits,
                    medusa_logits,
                    new_token,
                    past_key_values_data,
                    current_length_data,
                )

        if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
            break

    return input_ids, new_token, idx, wall_times

The following is an example of how to call a method.

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3" # define GPU id, remove if you want to use all GPUs available
import torch
from tqdm import tqdm
import time
from contextlib import contextmanager
import numpy as np
from medusa.model.modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
from medusa.model.medusa_model import MedusaModel
from medusa.model.kv_cache import *
from medusa.model.utils import *
from medusa.model.medusa_choices import *
import transformers
from huggingface_hub import hf_hub_download

# 加载模型
model_name = 'FasterDecoding/medusa-vicuna-7b-v1.3'
model = MedusaModel.from_pretrained(
    model_name,
    medusa_num_heads = 4,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="auto"
)
tokenizer = model.get_tokenizer()

medusa_choices = mc_sim_7b_63

# 设置推理参数
temperature = 0.
posterior_threshold = 0.09
posterior_alpha = 0.3

# 设置prompt
prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Hi, could you share a tale about a charming llama that grows Medusa-like hair and starts its own coffee shop? ASSISTANT:"

# 执行推理
with torch.inference_mode():
    input_ids = tokenizer([prompt]).input_ids
    output_ids, new_token, idx, wall_time = medusa_forward(
                    torch.as_tensor(input_ids).cuda(),
                    model,
                    tokenizer,
                    medusa_choices,
                    temperature,
                    posterior_threshold,
                    posterior_alpha,
                )
    output_ids = output_ids[0][len(input_ids[0]) :]
    print("Output length:", output_ids.size(-1))
    print("Compression ratio:", new_token / idx)
    
# 解码
output = tokenizer.decode(
                    output_ids,
                    spaces_between_special_tokens=False,
                )
print(output)

5.2 Computational and Space Complexity

The figure below shows the computational and space complexity of the prefill, decoding, and MEDUSA decoding stages.

  • b is the batch size.
  • s is the sequence length.
  • h is a hidden dimension.
  • i是intermediate dimension。
  • n is the number of attention heads.
  • d is the head dimension.
  • q is the candidate length of MEDUSA.

3117

Additionally, the diagram below illustrates the operation flow of Medusa. When there is no operator fusion or Tiling strategy, , DCM (Dense Causal Mask), and Softmax can both lead to a large number of I/O operations between video memory and on-chip cache.

3118

0xFF Reference

SpecInfer: Accelerating Generative Large Language Model Serving with Speculative Inference and Token Tree Verification

Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

LLM Speculation Decoding & Medusa Enables AI Chatter

[Tearing Apart LLM-Medusa] Parallel Decoding Paradigm: Medusa is Here, Get Out of the Way!! (by Xiaodonggua AIGC)

Fang Jiarui: A Clever Technique for Large Model Inference—Speculative Decoding

[Transformer 101 Series] In-depth Study of LLM Speculative Sampling (Part 1) aaronxic

https://github.com/FasterDecoding/Medusa/blob/main/notebooks/medusa_introduction.ipynb

Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads, Jan 2024, Princeton University. Proceedings of the ICML 2024.

[2401.10774] Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

GitHub - FasterDecoding/Medusa: Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads

LLM Inference Acceleration: Medusa - An Inheritance and Development of Blockwise Parallel Decoding ( by Fang Jiarui)

Fang Jiarui: The Renaissance of LLM Inference Acceleration: Noam Shazeer and Blockwise Parallel Decoding

A 10,000-word overview of 10+ LLM speculative sampling inference acceleration solutions; AI casual discussion.

[2401.07851] Unlocking Efficiency in Large Language Model Inference: A Comprehensive Survey of Speculative Decoding

A quick look at Medusa and Lookahead’s speculative reasoning is done by A-Yuan.

Open Source Progress | Medusa: Using multi-head decoding, inference speed for large models is improved by more than 2 times .

arXiv:1811.03115: Berkey, Google Brain, Blockwise Parallel Decoding for Deep Autoregressive Models.

arXiv:2211.17192: Google Research, Fast Inference from Transformers via Speculative Decoding

arXiv:2202.00666: ETH Zürich、University of Cambridge,Locally Typical Sampling

[4] arXiv:2106.05234: Dalian University of Technology、Princeton University、Peking University、Microsoft Research Asia,Do Transformers Really Perform Bad for Graph Representation?

A 30,000-word detailed analysis of Tsinghua University’s latest review work: A review of efficient inference using large models .

Accelerating Large Model Inference - MEDUSA Linsight

LLM Deduction Acceleration - Medusa uuuuu

[Tearing Apart LLM-Medusa] Parallel Decoding Paradigm: Medusa is Here, Get Out of the Way!! (by Xiaodonggua AIGC)

Blockwise Parallel Decoding: Paper Interpretation and AI Talk

LLM Speculation Decoding & Medusa Enables AI Chatter

https://sites.google.com/view/medusa-llm

https://github.com/FasterDecoding/Medusa

Clover: A Casual Discussion on Speculative Sampling AI that Outperforms Medusa

[2405.00263] Clover: Regressive Lightweight Speculative Decoding with Sequential Knowledge

Hydra: Sequentially- Dependent Draft Heads for Medusa Decoding

Hydra: Sequentially-Dependent Draft Heads for Medusa Decoding

[Paper Interpretation] Medusa: Using Multiple Decoders to Predict Multiple Subsequent Tokens in Parallel (tomsheep)

LLM Inference Acceleration (Part 3): Medusa Speculative Sampling Yueda