Skip to content

旋转式位置编码(RoPE)

旋转式位置编码(Rotary Position Embedding,RoPE)最早由论文 RoFormer: Enhanced Transformer with Rotary Position Embedding 提出,是一种将相对位置信息集成到 self-attention 中并提升 Transformer 架构性能的位置编码方式。目前广受关注的 LLaMA 模型也采用了该位置编码方式。

1. 基本概念

首先定义一个长度为 N 的输入序列为:

SN={tokeni}i=1N

其中 tokeni 表示输入序列中第 i 个 token,而输入序列 SN 对应的 embedding 表示为:

EN={xi}i=1N

其中 xiRd 表示第 i 个 token 对应的 d 维词嵌入向量。

在执行 self-attention 之前,会用词嵌入向量计算 query、key、value 向量并同时加入位置信息:

qm=fq(xm,m),kn=fk(xn,n),vn=fv(xn,n),

其中 qm 表示第 m 个 token 对应的词向量 xm 集成位置信息 m 之后的 query 向量;knvn 则表示第 n 个 token 对应的词向量 xn 集成位置信息 n 之后的 key 和 value 向量。

基于 Transformer 的位置编码方法都着重于构造合适的 f{q,k,v} 函数形式。计算第 m 个词嵌入向量 xm 对应的 self-attention 输出结果时,qm 与所有 kn 计算 attention score,然后将 attention score 乘以对应的 vn 再求和,得到输出向量 om

am,n=exp(qmknd)j=1Nexp(qmkjd),om=n=1Nam,nvn

2. 绝对位置编码

对于位置编码,常规做法是在计算 query、key 和 value 向量之前,先计算一个位置编码向量 pi 加到词嵌入 xi 上。位置编码向量 pi 同样是 d 维向量,然后再乘以对应的变换矩阵 W{q,k,v}

f{q,k,v}(xi,i)=W{q,k,v}(xi+pi)

而经典的位置编码向量 pi 的计算方式是:

pi,2t=sin(i100002t/d),pi,2t+1=cos(i100002t/d),

其中 pi,2t 表示 d 维位置向量 pi 中第 2t 个分量(偶数索引位置)的计算公式,而 pi,2t+1 对应第 2t+1 个分量(奇数索引位置)的计算公式。

3. 旋转式位置编码

接下来介绍 Rotary Transformer(RoFormer)模型。它的主要改动是引入"旋转式位置编码(Rotary Position Embedding,RoPE)",这是一种配合 Attention 机制能达到"以绝对位置编码的方式实现相对位置编码"的设计。正因如此,它也是目前唯一一种可用于线性 Attention 的相对位置编码。

3.1 基本思路

在 RoPE 中,出发点是"通过绝对位置编码的方式实现相对位置编码"。这一设计既有理论上的优雅之处,也有实践上的实用价值,例如它可以扩展到线性 Attention 中。

在机器学习中,我们通常只关注实数,但对于旋转嵌入来说,使用复数作为空间的基域在数学上更为方便。先考虑二维情形,然后借助复数来求解。将 query 向量和 key 向量的元素视为单个复数,我们使用 Cd/2 而非通常的 Rd 空间来表示。具体而言,不再将 q=(q1,q2,q3,q4,,qd) 视为 d 维实数向量,而是将其视为 q=(q1+iq2,q3+iq4,,qd1+iqd)Cd/2。若 d 为奇数,可用零虚部填充以确保对齐。

qk 分别为 query 向量和 key 向量,mn 分别为相应 token 的绝对位置。假设 f(x,) 是一个函数,它接收位于位置 的嵌入 x,并输出一个包含相对位置信息的新嵌入。我们假设通过下述运算来给 q,k 添加绝对位置信息:

q~m=f(q,m),k~n=f(k,n)

也就是说,分别为 q,k 设计操作 f(,m),f(,n),使得经过该操作后,q~m,k~n 就带有了位置 m,n 的绝对位置信息。Attention 的核心运算是内积,因此我们希望内积的结果带有相对位置信息,假设存在恒等关系:

f(q,m),f(k,n)=g(q,k,mn)

因此需要给出该恒等式的一个尽可能简单的解。求解过程还需要初始条件,显然可以合理地设 f(q,0)=qf(k,0)=k

