Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

张皓作者

三次简化一张图:一招理解LSTM/GRU门控机制

RNN 在处理时序数据时十分成功。但是,对 RNN 及其变种 LSTM 和 GRU 结构的理解仍然是一个困难的任务。本文介绍一种理解 LSTM 和 GRU 的简单通用的方法。通过对 LSTMGRU 数学形式化的三次简化,最后将数据流形式画成一张图,可以简洁直观地对其中的原理进行理解与分析。此外,本文介绍的三次简化一张图的分析方法具有普适性,可广泛用于其他门控网络的分析。

1. RNN、梯度爆炸与梯度消失

1.1 RNN

近些年,深度学习模型在处理有非常复杂内部结构的数据时十分有效。例如,图像数据的像素之间的 2 维空间关系非常重要,CNN(convolution neural networks,卷积神经网络)处理这种空间关系十分有效。而时序数据(sequential data)的变长输入序列之间时序关系非常重要,RNN(recurrent neural networks,循环神经网络,注意和 recursive neural networks,递归神经网络的区别)处理这种时序关系十分有效。

我们使用下标 t 表示输入时序序列的不同位置,用 h_t 表示在时刻 t 的系统隐层状态向量,用 x_t 表示时刻 t 的输入。t 时刻的隐层状态向量 h_t 依赖于当前词 x_t 和前一时刻的隐层状态向量 h_(t-1):

其中 f 是一个非线性映射函数。一种通常的做法是计算 x_t 和 h_(t-1) 的线性变换后经过一个非线性激活函数,例如

其中 W_(xh) 和 W_(hh) 是可学习的参数矩阵,激活函数 tanh 独立地应用到其输入的每个元素。

为了对 RNN 的计算过程做一个可视化,我们可以画出下图:

图中左边是输入 x_t 和 h_(t-1)、右边是输出 h_t。计算从左向右进行,整个运算包括三步:输入 x_t 和 h_(t-1) 分别乘以 W_(xh) 和 W_(hh) 、相加、经过 tanh 非线性变换。

我们可以认为 h_t 储存了网络中的记忆(memory),RNN 学习的目标是使得 h_t 记录了在 t 时刻之前(含)的输入信息 x_1, x_2,..., x_t。在新词 x_t 输入到网络之后,之前的隐状态向量 h_(t-1) 就转换为和当前输入 x_t 有关的 h_t。

1.2 梯度爆炸与梯度消失

虽然理论上 RNN 可以捕获长距离依赖,但实际应用中,RNN 将会面临两个挑战:梯度爆炸(gradient explosion)和梯度消失(vanishing gradient)。

我们考虑一种简单情况,即激活函数是恒等(identity)变换,此时

在进行误差反向传播(error backpropagation)时,当我们已知损失函数对 t 时刻隐状态向量 h_t 的偏导数时,利用链式法则,我们计算损失函数对 t 时刻隐状态向量 h_0 的偏导数

我们可以利用 RNN 的依赖关系,沿时间维度展开,来计算

也就是说,在误差反向传播时我们需要反复乘以参数矩阵 W_(hh)。我们对矩阵  W_(hh) 进行奇异值分解(SVD)

其中 r 是矩阵 W_(hh) 的秩(rank)。因此,

那么我们最后要计算的目标

当 t 很大时,该偏导数取决于矩阵 W_(hh) 的最大的奇异值是大于 1 还是小于 1,要么结果太大,要么结果太小:


(1). 梯度爆炸。当 > 1,,那么

此时偏导数将会变得非常大,实际在训练时将会遇到 NaN 错误,会影响训练的收敛,甚至导致网络不收敛。这好比要把本国的产品卖到别的国家,结果被加了层层关税,等到了别国市场的时候,价格已经变得非常高,老百姓根本买不起。在 RNN 中,梯度(偏导数)就是价格,随着向前推移,梯度越來越大。这种现象称为梯度爆炸。

梯度爆炸相对比较好处理,可以用梯度裁剪(gradient clipping)来解决:

这好比是不管前面的关税怎么加,设置一个最高市场价格,通过这个最高市场价格保证老百姓是买的起的。在 RNN 中,不管梯度回传的时候大到什么程度,设置一个梯度的阈值,梯度最多是这么大。

