Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

苏剑林作者

全新视角:用变分推断统一理解生成模型(VAE、GAN、AAE、ALI)

摘要:本文从一种新的视角阐述了变分推断,并证明了 EM 算法、VAE、GAN、AAE、ALI (BiGAN) 都可以作为变分推断的某个特例。其中,论文也表明了标准的 GAN 的优化目标是不完备的,这可以解释为什么 GAN 的训练需要谨慎地选择各个参数。最后,文中给出了一个可以改善这种不完备性的正则项,实验表明该正则项能增强 GAN 训练的稳定性。

前言

我小学开始就喜欢纯数学,后来也喜欢上物理,还学习过一段时间的理论物理,直到本科毕业时,我才慢慢进入机器学习领域。所以,哪怕在机器学习领域中,我的研究习惯还保留着数学和物理的风格:企图从最少的原理出发,理解、推导尽可能多的东西。这篇文章是我这个理念的结果之一,试图以变分推断作为出发点,来统一地理解深度学习中的各种模型,尤其是各种让人眼花缭乱的 GAN。

本文已经挂到 arXiv 上,需要读英文原稿的可以访问下方链接下载论文 Variational Inference: A Unified Framework of Generative Models and Some Revelations。 

■ 论文 | Variational Inference: A Unified Framework of Generative Models and Some Revelations

■ 链接 | https://www.paperweekly.site/papers/2117

■ 作者 | Jianlin Su

下面是文章的介绍。其实,中文版的信息可能还比英文版要稍微丰富一些,原谅我这蹩脚的英语。

近年来,深度生成模型,尤其是 GAN,取得了巨大的成功。现在我们已经可以找到数十个乃至上百个 GAN 的变种。然而,其中的大部分都是凭着经验改进的,鲜有比较完备的理论指导。

本文的目标是通过变分推断来给这些生成模型建立一个统一的框架。首先,本文先介绍了变分推断的一个新形式,这个新形式其实在本人以前的文章中就已经介绍过,它可以让我们在几行字之内导出变分自编码器(VAE)和 EM 算法。然后,利用这个新形式,我们能直接导出 GAN,并且发现标准 GAN 的 loss 实则是不完备的,缺少了一个正则项。如果没有这个正则项,我们就需要谨慎地调整参数,才能使得模型收敛

实际上,本文这个工作的初衷,就是要将 GAN 纳入到变分推断的框架下。目前看来,最初的意图已经达到了,结果让人欣慰。新导出的正则项实际上是一个副产品,并且幸运的是,在我们的实验中这个副产品生效了。

变分推断新解

假设 x 为显变量,z 为隐变量,p̃(x) 为 x 的证据分布,并且有:

我们希望 qθ(x) 能逼近 p̃(x),所以一般情况下我们会去最大化似然函数

这也等价于最小化 KL 散度 KL(p̃(x))‖q(x)):

但是由于积分可能难以计算,因此大多数情况下都难以直接优化。 

变分推断中,首先引入联合分布 p(x,z) 使得p̃(x)=∫p(x,z)dz,而变分推断的本质,就是将边际分布的 KL 散度 KL(p̃(x)‖q(x)) 改为联合分布的 KL 散度 KL(p(x,z)‖q(x,z)) 或 KL(q(x,z)‖p(x,z)),而:

意味着联合分布的 KL 散度是一个更强的条件(上界)。所以一旦优化成功,那么我们就得到q(x,z)→p(x,z),从而 ∫q(x,z)dz→∫p(x,z)dz=p̃ (x),即 ∫q(x,z)dz 成为了真实分布 p̃(x) 的一个近似。

当然,我们本身不是为了加强条件而加强,而是因为在很多情况下,KL(p(x,z)‖q(x,z)) 或 KL(q(x,z)‖p(x,z)) 往往比 KL(p̃(x)‖q(x)) 更加容易计算。所以变分推断是提供了一个可计算的方案。

VAE和EM算法

由上述关于变分推断的新理解,我们可以在几句话内导出两个基本结果:变分自编码器和 EM 算法。这部分内容,实际上在从最大似然到EM算法:一致的理解方式变分自编码器(二):从贝叶斯观点出发已经详细介绍过了。这里用简单几句话重提一下。

VAE

在 VAE 中,我们设 q(x,z)=q(x|z)q(z),p(x,z)=p̃(x)p(z|x),其中 q(x|z),p(z|x) 带有未知参数高斯分布而 q(z) 是标准高斯分布。最小化的目标是:

其中 log(x) 没有包含优化目标,可以视为常数,而对 (x) 的积分则转化为对样本的采样,从而:

因为 q(x|z),p(z|x) 为带有神经网络高斯分布,这时候 KL(p(z|x)‖q(z)) 可以显式地算出,而通过重参数技巧来采样一个点完成积分 ∫p(z|x)logq(x|z)dz 的估算,可以得到 VAE 最终要最小化的 loss:



EM算法

在 VAE 中我们对后验分布做了约束,仅假设它是高斯分布,所以我们优化的是高斯分布参数。如果不作此假设,那么直接优化原始目标 (5),在某些情况下也是可操作的,但这时候只能采用交替优化的方式:先固定 p(z|x),优化 q(x|z),那么就有:

完成这一步后,我们固定 q(x,z),优化 p(z|x),先将 q(x|z)q(z) 写成 q(z|x)q(x) 的形式:

那么有:

由于现在对 p(z|x) 没有约束,因此可以直接让 p(z|x)=q(z|x) 使得 loss 等于 0。也就是说,p(z|x) 有理论最优解:

(8),(11) 的交替执行,构成了 EM 算法的求解步骤。这样,我们从变分推断框架中快速得到了 EM 算法。

变分推断下的GAN

在这部分内容中,我们介绍了一般化的将 GAN 纳入到变分推断中的方法,这将引导我们得到 GAN 的新理解,以及一个有效的正则项。 

一般框架

同 VAE 一样,GAN 也希望能训练一个生成模型 q(x|z),来将 q(z)=N(z;0,I) 映射为数据集分布(x),不同于 VAE 中将 q(x|z) 选择为高斯分布,GAN 的选择是:

其中 δ(x) 是狄拉克 δ 函数,G(z) 即为生成器的神经网络

一般我们会认为 z 是一个隐变量,但由于 δ 函数实际上意味着单点分布,因此可以认为 xz的关系已经是一一对应的,所以 x 的关系已经“不够随机”,在 GAN 中我们认为它不是隐变量(意味着我们不需要考虑后验分布 p(z|x))。

事实上,在 GAN 中仅仅引入了一个二元的隐变量 y 来构成联合分布:

这里 p1=1−p0 描述了一个二元概率分布,我们直接取 p1=p0=1/2。另一方面,我们设 p(x,y)=p(y|x)p̃(x),p(y|x) 是一个条件伯努利分布。而优化目标是另一方向的 KL(q(x,y)‖p(x,y)):

一旦成功优化,那么就有 q(x,y)→p(x,y),那么:

从而 q(x)→p̃(x),完成了生成模型的构建。 

现在我们优化对象有 p(y|x) 和 G(x),记 p(1|x)=D(x),这就是判别器。类似 EM 算法,我们进行交替优化:先固定 G(z),这也意味着 q(x) 固定了,然后优化 p(y|x),这时候略去常量,得到优化目标为:

然后固定 D(x) 来优化 G(x),这时候相关的 loss 为:

这里包含了我们不知道的 p̃(x),但是假如 D(x) 模型具有足够的拟合能力,那么跟 (11) 式同理,D(x) 的最优解应该是:

这里的是前一阶段的 q(x)。从中解出 q̃(x),代入 loss 得到:

基本分析

可以看到,第一项就是标准的 GAN 生成器所采用的 loss 之一。

多出来的第二项,描述了新分布与旧分布之间的距离。这两项 loss 是对抗的,因为希望新旧分布尽量一致,但是如果判别器充分优化的话,对于旧分布中的样本,D(x) 都很小(几乎都被识别为负样本),所以 −logD(x) 会相当大,反之亦然。这样一来,整个 loss 一起优化的话,模型既要“传承”旧分布,同时要在往新方向 p(1|y) 探索,在新旧之间插值。

我们知道,目前标准的 GAN 的生成器 loss 都不包含,这事实上造成了 loss 的不完备。假设有一个优化算法总能找到 G(z) 的理论最优解、并且 G(z) 具有无限的拟合能力,那么 G(z) 只需要生成唯一一个使得 D(x) 最大的样本(不管输入的 z 是什么),这就是模型坍缩。这样说的话,理论上它一定会发生。

那么,给我们的启发是什么呢?我们设:

也就是说,假设当前模型的参数改变量为 Δθ,那么展开到二阶得到:

我们已经指出一个完备的 GAN 生成器的损失函数应该要包含,如果不包含的话,那么就要通过各种间接手段达到这个效果,上述近似表明额外的损失约为 (Δθ⋅c)2,这就要求我们不能使得它过大,也就是不能使得 Δθ 过大(在每个阶段 c 可以近似认为是一个常数)。

