Skip to content

Transformer 位置编码综述

不同于 RNN、CNN 等模型,Transformer 必须引入位置编码,因为纯粹的 Attention 模块无法捕捉输入顺序,即无法区分不同位置的 token。为此我们有两种选择:

  1. 将位置信息融入输入,这构成了绝对位置编码的一般做法;
  2. 调整 Attention 结构使其能分辨不同位置的 token,这构成了相对位置编码的一般做法。

Transformer 中的自注意力机制无法捕捉位置信息,因为其计算过程具有置换不变性(permutation invariant),打乱输入序列的顺序不会影响输出结果。

对于 Transformer 模型 f(),标记输入序列的两个向量 xm,xn,则 Transformer 具有全对称性

f(,xm,,xn,)=f(,xn,,xm,)

位置编码(Position Encoding) 通过将位置信息引入输入序列,以打破模型的全对称性。为简化问题,考虑在 m,n 位置处加上不同位置编码 pm,pn

f~(,xm,,xn,)=f(,xm+pm,,xn+pn,)

对上式进行二阶 Taylor 展开:

f~f+pmfxm+pnfxn+pm2fxm2pm+pn2fxn2pn绝对位置信息+pm2fxmxnpn相对位置信息

在上式中,第 2 至第 5 项仅依赖于单一位置,表示绝对位置信息;第 6 项包含 m,n 位置的交互项,表示相对位置信息。因此位置编码主要有两种实现形式:

  • 绝对位置编码(absolute PE):将位置信息加入到输入序列中,相当于引入索引的嵌入。例如 Sinusoidal、Learnable、FLOATER、Complex-order、RoPE。
  • 相对位置编码(relative PE):通过调整自注意力运算过程,使其能分辨不同 token 之间的相对位置。例如 XLNet、T5、DeBERTa、URPE。

1. 绝对位置编码(Absolute Position Encoding)

绝对位置编码是指在输入序列经过词嵌入后的第 k 个 token 向量 xkRd 中加入(add)位置向量 pkRd。其过程等价于先向输入拼接(concatenate)位置索引 k 的 one-hot 向量 pk:xk+pk,再进行词嵌入。因此绝对位置编码也被称为位置嵌入(Position Embedding)

1.1 三角函数式(Sinusoidal)位置编码

三角函数式(Sinusoidal)位置编码是原 Transformer 模型中使用的一种显式编码。以一维三角函数编码为例:

pk,2i=sin(k100002i/d),pk,2i+1=cos(k100002i/d),

其中 pk,2i,pk,2i+1 分别是位置索引 k 处的编码向量的第 2i,2i+1 个分量。一个长度为 32 的输入序列(每个输入向量的特征维度为 128)的 Sinusoidal 编码可视化如下:

以下是正弦位置编码的 PyTorch 实现:

python
import numpy as np
import torch


def sinusoidal_encoding_1d(seq_len, d_model):
    pos_table = np.array([
        [pos / np.power(10000, 2 * i / d_model) for i in range(d_model)]
        for pos in range(seq_len)
    ])
    # pos_table[0] 作用于 [CLS],不需要位置编码
    pos_table[1:, 0::2] = np.sin(pos_table[1:, 0::2])
    pos_table[1:, 1::2] = np.cos(pos_table[1:, 1::2])
    return torch.FloatTensor(pos_table)

三角函数式位置编码具有显式的生成规律,因此可以期望它具有一定的外推性。根据三角函数的性质,位置 α+β 处的编码向量可以表示为位置 α 和位置 β 的向量的组合,因此可以外推到任意位置:

sin(α+β)=sinαcosβ+cosαsinβ,cos(α+β)=cosαcosβsinαsinβ.

在图像领域,常用到二维形式的位置编码。以二维三角函数编码为例,需要分别对高度方向和宽度方向进行编码 p=[ph,pw]

ph,2i=sin(h100002i/d),ph,2i+1=cos(h100002i/d),pw,2i=sin(w100002i/d),pw,2i+1=cos(w100002i/d).
python
import math
import torch