(2). 梯度消失。当 < 1,,那么

此时偏导数将会变得十分接近 0,从而在梯度更新前后没有什么区别,这会使得网络捕获长距离依赖(long-term dependency)的能力下降。这好比打仗的时候往前线送粮食,送粮食的队伍自己也得吃粮食。当补给点离前线太远时,还没等送到,粮食在半路上就已经被吃完了。在 RNN 中,梯度(偏导数)就是粮食,随着向前推移,梯度逐渐被消耗殆尽。这种现象称为梯度消失。

梯度消失现象解决起来困难很多,如何缓解梯度消失是 RNN 及几乎其他所有深度学习方法研究的关键所在。LSTM 和 GRU 通过门(gate)机制控制 RNN 中的信息流动,用来缓解梯度消失问题。其核心思想是有选择性的处理输入。比如我们在看到一个商品的评论时

Amazing! This box of cereal gave me a perfectly balanced breakfast, as all things should be. In only ate half of it but will definitely be buying again!

我们会重点关注其中的一些词,对它们进行处理

Amazing! This box of cereal gave me a perfectly balanced breakfast, as all things should be. In only ate half of it but will definitely be buying again!

LSTM 和 GRU 的关键是会选择性地忽略其中一些词,不让其参与到隐层状态向量的

更新中,最后只保留相关的信息进行预测。

2. LSTM

2.1 LSTM 的数学形式

LSTM(Long Short-Term Memory)由 Hochreiter 和 Schmidhuber 提出,其数学上的形式化表示如下:

其中代表逐元素相乘,sigm 代表 sigmoid 函数

和 RNN 相比,LSTM 多了一个隐状态变量 c_t,称为细胞状态(cell state),用来记录信息。

这个公式看起来似乎十分复杂,为了更好的理解 LSTM 的机制,许多人用图来描述 LSTM 的计算过程。比如下面这张图:

似乎看完之后,对 LSTM 的理解仍然是一头雾水?这是因为这些图想把 LSTM 的所有细节一次性都展示出来,但是突然暴露这么多的细节会使你眼花缭乱,从而无处下手。

2.2 三次简化一张图

因此,本文提出的方法旨在简化门控机制中不重要的部分,从而更关注在 LSTM 的核心思想。整个过程是三次简化一张图,具体流程如下:

(1). 第一次简化:忽略门控单元 i_t 、f_t 、o_t 的来源。3 个门控单元的计算方法完全相同,都是由输入经过线性映射得到的,区别只是计算的参数不同:

使用相同计算方式的目的是它们都扮演了门控的角色,而使用不同参数的目的是为了误差反向传播时对三个门控单元独立地进行更新。在理解 LSTM 运行机制的时候,为了对图进行简化,我们不在图中标注三个门控单元的计算过程,并假定各门控单元是给定的。

(2). 第二次简化:考虑一维门控单元 i_t 、 f_t 、 o_t。LSTM 中对各维是独立进行门控的,所以为了表示和理解方便,我们只需要考虑一维情况,在理解 LSTM 原理之后,将一维推广到多维是很直接的。经过这两次简化,LSTM 的数学形式只有下面三行

由于门控单元变成了一维,所以向量和向量的逐元素相乘符号变成了数和向量相乘 · 。

(3). 第三次简化:各门控单元二值输出。门控单元 i_t 、f_t 、o_t 的由于经过了 sigmoid 激活函数,输出是范围是 [0, 1]。激活函数使用 sigmoid 的目的是为了近似 0/1 阶跃函数,这样 sigmoid 实数值输出单调可微,可以基于误差反向传播进行更新。

既然 sigmoid 激活函数是为了近似 0/1 阶跃函数,那么,在进行 LSTM 理解分析的时候,为了理解方便,我们认为各门控单元 {0, 1} 二值输出,即门控单元扮演了电路中开关的角色,用于控制信息的通断。