而我们用的是基于梯度下降的优化算法,所以 Δθ 正比于梯度,因此标准 GAN 训练时的很多 trick,比如梯度裁剪、用 adam 优化器、用 BN,都可以解释得通了,它们都是为了稳定梯度,使得 θ 不至于过大,同时,G(z) 的迭代次数也不能过多,因为过多同样会导致 Δθ 过大。

还有,这部分的分析只适用于生成器,而判别器本身并不受约束,因此判别器可以训练到最优。

正则项

现在我们从中算出一些真正有用的内容,直接对进行估算,以得到一个可以在实际训练中使用的正则项。直接计算是难以进行的,但我们可以用 KL(q(x,z)‖q̃(x,z)) 去估算它:

因为有极限:

所以可以将 δ(x) 看成是小方差的高斯分布,代入算得也就是我们有:

所以完整生成器的 loss 可以选为:

也就是说,可以用新旧生成样本的距离作为正则项,正则项保证模型不会过于偏离旧分布。

下面的两个在人脸数据 CelebA 上的实验表明这个正则项是生效的。实验代码修改自:

https://github.com/LynnHo/DCGAN-LSGAN-WGAN-WGAN-GP-Tensorflow

实验一:普通的 DCGAN 网络,每次迭代生成器和判别器各训练一个 batch。

▲ 不带正则项,在25个epoch之后模型开始坍缩

▲ 带有正则项,模型能一直稳定训练

实验二:普通的 DCGAN 网络,但去掉 BN,每次迭代生成器和判别器各训练五个 batch。

▲ 不带正则项,模型收敛速度比较慢

▲ 带有正则项,模型更快“步入正轨”

GAN相关模型

对抗自编码器Adversarial Autoencoders,AAE)和对抗推断学习(Adversarially Learned Inference,ALI)这两个模型是 GAN 的变种之一,也可以被纳入到变分推断中。当然,有了前述准备后,这仅仅就像两道作业题罢了。 

有意思的是,在 ALI 之中,我们有一些反直觉的结果。

GAN视角下的AAE

事实上,只需要在 GAN 的论述中,将 x,z 的位置交换,就得到了 AAE 的框架。 

具体来说,AAE 希望能训练一个编码模型 p(z|x),来将真实分布 q̃(x) 映射为标准高斯分布 q(z)=N(z;0,I),而:

其中 E(x) 即为编码器的神经网络

同 GAN 一样,AAE 引入了一个二元的隐变量 y,并有:

同样直接取 p1=p0=1/2。另一方面,我们设 q(z,y)=q(y|z)q(z),这里的后验分布 p(y|z) 是一个输入为 z 的二元分布,然后去优化 KL(p(z,y)‖q(z,y)):

现在我们优化对象有 q(y|z) 和 E(x),记 q(0|z)=D(z),依然交替优化:先固定 E(x),这也意味着 p(z) 固定了,然后优化 q(y|z),这时候略去常量,得到优化目标为:

然后固定 D(z) 来优化 E(x),这时候相关的 loss 为:

利用 D(z) 的理论最优解,代入 loss 得到:

一方面,同标准 GAN 一样,谨慎地训练,我们可以去掉第二项,得到:

另外一方面,我们可以得到编码器后再训练一个解码器 G(z),但是如果所假设的 E(x),G(z) 的拟合能力是充分的,重构误差可以足够小,那么将 G(z) 加入到上述 loss 中并不会干扰 GAN 的训练,因此可以联合训练:

反直觉的ALI版本

ALI 像是 GAN 和 AAE 的融合,另一个几乎一样的工作是 Bidirectional GAN (BiGAN)。相比于 GAN,它将 z 也作为隐变量纳入到变分推断中。具体来说,在 ALI 中有:

以及 p(x,z,y)=p(y|x,z)p(z|x)p̃(x),然后去优化 KL(q(x,z,y)‖p(x,z,y)):

等价于最小化:

现在优化的对象有 p(y|x,z),p(z|x),q(x|z),记 p(1|x,z)=D(x,z),而 p(z|x) 是一个带有编码器E的高斯分布或狄拉克分布,q(x|z) 是一个带有生成器 G 的高斯分布或狄拉克分布。依然交替优化:先固定 E,G,那么与 D 相关的 loss 为:

跟 VAE 一样,对 p(z|x) 和 q(x|z) 的期望可以通过“重参数”技巧完成。接着固定 D 来优化 G,E,因为这时候有 E 又有 G,整个 loss 没得化简,还是 (37) 那样。但利用 D 的最优解:

可以转化为:

由于 q(x|z),p(x|z) 都是高斯分布,事实上后两项我们可以具体地算出来(配合重参数技巧),但同标准 GAN 一样,谨慎地训练,我们可以简单地去掉后面两项,得到:

这就是我们导出的 ALI 的生成器和编码器的 loss,它跟标准的 ALI 结果有所不同。标准的 ALI(包括普通的 GAN)将其视为一个极大极小问题,所以生成器和编码器的 loss 为:

或:

它们都不等价于 (41)。针对这个差异,事实上笔者也做了实验,结果表明这里的 ALI 有着和标准的 ALI 同样的表现,甚至可能稍好一些(可能是我的自我良好的错觉,所以就没有放图了)。这说明,将对抗网络视为一个极大极小问题仅仅是一个直觉行为,并非总应该如此。

结论综述

本文的结果表明了变分推断确实是一个推导和解释生成模型的统一框架,包括 VAE 和 GAN。通过变分推断的新诠释,我们介绍了变分推断是如何达到这个目的的。 

当然,本文不是第一篇提出用变分推断研究 GAN 这个想法的文章。在《On Unifying Deep Generative Models》一文中,其作者也试图用变分推断统一 VAE 和 GAN,也得到了一些启发性的结果。但笔者觉得那不够清晰。事实上,我并没有完全读懂这篇文章,我不大确定,这篇文章究竟是将 GAN 纳入到了变分推断中了,还是将 VAE 纳入到了 GAN 中。相对而言,我觉得本文的论述更加清晰、明确一些。 

看起来变分推断还有很大的挖掘空间,等待着我们去探索。

PaperWeekly
PaperWeekly

推荐、解读、讨论和报道人工智能前沿论文成果的学术平台。

入门VAEGANAAEALI
5
相关数据
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

高斯分布技术

正态分布是一个非常常见的连续概率分布。由于中心极限定理(Central Limit Theorem)的广泛应用,正态分布在统计学上非常重要。中心极限定理表明,由一组独立同分布,并且具有有限的数学期望和方差的随机变量X1,X2,X3,...Xn构成的平均随机变量Y近似的服从正态分布当n趋近于无穷。另外众多物理计量是由许多独立随机过程的和构成,因而往往也具有正态分布。

重构技术

代码重构(英语:Code refactoring)指对软件代码做任何更动以增加可读性或者简化结构而不影响输出结果。 软件重构需要借助工具完成,重构工具能够修改代码同时修改所有引用该代码的地方。在极限编程的方法学中,重构需要单元测试来支持。

变分自编码器技术

变分自编码器可用于对先验数据分布进行建模。从名字上就可以看出,它包括两部分:编码器和解码器。编码器将数据分布的高级特征映射到数据的低级表征,低级表征叫作本征向量(latent vector)。解码器吸收数据的低级表征,然后输出同样数据的高级表征。变分编码器是自动编码器的升级版本,其结构跟自动编码器是类似的,也由编码器和解码器构成。在自动编码器中,需要输入一张图片,然后将一张图片编码之后得到一个隐含向量,这比原始方法的随机取一个随机噪声更好,因为这包含着原图片的信息,然后隐含向量解码得到与原图片对应的照片。但是这样其实并不能任意生成图片,因为没有办法自己去构造隐藏向量,所以它需要通过一张图片输入编码才知道得到的隐含向量是什么,这时就可以通过变分自动编码器来解决这个问题。解决办法就是在编码过程给它增加一些限制,迫使其生成的隐含向量能够粗略的遵循一个标准正态分布,这就是其与一般的自动编码器最大的不同。这样生成一张新图片就比较容易,只需要给它一个标准正态分布的随机隐含向量,这样通过解码器就能够生成想要的图片,而不需要给它一张原始图片先编码。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

收敛技术

在数学,计算机科学和逻辑学中,收敛指的是不同的变换序列在有限的时间内达到一个结论(变换终止),并且得出的结论是独立于达到它的路径(他们是融合的)。 通俗来说,收敛通常是指在训练期间达到的一种状态,即经过一定次数的迭代之后,训练损失和验证损失在每次迭代中的变化都非常小或根本没有变化。也就是说,如果采用当前数据进行额外的训练将无法改进模型,模型即达到收敛状态。在深度学习中,损失值有时会在最终下降之前的多次迭代中保持不变或几乎保持不变,暂时形成收敛的假象。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

超参数技术

