Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

Yitong Li等作者一鸣 路编译

还在脑补画面?这款GAN能把故事画出来

当我们阅读的时候,我们的头脑可以想象书中发生的事情,似乎文本可以转换为脑海中栩栩如生的画面。这种能力似乎是人类的「专利」。现在,机器也可以做到这一点了。来自杜克大学和微软等机构的研究人员开发了一种新的GAN网络——StoryGAN,它可以根据文本生成对应的故事插图。

阅读小说是一件很有趣的事情,但是没有插图的故事往往索然无味。特别是儿童书籍,缺乏插图可能会让故事变得无聊。

如下是大段的儿童故事文本,即使内容很精彩,读者也不一定有兴趣阅读下去。如果配以图片,内容会有趣很多。

近来,一些研究人员就基于以上文本生成了对应的故事图片!

虽然不是很精致,但是已经有了故事的大概样貌。

具体来说,他们是根据如下故事内容生成图片的:

以下还有更多例子,左边五张图为生成结果,右边五张图则是实际的插图。

这种神奇的操作是怎么来的呢?我们来看这项有趣的研究《StoryGAN: A Sequential Conditional GAN for Story Visualization》。

论文链接:https://arxiv.org/pdf/1812.02784v2.pdf

这项研究的作者提出了一个新的任务类型——故事可视化,即基于给定句子生成一系列对应的图像,每张图像对应一个句子。和视频生成不同的是,故事图像化较少关注生成图像的连续性,而是更多地强调多个动态场景和角色之间的连贯性。此类问题目前无法被任何单个图像或视频生成方法解决。因此,论文作者提出了一个新的故事-图像序列生成模型——StoryGAN,它基于条件序列生成对抗网络。这一模型的特别之处在于,它有一个深度语境编码器可以动态跟踪故事流,以及两个判别器用于判别故事和图像,以便提升图像质量和生成序列的连贯性。为了评价模型的效果,研究人员修改了已有的数据集 CLEVR-SV 和 Pororo-SV。实验结果表明,StoryGAN 在图像质量、语境连贯性和人类评分上都超过了当前最佳模型。

图 1:输入的故事是「Pororo 和 Crong 在一起钓鱼。Crong 看了一下水桶。Pororo 的鱼竿上有一条鱼。」每个句子都需要生成一幅对应的图片。

基于故事文本生成图像,难点是什么?

让模型基于自然语言学习生成有意义且连贯的图像序列是一个有挑战的任务,它需要对自然语言和图像都能够理解和推理。

该任务主要面临两项挑战。第一,图像序列必须连贯且完整地描述整个故事。这项任务和文本-图像生成任务紧密相关,因为图像需要基于很短的描述语句生成。

第二项挑战是如何有逻辑地呈现故事线。具体来说,图像中目标的外观和背景布局必须根据故事情节推进以恰当的方式呈现。

StoryGAN 如何解决这个难题

下图展示了 StoryGAN 的模型架构:

图 2:StoryGAN 架构图示。灰色实心圈中的变量分别代表输入故事 S 和单个句子 s_1,...,s_T,以及随机噪声生成器网络包括故事编码器、语境编码器和图像生成器。顶部有两个判别器,分别判断每个图像-句子对和图像-序列-故事对是真实数据还是生成数据。

给定一个多句子段落(故事),StoryGAN 使用循环神经网络(RNN),将之前生成的图像信息加入到根据当前句生成图像的过程中。语境信息从语境编码器中获得,包括堆叠的 GRU 单元和新提出的 Text2Gist 单元。语境编码器将当前句子和故事向量转换为高维特征向量(Gist),用于之后的图像生成任务。

当故事推进时,Gist 动态更新,以反映故事流中的目标变化和场景变化。在 Text2Gist 组件中,句子被转换为一个隐藏向量,并与经过滤波器的输入做卷积,从而调整以适应整个故事,所以我们可以通过修改滤波器的方式优化混合过程。类似的方法还有动态滤波(dynamic filtering)、注意力模型和元学习(meta-learning)。

为了保证生成图像序列的连贯性,研究人员采用了一个双层 GAN 网络。他们使用了一个图像级别判别器来衡量句子和生成图像之间的相关性,以及一个故事判别器来衡量生成图像序列和整个故事的整体匹配度。

故事编码器

如图 2 粉色区域所示,故事编码器 E(·) 对故事 S 进行随机映射,得到低维嵌入向量 h_0。h_0 编码了整个故事,并作为隐藏层的初始状态输入到语境编码器中。

