多头潜在注意力(MLA)
Multi-Head Latent Attention(MLA)是 DeepSeek 提出的一种高效注意力机制,通过对 KV 缓存进行低秩联合压缩,在显著降低推理阶段显存占用的同时,保持与标准多头注意力(MHA)相当的模型性能。其核心思想是将键(Key)和值(Value)投影到低维潜在空间,仅缓存压缩后的隐向量,并在计算时通过上投影恢复。以下是 MLA 的详细数学推导。
约定:所有计算使用行向量,即
。
1. Q 的计算
其中:
是查询的压缩隐向量; 表示查询压缩维度; 、 分别是查询的下投影和上投影矩阵; 用于生成携带 RoPE 的解耦查询。
2. KV 的计算
其中:
是键值的压缩隐向量; 表示 KV 压缩维度; 是下投影矩阵; 是键和值的上投影矩阵; 用于生成携带 RoPE 的解耦键; 表示应用旋转位置编码的操作。
注意:对于 MLA,仅需缓存蓝色框中的向量(
和 ),从而显著减少 KV 缓存大小,同时保持与标准多头注意力(MHA)相当的性能。
最终,注意力查询(
其中
3. 实际参数配置
此外:
和 可合并, 。 和 可合并, 。
4. 矩阵吸收(Absorb)
考虑如下计算:
其中:
是输入隐状态(hidden states); 、 是权重矩阵; 是 absorb 后的等效权重矩阵。
直接计算的 FLOPs 为:
合并权重后计算的 FLOPs 为:
当
因此不一定需要合并两个权重矩阵。
不考虑 RoPE 部分,仅从
警告:此处 "Absorb" 的真实含义是利用矩阵乘法结合律,优先将
与 结合,并缓存压缩隐向量 。它并非合并权重矩阵,"Absorb" 这一命名具有一定误导性。
4.1 为什么计算时不把 合并
对单个 token、单个 head,FLOPs 分别为:
- 分开计算:
- 合并计算:
合并后计算量反而是原来的 3 倍。
4.2 为什么 Prefill 阶段显式计算 k 和 v,而 Decode 阶段不需要
假设输入 shape 如下:
4.2.1 Prefill 阶段( )
FLOPs 对比:
Prefill 阶段 Normal 更快,且此阶段是计算瓶颈,故显式计算
4.2.2 Decode 阶段( )
FLOPs 对比:
虽然缓存 k 的计算量最小(极限为 Absorb 的 1/4),但 Decode 阶段瓶颈是显存带宽。
4.2.3 内存读取量对比(bfloat16 精度)
MLA(Absorb):
标准 MHA:
内存读取比例:
- 当
时,比值 ; - 极限情况(
):比值 。
因此,Decode 阶段采用 Absorb 方式,可大幅降低显存带宽压力,并复用 MQA(Multi-Query Attention)实现。
5. 矩阵吸收问题总结
"矩阵吸收"的本质是如何应用矩阵乘法结合律:
其中
决策依据应综合权衡:
- 计算量(FLOPs)
- 显存读写量(Memory Traffic)
- 当前阶段瓶颈(计算 or 带宽)
可借助 Roofline Model 进行系统性分析。