ZeRO 技术原理:基础与实现
本次大规模训练技术系列分享围绕微软 ZeRO Optimizer 的思路与实现展开,全文包含以下四个部分:
- 大规模训练的技术挑战与现有并行训练方式
- ZeRO Optimizer 的三个不同级别
- ZeRO-3 的具体实现思路与方式
- ZeRO 的局限性与大模型训练的未来方向
1. 训练大模型的挑战
随着人工智能技术在全球的推广应用,自动驾驶、人脸识别、自然语言处理等越来越多的领域通过深度学习大幅提升了算法的整体性能与表现,GPU 也成为训练模型不可或缺的基础计算设备。然而,随着模型规模不断增大,加之训练数据量也越来越大,单个 GPU 的计算能力已完全无法满足大规模网络的训练需求。
在密集型训练的代表——自然语言处理中,OpenAI 于 2020 年 6 月发布的第三代语言模型 GPT-3 参数量达到了 1750 亿,相比 GPT-2 最大版本 15 亿参数增长了百倍以上。2021 年 4 月 25 日,华为云也发布盘古系列超大预训练模型,其中包含 30 亿参数的全球最大视觉(CV)预训练模型,以及与循环智能、鹏城实验室联合开发的千亿参数、40 TB 训练数据的全球最大中文语言(NLP)预训练模型。这些庞大模型训练的背后,必然少不了一套精妙运转的训练系统。本次分享将揭秘超大模型训练系统中必不可少的一项技术——ZeRO。
2. 现有并行方法
在探索 ZeRO 之前,我们需要先了解当前分布式训练主要的三种并行模式:数据并行、模型并行和流水线并行。
2.1 数据并行
当模型规模足够小且单个 GPU 能够承载时,数据并行是一种有效的分布式训练方式。因为每个 GPU 都会复制一份模型参数,我们只需要把训练数据均分给多个 GPU,然后让每个 GPU 作为一个计算节点独立完成前向和反向传播运算。数据并行不仅通信量较小,而且可以方便地做通信计算重叠,因此能够取得较好的加速比。
2.2 模型并行
如果模型规模较大,单个 GPU 的内存承载不下,我们可以将模型网络结构进行拆分,将模型的单层分解成若干份,把每一份分配到不同的 GPU 中,从而在训练时实现模型并行。训练过程中,正向和反向传播计算出的数据通过 All-Gather 或 All-Reduce 完成整合。这样的特性使得模型并行成为处理模型中大层的理想方案之一。
然而,深度神经网络层与层之间的依赖使得通信成本与模型并行通信组中的计算节点(GPU)数量正相关。在其他条件不变的情况下,模型规模的增加能够提供更好的计算通信比。
2.3 流水线并行
流水线并行可以理解为层与层之间的重叠计算,也可以理解为按照模型的结构与深度,将不同的 layer 分配给指定 GPU 进行计算。相较于数据并行需要 GPU 之间的全局通信,流水线并行只需在层之间点对点地传递部分 activations,因此流水线并行对通信带宽的需求更低。
然而,流水线并行需要相对稳定的通信频率来确保效率,这导致应用时需要手动进行网络分段,并插入繁琐的通信原语。同时,流水线并行的并行效率也依赖各卡负载的手动调优,这些操作都对应用该技术的研究员提出了更高要求。