3.2 求解过程

在复数中有 q,k=Re[qk]Re[] 代表复数的实部,所以有:

Re[f(q,m)f(k,n)]=g(q,k,mn)

简单起见,假设存在复数 g(q,k,mn),使得 f(q,m)f(k,n)=g(q,k,mn)。然后用复数的指数形式,设:

f(q,m)=Rf(q,m)eiΘf(q,m),f(k,n)=Rf(k,n)eiΘf(k,n),g(q,k,mn)=Rg(q,k,mn)eiΘg(q,k,mn).

则:

Rf(q,m)Rf(k,n)=Rg(q,k,mn),Θf(q,m)Θf(k,n)=Θg(q,k,mn).

对于第一个方程,代入 m=n 得到:

Rf(q,m)Rf(k,m)=Rg(q,k,0)=Rf(q,0)Rf(k,0)=qk

最后一个等号源于初始条件 f(q,0)=qf(k,0)=k。因此可以直接设 Rf(q,m)=qRf(k,m)=k,即它不依赖于 m。至于第二个方程,同样代入 m=n 得到:

Θf(q,m)Θf(k,m)=Θg(q,k,0)=Θf(q,0)Θf(k,0)=Θ(q)Θ(k)

这里的 Θ(q),Θ(k)q,k 本身的幅角,最后一个等号同样源于初始条件。根据上式可得:

Θf(q,m)Θ(q)=Θf(k,m)Θ(k)

所以 Θf(q,m)Θ(q) 应是一个只与 m 相关、与 q 无关的函数,记为 φ(m),即 Θf(q,m)=Θ(q)+φ(m)。接着代入 n=m1,整理得到:

φ(m)φ(m1)=Θg(q,k,1)+Θ(k)Θ(q)

{φ(m)} 是等差数列,代入初始值 φ(0)=0,φ(1)=θ,解得 φ(m)=mθ

将前面所有的公式推导汇总,即可得到 Rotary Position Embedding 的最终表达式:

f(q,m)=Rf(q,m)eiΘf(q,m)=qei(Θ(q)+mθ)=j=1d/2qjeimθjej

因此,对于任意的 0<επ2N,其中 N 是最大序列长度。当按元素计算 qk 时,以 j 作为元素索引,RoPE 可以表示如下:

RoPE(x,m)=xemiε,RoPE(qj,m),RoPE(kj,n)=qjemiε,kjeniε=qjkjemiεeniε=qjkje(mn)iε=RoPE(qjkj,mn).

由于与复数相比,计算机更喜欢实数和矩阵,因此将此表达式转换为矩阵方程很方便:

f(q,m)=(M1M2Md/2)(q1q2qd)=ΘmQm=ΘmWqxm

其中:

Mj=(cosmθjsinmθjsinmθjcosmθj)

Θm 为块对角矩阵,Wq 为可学习的 query 权重,xm 为位置 m 处的嵌入。

3.3 编码形式

综上,我们得到二维情况下用复数表示的 RoPE:

f(q,m)=Rf(q,m)eiΘf(q,m)=qei(Θ(q)+mθ)=qeimθ

根据复数乘法的几何意义,该变换实际上对应着向量的旋转,因此称之为"旋转式位置编码"。它还可以写成矩阵形式:

f(q,m)=(cosmθsinmθsinmθcosmθ)(q0q1)

由于内积满足线性叠加性,任意偶数维的 RoPE 都可以表示为二维情形的拼接,即:

(cosmθ0sinmθ00000sinmθ0cosmθ0000000cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθd/21sinmθd/210000sinmθd/21cosmθd/21)Rm(q0q1q2q3qd2qd1)

也就是说,给位置为 m 的向量 q 乘上矩阵 Rm、位置为 n 的向量 k 乘上矩阵 Rn,用变换后的 Q,K 序列做 Attention,则 Attention 就自动包含相对位置信息,因为成立恒等式:

(Rmq)(Rnk)=qRmRnk=qRnmk

值得指出的是,Rm 是一个正交矩阵,它不会改变向量的模长,因此通常不会影响原模型的稳定性。

由于 Rm 的稀疏性,直接用矩阵乘法来实现会浪费算力,推荐通过下述方式来实现 RoPE:

