FSDP 深度解析
ChatGPT 掀起的大模型训练浪潮让不少同学都对训练大模型跃跃欲试,在找训练 baseline 的时候肯定发现大模型训练的 codebase 更倾向于用 DeepSpeed、ColossalAI 等大模型训练框架,而鲜有问津 PyTorch 原生的 FSDP (FullyShardedDataParallel)。这到底是为啥嘞?是 FSDP 不够节省显存?训练速度太慢?还是说不好用?请耐心看完这篇文章,相信一定会有所收获。
FSDP 的前生今世
FSDP 的实现借鉴了 FairScale。PyTorch 在开发大型特性时一般会新建一个库来做一些验证性的支持,并收集用户发反馈,FairScale、Dynamo(PyTorch 2.0 的基石)、torchdistx 均是如此。等到特性日益成熟后,(也许)就会合入到 PyTorch。相比于 PyTorch 官方在 Tutorial 里对 FSDP 简短的介绍,FairScale 显然做的更好,在正式开始介绍之前,贴一张 FairScale 的介绍,大家不妨思考一下,你真的需要 FSDP 么(其他大规模训练框架亦是如此)

ZeRO 系列简介
看过上面这张图的同学肯定会发现,FairScale 把 FSDP 定义为 ZeRO3,考虑到有些小伙伴可能对 ZeRO 系列的大模型优化策略不是很熟悉,这边做一个简短的介绍:

