Skip to content

🚀 PyTorch 分布式与并行训练:深度解析

PyTorch 的分布式计算能力核心在于 torch.distributed 包,其底层基于 Collective Communication (c10d) 库实现了高效的跨进程张量通信。本文将重点对 分布式数据并行训练 (DDP) 的通信机制、环境配置、核心概念以及底层通信原语进行深入剖析,以确保读者能够专业、高效地应用 PyTorch 进行大规模AI模型训练。

核心通信机制

PyTorch 的分布式通信主要依赖 torch.distributed 包,它支持以下两种通信 API:

  • 集合通信 (Collective Communication APIs):用于所有进程之间进行协调和数据交换,是 分布式数据并行训练 (Distributed Data-Parallel, DDP) 的核心支撑。
  • 点对点通信 (P2P Communication APIs):用于两个指定进程之间进行直接数据传输,是 基于 RPC 的分布式训练 (RPC-Based Distributed Training) 的基础。

本教程将侧重于分布式数据并行训练 (DDP) 的通信机制和 API 使用。

1. 分布式训练基础模板

torch.distributed 包通过消息传递语义(Message Passing Semantics)使研究人员和开发者能够将计算任务轻松并行化到多个进程,无论是单机多卡还是多机集群环境。这与仅支持单机多进程的 torch.multiprocessing 包有本质区别,分布式包支持跨机器通信,并且可支持多种通信后端。

1.1 单节点多进程训练模板

以下是一个基础的分布式训练模板,用于在单机上启动多个训练进程。

此模板通过 torch.multiprocessing 启动了 $2$ 个独立的进程,每个进程都完成了分布式环境配置进程组初始化,随后执行训练函数 run

python
"""run.py: 基础分布式启动脚本"""
import os
import torch.distributed as dist
import torch.multiprocessing as mp

def run(rank: int, world_size: int):
    """
    具体的分布式训练逻辑实现函数。
    :param rank: 当前进程的唯一标识(0 到 world_size-1)。
    :param world_size: 参与训练的总进程数。
    """
    # 核心训练逻辑在此处实现
    print(f"进程 {rank} 已启动,总进程数:{world_size}")
    # 示例:执行一个 barrier 同步操作
    dist.barrier()
    print(f"进程 {rank} 同步完成。")


def init_process(rank: int, world_size: int, fn, backend: str = 'gloo'):
    """
    初始化分布式环境(进程组)。
    """
    # 使用环境变量方式协调进程,这是 torch.distributed 推荐的方式
    os.environ['MASTER_ADDR'] = '127.0.0.1'  # 主节点 IP 地址
    os.environ['MASTER_PORT'] = '29500'      # 主节点通信端口

    # 初始化进程组
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    fn(rank, world_size) # 执行训练函数