(q0q1q2q3qd2qd1)(cosmθ0cosmθ0cosmθ1cosmθ1cosmθd/21cosmθd/21)+(q1q0q3q2qd1qd2)(sinmθ0sinmθ0sinmθ1sinmθ1sinmθd/21sinmθd/21)

其中 是逐元素相乘,即 NumPy、TensorFlow 等计算框架中的 * 运算。从这个实现也可以看到,RoPE 可以视为三角函数式位置编码的变体。

3.4 LLaMA 模型中的 RoPE

LLaMA 模型使用了 Rotary Position Embedding。对于 Q 的第 m 个位置向量 q,通过以下方式注入位置编码。

3.4.1 Step 1:初始化 θ 矩阵

(θ0θ1θd/21θ0θ1θd/21θ0θ1θd/21θ0θ1θd/212θ02θ12θd/212θ02θ12θd/21mθ0mθ1mθd/21mθ0mθ1mθd/21)

3.4.2 Step 2:计算 cos 矩阵和 sin 矩阵

(cosθ0cosθ1cosθd/21cosθ0cosθ1cosθd/21cosθ0cosθ1cosθd/21cosθ0cosθ1cosθd/21cos2θ0cos2θ1cos2θd/21cos2θ0cos2θ1cos2θd/21cosmθ0cosmθ1cosmθd/21cosmθ0cosmθ1cosmθd/21)

3.4.3 Step 3:计算 Query 向量

python
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)

对应公式:

(q0q1qd/21qd/2qd2qd1)(cosmθ0cosmθ1cosmθd/21cosmθ0cosmθ1cosmθd/21)+(qd/2 qd/2+1qd1q0q1qd/21)(sinmθ0sinmθ1sinmθd/21sinmθ0sinmθ1sinmθd/21)

4. RoPE 证明过程

4.1 简单证明

简单起见,先假设 qm,kn 是所在位置分别为 m,n 的二维行向量。既然是二维,可以将其当作复数来运算。Attention 的关键之处在于向量的内积,用复数表示为:

qm,kn=Re[qmkn]

其中 是共轭复数,右端的乘法是普通的复数乘法,Re[] 表示取结果的实部。上式意味着:如果将 qm,kn 分别乘以 eimθ,einθ 变成 qmeimθ,kneinθ,那么就相当于给它们配上了绝对位置编码(因为显式地依赖绝对位置 m,n)。然后代入内积,有:

qmeimθ,kneinθ=Re[(qmeimθ)(kneinθ)]=Re[qmknei(mn)θ]

值得注意的是,内积只依赖于相对位置 mn。这就巧妙地将绝对位置与相对位置融合在了一起。

由上述结果可知,对于位置为 n 的二维实数向量 [x,y],将其当作复数运算并乘以 einθ,得到恒等式:

(x+yi)einθ=(xcosnθysinnθ)+i(xsinnθ+ycosnθ)

这意味着,通过

(xy)(xcosnθysinnθxsinnθ+ycosnθ)=(xy)cosnθ+(yx)sinnθ

来赋予 [x,y] 绝对位置信息,那么在 Attention 运算时就等价于相对位置编码。如果是多于二维的向量,可以每两维为一组执行同样的运算,每组的 θ 可以不同。

这样一来,我们得到了一种融合绝对位置与相对位置的位置编码方案。从形式上看它类似乘性的绝对位置编码:通过在 q,k 中施加该位置编码,效果等价于相对位置编码。如果还需要显式的绝对位置信息,则可以同时在 v 上施加该编码。

4.2 完整证明

假定 query 向量 qm 和 key 向量 kn 之间的内积操作可以用函数 g 表示,该函数的输入是词嵌入向量 xmxn 和它们之间的相对位置 mn

fq(xm,m),fk(xn,n)=g(xm,xn,mn)

我们的目标是找到一个等价的位置编码方式,使得上述关系成立,即构造出函数 fg,使得上述等式成立。

假定词嵌入向量的维度为二维 d=2,这样就可以利用二维平面上向量的几何性质。论文中提出了满足上述关系的 fg 的形式如下:

fq(xm,m)=(Wqxm)eimθ,fk(xn,n)=(Wkxn)einθ,g(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ],

