Skip to content

HFRollout

整体概述

HFRollout 类是基于 Hugging Face Transformers 库的序列生成引擎,用于在强化学习训练过程中生成模型响应。 hf_rollout.py:39-44 它继承自 BaseRollout 基类,主要负责接收提示(prompts)并生成相应的序列输出。 hf_rollout.py:34-36

逐行/逐段解析

类初始化

python
def __init__(self, module: nn.Module, config):
    super().__init__()
    self.config = config
    self.module = module

hf_rollout.py:40-43

初始化方法接收一个 PyTorch 模型模块和配置对象,用于后续的序列生成。

主要生成方法

generate_sequences 方法是核心接口,负责处理批量提示生成: hf_rollout.py:45-51

  • 首先获取批次大小并计算需要分割的块数
  • 使用微批次处理来避免内存溢出
  • 将输入数据分块处理后重新合并结果

微批次生成逻辑

_generate_minibatch 方法实现了具体的生成逻辑: hf_rollout.py:53-92

采样参数处理

  • 支持从输入元信息中覆盖默认配置参数
  • 处理 do_sampletemperaturetop_ptop_k 等采样参数
  • 区分训练、验证和贪婪解码三种模式

生成配置构建: hf_rollout.py:64-92

  • 贪婪解码:do_sample=False,使用单束搜索
  • 验证模式:使用 val_kwargs 中的参数
  • 训练模式:使用配置中的采样参数,支持多次采样(num_return_sequences=self.config.n

模型推理执行

核心推理部分使用了多项优化技术: hf_rollout.py:103-122

FSDP 兼容性处理

python
if isinstance(self.module, FSDP):
    param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False)

hf_rollout.py:106-108

这里处理了 FSDP(Fully Sharded Data Parallel)模式下的参数聚合,确保生成时能访问完整模型参数。

自动混合精度: 使用 torch.autocast 进行 bfloat16 推理以提升性能: hf_rollout.py:109-122

序列后处理

生成完成后需要进行多项后处理: hf_rollout.py:124-153

长度填充: Hugging Face 的 generate 方法在所有序列达到 EOS 时会停止,因此需要手动填充到目标长度: hf_rollout.py:128-137

多样本处理: 当 num_return_sequences > 1 时,需要相应地重复输入张量: hf_rollout.py:139-143

位置编码更新: 为生成的响应部分计算正确的位置编码: hf_rollout.py:149-153

技术要点

  1. 批处理优化:通过微批次处理避免大批次导致的内存问题
  2. FSDP 集成:正确处理分布式训练中的参数聚合
  3. 混合精度推理:使用 bfloat16 提升推理效率
  4. 灵活采样策略:支持贪婪、采样和验证三种不同模式
  5. 张量操作:大量使用 PyTorch 张量操作进行高效的序列处理

更新 position_ids和attention_mask

  1. 生成序列的位置编码更新:
python
# 获取生成的回复长度
response_length = response.size(1)

# 创建增量位置编码
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
# 增加batch维度
delta_position_id = delta_position_id.unsqueeze(0).repeat(generated_batch_size, 1)

# 基于最后一个prompt位置继续计算response的位置编码
response_position_ids = position_ids[:, -1:] + delta_position_id
# 拼接prompt和response的位置编码
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)

这段代码的关键点:

  1. torch.arange(1, response_length + 1) 创建从1开始的连续位置编码
  2. unsqueeze(0) 增加batch维度
  3. repeat(generated_batch_size, 1) 复制到batch size大小
  4. 通过 position_ids[:, -1:] 获取prompt的最后位置,作为基准点
  5. 将delta_position_id加到基准点上,得到连续的位置编码
python
# 为生成的回复创建注意力掩码
response_attention_mask = get_response_mask(
    response_id=response,        # 生成的回复序列
    eos_token=eos_token_id,     # 结束符标记
    dtype=attention_mask.dtype   # 保持数据类型一致
)
# 拼接原始和新生成的注意力掩码
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)

函数会创建适当的注意力掩码,用于:

  • 标记有效token(1)和填充token(0)
  • 在EOS token后的位置设置为0
  • 确保模型不会关注到填充位置

Maintained by Robin