Skip to content

FSDPVLLMShardingManager

整体概述

verl/workers/sharding_manager/fsdp_vllm.py 文件实现了 FSDPVLLMShardingManager 类,这是一个用于管理 FSDP(FullyShardedDataParallel)训练模型与 vLLM 推理引擎之间参数同步的分片管理器 fsdp_vllm.py:55-61 。该类的核心作用是在强化学习训练过程中,将训练模型的权重高效地同步到推理引擎,实现训练和推理的无缝切换。

逐行/逐段解析

类初始化 (__init__ 方法)

初始化方法接收多个关键参数来配置分片管理器 fsdp_vllm.py:64-75 :

  • module: FSDP 包装的训练模型
  • inference_engine: vLLM 推理引擎实例
  • model_configrollout_config: 模型和推理配置
  • full_params: 是否使用完整参数模式
  • device_mesh: 设备网格,用于分布式计算
  • offload_param: 是否启用参数卸载到 CPU
  • load_format: 权重加载格式
  • layered_summon: 是否使用分层召唤优化

初始化过程中会根据 FSDP 版本设置不同的状态字典类型 fsdp_vllm.py:97-106 ,并配置张量并行相关参数 fsdp_vllm.py:108-109 。

上下文管理器入口 (__enter__ 方法)

这是权重同步的核心逻辑,使用上下文管理器模式确保资源的正确管理 fsdp_vllm.py:127 。

LoRA 参数收集

内部定义了 __collect_lora_params 函数来处理 LoRA(Low-Rank Adaptation)参数的收集 fsdp_vllm.py:128-133 。该函数支持两种模式:

  1. 分层召唤模式:当 layered_summon=True 且基础模型已预加载时,使用优化的分层参数召唤 fsdp_vllm.py:138-144
  2. 完整召唤模式:使用 FSDP.summon_full_params 获取完整参数 fsdp_vllm.py:146

权重同步流程

主要的权重同步流程包括:

  1. 内存清理:在开始前清空 GPU 缓存 fsdp_vllm.py:194
  2. 参数提取:根据模型类型(普通模型或 PEFT 模型)提取相应参数 fsdp_vllm.py:200-206
  3. 权重键转换:将 FSDP 格式的权重键转换为 vLLM 兼容格式 fsdp_vllm.py:207
  4. 引擎唤醒:如果启用了缓存引擎释放功能,先唤醒推理引擎 fsdp_vllm.py:210-214
  5. 参数更新:调用 update_params 方法将参数同步到推理引擎 fsdp_vllm.py:217

上下文管理器退出 (__exit__ 方法)

退出时进行清理工作 fsdp_vllm.py:238 :

  1. 引擎休眠:如果启用了缓存引擎释放,让推理引擎进入休眠状态释放内存 fsdp_vllm.py:239-240
  2. 恢复训练模式:将模型设置回训练模式 fsdp_vllm.py:242
  3. 随机状态恢复:恢复之前保存的随机数生成器状态 fsdp_vllm.py:248-250

数据预处理和后处理

  • 预处理 (preprocess_data):在张量并行大小大于1时,执行全收集操作确保每个rank都有相同的输入数据 fsdp_vllm.py:253-262
  • 后处理 (postprocess_data):将数据按张量并行维度分块,每个rank只保留自己的部分 fsdp_vllm.py:265-270

技术要点

  1. 上下文管理器模式:使用 __enter____exit__ 方法确保资源的正确获取和释放
  2. 分布式计算:支持张量并行和数据并行的混合分布式策略
  3. 内存管理:通过参数卸载、缓存清理等机制优化内存使用
  4. LoRA 支持:专门处理 LoRA 适配器的参数同步
  5. 异步处理:支持异步推理引擎的唤醒和休眠机制

Maintained by Robin