Skip to content

PyTorch FSDP 后端实现

本框架通过为 Actor (actor)、Critic (critic)、Reference (reference)、Rollout (rollout)和 Reward (Reward Model)实现对应的 Worker (worker),全面支持 PyTorch 的 Fully Sharded Data Parallel (FSDP)后端。此外,我们还在 fsdp_vllm.py 中实现了 FSDPVLLMShardingManager,用于在 FSDP 与 vLLM 之间进行权重重分片(resharding),以实现高效的模型状态转换。

1. 优势与适用场景

1.1 优势

  • 模型兼容性强,易于扩展 用户仅需实现相应的 dtensor_weight_loader,即可完成 FSDP 与 vLLM 之间的权重同步。对于支持 Hugging Face (HF)格式的模型,用户可直接使用 hf_weight_loader,无需额外修改代码,即可兼容任何同时被 HF 和 vLLM 支持的模型。

  • 计算流程清晰,便于管理 FSDP 后端将各模型的前向传播与反向传播逻辑组织得更为清晰,有利于复杂训练流程的开发与调试。

1.2 劣势

  • 大规模模型扩展性有限 在面对超大规模模型(如 Llama 70B 或 405B)时, FSDP 的内存和通信开销可能导致扩展性不足。

  • 重分片开销较高 在 actor 与 rollout 模型之间进行权重重分片时,其通信与转换开销可能高于 Megatron-LM 后端。

鉴于其简洁性和开发友好性,我们推荐将 FSDP 后端用于算法研究与原型开发阶段,尤其适用于中小规模模型的快速迭代与验证。


2. FSDP Worker 实现

2.1 ActorRolloutRefWorker

ActorRolloutRefWorker 是一个集成了 Actor、Rollout 和 Reference 功能的复合 Worker,支持混合部署模式。

2.1.1 Actor/Rollout 混合引擎

2.1.2 模型初始化接口

python
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
  • Dispatch.ONE_TO_ALL:当驱动进程调用 init_model 时,该函数将在每个 GPU Worker 上并行执行,完成本地模型的初始化。

初始化流程主要包括以下组件:

  • DataParallelPPOActor:封装了基于 FSDP 的 PPO 基础计算逻辑,包括对数概率(log probability)计算和模型参数更新。
  • vLLMRollout:集成 vLLM 实现高效的自回归生成。我们对 vLLM 引擎进行了修改,使其支持 SPMD (Single Program, Multiple Data)模式,以适配 WorkerGroup 的分布式架构。
  • FSDPVLLMShardingManager:作为上下文管理器,负责在 Actor (FSDP)与 Rollout (vLLM)之间执行权重重分片操作。

更多实现细节请参见 源代码

2.1.3 生成序列并重新计算对数概率

python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
  • Dispatch.DP_COMPUTE_PROTO:数据将沿数据并行维度进行分发与聚合。
  • 该函数中, Rollout 使用 vLLM 执行自回归生成,而 Actor 则对生成的响应重新计算其在旧策略下的对数概率 logπθold(a|s),用于后续的 PPO 优势估计。

2.1.4 更新 Actor 模型

python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor(self, data: DataProto):
  • 使用 PPO 目标函数(含策略梯度与熵正则项)更新 Actor 模型参数。PPO 损失函数定义如下:
LCLIP(θ)=Et[min(rt(θ)A^t,clip(rt(θ),1ϵ,1+ϵ)A^t)]

其中 rt(θ)=πθ(at|st)πθold(at|st) 为概率比,A^t 为优势函数估计值。

2.1.5 Reference Model

2.1.6 Reference 初始化

Reference 复用 Actor 的初始化接口,但不初始化优化器和混合引擎组件。初始化完成后,模型由 DataParallelPPOActor 封装,仅用于前向推理。

2.1.7 计算参考对数概率

python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
  • 该函数调用 DataParallelPPOActor 中的对数概率计算模块,获取参考策略下生成动作的对数概率 logπθref(a|s),用于后续 KL 散度或奖励计算。

2.2 CriticWorker 与 RewardWorker

2.2.1 模型初始化

Critic (Critic)与 Reward (Reward Model)的初始化流程与 Reference 类似。区别在于, Critic 还需初始化优化器以支持反向传播更新。

2.2.2 计算价值函数

python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_values(self, data: DataProto):
  • Critic 模型接收状态 s,输出状态价值估计 Vϕ(s),用于优势函数计算:
A^t=δt+(γλ)δt+1++(γλ)Tt+1δT1

其中 δt=rt+γV(st+1)V(st)

2.2.3 更新 Critic

python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_critic(self, data: DataProto):
  • 使用均方误差(MSE)损失函数更新价值网络参数:
Lvalue(ϕ)=Et[(Vϕ(st)Vttarget)2]

2.2.4 计算奖励分数

python
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
  • RewardWorker 调用预训练的 Reward,对生成响应进行打分,输出标量奖励 r(s,a)

3. 混合分片支持

当前版本暂不支持 FSDP 混合分片模式(Hybrid Sharding)。若需实现该功能,可能需要:

  1. 构建二维设备网格(2D device mesh),结合张量并行与数据并行;
  2. 为不同模型分别设计和测试 dtensor_weight_loaderhf_weight_loader 的适配逻辑;
  3. 实现跨分片策略的权重映射与同步机制。

未来版本将考虑引入对混合分片的支持,以提升大规模模型训练的效率与灵活性。

Maintained by Robin