TrOPD:用投机解码的信任区域,把"学生跑飞"的在线蒸馏拉回来

核心摘要

如果你做过大模型蒸馏,一定碰到过这个尴尬情况:用 SFT(前向 KL)老老实实模仿教师没什么问题,但一上 On-Policy Distillation(OPD,让学生自己 rollout、教师在学生轨迹上打分),训练曲线分分钟坐过山车——一开始还行,跑着跑着就崩了,loss 飙升、生成开始复读、benchmark 暴跌。

这篇 TrOPD(Trust Region On-Policy Distillation,arXiv:2606.01249)就是冲这个不稳定性来的。作者把锅扣在 OPD 的反向 KL 估计器上:当学生分布 \(\pi_S\) 和教师分布 \(\pi_T\) 在某个 token 上差距很大(比如 \(\pi_S(x)/\pi_T(x)\) 飞到几百),K1 估计器会把这个比值的对数当成"梯度信号"放大,结果异常 token 主导了整批梯度,参数被一脚踹飞。

TrOPD 的解法很有意思——它没有继续在估计器层面打补丁(裁梯度、加掩码、塞熵正则那一套已经被试过了),而是借了投机解码(Speculative Decoding)的接受概率:定义一个信任区域 \(P_{\mathrm{trust}}(x) = \min(\pi_T(x)/\pi_S(x), 1)\),把 token 一刀切成"信任区域内"和"异常值"两类。信任区域内的样本用反向 KL 优化追求 mode-seeking;异常值改用前向 KL,让学生先把偏离的 token 概率往教师那拉。再叠加一个离策略指导——按概率混入"教师写前缀、学生续写"的轨迹,在 KL 巨大的早期阶段稳住训练。

效果是真的能打:在 Skywork-OR1-Math-7B 蒸到 1.5B 学生上,AIME24 从 OPD 的 35.0 跳到 51.7,AMC23 从 73.0 跳到 88.4;多域设置(数学/代码/通用)和 Qwen3 蒸馏链路上同样大幅领先。这是一篇"诊断准、解药漂亮、实验扎实"的工程向论文,思路对任何做 on-policy 蒸馏的同学都有参考价值。


论文信息

  • 标题:Trust Region On-Policy Distillation
  • arXiv2606.01249
  • 作者:Xingrun Xing, Haoqing Wang, Boyan Gao, Ziheng Li, Yehui Tang
  • 关键词:On-Policy Distillation, Trust Region, Speculative Decoding, Reverse KL, LLM Reasoning

一、问题动机:OPD 为什么会"崩"

先把前置背景对齐一下。

OPD 这一两年挺火,玩法是这样的:学生 \(\pi_S\) 自己 rollout 一段 \(y\),教师 \(\pi_T\) 在学生的每一个 token 上给出参考分布;loss 是反向 KL:

\[\mathcal{L}_{\mathrm{OPD}}(\theta) = \mathbb{E}_{y \sim \pi_S}\bigl[\mathrm{KL}\bigl(\pi_S(\cdot\mid y_{<t}) \,\|\, \pi_T(\cdot\mid y_{<t})\bigr)\bigr]\]

反向 KL 在直觉上很合理——它鼓励学生只在教师认可的 mode 上下注,不会被教师的低密度区拖偏。但实操中有个核心难点:这个 KL 是没法精确算的(\(|V| \approx 10^5\) token 太大),通常用 K1 估计器:

\[\hat{\mathrm{KL}}_{\mathrm{K1}} = \log \frac{\pi_S(x)}{\pi_T(x)}, \quad x \sim \pi_S\]

问题就出在这里。看 Figure 1:

图1:OPD 训练曲线对比。原始 OPD 在训练几百步后剧烈震荡甚至崩溃,TrOPD 全程稳定,且在 AIME/AMC 等评测上显著领先

图1:左边几条曲线是原始 OPD 加各种异常值估计补丁(梯度裁剪、token 掩码、前向 KL 兜底)后的训练表现,蓝绿色 TrOPD 曲线则全程平滑——这就是这篇论文的"动机一击"。右边给的是数学评测集上的最终分数,TrOPD 显著高于一众基线

把这个不稳定性拆开看,作者的诊断是这样的:

  • 训练初期 \(\pi_S\)\(\pi_T\) 还很接近,每个 token 的 \(\pi_S(x)/\pi_T(x)\) 大致在 1 附近,K1 估计是准的。
  • 训练几步后,学生在某些 token 上探索到了教师极度不喜欢的位置(\(\pi_T(x) \to 0\)),此时 \(\log(\pi_S/\pi_T)\) 会爆炸到几十甚至上百。
  • 梯度被这些异常 token 主导——它们在 batch 里可能就几个 token,但梯度幅度比正常 token 大几个数量级,整次更新等于沿着这些"教师极不喜欢但学生采样到了"的方向猛推一把,参数被踹偏。
  • 结果就是熵塌陷、复读、benchmark 暴跌。