这里 Re 表示复数的实部。

首先看到上述 fg 公式中有个指数函数 eix,这是欧拉公式,其中 x 表示任意实数,e 是自然对数的底数,i 是复数中的虚数单位。根据欧拉公式有:

eix=cosx+isinx

即上述指数函数可以表示为实部为 cosx、虚部为 sinx 的复数。欧拉公式建立了指数函数、三角函数和复数之间的桥梁。

则上述 fg 公式中:

eimθ=cos(mθ)+isin(mθ),einθ=cos(nθ)+isin(nθ),ei(mn)θ=cos((mn)θ)+isin((mn)θ).

然后看回公式:

fq(xm,m)=(Wqxm)eimθ

其中 Wq 是个二维矩阵,xm 是个二维向量,相乘结果也是一个二维向量,用 qm 表示:

qm=(qm(1)qm(2))=Wqxm=(Wq(11)Wq(12)Wq(21)Wq(22))(xm(1)xm(2))

首先将 qm 表示成复数形式:

qm=[qm(1),qm(2)]=qm(1)+iqm(2)

接着:

fq(xm,m)=(Wqxm)eimθ=qmeimθ

其实就是两个复数相乘:

(a+ib)(c+id)=ac+ibc+iad+i2bd=(acbd)+i(bc+ad)

复数乘法使用分配律,并利用 i2=1 的性质。代入可得:

qmeimθ=(qm(1)+iqm(2))(cos(mθ)+isin(mθ))

复习一下复数乘法的性质:

qmeimθ=(qm(1)+iqm(2))(cos(mθ)+isin(mθ))=(qm(1)cos(mθ)qm(2)sin(mθ))+i(qm(2)cos(mθ)+qm(1)sin(mθ))

将结果重新表达成实数向量形式就是:

qmeimθ=[qm(1)cos(mθ)qm(2)sin(mθ),qm(2)cos(mθ)+qm(1)sin(mθ)]

因此:

fq(xm,m)=(Wqxm)eimθ=qmeimθ=[qm(1)cos(mθ)qm(2)sin(mθ),qm(2)cos(mθ)+qm(1)sin(mθ)]=(cos(mθ)sin(mθ)sin(mθ)cos(mθ))(qm(1)qm(2))

看到这里会发现,这就是 query 向量乘以了一个旋转矩阵。这就是"旋转位置编码"名称的由来。

同理,fk 可以表示成下面的式子:

fk(xn,n)=(cosnθsinnθsinnθcosnθ)(Wk(11)Wk(12)Wk(21)Wk(22))(xn(1)xn(2))=(cosnθsinnθsinnθcosnθ)(kn(1)kn(2))

最后还有个函数 g

g(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ]

其中 Re[x] 表示复数 x 的实部,而 (Wkxn) 表示复数 Wkxn 的共轭。

复习一下共轭复数的定义:

z=a+ib,z=aib

所以可得:

Wqxm=qm=qm(1)+iqm(2),Wkxn=kn=kn(1)+ikn(2),(Wkxn)=kn=kn(1)ikn(2),ei(mn)θ=cos((mn)θ)+isin((mn)θ).

继续可得:

g(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ]=Re[(qm(1)+iqm(2))(kn(1)ikn(2))(cos((mn)θ)+isin((mn)θ))]=Re[((qm(1)kn(1)+qm(2)kn(2))+i(qm(2)kn(1)qm(1)kn(2)))(cos((mn)θ)+isin((mn)θ))]=(qm(1)kn(1)+qm(2)kn(2))cos((mn)θ)(qm(2)kn(1)qm(1)kn(2))sin((mn)θ)

接下来我们就要证明函数 g 的计算公式是成立的。

首先回顾一下 attention 操作,位置 m 的 query 和位置 n 的 key 会做一个内积操作:

