Skip to content
Robin's Blog
Go back

KL散度的近似计算方法

Edit page

本文探讨在无法解析计算的情况下,利用蒙特卡洛方法对KL散度进行近似估计的技术。我们重点分析一种在实际代码中广泛使用的技巧:使用 12(logp(x)logq(x))2\frac{1}{2}(\log p(x) - \log q(x))^2 的样本均值来估计 KL[q,p]KL[q, p],而非标准的 logq(x)p(x)\log \frac{q(x)}{p(x)} 形式。我们将解释该估计量为何是一个良好(尽管有偏)的近似,并进一步讨论如何构造一个无偏且低方差的替代估计量。

KL散度的定义

KL散度(Kullback-Leibler Divergence)衡量两个概率分布 q(x)q(x)p(x)p(x) 之间的差异,其定义如下:

KL[q,p]=Exq[logq(x)p(x)]=q(x)logq(x)p(x)dxKL[q, p] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = \int q(x) \log \frac{q(x)}{p(x)} \, dx

qqpp 接近时,KL散度趋近于零;当两者差异较大时,KL散度增大。KL散度非负且不对称,即 KL[q,p]KL[p,q]KL[q, p] \neq KL[p, q]

蒙特卡洛估计的动机

在许多实际应用中,我们虽然能够计算任意给定点 xx 的概率密度值 p(x)p(x)q(x)q(x),但无法对整个空间进行解析积分或求和。这通常由以下原因导致:

  1. 计算复杂度过高:精确计算涉及高维积分或大规模求和,计算成本过高。
  2. 缺乏闭式解:某些分布组合下,KL散度无解析表达式。
  3. 工程实现简化:在某些场景(如强化学习中的策略优化),KL散度仅作为诊断指标,只需通过采样估计即可,无需精确值。

因此,蒙特卡洛方法成为一种自然的选择。

常见的KL散度估计量

假设我们可以从分布 qq 中独立采样得到样本 {x1,x2,,xN}\{x_1, x_2, \dots, x_N\},目标是构造 KL[q,p]KL[q, p] 的估计量。一个理想的估计量应具备以下特性:

1. 标准估计量 k1k_1

最直接的无偏估计量是基于KL散度定义本身:

k1=logq(x)p(x)=logr,其中 r=p(x)q(x)k_1 = \log \frac{q(x)}{p(x)} = -\log r, \quad \text{其中 } r = \frac{p(x)}{q(x)}

该估计量满足 Eq[k1]=KL[q,p]\mathbb{E}_q[k_1] = KL[q, p],但其方差较高。这是因为当 p(x)>q(x)p(x) > q(x) 时,logq(x)p(x)<0\log \frac{q(x)}{p(x)} < 0,而KL散度本身始终非负,导致估计值在零附近剧烈波动。

2. 有偏但低方差估计量 k2k_2

另一种常用估计量为:

k2=12(logr)2=12(logp(x)q(x))2k_2 = \frac{1}{2} (\log r)^2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2

该估计量始终非负,且实验表明其方差显著低于 k1k_1。尽管 Eq[k2]KL[q,p]\mathbb{E}_q[k_2] \neq KL[q, p],但其偏差在 pqp \approx q 时非常小。下面我们解释其背后的理论依据。

为什么 k2k_2 是一个良好的近似?

关键在于 k2k_2 的期望属于一类称为 ff-散度ff-divergence)的广义距离度量。ff-散度定义为:

Df(p,q)=Exq[f(p(x)q(x))]D_f(p, q) = \mathbb{E}_{x \sim q} \left[ f\left( \frac{p(x)}{q(x)} \right) \right]

其中 f:R+Rf: \mathbb{R}^+ \to \mathbb{R} 是一个凸函数,且 f(1)=0f(1) = 0。KL散度本身是 ff-散度的特例:

对于估计量 k2k_2,其对应的 ff 函数为 f(x)=12(logx)2f(x) = \frac{1}{2} (\log x)^2

一个重要结论是:当两个分布 ppqq 接近时,所有具有可微 ffff-散度在二阶泰勒展开下都近似于KL散度。具体地,考虑参数化分布族 pθp_\theta,在 θ=0\theta = 0 附近展开:

Df(p0,pθ)=f(1)2θFθ+O(θ3)D_f(p_0, p_\theta) = \frac{f''(1)}{2} \theta^\top F \theta + O(\|\theta\|^3)

其中 FF 是Fisher信息矩阵。对于 f(x)=logxf(x) = -\log xf(x)=12(logx)2f(x) = \frac{1}{2} (\log x)^2,均有 f(1)=1f''(1) = 1,因此当 pqp \approx q 时,KL[q,p]KL[q, p]Eq[k2]\mathbb{E}_q[k_2] 在局部具有相同的二次结构,偏差较小。

构造无偏且低方差的估计量

能否构造一个既无偏又低方差的KL散度估计量?一种通用方差缩减技术是 控制变量法(Control Variates)。其思想是:从原始估计量中减去一个期望为零但与之负相关的辅助项。