def positional_encoding_2d(d_model, height, width):
    """
    :param d_model: dimension of the model
    :param height: height of the positions
    :param width: width of the positions
    :return: d_model x height x width position matrix
    """
    if d_model % 4 != 0:
        raise ValueError(
            "Cannot use sin/cos positional encoding with odd dimension (got dim={:d})".format(d_model)
        )
    pe = torch.zeros(d_model, height, width)
    # Each dimension uses half of d_model
    d_model = d_model // 2
    div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model))
    pos_w = torch.arange(0.0, width).unsqueeze(1)
    pos_h = torch.arange(0.0, height).unsqueeze(1)
    pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
    pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
    pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
    return pe

1.2 可学习(Learnable)位置编码

可学习(Learnable)位置编码是指将位置编码作为可训练参数。例如输入序列(经过嵌入层后)的大小为 n×d,则随机初始化一个 PRn×d 的矩阵作为位置编码,随训练过程更新。

可学习位置编码的缺点是缺乏外推性,即如果预训练序列的最大长度为 n,则无法处理长度超过 n 的序列。此时可以将超过 n 部分的位置编码随机初始化并微调。

1.3 FLOATER:递归式位置编码

原则上,RNN 模型不需要位置编码,其结构本身就具备学习位置信息的能力(因为递归本身意味着可以训练一个"计数"模型)。因此,如果在输入后先接一层 RNN,再接 Transformer,理论上就不需要额外的位置编码。

同理,也可以用 RNN 模型来学习一种绝对位置编码。例如从一个向量 p0 出发,通过递归格式 pk+1=f(pk) 得到各个位置的编码向量。如果位置编码能递归生成,则其生成结构自带学习位置信息的能力。

FLOATER 使用神经常微分方程(Neural ODE)构建的连续动力系统对位置编码进行递归建模:

p(t)=p(s)+sth(τ,p(τ);θh)dτ

理论上,基于递归模型的 FLOATER 位置编码也具有较好的外推性,同时比三角函数式位置编码更具灵活性(例如可以证明三角函数式位置编码是 FLOATER 的某个特解)。但递归形式的位置编码牺牲了一定的并行性,可能成为速度瓶颈。

1.4 Complex-order:复数式位置编码

绝对位置编码等价于词表索引 j 的词嵌入 x(j) 与位置索引 k 的嵌入 pk 的求和函数 f(j,k)=xk(j)+pk。Complex-order 方法则直接将该函数建模为一个复值函数:

f(j,k)=rjei(θj+ωjk)

其中振幅 rj、角频率 ωj 和初相位 θj 是待学习的参数(为同一个词设置三组词嵌入)。振幅 rj 仅与词表索引 j 有关,相当于该词的词嵌入;角频率 ωj 表示该词对位置的敏感程度;相位 θj+ωjpos 引入该词在文本中的位置信息。

1.5 RoPE:旋转式位置编码

旋转式位置编码(Rotary Position Embedding,RoPE)是一种通过绝对位置编码的方式实现相对位置编码的方案。其核心思想是在构造查询矩阵 q 和键矩阵 k 时,根据其绝对位置引入旋转矩阵 R

补充说明:关于 RoPE 的完整数学推导与实现细节,请参阅旋转式位置编码(RoPE)

qi=RixiWQ,kj=RjxjWK

旋转矩阵 R 设计为正交矩阵,且应满足 RiRj=Rji,使得后续注意力矩阵的计算中隐式地包含相对位置信息:

(RixiWQ)(RjxjWK)=(xiWQ)RiRjxjWK=(xiWQ)RjixjWK

1.6 层次化位置编码

在可学习的位置编码中,假设学习到序列长度为 n 的编码,则难以外推到序列长度 >n 的场合。层次化位置编码通过对现有编码进行层次分解,利用 n 个编码构造长度为 n2 的一系列编码。