if __name__ == "__main__":
    world_size = 2  # 总进程数
    mp.set_start_method("spawn") # 使用 'spawn' 启动方法
    processes = []

    for rank in range(world_size):
        # 为每个 rank 启动一个独立的进程
        p = mp.Process(target=init_process, args=(rank, world_size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join() # 阻塞主进程直到所有子进程结束

    dist.destroy_process_group() # 清理资源(如果不在子进程内调用)

执行流程分析:

  1. 环境配置与进程组初始化 (dist.init_process_group)
  2. 分布式训练函数执行
  3. 进程同步与资源清理

2. 核心概念解析

DDP vs. DataParallel

在实践中,我们通常选择 torch.nn.parallel.DistributedDataParallel (DDP) 而非 torch.nn.DataParallel (DP),原因在于 DDP 在性能和扩展性上具有显著优势。

特性torch.nn.DataParallel (DP)torch.nn.parallel.DistributedDataParallel (DDP)
进程模型单进程多线程(主从结构)多进程(对等结构)
适用范围仅支持单机多卡支持单机和多机
通信机制基于线程的张量拷贝和广播基于集合通信的梯度同步
性能瓶颈受 Python GIL 限制,主 GPU 负载高,通信开销大无 GIL 限制,负载均衡,通信开销更小
内存/优化器所有 GPU 共享一个 Python 解释器和优化器每个进程(通常对应一个 GPU)拥有独立的解释器和优化器
扩展性不支持与模型并行结合支持与模型并行结合使用

DDP 核心优势

  1. 性能优化:多进程架构消除 GIL 瓶颈,显著提升计算密集型任务性能
  2. 通信效率:梯度同步通过高效的集合通信(如 All-Reduce)同步梯度机制避免额外的参数广播步骤,降低通信开销
  3. 扩展能力:支持模型并行与数据并行混合策略,可处理超大规模模型

3. torch.distributed 通信后端

PyTorch 分布式包的一大优势是其抽象设计,允许用户根据硬件和场景选择最合适的通信后端。主要实现的后端包括 Gloo、NCCL 和 MPI。

3.1 通信后端

NCCL 后端

  • GPU 优化:专为 NVIDIA GPU 设计的集合通信库
  • 性能领先:在 GPU 上提供最优通信性能
  • 原生支持:集成于 CUDA 版本的 PyTorch

Gloo 后端

  • CPU 通信:完整支持点对点与集合通信
  • GPU 通信:支持集合通信,但性能逊于 NCCL
  • 适用场景:CPU 训练或 GPU 训练的回退方案

MPI 后端

  • 标准化实现:基于消息传递接口标准
  • 高性能计算:在大规模集群上经过深度优化
  • 高级特性:部分实现支持 CUDA IPC 和 GPUDirect 技术
后端适用场景优势劣势/限制
NCCLGPU 训练(推荐)为 CUDA 张量提供高度优化的集合通信,性能最佳。仅支持集合通信,通常不用于 CPU。
GlooCPU 训练(推荐)支持所有点对点和集合通信操作。GPU 性能不如 NCCL,常作为备选。
MPI高性能计算集群广泛可用性,在特定集群上高度优化。需要手动编译 PyTorch 启用

3.2 后端功能对比

下表对比了各种通信后端在不同设备上的支持情况(以集合通信为例):

集合操作Gloo (CPU)Gloo (GPU)NCCL (GPU)
broadcast$\checkmark$$\checkmark$$\checkmark$
all_reduce$\checkmark$$\checkmark$$\checkmark$
reduce$\checkmark$$\checkmark$$\checkmark$
all_gather$\checkmark$$\checkmark$$\checkmark$
reduce_scatter$\times$$\times$$\checkmark$
all_to_all$\times$$\times$$\checkmark$
barrier$\checkmark$$\checkmark$$\checkmark$

3.3 MPI 后端补充说明

MPI (Message Passing Interface) 是高性能计算的标准,是 torch.distributed API 设计的参考。

使用方式:使用 MPI 后端时,进程的生成和管理由 mpirun 等 MPI 启动工具负责,因此不需要在 Python 脚本中手动使用 mp.Process 或设置 rankworld_size

启动示例

  1. 训练脚本中替换为 init_process(0, 0, run, backend='mpi')
  2. 使用 mpirun -n 4 python myscript.py 启动 $4$ 个进程。

4. 初始化分布式环境

在使用任何分布式功能之前,必须通过初始化进程组来建立进程间的通信通道。

4.1 核心初始化函数

python
# 核心初始化函数
torch.distributed.init_process_group(
    backend='nccl',           # 通信后端
    init_method='env://',     # 初始化方法
    rank=rank,               # 进程排名
    world_size=world_size    # 总进程数
)

# 状态检查函数
torch.distributed.is_available()         # 分布式功能可用性
torch.distributed.is_initialized()         # 进程组初始化状态
torch.distributed.is_torchelastic_launched()  # Torchelastic 启动检测
函数描述
dist.init_process_group()初始化默认进程组,阻塞执行直到所有进程加入。
dist.destroy_process_group()销毁默认进程组,释放资源。
dist.is_available()检查分布式包是否可用。
dist.is_initialized()检查默认进程组是否已初始化。

⚠️ 重要提示: init_process_group 不是线程安全的。进程组的创建必须在单个线程中执行,以避免竞争条件和进程挂起。

4.2 关键参数

  • backend: 指定通信后端(如 'nccl''gloo')。
  • world_size: 参与训练的总进程数。
  • rank: 当前进程的唯一标识($0 \le \text{rank} < \text{world_size}$)。
  • init_method: 指定进程组的初始化方式(URL字符串),默认为 "env://"

4.3 初始化方法(init_method

PyTorch 支持三种初始化方法,用于协调进程间的连接信息:

1. 环境变量初始化 (env://) - 推荐

这是最常用的方法,通过设置环境变量 MASTER_ADDRMASTER_PORTWORLD_SIZERANK 来协调进程, 这是 PyTorch 官方启动工具(如 torchrun)推荐的方法。

python
# 脚本启动前设置环境变量或在代码中设置 os.environ
os.environ['MASTER_ADDR'] = '10.1.1.20'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend='nccl') # 自动从环境中读取配置

2. TCP 初始化 (tcp://)

通过指定 rank=0 进程的 IP 地址和端口进行初始化:

python
dist.init_process_group(
    backend='nccl',
    init_method='tcp://10.1.1.20:23456', # rank 0 的地址
    rank=args.rank,
    world_size=4
)

3. 共享文件系统初始化 (file://)

利用所有进程都可访问的共享文件路径来协调通信:

python
dist.init_process_group(
    backend='gloo',
    init_method='file:///path/to/shared/storage/directory/shared_file',
    world_size=4,
    rank=args.rank
)

注意: 使用文件系统初始化时,每次初始化都应使用一个新的或为空的文件,并在完成后清理该文件。

4.4 资源清理

在训练结束时,应调用 dist.destroy_process_group() 来清理资源。这对于使用 NCCL 后端尤其重要,有助于确保 NCCL 通信器的销毁过程(ncclCommAbort)在所有等级上以一致的顺序执行,避免进程挂起。

5. 分布式通信原语

初始化完成后,即可使用通信原语进行进程间的数据交换。

5.1 点对点(P2P)通信

P2P 通信用于两个指定进程之间的数据传输,可分为同步和异步模式。

Send and Recv

1. 同步通信 (send() / recv())

  • dist.send(tensor, dst):将张量发送到目标进程 (dst)。
  • dist.recv(tensor, src):从源进程 (src) 接收张量,并存储到预分配的 tensor 中。
  • 特性: 阻塞调用,直到通信完成。
python
def p2p_sync_demo(rank: int, size: int):
    """同步点对点通信示例"""
    tensor = torch.zeros(1)
    if rank == 0:
        tensor += 1
        dist.send(tensor=tensor, dst=1)  # 阻塞发送
    else:
        dist.recv(tensor=tensor, src=0)  # 阻塞接收
    print(f'Rank {rank} received data: {tensor[0]}')

2. 异步通信 (isend() / irecv())

  • dist.isend(tensor, dst):异步发送,返回一个 Request 对象。
  • dist.irecv(tensor, src):异步接收,返回一个 Request 对象。
python
def p2p_async_demo(rank: int, size: int):
    """异步点对点通信示例"""
    tensor = torch.zeros(1)
    request = None

    if rank == 0:
        tensor += 1
        request = dist.isend(tensor=tensor, dst=1)  # 非阻塞发送
    else:
        request = dist.irecv(tensor=tensor, src=0)  # 非阻塞接收

    # 等待通信完成
    request.wait()
    print(f'Rank {rank} has data: {tensor[0]}')

异步通信约束:

  • wait() 完成前禁止修改发送张量
  • wait() 完成前禁止访问接收张量
  • 违反约束将导致未定义行为

5.2 集合通信(Collective Communication)

集合通信涉及进程组内的所有进程,是 DDP 的核心机制。所有进程必须参与并使用相同的通信原语。

1. broadcast(广播)

  • 功能: 将源进程(src)的张量复制到进程组内所有其他进程。

$$ \text{Source Process } i \to \text{All Processes } j $$

播送

2. gather (收集)

  • 收集所有进程的 tensor 到目标进程的 gather_list
  • 非目标进程的 gather_list 必须为 None
收集

3. scatter (散播)

  • 将源进程的 scatter_list 分发到所有进程的 tensor
  • 源进程的 scatter_list 必须是长度为进程数的列表
Scatter

4. reduce (归约)

  • 功能:对组 $G$ 内所有进程的张量执行指定操作 $\text{op}$(如求和、求最大值等),并将结果存储在目标进程的张量 中。

  • 应用:通常用于多进程计算梯度后的梯度聚合(但 DDP 使用 all_reduce)。

减少

5. all_reduce(全归约)

  • 功能: 对所有进程的张量执行指定的归约操作(如求和 $\Sigma$、取最大值 $\text{max}$),并将最终结果存储回所有进程的张量中。

  • 用途: 这是 DDP 中同步梯度的核心操作。

$$ \text{tensor}j \leftarrow \sum^{N} \text{tensor}_i, \quad \forall j \in {1, \dots, N} $$

全归约

6. all_gather(全聚集)

  • 功能: 收集所有进程的张量,并将一个包含所有进程张量副本的列表发送给所有进程。
  • 用途: 常用于同步 $\text{Batch Normalization}$ 层统计信息,或收集所有进程的预测结果。

$$ \text{tensor_list}_j \leftarrow {\text{tensor}_1, \text{tensor}_2, \dots, \text{tensor}_N}, \quad \forall j \in {1, \dots, N} $$

全员聚集

7. reduce_scatter(归约-分散)

  • 功能: 结合 reducescatter。首先对所有进程的输入张量进行归约,然后将结果分散到各个进程的输出张量中。
  • 用途:ZeRO/Fully Sharded Data Parallel (FSDP) 等参数分片训练策略中的关键操作。

image

8. barrier(屏障)

  • 功能: 阻塞所有进程,直到进程组中的所有进程都达到屏障点后才继续执行。
  • 用途: 用于确保严格的进程同步。

6. 分布式训练启动工具

在实际生产环境中,通常不手动编写 mp.Process 启动脚本,而是使用 PyTorch 提供的启动工具。

6.1 torchrun (推荐)

torchrun是 PyTorch 官方推荐的分布式训练启动器。它支持弹性(elasticity)和容错(fault tolerance),能够自动处理进程的创建、初始化和资源分配。

1. 单节点多 GPU 训练

bash
# 假设有 4 个 GPU
torchrun --nproc_per_node=4 train_script.py [训练脚本参数]

2. 多节点分布式训练(例如,2 个节点)

在每个节点上运行相同的 torchrun 命令,但需指定唯一的 node_rank

bash
# 节点 1 (rank=0)
torchrun --nproc_per_node=GPU数量 \
         --nnodes=2 \
         --node_rank=0 \
         --master_addr="192.168.1.1" \
         --master_port=29500 \
         train_script.py [训练脚本参数]

# 节点 2 (rank=1)
torchrun --nproc_per_node=GPU数量 \
         --nnodes=2 \
         --node_rank=1 \
         --master_addr="192.168.1.1" \
         --master_port=29500 \
         train_script.py [训练脚本参数]

6.2 训练脚本的简化配置

当使用 torchrun 启动时,它会自动设置所有必需的环境变量(RANK, WORLD_SIZE, LOCAL_RANK 等),因此训练脚本的初始化可以大大简化:

python
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_distributed_model(model: torch.nn.Module) -> DDP:
    """
    使用 torchrun/env:// 方式设置分布式环境和 DDP 模型。
    """
    # 1. 初始化进程组(使用默认的 env:// 方法)
    # torchrun 已自动设置环境变量
    dist.init_process_group(backend='nccl')

    # 2. 获取当前进程的本地排名并设置设备
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    # 3. 创建 DistributedDataParallel 模型
    ddp_model = DDP(model.to(local_rank),
                    device_ids=[local_rank],
                    output_device=local_rank)
    return ddp_model

# 在训练代码的末尾,添加资源清理
# dist.destroy_process_group()

最佳实践总结

  1. 后端选择: 始终优先使用 NCCL 后端进行 GPU 训练
  2. 初始化方法: 统一采用 env:// 环境变量方式
  3. 设备管理: 正确设置 local_rank 确保 GPU 设备隔离
  4. 资源管理: 训练完成后调用 destroy_process_group() 清理资源
  5. 进程管理: 使用 torchrun 作为分布式训练启动工具

Reference:

  1. PyTorch Distributed Training
  2. PyTorch Distributed Training Tutorial

Maintained by Robin