解码策略详解
解码策略(Decoding Strategy)决定了语言模型在每一步如何选择下一个 Token,从而直接影响生成文本的流畅性、多样性与事实性。本文将从最基础的贪婪搜索出发,逐步介绍束搜索、采样、Top-K/Top-p、推测解码、对比搜索、DoLa 以及自定义解码方法。
1. 自回归生成的概率视角
自回归语言模型将序列概率分解为条件概率的连乘:
其中
模型首先为每个候选 Token 输出 logits
其中
2. 贪婪搜索
贪婪搜索(Greedy Search)在每个时间步选择概率最高的 Token:
这种方法实现简单、计算高效,但容易陷入局部最优和重复循环,且可能错过隐藏在高概率词后面的更长优质序列。
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
inputs = tokenizer("I enjoy walking with my cute dog", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=40, do_sample=False)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))3. 束搜索
束搜索(Beam Search)在每个时间步保留 num_beams > 1 即可启用,通常配合 early_stopping=True 使用。
outputs = model.generate(
**inputs,
max_new_tokens=40,
num_beams=5,
early_stopping=True,
)3.1 避免重复
可通过 no_repeat_ngram_size=2 禁止重复 2-gram,但需注意对固定短语(如 "New York")的误伤。
outputs = model.generate(
**inputs,
max_new_tokens=40,
num_beams=5,
no_repeat_ngram_size=2,
early_stopping=True,
)3.2 返回多个候选
设置 num_return_sequences <= num_beams 可返回概率最高的多个候选序列,便于下游选择或重排序。
outputs = model.generate(
**inputs,
max_new_tokens=40,
num_beams=5,
num_return_sequences=3,
early_stopping=True,
)适用场景:束搜索在输出长度相对可预测的任务(如机器翻译、摘要)中表现更好;在开放式生成中可能产生重复或不自然的文本。
4. 采样与温度
采样(Sampling)根据模型输出的完整概率分布随机选择 Token,能够生成更多样化的内容。通过 do_sample=True 启用。
4.1 温度缩放
温度
:趋近贪婪解码,输出确定性高。 :标准 Softmax。 :分布趋于平坦,输出多样性增加。
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, temperature=0.7)5. Top-K 与 Top-p 采样
5.1 Top-K 采样
Top-K 采样只保留概率最高的
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, top_k=50)5.2 Top-p(核)采样
Top-p(Nucleus Sampling)动态选择累积概率超过
相比固定
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.92, top_k=0)两者也可联合使用:top_k=50, top_p=0.95。
6. 高级解码方法
6.1 推测解码
推测解码(Speculative Decoding)使用一个较小的辅助模型(Draft Model)或 N-gram 机制提前生成多个候选 Token,主模型在一次前向传播中并行验证,从而减少解码步数、加速生成。
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-1.7B")
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-1.7B")
assistant = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt")
outputs = model.generate(**inputs, assistant_model=assistant, max_new_tokens=50)变体包括:
- Prompt Lookup Decoding:利用输入文本中的 n-gram 作为候选 Token,适合摘要等输入导向任务。
- Self-speculative Decoding:使用模型早期层输出作为草稿,再通过后续层验证,可复用主模型缓存。
- Universal Assisted Decoding(UAD):允许主模型与辅助模型使用不同分词器,通过最长公共子序列对齐 Token。
6.2 对比搜索
对比搜索(Contrastive Search)在采样时惩罚与已生成 Token 过于相似的候选,从而在较长序列中抑制重复。
outputs = model.generate(**inputs, max_new_tokens=100, penalty_alpha=0.6, top_k=4)6.3 DoLa
DoLa(Decoding by Contrasting Layers)通过对比最终层与早期层的 logits 差异,增强事实性并减少幻觉。对于短答案任务建议使用 "high",长答案推理任务建议使用 "low"。
outputs = model.generate(
**inputs,
max_new_tokens=50,
dola_layers="high",
repetition_penalty=1.2,
do_sample=False,
)6.4 多样性束搜索
多样性束搜索(Diverse Beam Search)将束分组,并在组间施加多样性惩罚,适合需要多个差异较大候选的场景。
outputs = model.generate(
**inputs,
max_new_tokens=50,
num_beams=6,
num_beam_groups=3,
diversity_penalty=1.0,
do_sample=False,
)7. 自定义解码方法
Transformers 支持从模型仓库加载自定义 generate 实现,只需仓库包含 custom_generate/generate.py 与 custom_generate/requirements.txt,并在 README.md 中添加 custom_generate 标签。调用时通过 custom_generate 参数指定仓库名称:
outputs = model.generate(
**inputs,
custom_generate="transformers-community/custom_generate_example",
trust_remote_code=True,
)自定义方法需提供一个以 model 为第一个参数的 generate 函数,并可自由扩展输入输出行为。
8. 方法对比
| 解码方法 | 优点 | 缺点 | 典型场景 |
|---|---|---|---|
| 贪婪搜索 | 快速、简单 | 易重复、易错过全局最优 | 短答案、确定性任务 |
| 束搜索 | 考虑多路径、结果更优 | 易重复、调参复杂 | 翻译、摘要 |
| Top-K 采样 | 控制候选范围、结果多样 | 固定 | 通用生成 |
| Top-p 采样 | 动态候选集、生成自然 | 仍可能重复 | 对话、创意写作 |
| 推测解码 | 显著加速解码 | 需要辅助模型或机制 | 高吞吐在线服务 |
| 对比搜索 | 抑制长序列重复 | 增加计算开销 | 长文本生成 |
| DoLa | 提升事实性 | 不适合小模型 | 知识密集型问答 |