语境编码器

在序列图像生成任务中,角色、动作、背景等信息经常变化,每张图像可能都不相同。这里需要解决两个问题:

  • 如何在背景改变时有效地更新语境信息。

  • 如何在生成每张图像时将新的输入和随机噪声结合,从而可视化角色的变化(变化可能非常大)。

为了解决这两个问题,研究人员使用了一种基于深度 RNN 的语境编码器结构,用于在序列图像生成过程中捕捉语境信息。

这个深度循环神经网络包括两个隐藏层。底层使用标准 GRU 单元,而顶层使用论文提出的 Text2Gist 单元,它是 GRU 的一种变体。

在时间步 t,GRU 层接受句子 s_t 和等距高斯噪声ε_t 作为输入,并输出向量 i_t。Text2Gist 单元将 GRU 的输出 i_t 和故事语境向量 h_t 结合(h_t 来自故事编码器),生成 o_t。o_t 编码了在时间步 t 需要生成图像的所有必要信息。h_t 由 Text2Gist 更新,以反映潜在的语境信息变化。

如下为以上过程的公式:

o_t 是 Gist 向量,因为它分别结合了来自 h_t-1 的全局语境和 i_t 在时间步 t 的局部语境信息。故事编码器则初始化了 h_0,而 g_0 则是从等距高斯噪声分布中随机采样得到。根据以上信息,在时间步 t,Text2Gist 的内部运算过程如下:

在公式中,z_t 和 r_t 分别是更新门(update gate)和重置门(reset gate)的输出。更新门决定上一个时间步的信息要保留多少,而重置门决定从 h_t-1 中遗忘多少信息。σ_z 、σ_r 和 σ_h 是非线性 sigmoid 函数。和标准的 GRU 不同,输出 o_t 实际上是 Filter(i_t) 和 h_t 的卷积结果。

学习 i_t 学习的目的是适应 h_t。具体来说,Filter(·) 将向量 i_t 转换为多通道过滤器,其大小是 C_out × 1 × 1 × len(h_t ),C_out 表示输出通道数量。这一过程使用神经网络。由于 h_t 是向量,这个滤波器作为 1D 标准卷积层使用。

Text2Gist 中的这种卷积操作混合了来自 h_t 的全局语境信息,以及来自 i_t 的局部语境信息。由于 i_t 编码了 S 中 s_t 和 h_t 的信息,即它编码了整个故事的信息。而卷积操作可以被视为帮助 s_t 从生成过程中挑选重要的信息。实验结果表明,Text2Gist 比传统的循环神经网络在故事可视化上更加高效。

判别器

StoryGAN 使用两个判别器,分别对应图像和故事。这两个判别器分别确保局部和全局的故事可视化连贯性。

图 3:故事判别器的结构。图像和故事文本特征的内积作为输入馈送到全连接层,并使用 sigmoid 非线性函数预测是生成的还是真实的故事对。

算法

StoryGAN 的伪代码如算法 1 所示:

StoryGAN 的算法伪代码

在训练中,研究人员使用 Adam 优化器进行参数更新。他们发现,不同的 mini-batch 大小可以加快训练收敛的速度。在每轮训练中,在不同的时间步更新生成器和判别器也有很多好处。具体的网络和训练细节可以在附录 A 中找到。

实验

数据集

由于没有现有的数据集进行训练,研究人员根据现有的 CLEVR [19] 和 Pororo [21] 数据集进行了修改,制作了 CLEVR-SV 和 Pororo-SV 两个数据集。

1. CLEVR-SV 数据集

原版的 CLEVR 数据集用于视觉问答任务。研究人员使用如下方法将其改造为 CLEVR-SV:

  • 将一个故事中最多的目标数量限制在 4 个。

  • 目标为金属或橡胶制的物体,有八种颜色和两种尺寸。

  • 目标的形状可以是圆柱体、立方体或球体。

  • 目标每次增加一个,直到形成一个由四幅图像序列构成的故事。

研究人员生成了 10000 个图像序列用于训练,以及 3000 个图像序列用于测试。

2. Pororo-SV 数据集

Pororo 数据集原本用来进行视频问答,每个一秒的视频片段都有超过一个手写描述,40 个视频片段构成一个完整的故事。每个故事有一些问题和答案对。整个数据集有 16K 个时长一秒的视频片段,以及 13 个不同角色,而手写描述平均有 13.6 个词,包括发生了什么,以及视频中是哪个角色。这些视频片段总共组成了 408 个故事。