假设学习到位置编码 p1,p2,,pn;现构造位置编码 q1,q2,,qn2,其由基编码 u1,u2,,un 分层构造:

q(i1)×n+j=αui+(1α)uj

其中 α0.5 是为了区分 (i,j)(j,i) 两种不同的情况。假设前 n 个构造的编码与学习到的编码一致:

qi=pi,i=1,2,,n

则可以解出基编码:

ui=piαp11α,i=1,2,,n

2. 相对位置编码(Relative Position Encoding)

相对位置编码并非直接建模每个输入 token 的位置信息,而是在计算注意力矩阵时考虑当前向量与待交互向量之间的相对距离。由于自然语言通常更依赖于相对位置,因此相对位置编码一般也有优秀的表现。

从绝对位置编码出发,其形式相当于在输入中添加绝对位置的表示。对应的完整自注意力机制运算如下:

{qi=(xi+pi)WQ,kj=(xj+pj)WK,vj=(xj+pj)WV,αi,j=softmax(qikj),oi=jαi,jvj,

其中 softmaxj 维度归一化,此处向量均为行向量。将 qikj 展开:

qikj=(xi+pi)WQWK(xj+pj)=(xiWQ+piWQ)(WKxj+WKpj)

为了引入相对位置信息,Google 移除了第一项中的位置编码,并将第二项 pjWK 替换为二元位置向量 Ri,jK,变为:

αi,j=softmax(xiWQ(xjWK+Ri,jK))

注意到绝对位置编码相当于在自注意力运算中引入了一系列 piWQ,(pjWK),pjWV 项。而相对位置编码的出发点是将这些项调整为与相对位置 (i,j) 有关的向量 Ri,j

2.1 经典相对位置编码

在经典的相对位置编码设置中,移除了与 xi 的位置编码项 piWQ 相关的项,并将 xj 的位置编码项 pjWV,pjWK 替换为相对位置向量 Ri,jV,Ri,jK

αij=softmax{xiWQ(WK)xj+xiWQ(Ri,jK)},zi=j=1nαij(xjWV+Ri,jV),

相对位置向量 Ri,jV,Ri,jK 可以设置为三角函数式或可学习参数,且通常只考虑相对位置 pminijpmax 的情况:

Ri,jK=wclip(ji,pmin,pmax)K{wpminK,,wpmaxK},Ri,jV=wclip(ji,pmin,pmax)V{wpminV,,wpmaxV}.

2.2 XLNet 式

XLNet 式位置编码实际源自 Transformer-XL 的论文《Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context》。由于采用 Transformer-XL 架构的 XLNet 模型在一定程度上超越了 BERT,Transformer-XL 才广为人知,因此这种位置编码通常也被冠以 XLNet 之名。

在 XLNet 模型中,移除了值向量的位置编码 pj,并将注意力计算中 xj 的位置编码 pj 替换为相对位置向量 Rij(设置为三角函数式编码),xi 的位置编码 pi 设置为可学习向量 u,v

αij=softmax{xiWQ(WK)xj+xiWQ(WK)Rij+uWQ(WK)xj+vWQ(WK)Rij},zi=j=1nαijxjWV.

2.3 T5 式

T5 模型出自论文《Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer》,其中采用了一种更简洁的相对位置编码方案。思路同样源自展开式。分析各项含义,可分别理解为"输入-输入"、"输入-位置"、"位置-输入"、"位置-位置"四项注意力的组合。如果认为输入信息与位置信息应当独立(解耦),则不应有过多交互。因此,"输入-位置"项 (xi,pj) 和"位置-输入"项 (pi,xj) 可以删除,而 piWQWKpj 实际上只是一个仅依赖于 (i,j) 的标量,可以直接作为参数训练。简化为:

αij=softmax{xiWQ(WK)xj+ri,j},zi=j=1nαijxjWV.

一维形式的 T5 式相对位置编码实现过程如下:

python
import torch
import torch.nn as nn
from einops import rearrange


class Attention(nn.Module):
    def __init__(self, dim, seq_len, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        # Positional bias
        self.pos_bias = nn.Embedding(seq_len, heads)

        q_pos = torch.arange(seq_len)
        k_pos = torch.arange(seq_len)
        pos_indices = (q_pos[:, None] - k_pos[None, :]).abs()
        self.register_buffer("pos_indices", pos_indices)

    def apply_pos_bias(self, fmap):
        bias = self.pos_bias(self.pos_indices)
        bias = rearrange(bias, "i j h -> () h i j")
        return fmap + (bias / self.scale)

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # 引入相对位置编码
        dots = self.apply_pos_bias(dots)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")
        return self.to_out(out)

二维形式的 T5 式相对位置编码实现过程如下:

python
import torch
import torch.nn as nn
from einops import rearrange


class Attention2D(nn.Module):
    def __init__(self, dim, fmap_size, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        # Positional bias
        self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)

        q_range = torch.arange(fmap_size)
        k_range = torch.arange(fmap_size)

        q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing="ij"), dim=-1)
        k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing="ij"), dim=-1)

        q_pos, k_pos = map(lambda t: rearrange(t, "i j c -> (i j) c"), (q_pos, k_pos))
        rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()

        x_rel, y_rel = rel_pos.unbind(dim=-1)
        pos_indices = (x_rel * fmap_size) + y_rel

        self.register_buffer("pos_indices", pos_indices)

    def apply_pos_bias(self, fmap):
        bias = self.pos_bias(self.pos_indices)
        bias = rearrange(bias, "i j h -> () h i j")
        return fmap + (bias / self.scale)

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        # 引入相对位置编码
        dots = self.apply_pos_bias(dots)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, "b h n d -> b n (h d)")
        return self.to_out(out)

