Skip to content

采样参数详解

大语言模型(LLM)在生成文本时,会先将上下文映射为 logits,再通过 Softmax 转换为概率分布,最后依据解码策略采样下一个 Token。本文将系统介绍 temperaturetop_ktop_pmin_p 等核心采样参数,以及 Transformers 中常用的生成配置参数。

1. 背景:从 logits 到概率

对于词汇表中的第 i 个 Token,模型输出 logit zi。标准 Softmax 将其转换为概率:

P(yi)=ezij=1Kezj

其中 K 为词汇表大小。采样参数即在这一概率分布上进一步施加约束。

2. temperature

temperature 控制输出随机性,通常取值在 (0,+) 之间。

2.1 带温度的 Softmax

P(yiτ)=ezi/τj=1Kezj/τ
  • τ0:趋近贪婪采样,仅选择最大 logit 的 Token。
  • τ=1:标准 Softmax。
  • τ+:趋近均匀分布。

2.2 温度对概率分布的影响

τ<1 时:

ziτ>zi放大高概率 Token 的优势

τ>1 时:

ziτ<zi压缩概率差异

输出分布的熵随温度升高而增大:

H(P)=i=1KP(yiτ)logP(yiτ)

2.3 实现示例

python
import numpy as np

def softmax_with_temperature(logits, temperature=1.0):
    scaled = logits / temperature
    exp = np.exp(scaled - np.max(scaled))  # 数值稳定
    return exp / np.sum(exp)

2.4 推荐设置

温度分布特点示例场景
0.1尖锐峰值,低多样性事实问答、代码补全
0.5适度平滑,兼顾确定性与多样性通用对话
1.0原始模型分布默认配置
2.0平坦分布,高多样性创意写作、故事生成

DeepSeek 官方推荐:

场景温度
代码生成 / 数学解题0.0
数据抽取 / 分析1.0
通用对话1.3
翻译1.3
创意写作 / 诗歌创作1.5

2.5 温度对梯度的影响

P(yiτ)zi=P(yiτ)(1P(yiτ))τ

温度越高,梯度越小,模型更新越温和。

2.6 可视化

下图展示了不同温度下概率分布的变化。左图为按 Token 索引排列的原始视图,右图为按概率降序排列的视图。运行以下绘图代码可生成图片。

绘图代码如下:

python
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import softmax

np.random.seed(42)
logits = np.random.normal(0, 1, 50)
temperatures = [0.5, 1.0, 1.5]

# 原始视图
plt.figure(figsize=(12, 8))
original = softmax(logits)
plt.plot(original, "k--", alpha=0.5, label="Original (T=1.0)")
for temp in temperatures:
    probs = softmax(logits / temp)
    plt.plot(probs, "o-", label=f"T={temp}", alpha=0.7)
plt.title("Effect of Temperature on Probability Distribution")
plt.xlabel("Token Index")
plt.ylabel("Probability")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("temperature_effect_raw.png", dpi=300, bbox_inches="tight")
plt.close()

# 排序视图
plt.figure(figsize=(12, 8))
for temp in temperatures:
    probs = np.sort(softmax(logits / temp))[::-1]
    plt.plot(probs, "o-", label=f"T={temp}", alpha=0.7)
plt.title("Sorted Probability Distributions at Different Temperatures")
plt.xlabel("Rank (0 = highest)")
plt.ylabel("Probability (log scale)")
plt.yscale("log")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("temperature_effect_sorted.png", dpi=300, bbox_inches="tight")
plt.close()

3. top_k

top_k 限制采样池只保留概率最高的 k 个 Token,然后在该子集上重新归一化并采样。

  • k 较小:输出更可控,但多样性受限。
  • k 较大:候选更丰富,但可能引入低质量 Token。
  • 典型值:20–100。

4. top_p

top_p(核采样)动态选择累积概率超过 p 的最小 Token 集合,再在该集合内采样。

