Draft-OPD:让推测解码的草稿模型,从"自己犯的错"里学习

核心摘要

如果你做过推测解码(Speculative Decoding),大概率碰到过一个让人头疼的现象——草稿模型 SFT 跑完一阶段就死活涨不动了。继续喂数据,接受长度(accepted length)在固定值附近来回晃,再多 token 也是浪费。

这篇 Draft-OPD 就是冲着这个 plateau 来的。作者把锅扣在了离线训练 vs 在线推理的分布不匹配上:SFT 阶段,草稿模型看到的每个前缀都是目标模型一笔一笔写出来的"完美轨迹";但推理时,验证发生在草稿模型自己提出的 token 块上——它得在自己诱导的状态分布里活下去。这两个分布根本不是一回事。

直接套 On-Policy Distillation(OPD,让目标模型来监督草稿模型自己 rollout 的轨迹)行不行?不行。草稿模型独立 rollout 容易跑飞,目标辅助 rollout 又会让序列回到目标分布、把"草稿诱导的错误"信号洗掉。Draft-OPD 的解法挺聪明:用推测解码本身去采集稳定的 rollout,把每个草稿块的起点当 anchor 记下来;然后从这些 anchor 出发"重放"草稿生成,在被验证拒绝的位置上反向 KL 惩罚错误模式,在被接受的位置上正向 KL 抓住正确分布

效果是真的能打:思维模式(thinking mode)下相比 EAGLE-3 提升 23%、相比 DFlash 提升 13%,关闭思维模式后实现超过 5× 的无损加速。SGLang 部署上 Q3-30B-A3B 的 AIME25 吞吐量也能涨 17%。这是一篇把"训练-推理分布偏移"这件事在草稿模型场景下讲透了的论文,思路可以迁移到很多 distillation 任务里。


论文信息

  • 标题:Draft-OPD: On-Policy Distillation for Speculative Draft Models
  • arXiv2605.29343(v2,2026-05-29)
  • 作者:Haodi Lei, Yafu Li, Haoran Zhang, Shunkai Zhang, Qianjia Cheng, Xiaoye Qu, Ganqu Cui, Bowen Zhou, Ning Ding, Yun Luo, Yu Cheng
  • 机构:上海交通大学 / 上海人工智能实验室 / 清华大学 / 香港中文大学 / 北京大学 / 浙江大学

一、问题动机:SFT 为什么会"卡住"

先说一下推测解码的基本设定,免得后面聊起来打结。

推测解码的玩法是:用一个轻量的草稿模型 \(q_\phi\) 一次提议 \(K\) 个 token,再让大目标模型 \(p_\theta\) 并行去验证这些 token——能接受多长就保留多长,剩下的 reroll。这套机制的整体加速比,几乎完全由"平均每次能被接受多少 token"(记作 \(\tau\))决定。\(\tau\) 越长,每一次 forward 的目标模型就摊薄得越多,加速比越高。

那草稿模型怎么训?最常见的做法是 SFT:让目标模型生成一堆轨迹 \(y\),然后让草稿模型学着预测目标模型的下一个 token。EAGLE-3、DFlash 都是这条路子。

问题出在哪?看 Figure 1:

图1:草稿模型训练过程中的接受长度变化。SFT 预热后,继续做离线 SFT 很快就 plateau,甚至单纯把 OPD 数据拿去做 SFT 还会让接受长度倒退

图1:上图就是这篇论文的"动机一击"——蓝色 SFT 曲线在初期快速上涨之后基本躺平,把 OPD 收集来的 on-policy 数据直接套 SFT loss 反而比原来还差(红线下滑),只有把 KL 蒸馏目标用对(绿线 Draft-OPD)才能继续上涨

作者把这种现象拆开看,给了一个很到位的诊断:离线-推理不匹配(offline-to-inference mismatch)

  • 训练时:草稿模型看到的前缀 \(y_{\le t}\) 全是目标模型亲手写的,干净、平滑、分布稳定。
  • 推理时:草稿模型先连续提出 \(K\) 个 token,\(K\) 个 token 都是它自己写的。一旦在第 \(r\) 个位置被拒绝,下一轮起始位置 \(y_{\le t+r}\) 又会变成"目标和草稿混着写出来的"——这是个完全不同的状态分布。

