CANN-昇腾NPU-Speculative-Decoding进阶-怎么选draft模型
Speculative Decoding 用小模型(draft)猜 token,大模型(target)批量验证,猜对的直接接受。猜对率(acceptance rate)是关键——取决于 draft 模型跟 target 模型的分布匹配度。CANN 的 ATB 支持 Speculative Decoding,这篇讲怎么选 draft 模型让猜对率最高。
回顾:Speculative Decoding 流程
draft 模型猜 5 个 token: [A, B, C, D, E]
target 模型一次性验证 5 个 token:
A ✓, B ✓, C ✗ → 接受 A, B,拒绝 C, D, E
target 模型从 C 开始重新采样
猜对 2/5 = acceptance rate 40%。加速比 = 1 / (1 - acceptance_rate) = 1.67×。
Draft 模型选择的三条原则
原则 1:同架构优先
Llama2-7B 的 draft 优先选 Llama2-1.1B(同架构),不要选 Qwen-1.5B(不同架构)。
原因:同架构的 token embedding 和词表完全一致,分布更接近。
from atb import LLM, SpeculativeConfig
# 同架构:Llama2-7B + Llama2-1.1B
model = LLM(
"meta-llama/Llama-2-7b-hf",
device="npu:0",
speculative=SpeculativeConfig(
draft_model="meta-llama/Llama-2-1.1b-hf",
num_spec_tokens=5, # 每次猜 5 个 token
)
)
原则 2:大小比例 5-10×
| Target | Draft(推荐) | 大小比例 |
|---|---|---|
| 7B | 1.1B | 6× |
| 13B | 1.1B | 12× |
| 70B | 7B | 10× |
比例太大(>15×):draft 分布跟 target 差距大,acceptance rate 低。
比例太小(<3×):draft 太慢,加速比上不去。
原则 3:同训练数据优先
用同一份数据训练的模型,分布更接近:
Llama2-7B ← Llama2 预训练数据
Llama2-1.1B ← Llama2 预训练数据(同数据)
→ acceptance rate ~65%
Llama2-7B ← Llama2 预训练数据
Qwen-1.5B ← Qwen 预训练数据(不同数据)
→ acceptance rate ~40%
Acceptance Rate 实测
Llama2-7B target,Atlas 800I A2:
| Draft 模型 | 架构 | Acceptance Rate | 加速比 | TTFT 开销 |
|---|---|---|---|---|
| Llama2-1.1B | 同 | 65% | 2.1× | +15ms |
| TinyLlama-1.1B | 同 | 55% | 1.7× | +15ms |
| Qwen-1.5B | 异 | 40% | 1.3× | +20ms |
| GPT2-small | 异 | 25% | 0.9× | +25ms |
GPT2-small 的加速比 <1,因为 acceptance rate 太低,验证开销比收益还大。
num_spec_tokens 调优
猜多少个 token 最优?取决于 acceptance rate:
acceptance_rate = p
num_spec_tokens = k
期望接受 token 数 = 1 + p + p² + ... + p^(k-1) = (1 - p^k) / (1 - p)
最优 k:
p=0.7 → k=5(期望 3.3 token)
p=0.5 → k=3(期望 1.75 token)
p=0.3 → k=2(期望 1.3 token)
# 高 acceptance rate:多猜几个
speculative=SpeculativeConfig(
draft_model="Llama-2-1.1b-hf",
num_spec_tokens=5, # acceptance ~65%,猜 5 个
)
# 低 acceptance rate:少猜几个
speculative=SpeculativeConfig(
draft_model="Qwen-1.5b",
num_spec_tokens=3, # acceptance ~40%,猜 3 个
)
Draft 模型的显存预算
Draft 模型也占显存:
| Draft | 参数量 | 显存 (bf16) |
|---|---|---|
| Llama2-1.1B | 1.1B | 2.2 GB |
| Llama2-7B | 7B | 14 GB |
64GB 显存:Target 7B (14GB) + Draft 1.1B (2.2GB) + KV Cache (40GB) = 56.2GB ✅
64GB 显存:Target 70B (140GB) + Draft 7B (14GB) = 需要多卡 TP
自适应 Speculative Decoding
ATB 支持根据 acceptance rate 动态调整 num_spec_tokens:
speculative=SpeculativeConfig(
draft_model="Llama-2-1.1b-hf",
num_spec_tokens=5,
adaptive=True, # 开启自适应
min_spec_tokens=2, # 最少猜 2 个
max_spec_tokens=8, # 最多猜 8 个
)
ATB 内部跟踪最近的 acceptance rate:
- rate > 70%:增加 num_spec_tokens(多猜)
- rate < 40%:减少 num_spec_tokens(少猜)
Medusa Head:不用 draft 模型
除了独立的 draft 模型,还有 Medusa 方案——在 target 模型上加多个预测头:
# Medusa:target 模型自带的多个预测头
# 不需要额外 draft 模型,不占额外显存
speculative=SpeculativeConfig(
draft_model=None, # 不用独立 draft 模型
medusa_heads=4, # 4 个 Medusa 预测头
)
Medusa 的 acceptance rate 比 draft 模型低(~45%),但不占额外显存、无 TTFT 开销。
Speculative Decoding 的加速效果取决于 draft 模型选择。三原则:同架构、5-10× 大小、同训练数据。Llama2-7B + Llama2-1.1B 的 acceptance rate ~65%,加速 2.1×。仓库在这里:
https://atomgit.com/cann/ATB
https://atomgit.com/cann/ops-transformer
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)