截断重要性采样(TIS)
1. 简而言之
大语言模型强化学习微调不稳定的一个关键来源是训练-推理不匹配(training-inference mismatch)。为了最大化训练效率,现代强化学习训练框架(如 VeRL)通常会采用两种不同的计算引擎:一种是为快速推理(rollout)高度优化的引擎(如 vLLM),另一种是为梯度计算设计的训练引擎(如 FSDP)。尽管这两种引擎在数学原理上是等价的,但由于浮点数精度误差和硬件层面的具体优化差异,它们会产生数值上不完全相同的输出。近期的一系列研究已经指出,这种看似微不足道的不匹配,会在优化过程中引入显著的问题,是导致训练不稳定的核心因素之一。
2. 不匹配问题
为简化起见,我们以 REINFORCE 算法为例,该算法通过以下方式更新策略
实践中,轨迹生成成本高昂,现代强化学习框架(例如 VeRL)通常采用高度优化的推理引擎(例如 vLLM、SGLang)来提升吞吐量,同时使用独立后端(例如 FSDP、Megatron)进行模型训练。这种混合设计使得更新过程变为:
此处我们使用
实验中观察到意外的 rollout 训练失配现象, 尽管
2.1 优化驱动的恶性循环
你可能认为训练-推理失配是硬件与软件栈的静态特性。然而,通过实验证明这种失配与训练动态及模型状态相互耦合。
我们推测这是由于以下两阶段级联故障所致:
- 阶段一:数值敏感度增强。强化学习优化器将模型权重推至
bfloat16数据类型相对精度较低的数值范围(例如极小或极大值)。 - 阶段二:内核驱动的误差放大。这些初始微小的
bfloat16量化误差随后被输入 vLLM 和 FSDP 的不同内核实现。差异化的计算顺序充当非线性放大器,使初始微小偏差最终雪崩式扩大为最终逻辑值的巨大差异。
这形成了一个恶性反馈循环:失配导致有偏且含噪的梯度,可能将参数进一步推向数值敏感区域,进而加剧下一轮迭代的失配程度,直至系统彻底崩溃。
3. 缓解训练-推理失配的尝试
接下来我们将列举为缓解训练-推理失配所尝试的方法。其中部分方法有所助益,另一些则收效甚微。
3.1 使用 FP32 lm_head
受 Minimax-M1 技术报告及博客文章 《Your Efficient RL Framework Secretly Brings You Off-Policy RL Training》启发,我们通过修改 vLLM 将 lm_head 转换为 FP32 精度。但在实验中,修补后失配问题依然存在,模型崩溃仍不可避免。
3.2 禁用分块预填充
我们还尝试通过禁用分块预填充来验证是否能解决崩溃问题。然而,实验结果显示该方法并未解决崩溃问题。
3.3 启用 enforce_eager 与 free_cache_engine
VeRL 官方提供的 DAPO 方案指出,启用 CUDA 图(enforce_eager=False)可能导致模型性能下降。为探究这是否会影响训练-推理失配问题,我们通过消融实验研究了 vLLM 引擎超参数 enforce_eager 的影响,并同步考量另一超参数 free_cache_engine。实验结果显示,调整 enforce_eager 与 free_cache_engine 的取值对训练-推理失配现象及测试性能均无显著影响。
4. 接纳失配—实施算法级修复
4.1 重要性采样
当直接对目标分布下的期望值进行蒙特卡洛估计较为困难时,重要性采样允许我们从替代分布中进行抽样。在上面描述的强化学习场景中,目标分布是
4.2 截断重要性采样 TIS
不同于在系统层面缓解分布失配,我们提出通过调整模型更新机制使其感知这种失配。简单的方法是采用重要性采样校正。具体而言,我们通过在当前梯度计算中添加重要性比率来处理
变为
尽管关于如何设计稳定有效的重采样方法已有广泛研究,但在实践中我们发现通常采用经典技术——截断重要性采样便已足够:
其中
4.2.1 扩展至其他算法
将上述分析扩展到其他算法是直截了当的,因为我们可以将梯度计算的具体形式从 REINFORCE 的
PPO 的策略梯度
为提升吞吐量,混合强化学习系统采用 vLLM 引擎进行推演生成——从
与上述分析类似,
其中
4.3 与两种 TIS 变体的比较
我们还总结了两种用于缓解分布差距的替代方案。
- PPO 重要性采样 (PPO-IS)
注意: Colossal 框架使用此实现。
基础重要性采样 (vanilla-IS)
注意: Nemo-RL 使用此实现。
为评估 TIS 的有效性并理解其设计选择的影响,我们进行了对比 TIS 与上述两种变体的实验。TIS 始终优于两种变体,尤其在差异显著的情况下(如 FP8/INT8 量化场景)表现更为突出。
4.3.1 vanilla-IS 对比 TIS
关于基础重要性采样(vanilla-IS),其不稳定性主要源于当
4.3.2 PPO-IS 对比 TIS
采用 PPO-IS 方法后,梯度实际上仍会偏离 PPO 的同策略版本。换言之,尽管该方法可能仍在朝着无偏目标进行优化,但相比标准 PPO 算法其效率可能有所不足。
此外需要说明的是, PPO 信任域技术的提出旨在将轨迹采样
4.4 TIS 工作机制的直观解释
虽然 TIS 的确切机制仍是待解之谜,我们对其缓解分布差异的原理提供高层级阐释。
特别需要注意的是,忽略具有
与此同时, TIS 坚持对
5. 重要性采样的进一步讨论
训练-推理失配将原本同策略的强化学习问题转化为异策略问题,其中用于生成轨迹的策略(行为策略,
受 [Yao 等, 2025] 首次揭示这一隐式异策略问题的研究启发,我们分析了两种主要的 IS 形式:理论完备的Sequence-Level IS 与常见但存在缺陷的Token-Level IS 近似——后者也是该文献中探讨的启发式方法。
5.1 Sequence-Level 重要性采样
正确且无偏的策略梯度估计器在整个生成序列(轨迹)上应用单一重要性比率
让我们逐步推导Sequence-Level 重要性采样估计器
- 目标是在目标 FSDP 策略下最大化期望奖励:
- 因此真实策略梯度为:
- 由于我们只能从 vLLM 策略中采样,故使用重要性采样来改变期望的分布:
该估计器在数学上等价于标准优势函数形式的策略梯度。关键在于证明重要性采样比率能精确修正期望值,揭示底层真实的同策略梯度,进而可对其进行优化。
此推导最终得到策略梯度的优势函数形式:
此处
该估计器是无偏的,这意味着
这种方法的关键方面包括:
正确的分布校正:该方法通过计算整个序列的单一比率而不是每个 token 的比率来正确地应用重要性采样。这一点至关重要,因为它保持了行为策略(
π_vllm)和目标策略(π_fsdp)之间正确的概率关系。无偏梯度估计:通过使用完整的序列概率比:
估计器保持无偏,这意味着
是精确成立的。 状态访问度量的考虑:推导过程通过项
正确考虑了状态访问分布的差异,该项表示在目标策略下访问各个状态的频率。
5.1.1 序列级别 IS 的方差挑战
虽然在理论上是合理的,但序列级别重要性采样在实践中可能会遇到高方差问题:
- 当策略差异很大时,序列级别的比率可能变得极大或极小
- 这会导致不稳定的梯度估计,可能损害训练收敛性
- 在长序列中这个问题更加严重,因为每一步的小差异会以乘法方式累积
5.1.2 截断重要性采样(TIS)解决方案
为了在保持理论正确性的同时解决方差问题,截断重要性采样限制了极端比率的影响:
其中
这种截断引入了一些偏差,但显著降低了方差,通常在实践中带来更稳定的训练效果。
5.2 Token-Level 重要性采样
一种常见启发式方法,通常受到 PPO 等算法的启发并在 (Yao 等人, 2025) 中使用,采用逐词元重要性比率。虽然这通常比 Sequence-Level 比率具有更低的方差,但它是一种有偏估计器,对于自回归模型在理论上并不严谨。
让我们推导Token-Level 重要性采样梯度估计器
该公式通过错误地在时间步求和和内部应用重要性采样比率开始:即
被定义为 我们可以将此轨迹期望重写为在 vLLM 策略下访问状态的期望。
注:此处
表示由 采样的完整轨迹所得的经验回报,作为状态-动作价值函数 的蒙特卡洛估计值。通过引入基线函数并改变动作期望的计算方式,最终得到如下形式:
最终表达式清晰地揭示了 Token-Level 重要性采样的梯度偏差。
5.3 Token-Level 重要性采样的偏差来源
将
5.3.1 误差源 1:状态访问分布失配 🌍
有效的离策略修正必须考虑两种分布偏移:动作概率分布与状态访问概率分布。词元级方法仅修正了前者。
- 真实梯度(
):期望计算基于正确目标 FSDP 分布下的状态访问, 。 - 缺陷梯度(
):期望计算基于错误行为 vLLM 分布下的状态访问, 。
该方法隐含假设状态访问比率为 1,即
5.3.2 误差源 2:失配奖励信号 🎯
第二个关键错误在于,词元级梯度使用错误策略的奖励信号来加权更新。
- 真实梯度(
):该更新通过目标全分片数据并行策略的优势函数 进行缩放,该函数代表在该策略下的预期未来奖励。 - 有缺陷的梯度(
):该更新由行为 vLLM 策略的优势函数进行缩放, 。
目标策略的梯度正在被属于行为策略的奖励信号所缩放。由于状态分布和奖励信号存在根本性不匹配, Token-Level 梯度实际上是一个有偏且理论不稳健的估计量。
🔧 这些理论表明,尽管 Token-Level 方法可能具有较低的方差,但梯度偏差仍然存在,可能导致训练不稳定——这一预测在我们的实验中得到了验证。我们还针对 Token 级和序列级方法提出了详细的偏差与方差分析(第一部分和第二部分)。
6. 掩码重要性采样(MIS)
为改进 TIS,我们提出掩码重要性采样(MIS),该方法对重要性采样比率超过阈值
6.1 Sequence-Level MIS
在 Sequence-Level MIS 中,我们基于整个序列的重要性比率为整个序列应用掩码。具体而言,对于一个由采样策略
当
这种方式相比于 TIS 更加严格,因为它完全排除了那些可能引入巨大方差的样本,而不是仅仅截断重要性比率。这有助于进一步稳定训练,特别是在策略差异较大的情况下。
6.2 Token-Level MIS
在 Token-Level MIS 中,我们为每个 token 单独计算重要性比率,并基于该比率决定是否对该 token 的贡献进行掩码。对于序列中的第
当
与 Token-Level TIS 相比, Token-Level MIS 更加严格,因为它完全排除了那些可能引入不稳定性的 token 贡献,而不是仅仅截断比率值。
6.3 MIS 与 TIS 的比较
方差控制:
- TIS 通过截断操作限制了重要性比率的最大值,但仍然保留了所有样本的贡献
- MIS 通过完全移除高比率样本/Token,从根本上消除了这些可能引入巨大方差的贡献
偏差-方差权衡:
- TIS 引入了一些偏差(通过截断),但保持了较低的方差
- MIS 可能引入更大的偏差(通过完全排除样本),但能够更有效地控制方差
适用场景:
- 当策略差异相对较小且主要由少数极端比率主导时, MIS 可能更有效
- 当策略差异较为均匀分布时, TIS 可能提供更好的偏差-方差平衡
6.4 在 PPO 中的应用
将 MIS 扩展到 PPO 算法中,我们可以得到相应的表达式:
对于 Sequence-Level MIS-PPO:
对于 Token-Level MIS-PPO:
其中
通过这种方式, MIS 为处理训练-推理不匹配问题提供了另一种有效的算法级解决方案,能够与 TIS 形成互补,在不同场景下提供更好的稳定性和性能。