旋转式位置编码(RoPE)
旋转式位置编码(Rotary Position Embedding,RoPE)最早由论文 RoFormer: Enhanced Transformer with Rotary Position Embedding 提出,是一种将相对位置信息集成到 self-attention 中并提升 Transformer 架构性能的位置编码方式。目前广受关注的 LLaMA 模型也采用了该位置编码方式。
1. 基本概念
首先定义一个长度为
其中
其中
在执行 self-attention 之前,会用词嵌入向量计算 query、key、value 向量并同时加入位置信息:
其中
基于 Transformer 的位置编码方法都着重于构造合适的
2. 绝对位置编码
对于位置编码,常规做法是在计算 query、key 和 value 向量之前,先计算一个位置编码向量
而经典的位置编码向量
其中
3. 旋转式位置编码
接下来介绍 Rotary Transformer(RoFormer)模型。它的主要改动是引入"旋转式位置编码(Rotary Position Embedding,RoPE)",这是一种配合 Attention 机制能达到"以绝对位置编码的方式实现相对位置编码"的设计。正因如此,它也是目前唯一一种可用于线性 Attention 的相对位置编码。
3.1 基本思路
在 RoPE 中,出发点是"通过绝对位置编码的方式实现相对位置编码"。这一设计既有理论上的优雅之处,也有实践上的实用价值,例如它可以扩展到线性 Attention 中。
在机器学习中,我们通常只关注实数,但对于旋转嵌入来说,使用复数作为空间的基域在数学上更为方便。先考虑二维情形,然后借助复数来求解。将 query 向量和 key 向量的元素视为单个复数,我们使用
也就是说,分别为
因此需要给出该恒等式的一个尽可能简单的解。求解过程还需要初始条件,显然可以合理地设
3.2 求解过程
在复数中有
简单起见,假设存在复数
则:
对于第一个方程,代入
最后一个等号源于初始条件
这里的
所以
即
将前面所有的公式推导汇总,即可得到 Rotary Position Embedding 的最终表达式:
因此,对于任意的
由于与复数相比,计算机更喜欢实数和矩阵,因此将此表达式转换为矩阵方程很方便:
其中:
3.3 编码形式
综上,我们得到二维情况下用复数表示的 RoPE:
根据复数乘法的几何意义,该变换实际上对应着向量的旋转,因此称之为"旋转式位置编码"。它还可以写成矩阵形式:
由于内积满足线性叠加性,任意偶数维的 RoPE 都可以表示为二维情形的拼接,即:
也就是说,给位置为
值得指出的是,
由于
其中 * 运算。从这个实现也可以看到,RoPE 可以视为三角函数式位置编码的变体。
3.4 LLaMA 模型中的 RoPE
LLaMA 模型使用了 Rotary Position Embedding。对于
3.4.1 Step 1:初始化 矩阵
3.4.2 Step 2:计算 矩阵和 矩阵
3.4.3 Step 3:计算 Query 向量
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)2
对应公式:
4. RoPE 证明过程
4.1 简单证明
简单起见,先假设
其中
值得注意的是,内积只依赖于相对位置
由上述结果可知,对于位置为
这意味着,通过
来赋予
这样一来,我们得到了一种融合绝对位置与相对位置的位置编码方案。从形式上看它类似乘性的绝对位置编码:通过在
4.2 完整证明
假定 query 向量
我们的目标是找到一个等价的位置编码方式,使得上述关系成立,即构造出函数
假定词嵌入向量的维度为二维
这里
首先看到上述
即上述指数函数可以表示为实部为
则上述
然后看回公式:
其中
首先将
接着:
其实就是两个复数相乘:
复数乘法使用分配律,并利用
复习一下复数乘法的性质:
将结果重新表达成实数向量形式就是:
因此:
看到这里会发现,这就是 query 向量乘以了一个旋转矩阵。这就是"旋转位置编码"名称的由来。
同理,
最后还有个函数
其中
复习一下共轭复数的定义:
所以可得:
继续可得:
接下来我们就要证明函数
首先回顾一下 attention 操作,位置
接着继续之前先复习一下三角函数的和差公式:
回到上面的式子,整理得到:
这就证明了上述关系成立:位置
把上面的式子用矩阵向量乘的形式来表达就是:
上面的推导假定词嵌入维度为 2 维向量。对于
综上,RoPE 的 self-attention 操作流程为:对 token 序列中的每个词嵌入向量,首先计算对应的 query 和 key 向量,然后对每个 token 位置计算对应的旋转位置编码,接着对 query 和 key 向量的元素按两两一组应用旋转变换,最后计算 query 和 key 之间的内积得到 self-attention 的计算结果。
5. RoPE 的性质
5.1 远程衰减
可以看到,RoPE 形式上和 Sinusoidal 位置编码有一定相似性,只不过 Sinusoidal 位置编码是加性的,而 RoPE 可视为乘性的。在
具体证明如下:将
记
所以:
因此可以考察
5.2 线性场景
最后指出,RoPE 是目前唯一一种可用于线性 Attention 的相对位置编码。这是因为其他相对位置编码直接基于 Attention 矩阵进行操作,而线性 Attention 并不事先计算 Attention 矩阵,因此无法应用。RoPE 以绝对位置编码的方式实现相对位置编码,不需要操作 Attention 矩阵,因而具备应用到线性 Attention 的可能性。
线性 Attention 的常见形式是:
其中
但这样存在的问题是,内积
也就是说,RoPE 只插入分子中,分母保持不变。这样的注意力不再是基于概率的(注意力矩阵不再满足非负归一性),但某种意义上也是一种归一化方案。目前也没有证据表明非概率式的注意力效果更差(例如 Nyströmformer 也未严格依据概率分布构建注意力)。因此将其作为候选方案之一进行实验,初步实验结果显示这样的线性 Attention 也是有效的。
5.3 RoPE 的长度扩展
在 LLM 的应用中,有一个非常重要的参数——上下文长度(max context length)。更长的上下文长度允许进行更多轮次的对话、对更长的文本进行总结分析,也允许生成更长的文章。然而在训练 LLM 时,训练语料大部分不够长,许多 LLM 训练时设计的最大文本长度仅为 2k(即最长 2048 个 token)。那么,能否在训练时使用较短的文本,而在推理时扩展到长文本上呢?
这是可行的,可以对 RoPE 进行长度扩展。下面介绍三种扩展方案。
5.3.1 直接外推
直接外推即继续沿用现有位置编码公式,不做任何修改。在扩展长度不太大时(例如由 2k 扩展到 2.5k),此方法对性能的影响不大。旋转位置编码只与相对位置
因此,如果模型已从训练数据中学习到 token 之间在 0-2k 范围内合适的衰减规律,将其应用到 0-2.5k 通常也没有问题。但若扩展到更长的长度(例如从 2k 扩展到 32k),直接外推通常会严重影响性能。因为学习到的衰减规律可能在 5k 处就完全衰减为零,导致无法捕捉超过 5k 相对距离的 token 之间的相互作用。
总结:直接外推对衰减规律在长距离情况下的使用容易出现问题。为减少性能影响,可以让训练好的模型在更长的上下文上做少量步骤的微调。
5.3.2 线性内插
线性内插需要改变位置编码公式,等效于将位置序号等比例缩小。
例如从 2k 扩展到 32k 时,等效于将位置序号缩小为原来的 1/16。线性内插未改变模型学习到的衰减规律的应用范围,不做微调时其效果一般优于直接外推方案。但当扩展倍数非常大时(如从 2k 到 32k),性能也会明显受影响。原因在于短距离情况下的使用受到较大影响:本来距离为 1 的两个 token,扩展后相当于距离为 1/16,而衰减规律在短距离时可能变化率极大,对相关性的评估可能偏离合理值。
应用线性内插时,在长文本上做少量步骤的微调也能明显改善性能。
5.3.3 NTK 扩展方式
这种方式综合了外推和内插的优点,做长度扩展后即使不微调也能保持较好的性能。
前面的分析表明:直接外推对衰减规律在长距离情况下的使用容易出问题,在短距离下不受影响;线性内插对衰减规律在短距离下的使用容易出问题,在长距离下影响较小。那么能否将两者综合——在短距离情况下具有外推特性(与扩展前基本一致),在长距离情况下具有内插特性(缩放到扩展前的范围)?
观察 RoPE 位置编码的元素计算公式,可以发现
为了在短距离情况下具有外推特性、长距离情况下具有内插特性,可以设计一个与频率相关的位置序号缩放因子:在最高频时取值为 1(与扩展前一致),在最低频时恰好为缩放倍数的倒数(缩放到扩展前的范围)。一种有效的选择方案是对 base 做指数缩放。NTK 扩展方式的要点是高频外推、低频内插,实现方法是直接对底数 base 进行缩放,类似进制编码转换。采用 NTK 扩展到长文本,即使不做微调,性能也仅略有下降。
6. 代码实现
旋转位置嵌入的简单实现使用前面所示的块对角矩阵形式。在实践中,这种实现方式效率较低,更优化的形式很容易获得。RoPE 的原始实现可在 roformer 和 bert4keras 中找到。
此外,在 x-transformers、GPT-Neo、GPT-NeoX 和 Mesh Transformer JAX 中也实现了旋转位置嵌入。以下是从这些代码库中提取的 PyTorch 实现。
import torch
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x, seq_dim=1):
seq_len = x.shape[seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[:, None, None, :]
self.sin_cached = emb.sin()[:, None, None, :]
return self.cos_cached, self.sin_cached
# Rotary pos emb helpers
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
# dim=-1 triggers a bug in torch < 1.8.0
return torch.cat((-x2, x1), dim=x1.ndim - 1)
@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
总结
从理论上看,RoPE 与 Sinusoidal 位置编码有相通之处,但 RoPE 不依赖泰勒展开,更具严谨性与可解释性。从预训练模型 RoFormer 的结果来看,RoPE 具有良好的外推性,应用到 Transformer 中体现出较好的处理长文本的能力。此外,RoPE 是目前唯一一种可用于线性 Attention 的相对位置编码。
参考文献
[1] RoFormer: Enhanced Transformer with Rotary Position Embedding
[2] Euler's Formula
[3] List of Trigonometric Identities
[4] LLaMA
[5] 旋转矩阵
[6] Jianlin Su. 让研究人员绞尽脑汁的 Transformer 位置编码. https://kexue.fm/archives/8130, 2021. [Online; accessed 18-April-2021].
[7] Jianlin Su. Transformer 升级之路:2、博采众长的旋转式位置编码. https://kexue.fm/archives/8265, 2021. [Online; accessed 18-April-2021].
[8] Jianlin Su, Yu Lu, Shengfeng Pan, Bo Wen, and Yunfeng Liu. RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv preprint arXiv:2104.09864, 2021.