Faster Text Generation with Self-Speculative Decoding
Faster Text Generation with Self-Speculative Decoding
Self-speculative decoding is a novel approach to text generation that combines the strengths of speculative decoding with early exiting from a large language model (LLM). This method allows for efficient generation by using the same model's early layers for drafting tokens, and later layers for verification.
Traditional Speculative Decoding
Traditional speculative decoding uses two models: a smaller one (draft model) to generate a sequence of draft tokens, and a larger one (verification model) to verify the draft's accuracy. The smaller model performs a significant portion of the generation, while the larger model refines the results. This increases text generation speed since the larger model verifies full sequences at once, instead of generating one draft at a time.
Self-Speculative Decoding
In self-speculative decoding, the authors build on this concept but use the early layers of a large model to generate draft tokens that are then verified by the model's deeper layers. This "self" aspect of speculative decoding, which requires specific training, allows the model to perform both drafting and verification. This, in turn, improves speed and reduces computational costs compared to traditional speculative decoding.
Usage with Transformers
To enable early-exit self-speculative decoding in the š¤ transformers library, we just need to add the assistant_early_exit argument to the generate() function.
import transformers
early_exit_layer = 4
prompt = "Alice and Bob"
checkpoint = "facebook/layerskip-llama2-7B"
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint).to("cuda")
outputs = model.generate(**inputs, assistant_early_exit=early_exit_layer)
Note: While the assistant_early_exit argument can potentially enable early-exit self-speculative decoding for any decoder-only transformer, the logits from the intermediate layers cannot be unembedded (process of decoding through LM Head, described later in the blog post) unless the model is specifically trained for that. You will also only obtain speedups for a checkpoint that was trained in such a way to increase the accuracy of earlier layers.
Benchmarking
We ran an extensive list of benchmarks to measure the speedup of LayerSkip's self-speculative decoding with respect to autoregressive decoding on various models. We also compare self-speculative decoding (based on early exiting) with standard speculative decoding techniques. To reproduce the results, you may find the code here and the command to run each experiment in this spreadsheet.
Training Modifications: Layer Dropout and Early Exit Loss
In the training phase, we introduce layer dropout, which allows the model to skip certain layers during training. The dropout rate increases progressively in deeper layers, making the model less reliant on its later layers, as well as enhancing the model's generalization and speeding up training.
In addition to layer dropout, early exit loss is applied to ensure the LM head learns to unembed different layers. The total loss function for training the model with early exits is given by a summation of normalized loss from each exit (intermediate layers). This technique enables efficient training by distributing the learning task across all layers.
Self-Drafting and Self-Verification
Once training is complete, we can apply self-speculative decoding during inference. The process begins with self-drafting, where tokens are generated by exiting early from some intermediate layer. The number of speculative tokens defines how many draft tokens are produced during this stage, and the layer we exit at defines how large and accurate is the draft stage.
The next stage is self-verification, where the full model is used to verify the draft tokens. The verification model reuses the portion of cache from the draft model. If the draft tokens align with the verified tokens, they are added to the final output, resulting in a better usage of the memory bandwidth in our system, because itās much more expensive to generate a sequence of tokens with the full model than verifying a draft, as long as several of the tokens match.
Optimizations: Shared Weights, Shared KV Cache, and Shared Compute
Self-speculative decoding benefits significantly from cache reuse, particularly the KV cache, which stores key-value pairs computed during the drafting stage. This cache allows the model to skip redundant calculations, as both the draft and verification stages use the same early layers.
Compared to traditional two-model speculative decoding, early-exit self-speculative decoding can benefit from the following savings:
- Shared Weights: Reuses the weights from the first E E E layers for both drafting and verification.
- Shared KV Cache: Reuses key-value pairs from the first E E E layers for both drafting and verification.
- Shared Compute: Reuses the compute of the first E E E layers by using a Exit Query Cache that saves only the query vector of the exit layer Eā1E-1Eā1 so that the verification process wonāt need to compute layers 0 0 0 to Eā1 E-1 Eā1.
The combination of KV and exit query caches, known as the KVQ cache, reduces memory overhead and improves inference latency.
How Early Can We Exit?
The early exit layer of the draft stage is a hyperparameter that we can tune or modify during inference:
- The earlier we exit, the faster the generation of draft tokens are but the less accurate they will be.
- The later we exit, the more accurate the draft tokens generated are but the slower their generation will be.
We wrote a script to sweep across different early exit layers and measure the tokens per second on A100 GPUs. In the Tables below we plot the tokens per second versus the early exit layer for different Llama models for both LayerSkip and baseline checkpoints (you can view the full logs here).
Observations
For the baseline checkpoints that have not been pretrained or continually pretrained with the LayerSkip training recipe, early exit self-speculative decoding is slower than autoregressive decoding. This is because during training of most LLMs, earlier layers are not motivated to learn to predict the output, and hence generating tokens using earlier layers will have a very low acceptance rate.
On the other hand, for the Llama checkpoints that were continually pre-trained with the LayerSkip training, early exit self-speculative decoding has higher speedup than autoregressive decoding for at least a subset of the layers. For most models, except Llama3.2 1B, we notice a regular pattern when we traverse across layers: speedup starts low for the first few layers, increases gradually to a sweet spot, and then decreases again.
The early exit layer sweet spot is when we have the optimal tradeoff between high accuracy of predictions and low overhead of generating tokens. This sweet spot depends on each model, and may also depend on the prompt or domain of the prompt.
Conclusion
LayerSkip leverages the synergy between early exit, layer dropout, and cache reuse to create a fast and efficient text generation pipeline. By training the model to unembed outputs from different layers and optimizing the verification process with caches, this approach strikes a balance between speed and accuracy. As a result, it significantly improves inference times in large language models while maintaining high-quality outputs. It also reduces memory compared to traditional speculative decoding techniques due to a single model used as both the draft and verification model.
Self-speculation is an exciting field where the same LLM can create draft tokens and fix itself. Other self-speculation approaches include:
- Draft & Verify: where the draft stage involves skipping pre-determined attention and feed forward layers.
- MagicDec: where the draft stage uses a subset of the KV cache, which is useful for long context inputs.
- Jacobi Decoding and Lookahead Decoding: Where the draft stage are a series of āguess tokensā that could be either random or obtained from a n-gram lookup table.
More Articles from our Blog
- Universal Assisted Generation: Faster Decoding with Any Assistant Model
- Letting Large Models Debate: The First Multilingual LLM Debate Competition




