Skip to content

大模型训练 MFU 计算方法

1. 概述

在深度学习领域,MFU(Model FLOPs Utilization,模型浮点运算利用率) 是衡量大模型训练效率的核心指标。它描述了模型在实际训练过程中有效利用硬件理论计算能力的程度。对于 GPT 和 Llama 等参数量巨大的 Transformer 模型,监控并优化 MFU 是提升训练吞吐量、降低计算成本的关键。

2. MFU 的核心概念

2.1 定义

MFU 定义为模型在实际训练过程中每秒执行的有效浮点运算次数与硬件理论峰值算力的比值:

MFU=Actual FLOPsTheoretical Peak FLOPs

2.2 关键要素

  • Actual FLOPs(实际达到的 FLOPs):指模型执行前向传播(Forward)、反向传播(Backward)及优化器更新等所需的所有计算。
  • Theoretical Peak FLOPs(硬件理论峰值):硬件(如 NVIDIA H100 或 Ascend 910)在特定精度(如 FP16/BF16)下的标称最大吞吐量。

2.3 核心意义

  • 诊断瓶颈:低 MFU 通常意味着存在通信带宽瓶颈(Communication Bound)、内存带宽限制(Memory Bound)或算子实现效率低下。
  • 优化指导:通过调整 Batch Size、并行策略(TP/PP/DP)或使用混合精度训练来逼近硬件极限。

3. Transformer 模型的 FLOPs 推导

以标准的 GPT 架构为例,详细推导其在训练过程中每步所需的浮点运算量(FLOPs),为评估大模型训练效率提供理论基础。

3.1 FLOPs 定义

  • FLOPs:表示完成一次矩阵乘法所需的浮点操作次数。
  • 对于一个 Am×k×Bk×n 的矩阵乘法:
    • 所需 FLOPs 数量为:2×m×k×n
    • 包含:
      • m×k×n 次乘法操作;
      • m×k×(n1) 次加法操作;
      • 总计约 2mkn

3.2 Transformer 层的 FLOPs 分析

3.2.1 符号定义

符号含义符号含义
bBatch Size(批大小)hHidden Dimension(隐藏层维度)
sSequence Length(序列长度)VVocabulary Size(词表大小)
LNumber of Layers(层数)4hMLP 中间层维度

3.2.2 Self-Attention 模块 FLOPs

3.2.3 公式

Q,K,V=xWQ,xWK,xWVS=softmax(QKh)C=SVO=CWO

3.2.4 FLOPs 计算

操作输入形状输出形状FLOPs
[b,s,h][b,s,h](Q/K/V 投影)[b,s,h]×[h,h][b,s,h]3×2bsh2=6bsh2
[b,s,h][b,s,s]QK[b,s,h]×[b,h,s][b,s,s]2bs2h
[b,s,s][b,s,h]SV[b,s,s]×[b,s,h][b,s,h]2bs2h
[b,s,h][b,s,h](投影回 hidden)[b,s,h]×[h,h][b,s,h]2bsh2

总计 Attention FLOPs

6bsh2+2bs2h+2bs2h+2bsh2=8bsh2+4bs2h

3.2.5 MLP 模块 FLOPs

3.2.6 结构

  • 升维:h4h
  • GeLU 激活
  • 降维:4hh

3.2.7 FLOPs 计算

操作输入输出FLOPs
[b,s,h][b,s,4h][b,s,h]×[h,4h]2bsh4h=8bsh2
[b,s,4h][b,s,h][b,s,4h]×[4h,h]2bs4hh=8bsh2

总计 MLP FLOPs

8bsh2+8bsh2=16bsh2

3.3 单层 Transformer 总 FLOPs

忽略轻微的逐元素操作(如 LayerNorm、Softmax、GeLU),单层前向传播总 FLOPs 为:

FLOPsforward_layer=24bsh2+4bs2h

3.4 Vocabulary Embedding FLOPs

3.4.1 操作

  • 将 token ID 映射为 embedding 向量。
  • 形状变换:[b,s][b,s,h][b,s,V]

3.4.2 FLOPs

[b,s,h]×[h,V]2bshV

注意:此步骤通常只在前向传播中执行一次,但在反向传播中也有类似计算。

3.5 模型总 FLOPs

假设 GPT 模型有 L 层 Transformer 层,则:

组件FLOPs 公式
Self-Attention8bsh2+4bs2h
MLP16bsh2
Embedding2bshV
单个 Transformer 层24bsh2+4bs2h
整模型总 FLOPsL×(24bsh2+4bs2h)+2bshV

4. 反向传播与全流程计算

4.1 梯度计算原理

在反向传播(Backward)中,对于每一个线性变换 y=wx,需要计算:

  • 对权重的梯度:Lw=Lyx → 一次矩阵乘法;
  • 对输入的梯度:Lx=Lyw → 一次矩阵乘法。

结论:在训练中,反向传播的计算量约为前向传播的 2 倍

4.2 训练总 FLOPs

包含前向和反向传播的单步总计算量约为:

Total FLOPs3×Forward FLOPs

注意:若考虑重计算/激活值检查点(Recompute),倍数会增至约 4 倍。

