KL散度的近似计算方法
本文探讨在无法解析计算的情况下,利用蒙特卡洛方法对KL散度进行近似估计的技术。我们重点分析一种在实际代码中广泛使用的技巧:使用 $\frac{1}{2}(\log p(x) - \log q(x))^2$ 的样本均值来估计 $KL[q, p]$,而非标准的 $\log \frac{q(x)}{p(x)}$ 形式。我们将解释该估计量为何是一个良好(尽管有偏)的近似,并进一步讨论如何构造一个无偏且低方差的替代估计量。
KL散度的定义
KL散度(Kullback-Leibler Divergence)衡量两个概率分布 $q(x)$ 和 $p(x)$ 之间的差异,其定义如下:
$$ KL[q, p] = \mathbb{E}_{x \sim q} \left[ \log \frac{q(x)}{p(x)} \right] = \int q(x) \log \frac{q(x)}{p(x)} , dx $$
当 $q$ 与 $p$ 接近时,KL散度趋近于零;当两者差异较大时,KL散度增大。KL散度非负且不对称,即 $KL[q, p] \neq KL[p, q]$。
蒙特卡洛估计的动机
在许多实际应用中,我们虽然能够计算任意给定点 $x$ 的概率密度值 $p(x)$ 和 $q(x)$,但无法对整个空间进行解析积分或求和。这通常由以下原因导致:
- 计算复杂度过高:精确计算涉及高维积分或大规模求和,计算成本过高。
- 缺乏闭式解:某些分布组合下,KL散度无解析表达式。
- 工程实现简化:在某些场景(如强化学习中的策略优化),KL散度仅作为诊断指标,只需通过采样估计即可,无需精确值。
因此,蒙特卡洛方法成为一种自然的选择。
常见的KL散度估计量
假设我们可以从分布 $q$ 中独立采样得到样本 ${x_1, x_2, \dots, x_N}$,目标是构造 $KL[q, p]$ 的估计量。一个理想的估计量应具备以下特性:
- 无偏性:估计量的期望等于真实KL值。
- 低方差:估计结果稳定,减少采样噪声的影响。
1. 标准估计量 $k_1$
最直接的无偏估计量是基于KL散度定义本身:
$$ k_1 = \log \frac{q(x)}{p(x)} = -\log r, \quad \text{其中 } r = \frac{p(x)}{q(x)} $$
该估计量满足 $\mathbb{E}_q[k_1] = KL[q, p]$,但其方差较高。这是因为当 $p(x) > q(x)$ 时,$\log \frac{q(x)}{p(x)} < 0$,而KL散度本身始终非负,导致估计值在零附近剧烈波动。
2. 有偏但低方差估计量 $k_2$
另一种常用估计量为:
$$ k_2 = \frac{1}{2} (\log r)^2 = \frac{1}{2} \left( \log \frac{p(x)}{q(x)} \right)^2 $$
该估计量始终非负,且实验表明其方差显著低于 $k_1$。尽管 $\mathbb{E}_q[k_2] \neq KL[q, p]$,但其偏差在 $p \approx q$ 时非常小。下面我们解释其背后的理论依据。
为什么 $k_2$ 是一个良好的近似?
关键在于 $k_2$ 的期望属于一类称为 $f$-散度($f$-divergence)的广义距离度量。$f$-散度定义为:
$$ D_f(p, q) = \mathbb{E}_{x \sim q} \left[ f\left( \frac{p(x)}{q(x)} \right) \right] $$
其中 $f: \mathbb{R}^+ \to \mathbb{R}$ 是一个凸函数,且 $f(1) = 0$。KL散度本身是 $f$-散度的特例:
- $KL[q, p]$ 对应 $f(x) = -\log x$
- $KL[p, q]$ 对应 $f(x) = x \log x$
对于估计量 $k_2$,其对应的 $f$ 函数为 $f(x) = \frac{1}{2} (\log x)^2$。
一个重要结论是:当两个分布 $p$ 和 $q$ 接近时,所有具有可微 $f$ 的 $f$-散度在二阶泰勒展开下都近似于KL散度。具体地,考虑参数化分布族 $p_\theta$,在 $\theta = 0$ 附近展开:
$$ D_f(p_0, p_\theta) = \frac{f''(1)}{2} \theta^\top F \theta + O(|\theta|^3) $$
其中 $F$ 是Fisher信息矩阵。对于 $f(x) = -\log x$ 和 $f(x) = \frac{1}{2} (\log x)^2$,均有 $f''(1) = 1$,因此当 $p \approx q$ 时,$KL[q, p]$ 与 $\mathbb{E}_q[k_2]$ 在局部具有相同的二次结构,偏差较小。
构造无偏且低方差的估计量
能否构造一个既无偏又低方差的KL散度估计量?一种通用方差缩减技术是 控制变量法(Control Variates)。其思想是:从原始估计量中减去一个期望为零但与之负相关的辅助项。
已知 $\mathbb{E}_q[r] = \mathbb{E}_q\left[\frac{p(x)}{q(x)}\right] = 1$,因此 $r - 1$ 是一个自然的控制变量。考虑如下形式的估计量:
$$ k_\lambda = -\log r + \lambda (r - 1) $$
由于 $\mathbb{E}q[r - 1] = 0$,故 $\mathbb{E}q[k\lambda] = KL[q, p]$,即对任意 $\lambda$,$k\lambda$ 都是无偏的。最优 $\lambda$ 可通过最小化方差求得,但其表达式依赖于 $p$ 和 $q$ 的具体形式,通常难以解析计算。
一个简单而有效的选择是 $\lambda = 1$,此时:
$$ k_3 = (r - 1) - \log r $$
该估计量具有以下优点:
- 无偏性:$\mathbb{E}_q[k_3] = KL[q, p]$
- 非负性:由不等式 $\log x \leq x - 1$ 可知 $k_3 \geq 0$
- 几何解释:$k_3$ 是凸函数 $-\log x$ 与其在 $x=1$ 处切线 $-(x-1)$ 之间的垂直距离,属于 Bregman散度 的特例。
推广到其他 $f$-散度
上述构造方法可推广至任意 $f$-散度。设 $f$ 为凸函数,且 $f(1) = 0$,则以下表达式是 $D_f(p, q)$ 的一个无偏、非负估计量:
$$ \hat{D}_f(p, q) = f(r) - f'(1)(r - 1) $$
该式表示凸函数 $f(x)$ 与其在 $x=1$ 处切线之间的差距。
对于 $KL[p, q]$,$f(x) = x \log x$,$f'(1) = 1$,因此估计量为:
$$ r \log r - (r - 1) $$
对于 $KL[q, p]$,$f(x) = -\log x$,$f'(1) = -1$,因此估计量为:
$$ (r - 1) - \log r \quad (\text{即 } k_3) $$
实验比较
我们通过数值实验比较三种 $KL[q, p]$ 估计量的性能。设 $q = \mathcal{N}(0, 1)$,真实KL散度通过 torch.distributions.kl_divergence 计算。
实验设置1:小KL散度($p = \mathcal{N}(0.1, 1)$,真实KL = 0.005)
| 估计量 | 相对偏差(偏差 / 真实值) | 相对标准差(标准差 / 真实值) |
|---|---|---|
| $k_1 = -\log r$ | 0(无偏) | 20 |
| $k_2 = \frac{1}{2}(\log r)^2$ | 0.002(0.2%) | 1.42 |
| $k_3 = (r - 1) - \log r$ | 0(无偏) | 1.42 |
结论:当 $p \approx q$ 时,$k_2$ 的偏差极小,且方差显著低于 $k_1$。
实验设置2:较大KL散度($p = \mathcal{N}(1, 1)$,真实KL = 0.5)
| 估计量 | 相对偏差(偏差 / 真实值) | 相对标准差(标准差 / 真实值) |
|---|---|---|
| $k_1 = -\log r$ | 0(无偏) | 2 |
| $k_2 = \frac{1}{2}(\log r)^2$ | 0.25(25%) | 1.73 |
| $k_3 = (r - 1) - \log r$ | 0(无偏) | 1.7 |
结论:随着 $p$ 与 $q$ 差异增大,$k_2$ 的偏差显著上升,而 $k_3$ 在保持无偏的同时方差更低,是更优的选择。
代码实现
import torch.distributions as dis
# 定义分布
p = dis.Normal(loc=0.0, scale=1.0)
q = dis.Normal(loc=0.1, scale=1.0) # 可替换为 loc=1.0
# 采样
x = q.sample(sample_shape=(10_000_000,))
# 真实KL散度(用于比较)
true_kl = dis.kl_divergence(q, p) # 注意:PyTorch中是 KL[q, p]
print("真实KL散度:", true_kl.item())
# 计算比率对数
log_r = p.log_prob(x) - q.log_prob(x) # log(p(x)/q(x))
# 三种估计量
k1 = -log_r # 标准估计量
k2 = 0.5 * (log_r ** 2) # 有偏低方差估计量
k3 = (log_r.exp() - 1) - log_r # 无偏低方差估计量 (r - 1 - log r)
# 输出相对偏差和相对标准差
for name, k in zip(['k1', 'k2', 'k3'], [k1, k2, k3]):
bias = (k.mean() - true_kl) / true_kl
std_ratio = k.std() / true_kl
print(f"{name}: 相对偏差 = {bias:.4f}, 相对标准差 = {std_ratio:.4f}")总结
本文系统分析了KL散度的蒙特卡洛估计方法,得出以下结论:
- $k_2 = \frac{1}{2}(\log r)^2$ 在 $p \approx q$ 时偏差极小且方差低,适合微小更新的场景(如策略梯度方法中的KL约束)。
- $k_3 = (r - 1) - \log r$ 是一个无偏、非负、低方差的估计量,适用于一般情况,是 $k_1$ 和 $k_2$ 的严格改进。
- 该构造思想可推广至任意 $f$-散度,通过凸函数与其切线的差距构建良好的估计量。
在实际应用中,推荐优先使用 $k_3$ 作为KL散度的估计量,尤其在需要无偏性和稳定性时。
参考文献:
- Joschu's Blog: KL Approximation
- 知乎专栏: KL散度的近似与估计