FSDPVLLMShardingManager
整体概述
verl/workers/sharding_manager/fsdp_vllm.py 文件实现了 FSDPVLLMShardingManager 类,这是一个用于管理 FSDP(FullyShardedDataParallel)训练模型与 vLLM 推理引擎之间参数同步的分片管理器 fsdp_vllm.py:55-61 。该类的核心作用是在强化学习训练过程中,将训练模型的权重高效地同步到推理引擎,实现训练和推理的无缝切换。
逐行/逐段解析
类初始化 (__init__ 方法)
初始化方法接收多个关键参数来配置分片管理器 fsdp_vllm.py:64-75 :
module: FSDP 包装的训练模型inference_engine: vLLM 推理引擎实例model_config和rollout_config: 模型和推理配置full_params: 是否使用完整参数模式device_mesh: 设备网格,用于分布式计算offload_param: 是否启用参数卸载到 CPUload_format: 权重加载格式layered_summon: 是否使用分层召唤优化
初始化过程中会根据 FSDP 版本设置不同的状态字典类型 fsdp_vllm.py:97-106 ,并配置张量并行相关参数 fsdp_vllm.py:108-109 。
上下文管理器入口 (__enter__ 方法)
这是权重同步的核心逻辑,使用上下文管理器模式确保资源的正确管理 fsdp_vllm.py:127 。
LoRA 参数收集
内部定义了 __collect_lora_params 函数来处理 LoRA(Low-Rank Adaptation)参数的收集 fsdp_vllm.py:128-133 。该函数支持两种模式:
- 分层召唤模式:当
layered_summon=True且基础模型已预加载时,使用优化的分层参数召唤 fsdp_vllm.py:138-144 - 完整召唤模式:使用
FSDP.summon_full_params获取完整参数 fsdp_vllm.py:146
权重同步流程
主要的权重同步流程包括:
- 内存清理:在开始前清空 GPU 缓存 fsdp_vllm.py:194
- 参数提取:根据模型类型(普通模型或 PEFT 模型)提取相应参数 fsdp_vllm.py:200-206
- 权重键转换:将 FSDP 格式的权重键转换为 vLLM 兼容格式 fsdp_vllm.py:207
- 引擎唤醒:如果启用了缓存引擎释放功能,先唤醒推理引擎 fsdp_vllm.py:210-214
- 参数更新:调用
update_params方法将参数同步到推理引擎 fsdp_vllm.py:217
上下文管理器退出 (__exit__ 方法)
退出时进行清理工作 fsdp_vllm.py:238 :
- 引擎休眠:如果启用了缓存引擎释放,让推理引擎进入休眠状态释放内存 fsdp_vllm.py:239-240
- 恢复训练模式:将模型设置回训练模式 fsdp_vllm.py:242
- 随机状态恢复:恢复之前保存的随机数生成器状态 fsdp_vllm.py:248-250
数据预处理和后处理
- 预处理 (
preprocess_data):在张量并行大小大于1时,执行全收集操作确保每个rank都有相同的输入数据 fsdp_vllm.py:253-262 - 后处理 (
postprocess_data):将数据按张量并行维度分块,每个rank只保留自己的部分 fsdp_vllm.py:265-270
技术要点
- 上下文管理器模式:使用
__enter__和__exit__方法确保资源的正确获取和释放 - 分布式计算:支持张量并行和数据并行的混合分布式策略
- 内存管理:通过参数卸载、缓存清理等机制优化内存使用
- LoRA 支持:专门处理 LoRA 适配器的参数同步
- 异步处理:支持异步推理引擎的唤醒和休眠机制