谷歌论文新突破:通过辅助损失提升RNN学习长期依赖关系的能力

本文提出了一种简单的方法,通过在原始函数中加入辅助损失改善 RNN 捕捉长期依赖关系的能力,并在各种设置下评估了该方法,包括用长达 16,000 的序列对一张图的逐个像素进行分类,以及对一个真实的基准文件进行分类;和其他常用模型和大小相当的转换器相比,该方法在性能和资源使用效率方面的表现都非常突出。

介绍

大量人工智能应用的前提是首先理解序列中事件间的长期依赖关系。例如,在自然语言处理中,有时就必须要对书中描述的远距离事件之间的关系有所了解,这样才能回答问题。一般而言,现在是通过梯度下降和带有循环网络的 BPTT(Rumelhart et al., 1986)解决这一问题的。然而,通过梯度下降方法学习长期依赖性很难,因为借助 BPTT 计算的梯度在训练过程中有消失或爆炸的倾向。除此以外,如果想要使 BPTT 起作用,人们需要存储中间过程的隐藏状态。内存需求与序列长度成正比,使得这种方法难以处理大问题。


图 1:本文方法概述。辅助损失改善了循环网络的内存问题,主任务的 BPTT 需要的步骤少了一些。

也有人提出过若干个有望解决这些问题的方法。首先,可以使用 LSTM(Hochreiter & Schmidhuber, 1997)代替常用的循环神经网络,这可以改善循环网络中的梯度流的问题。此外,还可以使用梯度裁减(Pascanu et al., 2013)提高 LSTM 训练过程的稳定性。最后,为了减少内存方面的需求,可以使用截断的 BPTT 或合成梯度(Jaderberg et al., 2017)定期存储隐藏层的状态(Gruslys et al., 2016; Chen et al., 2016)。

卷积神经网络也可以消除长期的依赖关系问题,因为内核较大,而且像 ResNets(He et al., 2016)这样的深度网络允许跨越图像中相距较远的两个部分学习长期依赖关系。但这样就会用到完全不同的架构,我们可以对此进行权衡。例如,在训练过程中,模型的输入(一张图像或者一个序列)以及中间的激活都要存储在内存中。在推断时,典型的 CNN 需 O(n) 的存储空间(n 代表输入的大小)。尽管由于训练和推断的计算需要随机存取到内存 O(n),但变换器(Vaswani et al., 2017)也有相似的问题,并且严重一些。

图 2:本文方法的草图。对每个随机定位点,也就是 F 点而言,我们在这个位置上建立辅助损失。左图:我们预测了 F 点前的一段随机序列 BCD。将 B 点插入解码器网络以开始重建,而 C 点和 D 点可以选择是否馈送。右图:我们通过在主窗口堆叠辅助 RNN 对子序列 GHI 进行预测。在这两种情况中,辅助损失的梯度都被截断,通过这种方式来保证 BPTT 总体消耗维持不变。

RNN 的优势在于,假设 BPTT 的长度为 l,训练就需要 O(l) 的内存。这是一个用 PTB 数据集(Marcus et al., 1994)训练语言模型的典型实例,这样做 100 万个符号序列的状态就永远不会重置。因此,从理论上讲 RNN 可以从极远的距离学到这种关系。此外,RNN 的推断也需要 O(l) 的内存,因为 RNN 不需要「回头」。

在这篇论文中,我们提出一种正交技术以进一步解决循环网络单纯依赖 BPTT 的缺陷。该方法介绍了一种无监督辅助损失,可以重建/预测锚点前后的一部分随机序列。实现这个方法,只需要几步有监督损失的 BPTT。

论文结果表明无监督辅助损失显著改善了 LSTM 的优化和泛化能力。此外,如果使用这一方法,无需在训练过程中执行冗长的 BPTT 以获得良好的结果。因此,该方法适用于长序列,在此之前,这些长序列中出现的梯度消失/爆炸问题以及冗长的 BPTT 消耗问题都是模型训练中的重要瓶颈。

实验采用的序列长达 16,000 个元素,带有辅助损失的 LSTM 训练得更快并使用了更少的内存,而采用完整的反向传播训练 LSTM 则非常困难。

方法

假设目标是使用循环网络阅读序列并分类。我们随机采样一个或多个锚点,并在每个锚点插入无监督辅助损失。

3.1. 重建辅助损失

在重建过去事件时,我们取样了锚点之前的子序列,并将第一段标记序列插入解码器网络;然后我们要求解码器网络预测出剩下的子序列。整个过程如图 2 左图所示。

我们推断,如果拟预测序列离定位点足够近,解码重建过去事件所需的 BPTT 的步骤就会非常少。另外,随着训练的进一步加强,定位点会在循环网络中充当临时存储的角色来记录序列中过去的事件。如果我们选择了足够多的定位点,就会在整段序列上建立足够多的存储,当我们到序列末端时,分类器会记住序列从而更好地进行分类。因此,分类器仅需几步反向传播步骤对 LSTM 的权重进行微调,因为网络已经通过优化的辅助损失很好地对输入序列的嵌入进行了学习。

3.2. 预测辅助损失

本文考虑的另一种辅助损失类似于语言模型损失,如图 2 右图所示。这种情况要求解码器网络在子序列中从锚点出发预测出所给序列的下一段标记序列。这类无监督辅助损失第一次是 Dai & Le (2015) 提出的,他们将其应用于整个输入序列。但我们将其应用在长期依赖关系学习的扩展方案中,因此我们仅将这种损失应用在随机锚点之后的子序列中。

3.3. 训练

我们将前一种方法称为 r-LSTM , 将后一种方法称为 p-LSTM(r 和 p 分别代表重建和预测),在两个阶段对这两个模型进行训练。第一阶段是单纯的无监督预训练,在这一方法中辅助损失取最小值。而在第二阶段中,执行的是半监督学习,在这一阶段中我们取主要目标损失 L_supervised 和 L_auxiliary 的总和最小值。用定期采样(Bengio et al., 2015a)的方法训练执行重建操作的辅助 LSTM。

表 2:在 MNIST、pMNIST 和 CIFAR10 上测试的准确率


图 3:上图: StanfordDogs 的 8 个级别序列长度测试的准确度。下图:运行具有 128 个训练实例的单个小批次的时间,以秒为测量单位。


图 5 :辅助损失对训练和测试准确率的影响

论文:Learning Longer-term Dependencies in RNNs with Auxiliary Losses

论文链接:https://arxiv.org/abs/1803.00144

尽管训练循环神经网络(RNNs)最近仍有进展,但在长序列中捕捉长期依赖关系仍旧是根本的挑战。现在一般会用通过时间的反向传播(BPTT)解决这一问题,但这很难应用于极长的序列。本文提出了一种简单的方法,可以通过在原始函数中加入辅助损失改善 RNN 捕捉长期依赖关系的能力。辅助损失强制 RNN 在序列中重建之前的事件或是预测接下来的事件,这样的操作可以截断长序列中的反馈,还可以提高 BPTT 整体的能力。我们在各种设置下评估了所述方法,包括用长达 16,000 的序列对一张图的逐个像素进行分类,以及对一个真实的基准文件进行分类。和其他常用模型和大小相当的转换器相比,我们的方法在性能和资源使用效率方面的表现都非常突出。我们进一步分析了辅助损失在优化和正则化方面的积极影响,和没有反向传播相比,几乎不会出现极端情况。

理论谷歌循环神经网络损失函数谷歌大脑
2
返回顶部