已知 Eq[r]=Eq[p(x)q(x)]=1\mathbb{E}_q[r] = \mathbb{E}_q\left[\frac{p(x)}{q(x)}\right] = 1,因此 r1r - 1 是一个自然的控制变量。考虑如下形式的估计量:

kλ=logr+λ(r1)k_\lambda = -\log r + \lambda (r - 1)

由于 Eq[r1]=0\mathbb{E}_q[r - 1] = 0,故 Eq[kλ]=KL[q,p]\mathbb{E}_q[k_\lambda] = KL[q, p],即对任意 λ\lambdakλk_\lambda 都是无偏的。最优 λ\lambda 可通过最小化方差求得,但其表达式依赖于 ppqq 的具体形式,通常难以解析计算。

一个简单而有效的选择是 λ=1\lambda = 1,此时:

k3=(r1)logrk_3 = (r - 1) - \log r

该估计量具有以下优点:

推广到其他 ff-散度

上述构造方法可推广至任意 ff-散度。设 ff 为凸函数,且 f(1)=0f(1) = 0,则以下表达式是 Df(p,q)D_f(p, q) 的一个无偏、非负估计量:

D^f(p,q)=f(r)f(1)(r1)\hat{D}_f(p, q) = f(r) - f'(1)(r - 1)

该式表示凸函数 f(x)f(x) 与其在 x=1x=1 处切线之间的差距。

实验比较

我们通过数值实验比较三种 KL[q,p]KL[q, p] 估计量的性能。设 q=N(0,1)q = \mathcal{N}(0, 1),真实KL散度通过 torch.distributions.kl_divergence 计算。

实验设置1:小KL散度(p=N(0.1,1)p = \mathcal{N}(0.1, 1),真实KL = 0.005)

估计量相对偏差(偏差 / 真实值)相对标准差(标准差 / 真实值)
k1=logrk_1 = -\log r0(无偏)20
k2=12(logr)2k_2 = \frac{1}{2}(\log r)^20.002(0.2%)1.42
k3=(r1)logrk_3 = (r - 1) - \log r0(无偏)1.42

结论:当 pqp \approx q 时,k2k_2 的偏差极小,且方差显著低于 k1k_1

实验设置2:较大KL散度(p=N(1,1)p = \mathcal{N}(1, 1),真实KL = 0.5)

估计量相对偏差(偏差 / 真实值)相对标准差(标准差 / 真实值)
k1=logrk_1 = -\log r0(无偏)2
k2=12(logr)2k_2 = \frac{1}{2}(\log r)^20.25(25%)1.73
k3=(r1)logrk_3 = (r - 1) - \log r0(无偏)1.7

结论:随着 ppqq 差异增大,k2k_2 的偏差显著上升,而 k3k_3 在保持无偏的同时方差更低,是更优的选择。

代码实现

import torch.distributions as dis

# 定义分布
p = dis.Normal(loc=0.0, scale=1.0)
q = dis.Normal(loc=0.1, scale=1.0)  # 可替换为 loc=1.0

# 采样
x = q.sample(sample_shape=(10_000_000,))

# 真实KL散度(用于比较)
true_kl = dis.kl_divergence(q, p)  # 注意:PyTorch中是 KL[q, p]
print("真实KL散度:", true_kl.item())

# 计算比率对数
log_r = p.log_prob(x) - q.log_prob(x)  # log(p(x)/q(x))

# 三种估计量
k1 = -log_r  # 标准估计量
k2 = 0.5 * (log_r ** 2)  # 有偏低方差估计量
k3 = (log_r.exp() - 1) - log_r  # 无偏低方差估计量 (r - 1 - log r)

# 输出相对偏差和相对标准差
for name, k in zip(['k1', 'k2', 'k3'], [k1, k2, k3]):
    bias = (k.mean() - true_kl) / true_kl
    std_ratio = k.std() / true_kl
    print(f"{name}: 相对偏差 = {bias:.4f}, 相对标准差 = {std_ratio:.4f}")

总结

本文系统分析了KL散度的蒙特卡洛估计方法,得出以下结论:

  1. k2=12(logr)2k_2 = \frac{1}{2}(\log r)^2pqp \approx q 时偏差极小且方差低,适合微小更新的场景(如策略梯度方法中的KL约束)。
  2. k3=(r1)logrk_3 = (r - 1) - \log r 是一个无偏、非负、低方差的估计量,适用于一般情况,是 k1k_1k2k_2 的严格改进。
  3. 该构造思想可推广至任意 ff-散度,通过凸函数与其切线的差距构建良好的估计量。

在实际应用中,推荐优先使用 k3k_3 作为KL散度的估计量,尤其在需要无偏性和稳定性时。


参考文献


Edit page
Share this post:

Previous Post
从 Online Softmax 到 FlashAttention
Next Post
Trackio:Hugging Face 开源的免费实验追踪库,wandb 的即插即用替代方案