veScale-FSDP 深度解读:让大模型分布式训练的分片方式真正灵活起来
论文信息:veScale-FSDP: Flexible and High-Performance FSDP at Scale 作者:Zezhou Wang, Youjie Li, Zhiqi Lin 等(ByteDance Seed) 链接:arXiv:2602.22437 Github:https://github.com/ByteDance/veScale-FSDP
一、论文核心研究背景
训练大语言模型(LLM)需要将模型和优化器状态分散到数千甚至上万块 GPU 上。Fully Sharded Data Parallel(FSDP),也被称为 Zero Redundancy Optimizer(ZeRO),是目前最主流的选择——它通过将模型参数、梯度和优化器状态分片(shard)到不同 GPU 上来节省内存,同时保持数据并行的编程简洁性。
然而,现有 FSDP 系统存在一个根本性的结构性矛盾:
- 现有分片格式(元素级 element-wise 或行级 row-wise)产生的分片边界,常常与下游计算所需的块结构(block structure)不对齐。
- 现代训练技术——如 块量化(block-wise quantization) 和 矩阵优化器(Shampoo、Muon)——都明确要求张量块在分片后保持完整。
这个矛盾导致现有系统不得不在灵活性和性能之间做痛苦的取舍:要么放弃这些先进训练技术,要么承受巨大的填充(padding)和通信开销。
二、主要贡献与创新点
veScale-FSDP 提出了三项核心创新,系统性解决了上述矛盾:
贡献 1:RaggedShard — 真正灵活的分片格式
技术普及:传统分片就像把一本书要么拆成单个字(element-wise),要么拆成等长的行(row-wise)。RaggedShard 则允许你按任意”段落”(自定义块大小)来拆,而且每段长度可以不同——就像你可以按章节拆,也可以按自然段落拆,每章长度各异。
RaggedShard 支持任意分片粒度和自定义块大小,通过一个整数序列定义每个分片的长度。它统一了所有现有分片格式:
- 块大小 = 1 → 等价于元素级分片
- 块大小 = 行维度 → 等价于行级分片
- 任意块大小 → 支持块量化、矩阵优化器等块结构计算
贡献 2:结构感知规划算法
在 RaggedShard 提供的灵活性基础上,veScale-FSDP 设计了一个规划算法,自动重排 RaggedShard 张量以最大化通信效率。该算法需要同时满足三个约束:
- 非跨设备块(Non-Sharded Block):块不能跨越设备边界
- 连续张量内存(Contiguous Tensor Memory):消除内部填充,避免内存碎片
- 负载均衡(Balanced Load):保持集合通信的对称性
这个问题本身是 NP-hard 的(可归约到经典的分区问题),论文提出了一个多项式时间动态规划启发式算法,在实际场景中运行时间小于 0.3 秒,填充开销控制在 3% 以内。
贡献 3:DBuffer — 零拷贝全局缓冲区
技术普及:想象多个 GPU 各自有一块内存,传统方式下数据在通信前后需要来回拷贝(copy-in/copy-out)。DBuffer 就像在所有 GPU 上铺了一张连续的”大地毯”,每个 GPU 直接在上面划出自己的区域使用,通信时数据原地不动,省去了来回搬运。
DBuffer 为 RaggedShard 张量提供全局缓冲区语义,支持:
- 组级算子(group-level operators):实现内核融合,减少启动开销
- 通信前后零拷贝访问:消除 FSDP2 中高达 5~23ms 的 copy-out/copy-in 开销
三、关键技术解析(含技术普及)
3.1 为什么现有 FSDP 分片格式不够用?
技术普及:FSDP(Fully Sharded Data Parallel)是数据并行的一种进阶形式。普通数据并行(DP)是每个 GPU 存一份完整模型,各自处理不同数据批次;FSDP 则把模型参数、梯度和优化器状态”切碎了”分散到所有 GPU 上,只在计算前临时把需要的参数凑齐(AllGather),计算完再分散(ReduceScatter)。这样每个 GPU 只存 1/N 的模型状态,大幅节省内存。
但”怎么切”是个关键问题。切得太碎(元素级)→ 通信频繁且无法做块计算;切得太粗(行级)→ 遇到矩阵优化器时整块被切成两半,无法独立计算。
现有四大 FSDP 系统的局限性对比:
| 系统 | 分片方式 | 核心局限 |
|---|---|---|
| DeepSpeed ZeRO | 元素级 | AllGather 操作碎片化 |
| PyTorch FSDP1 | 元素级 | ReduceScatter 慢,内存开销大 |
| PyTorch FSDP2 | 行级(Shard(0)) | Copy-Out 5 |
| Megatron-FSDP | 行级+填充 | 33% 缓冲区填充膨胀 |
3.2 RaggedShard 如何统一所有分片格式?
RaggedShard 的核心是一个整数序列 offsets,定义每个设备上张量的起始和结束位置。例如:
offsets = [0, 512, 1024, 1536, 2048]
表示 4 个设备分别持有长度 512、512、512、512 的段。如果改为:
offsets = [0, 32, 1056, 1088, 2048]
则表示设备 0 和 2 各持有 32 个元素,设备 1 和 3 各持有 1024 个元素——这就是”Ragged”(参差不齐)的含义。
与 TP/EP 的兼容性:通过 StridedRaggedShard 和基于最小公倍数(LCM)的粒度适配,RaggedShard 可以与张量并行(TP)和专家并行(EP)无缝组合。
3.3 DBuffer 的零拷贝设计
DBuffer 的关键洞察是:如果所有 RaggedShard 张量都从一个预先分配的全局连续缓冲区中切出,那么通信前后的数据就不需要拷贝。
具体实现:
- 预先分配一个全局连续缓冲区,大小等于所有设备本地分片之和
- 每个 RaggedShard 张量只是这个缓冲区的一个”视图”(slice)
- AllGather / ReduceScatter 直接在缓冲区上操作
- 计算时直接读取缓冲区中的数据,无需 copy-in;计算结果写回缓冲区,无需 copy-out
这消除了 PyTorch FSDP2 中实测高达 5~23ms 的拷贝开销。
四、实验结论与价值总结
端到端性能(1024 GPUs)
| 模型 | 吞吐量提升 | 内存降低 |
|---|---|---|
| LLaMA-3-70B | ~5% | 16~30% |
| GPT-OSS-120B | 11~66% | 16~30% |
| Internal MoE | 11~66% | 16~30% |
提升来源:优化的通信重叠、DBuffer 零拷贝集合通信、灵活分片粒度避免的填充开销。
扩展性
- 弱扩展(Weak Scaling):近线性扩展到 8K GPUs
- 强扩展(Strong Scaling):在 120M token 全局批次下线性扩展到 10K GPUs
- 模型扩展:2.4T 参数模型仅需 1K GPUs 即可运行,无性能退化
实际训练技术验证
8-bit Adam:使用 32×32 块大小,每设备独立量化本地分片,无需任何额外通信。
分布式 Muon:通过 Redistribute(u, RaggedShard(r)) 将完整矩阵收集到根设备,应用 Newton-Schulz 迭代,再重新分片回去。在 256 块 Hopper GPU 上达到 47.3% MFU。
消融实验
| 禁用组件 | 归一化吞吐量 |
|---|---|
| 无(完整系统) | 100.0% |
| 仅 DBuffer | 92.8% |
| 仅规划算法 | 65.4% |
| 仅 RaggedShard | N/A(系统无法有意义地运行) |
规划算法贡献了最大的性能提升(34.6%),DBuffer 贡献 7.2%,两者缺一不可。
局限与未来方向
- 规划算法虽为多项式时间,但在超大规模(数万 GPU)下的运行时仍需关注
- 目前主要针对数据并行场景,与更复杂的 3D 并行组合时的最优策略有待探索
- 代码已开源:github.com/volcengine/veScale
一句话总结
veScale-FSDP 通过 RaggedShard 打破了”元素级 or 行级”的二元分片困境,让 FSDP 首次同时支持任意块结构计算和高效通信,为大模型训练的下一代优化技术(Muon、块量化等)铺平了道路。