说白了,SFT 教出来的草稿模型,在自己诱导的分布上从来没被监督过。\(\tau\) 卡死在某个值,本质就是因为它不知道在"自己的状态"下该怎么走。

这个诊断不算颠覆,但论文做得漂亮的地方是把它量化清楚了——而且证明了一件反直觉的事:就算把 OPD 收集到的 on-policy 数据拿过来用 SFT loss 训,效果不仅没涨,还可能跌(Figure 1 红线)。问题不在数据来自哪里,问题在 loss 的形式本身假设了 token 是从目标分布采的。


二、为什么直接做 OPD 行不通

OPD(On-Policy Distillation)这个想法其实在 LLM 蒸馏里挺常见:让学生自己 rollout,老师在学生的轨迹上打分。在策略对齐、reasoning 蒸馏里都用过。

但作者发现,直接把 OPD 套到草稿模型上不 work。原因有两个,看 Figure 2:

图2:直接 OPD 为什么不适合草稿模型。(a) 草稿模型独立 rollout 出来的轨迹高度重复、完全跑飞 (b) 朴素的目标辅助 rollout 让序列回到目标分布,丢失了草稿诱导的错误信号

图2:(a) 子图展示草稿模型独立 rollout 极易陷入循环复读——这是因为草稿模型本身参数小、没有经过 RLHF,单独跑长序列就是灾难。(b) 子图展示了一个更微妙的问题:用目标模型 token 替换被拒绝的位置后继续往下采,整段序列会迅速向目标分布靠拢,"草稿到底在什么状态下犯什么错"这件事被洗掉了

这两个失败模式说的其实是同一件事的两个极端:

  • 完全让草稿自己走——分布是真 on-policy 了,但轨迹质量太差,目标模型根本不知道该监督什么
  • 让目标全程辅助走——轨迹质量很好,但草稿诱导的错误信号被稀释,蒸馏退化成另一种形式的 SFT

中间地带在哪?这就是 Draft-OPD 的切入点。


三、Draft-OPD:用推测解码本身做数据采集

核心架构看 Figure 3:

图3:Draft-OPD 整体框架。用推测解码采集稳定 rollout,记录每个草稿块的起始位置作为 anchor,然后从多个 anchor 重放草稿生成,在被验证拒绝/接受的位置上施加不同方向的 KL 损失

图3:左侧是 rollout 阶段——目标模型和草稿模型按推测解码的标准流程交替进行,每个草稿块的起点 \(a_m\) 被作为 anchor 记下来;右侧是训练阶段——从每个 anchor 重新让草稿模型展开 \(K\) 个 token,对照目标模型的概率分布算 KL,关键是连被拒绝的位置也要算,这部分才是真正的"on-policy 错误信号"

整个方法可以拆成三块。

3.1 目标辅助 rollout + Anchor 记录

给定 prompt \(x\),标准推测解码循环是这样的:草稿模型从位置 \(a_m\) 出发提议 \(K\) 个 token:

\[d_m = (d_{m,1}, \ldots, d_{m,K}) \sim q_\phi(\cdot \mid x, y_{\le a_m})\]

目标模型并行验证,接受了 \(r_m\) 个,下一轮 anchor 就是 \(a_m + r_m\)。最终的 rollout 序列 \(y\) 是目标模型质量的(因为被拒绝的部分都被目标 token 替换了),但 anchor 列表 \(\{a_m\}\) 完整保留了"草稿模型在哪些状态下被叫起来提议"的信息

这一步的精髓在于:"采数据"和"打标签"复用了推测解码本身。不需要额外的目标模型 forward,不需要额外的 reward model——一次 rollout 同时拿到了高质量序列和 on-policy 状态点。

3.2 错误位置回放

光有 anchor 还不够,因为最终序列 \(y\) 是目标分布的。怎么把"草稿犯的错"还原回来?

