Skip to content

KV Cache 机制详解

在生成式 Transformer 模型中,键值状态(Key-Value States)缓存技术能够显著加速解码过程。本文将介绍 KV Cache 的核心原理、数学表达、基础实现以及性能影响。

注意:KV Cache 仅存在于自回归解码器模型(如 GPT、LLaMA 等)的生成阶段。BERT 等非生成式模型不使用 KV Cache。

1. 自回归解码的冗余计算

解码器以自回归方式工作:模型根据输入预测下一个 Token,再将该 Token 拼回输入进行下一步预测。由于因果掩码(Causal Mask)的限制,每个新 Token 只能关注它及之前的 Token,因此历史 Token 的 Key 和 Value 会在每一步被重复计算。

例如,在第三步输入 "She poured coffee" 时,"She" 与 "poured" 的 Key/Value 已经被计算过。KV Cache 通过持久化存储这些历史状态,避免重复计算。

1.1 朴素自回归生成示例

python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class GreedySampler:
    def __init__(self, model_name="gpt2-medium"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)

    def __call__(self, prompt, max_new_tokens=10):
        result = prompt
        for i in range(max_new_tokens):
            input_ids = self.tokenizer.encode(result, return_tensors="pt").to(self.device)
            with torch.no_grad():
                logits = self.model(input_ids).logits[0, -1, :]
            next_token_id = torch.argmax(logits, dim=-1).item()
            result += self.tokenizer.decode(next_token_id)
            print(f"step {i}: {result}")
        return result

不使用 KV Cache 时,每步输入序列都会变长,计算量随之线性增长。

2. KV Cache 的工作原理

KV Cache 在推理时保存每个 Token 的 Key 和 Value 向量。生成第 t+1 个 Token 时,只需计算当前 Token 的 Query、Key、Value,将 Key 和 Value 追加到缓存中,再与历史 Key/Value 一起做注意力计算。

2.1 计算复杂度对比

  • 不使用 KV Cache:每步计算量为 O(t2),随序列长度平方增长。
  • 使用 KV Cache:每步计算量降为 O(t),仅需计算新 Token 与所有历史 Token 的注意力;显存开销为 O(t)

2.2 注意力计算示意

设第 l 层在第 t 步的输入为 xtlRb×1×h,其中 b 为 batch size,h 为隐藏维度。

KV Cache 更新

KlConcat(Kl,xtlWKl)VlConcat(Vl,xtlWVl)

Query 计算

qtl=xtlWQl

注意力输出

otl=softmax(qtlKlh)VlWOl+xtl

前馈网络

xtl+1=factivation(otlW1l)W2l+otl

3. 显存占用估算

对于一个拥有 L 层、H 个注意力头、头维度为 A、batch size 为 B、序列长度为 S、精度为 Pbyte 的模型,KV Cache 占用的显存约为:

KV Cache Size (Bytes)=2×B×S×L×H×A×Pbyte

其中系数 2 表示同时存储 Key 和 Value。

3.1 示例

以 Llama-2-7B 为例:L=32H=32A=128,FP16(Pbyte=2),Batch Size = 1,序列长度 S=1024

2×1×1024×32×32×128×2=512×10242 Bytes0.5 GB

当序列长度达到 128K 时,单请求的 KV Cache 可达数十 GB,极易成为显存瓶颈。

4. 基础实现

4.1 手动 KV Cache 示例

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class ModelArgs:
    def __init__(self, dim=16, n_heads=2, max_seq_len=8, max_batch_size=1):
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.max_seq_len = max_seq_len
        self.max_batch_size = max_batch_size

class SelfAttention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.head_dim = args.head_dim
        self.dim = args.dim

        self.w_q = nn.Linear(self.dim, self.dim, bias=False)
        self.w_k = nn.Linear(self.dim, self.dim, bias=False)
        self.w_v = nn.Linear(self.dim, self.dim, bias=False)
        self.w_o = nn.Linear(self.dim, self.dim, bias=False)

        self.register_buffer(
            "cache_k",
            torch.zeros(args.max_batch_size, args.max_seq_len, self.n_heads, self.head_dim),
        )
        self.register_buffer(
            "cache_v",
            torch.zeros(args.max_batch_size, args.max_seq_len, self.n_heads, self.head_dim),
        )

    def forward(self, x: torch.Tensor, start_pos: int):
        B, S, D = x.size()
        H, Hd = self.n_heads, self.head_dim

        q = self.w_q(x).view(B, S, H, Hd)
        k = self.w_k(x).view(B, S, H, Hd)
        v = self.w_v(x).view(B, S, H, Hd)

        # 更新缓存
        self.cache_k[:B, start_pos:start_pos + S] = k
        self.cache_v[:B, start_pos:start_pos + S] = v

        keys = self.cache_k[:B, :start_pos + S]
        values = self.cache_v[:B, :start_pos + S]

        q = q.transpose(1, 2)            # (B, H, S, Hd)
        k = keys.transpose(1, 2)         # (B, H, Seq_KV, Hd)
        v = values.transpose(1, 2)       # (B, H, Seq_KV, Hd)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (Hd ** 0.5)
        weights = F.softmax(scores, dim=-1)
        output = torch.matmul(weights, v)

        output = output.transpose(1, 2).contiguous().view(B, S, D)
        return self.w_o(output)

# 测试:逐步生成一个序列
torch.manual_seed(42)
args = ModelArgs()
attn = SelfAttention(args)
sequence = torch.randn(1, args.max_seq_len, args.dim)
outputs = []
for i in range(args.max_seq_len):
    x = sequence[:, i:i + 1, :]
    y = attn(x, start_pos=i)
    outputs.append(y)
final = torch.cat(outputs, dim=1)
print("最终输出 shape:", final.shape)

4.2 使用 Transformers 的 DynamicCache

python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, device_map="cuda:0"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

past_key_values = DynamicCache()
messages = [{"role": "user", "content": "Hello, what's your name?"}]
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
).to("cuda:0")

cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device="cuda:0")
max_new_tokens = 10

for _ in range(max_new_tokens):
    outputs = model(
        **inputs,
        cache_position=cache_position,
        past_key_values=past_key_values,
        use_cache=True,
    )
    next_token_ids = outputs.logits[:, -1:].argmax(-1)
    attention_mask = torch.cat(
        [inputs["attention_mask"], inputs["attention_mask"].new_ones((1, 1))],
        dim=-1,
    )
    inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}
    cache_position = cache_position[-1:] + 1

print(tokenizer.batch_decode(
    torch.cat([inputs["input_ids"], past_key_values.key_cache[0][:, -max_new_tokens:]], dim=-1),
    skip_special_tokens=True,
)[0])

4.3 传统缓存格式

DynamicCache 之前,缓存以嵌套张量元组形式存储。可使用 from_legacy_cache()to_legacy_cache() 进行转换:

python
from transformers import DynamicCache

generation_outputs = model.generate(
    **inputs,
    return_dict_in_generate=True,
    return_legacy_cache=True,
    max_new_tokens=5,
)
cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
legacy_cache = cache.to_legacy_cache()

5. 性能影响

使用 KV Cache 可显著降低推理时间。以 GPT-2 在 Tesla T4 上生成 1000 个新 Token 为例:

模式平均耗时标准差
使用 KV Cache11.885 s±0.272 s
不使用 KV Cache56.197 s±1.855 s

在大多数 Transformer 生成任务中,建议始终启用 use_cache=True

参考

Maintained by Robin