Skip to content

PyTorch 分布式训练详解

PyTorch 的分布式计算能力核心在于 torch.distributed 包,其底层基于 Collective Communication(c10d)库实现高效的跨进程张量通信。本文将深入剖析分布式数据并行训练(DDP)的通信机制、环境配置、核心概念及底层通信原语,帮助读者高效地应用 PyTorch 进行大规模模型训练。

1. 核心通信机制

PyTorch 的分布式通信主要依赖 torch.distributed 包,它支持以下两类通信 API:

  • 集合通信(Collective Communication APIs):用于所有进程之间进行协调和数据交换,是分布式数据并行训练(Distributed Data-Parallel,DDP)的核心支撑。
  • 点对点通信(P2P Communication APIs):用于两个指定进程之间进行直接数据传输,是基于 RPC 的分布式训练(RPC-Based Distributed Training)的基础。

本教程侧重于分布式数据并行训练(DDP)的通信机制与 API 使用。

2. 分布式训练基础模板

torch.distributed 包通过消息传递语义(Message Passing Semantics)将计算任务并行化到多个进程,支持单机多卡和多机集群环境。这与仅支持单机多进程的 torch.multiprocessing 包有本质区别——分布式包支持跨机器通信,并可选择多种通信后端。

2.1 单节点多进程训练模板

以下是一个基础的分布式训练模板,用于在单机上启动多个训练进程。

该模板通过 torch.multiprocessing 启动 2 个独立的进程,每个进程都完成分布式环境配置进程组初始化,随后执行训练函数 run

python
"""run.py: 基础分布式启动脚本"""
import os

import torch.distributed as dist
import torch.multiprocessing as mp


def run(rank: int, world_size: int):
    """
    具体的分布式训练逻辑实现函数。

    :param rank: 当前进程的唯一标识(0 到 world_size-1)。
    :param world_size: 参与训练的总进程数。
    """
    # 核心训练逻辑在此处实现
    print(f"进程 {rank} 已启动,总进程数:{world_size}")
    # 示例:执行一个 barrier 同步操作
    dist.barrier()
    print(f"进程 {rank} 同步完成。")


def init_process(rank: int, world_size: int, fn, backend: str = "gloo"):
    """
    初始化分布式环境(进程组)。
    """
    # 使用环境变量方式协调进程,这是 torch.distributed 推荐的方式
    os.environ["MASTER_ADDR"] = "127.0.0.1"  # 主节点 IP 地址
    os.environ["MASTER_PORT"] = "29500"      # 主节点通信端口

    # 初始化进程组
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    fn(rank, world_size)  # 执行训练函数