答案是回放。对每个 anchor \(a_m\),定义回放上下文 \(c_m = (x, y_{\le a_m})\),从 \(c_m\) 开始重新让草稿模型生成那个原始的 \(d_m\)(注意是原本被验证、被部分拒绝的那个块,不是新采一次)。然后在每个位置 \(k\) 算两个概率:

  • 学生:\(\log q_{m,k}(d_{m,k}) = \log q_\phi(d_{m,k} \mid c_m, d_{m,\lt k})\)
  • 教师:\(\log p_{m,k}(d_{m,k}) = \log p_\theta(d_{m,k} \mid c_m, d_{m,\lt k})\)

注意这里的细节:教师是在草稿模型生成的 token 序列上算概率的,不是在最终 rollout 序列上。这意味着教师能看到"如果一直让草稿写下去会发生什么",并对每一步给出反馈——包括那些在原始验证中被拒绝的位置

被拒绝的位置才是金矿。SFT 永远看不到这些位置,因为目标模型本来就不会写出这些 token。Draft-OPD 通过回放硬是把这部分错误信号挖了回来。

3.3 验证感知的混合 KL 目标

接下来就是 loss 设计。作者把所有 token 按验证结果分两组,用不同方向的 KL:

接受位置\(\mathcal{I}_{acc} = \{(m,k): 1 \le k \le r_m\}\))—— 用前向 KL

\[\mathcal{L}_{acc} = \frac{1}{|\mathcal{I}_{acc}|} \sum_{(m,k) \in \mathcal{I}_{acc}} D_{KL}(p_{m,k} \,\|\, q_{m,k})\]

直觉是:在草稿已经做对的位置,让它的分布覆盖目标分布的所有 mode(前向 KL 是 mean-seeking 的)。

拒绝位置\(\mathcal{I}_{rej} = \{(m,k): r_m \lt k \le K\}\))—— 用反向 KL + 位置权重:

\[\mathcal{L}_{rej} = \frac{1}{Z} \sum_{(m,k) \in \mathcal{I}_{rej}} w_k \cdot D_{KL}(q_{m,k} \,\|\, p_{m,k})\]

其中 \(w_k = \gamma^{k-1}\)\(\gamma=0.8\)。直觉是:在草稿犯错的位置,惩罚草稿自己 confident 但目标不同意的那些模式(反向 KL 是 mode-seeking 的,能抑制错误高峰);同时块里越靠前的错误对接受长度损害越大(一旦第 1 个 token 被拒,整个块都没了),所以加大它们的权重。

最终目标:

\[\mathcal{L}_{Draft\text{-}OPD} = \frac{\lambda_{acc} \mathcal{L}_{acc} + \lambda_{rej} \mathcal{L}_{rej}}{\lambda_{acc} + \lambda_{rej}}\]

实验中 \(\lambda_{acc} = \lambda_{rej} = 1\)

我第一遍读到这里的时候愣了一下——前向 KL 给接受位置、反向 KL 给拒绝位置,这个分配不是反着来的吗?再想想就明白了:接受位置上你要的是"覆盖目标的全部行为"(不管目标走哪条路你都跟得上),所以 mean-seeking;拒绝位置上你要的是"压住自己犯错的高峰"(让那个错误模式的概率塌下去),所以 mode-seeking。这个设计相当精巧。


四、实验结果:5× 加速、消融全胜

4.1 主实验

实验在 Qwen3 系列(4B / 8B / 30B-A3B-Thinking-2507)上做,草稿模型用 DFlash 风格的并行架构(5-8 层 Transformer,块大小 \(K=16\))。Baseline 选了 EAGLE-3 和 DFlash。

Model 方法 Thinking 开(avg speedup / τ) Thinking 关(avg speedup / τ)
Q3-4B EAGLE-3 4.41× / 5.33 4.58× / 5.84
Q3-4B DFlash 4.51× / 5.51 5.36× / 6.04
Q3-4B Draft-OPD 5.31× / 5.96 6.22× / 6.60
Q3-8B EAGLE-3 4.58× / 5.64 4.99× / 5.99
Q3-8B DFlash 4.67× / 5.19 5.69× / 6.04
Q3-8B Draft-OPD 5.36× / 5.73 6.49× / 6.57

