Skip to content

重参数化与强化学习

1. 为什么需要重参数化?随机函数的梯度计算问题

在机器学习领域,目标函数常常需要计算如下形式的随机函数:

L(θ)=Eqθ(z)[l(θ,z)]=l(θ,z)qθ(z)dz.

其梯度为:

θL(θ)=θEqθ(z)[l(θ,z)]=θl(θ,z)qθ(z)dz=θl(θ,z)qθ(z)dz+l(θ,z)θqθ(z)dz=I1+I2.

上式中的第一项可以使用蒙特卡洛法近似求解:

I1=θl(θ,z)qθ(z)dz1Ss=1Sθl(θ,zs),zsqθ.

但第二项无法使用朴素蒙特卡洛法来近似,因为它是关于一个概率分布的梯度,而蒙特卡洛法只能计算某个函数关于采样的概率分布的期望。这里一般有两种主流办法来处理它。

2. 得分函数估计 Score Function Estimator

得分函数(score function)是对数似然函数关于概率分布参数的梯度,即

sθ(x)=θlogqθ(x).

通过使用对数导数技巧(log-derivative trick),可以将概率密度函数的梯度转化为概率密度函数与得分函数的乘积:

θqθ(z)=qθ(z)θqθ(z)qθ(z)=qθ(z)θlogqθ(z).

上文中第二项 I2 可以写成:

I2=l(θ,z)θqθ(z)dz=l(θ,z)qθ(z)θlogqθ(z)dz=Eqθ(z)[l(θ,z)θlogqθ(z)]1Ss=1Sl(θ,zs)θlogqθ(zs),zsqθ.

3. 重参数化技巧 Reparameterization Trick

实践中,得分函数估计带来的计算方差往往太大。为了解决这个问题,可以使用重参数化方法。该方法被广泛应用到例如变分自编码器(VAE)这样的经典机器学习模型中。

3.1 重参数化怎么做?具体原理详解

对于高斯分布这类可参数化的标准分布,其分布参数与采样过程是可以独立分离的。它可以被分离为一组独立的参数与一个新的随机变量,分离后各个参数与新的随机变量无关。因而新的目标函数可以表示为新随机变量分布的期望,这种分离方法称为重参数化技巧。

3.2 高斯分布的参数化

这里以常用的高斯分布为例解释重参数化。假设随机变量 xRD 服从高斯分布:

xN(μ,Σ).

我们可以将这个高斯分布进一步拆解为来自于一个标准高斯分布 ϵN(0,I) 的线性变换:

x=μ+LϵN(μ,LLT),

式中 L 是一个下三角方阵。当进一步简化假设各维度变量独立分布时,可以将 L 退化为一个对角矩阵:

L=[σ11σ21σ22σD1σD2σDD][σ11σ22σDD]=diag(σ).

通过这种参数化方法,可以使用有限个自然参数来近似表征复杂连续的概率分布。在上述例子中,对于这个高斯分布,仅需要 2D 个参数来描述即可。

当然,我们可以根据数据分布和需要,对概率分布进行多种形式的分解,比如 Gamma 分布、Beta 分布以及 Student-t 分布等等。这样可以把概率分布簇的参数与概率分布自身的随机性拆分和剥离,从而有益于计算和求导。

具体的实用例子:在 VAE 中,假设算法中的隐变量服从高斯分布 zqϕ(zx),那么这个隐变量可以被重参数化表达为

z=μϕ(x)+ϵσϕ(x),

从而使目标函数从原来的 Eqϕ(zx)[f(z)] 变为 Eq(ϵ)[f(z)]

3.3 使用条件

重参数化法的使用条件是:

  • 目标函数 l(θ,z) 对参数 zθ 都可微分;
  • 分布 z 的采样可以从某个独立分布 ϵ 获得,然后使用一组参数 θ 将其转化为分布 z,即 z=r(θ,ϵ)

3.4 重参数化梯度

将重参数化后的形式代入,原先的随机目标函数可以变为:

L(θ)=Eqθ(z)[l(θ,z)]=l(θ,z)qθ(z)dz=l(θ,r(θ,ϵ))q(ϵ)dϵ=Eq(ϵ)[l(θ,r(θ,ϵ))].

那么我们就可以计算其梯度:

θL(θ)=θEq(ϵ)[l(θ,r(θ,ϵ))]=Eq(ϵ)[θl(θ,r(θ,ϵ))]1Ss=1Sθl(θ,r(θ,ϵs)),ϵsq.

这个梯度称为重参数化梯度(reparameterization gradient / pathwise derivative)。

4. 为什么要用重参数化?两类优势

4.1 优势一:穿透随机性的求导

我们可以把概率分布重参数化后的参数理解为概率分布的“坐标”,因而可以把目标函数对于概率分布的梯度,转移到目标函数对这些坐标分量的梯度计算上去。对这些自然参数的求导有着明确的数学含义,比如 μ 是指概率分布的均值变化单位数值所带来的目标函数的变化量。

通过这种方法,可以近似认为:计算某个函数对该分布的导数,等价于计算该函数对该概率分布自然参数的导数。而在很多机器学习算法的实际使用中,假如最终的目标函数与某个概率分布有关,那么对目标函数的梯度就可以一直向后传递至对该概率分布的自然参数中,而不受随机性带来的影响。由此,可以便捷地构建基于梯度下降优化的各种机器学习模型(例如 VAE),从而解决一系列实用性问题。

(图 1:重参数化求导的原理解释图)

4.2 优势二:降低方差