(4). 一张图。将三次简化的结果用电路图表述出来,左边是输入,右边是输出。在 LSTM 中,有一点需要特别注意,LSTM 中的细胞状态 c_t 实质上起到了 RNN 中隐层单元 h_t 的作用,这点在其他文献资料中不常被提到,所以整个图的输入是 x_t 和  c_{t-1},而不是 x_t 和 h_(t-1)。为了方便画图,我们需要将公式做最后的调整

最终结果如下:

和 RNN 相同的是,网络接受两个输入,得到一个输出。其中使用了两个参数矩阵  W_(xc) 和 W_(hc),以及 tanh 激活函数。不同之处在于,LSTM 中通过 3 个门控单元 i_t 、f_t 、o_t 来对的信息交互进行控制。当 i_t=1(开关闭合)、f_t=0(开关打开)、o_t=1(开关闭合)时,LSTM 退化为标准的 RNN。

2.3 LSTM 各单元作用分析

根据这张图,我们可以对 LSTM 中各单元作用进行分析:

  • 输出门 o_(t-1):输出门的目的是从细胞状态 c_(t-1) 产生隐层单元 h_(t-1)。并不是 c_(t-1) 中的全部信息都和隐层单元 h_(t-1) 有关,c_(t-1) 可能包含了很多对 h_(t-1) 无用的信息。因此,o_t 的作用就是判断 c_(t-1) 中哪些部分是对 h_(t-1) 有用的,哪些部分是无用的。

  • 输入门 i_t。i_t 控制当前词 x_t 的信息融入细胞状态 c_t。在理解一句话时,当前词 x_t 可能对整句话的意思很重要,也可能并不重要。输入门的目的就是判断当前词 x_t 对全局的重要性。当 i_t 开关打开的时候,网络将不考虑当前输入  x_t。

  • 遗忘门 f_t: f_t 控制上一时刻细胞状态 c_(t-1) 的信息融入细胞状态 c_t。在理解一句话时,当前词 x_t 可能继续延续上文的意思继续描述,也可能从当前词 x_t 开始描述新的内容,与上文无关。和输入门 i_t 相反,f_t 不对当前词 x_t 的重要性作判断,而判断的是上一时刻的细胞状态c_(t-1)对计算当前细胞状态 c_t 的重要性。当 f_t 开关打开的时候,网络将不考虑上一时刻的细胞状态 c_(t-1)。

  • 细胞状态 c_t :c_t 综合了当前词 x_t 和前一时刻细胞状态 c_(t-1) 的信息。这和 ResNet 中的残差逼近思想十分相似,通过从 c_(t-1) 到 c_t 的「短路连接」,梯度得已有效地反向传播。当 f_t 处于闭合状态时,c_t 的梯度可以直接沿着最下面这条短路线传递到c_(t-1),不受参数 W_(xh) 和 W_(hh) 的影响,这是 LSTM 能有效地缓解梯度消失现象的关键所在。

3. GRU

3.1 GRU 的数学形式

GRU 是另一种十分主流的 RNN 衍生物。RNN 和 LSTM 都是在设计网络结构用于缓解梯度消失问题,只不过是网络结构有所不同。GRU 在数学上的形式化表示如下:

3.2 三次简化一张图

为了理解 GRU 的设计思想,我们再一次运用三次简化一张图的方法来进行分析:

(1). 第一次简化:忽略门控单元 z_t 和 r_t 的来源。

(2). 考虑一维门控单元 z_t 和 r_t。经过这两次简化,GRU 的数学形式是以下两行

(3). 第三次简化:各门控单元二值输出。这里和 LSTM 略有不同的地方在于,当 z_t=1 时h_t = h_(t-1) ;而当 z_t = 0 时,h_t =。因此,z_t 扮演的角色是一个个单刀双掷开关。

(4). 一张图。将三次简化的结果用电路图表述出来,左边是输入,右边是输出。

与 LSTM 相比,GRU 将输入门 i_t 和遗忘门 f_t 融合成单一的更新门 z_t,并且融合了细胞状态 c_t 和隐层单元 h_t。当 r_t=1(开关闭合)、 z_t=0(开关连通上面)GRU 退化为标准的 RNN。

3.3 GRU 各单元作用分析

