武广作者合肥工业大学硕士生学校图像生成研究方向

UC Berkeley提出变分判别器瓶颈,有效提高对抗学习平衡性

本期推荐的论文笔记来自 PaperWeekly 社区用户 @TwistedW作者今天要解读的是 UC Berkeley 投稿 ICLR 2019 的工作。

对抗学习中判别器一直保持着强大的侵略优势,造成了对抗中的不平衡。本文采用变分判别器瓶颈(Variational Discriminator Bottleneck,VDB),通过对数据样本和编码到的特征空间的互信息进行限制,提高判别器的判别难度,进而提高了对抗学习中的平衡性。实验表明 VDB 思想可以在 GAN、模仿学习和逆强化学习上取得不小的进步。

引言

生成对抗网络中判别器在二分类游戏上表现了强大的区分能力,RSGAN 使用相对判别器将真假样本混合利用“图灵测试”的思想削弱了判别器的能力,T-GANs 将 RSGAN 一般化到其它 GAN 模型下,判别器得到限制在整体上平衡了生成器和判别器,可以使 GAN 训练上更加稳定。VDB 则通过对判别器加上互信息瓶颈来限制判别器的能力。

论文引入

GAN 存在两大固有问题,一个是生成上多样性不足;另一个就是当判别器训练到最优时,生成器的梯度消失。造成梯度消失的原因在于生成样本和真实样本在分布上是不交叠的,WGAN [1]提出可以通过加入噪声来强制产生交叠,但是如何控制噪声加入以及能否保证交叠都是存在问题的。WGAN 以及它的改进虽然在 GAN 训练中稳定性上提高了,但是对于样本真假的二分类判别上,判别器展现了过于强大的能力,这样打破了对抗上的平衡问题,最终还是造成训练阶段的不稳定(不平衡,生成质量提不上去)。 

RSGAN 提出了采用相对判别器通过区分真假样本混合在一起判断真假,这样判别器不再是判断真或假,还要在一堆样本下将真假样本分开。这样对于判别器的要求提高了,难度上来后自然会进一步平衡训练,

关于 RSGAN 的进一步理解可参看RSGAN:对抗模型中的“图灵测试”思想T-GANs 更是进一步将 RSGAN 一般化,让RSGAN中的混合真假样本的思想得到充分应用,具体了解,可参看T-GANs:基于“图灵测试”的生成对抗模型。 

我们今天要解读的文章是变分判别器瓶颈(Variational Discriminator Bottleneck,VBD)。论文通过对互信息加上限制来削弱判别器的能力,从而平衡网络的训练。这种对判别器互信息限制,不仅可以用在 GAN 的训练上,对于模仿学习和逆强化学习都有很大的提高。由于我更加关注 VDB 在 GAN 上的应用,所以在模仿学习强化学习方面将只做简短介绍,把重点放在 VDB 在 GAN 上的作用。 

在开启正文前,我们一起看一下互信息瓶颈限制在监督学习上的正则作用。这个思想在 16 年被 Alemi 提出,原文叫 Deep Variational Information Bottleneck [2]。我们有数据集 {xi,yi},其中 xi 为数据样本,yi 为对应的标签,通过最大似然估计优化模型:

这种最大似然估计方法往往会造成过拟合的现象,这时候就需要一定的正则化。变分互信息瓶颈则是鼓励模型仅关注最具辨别力的特征,从而对模型做一定的限制。

为了实现这种信息瓶颈,需要引入编码器对样本特征先做提取 E(z|x) 将样本编码到特征空间 z,通过对样本 x 和特征空间 z 的互信息 I(X,Z) 做限制,即 I(X,Z)≤Ic,则正则化目标:

此时最大似然估计就是对模型 q(y|z) 操作的,实现将特征空间 z 到标签 y,互信息定义为:

这里的 p(x) 为数据样本的分布,p(x,z)=p(x)E(z|x),计算分布 p(z)=∫p(x)E(z|x)dx是困难的,p(z) 是数据编码得到的,这个分布是很难刻画的,但是使用边际的近似 r(z) 可以获得变分下界。

取 KL[p(z)‖r(z)]=∫p(z)logp(z)−∫p(z)logr(z)≥0,此时 ∫p(z)logp(z)≥∫p(z)logr(z),I(X,Z) 可以表示为:

这提供了正则化的上界,J̃(q,E)≥J(q,E)。

优化的时候可以采取拉格朗日系数 β。我们从整体上分析一下这个互信息的瓶颈限制,互信息反应的是两个变量的相关程度,而我们得到的特征空间 z 是由 x 编码得到的,理论上已知 x 就可确定 z,x 和 z 是完全相关的,也就是 x 和 z 的互信息是较大的。

而现在限制了互信息的值,这样就切断了一部分 x 和 z 的相关性,保留的相关性是 x 和 z 最具辨别力的特征,而其它相关性较低的特征部分将被限制掉,从而使得模型不至于过度学习,从而实现正则化的思想。