几个数我盯着看了一会儿:

  • 思维模式下涨幅最大(相比 EAGLE-3 提升 23%)。这个挺反直觉——按说 thinking trace 那种 self-talk 风格应该对草稿模型更友好(重复、套路化),但实际是 thinking 模式下 SFT plateau 得更厉害,于是 Draft-OPD 的边际收益更大。作者后面在 Figure 5 里也讨论了这个现象:thinking 模式下不同任务的 token-level NLL 差异更大,分布更复杂,这正是 on-policy 信号最值钱的地方。
  • Q3-8B 在 thinking 模式下的 τ 反而比 4B 略低(5.73 vs 5.96)。这倒不是 bug——更大的目标模型有更尖锐的分布,草稿模型更难匹配。从 speedup 角度看 8B 的 5.36× 还是比 4B 的 5.31× 略高,因为目标模型一次 forward 的成本占比更高。

4.2 SGLang 部署吞吐

光看 τ 不够,真实部署还要看吞吐。在 SGLang(concurrency=32)上:

Model Task DFlash (tok/s) Draft-OPD (tok/s) 提升
Q3-4B AIME25 8410 9043 +8%
Q3-4B MATH-500 10062 10943 +9%
Q3-8B AIME25 5985 6645 +11%
Q3-8B MATH-500 6991 7940 +13%
Q3-30B AIME25 4014 4718 +17%

模型越大、任务越难,提升越明显。这件事其实印证了一个工程直觉——目标模型 forward 成本越贵的场景,每多接受一个 token 的边际收益越大,所以提升 τ 在大模型上更值钱。

4.3 消融实验

这是我最喜欢看的部分。作者在 Q3-4B 上做了两组消融,验证三个核心设计的贡献。

组件消融(Table 3)——在 Q3-4B 上:

配置 MATH-500 (×) HumanEval (×) MT-Bench (×)
Draft-OPD(完整) 5.55 5.17 3.18
移除位置权重衰减 5.13 4.96 3.07
全反向 KL(拒绝/接受位置都用反向) 5.11 4.94 3.08
全前向 KL 5.34 5.01 3.09
随机 anchor(不用真实推测解码点) 5.04 4.99 2.96

读这张表的几个观察:

  • 混合 KL 的设计是真的有用。全前向 KL 还能保住大部分收益(5.34 vs 5.55,掉 4%),全反向 KL 就掉得比较多(5.11,掉 8%)。这说明在接受位置上做 mean-seeking 的覆盖更重要,而在拒绝位置上做 mode-seeking 的压制是锦上添花。
  • anchor 选取至关重要。"随机 anchor" 这一行是把 anchor 随便选——结果掉到 5.04,几乎和 baseline 持平。这反过来证明了 anchor 必须对应真实推测解码中草稿模型实际提议的位置,分布偏移这件事不是开玩笑。
  • 位置权重衰减贡献中等(3-7%),但实现成本几乎为零,加上肯定是没毛病的。

Naive rollout 对比(Table 4)

配置 MATH-500 (×) HumanEval (×) MT-Bench (×)
Draft-OPD 5.55 5.17 3.18
朴素目标辅助 rollout(无错误位置回放) 5.06 4.80 3.02

掉 7-9%。这张表才是真正回答"为什么需要回放"的——不回放就退化成"在目标分布上做的另一种 SFT",前面 Figure 2(b) 那个失败模式直接复现。

4.4 训练数据消融

图4:训练数据消融实验,Qwen3-4B 思维模式

图4:作者还做了训练数据规模和种类的消融——OPD 数据量从 4K 涨到 16K 接受长度持续上升,但增速逐渐放缓;不同领域的数据(数学/代码)混合训练比单一领域更好

OPD 数据用了 16K 样本(GSM8K 2K + MATH 5K + AoPS 4K + CodeAlpaca 5K),全部只用 prompt,response 是目标模型在线生成的——这点很关键,不要用静态参考答案,否则又退化回 SFT。

4.5 Token 级 NLL 分析

图5:思维/非思维模式下的 token-level NLL 在不同评测集上的分布

图5:这张图展示了不同任务下 token 级负对数似然的分布——thinking 模式下分布的右尾(高 NLL token)明显更长,意味着草稿模型在思维链中遇到"难预测"token 的频率更高,而这些正是 on-policy 训练能改善的地方