4.3 经验简化公式

对于大模型,当 hs 时,公式可近似为:

Total FLOPs72bsh2L(1+s6h+V12hL)

5. Llama 系列模型的 MFU 特化计算

Llama 模型在标准 Transformer 基础上引入了 GQA(Grouped Query Attention)SwiGLU 激活函数,其计算分布略有不同。

5.1 模型结构简要说明

  • 使用 Llama 2 架构:包含多层 Transformer 解码器。
  • 关键组件:
    • Attention 模块:含 RoPE、QKV 投影、Softmax、输出投影;
    • FeedForward 模块:含两个线性层与 SiLU 激活函数;
    • RMSNorm:归一化操作;
    • lm_head:最终输出层。

注意WqWkWv 分别为 Q/K/V 的权重矩阵;Wo 为输出投影;W1W2W3 为 FFN 中的权重。

5.1.1 Llama 架构差异点

  • GQA 影响KV 的投影维度降为 h/rr 为 head 比例系数)。
  • SwiGLU 影响:FFN 层由三个线性矩阵 W1W2W3 组成,中间维度为 h^(通常为 83h)。

5.1.2 符号定义

符号含义
bbatch_size
Lnum_layers(层数)
sseq_length(序列长度)
hhidden_size(隐藏维度)
nnum_heads(注意力头数)
dhead_dim(每个头的维度)
h^intermediate_size(FFN 中间维度)
vvocab_size(词表大小)
rrepeat(重复因子,通常为 n/nkv
mffn_dim_multiplier(FFN 扩展倍数,一般为 4)

关系式

h=n×dh^83h×m

5.2 各模块 FLOPs 详细分析

5.2.1 Attention 模块(×L 层)

操作输入 输出FLOPs
Wq:(b,s,h)(b,s,h)[b,s,h]×[h,h]2bsh2
Wk:(b,s,h)(b,s,h/r)[b,s,h]×[h,h/r]2bsh2/r
Wv:(b,s,h)(b,s,h/r)[b,s,h]×[h,h/r]2bsh2/r
QK:(b,n,s,d)(b,n,s,s)[b,n,s,d]×[b,n,d,s]2bs2h
scoreV:(b,n,s,s)(b,n,s,d)[b,n,s,s]×[b,n,s,d]2bs2h
Wo:(b,s,h)(b,s,h)[b,s,h]×[h,h]2bsh2

总计 Attention FLOPs(每层)

2bsh2+2bsh2r+2bsh2r+2bs2h+2bs2h+2bsh2=4bsh2+4bsh2r+4bs2h

5.2.2 FeedForward 模块(×L 层)

操作输入 输出FLOPs
W3:(b,s,h)(b,s,h^)[b,s,h]×[h,h^]2bshh^
W1:(b,s,h)(b,s,h^)[b,s,h]×[h,h^]2bshh^
W2:(b,s,h^)(b,s,h)[b,s,h^]×[h^,h]2bshh^

总计 FFN FLOPs(每层)

6bshh^

5.2.3 lm_head 模块

操作输入 输出FLOPs
Whead:(b,s,h)(b,s,v)[b,s,h]×[h,v]2bshv

注意:该操作仅在最后一步执行,但对总 FLOPs 贡献显著。

5.3 总 FLOPs 推导

5.3.1 前向传播总 FLOPs(单次)

FLOPs=L×[4bsh2+4bsh2r+4bs2h+6bshh^]+2bshv

提取公因子 4Lbsh2

FLOPs=4Lbsh2(1+1r+sh+3h^2h+v2Lh)

代入 h^83hm

FLOPs4Lbsh2(1+1r+4m+sh+v2Lh)

进一步近似:

  • 1rnnkv,常忽略或合并;
  • 当模型规模较大时,v2Lh 项相对较小。

5.3.2 反向传播 FLOPs

  • 反向传播 FLOPs 前向的 2 倍
  • 因此,总训练 FLOPs 3 × 前向 FLOPs

提示:MFU 计算时需使用总 FLOPs(前向 + 反向 + 优化器),但此处以前向为主进行估算。

5.4 实例分析:Llama2-70B

给定参数:

  • s=4096
  • h=8192
  • L=80
  • r=8
  • h^=28672
  • v=32000

代入公式:

FLOPs4×80×b×81922(1+18+4×4+40968192+320002×80×8192)

对于 70B 模型,由于参数量极大,其 6bs×Params 的估算方法非常接近精确值。在计算 MFU 时,需精确代入 h=8192h^=28672 等超参数,以获得准确的算力需求分析。

6. 总结与优化建议

优化方向操作手段对 MFU 的影响
计算密度增大 Batch Size显著提升,减少算子调度和通信开销
通信优化调整 TP/PP 比例降低通信延迟,减少算力闲置
算子融合使用 FlashAttention降低显存带宽压力,提高计算利用率
精度转换使用 BF16/FP8提升单位时间内的理论吞吐量

专家目标:在千卡规模的集群训练中,将 MFU 稳定在 55% - 70% 是大模型训练达到工业级性能的标志。

Maintained by Robin