VDB 正是把这个用在监督学习的正则思想用到了判别器上,从而在 GAN、模仿学习和逆强化学习上都取得了不小的提升。

总结一下 VDB 的优势:

  • 判别器信息瓶颈是对抗性学习的自适应随机正则化方法,可显著提高各种不同应用领域的性能;

  • 在 GAN、模仿学习和逆强化学习上取得性能上的改进。

VDB在GAN中的实现

VDB其实是在 Deep Variational Information Bottleneck [2] 的基础上将互信息思想引入到判别器下,如果上面描述的互信息瓶颈读懂的话,这一块将很好理解。

对于传统 GAN,我们先定义下各个变量(保持和原文一致)。真实数据样本分布 p∗(x),生成样本分布 G(x),判别器为 D,生成器为 G,目标函数为:

类似于 Deep Variational Information Bottleneck [2],文章也是先对数据样本做了 Encoder,经数据编码到特征空间下,这样一来降低了数据的维度,同时将真假样本都做低维映射,更加可能实现一定的交叠。

当然这个不是文章的重点,文章的重点还是为了在互信息上实行瓶颈限制。将数据编码得到的 z 和数据 x 的互信息做瓶颈限制,我们先看目标函数,再来解释为什么做了瓶颈限制可以降低判别器的能力。

这里强调一下,这个我们待会再进一步分析,同样可以通过引入拉格朗日系数优化目标函数

我们分析一下限制互信息瓶颈在 GAN 中起到的作用,同样的互信息是样本 x 和它经过编码得到的特征空间 z。互信息表示变量间的相关程度,通过限制 x 和 z 的相关性,对于很具有辨识性的特征,判别器将可以区分真假,但是经过信息瓶颈限制把样本和特征空间相关性不足的特征限制住,这样判别器就增加了区分样本真假的难度。

判别器在这个二分类游戏下只能通过相关性很强的特征来判断真假,对于限制条件下,这个的作用是对整体样本的互信息都进行限制,这样真假样本都进行了混淆,判别器判断难度提高,游戏得到进一步平衡。 

文章通过实验进一步说明了判别器加入信息瓶颈的作用,通过对两个不同的高斯分布进行区别,左侧认为是假(判为 0),右侧认为是真(判为 1),经过信息瓶颈限制 Ic 的调整,得到的结果如下图: 

我们知道,在二分类下信息熵最小是 1bit(当两个事件等概率发生时),由于 x 和 z 是完全相关,我们可以理解理想状态此时的互信息最小是 1bit,当不断减小瓶颈 Ic 的值,上图中由 10 降到 0.1,这个过程中判别器区分两个分布的界限越来越弱,达到了限制判别器能力的效果。

对于网络的优化,主要是对 β 的更新上:

这个互信息瓶颈还可以用在模范学习和逆强化学习上,都取得了一定的改进,感兴趣的可以查看原文进一步了解。

实验

VDB 在 GAN 中的应用实验,作者对 CIFAR10 做了各个模型的 FID 定量对比。为了改善 VDB 在 GAN 上的性能,作者在 VDB 和 GAN 中加入了梯度惩罚,命名为 VGAN-GP。

这样可谓是又进一步限制了判别器,反正实验效果是有所提升,可以猜测作者用到的 GAN 的损失函数肯定基于 WGAN,文中说了代码即将公布,在没看到源码前只能猜测一下。

不过,通过后文实验做到了 1024 × 1024 可以看出,作者所在的实验室一定不简单,跑得动 1024 的图,只能表示一下敬意。

最后,来看一下作者展示的视频 Demo。

总结

在本文中,作者提出了变判别器瓶颈,这是一种用于对抗学习的一般正则化技术。实验表明,VDB 广泛适用于各种领域,并且在许多具有挑战性的任务方面比以前的技术产生了显着的改进。

通过对判别器加入信息瓶颈,限制了判别器的能力,使得对抗中保持平衡,提高了训练的稳定性。这种正则化思想可以在各类 GAN 模型下适用,后续还要对 VDB 做进一步实验上的分析。

参考文献

[1] Martin Arjovsky, Soumith Chintala, and Léon Bottou. Wasserstein generative adversarial networks. In International Conference on Machine Learning, pages 214–223, 2017.

[2] Alexander A. Alemi, Ian Fischer, Joshua V. Dillon, and Kevin Murphy. Deep variational information bottleneck. CoRR, abs/1612.00410, 2016.

PaperWeekly
PaperWeekly

推荐、解读、讨论和报道人工智能前沿论文成果的学术平台。

理论强化对抗学习UC Berkeley最大似然估计监督学习T-GANsWGAN图灵测试生成对抗网络强化学习GANVDB模仿学习
2
相关数据
机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

