Skip to content

解码策略详解

解码策略(Decoding Strategy)决定了语言模型在每一步如何选择下一个 Token,从而直接影响生成文本的流畅性、多样性与事实性。本文将从最基础的贪婪搜索出发,逐步介绍束搜索、采样、Top-K/Top-p、推测解码、对比搜索、DoLa 以及自定义解码方法。

1. 自回归生成的概率视角

自回归语言模型将序列概率分解为条件概率的连乘:

P(w1:TW0)=t=1TP(wtw1:t1,W0)

其中 W0 为初始上下文,w1:0=,序列长度 T 在生成过程中动态确定,直到采样到结束符(EOS)为止。

模型首先为每个候选 Token 输出 logits zi,再通过 Softmax 转换为概率分布:

P(yi)=ezij=1Kezj

其中 K 为词汇表大小。解码策略即在此基础上决定如何从分布中选取下一个 Token。

2. 贪婪搜索

贪婪搜索(Greedy Search)在每个时间步选择概率最高的 Token:

wt=argmaxwP(ww1:t1)

这种方法实现简单、计算高效,但容易陷入局部最优和重复循环,且可能错过隐藏在高概率词后面的更长优质序列。

python
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)

inputs = tokenizer("I enjoy walking with my cute dog", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=40, do_sample=False)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

3. 束搜索

束搜索(Beam Search)在每个时间步保留 b 个概率最高的部分序列(束),最终选择整体概率最大的完整序列。设置 num_beams > 1 即可启用,通常配合 early_stopping=True 使用。

python
outputs = model.generate(
    **inputs,
    max_new_tokens=40,
    num_beams=5,
    early_stopping=True,
)

3.1 避免重复

可通过 no_repeat_ngram_size=2 禁止重复 2-gram,但需注意对固定短语(如 "New York")的误伤。

python
outputs = model.generate(
    **inputs,
    max_new_tokens=40,
    num_beams=5,
    no_repeat_ngram_size=2,
    early_stopping=True,
)

3.2 返回多个候选

设置 num_return_sequences <= num_beams 可返回概率最高的多个候选序列,便于下游选择或重排序。

python
outputs = model.generate(
    **inputs,
    max_new_tokens=40,
    num_beams=5,
    num_return_sequences=3,
    early_stopping=True,
)

适用场景:束搜索在输出长度相对可预测的任务(如机器翻译、摘要)中表现更好;在开放式生成中可能产生重复或不自然的文本。

4. 采样与温度

采样(Sampling)根据模型输出的完整概率分布随机选择 Token,能够生成更多样化的内容。通过 do_sample=True 启用。

4.1 温度缩放

温度 τ 控制 Softmax 分布的尖锐程度:

P(yiτ)=ezi/τj=1Kezj/τ
  • τ0:趋近贪婪解码,输出确定性高。
  • τ=1:标准 Softmax。
  • τ>1:分布趋于平坦,输出多样性增加。
python
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, temperature=0.7)

5. Top-K 与 Top-p 采样

5.1 Top-K 采样

Top-K 采样只保留概率最高的 K 个 Token,然后在该子集上重新归一化并采样。典型取值在 20–100 之间。

python
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, top_k=50)

5.2 Top-p(核)采样

Top-p(Nucleus Sampling)动态选择累积概率超过 p 的最小 Token 集合:

Vtop-p=min{VVwVP(w)p}

相比固定 K,Top-p 能根据当前分布形状自适应调整候选池大小,在实践中通常更稳定。

python
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, top_p=0.92, top_k=0)

两者也可联合使用:top_k=50, top_p=0.95

6. 高级解码方法

6.1 推测解码

推测解码(Speculative Decoding)使用一个较小的辅助模型(Draft Model)或 N-gram 机制提前生成多个候选 Token,主模型在一次前向传播中并行验证,从而减少解码步数、加速生成。

python
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-1.7B")
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-1.7B")
assistant = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M")

inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt")
outputs = model.generate(**inputs, assistant_model=assistant, max_new_tokens=50)

变体包括:

  • Prompt Lookup Decoding:利用输入文本中的 n-gram 作为候选 Token,适合摘要等输入导向任务。
  • Self-speculative Decoding:使用模型早期层输出作为草稿,再通过后续层验证,可复用主模型缓存。
  • Universal Assisted Decoding(UAD):允许主模型与辅助模型使用不同分词器,通过最长公共子序列对齐 Token。

6.2 对比搜索

对比搜索(Contrastive Search)在采样时惩罚与已生成 Token 过于相似的候选,从而在较长序列中抑制重复。

python
outputs = model.generate(**inputs, max_new_tokens=100, penalty_alpha=0.6, top_k=4)

6.3 DoLa

DoLa(Decoding by Contrasting Layers)通过对比最终层与早期层的 logits 差异,增强事实性并减少幻觉。对于短答案任务建议使用 "high",长答案推理任务建议使用 "low"

python
outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    dola_layers="high",
    repetition_penalty=1.2,
    do_sample=False,
)

6.4 多样性束搜索

多样性束搜索(Diverse Beam Search)将束分组,并在组间施加多样性惩罚,适合需要多个差异较大候选的场景。

python
outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    num_beams=6,
    num_beam_groups=3,
    diversity_penalty=1.0,
    do_sample=False,
)

7. 自定义解码方法

Transformers 支持从模型仓库加载自定义 generate 实现,只需仓库包含 custom_generate/generate.pycustom_generate/requirements.txt,并在 README.md 中添加 custom_generate 标签。调用时通过 custom_generate 参数指定仓库名称:

python
outputs = model.generate(
    **inputs,
    custom_generate="transformers-community/custom_generate_example",
    trust_remote_code=True,
)

自定义方法需提供一个以 model 为第一个参数的 generate 函数,并可自由扩展输入输出行为。

8. 方法对比

解码方法优点缺点典型场景
贪婪搜索快速、简单易重复、易错过全局最优短答案、确定性任务
束搜索考虑多路径、结果更优易重复、调参复杂翻译、摘要
Top-K 采样控制候选范围、结果多样固定 K 适应性差通用生成
Top-p 采样动态候选集、生成自然仍可能重复对话、创意写作
推测解码显著加速解码需要辅助模型或机制高吞吐在线服务
对比搜索抑制长序列重复增加计算开销长文本生成
DoLa提升事实性不适合小模型知识密集型问答

参考

Maintained by Robin