DataParallelPPOCritic
整体概述
DataParallelPPOCritic 类实现了基于 FSDP(Fully Sharded Data Parallel)的分布式价值函数训练,用于 PPO 强化学习算法中估计状态价值以计算优势函数。 dp_critic.py:46-47 该类继承自 BasePPOCritic,提供两个核心功能:计算价值估计(compute_values)和更新价值网络(update_critic)。 dp_critic.py:158
类初始化和配置
构造函数设置
def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Optimizer):dp_critic.py:47-50
构造函数接收配置对象、Critic 模型和优化器,并设置关键配置参数:
use_remove_padding:启用序列打包优化,移除填充 token 以提高内存效率 dp_critic.py:51-52ulysses_sequence_parallel_size:Ulysses 序列并行大小,用于处理长序列 dp_critic.py:54device_name:设备类型(CUDA/NPU) dp_critic.py:55
核心方法解析
1. _forward_micro_batch - 微批次前向传播
这是价值计算的核心方法,处理单个微批次的前向传播: dp_critic.py:57
输入处理
response_length = micro_batch["responses"].size(-1)
input_ids = micro_batch["input_ids"]
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]dp_critic.py:58-70
多模态输入支持
代码支持视觉-语言模型的多模态输入: dp_critic.py:59-64
序列打包优化路径
当启用 use_remove_padding 时,使用 Flash Attention 的变长序列优化:
- 使用
unpad_input移除填充 token dp_critic.py:75-77 - 重新排列 position_ids 以适配旋转位置编码 dp_critic.py:81-90
- 支持 Ulysses 序列并行处理 dp_critic.py:93-96
模型推理
根据是否使用序列打包,调用不同的模型前向传播路径:
- 打包路径:只传递
input_ids和position_ids,attention_mask设为 None dp_critic.py:99-105 - 标准路径:传递完整的输入 dp_critic.py:124-130
价值提取
支持两种模型架构:
- TRL 的
AutoModelForCausalLMWithValueHead:从output[2]提取价值 dp_critic.py:107-109 - 标准架构:从
output.logits提取价值 dp_critic.py:134-135
最终返回响应部分的价值估计: dp_critic.py:136-137
2. _optimizer_step - 优化器步骤
处理梯度裁剪和优化器更新: dp_critic.py:139-155
支持三种分布式训练后端:
- FSDP:使用内置的
clip_grad_norm_dp_critic.py:142-143 - FSDP2:使用专用的
fsdp2_clip_grad_norm_dp_critic.py:144-145 - 标准 PyTorch:使用
torch.nn.utils.clip_grad_norm_dp_critic.py:146-147
包含梯度异常检测,当梯度不是有限值时跳过更新: dp_critic.py:149-154
3. compute_values - 价值计算
这是对外暴露的价值计算接口,用于 PPO 算法中的优势计算: dp_critic.py:158-187
数据预处理
- 设置模型为评估模式 dp_critic.py:159
- 提取批次大小和动态批处理配置 dp_critic.py:160-161
- 选择必要的数据字段 dp_critic.py:163-166
批次处理
支持两种批处理模式:
- 动态批处理:基于 token 数量的智能分批 dp_critic.py:168-170
- 固定批处理:按固定大小分批 dp_critic.py:171-172
推理和后处理
- 无梯度推理处理每个微批次 dp_critic.py:177-179
- 拼接所有微批次结果 dp_critic.py:180
- 恢复动态批处理的原始顺序 dp_critic.py:182-183
- 应用响应掩码,只保留动作 token 的价值 dp_critic.py:185-186
4. update_critic - 价值网络更新
实现 PPO 算法中的价值函数更新: dp_critic.py:190-256
训练循环结构
遵循 PPO 论文的标准实现:
- 外层循环:PPO epochs dp_critic.py:205
- 中层循环:mini-batches dp_critic.py:206
- 内层循环:micro-batches(梯度累积) dp_critic.py:218
损失计算
使用 core_algos.compute_value_loss 计算价值损失: dp_critic.py:226-233
- 支持价值裁剪(
cliprange_value) - 可配置的损失聚合模式(
loss_agg_mode)
梯度处理
- 动态批处理时按比例缩放损失 dp_critic.py:234-236
- 固定批处理时按梯度累积步数缩放 dp_critic.py:237-238
- 执行反向传播 dp_critic.py:240
指标收集
收集训练指标包括损失值、裁剪比例和预测均值: dp_critic.py:242-248
技术要点
1. 分布式训练集成
- FSDP 支持:与 PyTorch FSDP 和 FSDP2 深度集成,支持参数分片
- Ulysses 序列并行:处理超长序列的并行策略
- 动态批处理:基于 token 数量的智能批处理,提高 GPU 利用率
2. 内存优化
- 序列打包:移除填充 token,减少内存使用和计算量
- 梯度检查点:可选的激活重计算以节省内存
- 混合精度:使用 bfloat16 进行前向传播 dp_critic.py:66
3. 多模态支持
支持视觉-语言模型,能够处理图像和文本的联合输入
在 PPO 训练中的作用
该 Critic 在 PPO 训练流程中的位置: hybrid_flow.rst:190-201
- 生成序列后调用
compute_values计算状态价值 - 与奖励一起用于计算优势函数
- 在策略更新后调用
update_critic改进价值估计
Notes
这个实现展现了现代分布式强化学习系统的复杂性,集成了多种优化技术以处理大规模语言模型的训练。代码设计考虑了内存效率、计算性能和多种硬件配置的兼容性。动态批处理和序列打包等优化对于实际部署中的性能至关重要。