高斯分布技术

正态分布是一个非常常见的连续概率分布。由于中心极限定理(Central Limit Theorem)的广泛应用,正态分布在统计学上非常重要。中心极限定理表明,由一组独立同分布,并且具有有限的数学期望和方差的随机变量X1,X2,X3,...Xn构成的平均随机变量Y近似的服从正态分布当n趋近于无穷。另外众多物理计量是由许多独立随机过程的和构成,因而往往也具有正态分布。

最大似然估计技术

极大似然估计是统计学中用来估计概率模型参数的一种方法

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

信息熵技术

在信息论中,熵是接收的每条消息中包含的信息的平均量,又被称为信息熵、信源熵、平均自信息量。这里,“消息”代表来自分布或数据流中的事件、样本或特征。熵的单位通常为比特,但也用Sh、nat、Hart计量,取决于定义用到对数的底。

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

监督学习技术

监督式学习(Supervised learning),是机器学习中的一个方法,可以由标记好的训练集中学到或建立一个模式(函数 / learning model),并依此模式推测新的实例。训练集是由一系列的训练范例组成,每个训练范例则由输入对象(通常是向量)和预期输出所组成。函数的输出可以是一个连续的值(称为回归分析),或是预测一个分类标签(称作分类)。

目标函数技术

目标函数f(x)就是用设计变量来表示的所追求的目标形式,所以目标函数就是设计变量的函数,是一个标量。从工程意义讲,目标函数是系统的性能标准,比如,一个结构的最轻重量、最低造价、最合理形式;一件产品的最短生产时间、最小能量消耗;一个实验的最佳配方等等,建立目标函数的过程就是寻找设计变量与目标的关系的过程,目标函数和设计变量的关系可用曲线、曲面或超曲面表示。

过拟合技术

过拟合是指为了得到一致假设而使假设变得过度严格。避免过拟合是分类器设计中的一个核心任务。通常采用增大数据量和测试样本集的方法对分类器性能进行评价。

正则化技术

当模型的复杂度增大时,训练误差会逐渐减小并趋向于0;而测试误差会先减小,达到最小值后又增大。当选择的模型复杂度过大时,过拟合现象就会发生。这样,在学习时就要防止过拟合。进行最优模型的选择,即选择复杂度适当的模型,以达到使测试误差最小的学习目的。

图灵测试技术

图灵测试(英语:Turing test,又译图灵试验)是图灵于1950年提出的一个关于判断机器是否能够思考的著名试验,测试某机器是否能表现出与人等价或无法区分的智能。测试的谈话仅限于使用唯一的文本管道,例如计算机键盘和屏幕,这样的结果是不依赖于计算机把单词转换为音频的能力。

生成对抗网络技术

生成对抗网络是一种无监督学习方法,是一种通过用对抗网络来训练生成模型的架构。它由两个网络组成:用来拟合数据分布的生成网络G,和用来判断输入是否“真实”的判别网络D。在训练过程中,生成网络-G通过接受一个随机的噪声来尽量模仿训练集中的真实图片去“欺骗”D,而D则尽可能的分辨真实数据和生成网络的输出,从而形成两个网络的博弈过程。理想的情况下,博弈的结果会得到一个可以“以假乱真”的生成模型。

WGAN技术

就其本质而言,任何生成模型的目标都是让模型(习得地)的分布与真实数据之间的差异达到最小。然而,传统 GAN 中的判别器 D 并不会当模型与真实的分布重叠度不够时去提供足够的信息来估计这个差异度——这导致生成器得不到一个强有力的反馈信息(特别是在训练之初),此外生成器的稳定性也普遍不足。 Wasserstein GAN 在原来的基础之上添加了一些新的方法,让判别器 D 去拟合模型与真实分布之间的 Wasserstein 距离。Wassersterin 距离会大致估计出「调整一个分布去匹配另一个分布还需要多少工作」。此外,其定义的方式十分值得注意,它甚至可以适用于非重叠的分布。

模仿学习技术

模仿学习(Imitation Learning)背后的原理是是通过隐含地给学习器关于这个世界的先验信息,就能执行、学习人类行为。在模仿学习任务中,智能体(agent)为了学习到策略从而尽可能像人类专家那样执行一种行为,它会寻找一种最佳的方式来使用由该专家示范的训练集(输入-输出对)。

强化学习技术

强化学习是一种试错方法,其目标是让软件智能体在特定环境中能够采取回报最大化的行为。强化学习在马尔可夫决策过程环境中主要使用的技术是动态规划(Dynamic Programming)。流行的强化学习方法包括自适应动态规划(ADP)、时间差分(TD)学习、状态-动作-回报-状态-动作(SARSA)算法、Q 学习、深度强化学习(DQN);其应用包括下棋类游戏、机器人控制和工作调度等。

推荐文章
暂无评论
暂无评论~