这个故事其实在策略梯度训练里早就有了——异常重要性比例(importance ratio)的爆炸是 PPO/TRPO 提出的根源问题。作者敏锐地把这件事搬到了蒸馏上,并指出:之前几个 OPD 补丁(梯度裁剪、token 掩码、Entropy 正则、前向 KL 兜底)本质上都在打补丁式地解决这个问题,但都有副作用——要么过度激进(前向 KL 退化为 SFT,丢了 mode-seeking),要么过度保守(梯度裁剪扔掉了关键信号)。


二、TrOPD 的核心思路:把"信任"这件事正式建模

TrOPD 的核心切入点很巧妙:与其在估计器层面打补丁,不如先把 token 划成"该信任"和"不该信任"两类,再针对性处理。

2.1 信任区域:借投机解码的接受概率

作者直接把投机解码(Speculative Decoding)里的接受概率搬过来当判据:

\[P_{\mathrm{trust}}(x \mid y_{<t}) = \min\!\left(\frac{\pi_T(x \mid y_{<t})}{\pi_S(x \mid y_{<t})}, \; 1\right) \tag{6}\]

这个公式对熟悉投机解码的同学应该很亲切——它就是用大模型 \(\pi_T\) 验证小模型 \(\pi_S\) 提议的 token 时的接受率。物理意义是:

  • \(\pi_T(x) \ge \pi_S(x)\)(教师比学生更喜欢这个 token),\(P_{\mathrm{trust}} = 1\),完全信任;
  • \(\pi_T(x) < \pi_S(x)\)(学生过度高估了这个 token),\(P_{\mathrm{trust}} < 1\),按比例打折。

实操中再设一个阈值 \(\delta\)(论文里 \(\delta = 0.5\)):\(P_{\mathrm{trust}}(x) \ge \delta\) 的 token 进入"信任区域 \(\mathcal{T}\)",其余归为"异常值集 \(\mathcal{O}\)"。

看 Figure 2:

图2:TrOPD 框架概览。左侧展示信任区域 vs 异常值的概率密度划分;中间是信任区域内的反向 KL(mode-seeking),异常值集上切到前向 KL(mass-covering);右侧是离策略指导,按概率混入"教师写前缀、学生续写"的样本

图2:图的左半部分是分布层面的直觉——绿色信任区域内学生概率不超过教师,反向 KL 估计稳定;红色异常值区域学生概率远超教师,K1 估计会爆炸。中间块是核心 loss 设计:两种 KL 在两个区域分别上岗。右侧是离策略指导分支:按概率 \(\rho\) 用教师采样的前缀替换学生 rollout,缓解早期巨大 KL 带来的不稳定

2.2 信任区域内:稳定的反向 KL

\(\mathcal{T}\) 上,K1 估计是稳的,直接用反向 KL:

\[\mathcal{L}_{\mathrm{trust}}(\theta) = \mathbb{E}_{x \in \mathcal{T}}\bigl[\log \pi_S(x) - \log \pi_T(x)\bigr] \tag{7}\]

这部分让学生在教师认可的 mode 上集中概率,是 OPD 真正想要的 mode-seeking 行为。

2.3 异常值集:换成前向 KL

\(\mathcal{O}\) 上反向 KL 不可信,改用前向 KL(也即 SFT 风格的 cross-entropy on teacher samples):

\[\mathcal{L}_{\mathrm{outlier}}(\theta) = \mathbb{E}_{x \in \mathcal{O}}\bigl[-\log \pi_S(x)\bigr] \tag{8}\]

直觉是——既然学生在这些 token 上偏离过头了,那就先用前向 KL 把这些 token 的概率"往教师那拉",等学生回到信任区域里再切回反向 KL 学 mode。

总 loss:

\[\mathcal{L}_{\mathrm{TrOPD}}(\theta) = \mathcal{L}_{\mathrm{trust}}(\theta) + \lambda \mathcal{L}_{\mathrm{outlier}}(\theta) \tag{9}\]

论文里 \(\lambda = 1\)

2.4 离策略指导:早期稳定的关键

光有 trust/outlier 划分还不够。训练最早期,学生分布跟教师差距巨大,几乎所有 token 都是 outlier,反向 KL 这条路根本没机会启动。作者再加一个 trick——离策略指导(Off-Policy Guidance)