流水线并行
3. 为什么需要 ZeRO?
在三种并行方式中,数据并行因其易用性得到了最为广泛的应用。然而,数据并行会产生大量冗余 Model States 的空间占用。ZeRO 的本质是在数据并行的基础上,对冗余空间占用进行深度优化。
在《大规模训练系列之技术挑战》一文中,我们介绍了大规模训练中的显存占用可以分为 Model States 与 Activation 两部分,而 ZeRO 正是为了解决 Model States 而诞生的一项技术。
首先,我们来看模型在训练过程中 Model States 由什么组成:
- Optimizer States:优化器在进行梯度更新时所需要用到的数据,例如 SGD 中的 Momentum,以及混合精度训练时的 Float32 Master Parameters。
- Gradient:反向传播后产生的梯度信息,决定参数的更新方向。
- Model Parameter:模型参数,也就是整个过程中通过数据“学习”得到的信息。
在传统数据并行下,每个进程都使用同样的参数进行训练,并且持有对 Optimizer States 的完整拷贝,同样占用大量显存。在混合精度场景下,以参数量为
- Float16 的参数和梯度备份,这两项分别消耗
和 字节内存(1 个 Float16 = 2 字节); - Float32 的参数、Momentum、Variance 备份,对应 3 份
的内存占用(1 个 Float32 = 4 字节)。
最终需要
ZeRO 在数据并行的基础上,引入了对冗余 Model States 的优化。使用 ZeRO 后,各个进程只保存完整状态的
3.1 ZeRO 的三个级别
相比传统数据并行的简单复制,ZeRO 通过将模型的参数、梯度和 Optimizer State 划分到不同进程来消除冗余的内存占用。
ZeRO 有三个不同级别,分别对应于对 Model States 不同程度的分区(Partition):
- ZeRO-1:分割 Optimizer States;
- ZeRO-2:分割 Optimizer States 与 Gradients;
- ZeRO-3:分割 Optimizer States、Gradients 与 Parameters。

ZeRO-DP 优化的三个阶段。其中
表示模型大小(参数数量), 表示优化器状态的内存乘数, 表示数据并行度。
3.1.1 ZeRO-1
Optimizer States Partitioning (
):4 倍内存缩减,通信量与传统数据并行相同。
优化器在进行梯度更新时,会使用参数与 Optimizer States 计算新的参数。而在正向或反向传播中,Optimizer States 并不参与计算。因此,我们可以让每个进程只持有一小段 Optimizer States,利用这一小段 Optimizer States 更新完与之对应的一小段参数后,再把各个小段拼接为完整的模型参数。ZeRO-1 正是这么做的:
ZeRO Optimizer Stage 1 动画[^4]
假设我们有
- 对自己存储的 Optimizer States(包括 Momentum、Variance 与 FP32 Master Parameters)进行计算与更新。
- 更新后的 Partitioned FP32 Master Parameters 会通过 All-Gather 传回到各个进程中。
经过这两步,完成一次完整的参数更新。
通过 ZeRO-1 对 Optimizer States 的分段化存储,7.5B 参数量的模型内存占用将由原始数据并行下的 120 GB 缩减到 31.4 GB。
3.1.2 ZeRO-2
Optimizer States and Gradient Partitioning (
):8 倍内存缩减,通信量与传统数据并行相同。
ZeRO-1 将 Optimizer States 分小段存储在多个进程中,因此在计算时,这一小段的 Optimizer States 也只需要得到进程所需的对应一小段 Gradient。遵循这一原理,和 Optimizer States 一样,ZeRO-2 也将 Gradient 进行了切片:
在一个 Layer 的 Gradient 都被计算出来后:
- Gradient 通过 All-Reduce 进行聚合(类似于 DDP)。
- 聚合后的梯度只会被某一个进程用来更新参数,因此其它进程上的这段 Gradient 不再被需要,可以立即释放(按需保留)。
这样就在 ZeRO-1 的基础上实现了对 Gradient 的切分。
通过 ZeRO-2 对 Gradient 和 Optimizer States 的分段化存储,7.5B 参数量的模型内存占用将由 ZeRO-1 中的 31.4 GB 进一步下降到 16.6 GB。
3.1.3 ZeRO-3
Optimizer States, Gradient and Parameter Partitioning (
):内存缩减与数据并行度成线性关系。
当 Optimizer States 与 Gradient 都被分布式切割、分段存储和更新之后,剩下的就是 Model Parameter。ZeRO-3 通过对 Optimizer States、Gradient 和 Model Parameter 三方面的分割,使所有进程共同协作,只存储一份完整的 Model States。其核心思路是精细化通信,按照计算需求完成参数的收集与释放。
3.2 ZeRO-3 宏观概览
ZeRO-3 相对于 ZeRO-1 和 ZeRO-2,实现方式复杂很多。首先我们站在宏观角度理解 ZeRO-3 的算法原理。
3.2.1 初始化
一个模型由多个 Submodule 组成。在初始化时,ZeRO-3 会将每个 Submodule Parameter Tensor 下的数据按照 GPU 数量分摊切割成多个小的 ds_tensor,存储在不同 GPU 进程中。因为 ds_tensor 可以共同组合出完整数据,所以原始 param 下的数据变为冗余信息,会被释放掉。