这张图其实是给主实验的"thinking 模式下提升更大"做了一个机制性的解释:thinking 模式下 token 分布更复杂、长尾更重,离线 SFT 的覆盖能力本来就有限,on-policy 信号的价值更大。


五、我的判断:值得看,思路可迁移

读完这篇论文,我觉得它最值钱的地方有两点:

第一,把"训练-推理分布偏移"这件事在草稿模型场景里讲清楚了。RLHF 圈子里说 distribution shift 已经说烂了,但放到 spec decoding 这个场景,之前的论文要么模糊地提一下、要么把它和别的问题混在一起谈。这篇直接给出了诊断(Figure 1)、机制(Figure 2 的两种失败模式)和解药(Figure 3 的 anchor + replay),逻辑链条是闭合的。

第二,"用推测解码本身做数据采集"这个 trick 很优雅。它解决了 OPD 在草稿模型上的核心矛盾——草稿模型独立 rollout 不行,目标模型完全辅助又洗掉信号。Anchor + replay 找了一个非常聪明的中间态:rollout 用目标质量保证序列稳定,loss 在 anchor 起点的草稿生成上算保证 on-policy。这个思路其实可以迁移到很多 distillation 场景——任何"学生独立 rollout 跑飞、老师辅助又洗掉信号"的 setup,都可以借鉴 anchor 这个机制。

要说我有什么疑虑,主要是两个:

  • Anchor 数量和分布对训练效率影响多大? 论文里 K=16,每个 prompt 大概会产生几十个 anchor。但如果 K 更大、目标模型接受率更高(thinking 模式下平均 τ ≈ 6 已经接近 K 的一半了),anchor 会变稀——这种情况下 Draft-OPD 还能保持优势吗?论文没充分讨论。
  • 回放的计算成本。每个 anchor 都要让草稿模型重新展开 K 个 token,这对训练吞吐有影响。论文给了主实验数据但没给训练时间对比。如果回放让训练成本翻倍,那 23% 的 τ 提升就要重新算账。

另外,跟同期工作放一起看——EAGLE-3 系列、DFlash、还有更早的 Medusa,大家其实都在围着"草稿模型的 plateau"打转。EAGLE-3 通过引入 multi-layer feature 提升表征能力,DFlash 通过架构上的并行化降低草稿成本,Draft-OPD 则是从训练目标本身下手。这三条路是正交的,理论上可以叠加——但作者没做这种组合实验,是个遗憾。


工程启发

如果你正在做推测解码的草稿模型训练,这篇论文有几个直接可借鉴的点:

  1. 不要无脑加 SFT 数据。Figure 1 说得很清楚,SFT 的 plateau 是结构性问题,加多少数据都解决不了。
  2. OPD 数据 ≠ OPD loss。哪怕你已经收集了 on-policy 数据,用 SFT loss 训也不会涨(红线),必须用 KL 蒸馏。
  3. 混合 KL 方向。接受位置前向 KL,拒绝位置反向 KL + 位置衰减,这个配方可以直接抄。
  4. 用推测解码本身采集训练数据。这个 trick 在工程上非常友好——你部署阶段已经有推测解码 pipeline 了,加个 anchor 记录就能直接拿来训练,不用搭额外的 rollout 系统。

如果你在做更广义的 LLM 蒸馏(不止 spec decoding),anchor + replay 这个机制其实在任何"学生需要 on-policy 状态但又跑不稳"的场景都能用。比如 reasoning trace 蒸馏、tool use 蒸馏、agent 行为蒸馏——都值得想想能不能用类似的"老师辅助 rollout 但学生在 anchor 处重新展开"的思路把训练-推理 gap 给闭掉。


参考资料

  • 论文:arXiv:2605.29343
  • 项目页:https://www.haodilei.top/draft-opd
  • 代码:https://github.com/bingyang-lei/Draft-OPD
  • 模型:https://huggingface.co/collections/bingyang-lei/draft-opd

觉得有启发的话,欢迎点赞、在看、转发。跟进最新AI前沿,关注我