KV Cache 机制详解
在生成式 Transformer 模型中,键值状态(Key-Value States)缓存技术能够显著加速解码过程。本文将介绍 KV Cache 的核心原理、数学表达、基础实现以及性能影响。
注意:KV Cache 仅存在于自回归解码器模型(如 GPT、LLaMA 等)的生成阶段。BERT 等非生成式模型不使用 KV Cache。
1. 自回归解码的冗余计算
解码器以自回归方式工作:模型根据输入预测下一个 Token,再将该 Token 拼回输入进行下一步预测。由于因果掩码(Causal Mask)的限制,每个新 Token 只能关注它及之前的 Token,因此历史 Token 的 Key 和 Value 会在每一步被重复计算。
例如,在第三步输入 "She poured coffee" 时,"She" 与 "poured" 的 Key/Value 已经被计算过。KV Cache 通过持久化存储这些历史状态,避免重复计算。
1.1 朴素自回归生成示例
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class GreedySampler:
def __init__(self, model_name="gpt2-medium"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
def __call__(self, prompt, max_new_tokens=10):
result = prompt
for i in range(max_new_tokens):
input_ids = self.tokenizer.encode(result, return_tensors="pt").to(self.device)
with torch.no_grad():
logits = self.model(input_ids).logits[0, -1, :]
next_token_id = torch.argmax(logits, dim=-1).item()
result += self.tokenizer.decode(next_token_id)
print(f"step {i}: {result}")
return result不使用 KV Cache 时,每步输入序列都会变长,计算量随之线性增长。
2. KV Cache 的工作原理
KV Cache 在推理时保存每个 Token 的 Key 和 Value 向量。生成第
2.1 计算复杂度对比
- 不使用 KV Cache:每步计算量为
,随序列长度平方增长。 - 使用 KV Cache:每步计算量降为
,仅需计算新 Token 与所有历史 Token 的注意力;显存开销为 。
2.2 注意力计算示意
设第
KV Cache 更新:
Query 计算:
注意力输出:
前馈网络:
3. 显存占用估算
对于一个拥有
其中系数 2 表示同时存储 Key 和 Value。
3.1 示例
以 Llama-2-7B 为例:
当序列长度达到 128K 时,单请求的 KV Cache 可达数十 GB,极易成为显存瓶颈。
4. 基础实现
4.1 手动 KV Cache 示例
import torch
import torch.nn as nn
import torch.nn.functional as F
class ModelArgs:
def __init__(self, dim=16, n_heads=2, max_seq_len=8, max_batch_size=1):
self.dim = dim
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
class SelfAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.head_dim = args.head_dim
self.dim = args.dim
self.w_q = nn.Linear(self.dim, self.dim, bias=False)
self.w_k = nn.Linear(self.dim, self.dim, bias=False)
self.w_v = nn.Linear(self.dim, self.dim, bias=False)
self.w_o = nn.Linear(self.dim, self.dim, bias=False)
self.register_buffer(
"cache_k",
torch.zeros(args.max_batch_size, args.max_seq_len, self.n_heads, self.head_dim),
)
self.register_buffer(
"cache_v",
torch.zeros(args.max_batch_size, args.max_seq_len, self.n_heads, self.head_dim),
)
def forward(self, x: torch.Tensor, start_pos: int):
B, S, D = x.size()
H, Hd = self.n_heads, self.head_dim
q = self.w_q(x).view(B, S, H, Hd)
k = self.w_k(x).view(B, S, H, Hd)
v = self.w_v(x).view(B, S, H, Hd)
# 更新缓存
self.cache_k[:B, start_pos:start_pos + S] = k
self.cache_v[:B, start_pos:start_pos + S] = v
keys = self.cache_k[:B, :start_pos + S]
values = self.cache_v[:B, :start_pos + S]
q = q.transpose(1, 2) # (B, H, S, Hd)
k = keys.transpose(1, 2) # (B, H, Seq_KV, Hd)
v = values.transpose(1, 2) # (B, H, Seq_KV, Hd)
scores = torch.matmul(q, k.transpose(-2, -1)) / (Hd ** 0.5)
weights = F.softmax(scores, dim=-1)
output = torch.matmul(weights, v)
output = output.transpose(1, 2).contiguous().view(B, S, D)
return self.w_o(output)
# 测试:逐步生成一个序列
torch.manual_seed(42)
args = ModelArgs()
attn = SelfAttention(args)
sequence = torch.randn(1, args.max_seq_len, args.dim)
outputs = []
for i in range(args.max_seq_len):
x = sequence[:, i:i + 1, :]
y = attn(x, start_pos=i)
outputs.append(y)
final = torch.cat(outputs, dim=1)
print("最终输出 shape:", final.shape)4.2 使用 Transformers 的 DynamicCache
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, device_map="cuda:0"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
past_key_values = DynamicCache()
messages = [{"role": "user", "content": "Hello, what's your name?"}]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to("cuda:0")
cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device="cuda:0")
max_new_tokens = 10
for _ in range(max_new_tokens):
outputs = model(
**inputs,
cache_position=cache_position,
past_key_values=past_key_values,
use_cache=True,
)
next_token_ids = outputs.logits[:, -1:].argmax(-1)
attention_mask = torch.cat(
[inputs["attention_mask"], inputs["attention_mask"].new_ones((1, 1))],
dim=-1,
)
inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}
cache_position = cache_position[-1:] + 1
print(tokenizer.batch_decode(
torch.cat([inputs["input_ids"], past_key_values.key_cache[0][:, -max_new_tokens:]], dim=-1),
skip_special_tokens=True,
)[0])4.3 传统缓存格式
在 DynamicCache 之前,缓存以嵌套张量元组形式存储。可使用 from_legacy_cache() 与 to_legacy_cache() 进行转换:
from transformers import DynamicCache
generation_outputs = model.generate(
**inputs,
return_dict_in_generate=True,
return_legacy_cache=True,
max_new_tokens=5,
)
cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
legacy_cache = cache.to_legacy_cache()5. 性能影响
使用 KV Cache 可显著降低推理时间。以 GPT-2 在 Tesla T4 上生成 1000 个新 Token 为例:
| 模式 | 平均耗时 | 标准差 |
|---|---|---|
| 使用 KV Cache | 11.885 s | ±0.272 s |
| 不使用 KV Cache | 56.197 s | ±1.855 s |
在大多数 Transformer 生成任务中,建议始终启用 use_cache=True。