按概率 \(\rho\)(论文里 \(\rho = 0.5\))将一部分样本替换为"教师写前缀、学生续写"的混合轨迹。这样早期阶段学生看到的状态分布更接近教师,可以快速把信任区域比例做大;等学生跟教师靠近了,再让 \(\rho\) 自然下降(论文里没做 schedule,固定 0.5)。

这个设计有点像 RLHF 里的"behavior cloning warmup",但更精细——不是 batch 级别的 SFT 切 RL,而是 token 级别的混合采样。


三、实验:单域、多域、Qwen3,全打过去

3.1 单域数学蒸馏(Skywork-OR1-Math-7B → 1.5B)

教师是 Skywork-OR1-Math-7B,学生是 DeepSeek-Distilled-Qwen-1.5B,数据集 OpenThoughts3-Math。

方法 AIME24 AIME25 AMC23 平均
Student baseline 28.8 24.0 70.0 40.9
OPD 35.0 28.5 73.0 45.5
EOPD 32.5 25.7 71.4 43.2
REOPOLD 36.2 27.9 75.6 46.6
Entropy OPD 38.8 30.0 76.5 48.4
TrOPD 51.7 41.7 88.4 60.6

AIME24 从 OPD 的 35.0 直接拉到 51.7,提升幅度 +16.7 绝对点——这个幅度在数学蒸馏里非常夸张,相当于把 1.5B 学生推到了远超原 baseline 的水平。

3.2 多域蒸馏(DeepSeek 链路)

教师 Skywork-OR1-7B(数学+代码+通用),学生还是 DeepSeek-Distilled-Qwen-1.5B:

方法 AIME24 AIME25 AMC23 GPQA MMLU-R IFBench LCB 平均
Student 28.8 24.0 70.0 30.0 53.0 14.7 18.3 34.1
OPD 31.5 27.0 71.5 33.5 60.4 18.0 21.7 37.7
TrOPD 48.5 39.8 87.0 41.4 65.8 22.3 27.8 47.5

注意这里不只是数学涨——GPQA Diamond +7.9、MMLU-Redux +5.4、IFBench +4.3、LiveCodeBench +6.1,全面优于 OPD。这说明 TrOPD 的稳定性收益不是数学专属。

3.3 多域蒸馏(Qwen3 链路)

教师换成 Qwen3-Nemotron-4B,学生 Qwen3-SFT-1.7B:

方法 AIME24 AIME25 AMC23 GPQA 平均
OPD 41.2 33.0 79.5 36.7 47.6
TrOPD 55.6 45.0 89.6 44.1 58.6

跨教师/学生组合一致领先,说明方法对架构和初始化的依赖较小。

3.4 消融:每个组件都有用

图3:(a) 训练过程中熵的演化对比,TrOPD 维持健康熵;(b) 梯度范数对比,TrOPD 梯度全程稳定,原始 OPD 在 200 步后梯度爆炸

图3:左子图是熵曲线——TrOPD 始终保持在合理水位(既不塌陷也不发散),原始 OPD 在 ~300 步左右熵剧烈震荡然后塌缩到接近 0。右子图是梯度范数——这才是最直观的证据,TrOPD 梯度范数全程在一个量级内,原始 OPD 在异常值出现后梯度直接飞了几十倍

消融数据(AIME24 / 平均):

配置 AIME24 平均
完整 TrOPD 51.7 60.6
- 离策略指导 46.0 56.2
- 异常值前向 KL 41.5 52.1
- 信任区域划分(退化为纯 OPD) 35.0 45.5

三个组件都不可少,其中信任区域划分贡献最大(去掉就退化回 OPD baseline),离策略指导和异常值前向 KL 是稳定性收益。

3.5 vs 同期工作 AOPD

AOPD(Asymmetric On-Policy Distillation)是同期一个用非对称 loss 处理 OPD 不稳定性的工作。在同样的 Skywork→1.5B 设置下:

方法 AIME24 AIME25 AMC23 平均
AOPD 44.5 35.0 82.5 54.0
TrOPD 51.7 41.7 88.4 60.6

TrOPD 仍然领先 +6.6。作者指出,AOPD 的非对称设计本质是给反向 KL 加了一个软裁剪,但没有显式地处理"哪些 token 该信任、哪些不该"——TrOPD 的硬划分加双 loss 设计在这一点上更彻底。

3.6 训练动力学

图4:训练过程中信任区域内 token 比例的演化。初期约 60% token 处于信任区域,随训练推进逐步上升到 85% 以上,说明学生分布在向教师收敛