在机器学习中,超参数是在学习过程开始之前设置其值的参数。 相反,其他参数的值是通过训练得出的。 不同的模型训练算法需要不同的超参数,一些简单的算法(如普通最小二乘回归)不需要。 给定这些超参数,训练算法从数据中学习参数。相同种类的机器学习模型可能需要不同的超参数来适应不同的数据模式,并且必须对其进行调整以便模型能够最优地解决机器学习问题。 在实际应用中一般需要对超参数进行优化,以找到一个超参数元组(tuple),由这些超参数元组形成一个最优化模型,该模型可以将在给定的独立数据上预定义的损失函数最小化。

张量技术

张量是一个可用来表示在一些矢量、标量和其他张量之间的线性关系的多线性函数,这些线性关系的基本例子有内积、外积、线性映射以及笛卡儿积。其坐标在 维空间内,有 个分量的一种量,其中每个分量都是坐标的函数,而在坐标变换时,这些分量也依照某些规则作线性变换。称为该张量的秩或阶(与矩阵的秩和阶均无关系)。 在数学里,张量是一种几何实体,或者说广义上的“数量”。张量概念包括标量、矢量和线性算子。张量可以用坐标系统来表达,记作标量的数组,但它是定义为“不依赖于参照系的选择的”。张量在物理和工程学中很重要。例如在扩散张量成像中,表达器官对于水的在各个方向的微分透性的张量可以用来产生大脑的扫描图。工程上最重要的例子可能就是应力张量和应变张量了,它们都是二阶张量,对于一般线性材料他们之间的关系由一个四阶弹性张量来决定。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

梯度下降技术

梯度下降是用于查找函数最小值的一阶迭代优化算法。 要使用梯度下降找到函数的局部最小值,可以采用与当前点的函数梯度(或近似梯度)的负值成比例的步骤。 如果采取的步骤与梯度的正值成比例,则接近该函数的局部最大值,被称为梯度上升。

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

似然函数技术

在数理统计学中,似然函数是一种关于统计模型中的参数的函数,表示模型参数中的似然性。 似然函数在统计推断中有重大作用,如在最大似然估计和费雪信息之中的应用等等。“ 似然性”与“或然性”或“概率”意思相近,都是指某种事件发生的可能性,但是在统计学中,“似然性”和“或然性”或“概率”又有明确的区分。

生成模型技术

在概率统计理论中, 生成模型是指能够随机生成观测数据的模型,尤其是在给定某些隐含参数的条件下。 它给观测值和标注数据序列指定一个联合概率分布。 在机器学习中,生成模型可以用来直接对数据建模(例如根据某个变量的概率密度函数进行数据采样),也可以用来建立变量间的条件概率分布。

对抗学习推理技术

对抗学习推理(ALI)模型是一个深度定向生成模型,它利用对抗过程共同学习一个生成网络和一个推理网络。 这个模型构成了一种将高效推理与生成对抗网络(GAN)框架相结合的新方法。

WGAN技术

就其本质而言,任何生成模型的目标都是让模型(习得地)的分布与真实数据之间的差异达到最小。然而,传统 GAN 中的判别器 D 并不会当模型与真实的分布重叠度不够时去提供足够的信息来估计这个差异度——这导致生成器得不到一个强有力的反馈信息(特别是在训练之初),此外生成器的稳定性也普遍不足。 Wasserstein GAN 在原来的基础之上添加了一些新的方法,让判别器 D 去拟合模型与真实分布之间的 Wasserstein 距离。Wassersterin 距离会大致估计出「调整一个分布去匹配另一个分布还需要多少工作」。此外,其定义的方式十分值得注意,它甚至可以适用于非重叠的分布。

优化器技术

优化器基类提供了计算梯度loss的方法,并可以将梯度应用于变量。优化器里包含了实现了经典的优化算法,如梯度下降和Adagrad。 优化器是提供了一个可以使用各种优化算法的接口,可以让用户直接调用一些经典的优化算法,如梯度下降法等等。优化器(optimizers)类的基类。这个类定义了在训练模型的时候添加一个操作的API。用户基本上不会直接使用这个类,但是你会用到他的子类比如GradientDescentOptimizer, AdagradOptimizer, MomentumOptimizer(tensorflow下的优化器包)等等这些算法。

对抗自编码器技术

对抗自编码器通过使用对抗学习(adversarial learning)避免了使用 KL 散度。在该架构中,训练一个新网络来有区分地预测样本是来自自编码器的隐藏代码还是来自用户确定的先验分布 p(z)。编码器的损失函数现在由重建损失函数与判别器网络(discriminator network)的损失函数组成。

推荐文章
暂无评论
暂无评论~