ZeRO-3 初始化参数分区
3.2.2 训练过程中
在训练过程中,ZeRO-3 会按照 Submodule 的计算需求进行参数的收集与释放:在当前 Submodule 正向/反向传播计算前,ZeRO-3 通过 All-Gather 拿到分摊存储在不同进程中的 ds_tensor,重建原始的 param。重建后的参数即可参与计算。
在当前 Submodule 正向/反向传播计算后,param 下的数据并没有发生变更,与 ds_tensor 相同,造成了冗余。因此,param 会再次被释放。

ZeRO-3 训练中的参数收集与释放
经过 ZeRO-3,一套完整的 Model States 就被分布式存储在多个 GPU 进程中。通过按照计算需求进行数据收集和释放,实现在存储空间有限的情况下训练超大规模模型。7.5B 参数量、64 卡并行的模型,内存占用将由 ZeRO-2 的 16.6 GB 最终下降到 1.9 GB。相较于传统数据并行下 120 GB 的内存空间,ZeRO-3 显著提升了内存使用效率[^1]。
以上就是 ZeRO-3 宏观算法原理的概述。在接下来的几个章节中,我们将深入源码,解读 ZeRO-3 的实现方式与逻辑。
ZeRO Optimizer Stage 3 动画[^4]
3.3 ZeRO-3 在 DeepSpeed 中的具体实现
在这里,我们深入代码,探索 ZeRO-3 是如何实现 Model Parameter 分布式存储的。
核心机制:初始化 -> 分割与收集 -> Submodule 收集 -> Submodule 释放
3.3.1 初始化——模型参数的分割
参数的分割遵循均匀分配原则,每个进程获得等量的参数分片。
首先,为了防止内存爆炸,巨大的 Model Parameters 必须在加载之前就被拆分并分发到各个进程中。ZeRO-3 在模型初始化时通过 zero.Init 完成分摊与切割。
model = zero.Init(module=model)zero.Init 对传入的 module 执行以下四步:
- 判定传入 ZeRO-3 的
module非None。 - 在一个
for循环中,遍历其下submodule中的所有参数。 - 在 tensor 的 data 被分割改变之前,对每个
parameter tensor套上一层_convert_to_deepspeed_param的包装,用于记录 tensor 的特性(shape、numel 等),防止后期因为 padding 和 partition 导致原始数据特性的丢失。 - 参数完成
_convert_to_deepspeed_param之后,param.partition()对其进行均分切割并分摊给各个进程。
param.partition() 中按照如下步骤进行参数切分:
- 根据进程数量(
self.world_size)计算 parameter partition 之后的大小:
partition_size = tensor_size // self.world_size- 创建一个
partition_size大小的空白 tensor:
partitioned_tensor = torch.zeros(
partition_size,
dtype=param.dtype,
device=self.remote_device,
)- 计算 partition 需要截取和存储的数据区间:
start = partition_size * self.rank
end = start + partition_size- 把原始 param 拉成一维后,按照进程自己的 rank 决定偏移量的
start和end,计算出截取区间并放入partitioned_tensor,把这个新创建的 tensor 挂在原始的param.ds_tensor下:
one_dim_param = param.contiguous().view(-1)
src_tensor = one_dim_param.narrow(0, start, partition_size)
param.ds_tensor.copy_(src_tensor)- 把原始的
param.data缩减为 1 个标量 tensor:
# 因为 param.data 已经被分散存储在 param.ds_tensor 下,
# 所以这里会将 param.data 释放掉,改为只存储一个标量的形式参数。
# 这也是为什么要通过 _convert_to_deepspeed_param 记录原始信息的原因。
param.data = torch.ones(1).half().to(param.device)通过以上五个步骤,每个 module 中的参数就被拆分并存储到不同进程中。当这一步结束时,原始 param.data 的长度变为 1,分段后的参数则存放在 param.ds_tensor 中。
假设有 param.partition() 切分为