模型训练的时候,显存占用大体可以分成三部分,即激活值、模型权重、模型梯度和优化器状态。对于视觉模型而言,显存占比最大的是激活值,因此使用混合精度训练能够大幅度的降低激活值的显存占用(fp16)。然而对于大语言模型或者多模态模型而言,优化后三者的显存占用则显得更重要。
以 PyTorch 为例,当你使用 DistributedDataParallel 时,其实会在每个进程为模型参数、模型梯度、优化器状态分配内存,并在训练过程中同步地更新这些数据。这样的做法虽然能够通过数据并行以达到加速训练的目的,但是它在显存分配上的策略,显然是非常糟糕的。既然每个进行的参数都是一样的,为什么每个进程还需要保存完整的参数呢?所以 ZeRO 就主张每个进程只保存参数的一部分,用到的时候再 all gather 到各个进程。ZeRO 有三个阶段的优化策略,即:
- ZeRO1:只把优化器状态进行分片
- ZeRO2:对优化器状态 + 梯度进行分片
- ZeRO3:对优化器状态 + 梯度 + 模型参数进行分片
以 7.5 B (φ)参数量的模型为例,先简单计算一下模型参数、模型梯度、优化器状态的显存占用情况:
fp32 训练: 模型参数量为 φ,其梯度也为 φ,在使用 Adam 的情况下,优化器状态为 2φ。如果是普通的 fp32 训练,那么实际占用的内存就是 (1 + 1 + 2)φ * 4:16 φ 字节 (4 为 fp32 数据占据的内存大小);
fp16 训练: 如果开启混合精度训练,为了保证参数更新的精度,优化器状态需要维持在 fp32 ,此外还需要额外保存一份 fp32 模型参数的拷贝,因此显存占用为 2φ(模型参数) + 2φ(模型梯度) + 8φ(优化器状态) + 4φ(模型参数 fp32 拷贝,deepspeed 实现存储在优化器):16 φ 字节。
带入这样的视角,相信就能理解为什么上图中 7.5B 的模型显存占用可以高达 120B,以及为什么 ZeRO 系列为何如此有效。
FSDP - ZeRO3?
言归正传,FairScale 说 FSDP 相当于 ZeRO3 的优化,那我们不妨通过一个简单的例子,来感受一下(例子中优化器选择 SGD,因为 PyTorch 的 Adam 做了非常多的优化,其显存实际占用会明显高于理论)。在正式测试之前,我们先来看一下单卡 fp32 训练、单卡 fp16 训练、DDP fp16 训练的测试:
单卡 fp16 + fp32
class Layer(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Sequential(
*(nn.Linear(10000, 10000) for _ in range(10))
)
def forward(self, x):
return self.linear(x)
def test_fp32():
model = Layer().cuda()
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
data = torch.ones(10000).cuda()
for i in range(10):
optimizer.zero_grad()
output = model(data)
loss = output.sum()
loss.backward()
optimizer.step()
memory = max_memory_allocated()
print(f'step memory allocate: {memory / 1e9:.3f}G')
def test_fp16():
torch.cuda.init()
model = Layer().cuda()
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
data = torch.ones(10000).cuda()
for _ in range(10):
with autocast(device_type='cuda'):
optimizer.zero_grad()
output = model(data)
loss = output.sum()
loss.backward()
optimizer.step()
memory = max_memory_allocated()
print(f'memory allocated: {memory / 1e9:.3f}G')跑过代码后发现,显存占用如下: fp32: 12.035G fp16: 14.035G
啥?amp 显存占用还多了 2G?这是咋算的?这里就不得不提到 amp 的实现方式了。PyTorch 的 amp 不会改变模型权重的类型,即仍然以 fp32 存储,而选择在白名单算子的 forward backward 前后,把 fp32 的 weights 转换成 fp16,以计算出 fp16 的激活值和 fp16 的梯度,其中 fp16 的梯度还会进一步转换成 fp32,以保证参数更新的精度。但是既然权重和梯度仍然保留 fp32,优化器状态也理应保持不变,那为啥还多了 2G?原因在于 forward 和 backward 这份 fp16 的权重被缓存了,这部分实现在 amp 的 C++ 代码里。缓存的 fp16 梯度,就是多出来 2G 的源头。
要想节省这部分参数,需要给 autocast 传入 cache_enabled=False,
def test_fp16():
torch.cuda.init()
model = Layer().cuda()
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
data = torch.ones(10000).cuda()
for _ in range(10):
with autocast(device_type='cuda', cache_enabled=False):
optimizer.zero_grad()
output = model(data)
loss = output.sum()
loss.backward()
optimizer.step()
memory = max_memory_allocated()
print(f'memory allocated: {memory / 1e9:.3f}G')这样一来,显存消耗为 12.235G,基本和 fp32 一致,也符合预期。
DDP 训练
DDP 只是在每个进程创建模型,更新模型而已,显存占用应该还是 12G 吧?
def _test_ddp_fp16():
rank = dist.get_rank()
model = DistributedDataParallel(Layer().cuda())
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
data = torch.ones(10000).cuda()
for _ in range(10):
with autocast(device_type='cuda', cache_enabled=False):
optimizer.zero_grad()
output = model(data)
loss = output.sum()
loss.backward()
optimizer.step()
memory = max_memory_allocated()
if rank == 0:
print(f'memory allocated: {memory / 1e9:.3f}G')然而结果是: 16.036G
原理也很简单,ddp 执行 gradient computation 和 gradient synchronization 时需要有一个桶(bucket,具体介绍见之前的 DDP 介绍),桶会保留一份 gradient 的拷贝,因此会额外消耗 4G 左右的显存。
FSDP 训练
我们在使用 FSDP 时,需要通过配置 auto_wrap_policy 参数来选择模型分片策略,不然显存优化只能达到 ZeRO-stage1 的水准。如何配置 auto_wrap_policy 以及其对应的原理会在后面的章节具体介绍。
from torch.distributed.fsdp.wrap import _module_wrap_policy
def _test_fsdp_fp16():
rank = dist.get_rank()
fsdp_model = FullyShardedDataParallel(
module=Layer(), device_id=rank,
auto_wrap_policy=partial(
_module_wrap_policy,
module_classes=nn.Linear))
optimizer = SGD(fsdp_model.parameters(), lr=0.1, momentum=0.9)
data = torch.ones(10000).cuda()
for _ in range(10):
optimizer.zero_grad()
output = fsdp_model(data)
loss = output.sum()
loss.backward()
optimizer.step()
memory = max_memory_allocated()
if rank == 0:
print(f'step memory allocate: {memory / 1e9:.3f}G')
torch.cuda.reset_max_memory_allocated()结果是 1.524G,显存占用基本等价于 ZeRO3 的优化效果。
之所以做了这些内存占用分析,是希望大家从 DDP 切换到 FSDP 时,能够理性的看待显存优化。
FSDP 分片策略
上一章我们提到,我们需要通过 auto_wrap_policy 来指定模型分片策略,那么这个参数是如何起作用的呢?以及为什么不配这个参数,其优化效果只能达到 ZeRO-stage1。
与 DistiributedDataParallel 类似,FSDP 也是通过一个 model wrapper: FullyShardedDataParallel 来实现参数切分的逻辑。被 wrap 的 model 会成为 root fsdp module,而 root fsdp module 在构建时,会根据用户定义的 auto_wrap_policy 递归地把 submodule wrap 成 child fsdp module:

以官方实现的 _module_wrap_policy 为例,其中关键参数 module_classes 用于说明哪个类型的 submodule 应该被 wrap 成 child fsdp module
def _module_wrap_policy(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
module_classes: Set[Type[nn.Module]],
) -> bool:
"""
This auto wrap policy wraps every module that is an instance of any type in
``module_classes`` as its own FSDP instance. The root module given by
``module`` is always wrapped as an FSDP instance regardless. Since the
wrapping proceeds bottom up, each FSDP instance manages the parameters in
its subtree excluding any already managed by a child FSDP instance.
Args:
module (nn.Module): Current module being considered.
recurse (bool): If ``False``, then this function must decide whether
``module`` should be wrapped as an FSDP instance or not. If
``True``, then the function is still recursing down the module
tree as a part of the DFS.
nonwrapped_numel (int): Parameter numel not yet wrapped.
module_classes (Set[Type[nn.Module]]): Set of module classes that are
wrapped as FSDP instances.
Returns:
``True`` if ``recurse=True``, and whether ``module`` should be wrapped
if ``recurse=False``.
"""
if recurse:
return True # always recurse
if inspect.isclass(module_classes):
module_classes = (module_classes, )
return isinstance(module, tuple(module_classes))在上一章中我们将其指定成 nn.Linear,也就是说每个 nn.Linear 都会被 wrap 成 child fsdp module。 所有的 fsdp module 在 forward 过程中都会触发参数的 unshard (all gather) 和 shard。
- root fsdp module 的 forward,会在 pre-forward 阶段 all gather 不同进程的参数,并注册一些 prebackward-hook 和 post-backward-hook。然后在 post-forward 阶段释放不属于当前 rank 的参数。
其中 pre-backward-hook 会在执行 backward 之前再次 gather 参数,而 post-backward-hook 负责实现梯度的 reduce-scatter,即梯度同步 + 梯度分发。 需要注意的是,fsdp-module forward 时不会进一步 gather child fsdp module 的 parameter。
相比于 child fsdp module,root fsdp module 的 forward 还会额外做一些 cuda stream 初始化等工作,这里不做额外的展开。

2.child fsdp module 的 foward
主体逻辑基本同 root fsdp module


可见每次 fsdp module 只会 gather 部分参数,这样是符合我们预期的。那如果我们不设置 auto_wrap_policy 又会如何?那就是没有 child fsdp module

root fsdp module 在 forward 阶段,会直接 gather 所有的参数,也就意味着无法做到 ZeRO-stage3 中,通过对参数分片来实现节省显存。但是 ZeRO1 和 ZeRO2 里对梯度和优化器状态的分片,还是可以做到的。理由是 forward 阶段仍然会注册 post-backward-hook,因此 gradient reduce-scatter 的逻辑仍然会起作用。构建 Optimizer 时,传入的是 root fsdp module 的 parameters,因此优化器会直接更新分片后的参数、记录分片后参数的状态,因此优化器状态的分片的优化也是有效的。
auto_wrap_policy 需要遵循一定的接口规范即接受以下几个参数:
module:递归遍历 submodule 时,访问到的 module recurse:判断一个 submodule 为 child fsdp module 后,是否再进一步递归判断该 submodule 的 submodule 需要被 wrap 成 child fsdp module nonwrapped_numel:这个参数的的含义是当前模块,不需要被分片的参数的参数量。什么是不需要被分片的参数呢?一般来说包含两部分,即已经被分片的参数和用户指定的需要被忽略的参数(ignored_params)。基于这个参数可以实现 size-based wrap policy,例如官方实现的 size_based_auto_wrap_policy 。
FSDP 把 auto_wrap_policy 这个参数的配置权交给用户,扩展性固然是提升了,但是也无形的增加了 FSDP 的学习成本,比如 auto_wrap_policy 会起什么作用,它的几个入参的含义又是什么,刚使用 FSDP 的用户难免会为此感到一头雾水。
然而如果 FSDP 的使用成本仅限于此,我相信大家还是愿意去学习和使用的,然而一些隐性的约定和一些奇奇怪怪的报错,就非常劝退了。
FSDP 试错的血与泪
替换 submodule 的风险
上一章我们提到,fsdp 会把 submodule 替换成 wrap 之后的 child fsdp module,看到这你或许会奇怪,如果我 parent module 访问了 submodule 的一些属性或者方法,这个时候 submodule 被替换成 fsdp module,难道不会触发 attribute error 么?对于这种情况,FSDP 机智的重载 getattr 方法:
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self._fsdp_wrapped_module, name)这样对于没有定义的属性,它就会从 submodule 里去找。然而这样做仍然会有风险。
- 如果你访问的属性恰巧和 child fsdp module 本身的属性重名,就出现拿错属性的情况
- 如果你直接访问了 submodule 的 parameter,并对其做了一些操作。由于 parameter 是在 forward 阶段才会被 gather,那么此时你直接获取的是一个分片后的参数,大概率也会报错
- 如果你恰巧没有直接调用 child fsdp module 的
__call__方法,例如这种情况:
class Layer(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.processor = nn.Linear(1, 1)
self.linear1 = nn.Linear(1, 1)
self.linear2 = nn.Linear(1, 1)
def forward(self, x):
return self.linear1(x) + self.linear2(x)
class ToyModel(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = nn.Linear(1, 1)
self.layer = Layer() # 会被 auto wrap policy 指定为 child fsdp module
def forward(self, x):
y = self.linear(self.layer.processor(x))
return self.layer(y)假设 Layer 被 wrap 成了 fsdp module,由于 ToyModel.forward 里,直接调用了 self.layer.processor 的 forward,此时由于 layer 的 forward 没有被触发,layer.precessor 里的参数仍然处于分配的状态,也会报错。
又例如这种情况:
class A:
...
def loss(self, inputs: torch.Tensor, data_samples: List[DataSample]) -> dict:
feats = self.extract_feat(inputs)
return self.head.loss(feats, data_samples)
class B:
...
def loss(self, feats: Tuple[torch.Tensor], data_samples: List[DataSample], **kwargs) -> dict:
cls_score = self(feats) # 没有走 FSDP 的 forward
losses = self._get_loss(cls_score, data_samples, **kwargs)
return losses假如 class A 中的 self.head 类型为 class B,且被 wrap 成了 child fsdp module。那么在执行 self.head.loss 的时候,会通过 FSDP 的 getattr 方法直接找到 class B 的 loss,此时的局部变量 self 已经是 class B 实例而并非 FSDP,因此在执行 self(feats) 时不会进入 FSDP 的 forward 触发参数 all gather,进一步引发错误。
多参数组的优化器
PyTorch 的 optimizer 支持对 model 里的不同参数设置不同的学习率、动量等超参数。设置过程大概类似这样:
param_groups = []
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
param_groups.append({'param': module.weight, lr=0.01})
param_groups.append({'param': module.bias, lr=0.1})
elif:
optimizer = SGD(param_groups, lr=0.1)然而问题在于,在 PyTorch 2.0 之前,一旦 root fsdp module,child fsdp module 构建完成,它会删除原有的参数,例如 bn.weights,bn.bias,转而将 fsdp module 下所有未被切片的参数,转换成一个大的 flatten parameters。举个例子,如果上一章的 example 里,如果没有指定 auto_wrap_policy,那么就只会保留最外层的 root fsdp module。那么所有的 linear 层的 parameters 都会汇总成一个大的 flatten parameters,放在 root_fsdp_module 下:
rank = dist.get_rank()
fsdp_model = FullyShardedDataParallel(
module=Layer(), device_id=rank,
# auto_wrap_policy=partial(
# _module_wrap_policy,
# module_classes=nn.Linear),
)
print(list(fsdp_model.parameters()))此时每个 rank 只会打印出一个参数:
[Parameter containing:
Parameter(FlatParameter([-4.6519e-05, -6.2861e-03, 3.9519e-03, ..., -3.2763e-03,
7.1111e-04, -8.2136e-03], device='cuda:3', requires_grad=True))]因此在 PyTorch 2.0 之前,一旦使用了 FSDP,就很难对每个参数设置不同的学习率了,因为 fsdp wrap 之后多个参数会合并成一个参数。之后的 gradient 分片、参数更新也都是基于 flatten tensor 去实现的。 由于参数更新也是基于 flatten tensor 实现的,因此 FSDP 要求,每个 fsdp module 下的参数,dtype 和 requires_grad 属性都应该统一,否则无法 concat 成一个大的 flatten tensor。
PyTorch 2.0 为 FSDP 添加了 use_orig_params 参数,开启这个参数的情况下,FSDP wrap 的过程中不会删除原有的参数,而会让原有参数的内存指向 flatten params 的某个区域。这是一个很棒的更新,在不引入额外显存消耗的情况下,让用户仍然能够访问到分片之前的参数,并为其设置不同的优化器超参。引入这个参数后,按理说 ,fsdp module 下所有参数 requires_grad 属性统一的限制应该也解除了,但不幸的是,PyTorch 2.0 并没有调整这部分逻辑,不过在主分支上已经修复了这个问题,相信即将到来的 PyTorch 2.1 能够解决这个痛点。
FSDP 的接口稳定性
尽管说早在 PyTorch 1.11,FSDP 就已经是一个 beta 版本的特性了,然而时至今日,FSDP 模块仍然处于高速迭代的状态。FSDP 的开发者也于 2023 年 2 月发起了一个 discussion,介绍了一些设计理念,以及内部的重构。 除此之外,FSDP 的外部接口更新的也比较快,打开 PyTorch FSDP 的 api 文档,你会发现不少接口都贴上了 deprecated 标签。不过总的来说,新接口确实比老接口要易用、灵活很多,MMEngine 这次集成 FSDP,也都是基于新接口去开发的。
总结
FSDP 在显存节省方面,其效果确实与 ZeRO3 等价,但是需要注意的是,在开启混合精度训练(autocast)的情况下,需要把 cache_enabled 设置为 Flase。
FSDP 在易用性方面,上手成本比较高,用户需要理解 FSDP wrap module 的逻辑,auto_wrap_policy 的作用,以及一些限制。在不足够熟悉 FSDP 本身的逻辑和限制,足够了解 model 结构的情况下,容易出现报错,且触error message 和 error 真正的诱因没有太大关联,难以 debug。
PyTorch 2.0 通过 use_ori_params 参数大大提升了 FSDP 的易用性,但是对 requires_grad 属性统一的限制仍然存在。要解决这个问题可以坐等 PyTorch 2.1 更新,并指定 use_orig_params=True。但如果想要临时解决的话需要在 auto_wrap_policy 做一些改动,由于是基于 FSDP 内部的协议做的修改,可能不是很稳定,在这就不做赘述。
总的来说,FSDP 在易用性方面确实差强人意,但是在灵活性方面,留给了用户更大的操作空间,不过相信随着 PyTorch 的不断迭代,相信 FSDP 也会逐渐变得和 DDP 一样好用。MMEngine 也会紧跟 FSDP 更新的动向,在保持灵活性的基础上,尽可能的降低大家的使用门槛,总结出一套简单、易配置的最佳实践。