Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

Xue Bin Peng 等作者Geek AI、路编译openreview选自

有效稳定对抗模型训练过程,伯克利提出变分判别器瓶颈


近期,加州大学伯克利分校的研究者提出一种新型简单而通用的方法变分判别器瓶颈(VDB),利用信息瓶颈约束判别器内信息流,通过对观测结果和判别器内部表征之间的互信息进行约束来稳定对抗性模型的训练过程。该论文已被 ICLR 2019 接收,获得了6、10、8的评分。

对抗性学习方法为具有复杂的内部关联结构的高维数据分布的建模提供了一种很有发展前景的方法。这些方法通常使用判别器来监督生成器的训练,从而产生与原始数据极为相似、难以区分的样本。生成对抗网络(GAN)就是对抗性学习方法的一个实例,它可以用于高保真的图像生成任务(Goodfellow et al., 2014; Karrasrt et al.,2017)和其他高维数据的生成(Vondrick et al.,2016;Xie et al.,2018;Donahue et al.,2018)。在逆向强化学习(inverse reinforcement learning)框架中也可以使用对抗性方法学习奖励函数,或者直接生成模仿学习的专家演示样例(Ho & Ermon, 2016)。然而,对抗性学习方法的优化问题面临着很大的挑战,如何平衡生成器和判别器的性能就是其中之一。一个具有很高准确率的判别器可能会产生信息量较少的梯度,但是一个弱的判别器也可能会不利于提高生成器的学习能力。这些挑战引起了人们对对抗性学习算法的各种稳定方法的广泛兴趣(Arjovsky et al., 2017; Kodali et al., 2017; Berthelot et al., 2017)。

本研究提出了一种简单的对抗性学习正则化方法,该方法利用信息瓶颈的变分近似约束从输入到判别器的信息流。通过对输入的观测数据和判别器的内部表征之间的互信息施加约束,我们可以促使判别器学习到使原始数据和生成器的数据分布有很多重叠的数据表征,从而有效地调整判别器的准确率并维持生成器能够带有足够信息量的梯度。这一使对抗性学习稳定的方法可以看作是实例噪声的自适应方差(Salimans et al., 2016; Sønderby et al., 2016; Arjovsky & Bottou, 2017)。然而,该研究证明了这种方法的自适应特性至关重要。约束判别器内部表征和输入之间的互信息可以使正则化项能够直接限制判别器的准确率,这可以自动完成对噪声大小的选择,并将这样的噪声应用到该输入的压缩表征上,该表征是经过专门优化的,能够对生成器生成的数据和原始数据分布之间最明显的差异进行建模。

这项工作的主要贡献是变分判别器瓶颈(variational discriminator bottleneck,VDB),如图 1 所示,这是一种用于对抗性学习的自适应随机正则化方法,可以显著提高其在不同应用领域上的性能。该方法可以很容易地应用于各种任务和架构。首先,研究者在一组具有挑战性的模仿学习任务上评估了该方法,这些任务包括从模拟人形机器人的 mocap 数据(动作捕捉数据)中学习高度复杂的技巧。该研究提出的方法还使模拟机器人能够直接从原始视频演示样例中学习动态连续的控制技能,相对于使用对抗性模仿学习的之前工作有很大的改进。研究者进一步评估了逆向强化学习技术的有效性,该技术可以从演示样例中恢复奖励函数,用于训练未来的策略。最后,研究者将该框架应用于生成对抗网络的图像生成任务上,在许多情况下,使用 VDB 可以提高模型的性能。

图 1:该研究提出的方法是通用的,可应用于大量对抗学习任务。左:使用对抗模仿学习进行运动模拟。中:图像生成。右:通常逆向强化学习学习可迁移奖励函数。

论文:VARIATIONAL DISCRIMINATOR BOTTLENECK: IMPROVING IMITATION LEARNING, INVERSE RL, AND GANS BY CONSTRAINING INFORMATION FLOW 

  • 论文链接:https://openreview.net/forum?id=HyxPx3R9tm

  • arXiv链接:https://arxiv.org/pdf/1810.00821.pdf

摘要:对抗性学习方法已经得到了广泛的应用,但是众所周知,对抗性模型的训练是很不稳定的。有效地平衡生成器和判别器的性能是至关重要的,因为一个判别器如果达到很高的准确率,就会产生信息量较少的梯度。本研究提出了一种利用信息瓶颈约束判别器内信息流的简单而通用的技术。通过对观测结果和判别器内部表征之间的互信息进行约束,我们可以有效地调整判别器的准确率,保持有用的、信息量较大的梯度。我们证明了我们提出的变分判别器瓶颈(VDB)可以在三个不同的对抗式学习算法应用领域中显著提升模型的性能。我们初步的评估研究了 VDB 对动态连续控制技能(如跑步)的模仿学习的适用性,证明我们的方法可以直接从原始视频演示样例中学习这些技能,大大超过之前的对抗性模仿学习方法的性能。VDB 还可以与逆向强化学习结合,学习可在新的环境下迁移并重新优化的简洁奖励函数。最后,我们证明了 VDB 可以更有效地训练用于生成图像的 GAN,相对于之前的稳定方法取得了一定的提升。