Parameter Partition
3.3.2 初始化——模型参数收集初始化
根据每个 submodule 的需求实现更精细化的参数收集与释放。
拆分好 model parameter 之后,下一步需要考虑如何在需要时快速找到这些分摊存储的参数,并重新组合成完整参数进行运算。参数的收集与释放虽然发生在每次 forward 与 backward 中,但控制信息需要在初始化阶段建立。针对这一目的,ZeRO-3 中创建了两个类:
PartitionedParameterCoordinatorPrefetchCoordinator
这两个类负责在 forward 和 backward 时协调 module parameters 的获取与释放。
为了能够在模型 forward 和 backward 中及时拿到模型参数,ZeRO 初始化过程的一个重要环节就是给每个 submodule 创建 hooks。
首先了解 PyTorch 中的 hook。根据 PyTorch 文档:
"You can register a function on a Module or Tensor. The hook can be a forward hook or a backward hook. The forward hook will be executed when a forward call is executed. The backward hook will be executed in the backward phase."
通过使用 hook,我们可以在保留网络输入输出结构的同时,方便地获取、改变网络中间层变量的值和梯度。ZeRO-3 Optimizer 初始化过程中,代码通过递归方式对 module 下的每个 submodule 都挂上四个 hook:
_pre_forward_module_hook:在 submodule 的 forward 开始前负责获取 module parameters;_post_forward_module_hook:在 submodule 的 forward 结束后负责释放 module parameters;_pre_backward_module_hook:在 submodule 的 backward 开始前负责获取 module parameters;_post_backward_module_hook:在 submodule 的 backward 结束后负责释放 module parameters。
在每个 submodule 的 forward 和 backward 计算前,hook 会调用:
PartitionedParameterCoordinator中的fetch_sub_module和all_gather,收集并重建自己需要的 parameter。PrefetchCoordinator中的prefetch_next_sub_modules则最大化利用通信带宽,提前 all_gather 未来 submodule 需要的 parameter,为之后的计算做好准备。
计算完成后,hook 通过 PartitionedParameterCoordinator 中的 release_sub_module 再次释放当前 submodule 的 parameters。
通过这样的方式,在每一个 iteration 中,各个 submodule 就可以对自己需要的参数做出计算前的获取和计算后的释放。

Forward and Backward Hooks
3.3.3 前向传播中的 ZeRO-3
- 前向传播中 Model Parameter 的获取(Pre-Forward Hook)

Pre-Forward Hook
在初始化时,ZeRO-3 Optimizer 将全部 module parameter 分散 partition 到不同 GPU 上。因此,在每个 submodule 做 forward 之前,需要:
- 明确 submodule 所需要的 parameter;
- 通过进程间通信拿到分散存储的 partitioned parameter;
- 重新构造出原始 parameter 进行运算。
整个流程通过 PartitionedParameterCoordinator 和 PrefetchCoordinator 实现。每个 submodule 在 Pre-Forward Hook 中进行四步操作:
param_coordinator.record_trace:在第一个 iteration 时,record_trace会通过param_coordinator记录下一份 model 的完整运行记录trace,也就是各nn.Module的执行顺序。在之后的 iteration,运行记录已经创建好,record_trace不再发挥作用。param_coordinator.fetch_sub_module:因为 module forward 会逐层进行,当获得 submodule 的信息后:- 通过
submodule.named_parameters()收集当前需要的全部 partitioned parameters。 - 通过
all_gather,各个进程中的 partitioned parameters 会被重新组合构建成原始 parameter。 - 利用原始 parameter 进行
submodule.forward的计算。
- 通过
param_coordinator.prefetch_next_sub_modules:为了节省通信时间、提高效率,Pre-Forward Hook 中也会提前预取当前 submodule 之后的 submodule 参数,并对其标记以便后续调用。param_coordinator.increment_step:Step会更新当前 Submodule 在trace中走到了哪一步,从而确定之后prefetch_next_sub_modules的起点。
经过以上四步处理,便实现了:
- 完成 submodule 计算所需的所有 parameter 重建;
- 完成下一个 submodule 计算的准备;
- submodule 加入
most_recent_sub_module_step字典并做记录。
在第一个 iteration 后,通过之前创建好的 trace,之后计算过程中会按照 trace 中的顺序,从当前 step 进行参数的 fetch 和 eager prefetch。
通过以上完整的四个步骤,就实现了一个 submodule 在 Pre-Forward Hook 中的操作。在实际过程中,由于 module 可以逐层分成多个 submodule,整个 module 的 forward 过程中会不断对各 submodule 重复以上操作。
- 前向传播中 Model Parameter 的分割释放(Post-Forward Hook)

