Skip to content

DataParallelPPOCritic

整体概述

DataParallelPPOCritic 类实现了基于 FSDP(Fully Sharded Data Parallel)的分布式价值函数训练,用于 PPO 强化学习算法中估计状态价值以计算优势函数。 dp_critic.py:46-47 该类继承自 BasePPOCritic,提供两个核心功能:计算价值估计(compute_values)和更新价值网络(update_critic)。 dp_critic.py:158

类初始化和配置

构造函数设置

python
def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer):

dp_critic.py:47-50

构造函数接收配置对象、Critic 模型和优化器,并设置关键配置参数:

  • use_remove_padding:启用序列打包优化,移除填充 token 以提高内存效率 dp_critic.py:51-52
  • ulysses_sequence_parallel_size:Ulysses 序列并行大小,用于处理长序列 dp_critic.py:54
  • device_name:设备类型(CUDA/NPU) dp_critic.py:55

核心方法解析

1. _forward_micro_batch - 微批次前向传播

这是价值计算的核心方法,处理单个微批次的前向传播: dp_critic.py:57

输入处理

python
response_length = micro_batch["responses"].size(-1)
input_ids = micro_batch["input_ids"]
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]

dp_critic.py:58-70

多模态输入支持

代码支持视觉-语言模型的多模态输入: dp_critic.py:59-64

序列打包优化路径

当启用 use_remove_padding 时,使用 Flash Attention 的变长序列优化:

  • 使用 unpad_input 移除填充 token dp_critic.py:75-77
  • 重新排列 position_ids 以适配旋转位置编码 dp_critic.py:81-90
  • 支持 Ulysses 序列并行处理 dp_critic.py:93-96

模型推理

根据是否使用序列打包,调用不同的模型前向传播路径:

  • 打包路径:只传递 input_idsposition_idsattention_mask 设为 None dp_critic.py:99-105
  • 标准路径:传递完整的输入 dp_critic.py:124-130

价值提取

支持两种模型架构:

  • TRL 的 AutoModelForCausalLMWithValueHead:从 output[2] 提取价值 dp_critic.py:107-109
  • 标准架构:从 output.logits 提取价值 dp_critic.py:134-135

最终返回响应部分的价值估计: dp_critic.py:136-137

2. _optimizer_step - 优化器步骤

处理梯度裁剪和优化器更新: dp_critic.py:139-155

支持三种分布式训练后端:

  • FSDP:使用内置的 clip_grad_norm_ dp_critic.py:142-143
  • FSDP2:使用专用的 fsdp2_clip_grad_norm_ dp_critic.py:144-145
  • 标准 PyTorch:使用 torch.nn.utils.clip_grad_norm_ dp_critic.py:146-147

包含梯度异常检测,当梯度不是有限值时跳过更新: dp_critic.py:149-154

3. compute_values - 价值计算

这是对外暴露的价值计算接口,用于 PPO 算法中的优势计算: dp_critic.py:158-187

数据预处理

  • 设置模型为评估模式 dp_critic.py:159
  • 提取批次大小和动态批处理配置 dp_critic.py:160-161
  • 选择必要的数据字段 dp_critic.py:163-166

批次处理

支持两种批处理模式:

  • 动态批处理:基于 token 数量的智能分批 dp_critic.py:168-170
  • 固定批处理:按固定大小分批 dp_critic.py:171-172

推理和后处理

  • 无梯度推理处理每个微批次 dp_critic.py:177-179
  • 拼接所有微批次结果 dp_critic.py:180
  • 恢复动态批处理的原始顺序 dp_critic.py:182-183
  • 应用响应掩码,只保留动作 token 的价值 dp_critic.py:185-186

4. update_critic - 价值网络更新

实现 PPO 算法中的价值函数更新: dp_critic.py:190-256

训练循环结构

遵循 PPO 论文的标准实现:

  • 外层循环:PPO epochs dp_critic.py:205
  • 中层循环:mini-batches dp_critic.py:206
  • 内层循环:micro-batches(梯度累积) dp_critic.py:218

损失计算

使用 core_algos.compute_value_loss 计算价值损失: dp_critic.py:226-233

  • 支持价值裁剪(cliprange_value
  • 可配置的损失聚合模式(loss_agg_mode

梯度处理

  • 动态批处理时按比例缩放损失 dp_critic.py:234-236
  • 固定批处理时按梯度累积步数缩放 dp_critic.py:237-238
  • 执行反向传播 dp_critic.py:240

指标收集

收集训练指标包括损失值、裁剪比例和预测均值: dp_critic.py:242-248

技术要点

1. 分布式训练集成

  • FSDP 支持:与 PyTorch FSDP 和 FSDP2 深度集成,支持参数分片
  • Ulysses 序列并行:处理超长序列的并行策略
  • 动态批处理:基于 token 数量的智能批处理,提高 GPU 利用率

2. 内存优化

  • 序列打包:移除填充 token,减少内存使用和计算量
  • 梯度检查点:可选的激活重计算以节省内存
  • 混合精度:使用 bfloat16 进行前向传播 dp_critic.py:66

3. 多模态支持

支持视觉-语言模型,能够处理图像和文本的联合输入

在 PPO 训练中的作用

该 Critic 在 PPO 训练流程中的位置: hybrid_flow.rst:190-201

  1. 生成序列后调用 compute_values 计算状态价值
  2. 与奖励一起用于计算优势函数
  3. 在策略更新后调用 update_critic 改进价值估计

Notes

这个实现展现了现代分布式强化学习系统的复杂性,集成了多种优化技术以处理大规模语言模型的训练。代码设计考虑了内存效率、计算性能和多种硬件配置的兼容性。动态批处理和序列打包等优化对于实际部署中的性能至关重要。

Maintained by Robin