fq(xm,m)=[qm(1)cos(mθ)qm(2)sin(mθ),qm(2)cos(mθ)+qm(1)sin(mθ)],fk(xn,n)=[kn(1)cos(nθ)kn(2)sin(nθ),kn(2)cos(nθ)+kn(1)sin(nθ)],fq(xm,m),fk(xn,n)=(qm(1)cos(mθ)qm(2)sin(mθ))(kn(1)cos(nθ)kn(2)sin(nθ))+(qm(2)cos(mθ)+qm(1)sin(mθ))(kn(2)cos(nθ)+kn(1)sin(nθ))=qm(1)kn(1)cos(mθ)cos(nθ)qm(1)kn(2)cos(mθ)sin(nθ)qm(2)kn(1)sin(mθ)cos(nθ)+qm(2)kn(2)sin(mθ)sin(nθ)+qm(2)kn(2)cos(mθ)cos(nθ)+qm(2)kn(1)cos(mθ)sin(nθ)+qm(1)kn(2)sin(mθ)cos(nθ)+qm(1)kn(1)sin(mθ)sin(nθ)

接着继续之前先复习一下三角函数的和差公式:

sin(a+b)=sinacosb+cosasinb,sin(ab)=sinacosbcosasinb,cos(a+b)=cosacosbsinasinb,cos(ab)=cosacosb+sinasinb.

回到上面的式子,整理得到:

fq(xm,m),fk(xn,n)=qm(1)kn(1)(cos(mθ)cos(nθ)+sin(mθ)sin(nθ))+qm(1)kn(2)(cos(mθ)sin(nθ)+sin(mθ)cos(nθ))+qm(2)kn(1)(sin(mθ)cos(nθ)+cos(mθ)sin(nθ))+qm(2)kn(2)(sin(mθ)sin(nθ)+cos(mθ)cos(nθ))=qm(1)kn(1)cos((mn)θ)+qm(1)kn(2)sin((mn)θ)qm(2)kn(1)sin((mn)θ)+qm(2)kn(2)cos((mn)θ)=(qm(1)kn(1)+qm(2)kn(2))cos((mn)θ)+(qm(1)kn(2)qm(2)kn(1))sin((mn)θ)=(qm(1)kn(1)+qm(2)kn(2))cos((mn)θ)(qm(2)kn(1)qm(1)kn(2))sin((mn)θ)=g(xm,xn,mn)

这就证明了上述关系成立:位置 m 的 query 和位置 n 的 key 的内积即为函数 g

把上面的式子用矩阵向量乘的形式来表达就是:

fq(xm,m),fk(xn,n)=[(cos(mθ)sin(mθ)sin(mθ)cos(mθ))(qm(1)qm(2))][(cos(nθ)sin(nθ)sin(nθ)cos(nθ))(kn(1)kn(2))]=(qm(1)qm(2))(cos(mθ)sin(mθ)sin(mθ)cos(mθ))(cos(nθ)sin(nθ)sin(nθ)cos(nθ))(kn(1)kn(2))=(qm(1)qm(2))(cos(mθ)cos(nθ)+sin(mθ)sin(nθ)cos(mθ)sin(nθ)+sin(mθ)cos(nθ)sin(mθ)cos(nθ)+cos(mθ)sin(nθ)sin(mθ)sin(nθ)+cos(mθ)cos(nθ))(kn(1)kn(2))=(qm(1)qm(2))(cos((mn)θ)sin((mn)θ)sin((mn)θ)cos((mn)θ))(kn(1)kn(2))

上面的推导假定词嵌入维度为 2 维向量。对于 d2 的通用情况,将词嵌入向量元素按两两一组分组,每组应用同样的旋转操作,且每组的旋转角度计算方式如下:

θj=100002(j1)/d,j[1,2,,d/2]

综上,RoPE 的 self-attention 操作流程为:对 token 序列中的每个词嵌入向量,首先计算对应的 query 和 key 向量,然后对每个 token 位置计算对应的旋转位置编码,接着对 query 和 key 向量的元素按两两一组应用旋转变换,最后计算 query 和 key 之间的内积得到 self-attention 的计算结果。

5. RoPE 的性质

5.1 远程衰减

可以看到,RoPE 形式上和 Sinusoidal 位置编码有一定相似性,只不过 Sinusoidal 位置编码是加性的,而 RoPE 可视为乘性的。在 θi 的选择上,同样沿用了 Sinusoidal 位置编码的方案,即 θi=100002i/d,它可以带来一定的远程衰减性。