Post-Forward Hook
当 submodule 完成正向传播计算后,post_forward_hook 会释放掉当前 submodule 的参数,参数也会再次被 partition。但与初始化时的 partition 不同,此时每个进程中已经有了自己的小段 data,所以此时 partition 只需要把计算前重建的完整大 tensor 再次释放掉:
# param.data does not store anything meaningful in partitioned state
param.data = torch.ones(1, dtype=self.dtype).to(param.device)通过这样的方式,每个进程中的 submodule 只需要在计算前收集参数、计算后释放参数,从而大大减少冗余空间占用。
当 module 所有的 submodule 都完成正向传播后,engine 会将记录 submodule 执行顺序的 step_id 重新归为 0,重新回到整个计算 trace 的最初起点,准备下一次计算流程的开始。
3.3.4 反向传播中的 ZeRO-3
- 反向传播中 Model Parameter 的获取(Pre-Backward Hook)

Backward Hooks
pre-backward_hook 同样通过 record_trace、fetch_sub_module、prefetch_next_sub_modules 和 next_step 实现过程记录、参数获取,并为下一步做准备。
然而,由于 PyTorch 不支持 Pre-Backward Hook,因此需要曲线救国:使用 register_forward_hook 挂上一个 autograd.Function,从而在 module backward 之前执行自定义操作。在 backward 前,参数收集和分割的操作通过 torch.autograd.Function 挂在各个 submodule 的 tensor 上。
当该 tensor 反向传播计算时,autograd 的 backward 会调用 ctx.pre_backward_function(ctx.module),依次完成:
record_tracefetch_sub_moduleprefetch_next_sub_modulesnext_step
这四步操作与 Pre-Forward Hook 中的四步一致。
- 反向传播中 Model Parameter 的分割释放(Post-Backward Hook)
当 backward 结束后,Post-Backward Hook 中的 Post-Backward Function 也会和 post_forward_function 一样将 parameter 释放,从而减少 model parameter 的空间占用[^3]。
3.3.5 Evaluation

