Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

带自注意力机制的生成对抗网络,实现效果怎样?

在前一段时间,Han Zhang 和 Goodfellow 等研究者提出添加了自注意力机制生成对抗网络,这种网络可使用全局特征线索来生成高分辨率细节。本文介绍了自注意力生成对抗网络的 PyTorch 实现,读者也可以尝试这一新型生成对抗网络

项目地址:https://github.com/heykeetae/Self-Attention-GAN

这个资源库提供了一个使用 PyTorch 实现的 SAGAN。其中作者准备了 wgan-gp 和 wgan-hinge 损失函数,但注意 wgan-gp 有时与谱归一化(spectral normalization)是不匹配的;因此,作者会移除模型所有的谱归一化来适应 wgan-gp。

在这个实现中,自注意机制会应用到生成器和鉴别器的两个网络层。像素级的自注意力会增加 GPU 资源的调度成本,且每个像素有不同的注意力掩码。Titan X GPU 大概可选择的批量大小为 8,你可能需要减少自注意力模块的数量来减少内存消耗。

目前更新状态:

  • 注意力可视化 (LSUN Church-outdoor)

  • 无监督设置(现未使用标签)

  • 已应用:Spectral Normalization(代码来自 https://github.com/christiancosgrove/pytorch-spectral-normalization-gan)

  • 已实现:自注意力模块(self-attention module)、两时间尺度更新规则(TTUR)、wgan-hinge 损失函数和 wgan-gp 损失函数

结果

下图展示了 LSUN 中的注意力结果 (epoch #8):

SAGAN 在 LSUN church-outdoor 数据集上的逐像素注意力结果。这表示自注意力模块的无监督训练依然有效,即使注意力图本身并不具有可解释性。更好的图片生成结果以后会添加,上面这些是在生成器第层 3 和层 4 中的自注意力的可视化,它们的尺寸依次是 16 x 16 和 32 x 32,每一张都包含 64 张注意力图的可视化。要可视化逐像素注意力机制,我们只能如左右两边的数字显示选择一部分像素。

CelebA 数据集 (epoch on the left, 还在训练中):

LSUN church-outdoor 数据集 (epoch on the left, 还在训练中):

训练环境:

  • Python 3.5+ (https://www.continuum.io/downloads)

  • PyTorch 0.3.0 (http://pytorch.org/)

用法

1. 克隆版本库

$ git clone https://github.com/heykeetae/Self-Attention-GAN.git

$ cd Self-Attention-GAN

2. 下载数据集 (CelebA 或 LSUN)

$ bash download.sh CelebA
or
$ bash download.sh LSUN

3. 训练

$ python python main.py --batch_size 6 --imsize 64 --dataset celeb --adv_loss hinge --version sagan_celeb
or
$ python python main.py --batch_size 6 --imsize 64 --dataset lsun --adv_loss hinge --version sagan_lsun

4. 享受结果吧~

$ cd samples/sagan_celeb
or
$ cd samples/sagan_lsun

每 100 次迭代生成一次样本,抽样率可根据参数 --sample_step (ex,—sample_step 100) 控制。

论文:Self-Attention Generative Adversarial Networks



论文地址:https://arxiv.org/abs/1805.08318


在此论文中,我们提出了自注意生成式对抗网络(SAGAN),能够为图像生成任务实现注意力驱动的、长范围的依存关系建模。传统的卷积 GAN 只根据低分辨特征图中的空间局部点生成高分辨率细节(detail)。在 SAGAN 中,可使用所有特征点的线索来生成高分辨率细节,而且鉴别器能检查图片相距较远部分的细微细节特征是否彼此一致。不仅如此,近期研究表明鉴别器调节可影响 GAN 的表现。根据这个观点,我们在 GAN 生成器中加入了谱归一化(spectral normalization),并发现这样可以提高训练动力学。我们所提出的 SAGAN 达到了当前最优水平,在极具挑战性的 ImageNet 数据集中将最好的 inception 分数记录从 36.8 提高到 52.52,并将 Frechet Inception 距离从 27.62 减少到 18.65。注意力层的可视化展现了生成器可利用其附近环境对物体形状做出反应,而不是直接使用固定形状的局部区域。

工程GitHub生成对抗网络注意力机制
3
相关数据
调度技术

调度在计算机中是分配工作所需资源的方法。资源可以指虚拟的计算资源,如线程、进程或数据流;也可以指硬件资源,如处理器、网络连接或扩展卡。 进行调度工作的程序叫做调度器。调度器通常的实现使得所有计算资源都处于忙碌状态,允许多位用户有效地同时共享系统资源,或达到指定的服务质量。 see planning for more details

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

注意力机制技术

我们可以粗略地把神经注意机制类比成一个可以专注于输入内容的某一子集(或特征)的神经网络. 注意力机制最早是由 DeepMind 为图像分类提出的,这让「神经网络在执行预测任务时可以更多关注输入中的相关部分,更少关注不相关的部分」。当解码器生成一个用于构成目标句子的词时,源句子中仅有少部分是相关的;因此,可以应用一个基于内容的注意力机制来根据源句子动态地生成一个(加权的)语境向量(context vector), 然后网络会根据这个语境向量而不是某个固定长度的向量来预测词。

规范化技术

规范化:将属性数据按比例缩放,使之落入一个小的特定区间,如-1.0 到1.0 或0.0 到1.0。 通过将属性数据按比例缩放,使之落入一个小的特定区间,如0.0到1.0,对属性规范化。对于距离度量分类算法,如涉及神经网络或诸如最临近分类和聚类的分类算法,规范化特别有用。如果使用神经网络后向传播算法进行分类挖掘,对于训练样本属性输入值规范化将有助于加快学习阶段的速度。对于基于距离的方法,规范化可以帮助防止具有较大初始值域的属性与具有较小初始值域的属相相比,权重过大。有许多数据规范化的方法,包括最小-最大规范化、z-score规范化和按小数定标规范化。

生成对抗网络技术

生成对抗网络是一种无监督学习方法,是一种通过用对抗网络来训练生成模型的架构。它由两个网络组成:用来拟合数据分布的生成网络G,和用来判断输入是否“真实”的判别网络D。在训练过程中,生成网络-G通过接受一个随机的噪声来尽量模仿训练集中的真实图片去“欺骗”D,而D则尽可能的分辨真实数据和生成网络的输出,从而形成两个网络的博弈过程。理想的情况下,博弈的结果会得到一个可以“以假乱真”的生成模型。

推荐文章
暂无评论
暂无评论~