优势:

  • 自适应候选集大小。

  • 避免低概率 Token,同时保留足够多样性。

  • p1:接近完整词汇表采样,随机性高。

  • p0:仅保留极少数高概率 Token,输出更确定。

5. min_p

min_p 通过相对阈值排除低概率 Token:只保留概率大于 min_p × 最高概率 Token 的候选。它是对 Top-p 的补充,适合需要精细控制多样性的场景。

例如 min_p=0.1 会排除概率不到最高概率 Token 10% 的所有选项。

6. 参数协同工作机制

temperature=0.8, top_k=50, top_p=0.9, min_p=0.01 为例,执行流程如下:

  1. 使用温度缩放 logits。
  2. 保留 Top-K 候选 Token。
  3. 从中筛选累积概率 0.9 且相对概率 \ge \text{min_p} 的 Token。
  4. 重新 Softmax 归一化。
  5. 采样得到下一个 Token。

默认无约束配置为:temperature=1.0, top_k=-1, top_p=1.0, min_p=None

7. Transformers 生成参数速查

7.1 控制输出长度

  • max_lengthint,默认 20):生成 Token 的最大总长度(输入 + 输出)。若同时设置 max_new_tokens,则后者覆盖该值。
  • max_new_tokensint):生成的最大新 Token 数。
  • min_length / min_new_tokens:类似地控制最小长度。
  • early_stoppingboolstr,默认 False):控制束搜索停止条件。 "never" 表示经典束搜索。
  • max_timefloat):允许生成的最大运行时间(秒)。
  • stop_stringsstrList[str]):命中时停止生成。

7.2 控制生成策略

  • do_samplebool,默认 False):是否使用采样;否则使用贪婪解码。
  • num_beamsint,默认 1):束数量,大于 1 时启用束搜索。
  • num_beam_groupsint,默认 1):多样性束搜索的分组数。
  • penalty_alphafloat):对比搜索中平衡置信度与退化惩罚的值。
  • dola_layersstrList[int]):DoLa 解码使用的层。

7.3 控制缓存

  • use_cachebool,默认 True):是否使用 KV Cache 加速解码。
  • cache_implementationstr,默认 None):指定缓存类,如 "dynamic""static""offloaded_static""sliding_window""hybrid""quantized" 等。
  • cache_configCacheConfigdict):缓存类参数。
  • return_legacy_cachebool,默认 True):是否返回旧格式缓存。

7.4 控制 logits 与采样

  • temperaturefloat,默认 1.0)
  • top_kint,默认 50)
  • top_pfloat,默认 1.0)
  • min_pfloat):最小相对概率阈值。
  • typical_pfloat,默认 1.0):典型性采样阈值。
  • epsilon_cutoff / eta_cutofffloat,默认 0.0):截断采样阈值。
  • repetition_penaltyfloat,默认 1.0):重复惩罚。
  • diversity_penaltyfloat,默认 0.0):多样性束搜索惩罚。
  • no_repeat_ngram_sizeint,默认 0):禁止重复的 n-gram 长度。
  • bad_words_ids / force_words_ids:禁用或强制生成的 Token ID。
  • renormalize_logitsbool,默认 False):是否在 logits 处理器后重新归一化。

7.5 特殊 Token 与输出控制

  • pad_token_idbos_token_ideos_token_id:分别指定填充、序列开始、序列结束 Token。
  • num_return_sequencesint,默认 1):每个输入返回的序列数。
  • output_attentionsoutput_hidden_statesoutput_scoresoutput_logitsreturn_dict_in_generate:控制额外输出。

7.6 辅助生成

  • assistant_model:推测解码的辅助模型。
  • num_assistant_tokens:每次辅助模型生成的候选 Token 数。
  • prompt_lookup_num_tokensmax_matching_ngram_size:Prompt Lookup Decoding 参数。
  • assistant_early_exit:Self-speculative Decoding 的早期退出层。

7.7 编译相关

  • compile_configCompileConfig):静态缓存下的编译配置。
  • disable_compilebool):是否禁用前向编译优化。

参考

Maintained by Robin