图 10:Pororo 数据集中出现的角色。

研究人员将 Pororo 数据集进行了改造。他们将每个视频片段的描述作为故事的文本输入。对于每个视频片段,随机提取一帧画面(采样率为 30Hz)作为真实的图像样本用于训练。五个连续的图像组成一个完整故事。最后,研究人员制作了 15,336 个描述-故事对,其中 13000 个用于训练,剩余的 2336 个用于测试。该数据集被称为 Pororo-SV。

研究人员对比了 StoryGAN 和其他模型的表现,并通过计算分类准确率、人类评分等方式评估结果。

CLEVR-SV 结果

图 5:不同模型在 CLEVR-SV 上生成结果的对比。

表 1:不同模型的生成结果和真实结果的结构相似性(SSIM)分数。

Pororo-SV 结果

图 6:两个故事中,不同模型的生成结果对比。

表 2:角色分类准确率上界为分类器在真实图像中的分类准确率

表 3:人类在评价生成图像时在不同指标上的打分。指标包括:图像质量、故事连贯性、相关性。±表示标准差。

表 4:基于排序的人类评估结果。±表示标准差。StoryGAN 获得了最高的平均排序,而 ImageGAN 表现最差。

理论GAN微软杜克大学
相关数据
视觉问答技术

收敛技术

在数学,计算机科学和逻辑学中,收敛指的是不同的变换序列在有限的时间内达到一个结论(变换终止),并且得出的结论是独立于达到它的路径(他们是融合的)。 通俗来说,收敛通常是指在训练期间达到的一种状态,即经过一定次数的迭代之后,训练损失和验证损失在每次迭代中的变化都非常小或根本没有变化。也就是说,如果采用当前数据进行额外的训练将无法改进模型,模型即达到收敛状态。在深度学习中,损失值有时会在最终下降之前的多次迭代中保持不变或几乎保持不变,暂时形成收敛的假象。

伪代码技术

伪代码,又称为虚拟代码,是高层次描述算法的一种方法。它不是一种现实存在的编程语言;它可能综合使用多种编程语言的语法、保留字,甚至会用到自然语言。 它以编程语言的书写形式指明算法的职能。相比于程序语言它更类似自然语言。它是半形式化、不标准的语言。

元学习技术

元学习是机器学习的一个子领域,是将自动学习算法应用于机器学习实验的元数据上。现在的 AI 系统可以通过大量时间和经验从头学习一项复杂技能。但是,我们如果想使智能体掌握多种技能、适应多种环境,则不应该从头开始在每一个环境中训练每一项技能,而是需要智能体通过对以往经验的再利用来学习如何学习多项新任务,因此我们不应该独立地训练每一个新任务。这种学习如何学习的方法,又叫元学习(meta-learning),是通往可持续学习多项新任务的多面智能体的必经之路。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

准确率技术

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

图像生成技术

图像生成(合成)是从现有数据集生成新图像的任务。

生成模型技术

在概率统计理论中, 生成模型是指能够随机生成观测数据的模型,尤其是在给定某些隐含参数的条件下。 它给观测值和标注数据序列指定一个联合概率分布。 在机器学习中,生成模型可以用来直接对数据建模(例如根据某个变量的概率密度函数进行数据采样),也可以用来建立变量间的条件概率分布。

生成对抗网络技术

生成对抗网络是一种无监督学习方法,是一种通过用对抗网络来训练生成模型的架构。它由两个网络组成:用来拟合数据分布的生成网络G,和用来判断输入是否“真实”的判别网络D。在训练过程中,生成网络-G通过接受一个随机的噪声来尽量模仿训练集中的真实图片去“欺骗”D,而D则尽可能的分辨真实数据和生成网络的输出,从而形成两个网络的博弈过程。理想的情况下,博弈的结果会得到一个可以“以假乱真”的生成模型。

优化器技术

优化器基类提供了计算梯度loss的方法,并可以将梯度应用于变量。优化器里包含了实现了经典的优化算法,如梯度下降和Adagrad。 优化器是提供了一个可以使用各种优化算法的接口,可以让用户直接调用一些经典的优化算法,如梯度下降法等等。优化器(optimizers)类的基类。这个类定义了在训练模型的时候添加一个操作的API。用户基本上不会直接使用这个类,但是你会用到他的子类比如GradientDescentOptimizer, AdagradOptimizer, MomentumOptimizer(tensorflow下的优化器包)等等这些算法。

推荐文章
暂无评论
暂无评论~