ZeRO Evaluation[^1]
ZeRO 在 Stage 2 时就在如下四个方面表现出色。
ZeRO-R optimizes activation memory by identifying and removing activation replication in existing MP approaches through activation partitioning. It also offloads activations to CPU when appropriate.
ZeRO-2 和 ZeRO-R 配合可以支持高达 1700 亿参数的模型训练:
- 模型规模:相较于 Megatron 局限于 40B 参数,ZeRO-2 和 ZeRO-R 的组合可以支持多达 1700 亿参数的模型训练,是当时 SOTA 方式的 8 倍。
- 训练速度:在 400 张 NVIDIA V100 GPU 集群上,ZeRO 可以将 100B 参数量的模型训练速度提升近 10 倍,达到 38 TFLOPS/GPU,总体高达 15 Petaflops。
- 可扩展性:在 64-400 个 GPU 区间,ZeRO 使训练速度具备超 GPU 增量的加速比。Model States 内存占用的减少支持了更大的 batch size 训练,从而提升模型的整体性能。
- 易用性:数据和模型开发人员无需做任何模型并行即可训练高达 130 亿参数的模型,从而减少模型重构带来的成本开销[^1]。
在 ZeRO-3 的加持下,ZeRO Optimization 的性能会得到进一步提升。
ZeRO-3 可以在单纯数据并行的模式下,实现在 1024 个 GPU 上训练超过 1 Trillion 的模型。配合模型并行,ZeRO 通过 16 路模型并行和 64 路数据并行,更是支持高达超过 2 Trillion 的模型训练[^1]。
4. 未来展望:ZeRO 的局限与大模型训练的方向
4.1 ZeRO 的局限性
ZeRO 在每个 submodule 的前向和反向传播中进行了参数的 collection 与 partition。在这种策略下:
- 单个 submodule 在前向或反向传播中所占用的显存(参数、梯度、Outputs、Workspace)需小于单个 GPU 的容量。
- 频繁利用通信来传递参数、梯度等信息,导致通信成为瓶颈。
4.1.1 大 Layer
例如 Transformer 模型中的一个
为了解决超大 Layer 这一难题,研究人员在 ZeRO 基础之上引入了对单层 Layer 的拆分技术,也就是俗称的模型并行。这里简单提一下两个比较有意思的工作:
- Megatron-LM[^5] 充分利用了 Transformer 的模型结构,对多个 GEMM 进行了高效的拆分。在 MLP 中,以纵向并行的方式划分第一个 GEMM,后续的 GeLU 与第二个 GEMM 只在本地进行,唯一的通信在 Dropout 前对第二个 GEMM 的输出做加和。通过这样的方式,GEMM 可以被分到不同的 GPU 上,并只需在正向和反向传播时各做一次 All-Reduce。对于 Self-Attention 模块也使用了类似的拆分方法,核心仍是利用分块矩阵乘法。

Megatron-LM Structure
- Optimus[^7] 同样利用了 Transformer 模型矩阵乘法的本质,但不在行列维度上分割矩阵,而是采用二维矩阵分割,并在理论效率上显著超过前者。
4.1.2 大通信

流水线并行
通信问题主要考虑引入流水线并行来缓解。流水线并行将模型按层切分成多个 Stage,每个 Worker 只持有一部分 Layer。切分后,不但每张卡上的参数和计算量减少了,同时 Worker 之间也只需要通信临界层的 Activations。
对于 Transformer 模型来说,临界层的 Activations 大小远远小于参数、梯度的大小,因此可以采用节点间做流水线并行,节点内多卡做数据并行的方式来缓解节点间的通信压力,同时充分利用节点内的超高带宽。也可以将数据并行分为两级:一级在节点内做通信量较大的 ZeRO 数据并行,另一级在多个流水线并行组间做普通的数据并行。
4.2 小结
细心的读者可能已经发现,将上述的流水线并行、模型并行与数据并行相融合,就形成了目前火热的 3D 混合并行。也正是 3D 混合并行支撑起了 GPT-3、盘古等千亿参数 Transformer 模型的训练。纵然 3D 混合并行威力巨大,其仍然有许多局限性,这将在之后的系列分享中再展开。
5. 引用
[^2]: Turing-NLG: A 17-billion-parameter language model by Microsoft
[^4]: KDD 2020: Hands on Tutorials: DeepSpeed - System optimizations enable training deep learning models
附录
PyTorch 的模型必须具有以下三种特性:
- 必须继承
nn.Module这个类,让 PyTorch 知道这个类是一个 Module。 - 在
__init__(self)中设置好需要的“组件”(如Conv、Pooling、Linear、BatchNorm等)。 - 在
forward(self, x)中用定义好的“组件”进行组装,就像搭积木一样把网络结构搭建出来,这样一个模型就定义好了。
根据 PyTorch 文档,nn.Module 是所有模型的基础 class,我们构建的各种模型网络也是这个 nn.Module 的 subclass,并且每个 Module 也可以包含其他的 Module。
"All network components should inherit from nn.Module and override the forward() method. That is about it, as far as the boilerplate is concerned. Inheriting from nn.Module provides functionality to your component."
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5) # submodule: Conv2d
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))PyTorch 给出的上述例子中,class Model 继承了 nn.Module,其内部两个 nn.Conv2d 各自也继承了 nn.Module,nn.Conv2d 就是 class Model 的 submodule。在 Stage 3 中,ZeRO 正是利用了 module 的这种嵌套特性来实现模型参数的记录与并行。