Speculative Decoding 用小模型(draft)猜 token,大模型(target)批量验证,猜对的直接接受。猜对率(acceptance rate)是关键——取决于 draft 模型跟 target 模型的分布匹配度。CANN 的 ATB 支持 Speculative Decoding,这篇讲怎么选 draft 模型让猜对率最高。

回顾:Speculative Decoding 流程

draft 模型猜 5 个 token: [A, B, C, D, E]
target 模型一次性验证 5 个 token:
  A ✓, B ✓, C ✗ → 接受 A, B,拒绝 C, D, E
  target 模型从 C 开始重新采样

猜对 2/5 = acceptance rate 40%。加速比 = 1 / (1 - acceptance_rate) = 1.67×。

Draft 模型选择的三条原则

原则 1:同架构优先

Llama2-7B 的 draft 优先选 Llama2-1.1B(同架构),不要选 Qwen-1.5B(不同架构)。

原因:同架构的 token embedding 和词表完全一致,分布更接近。

from atb import LLM, SpeculativeConfig

# 同架构:Llama2-7B + Llama2-1.1B
model = LLM(
    "meta-llama/Llama-2-7b-hf",
    device="npu:0",
    speculative=SpeculativeConfig(
        draft_model="meta-llama/Llama-2-1.1b-hf",
        num_spec_tokens=5,  # 每次猜 5 个 token
    )
)

原则 2:大小比例 5-10×

Target Draft(推荐) 大小比例
7B 1.1B
13B 1.1B 12×
70B 7B 10×

比例太大(>15×):draft 分布跟 target 差距大,acceptance rate 低。
比例太小(<3×):draft 太慢,加速比上不去。

原则 3:同训练数据优先

用同一份数据训练的模型,分布更接近:

Llama2-7B ← Llama2 预训练数据
Llama2-1.1B ← Llama2 预训练数据(同数据)
→ acceptance rate ~65%

Llama2-7B ← Llama2 预训练数据
Qwen-1.5B ← Qwen 预训练数据(不同数据)
→ acceptance rate ~40%

Acceptance Rate 实测

Llama2-7B target,Atlas 800I A2:

Draft 模型 架构 Acceptance Rate 加速比 TTFT 开销
Llama2-1.1B 65% 2.1× +15ms
TinyLlama-1.1B 55% 1.7× +15ms
Qwen-1.5B 40% 1.3× +20ms
GPT2-small 25% 0.9× +25ms

GPT2-small 的加速比 <1,因为 acceptance rate 太低,验证开销比收益还大。

num_spec_tokens 调优

猜多少个 token 最优?取决于 acceptance rate:

acceptance_rate = p
num_spec_tokens = k
期望接受 token 数 = 1 + p + p² + ... + p^(k-1) = (1 - p^k) / (1 - p)

最优 k:
p=0.7 → k=5(期望 3.3 token)
p=0.5 → k=3(期望 1.75 token)
p=0.3 → k=2(期望 1.3 token)
# 高 acceptance rate:多猜几个
speculative=SpeculativeConfig(
    draft_model="Llama-2-1.1b-hf",
    num_spec_tokens=5,  # acceptance ~65%,猜 5 个
)

# 低 acceptance rate:少猜几个
speculative=SpeculativeConfig(
    draft_model="Qwen-1.5b",
    num_spec_tokens=3,  # acceptance ~40%,猜 3 个
)

Draft 模型的显存预算

Draft 模型也占显存:

Draft 参数量 显存 (bf16)
Llama2-1.1B 1.1B 2.2 GB
Llama2-7B 7B 14 GB

64GB 显存:Target 7B (14GB) + Draft 1.1B (2.2GB) + KV Cache (40GB) = 56.2GB ✅
64GB 显存:Target 70B (140GB) + Draft 7B (14GB) = 需要多卡 TP

自适应 Speculative Decoding

ATB 支持根据 acceptance rate 动态调整 num_spec_tokens:

speculative=SpeculativeConfig(
    draft_model="Llama-2-1.1b-hf",
    num_spec_tokens=5,
    adaptive=True,        # 开启自适应
    min_spec_tokens=2,    # 最少猜 2 个
    max_spec_tokens=8,    # 最多猜 8 个
)

ATB 内部跟踪最近的 acceptance rate:

  • rate > 70%:增加 num_spec_tokens(多猜)
  • rate < 40%:减少 num_spec_tokens(少猜)

Medusa Head:不用 draft 模型

除了独立的 draft 模型,还有 Medusa 方案——在 target 模型上加多个预测头:

# Medusa:target 模型自带的多个预测头
# 不需要额外 draft 模型,不占额外显存
speculative=SpeculativeConfig(
    draft_model=None,      # 不用独立 draft 模型
    medusa_heads=4,        # 4 个 Medusa 预测头
)

Medusa 的 acceptance rate 比 draft 模型低(~45%),但不占额外显存、无 TTFT 开销。


Speculative Decoding 的加速效果取决于 draft 模型选择。三原则:同架构、5-10× 大小、同训练数据。Llama2-7B + Llama2-1.1B 的 acceptance rate ~65%,加速 2.1×。仓库在这里:

https://atomgit.com/cann/ATB

https://atomgit.com/cann/ops-transformer

Logo

AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。

更多推荐