通过重参数化,可以解决使用 Score Function Method 进行求取梯度时可能导致的高方差问题 [1]。具体来说,分别展开两种方法的梯度公式。在 Score Function Method 中:

θLSF(θ)=θl(θ,z)qθ(z)dz+l(θ,z)θqθ(z)dz=Eqθ(z)[θl(θ,z)+l(θ,z)θlogqθ(z)]=Eqθ(z)[l(θ,z)θ+l(θ,z)zzθ+l(θ,z)logqθ(z)θ].

在 Reparameterization Method 中:

θLReparam(θ)=θl(θ,z)qθ(z)dz=θl(θ,r(θ,ϵ))q(ϵ)dϵ=Eq(ϵ)[θl(θ,r(θ,ϵ))]=Eq(ϵ)[l(θ,r)θ+l(θ,r)rr(θ,ϵ)θ].

对比上述两式,Score Function Method 会额外引入:

Eqθ(z)[l(θ,z)θlogqθ(z)].

l(θ,z) 和分布无关时,该项为零。但在一般情况下,该项的存在会带来额外的梯度方差。因此一般来说:

Vq(ϵ)(l(θ,r)θ+l(θ,r)rr(θ,ϵ)θ)Vqθ(z)(l(θ,z)θ+l(θ,z)zzθ+l(θ,z)logqθ(z)θ).

而重参数化之所以可以避免这一项,是因为使用了独立于分布参数的采样,故此项梯度为 0。

(不过,我们可以构造某些特殊的函数,让上式中的各项产生相关性,从而使得“重参数化法的梯度方差比 Score Function Method 更小”这个命题不成立。但在统计意义上,即实际应用中的绝大多数情况,使用重参数化可以有效降低方差。)

此外,在实践中,较低的采样样本数会带来较大的方差,但较高的采样样本数又会降低计算效率。因此,如果使用重参数化技巧直接对目标函数进行求导,降低梯度的方差,就可以让我们使用更少的采样样本,从而也对提升训练效率有帮助。

5. 强化学习中的重参数化:SAC 实例

在强化学习 SAC 算法中 [2],策略通过 SAC 中定义的 soft value function(即引入最大熵的价值函数)来指导策略优化。具体来说,SAC 算法希望找到可以最大化 soft value function 的策略,即:

π=argmaxπVπ(s),Vπ(s)=Eaπ[Qπ(s,a)]+αH(π(s))=Eaπ[Qπ(s,a)αlogπ(as)].

在实现中,首先需要策略部分(Actor)输出当前状态对应的动作,然后交给价值函数(Critic)判断好坏得到目标函数,并用梯度下降进行优化。梯度将会从 Critic 沿着动作一路回传到 Actor。

在连续动作空间上,需要对策略使用高斯分布进行重参数化,并使用 tanh 函数压缩最终动作的输出范围到 [1,1],即采样到的动作通过下式获得:

aθ(s,ϵ)=tanh(μθ(s)+ϵσθ(s)),ϵN(0,I).

需要注意的是,这里需要一个额外的矫正操作,因为使用 tanh 函数压缩最终动作 a 会改变原本无界动作 au 的概率密度函数,即:

π(as)=p(aus)|detdadau|1,

其中

dadau=diag(1tanh2(au)).

所以代码实现中,我们需要修改压缩后的概率密度函数的 log prob 为:

logπ(as)=logp(aus)d=1Dlog(1tanh2(au,d)).

考虑 SAC 算法中使用的 double Q-trick,最终重参数化后的策略学习目标函数为:

θ=argmaxθEsϵN(0,I)[minj=1,2Qϕj(s,aθ(s,ϵ))αlogπθ(aθ(s,ϵ)s)].

对于一个 batch 为 B 的数据,这一批次数据产生的梯度为:

θJ(θ)=θ1|B|sB(minj=1,2Qϕj(s,aθ(s,ϵ))αlogπθ(aθ(s,ϵ)s))=1|B|sB(θ[minj=1,2Qϕj(s,aθ(s,ϵ))]αθlogπθ(aθ(s,ϵ)s)).

6. PPO 与重参数化的关系

回顾 PPO 的目标函数:

Et[min(pθ(atst)pθk(atst)A^θk(st,at),clip(pθ(atst)pθk(atst),1ϵ,1+ϵ)A^θk(st,at))].

可以发现,虽然公式中存在着 pθ(as) 的形式,但是与 SAC 中不同:训练时用到的动作 at 并不由将要更新的策略函数的参数 θ 生成,而是由旧策略(收集数据的策略)函数的参数 θk 生成。因此,在 PPO 中仅是借用重参数化的形式,将策略函数参数化为某个概率函数(比如连续动作空间使用高斯分布),即可让梯度通过概率函数的定义式直接回传。公式中的动作 at、状态 st、优势函数估计 A^θk 只是为了更新参数 θ 而存在的数据,它们本身并不回传任何梯度。

Data: s,a,A^θk,logitθk,Loss: πθ(as)πθk(as)A^θk.

(图 2:PPO 策略部分优化计算图,只有蓝色部分回传梯度,紫色部分只是参与计算。)

参考文献

  1. Xu M, Quiroz M, Kohn R, et al. Variance reduction properties of the reparameterization trick[C]//The 22nd International Conference on Artificial Intelligence and Statistics. PMLR, 2019: 2711-2720.
  2. Haarnoja T, Zhou A, Abbeel P, et al. Soft actor-critic: Off-policy maximum entropy deep reinforcement learning with a stochastic actor[C]//International Conference on Machine Learning. PMLR, 2018: 1861-1870.

Maintained by Robin