拒绝采样(Reject Sampling)在RLHF中的应用
1. 拒绝采样的核心思想
拒绝采样是一种从复杂分布中生成样本的技术,通过以下步骤实现:
- 生成候选样本:使用一个易于采样的提议分布生成候选样本。
- 评估样本质量:根据目标分布的概率密度函数(或替代评估标准)计算每个样本的接受概率。
- 接受或拒绝样本:以一定概率接受样本,否则拒绝。
在RL场景中,目标是从当前策略(Checkpoint)生成的样本中筛选出高质量数据,用于监督式微调。
2. 在RL中应用拒绝采样的具体流程
步骤 1:生成候选样本
- 使用当前RL策略(Checkpoint)对一批输入(例如用户查询)生成多个候选响应。
- 例如,对于每个输入 $ q $,策略 $ \pi_{\theta} $ 生成 $ N $ 个响应 $ { o_1, o_2, ..., o_N } $。
步骤 2:评估样本质量
- 通过奖励模型(Reward Model)或人工标注对每个候选响应打分,得到奖励值 $ r_i $。
- 奖励值可能基于:
- 响应与期望答案的匹配度(如BLEU、ROUGE)。
- 人类偏好(如安全性、连贯性)。
- 任务相关的性能指标(如代码正确性、数学解题步骤)。
步骤 3:计算接受概率
- 对每个候选响应 $ o_i $,计算接受概率 $ \alpha_i $: $$ \alpha_i = \frac{r_i}{\max(r_1, r_2, ..., r_N)} \quad \text{或} \quad \alpha_i = \text{sigmoid}(r_i - \tau) $$ 其中 $ \tau $ 为阈值,用于控制筛选严格度。
步骤 4:接受或拒绝样本
- 对每个候选响应,以概率 $ \alpha_i $ 接受该样本,否则拒绝。
- 最终保留的样本构成新的SFT数据集 $ \mathcal{D}_{\text{SFT} } $。
步骤 5:监督式微调(SFT)
- 使用 $ \mathcal{D}_{\text{SFT} } $ 对模型进行微调,提升其在特定任务上的表现。
3. 拒绝采样的优势
- 数据质量提升:通过筛选高奖励样本,避免低质量数据污染训练集。
- 稳定性增强:在RL接近收敛时,策略生成的样本趋于稳定,拒绝采样可进一步优化数据分布。
- 效率优化:仅保留有价值样本,减少冗余计算。
4. 实际应用案例
- PPO(Proximal Policy Optimization):在策略更新时,通过重要性采样和裁剪(Clipping)隐式实现拒绝采样。
- GRPO(Group Relative Policy Optimization):使用分组评估奖励,动态调整接受阈值。
- AlphaZero:在蒙特卡洛树搜索(MCTS)中,通过价值网络筛选高价值路径。
5. 数学示例
假设当前策略 $ \pi_{\theta} $ 生成三个候选响应,奖励分别为 $ r_1=0.8, r_2=0.5, r_3=0.9 $:
- 归一化接受概率:$ \alpha_1 = 0.8/0.9 \approx 0.89 $, $ \alpha_2 = 0.5/0.9 \approx 0.56 $, $ \alpha_3 = 1 $。
- 按概率接受:响应3必然被接受,响应1以89%概率接受,响应2以56%概率接受。
- 最终可能保留响应1和3,加入SFT数据集。
6. 注意事项
- 奖励模型的偏差:若奖励模型与真实目标不一致,可能导致筛选偏差。
- 多样性损失:过度筛选可能减少数据多样性,需平衡质量与覆盖度。
- 计算开销:生成和评估大量候选样本会增加训练成本。