vLLM 02 - speculative decoding
Why Speculative Decoding (SD)?
Decoding is memory-bound: loading KV cache and model takes a long time
memory-bound cases: big matrix * small matrix; vector * matrix → O(n^2)
compute-bound cases: large matrix * matrix → O(n^3)
Find a way to increase compute while not significantly increasing GPU memory access
Solution: Guess multiple tokens and verify
- Example: In terms of token generation per iteration:
- Guess 3 tokens, acceptance rate 2/3
- 2 tokens of guessing are correct + LLM inference will generate a new token –> 3 tokens
- Iteration time
- Computation: (1+3)x
- Memory:
- w/o SD: Model parameters (8x2 GB) + KV caches (n * 100 KB)
- w/ SD: Model parameters (8x2 GB) + KV caches ((n+3) * 100 KB)
- Iteration time almost unchanged
- Example: In terms of token generation per iteration:
How to guess tokens?
N-gram
- most-used algorithm in production; simple and effective
- N-grams essentially build a mapping like: if last 3 tokens are A, B, C, next 2 tokens are D, E
- Example with “To be or not to be, this is a question”:
- [To be or] –> [not to]
- [be or not] –> [to be]
- …
- [that is] –> [a question]
Build N-gram from request input, use this N-gram to guess tokens.
- Example:
1 | |
- Assume that LLM already generated:
1 | |
- [To be or] –> [not to]
- Guess: next two tokens are [not to]
- Verify: yes yes
Bottleneck of SD: token-guessing algorithm
Tree verification
- N-gram mapping can be one-to-many, because some phrases appear multiple times with different suffixes
- [To be or] –> [not to], [sleep in], [go to]
- [To be or] guess multiple choices: [not to], [sleep in], [go to]
- We need an efficient kernel for tree verification
How to tell if the verification is right or wrong?
- Deterministic sampling: bad case for SD, because only one correct answer and the acceptance rate may be low
- Random sampling: correct when guess probability > threshold
Example: guessed tokens are right
1 | |
Example: guessed tokens are wrong
1 | |
Model-based (draft model)
Parallel guessing: guesses next few tokens independently
- Fast but worse performance
Autoregressive guessing
Why deployment/production is so hard?
Acceptance rate is high enough (> 75%)… so that’s not an issue
SD is beneficial or not?
- Workload may already be compute-bound
- Batch size small: more memory-bound
- Batch size large: more compute-bound
- SD makes it move from memory-bound to compute-bound.
- In practice, the workload may already be compute-bound, and SD will make it worse
How many tokens should we guess? - It should be determined based on arithmetic intensity
- Arithmetic intensity: metric to measure whether it’s compute-bound or memory-bound.
- Arithmetic intensity = FLOPs / Bytes (or memory bandwidth)
- Every hardware has a suitable intensity
Other engineering issues
Small model needs KV cache. How can we allocate that?
- If put with vLLM, extra mem
- If not, not a single pool
Small model may need different parallel config. For example, model TP=8, draft model may need TP=2
- Assume small model is no TP, and is on GPU0 + vLLM forces same GPU utilization on different GPUs → memory waste on other GPU memory
Pre-allocate KV cache for guessed tokens.
- You break the system assumption that only one token is generated each time. You may need to change the whole vLLM interface
- What if pre-allocated tokens cross vLLM’s block boundary?
- Need to discard wrong tokens
Sampling → verification
Minimize overhead (ngram) for loop is slow
How many # of tokens should we guess
How to distinguish between requests
- Different requests: different # of tokens, part of them do not run spec decode
Summary of challenges above:
- When SD is beneficial, how to find the best configuration
- How to detect when SD is not beneficial? How to turn off SD in such cases?
Paper reading:
- Optimizing Speculative Decoding for Serving Large Language Models Using Goodput
- LLM Inference Unveiled: Survey and Roofline Model Insights
Source
https://www.youtube.com/watch?v=WF5xaQqtKUE&list=PLJj_urhaf2_qxpg8A5-6xoMvMLBKQMTX1&index=19