从 Online Softmax 到 FlashAttention
注意力机制是 Transformer 的核心组件,但标准实现在长序列下面临显存与计算效率双重瓶颈。本文从 Online Softmax 讲起,逐步推导 FlashAttention 如何通过分块(Tiling)与在线融合,将注意力计算降至接近单轮遍历,并在 GPU 片上内存中完成计算。
1. 标准自注意力
忽略多头、批次维度以及缩放因子
其中
计算分为三步:
其中
2. 矩阵分块优化
2.1 朴素矩阵乘法
import numpy as np
def naive_matmul(A, B):
n, m = A.shape
m2, p = B.shape
assert m == m2, "维度不匹配"
C = np.zeros((n, p))
for i in range(n):
for j in range(p):
for k in range(m):
C[i, j] += A[i, k] * B[k, j]
return C2.2 分块矩阵乘法
将大矩阵分解为可放入高速缓存的小块(Tiles),通过数据复用最大化计算效率:
import numpy as np
def tiled_matmul(A, B, tile_size):
n, m = A.shape
m2, p = B.shape
assert m == m2, "维度不匹配"
C = np.zeros((n, p))
for ii in range(0, n, tile_size):
for jj in range(0, p, tile_size):
for kk in range(0, m, tile_size):
i_end = min(ii + tile_size, n)
j_end = min(jj + tile_size, p)
k_end = min(kk + tile_size, m)
for i in range(ii, i_end):
for j in range(jj, j_end):
for k in range(kk, k_end):
C[i, j] += A[i, k] * B[k, j]
return C分块技术显著减少了缓存未命中,是 FlashAttention 在 GPU SRAM 中完成计算的基础。
3. Softmax 的数值稳定性
Softmax 将一组数值转换为概率分布:
直接计算
3.1 安全 Softmax
通过减去最大值保证指数非正:
标准安全 Softmax 需要三次遍历:
- 求最大值:
- 求和:
- 归一化:
import numpy as np
def softmax_3pass(x):
n = len(x)
m = np.max(x)
exp_x = np.exp(x - m)
d = np.sum(exp_x)
return exp_x / d在 Transformer 中,
4. Online Softmax:两轮遍历
Online Softmax 通过维护一个替代统计量,将遍历次数从三轮减至两轮。定义:
其递推关系为:
算法:2-pass Online Softmax
第一遍遍历
到 : 第二遍遍历
到 :
import numpy as np
def softmax_online(x):
n = len(x)
m = x[0]
d = 1.0
for i in range(1, n):
if x[i] > m:
d = d * np.exp(m - x[i]) + 1.0
m = x[i]
else:
d += np.exp(x[i] - m)
return np.exp(x - m) / dOnline Softmax 将全局最大值与分母的更新合并到同一轮遍历中,为 FlashAttention 的进一步融合奠定了基础。
5. FlashAttention:单轮遍历
FlashAttention 的核心洞察是:最终目标是输出矩阵
5.1 输出向量的递推
对第
其中
该式仅依赖
算法:FlashAttention(逐 Token)
遍历
最终
5.2 分块版 FlashAttention
实际实现中按块处理,每块包含
算法:FlashAttention(分块版)
遍历
import numpy as np
def flash_attention(Q, K, V, k):
"""
单头 FlashAttention 的逐 Token 参考实现。
返回 O[k, :] = softmax(Q[k, :] @ K.T) @ V。
"""
N = K.shape[0]
m_prev = float("-inf")
d_prev = 0.0
o_prev = np.zeros_like(V[0, :])
for i in range(N):
x_i = np.dot(Q[k, :], K[i, :])
m_i = max(m_prev, x_i)
d_i = d_prev * np.exp(m_prev - m_i) + np.exp(x_i - m_i)
o_i = o_prev * (d_prev * np.exp(m_prev - m_i) / d_i) + \
(np.exp(x_i - m_i) / d_i) * V[i, :]
m_prev, d_prev, o_prev = m_i, d_i, o_i
return o_prevFlashAttention 通过在线融合机制,将原本至少三遍的注意力计算压缩到一遍,在 GPU 片上内存中完成计算,避免
6. 技术优势
6.1 显存效率
FlashAttention 避免存储巨大的
6.2 计算性能
通过将
- 大幅减少内存访问次数
- 计算与内存访问重叠
- 更高的 GPU 资源利用率
总结
| 技术 | 遍历次数 | 核心思想 |
|---|---|---|
| 传统 Softmax | 3 轮 | 分别求最大值、求和、归一化 |
| Online Softmax | 2 轮 | 在线更新最大值与分母 |
| FlashAttention | 1 轮(效果) | 直接在线更新输出矩阵,端到端融合 |
FlashAttention 的创新不仅提升了 Transformer 的训练与推理速度,更为超长上下文理解、文档处理等应用提供了可行的技术方案。
参考
- Dao, T., et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS.
- Milakov, M., & Gimelshein, N. (2018). Online normalizer calculation for softmax. arXiv:1805.02867.