具体证明如下:将 q,k 两两分组后,加上 RoPE 后的内积可以用复数乘法表示为:

(Rmq)(Rnk)=Re[i=0d/21q[2i:2i+1]k[2i:2i+1]ei(mn)θi]

hi=q[2i:2i+1]k[2i:2i+1]Sj=i=0j1ei(mn)θi,并约定 hd/2=0,S0=0,由 Abel 变换(分部求和法)可以得到:

i=0d/21q[2i:2i+1]k[2i:2i+1]ei(mn)θi=i=0d/21hi(Si+1Si)=i=0d/21Si+1(hi+1hi)

所以:

|i=0d/21q[2i:2i+1]k[2i:2i+1]ei(mn)θi|=|i=0d/21Si+1(hi+1hi)|i=0d/21|Si+1||hi+1hi|(maxi|hi+1hi|)i=0d/21|Si+1|

因此可以考察 1d/2i=1d/2|Si| 随相对距离的变化情况来体现衰减性。可以观察到随着相对距离增大,内积结果呈现衰减趋势。因此,选择 θi=100002i/d 确实能带来一定的远程衰减性。

5.2 线性场景

最后指出,RoPE 是目前唯一一种可用于线性 Attention 的相对位置编码。这是因为其他相对位置编码直接基于 Attention 矩阵进行操作,而线性 Attention 并不事先计算 Attention 矩阵,因此无法应用。RoPE 以绝对位置编码的方式实现相对位置编码,不需要操作 Attention 矩阵,因而具备应用到线性 Attention 的可能性。

线性 Attention 的常见形式是:

Attention(Q,K,V)i=j=1nsim(qi,kj)vjj=1nsim(qi,kj)=j=1nϕ(qi)φ(kj)vjj=1nϕ(qi)φ(kj)

其中 ϕ,φ 是值域非负的激活函数。可以看到,线性 Attention 也是基于内积的,因此很自然的想法是将 RoPE 插入到内积中:

j=1n[Riϕ(qi)][Rjφ(kj)]vjj=1n[Riϕ(qi)][Rjφ(kj)]

但这样存在的问题是,内积 [Riϕ(qi)][Rjφ(kj)] 可能为负数,因此不再是常规的概率注意力,且分母有为零的风险,可能带来优化上的不稳定。考虑到 Ri,Rj 都是正交矩阵,不改变向量的模长,因此可以抛弃常规的概率归一化要求,使用如下运算作为一种新的线性 Attention:

j=1n[Riϕ(qi)][Rjφ(kj)]vjj=1nϕ(qi)φ(kj)

也就是说,RoPE 只插入分子中,分母保持不变。这样的注意力不再是基于概率的(注意力矩阵不再满足非负归一性),但某种意义上也是一种归一化方案。目前也没有证据表明非概率式的注意力效果更差(例如 Nyströmformer 也未严格依据概率分布构建注意力)。因此将其作为候选方案之一进行实验,初步实验结果显示这样的线性 Attention 也是有效的。

5.3 RoPE 的长度扩展

在 LLM 的应用中,有一个非常重要的参数——上下文长度(max context length)。更长的上下文长度允许进行更多轮次的对话、对更长的文本进行总结分析,也允许生成更长的文章。然而在训练 LLM 时,训练语料大部分不够长,许多 LLM 训练时设计的最大文本长度仅为 2k(即最长 2048 个 token)。那么,能否在训练时使用较短的文本,而在推理时扩展到长文本上呢?

这是可行的,可以对 RoPE 进行长度扩展。下面介绍三种扩展方案。

5.3.1 直接外推

直接外推即继续沿用现有位置编码公式,不做任何修改。在扩展长度不太大时(例如由 2k 扩展到 2.5k),此方法对性能的影响不大。旋转位置编码只与相对位置 mn 的大小有关,通常具有远程衰减性,即相对距离越大的两个 token 相关性越弱。

因此,如果模型已从训练数据中学习到 token 之间在 0-2k 范围内合适的衰减规律,将其应用到 0-2.5k 通常也没有问题。但若扩展到更长的长度(例如从 2k 扩展到 32k),直接外推通常会严重影响性能。因为学习到的衰减规律可能在 5k 处就完全衰减为零,导致无法捕捉超过 5k 相对距离的 token 之间的相互作用。

