Skip to content

MegatronPPOCritic

整体概述

MegatronPPOCritic 类实现了 PPO 算法中的价值函数(Critic)训练,使用 Megatron-LM 作为分布式训练后端。 megatron_critic.py:46 该类主要负责:1)计算状态价值估计,2)更新价值函数参数,3)支持多维并行训练(TP、PP、DP等)。

逐行/逐段解析

类初始化部分

python
def __init__(self, config, model_config, hf_config, tf_config, critic_module, critic_optimizer, critic_optimizer_config):

megatron_critic.py:47-56 初始化方法接收配置参数、模型组件和优化器。 megatron_critic.py:57-65 关键组件包括:

  • critic_module: Megatron 模型模块列表
  • critic_optimizer: 分布式优化器
  • tf_config: Megatron transformer 配置 megatron_critic.py:68-80 优化器步骤参数配置,包含序列并行、梯度累积等分布式训练设置。

配置验证

megatron_critic.py:82-87 配置验证确保 Megatron 后端不支持的功能(如 Ulysses 序列并行)被正确禁用,并要求在启用数据洗牌时设置随机种子。

价值计算方法

megatron_critic.py:90-142 compute_values 方法是核心的前向推理函数:

  1. 数据预处理: megatron_critic.py:91-100 将数据移至 GPU,提取批次大小和动态批处理配置
  2. 前向传播: megatron_critic.py:103-110 调用 forward_backward_batch 进行仅前向计算
  3. 结果处理: megatron_critic.py:111-130 在管道并行的最后阶段提取价值预测,应用响应掩码
  4. 跨进程同步: megatron_critic.py:133-137 通过广播确保所有管道并行 rank 获得相同结果

前向后向批处理

megatron_critic.py:154-289 forward_backward_batch 是核心计算函数:

  1. 数据广播: megatron_critic.py:167-171 在管道并行组间同步数据
  2. 微批次划分: megatron_critic.py:176-200 支持动态批处理和固定批处理两种模式
  3. 损失函数定义: megatron_critic.py:204-238 内嵌损失函数计算价值函数损失和统计信息
  4. 前向步骤: megatron_critic.py:240-258 定义单个微批次的前向计算逻辑

Critic 更新方法

megatron_critic.py:292-331 update_critic 方法执行价值函数训练:

  1. 数据迭代: megatron_critic.py:295-313 遍历小批次数据,执行前向后向传播
  2. 优化器更新: megatron_critic.py:315-327 执行梯度更新并收集训练指标

技术要点

分布式训练架构

该实现使用 Megatron-LM 的多维并行策略: megatron_critic.py:26

  • 管道并行:通过 mpu.is_pipeline_last_stage() 处理不同管道阶段
  • 张量并行:自动处理张量分片
  • 数据并行:通过分布式优化器实现

动态批处理优化

megatron_critic.py:176-192 支持基于 token 数量的动态批处理,通过 rearrange_micro_batches 函数优化内存使用和计算效率。

内存管理

megatron_critic.py:140 在每次计算后清空 GPU 缓存,防止内存泄漏。

潜在改进

  1. 错误处理:当前代码对分布式通信失败的处理较少,可以增加重试机制
  2. 性能监控:可以添加更详细的性能指标收集,如通信开销统计
  3. 配置验证:可以在初始化时进行更全面的配置兼容性检查

Notes

该实现是 VERL 框架中 Megatron 后端的核心组件之一,与 MegatronPPOActor 配合使用。 megatron_workers.py:889-897 在 CriticWorker 中被实例化和调用。 megatron_workers.py:921-934 该架构支持大规模模型的高效训练,是 VERL 支持 5D 并行训练的关键实现。

Maintained by Robin