Skip to content

多头潜在注意力(MLA)

Multi-Head Latent Attention(MLA)是 DeepSeek 提出的一种高效注意力机制,通过对 KV 缓存进行低秩联合压缩,在显著降低推理阶段显存占用的同时,保持与标准多头注意力(MHA)相当的模型性能。其核心思想是将键(Key)和值(Value)投影到低维潜在空间,仅缓存压缩后的隐向量,并在计算时通过上投影恢复。以下是 MLA 的详细数学推导。

约定:所有计算使用行向量,即 y=xW

1. Q 的计算

ctQ=htWDQ,[qt,1C;qt,2C;;qt,nhC]=qtC=ctQWUQ,[qt,1R;qt,2R;;qt,nhR]=qtR=RoPE(ctQWQR),qt,i=[qt,iC;qt,iR],

其中:

  • ctQRdc 是查询的压缩隐向量;
  • dc(dhnh) 表示查询压缩维度;
  • WDQRd×dcWUQRdc×dhnh 分别是查询的下投影和上投影矩阵;
  • WQRRdc×dhRnh 用于生成携带 RoPE 的解耦查询。

2. KV 的计算

ctKV=htWDKV,[kt,1C;kt,2C;;kt,nhC]=ktC=ctKVWUK,ktR=RoPE(htWKR),kt,i=[kt,iC;ktR],[vt,1C;vt,2C;;vt,nhC]=vtC=ctKVWUV,

其中:

  • ctKVRdc 是键值的压缩隐向量;
  • dc(dhnh) 表示 KV 压缩维度;
  • WDKVRd×dc 是下投影矩阵;
  • WUK,WUVRdc×dhnh 是键和值的上投影矩阵;
  • WKRRd×dhR 用于生成携带 RoPE 的解耦键;
  • RoPE() 表示应用旋转位置编码的操作。

注意:对于 MLA,仅需缓存蓝色框中的向量(ctKVktR),从而显著减少 KV 缓存大小,同时保持与标准多头注意力(MHA)相当的性能。

最终,注意力查询(qt,i)、键(kj,i)和值(vj,iC)组合得到最终输出 ut

ot,i=j=1tSoftmaxj(qt,ikj,idh+dhR)vj,iC,ut=[ot,1;ot,2;;ot,nh]WO,

其中 WORdhnh×d 是输出投影矩阵。

3. 实际参数配置

  • d=hidden_size=7168
  • dc=kv_lora_rank=512
  • dc=q_lora_rank=1536
  • nh=num_heads=128
  • dh=qk_nope_head_dim=128
  • dhR=qk_rope_head_dim=64

此外:

  • WUQWQR 可合并,q_head_dim=qk_nope_head_dim+qk_rope_head_dim=192
  • WDKVWKR 可合并,kv_lora_rank+qk_rope_head_dim=576

4. 矩阵吸收(Absorb)

考虑如下计算:

Y=XAB,C=AB

其中:

  • XRm×d 是输入隐状态(hidden states);
  • ARd×dcBRdc×n 是权重矩阵;
  • CRd×n 是 absorb 后的等效权重矩阵。

直接计算的 FLOPs 为:

2mddc+2mndc=2mdc(d+n)

合并权重后计算的 FLOPs 为:

2mdn

dc 较小时,通常有:

dn>dc(d+n)

因此不一定需要合并两个权重矩阵

不考虑 RoPE 部分,仅从 cQcKV 计算 qiki(第 i 个 head):

qiki=cQWiUQ(cKVWiUK),=cQWiUQ(WiUK)(cKV),=qi(WiUK)(cKV),(Absorb)=qi(cKVWiUK),(Normal)

警告:此处 "Absorb" 的真实含义是利用矩阵乘法结合律,优先将 qWUK 结合,并缓存压缩隐向量 cKV。它并非合并权重矩阵,"Absorb" 这一命名具有一定误导性。

4.1 为什么计算时不把 WiUQ(WiUK) 合并

对单个 token、单个 head,FLOPs 分别为:

  • 分开计算:2dh(dc+dc)=524,288
  • 合并计算:2dcdc=1,572,864=3×524,288

合并后计算量反而是原来的 3 倍。

4.2 为什么 Prefill 阶段显式计算 k 和 v,而 Decode 阶段不需要

假设输入 shape 如下:

  • q:(b,nh,sq,dh)
  • cKV:(b,1,skv,dc)
  • WUK:(dc,nhdh)

4.2.1 Prefill 阶段(sq=skv=s

FLOPs 对比:

TNormal=2bnhdhs(dc+s),TAbsorb=2bnhdcs(dh+s),TNormalTAbsorb=dh(dc+s)dc(dh+s)=s+5124s+512(14,1)

Prefill 阶段 Normal 更快,且此阶段是计算瓶颈,故显式计算 qk

4.2.2 Decode 阶段(sq=1,skv=s

FLOPs 对比:

TNormalK=2bnhdh(dc+s),(缓存 k)TNormalL=2bnhdh(dcs+s),(缓存 latent)TAbsorb=2bnhdc(dh+s),TNormalKTAbsorb=dh(dc+s)dc(dh+s)=s+5124s+512(0.25,1),TNormalLTAbsorb=513s4s+512(0.99,128.25)

虽然缓存 k 的计算量最小(极限为 Absorb 的 1/4),但 Decode 阶段瓶颈是显存带宽

4.2.3 内存读取量对比(bfloat16 精度)

  • MLA(Absorb)(b,nh,1,dc)×(b,1,s,dc)

    MMLA=2bdc(nh+s)
  • 标准 MHA(b,nh,1,dh)×(b,nh,s,dh)

    MMHA=2bdhnh(1+s)

内存读取比例:

MMLAMMHA=dc(nh+s)dhnh(1+s)=128+s32(1+s)
  • s=20 时,比值 0.22
  • 极限情况(s):比值 1/32

因此,Decode 阶段采用 Absorb 方式,可大幅降低显存带宽压力,并复用 MQA(Multi-Query Attention)实现

5. 矩阵吸收问题总结

"矩阵吸收"的本质是如何应用矩阵乘法结合律

Y=(XA)B=X(AB),Z=(XW)Y=X(WY),

其中 A,B,W 均为权重矩阵。

决策依据应综合权衡:

  • 计算量(FLOPs)
  • 显存读写量(Memory Traffic)
  • 当前阶段瓶颈(计算 or 带宽)

可借助 Roofline Model 进行系统性分析。

Maintained by Robin