根据这张图, 我们可以对 GRU 的各单元作用进行分析:

  • 重置门 r_t : r_t 用于控制前一时刻隐层单元 h_(t-1) 对当前词 x_t 的影响。如果 h_(t-1) 对 x_t 不重要,即从当前词 x_t 开始表述了新的意思,与上文无关。那么开关 r_t 可以打开,使得 h_(t-1) 对 x_t 不产生影响。

  • 更新门 z_t : z_t 用于决定是否忽略当前词 x_t。类似于 LSTM 中的输入门 i_t,z_t 可以判断当前词 x_t 对整体意思的表达是否重要。当 z_t 开关接通下面的支路时,我们将忽略当前词 x_t,同时构成了从 h_(t-1) 到 h_t 的短路连接,这使得梯度得已有效地反向传播。和 LSTM 相同,这种短路机制有效地缓解了梯度消失现象,这个机制于 highway networks 十分相似。

4. 小结

尽管 RNN、LSTM、和 GRU 的网络结构差别很大,但是他们的基本计算单元是一致的,都是对 x_t 和 h_t 做一个线性映射加 tanh 激活函数,见三个图的红色框部分。他们的区别在于如何设计额外的门控机制控制梯度信息传播用以缓解梯度消失现象。LSTM 用了 3 个门、GRU 用了 2 个,那能不能再少呢?MGU(minimal gate unit)尝试对这个问题做出回答,它只有一个门控单元。最后留个小练习,参考 LSTM 和 GRU 的例子,你能不能用三次简化一张图的方法来分析一下 MGU 呢?

参考文献

  1. Yoshua Bengio, Patrice Y. Simard, and Paolo Frasconi. Learning long-term dependencies with gradient descent is difficult. IEEE Transactions on Neural Networks 5(2): 157-166, 1994.

  2. Kyunghyun Cho, Bart van Merrienboer, Çaglar Gülçehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, and Yoshua Bengio. Learning phrase representations using RNN encoder-decoder for statistical machine translation. In EMNLP, pages 1724-1734, 2014.

  3. Junyoung Chung, Çaglar Gülçehre, KyungHyun Cho, and Yoshua Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. In NIPS Workshop, pages 1-9, 2014.

  4. Felix Gers. Long short-term memory in recurrent neural networks. PhD Dissertation, Ecole Polytechnique Fédérale de Lausanne, 2001.

  5. Ian J. Goodfellow, Yoshua Bengio, and Aaron C. Courville. Deep learning. Adaptive Computation and Machine Learning, MIT Press, ISBN 978-0-262-03561-3, 2016.

  6. Alex Graves. Supervised sequence labelling with recurrent neural networks. Studies in Computational Intelligence 385, Springer, ISBN 978-3-642-24796-5, 2012.

  7. Klaus Greff, Rupesh Kumar Srivastava, Jan Koutník, Bas R. Steunebrink, and Jürgen Schmidhuber. LSTM: A search space odyssey. IEEE Transactions on Neural Networks and Learning Systems. 28(10): 2222-2232, 2017.

  8. Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, pages 770-778, 2016.

  9. Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Identity mappings in deep residual networks. In ECCV, pages 630-645, 2016.

  10. Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural Computation 9(8): 1735-1780, 1997.

  11. Rafal Józefowicz, Wojciech Zaremba, and Ilya Sutskever. An empirical exploration of recurrent network architectures. In ICML, pages 2342-2350, 2015.

  12. Zachary Chase Lipton. A critical review of recurrent neural networks for sequence learning. CoRR abs/1506.00019, 2015.

  13. Razvan Pascanu, Tomas Mikolov, and Yoshua Bengio. On the difficulty of training recurrent neural networks. In ICML, pages 1310-1318, 2013.

  14. Rupesh Kumar Srivastava, Klaus Greff, and Jürgen Schmidhuber. Highway networks. In ICML Workshop, pages 1-6, 2015.

  15. Guo-Bing Zhou, Jianxin Wu, Chen-Lin Zhang, and Zhi-Hua Zhou. Minimal gated unit for recurrent neural networks. International Journal of Automation and Computing, 13(3): 226-234, 2016.