图4:这张图给了一个机制性的可视化——训练初期由于教师/学生分布差异大,只有 ~60% token 落在信任区域,需要离策略指导和异常值前向 KL 帮忙;随训练推进,信任区域比例稳步上升到 85%+,反向 KL 主导训练。这条曲线说明三个组件是有"接力"关系的——前期靠离策略指导和前向 KL 拉近分布,后期靠反向 KL 抓 mode


四、我的判断:值得抄,但要看清边界

读完这篇我的几个判断:

第一,"信任区域"这个比喻用得很准。它一下把 OPD 的不稳定性归因到"哪些 token 该信任 K1 估计器、哪些不该"这个层面,比之前各种打补丁的工作更结构化。借投机解码的接受概率当判据这件事,逻辑也很顺——投机解码本来就是在解决"小模型采样在大模型分布下的可信度"问题,这跟蒸馏是同一件事的两面。

第二,三组件设计的"接力"很聪明。Figure 4 把这件事讲透了:早期靠离策略指导 + 异常值前向 KL 拉近分布,后期靠反向 KL 抓 mode。这种"分阶段但用同一套机制平滑切换"的设计比手动 schedule(先 SFT 再 OPD)更优雅。

第三,提升幅度真的大。AIME24 +16.7 这种数字在蒸馏论文里很罕见,一般 mode-seeking 蒸馏方法跟前向 KL 比也就涨几个点。这意味着 OPD 不稳定性这件事被低估了——之前文献里 OPD 给的"小幅领先 SFT"的数字可能都是在崩溃前/崩溃边缘 cherry-pick 的。

要说我有什么疑虑:

  • 阈值 \(\delta = 0.5\) 是不是普适? 作者没做 \(\delta\) 的扫描,所有实验都是 0.5。直觉上 \(\delta\) 应该跟教师/学生的 size gap 相关——gap 越大,初期 outlier 越多,可能需要更小的 \(\delta\) 才能让信任区域不至于太空。这一点论文没讨论,是个小遗憾。
  • 离策略指导的 \(\rho = 0.5\) 也是固定的。Figure 4 显示训练后期信任区域比例已经 85%+,这时候继续给一半样本 off-policy 反而可能拖慢收敛。一个简单的 schedule(\(\rho\) 从 0.5 线性衰减到 0.1)可能能再涨一点点。
  • 跟 RL 的关系没讲清楚。TrOPD 用的"信任区域"这个名字明显在致敬 TRPO,但论文没有把 TrOPD 跟 PPO/TRPO 做形式上的对比——比如把信任区域看成一个软约束,能不能直接套用 PPO 的 clip?这个理论分析的口子如果开开,论文的影响力会更大。

跟同期工作放在一起看——AOPD 用非对称 loss、Entropy OPD 加熵正则、REOPOLD 引入参考策略,大家都是在治 OPD 不稳定性这个病。TrOPD 的 trust region 划分是这一系列工作里最干净的一刀。


工程启发

如果你正在做 LLM on-policy 蒸馏,这篇论文有几个直接可以抄的点:

  1. 不要无脑用反向 KL + K1 估计器。看到 OPD 训练崩了不要先怪数据,大概率是异常 token 在主导梯度。
  2. 把 token 分成两类对待\(P_{\mathrm{trust}} = \min(\pi_T/\pi_S, 1)\) 这个判据极其便宜(就是采样时已经算过的概率比),直接拿来切 trust/outlier 就完事。
  3. outlier 上换前向 KL。不要在 outlier 上继续硬抗反向 KL,那是异常值打梯度的来源;前向 KL 能稳稳把分布拉回来。
  4. 早期混 off-policy 样本。教师写前缀、学生续写,按概率 0.5 混进 batch,能极大地缓解早期不稳定。
  5. 盯紧梯度范数和熵。Figure 3 里两条诊断曲线非常实用——熵塌陷或梯度爆炸都是早期信号,不要等 benchmark 跌了再回头查。

如果你做 RLHF/RLVR,这篇的 trust region 划分思路也能借鉴——任何用 K1 estimator 算 KL 的场景(PPO 的 KL penalty 项、DPO 的隐式 KL)都可能踩同一个坑。


参考资料

  • 论文:arXiv:2606.01249
  • 关键基线:OPD(On-Policy Distillation)、AOPD(Asymmetric OPD)、Entropy OPD、REOPOLD
  • 相关方法:Speculative Decoding(投机解码,提供了 \(P_{\mathrm{trust}}\) 的灵感)、TRPO/PPO(提供了"trust region"的命名和直觉)

觉得有启发的话,欢迎点赞、在看、转发。跟进最新AI前沿,关注我