Skip to content

从 Online Softmax 到 FlashAttention

注意力机制是 Transformer 的核心组件,但标准实现在长序列下面临显存与计算效率双重瓶颈。本文从 Online Softmax 讲起,逐步推导 FlashAttention 如何通过分块(Tiling)与在线融合,将注意力计算降至接近单轮遍历,并在 GPU 片上内存中完成计算。

1. 标准自注意力

忽略多头、批次维度以及缩放因子 1d,标准自注意力可写作:

O=softmax(QK)V

其中 Q,K,V,ORL×dL 为序列长度,d 为每个注意力头的维度。

计算分为三步:

X=QK,A=softmax(X),O=AV

其中 X 为 Pre-softmax Logits,A 为注意力分数矩阵,O 为输出矩阵。直接实现需要存储 L×L 的中间矩阵,显存开销随序列长度平方增长。

2. 矩阵分块优化

2.1 朴素矩阵乘法

python
import numpy as np

def naive_matmul(A, B):
    n, m = A.shape
    m2, p = B.shape
    assert m == m2, "维度不匹配"
    C = np.zeros((n, p))
    for i in range(n):
        for j in range(p):
            for k in range(m):
                C[i, j] += A[i, k] * B[k, j]
    return C

2.2 分块矩阵乘法

将大矩阵分解为可放入高速缓存的小块(Tiles),通过数据复用最大化计算效率:

python
import numpy as np

def tiled_matmul(A, B, tile_size):
    n, m = A.shape
    m2, p = B.shape
    assert m == m2, "维度不匹配"

    C = np.zeros((n, p))
    for ii in range(0, n, tile_size):
        for jj in range(0, p, tile_size):
            for kk in range(0, m, tile_size):
                i_end = min(ii + tile_size, n)
                j_end = min(jj + tile_size, p)
                k_end = min(kk + tile_size, m)
                for i in range(ii, i_end):
                    for j in range(jj, j_end):
                        for k in range(kk, k_end):
                            C[i, j] += A[i, k] * B[k, j]
    return C

分块技术显著减少了缓存未命中,是 FlashAttention 在 GPU SRAM 中完成计算的基础。

3. Softmax 的数值稳定性

Softmax 将一组数值转换为概率分布:

softmax(xi)=exij=1nexj

直接计算 exi 容易导致数值溢出。例如 FP16 最大可表示值约为 65504,而 e11.1 已超出此范围。

3.1 安全 Softmax

通过减去最大值保证指数非正:

softmax(xi)=eximj=1nexjm,m=maxjxj

标准安全 Softmax 需要三次遍历:

  1. 求最大值:m=maxjxj
  2. 求和:d=j=1nexjm
  3. 归一化:ai=eximd
python
import numpy as np

def softmax_3pass(x):
    n = len(x)
    m = np.max(x)
    exp_x = np.exp(x - m)
    d = np.sum(exp_x)
    return exp_x / d

在 Transformer 中,xiQK 动态计算。如果不存储所有 Logits,就需要三次访问 QK,I/O 效率极低。

4. Online Softmax:两轮遍历

Online Softmax 通过维护一个替代统计量,将遍历次数从三轮减至两轮。定义:

di=j=1iexjmi,mi=maxjixj

其递推关系为:

di=di1emi1mi+eximi

算法:2-pass Online Softmax

  1. 第一遍遍历 i=1n

    mi=max(mi1,xi)di=di1emi1mi+eximi
  2. 第二遍遍历 i=1n

    ai=eximndn
python
import numpy as np

def softmax_online(x):
    n = len(x)
    m = x[0]
    d = 1.0
    for i in range(1, n):
        if x[i] > m:
            d = d * np.exp(m - x[i]) + 1.0
            m = x[i]
        else:
            d += np.exp(x[i] - m)
    return np.exp(x - m) / d

Online Softmax 将全局最大值与分母的更新合并到同一轮遍历中,为 FlashAttention 的进一步融合奠定了基础。

5. FlashAttention:单轮遍历

FlashAttention 的核心洞察是:最终目标是输出矩阵 O,而非注意力分数矩阵 A。因此可以尝试直接建立 O 的在线递推形式。

5.1 输出向量的递推

对第 k 行(各行独立),定义:

oi=j=1iexjmidivj

其中 xj=qkkjvjV 的第 j 行。可以证明 oN 即为该行的最终输出。

oioi1 的递推关系为:

oi=oi1di1emi1midi+eximidivi

该式仅依赖 di,di1,mi,mi1,xivi,因此可将整个自注意力计算融合到一个循环中。

算法:FlashAttention(逐 Token)

遍历 i=1n

  1. xi=qkki
  2. mi=max(mi1,xi)
  3. di=di1emi1mi+eximi
  4. oi=oi1di1emi1midi+eximidivi

最终 O[k,:]=oN

5.2 分块版 FlashAttention

实际实现中按块处理,每块包含 b 个 Token:

算法:FlashAttention(分块版)

遍历 i=1#tiles

  1. xi=qkK[(i1)b:ib]
  2. mi(local)=maxjxi[j]
  3. mi=max(mi1,mi(local))
  4. di=di1emi1mi+j=1bexi[j]mi
  5. oi=oi1di1emi1midi+j=1bexi[j]midiV[(i1)b+j]
python
import numpy as np

def flash_attention(Q, K, V, k):
    """
    单头 FlashAttention 的逐 Token 参考实现。
    返回 O[k, :] = softmax(Q[k, :] @ K.T) @ V。
    """
    N = K.shape[0]
    m_prev = float("-inf")
    d_prev = 0.0
    o_prev = np.zeros_like(V[0, :])

    for i in range(N):
        x_i = np.dot(Q[k, :], K[i, :])
        m_i = max(m_prev, x_i)
        d_i = d_prev * np.exp(m_prev - m_i) + np.exp(x_i - m_i)
        o_i = o_prev * (d_prev * np.exp(m_prev - m_i) / d_i) + \
              (np.exp(x_i - m_i) / d_i) * V[i, :]

        m_prev, d_prev, o_prev = m_i, d_i, o_i

    return o_prev

FlashAttention 通过在线融合机制,将原本至少三遍的注意力计算压缩到一遍,在 GPU 片上内存中完成计算,避免 O(L2) 的中间显存开销。

6. 技术优势

6.1 显存效率

FlashAttention 避免存储巨大的 L×L 注意力分数矩阵,显著降低显存带宽消耗,使在有限显存下处理超长序列(如 100K+ 上下文)成为可能。

6.2 计算性能

通过将 QK、Softmax 与 V 乘法融合为单一 CUDA Kernel,FlashAttention 实现了:

  • 大幅减少内存访问次数
  • 计算与内存访问重叠
  • 更高的 GPU 资源利用率

总结

技术遍历次数核心思想
传统 Softmax3 轮分别求最大值、求和、归一化
Online Softmax2 轮在线更新最大值与分母
FlashAttention1 轮(效果)直接在线更新输出矩阵,端到端融合

FlashAttention 的创新不仅提升了 Transformer 的训练与推理速度,更为超长上下文理解、文档处理等应用提供了可行的技术方案。

参考

  1. Dao, T., et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS.
  2. Milakov, M., & Gimelshein, N. (2018). Online normalizer calculation for softmax. arXiv:1805.02867.

Maintained by Robin