Skip to content

昇腾 NPU 融合注意力算子详解与迁移指南

1. 背景与动机

在 Transformer 类模型(如 Llama、GPT、BERT 等)中,自注意力机制(Self-Attention)是计算量与内存占用的核心瓶颈。传统的 PyTorch 实现通常由 softmaxdropoutmatmul 以及多次 mask 操作等一系列离散算子组合而成。

这种"小算子组合"模式在 NPU 硬件上存在以下痛点:

  1. 访存带宽受限:每个中间结果都需要在 HBM(高带宽显存)和计算单元间反复读写,导致内存带宽被大量浪费。
  2. 算子调度开销:大量下发微小算子会增加 CPU 与 NPU 之间的通信与下发耗时。

npu_fusion_attention 是昇腾针对 NPU 亲和性深度优化的融合算子,借鉴 FlashAttention 的设计思想,通过片上 SRAM 缓存的分块计算,极大地减少了对 HBM 的访问频率,从而实现显著的性能加速和显存优化。

2. 算子数学原理

FlashAttention 的核心通过对计算逻辑的重组,将注意力机制的计算公式进行融合处理。标准的 Attention 计算公式如下:

Attention(Q,K,V)=Dropout(Softmax(Mask(QKscale+pse)))V

其中:

  • Q(Query)、K(Key)、V(Value)为输入张量;
  • dk 为 Head Dim(注意力头维度),通常 scale=1/dk
  • pse 为可选的位置编码(如 ALiBi)。

在融合算子内部,通过 Tiling(分块)技术和在线 Softmax 算法,实现在不显式存储 N×N 满秩注意力矩阵的情况下完成梯度回传,将空间复杂度从 O(N2) 降低至 O(N)

3. 函数原型

python
torch_npu.npu_fusion_attention(
    query, key, value, head_num, input_layout,
    pse=None, padding_mask=None, atten_mask=None,
    scale=1.0, keep_prob=1.0, pre_tokens=2147483647,
    next_tokens=2147483647, inner_precise=0,
    prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None,
    sparse_mode=0, gen_mask_parallel=True, sync=False,
    softmax_layout="NTD"
)

4. 关键参数详解

4.1 核心输入张量

  • query / key / value:核心计算张量,支持 float16bfloat16
    • 约束:三者数据类型必须一致,且 Head Dim(D)需满足 Dq=DkDv
  • head_num:整数类型,表示注意力头(Head)的数量。
  • input_layout:字符串类型,定义输入张量的维度排列格式。支持 BSHSBHBSNDBNSDTND(其中 TND 专为变长序列场景设计)。

4.2 功能性参数

参数名称类型描述备注
pseTensor位置编码(Position Embedding)可选,支持 ALiBi 等位置编码融合
atten_maskTensor注意力掩码1 表示遮蔽,0 表示保留;支持 BNSSB1SSSS 等形状
scaleFloat缩放因子对应公式中的 scale,默认为 1.0
keep_probFloatDropout 保留概率取值范围 (0,1],默认为 1.0

4.3 稀疏与变长控制

  • sparse_mode:稀疏模式选择。
    • 0:Default Mask(根据 pre_tokensnext_tokens 确定范围)
    • 2 / 3:Causal 模式(左上/右下顶点划分的下三角)
  • actual_seq_qlen / actual_seq_kvlen:在变长(Varlen)场景下,用于描述 Batch 中每个序列的累加长度(累加前缀和)。

5. 输出说明

接口返回一个包含 7 个元素的元组:

序号名称说明
1attention_out最终的计算结果 Tensor
2softmax_maxSoftmax 的 Max 中间结果,用于反向传播
3softmax_sumSoftmax 的 Sum 中间结果,用于反向传播
4logsumexp预留参数(暂未使用)
5–7seed / offset / numelsDropout 随机数生成的种子、偏移量及元素总数

6. 迁移与适配步骤

将原生代码迁移至融合算子通常分为三步:

6.1 第一步:识别离散逻辑

定位模型代码中计算 Attention 的部分,通常包含类似以下结构:

python
attn_weights = torch.matmul(q, k.transpose(-1, -2)) * scale
if mask is not None:
    attn_weights += mask
attn_probs = F.softmax(attn_weights, dim=-1)
output = torch.matmul(attn_probs, v)

6.2 第二步:对齐输入规格

确保 Tensor 的 Layout(如 BSH 即 Batch、Sequence、Hidden)符合 NPU 融合算子的对齐要求。通常要求 Head Dim 是 16 或 32 的倍数以触发高性能内核。

6.3 第三步:替换为融合接口

直接调用 npu_fusion_attention

python
import torch_npu

output = torch_npu.npu_fusion_attention(
    query, key, value, head_num,
    input_layout="BSH",
    pse=None,
    padding_mask=None,
    atten_mask=None,
    scale=1.0 / (head_dim ** 0.5),
    keep_prob=1.0
)

补充说明:如果使用混合精度训练,建议开启 AMP(自动混合精度),该算子对 float16bfloat16 具有极佳的加速效果。

7. 核心约束与注意事项

  1. 场景限制:该接口目前仅支持训练场景,且不支持 PyTorch 图模式(JIT/Symbolic)。
  2. 维度要求
    • Head Dim(D)取值范围:[1,768]
    • Sequence Length(S)最大支持 1M
    • 支持 GQA(Grouped-Query Attention)模式,即 Nq/Nkv 为正整数
  3. 精度与溢出:在大规模计算下,若计算量过大(受 B,S,N,D 影响)可能触发 AICore 超时错误。此时建议在模型脚本层面对 Sequence 轴进行切分处理。

8. 性能表现与调优建议

根据实测,使用融合算子替换离散算子后:

  • 吞吐量提升:在长序列(Sequence Length > 2048)场景下,训练吞吐量可提升 30%100% 以上。
  • 显存释放:由于无需存储完整的 Attention Score 矩阵,显存占用显著下降,允许使用更大的 Batch Size。

调优建议

  • 优先使用 BNSD 布局:在昇腾架构下,BNSD(Batch、Num_heads、Seq、Dim)通常能获得更直接的内存访问效率。
  • Varlen 场景使用 TND:处理 NLP 任务中的 Padding 数据时,推荐使用 TND 布局配合 actual_seq_len,比传统的 Padding + Mask 方式节省大量无效计算,效率提升约 20%50%
  • 启用 Causal 稀疏模式:如果显存压力较大,建议优先使能 sparse_mode=23 的 Causal 模式,配合压缩版的 atten_mask
  • 算子下沉:确保在 NPU 模式下,整个计算图尽量保持在 Device 侧,避免频繁的 Host-to-Device 拷贝。
  • 对齐优化:输入张量的各维度尽量对齐到 16 或 32 的倍数以获得最佳性能。

Maintained by Robin