理论LSTMGRURNN梯度消失梯度爆炸
4
相关数据
Fethi Bougares人物

法国缅因大学副教授。研究领域:机器学习、深度学习、神经机器翻译、自动语音识别 (ASR)、统计机器翻译 (SMT)、阿拉伯语形态分析。

Sepp Hochreiter人物

Sepp Hochreiter 是一名德国计算机科学家。 1991 年,Sepp Hochreiter 发表了德语论文,探讨了循环神经网络的梯度随着序列长度增加倾向于消失或爆炸。与 Yoshua Bengio 的相关工作几乎同时,并且开发了 LSTM 的雏形。

相关技术
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

激活函数技术

在 计算网络中, 一个节点的激活函数定义了该节点在给定的输入或输入的集合下的输出。标准的计算机芯片电路可以看作是根据输入得到"开"(1)或"关"(0)输出的数字网络激活函数。这与神经网络中的线性感知机的行为类似。 一种函数(例如 ReLU 或 S 型函数),用于对上一层的所有输入求加权和,然后生成一个输出值(通常为非线性值),并将其传递给下一层。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

收敛技术

在数学,计算机科学和逻辑学中,收敛指的是不同的变换序列在有限的时间内达到一个结论(变换终止),并且得出的结论是独立于达到它的路径(他们是融合的)。 通俗来说,收敛通常是指在训练期间达到的一种状态,即经过一定次数的迭代之后,训练损失和验证损失在每次迭代中的变化都非常小或根本没有变化。也就是说,如果采用当前数据进行额外的训练将无法改进模型,模型即达到收敛状态。在深度学习中,损失值有时会在最终下降之前的多次迭代中保持不变或几乎保持不变,暂时形成收敛的假象。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

奇异值分解技术

类似于特征分解将矩阵分解成特征向量和特征值,奇异值分解(singular value decomposition, SVD)将矩阵分解为奇异向量(singular vector)和奇异值(singular value)。通过分解矩阵,我们可以发现矩阵表示成数组元素时不明显的函数性质。而相比较特征分解,奇异值分解有着更为广泛的应用,这是因为每个实数矩阵都有一个奇异值分解,但未必都有特征分解。例如,非方阵型矩阵没有特征分解,这时只能使用奇异值分解。

导数技术

导数(Derivative)是微积分中的重要基础概念。当函数y=f(x)的自变量x在一点x_0上产生一个增量Δx时,函数输出值的增量Δy与自变量增量Δx的比值在Δx趋于0时的极限a如果存在,a即为在x0处的导数,记作f'(x_0) 或 df(x_0)/dx。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

卷积神经网络技术

卷积神经网路(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。卷积神经网路由一个或多个卷积层和顶端的全连通层(对应经典的神经网路)组成,同时也包括关联权重和池化层(pooling layer)。这一结构使得卷积神经网路能够利用输入数据的二维结构。与其他深度学习结构相比,卷积神经网路在图像和语音识别方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他深度、前馈神经网路,卷积神经网路需要考量的参数更少,使之成为一种颇具吸引力的深度学习结构。 卷积网络是一种专门用于处理具有已知的、网格状拓扑的数据的神经网络。例如时间序列数据,它可以被认为是以一定时间间隔采样的一维网格,又如图像数据,其可以被认为是二维像素网格。

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

长距离依赖技术

也作“长距离调序”问题,在机器翻译中,比如中英文翻译,其语言结构差异比较大,词语顺序存在全局变化,不容易被捕捉

链式法则技术

是求复合函数导数的一个法则, 是微积分中最重要的法则之一。

梯度消失问题技术

梯度消失指的是随着网络深度增加,参数的梯度范数指数式减小的现象。梯度很小,意味着参数的变化很缓慢,从而使得学习过程停滞,直到梯度变得足够大,而这通常需要指数量级的时间。这种思想至少可以追溯到 Bengio 等人 1994 年的论文:「Learning long-term dependencies with gradient descent is difficult」,目前似乎仍然是人们对深度神经网络的训练困难的偏好解释。

遗忘门技术

LSTM或GRU中特有的机制

暂无评论
暂无评论~