谷歌大脑:像BigGAN那样生成高清大图不一定需要大量图像标签

原版的 GAN 是一种无监督学习,我们只要准备大量真实数据就行了。而如果要像 BigGAN 那样在 ImageNet 上生成高保真度的图像,我们还是需要大量类别信息。本研究介绍了如何在没有标注或有少量标注数据的情况下生成高保真图像,这大大缩小了条件GAN 与无监督 GAN 的差距。

正如 GoodFellow 所言,尽管 GAN 本身是无监督的,但高保真自然图像的生成(通常在 ImageNet 上训练)取决于能否访问大量标注数据。这并不奇怪,因为标签会在训练过程中引入丰富的辅助信息,从而有效地将极具挑战性的图像生成任务分成语义上有意义的子任务。

但是,实际上大部分数据是未标注的,而标注通常成本较高,还容易出错。虽然无监督图像生成近期取得了一些进展,但就样本质量而言,条件模型和无监督模型之间的差距还是很大的。

图 1:基线方法和本文提出方法的 FID 得分。垂直线表示使用了所有标注数据的基线(BigGAN)。本文提出的方法(S^3GAN)仅用 10% 的标注数据就能够媲美当前最佳水平的基线模型,用 20% 的标注数据就超过了基线。

本文使用生成对抗网络,大大缩小了条件模型和无监督模型在高保真图像生成方面的差距。本文利用了两个简单但强大的概念:

  • 监督学习:通过自监督来学习训练数据的语义特征提取器,然后用生成的特征表示来指导 GAN 训练过程。

  • 监督学习:通过标注训练图像的较小子集来推断出整个训练集的标签,然后将推断出来的标签用作 GAN 训练的条件信息。

本文贡献如下:

  • 提出并研究了多种方法,来减少或完全删去用于自然图像生成任务的真值标注信息。

  • 以无监督生成的方式在 ImageNet 上达到了新的 SOTA,仅用 10% 的标注数据就在 128x128 的图像上达到了当前 SOTA 结果,仅用 20% 的标注数据就实现了新的 SOTA(由 FID 衡量)。

  • 开源了实验中使用的所有代码:http://github.com/google/compare_gan

论文 :High-Fidelity Image Generation With Fewer Labels

论文地址:https://arxiv.org/abs/1903.02271

摘要:深度生成模型正在成为当代机器学习的基石。近期针对条件生成对抗网络的研究表明,自然图像的学习复杂度、高维度分布也成为了可以解决的问题。虽然最新的模型能够生成高分辨率、高保真的多种自然图像,但它们极度依赖大量标注数据。在本文中,我们展示了如何利用在自监督、半监督学习领域取得的进步,在无监督 ImageNet 合成和条件环境下实现超越 SOTA 模型的性能。特别是,我们提出的方法仅用 10% 的标注数据就能媲美当前条件模型 BigGAN 在 ImageNet 上的 SOTA 采样质量,仅用 20% 的标注数据就超越了它。

减少标注数据的需求

简而言之,我们并不会为判别器提供经手动标注的真实图像,而是提供推断的标注。为了获得这些标签,我们将利用自监督和半监督学习的最新进展。在解释这些方法前,我们首先探讨一下标签信息在 SOTA GAN 中发挥了什么作用。以下阐述会先假定我们比较熟悉 Goodfellow 等人提出的 GAN 框架。

为了向判别器提供标签信息,我们采用了 Miyato&Koyama(2018)提出的线性映射层。在原版的 GAN 中(unconditional),判别器 D 会学习预测输入图像 x 到底是真实的还是由生成器 G 生成的。我们可以将判别器分解为一个学习的判别器表征 D˜ 和判别函数 c_r/f,其中 D˜ 馈送到一个线性分类器中,也就是说判别器可以表示为 c_r/f(D˜(x))。

映射判别器中,它可以学习到判别器每一个类别的嵌入向量表征 D˜(x)。因此给定一张图像与标签(x, y),决定样本是真实值还是伪造值的元素有两个:(a) 表征 D˜(x) 本身是与真实数据一致的,(b) 同时表征 D˜(x) 也和类别 y 的真实数据相一致。正式而言,判别器采样自 D(x, y) = c_r/f(D˜(x)) + P(D˜(x), y),其中

为应用到特征向量 x tiled 和 one-hot 标注向量 y 的线性映射层。作为生成器,标注信息 y 会通过类别受限的 BatchNorm 进行合并。带有映射判别器的受限 GAN 在图三中有展示:

图 3:带有映射判别器的 Conditional GAN。

我们首先使用 SOTA 自监督方法学习真实数据的表征,在此表征上执行聚类,并使用不同的集群作为类别的替代。

图 4:CLUSTERING:通过无监督聚类处理求解自监督任务获得的表征。

结果和讨论

本文的主要目标是以无监督的方式或用较少的标注数据来达到(或超过)全监督 BigGAN 的性能。接下来,本文将据此分析本文方法的优势和缺点。

研究者对基线模型 BigGAN 重新实现,获得了 8.4 的 FID 分数和 75.0 的 IS 分数,复现了 Brock 等人(2019)的结果。研究者在训练的动态过程中发现了一些不同之处,将在 5.4 中讨论。

