Transformer 注意力机制:MHA、MQA 与 GQA
1. 背景
Transformer(Vaswani et al., 2017)架构的提出彻底改变了自然语言处理(NLP)领域。该架构最初基于编码器-解码器(Encoder-Decoder)结构,随后演化出一系列变体:如仅包含编码器的 BERT(Devlin et al., 2018),以及仅包含解码器的 GPT(Radford et al., 2018)系列。目前主流的大型语言模型(LLM),如 LLaMA(Touvron et al., 2023)和 GPT-4,大多延续了仅解码器(Decoder-only)的架构。
2. 符号定义
| 符号 | 含义 |
|---|---|
| 批量大小(Batch Size) | |
| 序列长度(Sequence Length) | |
| 隐藏层维度 / 模型维度(Model Dimension) | |
| 注意力头数量(Number of Attention Heads) | |
| 分组数量(Group Number),用于 GQA | |
| 每个注意力头的维度,通常 | |
| 输入张量, | |
| 经过线性变换后的查询(Query)、键(Key)、值(Value)矩阵 | |
| 映射矩阵, | |
| 输出映射矩阵, | |
| 第 | |
| MQA 中所有头共享的键和值矩阵 |
3. Transformer 中的注意力机制
Transformer 的核心在于自注意力机制(Self-Attention),它赋予了模型动态捕捉序列内部长程依赖的能力。
对于输入序列
Transformer 采用的是缩放点积注意力(Scaled Dot-Product Attention)。其基本思想是计算查询与键之间的相关性,并将其作为权重对值进行加权求和:
3.1 多头注意力(Multi-Head Attention, MHA)
MHA 通过将
其中每个头的计算公式为:
MHA 的优势:
- 多维度特征捕捉:不同头可以关注序列中不同的语法或语义特征。
- 增强表达能力:通过子空间集成,提升了模型对复杂依赖关系的建模精度。
- 计算并行性:各头的计算逻辑相互独立,适合 GPU/TPU 硬件加速。
3.1.1 缩放因子 的必要性
引入缩放因子的主要目的是维持数值稳定性,防止 Softmax 函数进入梯度饱和区:
防止梯度消失:若点积结果过大,Softmax 的输出会集中在极小或极大的区域,导致导数接近于 0。
数学推导:假设
和 的各分量是独立同分布的随机变量,且满足均值为 0、方差为 1。则其点积 的方差为 。 通过除以
,可以使缩放后点积的方差恢复为 1:
3.2 多查询注意力(Multi-Query Attention, MQA)
MQA(Shazeer, 2019)是一种旨在提升推理效率的变体。在 MQA 中,所有的查询头共享同一组键(Key)和值(Value)。
其核心逻辑如下:
核心价值:显著减少了推理阶段 KV Cache 的显存占用和访存开销(Memory Bandwidth),这对长文本生成尤为重要。
3.3 分组查询注意力(Grouped-Query Attention, GQA)
GQA(Ainslie, 2023)是 MHA 与 MQA 的折中方案,它在保持推理效率的同时,尽可能保留多头机制的表达能力。
GQA 将查询头分为
- 若
,则等同于 MQA。 - 若
,则等同于 MHA。
3.4 三者对比总结
- MHA:
个 Query 头, 个 KV 头。性能最优,但推理时 KV Cache 显存压力大。 - MQA:
个 Query 头, 个 KV 头。推理速度最快,显存占用最低,但可能损失一定的模型容量。 - GQA:
个 Query 头, 个 KV 头( )。在速度与性能之间取得最佳平衡,是目前主流大模型(如 Llama 3)的首选。
4. 复杂度分析
4.1 时间复杂度(Time Complexity)
无论是 MHA、MQA 还是 GQA,对于完整序列的一次性前向传播,其计算复杂度量级是相同的。
- 矩阵乘法
:复杂度为 。 - 加权求和(与
相乘):复杂度同样为 。 - 总体量级:
。注意,注意力机制的计算开销随序列长度 呈二次方增长。
增量解码(Incremental Decoding)场景:
在 LLM 推理时,利用 KV Cache 缓存历史信息。每生成一个新 Token,只需计算当前 Query 与历史 KV 的关联:
- 单步复杂度:
。
4.2 空间复杂度(Space Complexity)
空间复杂度主要由参数量和中间激活值(KV Cache)组成。
- 参数量:
四个矩阵的参数量均为 ,总参数量约为 。MQA/GQA 虽然减少了 KV 头的数量,但由于投影矩阵的维度变化,其参数量微减,通常仍视为 。 - KV Cache 显存占用:这是 MQA/GQA 优化的核心。
- MHA:每个 Token 需要存储
个数值。 - MQA:每个 Token 仅需存储
个数值。显存占用降低为原来的 。 - GQA:每个 Token 需要存储
个数值。显存占用介于两者之间。
- MHA:每个 Token 需要存储
结论
在大语言模型时代,显存带宽往往是推理性能的瓶颈(Memory-Bound)。MQA 通过极致的共享策略解决了访存效率问题,但可能影响复杂任务的表现;GQA 则通过灵活的分组机制,在推理延迟、显存占用与模型效果之间找到了黄金平衡点,已成为当前工业界的事实标准。