RLHF 中 KL 散度的近似计算
本文探讨在无法解析计算的情况下,利用蒙特卡洛方法对 KL 散度进行近似估计的技术。我们重点分析一种在实际代码中广泛使用的技巧:使用
1. KL 散度的定义
KL 散度(Kullback-Leibler Divergence)衡量两个概率分布
当
2. 蒙特卡洛估计的动机
在许多实际应用中,我们虽然能够计算任意给定点
- 计算复杂度过高:精确计算涉及高维积分或大规模求和,计算成本过高。
- 缺乏闭式解:某些分布组合下,KL 散度无解析表达式。
- 工程实现简化:在某些场景(如强化学习中的策略优化),KL 散度仅作为诊断指标,只需通过采样估计即可,无需精确值。
因此,蒙特卡洛方法成为一种自然的选择。
3. 常见的 KL 散度估计量
假设我们可以从分布
- 无偏性:估计量的期望等于真实 KL 值。
- 低方差:估计结果稳定,减少采样噪声的影响。
3.1 标准估计量
最直接的无偏估计量是基于 KL 散度定义本身:
该估计量满足
3.2 有偏但低方差估计量
另一种常用估计量为:
该估计量始终非负,且实验表明其方差显著低于
4. 为什么 是一个良好的近似?
关键在于
其中
对应 对应
对于估计量
一个重要结论是:当两个分布
其中
5. 构造无偏且低方差的估计量
能否构造一个既无偏又低方差的 KL 散度估计量?一种通用方差缩减技术是 控制变量法(Control Variates)。其思想是:从原始估计量中减去一个期望为零但与之负相关的辅助项。
已知
由于
一个简单而有效的选择是
该估计量具有以下优点:
- 无偏性:
- 非负性:由不等式
可知 - 几何解释:
是凸函数 与其在 处切线 之间的垂直距离,属于 Bregman 散度 的特例。
6. 推广到其他 -散度
上述构造方法可推广至任意
该式表示凸函数
对于
, , ,因此估计量为: 对于
, , ,因此估计量为:
7. 实验比较
我们通过数值实验比较三种 torch.distributions.kl_divergence 计算。
7.1 实验设置 1:小 KL 散度( ,真实 KL = 0.005)
| 估计量 | 相对偏差(偏差 / 真实值) | 相对标准差(标准差 / 真实值) |
|---|---|---|
| 0(无偏) | 20 | |
| 0.002(0.2%) | 1.42 | |
| 0(无偏) | 1.42 |
结论:当
7.2 实验设置 2:较大 KL 散度( ,真实 KL = 0.5)
| 估计量 | 相对偏差(偏差 / 真实值) | 相对标准差(标准差 / 真实值) |
|---|---|---|
| 0(无偏) | 2 | |
| 0.25(25%) | 1.73 | |
| 0(无偏) | 1.7 |
结论:随着
8. 代码实现
import torch.distributions as dis
# 定义分布
p = dis.Normal(loc=0.0, scale=1.0)
q = dis.Normal(loc=0.1, scale=1.0) # 可替换为 loc=1.0
# 采样
x = q.sample(sample_shape=(10_000_000,))
# 真实KL散度(用于比较)
true_kl = dis.kl_divergence(q, p) # 注意:PyTorch中是 KL[q, p]
print("真实KL散度:", true_kl.item())
# 计算比率对数
log_r = p.log_prob(x) - q.log_prob(x) # log(p(x)/q(x))
# 三种估计量
k1 = -log_r # 标准估计量
k2 = 0.5 * (log_r ** 2) # 有偏低方差估计量
k3 = (log_r.exp() - 1) - log_r # 无偏低方差估计量 (r - 1 - log r)
# 输出相对偏差和相对标准差
for name, k in zip(['k1', 'k2', 'k3'], [k1, k2, k3]):
bias = (k.mean() - true_kl) / true_kl
std_ratio = k.std() / true_kl
print(f"{name}: 相对偏差 = {bias:.4f}, 相对标准差 = {std_ratio:.4f}")总结
本文系统分析了 KL 散度的蒙特卡洛估计方法,得出以下结论:
在 时偏差极小且方差低,适合微小更新的场景(如策略梯度方法中的 KL 约束)。 是一个无偏、非负、低方差的估计量,适用于一般情况,是 和 的严格改进。 - 该构造思想可推广至任意
-散度,通过凸函数与其切线的差距构建良好的估计量。
在实际应用中,推荐优先使用
参考文献:
- Joschu's Blog: KL Approximation
- 知乎专栏:KL 散度的近似与估计