总结:直接外推对衰减规律在长距离情况下的使用容易出现问题。为减少性能影响,可以让训练好的模型在更长的上下文上做少量步骤的微调。

5.3.2 线性内插

线性内插需要改变位置编码公式,等效于将位置序号等比例缩小。

例如从 2k 扩展到 32k 时,等效于将位置序号缩小为原来的 1/16。线性内插未改变模型学习到的衰减规律的应用范围,不做微调时其效果一般优于直接外推方案。但当扩展倍数非常大时(如从 2k 到 32k),性能也会明显受影响。原因在于短距离情况下的使用受到较大影响:本来距离为 1 的两个 token,扩展后相当于距离为 1/16,而衰减规律在短距离时可能变化率极大,对相关性的评估可能偏离合理值。

应用线性内插时,在长文本上做少量步骤的微调也能明显改善性能。

5.3.3 NTK 扩展方式

这种方式综合了外推和内插的优点,做长度扩展后即使不微调也能保持较好的性能。

前面的分析表明:直接外推对衰减规律在长距离情况下的使用容易出问题,在短距离下不受影响;线性内插对衰减规律在短距离下的使用容易出问题,在长距离下影响较小。那么能否将两者综合——在短距离情况下具有外推特性(与扩展前基本一致),在长距离情况下具有内插特性(缩放到扩展前的范围)?

观察 RoPE 位置编码的元素计算公式,可以发现 i 越大,三角函数对应的角频率系数越小(即越低频),三角函数变化越慢。由此可得到直观结论:短距离之间的差异主要体现在高频分量(i 较小)上;长距离之间的差异主要体现在低频分量(i 较大)上。

为了在短距离情况下具有外推特性、长距离情况下具有内插特性,可以设计一个与频率相关的位置序号缩放因子:在最高频时取值为 1(与扩展前一致),在最低频时恰好为缩放倍数的倒数(缩放到扩展前的范围)。一种有效的选择方案是对 base 做指数缩放。NTK 扩展方式的要点是高频外推、低频内插,实现方法是直接对底数 base 进行缩放,类似进制编码转换。采用 NTK 扩展到长文本,即使不做微调,性能也仅略有下降。

6. 代码实现

旋转位置嵌入的简单实现使用前面所示的块对角矩阵形式。在实践中,这种实现方式效率较低,更优化的形式很容易获得。RoPE 的原始实现可在 roformerbert4keras 中找到。

此外,在 x-transformersGPT-NeoGPT-NeoXMesh Transformer JAX 中也实现了旋转位置嵌入。以下是从这些代码库中提取的 PyTorch 实现。

python
import torch


class Rotary(torch.nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_dim=1):
        seq_len = x.shape[seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[:, None, None, :]
            self.sin_cached = emb.sin()[:, None, None, :]
        return self.cos_cached, self.sin_cached


# Rotary pos emb helpers

def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    # dim=-1 triggers a bug in torch < 1.8.0
    return torch.cat((-x2, x1), dim=x1.ndim - 1)


@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

总结

从理论上看,RoPE 与 Sinusoidal 位置编码有相通之处,但 RoPE 不依赖泰勒展开,更具严谨性与可解释性。从预训练模型 RoFormer 的结果来看,RoPE 具有良好的外推性,应用到 Transformer 中体现出较好的处理长文本的能力。此外,RoPE 是目前唯一一种可用于线性 Attention 的相对位置编码。

参考文献

[1] RoFormer: Enhanced Transformer with Rotary Position Embedding

[2] Euler's Formula

[3] List of Trigonometric Identities

[4] LLaMA

[5] 旋转矩阵

[6] Jianlin Su. 让研究人员绞尽脑汁的 Transformer 位置编码. https://kexue.fm/archives/8130, 2021. [Online; accessed 18-April-2021].

[7] Jianlin Su. Transformer 升级之路:2、博采众长的旋转式位置编码. https://kexue.fm/archives/8265, 2021. [Online; accessed 18-April-2021].

[8] Jianlin Su, Yu Lu, Shengfeng Pan, Bo Wen, and Yunfeng Liu. RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv preprint arXiv:2104.09864, 2021.

Maintained by Robin