if __name__ == "__main__":
    world_size = 2  # 总进程数
    mp.set_start_method("spawn")  # 使用 spawn 启动方法
    processes = []

    for rank in range(world_size):
        # 为每个 rank 启动一个独立的进程
        p = mp.Process(target=init_process, args=(rank, world_size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()  # 阻塞主进程直到所有子进程结束

    dist.destroy_process_group()  # 清理资源(如果不在子进程内调用)

执行流程分析:

  1. 环境配置与进程组初始化(dist.init_process_group
  2. 分布式训练函数执行
  3. 进程同步与资源清理

3. 核心概念解析

3.1 DDP vs. DataParallel

在实践中,我们通常选择 torch.nn.parallel.DistributedDataParallel(DDP) 而非 torch.nn.DataParallel(DP),原因在于 DDP 在性能和扩展性上具有显著优势。

特性torch.nn.DataParallel(DP)torch.nn.parallel.DistributedDataParallel(DDP)
进程模型单进程多线程(主从结构)多进程(对等结构)
适用范围仅支持单机多卡支持单机和多机
通信机制基于线程的张量拷贝和广播基于集合通信的梯度同步
性能瓶颈受 Python GIL 限制,主 GPU 负载高,通信开销大无 GIL 限制,负载均衡,通信开销更小
内存/优化器所有 GPU 共享一个 Python 解释器和优化器每个进程(通常对应一个 GPU)拥有独立的解释器和优化器
扩展性不支持与模型并行结合支持与模型并行结合使用

3.2 DDP 核心优势

  1. 性能优化:多进程架构消除 GIL 瓶颈,显著提升计算密集型任务性能。
  2. 通信效率:梯度同步通过高效的集合通信(如 All-Reduce)实现,避免额外的参数广播步骤,降低通信开销。
  3. 扩展能力:支持模型并行与数据并行混合策略,可处理超大规模模型。

4. torch.distributed 通信后端

PyTorch 分布式包的一大优势是其抽象设计,允许用户根据硬件和场景选择最合适的通信后端。主要实现的后端包括 Gloo、NCCL 和 MPI。

4.1 通信后端

4.1.1 NCCL 后端

  • GPU 优化:专为 NVIDIA GPU 设计的集合通信库。
  • 性能领先:在 GPU 上提供最优通信性能。
  • 原生支持:集成于 CUDA 版本的 PyTorch。

4.1.2 Gloo 后端

  • CPU 通信:完整支持点对点与集合通信。
  • GPU 通信:支持集合通信,但性能逊于 NCCL。
  • 适用场景:CPU 训练或 GPU 训练的回退方案。

4.1.3 MPI 后端

  • 标准化实现:基于消息传递接口标准。
  • 高性能计算:在大规模集群上经过深度优化。
  • 高级特性:部分实现支持 CUDA IPC 和 GPUDirect 技术。
后端适用场景优势劣势/限制
NCCLGPU 训练(推荐)为 CUDA 张量提供高度优化的集合通信,性能最佳。仅支持集合通信,通常不用于 CPU。
GlooCPU 训练(推荐)支持所有点对点和集合通信操作。GPU 性能不如 NCCL,常作为备选。
MPI高性能计算集群广泛可用性,在特定集群上高度优化。需要手动编译 PyTorch 启用

4.2 后端功能对比

下表对比了各种通信后端在不同设备上的支持情况(以集合通信为例):

集合操作Gloo(CPU)Gloo(GPU)NCCL(GPU)
broadcast
all_reduce
reduce
all_gather
reduce_scatter××
all_to_all××
barrier

4.3 MPI 后端补充说明

MPI(Message Passing Interface)是高性能计算的标准,也是 torch.distributed API 设计的参考。

使用方式:使用 MPI 后端时,进程的生成和管理由 mpirun 等 MPI 启动工具负责,因此不需要在 Python 脚本中手动使用 mp.Process 或设置 rankworld_size

启动示例

  1. 训练脚本中替换为 init_process(0, 0, run, backend="mpi")
  2. 使用 mpirun -n 4 python myscript.py 启动 4 个进程。

5. 初始化分布式环境

在使用任何分布式功能之前,必须通过初始化进程组来建立进程间的通信通道。

5.1 核心初始化函数

python
# 核心初始化函数
torch.distributed.init_process_group(
    backend="nccl",           # 通信后端
    init_method="env://",     # 初始化方法
    rank=rank,                # 进程排名
    world_size=world_size,    # 总进程数
)

# 状态检查函数
torch.distributed.is_available()              # 分布式功能可用性
torch.distributed.is_initialized()            # 进程组初始化状态
torch.distributed.is_torchelastic_launched()  # TorchElastic 启动检测
函数描述
dist.init_process_group()初始化默认进程组,阻塞执行直到所有进程加入。
dist.destroy_process_group()销毁默认进程组,释放资源。
dist.is_available()检查分布式包是否可用。
dist.is_initialized()检查默认进程组是否已初始化。

⚠️ 重要提示init_process_group 不是线程安全的。进程组的创建必须在单个线程中执行,以避免竞争条件和进程挂起。

5.2 关键参数

  • backend:指定通信后端(如 "nccl""gloo")。
  • world_size:参与训练的总进程数。
  • rank:当前进程的唯一标识(0rank<world_size)。
  • init_method:指定进程组的初始化方式(URL 字符串),默认为 "env://"

5.3 初始化方法(init_method

PyTorch 支持三种初始化方法,用于协调进程间的连接信息:

5.3.1 环境变量初始化(env://)——推荐

这是最常用的方法,通过设置环境变量 MASTER_ADDRMASTER_PORTWORLD_SIZERANK 来协调进程。这是 PyTorch 官方启动工具(如 torchrun)推荐的方法。

python
# 脚本启动前设置环境变量或在代码中设置 os.environ
os.environ["MASTER_ADDR"] = "10.1.1.20"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend="nccl")  # 自动从环境中读取配置

5.3.2 TCP 初始化(tcp://

通过指定 rank 0 进程的 IP 地址和端口进行初始化:

python
dist.init_process_group(
    backend="nccl",
    init_method="tcp://10.1.1.20:23456",  # rank 0 的地址
    rank=args.rank,
    world_size=4,
)

5.3.3 共享文件系统初始化(file://

利用所有进程都可访问的共享文件路径来协调通信:

python
dist.init_process_group(
    backend="gloo",
    init_method="file:///path/to/shared/storage/directory/shared_file",
    world_size=4,
    rank=args.rank,
)

注意:使用文件系统初始化时,每次初始化都应使用一个新的或为空的文件,并在完成后清理该文件。

5.4 资源清理

在训练结束时,应调用 dist.destroy_process_group() 来清理资源。这对于使用 NCCL 后端尤其重要,有助于确保 NCCL 通信器的销毁过程(ncclCommAbort)在所有 rank 上以一致的顺序执行,避免进程挂起。

6. 分布式通信原语

初始化完成后,即可使用通信原语进行进程间的数据交换。

6.1 点对点(P2P)通信

P2P 通信用于两个指定进程之间的数据传输,可分为同步和异步模式。

Send and Recv

6.1.1 同步通信(send() / recv()

  • dist.send(tensor, dst):将张量发送到目标进程(dst)。
  • dist.recv(tensor, src):从源进程(src)接收张量,并存储到预分配的 tensor 中。
  • 特性:阻塞调用,直到通信完成。
python
def p2p_sync_demo(rank: int, size: int):
    """同步点对点通信示例"""
    tensor = torch.zeros(1)
    if rank == 0:
        tensor += 1
        dist.send(tensor=tensor, dst=1)  # 阻塞发送
    else:
        dist.recv(tensor=tensor, src=0)  # 阻塞接收
    print(f"Rank {rank} received data: {tensor[0]}")

6.1.2 异步通信(isend() / irecv()

  • dist.isend(tensor, dst):异步发送,返回一个 Request 对象。
  • dist.irecv(tensor, src):异步接收,返回一个 Request 对象。
python
def p2p_async_demo(rank: int, size: int):
    """异步点对点通信示例"""
    tensor = torch.zeros(1)
    request = None

    if rank == 0:
        tensor += 1
        request = dist.isend(tensor=tensor, dst=1)  # 非阻塞发送
    else:
        request = dist.irecv(tensor=tensor, src=0)  # 非阻塞接收

    # 等待通信完成
    request.wait()
    print(f"Rank {rank} has data: {tensor[0]}")

异步通信约束

  • wait() 完成前禁止修改发送张量。
  • wait() 完成前禁止访问接收张量。
  • 违反约束将导致未定义行为。

6.2 集合通信(Collective Communication)

集合通信涉及进程组内的所有进程,是 DDP 的核心机制。所有进程必须参与并使用相同的通信原语。

6.2.1 broadcast(广播)

  • 功能:将源进程(src)的张量复制到进程组内所有其他进程。
Source Process iAll Processes j广播

6.2.2 gather(收集)

  • 收集所有进程的 tensor 到目标进程的 gather_list 中。
  • 非目标进程的 gather_list 必须为 None
收集

6.2.3 scatter(散播)

  • 将源进程的 scatter_list 分发到所有进程的 tensor 中。
  • 源进程的 scatter_list 必须是长度为进程数的列表。
散播

6.2.4 reduce(归约)

  • 功能:对组 G 内所有进程的张量执行指定操作 op(如求和、取最大值等),并将结果存储在目标进程的张量中。
  • 应用:通常用于多进程计算梯度后的梯度聚合(但 DDP 使用 all_reduce)。
归约

6.2.5 all_reduce(全归约)

  • 功能:对所有进程的张量执行指定的归约操作(如求和 Σ、取最大值 max),并将最终结果存储回所有进程的张量中。
  • 用途:这是 DDP 中同步梯度的核心操作。
tensorji=1Ntensori,j{1,,N}全归约

6.2.6 all_gather(全收集)

  • 功能:收集所有进程的张量,并将一个包含所有进程张量副本的列表发送给所有进程。
  • 用途:常用于同步 Batch Normalization 层统计信息,或收集所有进程的预测结果。
tensor_listj{tensor1,tensor2,,tensorN},j{1,,N}全收集

6.2.7 reduce_scatter(归约-分散)

  • 功能:结合 reducescatter。首先对所有进程的输入张量进行归约,然后将结果分散到各个进程的输出张量中。
  • 用途:是 ZeRO/Fully Sharded Data Parallel(FSDP) 等参数分片训练策略中的关键操作。

reduce_scatter

6.2.8 barrier(屏障)

  • 功能:阻塞所有进程,直到进程组中的所有进程都达到屏障点后才继续执行。
  • 用途:用于确保严格的进程同步。

7. 分布式训练启动工具

在实际生产环境中,通常不手动编写 mp.Process 启动脚本,而是使用 PyTorch 提供的启动工具。

7.1 torchrun(推荐)

torchrun 是 PyTorch 官方推荐的分布式训练启动器。它支持弹性(elasticity)和容错(fault tolerance),能够自动处理进程的创建、初始化和资源分配。

7.1.1 单节点多 GPU 训练

bash
# 假设有 4 个 GPU
torchrun --nproc_per_node=4 train_script.py [训练脚本参数]

7.1.2 多节点分布式训练(例如 2 个节点)

在每个节点上运行相同的 torchrun 命令,但需指定唯一的 node_rank

bash
# 节点 1(rank=0)
torchrun --nproc_per_node=GPU数量 \
         --nnodes=2 \
         --node_rank=0 \
         --master_addr="192.168.1.1" \
         --master_port=29500 \
         train_script.py [训练脚本参数]

# 节点 2(rank=1)
torchrun --nproc_per_node=GPU数量 \
         --nnodes=2 \
         --node_rank=1 \
         --master_addr="192.168.1.1" \
         --master_port=29500 \
         train_script.py [训练脚本参数]

7.2 训练脚本的简化配置

当使用 torchrun 启动时,它会自动设置所有必需的环境变量(RANKWORLD_SIZELOCAL_RANK 等),因此训练脚本的初始化可以大大简化:

python
import os

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP


def setup_distributed_model(model: torch.nn.Module) -> DDP:
    """
    使用 torchrun / env:// 方式设置分布式环境和 DDP 模型。
    """
    # 1. 初始化进程组(使用默认的 env:// 方法)
    # torchrun 已自动设置环境变量
    dist.init_process_group(backend="nccl")

    # 2. 获取当前进程的本地排名并设置设备
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    # 3. 创建 DistributedDataParallel 模型
    ddp_model = DDP(
        model.to(local_rank),
        device_ids=[local_rank],
        output_device=local_rank,
    )
    return ddp_model


# 在训练代码的末尾,添加资源清理
# dist.destroy_process_group()

7.3 最佳实践总结

  1. 后端选择:始终优先使用 NCCL 后端进行 GPU 训练。
  2. 初始化方法:统一采用 env:// 环境变量方式。
  3. 设备管理:正确设置 local_rank,确保 GPU 设备隔离。
  4. 资源管理:训练完成后调用 destroy_process_group() 清理资源。
  5. 进程管理:使用 torchrun 作为分布式训练启动工具。

8. Reference

  1. PyTorch Distributed Training
  2. PyTorch Distributed Training Tutorial

Maintained by Robin