Skip to content

拒绝采样(Reject Sampling)在RLHF中的应用

1. 拒绝采样的核心思想

拒绝采样是一种从复杂分布中生成样本的技术,通过以下步骤实现:

  1. 生成候选样本:使用一个易于采样的提议分布生成候选样本。
  2. 评估样本质量:根据目标分布的概率密度函数(或替代评估标准)计算每个样本的接受概率
  3. 接受或拒绝样本:以一定概率接受样本,否则拒绝。

在RL场景中,目标是从当前策略(Checkpoint)生成的样本中筛选出高质量数据,用于监督式微调。

2. 在RL中应用拒绝采样的具体流程

步骤 1:生成候选样本

  • 使用当前RL策略(Checkpoint)对一批输入(例如用户查询)生成多个候选响应。
  • 例如,对于每个输入 $ q $,策略 $ \pi_{\theta} $ 生成 $ N $ 个响应 $ { o_1, o_2, ..., o_N } $。

步骤 2:评估样本质量

  • 通过奖励模型(Reward Model)人工标注对每个候选响应打分,得到奖励值 $ r_i $。
  • 奖励值可能基于:
    • 响应与期望答案的匹配度(如BLEU、ROUGE)。
    • 人类偏好(如安全性、连贯性)。
    • 任务相关的性能指标(如代码正确性、数学解题步骤)。

步骤 3:计算接受概率

  • 对每个候选响应 $ o_i $,计算接受概率 $ \alpha_i $: $$ \alpha_i = \frac{r_i}{\max(r_1, r_2, ..., r_N)} \quad \text{或} \quad \alpha_i = \text{sigmoid}(r_i - \tau) $$ 其中 $ \tau $ 为阈值,用于控制筛选严格度。

步骤 4:接受或拒绝样本

  • 对每个候选响应,以概率 $ \alpha_i $ 接受该样本,否则拒绝。
  • 最终保留的样本构成新的SFT数据集 $ \mathcal{D}_{\text{SFT} } $。

步骤 5:监督式微调(SFT)

  • 使用 $ \mathcal{D}_{\text{SFT} } $ 对模型进行微调,提升其在特定任务上的表现。

3. 拒绝采样的优势

  1. 数据质量提升:通过筛选高奖励样本,避免低质量数据污染训练集。
  2. 稳定性增强:在RL接近收敛时,策略生成的样本趋于稳定,拒绝采样可进一步优化数据分布。
  3. 效率优化:仅保留有价值样本,减少冗余计算。

4. 实际应用案例

  • PPO(Proximal Policy Optimization):在策略更新时,通过重要性采样和裁剪(Clipping)隐式实现拒绝采样。
  • GRPO(Group Relative Policy Optimization):使用分组评估奖励,动态调整接受阈值。
  • AlphaZero:在蒙特卡洛树搜索(MCTS)中,通过价值网络筛选高价值路径。

5. 数学示例

假设当前策略 $ \pi_{\theta} $ 生成三个候选响应,奖励分别为 $ r_1=0.8, r_2=0.5, r_3=0.9 $:

  1. 归一化接受概率:$ \alpha_1 = 0.8/0.9 \approx 0.89 $, $ \alpha_2 = 0.5/0.9 \approx 0.56 $, $ \alpha_3 = 1 $。
  2. 按概率接受:响应3必然被接受,响应1以89%概率接受,响应2以56%概率接受。
  3. 最终可能保留响应1和3,加入SFT数据集。

6. 注意事项

  • 奖励模型的偏差:若奖励模型与真实目标不一致,可能导致筛选偏差。
  • 多样性损失:过度筛选可能减少数据多样性,需平衡质量与覆盖度。
  • 计算开销:生成和评估大量候选样本会增加训练成本。

Maintained by Robin