重参数化与强化学习
1. 为什么需要重参数化?随机函数的梯度计算问题
在机器学习领域,目标函数常常需要计算如下形式的随机函数:
其梯度为:
上式中的第一项可以使用蒙特卡洛法近似求解:
但第二项无法使用朴素蒙特卡洛法来近似,因为它是关于一个概率分布的梯度,而蒙特卡洛法只能计算某个函数关于采样的概率分布的期望。这里一般有两种主流办法来处理它。
2. 得分函数估计 Score Function Estimator
得分函数(score function)是对数似然函数关于概率分布参数的梯度,即
通过使用对数导数技巧(log-derivative trick),可以将概率密度函数的梯度转化为概率密度函数与得分函数的乘积:
上文中第二项
3. 重参数化技巧 Reparameterization Trick
实践中,得分函数估计带来的计算方差往往太大。为了解决这个问题,可以使用重参数化方法。该方法被广泛应用到例如变分自编码器(VAE)这样的经典机器学习模型中。
3.1 重参数化怎么做?具体原理详解
对于高斯分布这类可参数化的标准分布,其分布参数与采样过程是可以独立分离的。它可以被分离为一组独立的参数与一个新的随机变量,分离后各个参数与新的随机变量无关。因而新的目标函数可以表示为新随机变量分布的期望,这种分离方法称为重参数化技巧。
3.2 高斯分布的参数化
这里以常用的高斯分布为例解释重参数化。假设随机变量
我们可以将这个高斯分布进一步拆解为来自于一个标准高斯分布
式中
通过这种参数化方法,可以使用有限个自然参数来近似表征复杂连续的概率分布。在上述例子中,对于这个高斯分布,仅需要
当然,我们可以根据数据分布和需要,对概率分布进行多种形式的分解,比如 Gamma 分布、Beta 分布以及 Student-
具体的实用例子:在 VAE 中,假设算法中的隐变量服从高斯分布
从而使目标函数从原来的
3.3 使用条件
重参数化法的使用条件是:
- 目标函数
对参数 和 都可微分; - 分布
的采样可以从某个独立分布 获得,然后使用一组参数 将其转化为分布 ,即 。
3.4 重参数化梯度
将重参数化后的形式代入,原先的随机目标函数可以变为:
那么我们就可以计算其梯度:
这个梯度称为重参数化梯度(reparameterization gradient / pathwise derivative)。
4. 为什么要用重参数化?两类优势
4.1 优势一:穿透随机性的求导
我们可以把概率分布重参数化后的参数理解为概率分布的“坐标”,因而可以把目标函数对于概率分布的梯度,转移到目标函数对这些坐标分量的梯度计算上去。对这些自然参数的求导有着明确的数学含义,比如
通过这种方法,可以近似认为:计算某个函数对该分布的导数,等价于计算该函数对该概率分布自然参数的导数。而在很多机器学习算法的实际使用中,假如最终的目标函数与某个概率分布有关,那么对目标函数的梯度就可以一直向后传递至对该概率分布的自然参数中,而不受随机性带来的影响。由此,可以便捷地构建基于梯度下降优化的各种机器学习模型(例如 VAE),从而解决一系列实用性问题。

(图 1:重参数化求导的原理解释图)
4.2 优势二:降低方差
通过重参数化,可以解决使用 Score Function Method 进行求取梯度时可能导致的高方差问题 [1]。具体来说,分别展开两种方法的梯度公式。在 Score Function Method 中:
在 Reparameterization Method 中:
对比上述两式,Score Function Method 会额外引入:
当
而重参数化之所以可以避免这一项,是因为使用了独立于分布参数的采样,故此项梯度为 0。
(不过,我们可以构造某些特殊的函数,让上式中的各项产生相关性,从而使得“重参数化法的梯度方差比 Score Function Method 更小”这个命题不成立。但在统计意义上,即实际应用中的绝大多数情况,使用重参数化可以有效降低方差。)
此外,在实践中,较低的采样样本数会带来较大的方差,但较高的采样样本数又会降低计算效率。因此,如果使用重参数化技巧直接对目标函数进行求导,降低梯度的方差,就可以让我们使用更少的采样样本,从而也对提升训练效率有帮助。
5. 强化学习中的重参数化:SAC 实例
在强化学习 SAC 算法中 [2],策略通过 SAC 中定义的 soft value function(即引入最大熵的价值函数)来指导策略优化。具体来说,SAC 算法希望找到可以最大化 soft value function 的策略,即:
在实现中,首先需要策略部分(Actor)输出当前状态对应的动作,然后交给价值函数(Critic)判断好坏得到目标函数,并用梯度下降进行优化。梯度将会从 Critic 沿着动作一路回传到 Actor。
在连续动作空间上,需要对策略使用高斯分布进行重参数化,并使用
需要注意的是,这里需要一个额外的矫正操作,因为使用
其中
所以代码实现中,我们需要修改压缩后的概率密度函数的 log prob 为:
考虑 SAC 算法中使用的 double Q-trick,最终重参数化后的策略学习目标函数为:
对于一个 batch 为
6. PPO 与重参数化的关系
回顾 PPO 的目标函数:
可以发现,虽然公式中存在着

(图 2:PPO 策略部分优化计算图,只有蓝色部分回传梯度,紫色部分只是参与计算。)
参考文献
- Xu M, Quiroz M, Kohn R, et al. Variance reduction properties of the reparameterization trick[C]//The 22nd International Conference on Artificial Intelligence and Statistics. PMLR, 2019: 2711-2720.
- Haarnoja T, Zhou A, Abbeel P, et al. Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor[C]//International Conference on Machine Learning. PMLR, 2018: 1861-1870.