2.4 DeBERTa 式

在 DeBERTa 模型中,移除了值向量的位置编码 pj 以及注意力计算中的位置-位置注意力项 (pi,pj),并将注意力计算中 xi,xj 的位置编码 pi,pj 替换为相对位置向量 Rj,i,Ri,j

αij=softmax{xiWQ(WK)xj+xiWQ(WK)Ri,j+Rj,iWQ(WK)xj},zi=j=1nαijxjWV.

2.5 Universal RPE(URPE)

注意到在相对位置编码中,如果移除值向量的位置编码 pj,会使模型丧失通用函数近似能力。通用相对位置编码(Universal RPE)引入如下约束:

zi=j=1nαijcijxjWV

其中 C=[cij] 是一个可训练的 Toeplitz 矩阵:cij=g(ij),它与 Attention 矩阵逐元素相乘。尽管这使得 Attention 矩阵不再是按行的概率矩阵,但恢复了模型的通用近似性。

3. 其他位置编码

绝对位置编码和相对位置编码虽然花样繁多,但仍属经典范畴。除此之外,还有一些非常规方案同样能够表达位置编码。

3.1 CNN 式

尽管经典的 CNN 用于 NLP 的工作《Convolutional Sequence to Sequence Learning》加入了位置编码,但一般的 CNN 模型(尤其是图像领域的 CNN 模型)并未额外加位置编码。那么 CNN 模型究竟如何捕捉位置信息?

一种可能的解释是卷积核的各向异性使其能分辨不同方向的相对位置。然而 ICLR 2020 的论文《How Much Position Information Do Convolutional Neural Networks Encode?》给出了一个出人意料的答案:CNN 模型的位置信息是由 Zero Padding 泄漏的。

为使卷积编码过程中的特征图保持一定大小,通常会对输入填充(padding)一定数量的 0。该论文表明此操作使模型具备了识别位置信息的能力。也就是说,卷积核的各向异性固然重要,但最根本的因素是 Zero Padding 的存在。可以推测,模型实际提取的是当前位置与 padding 边界的相对距离。

不过,这一能力依赖于 CNN 的局部性。像 Attention 这种全局性的无先验结构并不适用。对于只关心 Transformer 位置编码方案的读者,权当扩展视野。

参考文献

Maintained by Robin