图 2:左图:变分判别器瓶颈概览。编码器首先将样本 x 映射到潜在分布 E(z|x) 上。接着,训练判别器从潜在分布中对样本 z 进行分类,将信息瓶颈 I(X, Z) ≤ I_c 作用于 z。右图:被训练用来通过不同 KL 边界 I_c 区分两个高斯分布的判别器的可视化结果。

图 3:模拟人形机器人正在执行各种技能。VAIL 能够从 mocap 数据中逼真地模仿各种技能。

图 4:比较 VAIL 与其他运动模仿方法的学习曲线。使用模拟特征与参考运动形态之间的平均关节旋转误差来测量性能。每种方法都使用 3 个随机种子进行评估。

表 1:人形机器人在运动模拟任务中的平均关节旋转误差(弧度值)。除了使用(Peng et al., 2018)中人为设计的奖励函数训练的策略,VAIL 在所有技能的评估中都优于其它方法。

图 7:左图:C 形迷宫和 S 形迷宫。当在左边的迷宫中训练时,AIRL 学得对于训练任务过拟合的奖励,因此这个奖励不能迁移到右边的迷宫中。相比之下,VAIRL 学习了一种更平滑的奖励函数,可以实现更可靠的迁移。右图:两个训练迷宫的翻转测试的性能。我们报告了 5 次运行后的模仿学习任务的平均返回值(±std. dev)以及用于生成演示样例的单个专家的平均返回值。

图 8: 在 CIFAR-10 数据集上使用 VGAN 和其它方法的对比结果,这里使用 Frechet Inception 距离(FID)作为评价指标。

图 9:在 CIFAR-10、CelebA 128×128 和 CelebAHQ 1024×1024 数据集上使用 VGAN 得到的随机图像样本。

理论ICLRICLR 2019对抗训练模仿学习图像生成信息瓶颈论文
4
相关数据
高斯分布技术

正态分布是一个非常常见的连续概率分布。由于中心极限定理(Central Limit Theorem)的广泛应用,正态分布在统计学上非常重要。中心极限定理表明,由一组独立同分布,并且具有有限的数学期望和方差的随机变量X1,X2,X3,...Xn构成的平均随机变量Y近似的服从正态分布当n趋近于无穷。另外众多物理计量是由许多独立随机过程的和构成,因而往往也具有正态分布。

学习曲线技术

在机器学习领域,学习曲线通常是表现学习准确率随着训练次数/时长/数据量的增长而变化的曲线

准确率技术

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

过拟合技术

过拟合是指为了得到一致假设而使假设变得过度严格。避免过拟合是分类器设计中的一个核心任务。通常采用增大数据量和测试样本集的方法对分类器性能进行评价。

正则化技术

当模型的复杂度增大时,训练误差会逐渐减小并趋向于0;而测试误差会先减小,达到最小值后又增大。当选择的模型复杂度过大时,过拟合现象就会发生。这样,在学习时就要防止过拟合。进行最优模型的选择,即选择复杂度适当的模型,以达到使测试误差最小的学习目的。

生成对抗网络技术

生成对抗网络是一种无监督学习方法,是一种通过用对抗网络来训练生成模型的架构。它由两个网络组成:用来拟合数据分布的生成网络G,和用来判断输入是否“真实”的判别网络D。在训练过程中,生成网络-G通过接受一个随机的噪声来尽量模仿训练集中的真实图片去“欺骗”D,而D则尽可能的分辨真实数据和生成网络的输出,从而形成两个网络的博弈过程。理想的情况下,博弈的结果会得到一个可以“以假乱真”的生成模型。

模仿学习技术

模仿学习(Imitation Learning)背后的原理是是通过隐含地给学习器关于这个世界的先验信息,就能执行、学习人类行为。在模仿学习任务中,智能体(agent)为了学习到策略从而尽可能像人类专家那样执行一种行为,它会寻找一种最佳的方式来使用由该专家示范的训练集(输入-输出对)。

强化学习技术

强化学习是一种试错方法,其目标是让软件智能体在特定环境中能够采取回报最大化的行为。强化学习在马尔可夫决策过程环境中主要使用的技术是动态规划(Dynamic Programming)。流行的强化学习方法包括自适应动态规划(ADP)、时间差分(TD)学习、状态-动作-回报-状态-动作(SARSA)算法、Q 学习、深度强化学习(DQN);其应用包括下棋类游戏、机器人控制和工作调度等。

推荐文章
暂无评论
暂无评论~