HFRollout
整体概述
HFRollout 类是基于 Hugging Face Transformers 库的序列生成引擎,用于在强化学习训练过程中生成模型响应。 hf_rollout.py:39-44 它继承自 BaseRollout 基类,主要负责接收提示(prompts)并生成相应的序列输出。 hf_rollout.py:34-36
逐行/逐段解析
类初始化
def __init__(self, module: nn.Module, config):
super().__init__()
self.config = config
self.module = modulehf_rollout.py:40-43
初始化方法接收一个 PyTorch 模型模块和配置对象,用于后续的序列生成。
主要生成方法
generate_sequences 方法是核心接口,负责处理批量提示生成: hf_rollout.py:45-51
- 首先获取批次大小并计算需要分割的块数
- 使用微批次处理来避免内存溢出
- 将输入数据分块处理后重新合并结果
微批次生成逻辑
_generate_minibatch 方法实现了具体的生成逻辑: hf_rollout.py:53-92
采样参数处理:
- 支持从输入元信息中覆盖默认配置参数
- 处理
do_sample、temperature、top_p、top_k等采样参数 - 区分训练、验证和贪婪解码三种模式
生成配置构建: hf_rollout.py:64-92
- 贪婪解码:
do_sample=False,使用单束搜索 - 验证模式:使用
val_kwargs中的参数 - 训练模式:使用配置中的采样参数,支持多次采样(
num_return_sequences=self.config.n)
模型推理执行
核心推理部分使用了多项优化技术: hf_rollout.py:103-122
FSDP 兼容性处理:
if isinstance(self.module, FSDP):
param_ctx = FSDP.summon_full_params(self.module, writeback=False, recurse=False)hf_rollout.py:106-108
这里处理了 FSDP(Fully Sharded Data Parallel)模式下的参数聚合,确保生成时能访问完整模型参数。
自动混合精度: 使用 torch.autocast 进行 bfloat16 推理以提升性能: hf_rollout.py:109-122
序列后处理
生成完成后需要进行多项后处理: hf_rollout.py:124-153
长度填充: Hugging Face 的 generate 方法在所有序列达到 EOS 时会停止,因此需要手动填充到目标长度: hf_rollout.py:128-137
多样本处理: 当 num_return_sequences > 1 时,需要相应地重复输入张量: hf_rollout.py:139-143
位置编码更新: 为生成的响应部分计算正确的位置编码: hf_rollout.py:149-153
技术要点
- 批处理优化:通过微批次处理避免大批次导致的内存问题
- FSDP 集成:正确处理分布式训练中的参数聚合
- 混合精度推理:使用 bfloat16 提升推理效率
- 灵活采样策略:支持贪婪、采样和验证三种不同模式
- 张量操作:大量使用 PyTorch 张量操作进行高效的序列处理
更新 position_ids和attention_mask
- 生成序列的位置编码更新:
# 获取生成的回复长度
response_length = response.size(1)
# 创建增量位置编码
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
# 增加batch维度
delta_position_id = delta_position_id.unsqueeze(0).repeat(generated_batch_size, 1)
# 基于最后一个prompt位置继续计算response的位置编码
response_position_ids = position_ids[:, -1:] + delta_position_id
# 拼接prompt和response的位置编码
position_ids = torch.cat([position_ids, response_position_ids], dim=-1)这段代码的关键点:
torch.arange(1, response_length + 1)创建从1开始的连续位置编码unsqueeze(0)增加batch维度repeat(generated_batch_size, 1)复制到batch size大小- 通过
position_ids[:, -1:]获取prompt的最后位置,作为基准点 - 将delta_position_id加到基准点上,得到连续的位置编码
# 为生成的回复创建注意力掩码
response_attention_mask = get_response_mask(
response_id=response, # 生成的回复序列
eos_token=eos_token_id, # 结束符标记
dtype=attention_mask.dtype # 保持数据类型一致
)
# 拼接原始和新生成的注意力掩码
attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1)函数会创建适当的注意力掩码,用于:
- 标记有效token(1)和填充token(0)
- 在EOS token后的位置设置为0
- 确保模型不会关注到填充位置