图 7:本研究提出的无监督方法获得的中值 FID 分数。垂直线表示 BIGGAN 实现的 FID 中值,该实现为所有训练图像使用标签。尽管无监督方法和全监督方法之间的差距仍然很大,但与单标签和随机标签相比,使用预训练自监督表征(聚类)提高了样本质量,从而在 IMAGENET 上实现了新的 SOTA 结果。

表 2:无监督方法获得的中值 FID 和 IS 分数(平均值和标准差见附录中的表 14)。

表 3:使用自监督和半监督损失(见 3.1)在的 IMAGENET 验证集上获得的 Top-1 和 top-5 误差率(%)。虽然与全监督 IMAGENET 分类任务相比,这些模型显然不是当前最优,但标签的质量已堪匹敌、在某些情况下甚至还可以改进当前最优的 GAN 自然图像合成结果。

表 4:预训练 vs 联合训练方法以及自监督方法在 GAN 训练期间的效果。尽管联合训练方法优于全监督方法,但预训练方法更胜一筹。在任何情况下,自监督在 GAN 训练过程中都很有用。

表 5:使用硬(预测)标签训练得到的模型要比软(预测)标签训练模型更好(均值和标准差参见附录中的表 13)。

图 8:垂直线表示使用所有标注数据实现的 BIGGAN 的 FID 中值。本研究提出的 S^3 GAN 方法使用 10% 的 ground-truth 标注数据的表现就能与 SOTA BIGGAN 模型相媲美,使用 20% 的标注数据后性能就能超越 SOTA 模型。

理论Ian GoodfellowGoogle Brain
3
相关数据
线性分类器技术

机器学习通过使用对象的特征来识别它所属的类(或组)来进行统计分类。线性分类器通过基于特征的线性组合的值进行分类决策。 对象的特征也称为特征值,通常在称为特征向量的向量中呈现给机器。

半监督学习技术

半监督学习属于无监督学习(没有任何标记的训练数据)和监督学习(完全标记的训练数据)之间。许多机器学习研究人员发现,将未标记数据与少量标记数据结合使用可以显着提高学习准确性。对于学习问题的标记数据的获取通常需要熟练的人类代理(例如转录音频片段)或物理实验(例如,确定蛋白质的3D结构或确定在特定位置处是否存在油)。因此与标签处理相关的成本可能使得完全标注的训练集不可行,而获取未标记的数据相对便宜。在这种情况下,半监督学习可能具有很大的实用价值。半监督学习对机器学习也是理论上的兴趣,也是人类学习的典范。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

验证集技术

验证数据集是用于调整分类器超参数(即模型结构)的一组数据集,它有时也被称为开发集(dev set)。

深度生成模型技术

深度生成模型基本都是以某种方式寻找并表达(多变量)数据的概率分布。有基于无向图模型(马尔可夫模型)的联合概率分布模型,另外就是基于有向图模型(贝叶斯模型)的条件概率分布。前者的模型是构建隐含层(latent)和显示层(visible)的联合概率,然后去采样。基于有向图的则是寻找latent和visible之间的条件概率分布,也就是给定一个随机采样的隐含层,模型可以生成数据。 生成模型的训练是一个非监督过程,输入只需要无标签的数据。除了可以生成数据,还可以用于半监督的学习。比如,先利用大量无标签数据训练好模型,然后利用模型去提取数据特征(即从数据层到隐含层的编码过程),之后用数据特征结合标签去训练最终的网络模型。另一种方法是利用生成模型网络中的参数去初始化监督训练中的网络模型,当然,两个模型需要结构一致。

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

监督学习技术

监督式学习(Supervised learning),是机器学习中的一个方法,可以由标记好的训练集中学到或建立一个模式(函数 / learning model),并依此模式推测新的实例。训练集是由一系列的训练范例组成,每个训练范例则由输入对象(通常是向量)和预期输出所组成。函数的输出可以是一个连续的值(称为回归分析),或是预测一个分类标签(称作分类)。

图像生成技术

图像生成(合成)是从现有数据集生成新图像的任务。

生成对抗网络技术

生成对抗网络是一种无监督学习方法,是一种通过用对抗网络来训练生成模型的架构。它由两个网络组成:用来拟合数据分布的生成网络G,和用来判断输入是否“真实”的判别网络D。在训练过程中,生成网络-G通过接受一个随机的噪声来尽量模仿训练集中的真实图片去“欺骗”D,而D则尽可能的分辨真实数据和生成网络的输出,从而形成两个网络的博弈过程。理想的情况下,博弈的结果会得到一个可以“以假乱真”的生成模型。

聚类技术

将物理或抽象对象的集合分成由类似的对象组成的多个类的过程被称为聚类。由聚类所生成的簇是一组数据对象的集合,这些对象与同一个簇中的对象彼此相似,与其他簇中的对象相异。“物以类聚,人以群分”,在自然科学和社会科学中,存在着大量的分类问题。聚类分析又称群分析,它是研究(样品或指标)分类问题的一种统计分析方法。聚类分析起源于分类学,但是聚类不等于分类。聚类与分类的不同在于,聚类所要求划分的类是未知的。聚类分析内容非常丰富,有系统聚类法、有序样品聚类法、动态聚类法、模糊聚类法、图论聚类法、聚类预报法等。

推荐文章
暂无评论
暂无评论~