谷歌提出贝叶斯循环神经网络:优于传统RNN

谷歌研究者最近在 arXiv 上发布了一篇论文,介绍了一种新的网络:贝叶斯循环神经网络(Bayesian Recurrent Neural Networks),在该论文中,谷歌还介绍并开源了两个实验实现。机器之心对该研究进行了简要介绍。

7.png

在本研究中,我们探索了一种用于循环神经网络的直接的变分贝叶斯方案(variational Bayes scheme)。首先,我们表明对截断的通过时间的反向传播(truncated backpropagation through time)进行一点简单的改进就能在训练过程只需消耗一点点额外的计算成本的情况下得到良好的质量不确定性估计和优异的正则化。其次,我们说明了一种全新的后验近似(posterior approximation)可以如何进一步提升贝叶斯 RNN 的表现。我们在近似的后验中整合了局部梯度信息,以围绕当前批的统计情况(current batch statistics)对其进行锐化。该技术并不限于循环神经网络,而且可被更广泛地应用于训练贝叶斯神经网络。我们还通过实验表明贝叶斯 RNN 在一个语言建模基准和一个图像描述任务上优于传统的 RNN,同时也说明了这些每种方法在其它多种用于训练它们的方案上对我们的模型实现了提升。我们还为语言模型的不确定度研究引入了一个新的基准,以便未来我们可以轻松地比较各种方法。

1 引言

本研究有以下贡献:

  • 我们表明通过反向传播的贝叶斯(BBB:Bayes by Backprop)可以被有效地应用于 RNN

  • 我们开发了一种全新的技术,其可以减少 BBB 的方差,而且其可被广泛地应用于其它最大似然框架

  • 我们在两个被广泛研究过的基准上实现了表现提升,并极大地超越了已有的正则化技术,比如 dropout

  • 我们引入了一个新的用于研究语言模型的不确定性的基准

2 通过反向传播的贝叶斯

算法 1 给出了用于最小化的通过反向传播的贝叶斯的蒙特卡洛过程(Bayes by Backprop Monte Carlo procedure),其涉及到后验 q(θ) 的平均值和标准偏差。

033117_1829_MicrosoftUp1.png

3 通过时间的反向传播

一个 RNN 可以使用通过时间的反向传播(backpropagation through time)来在一个长度为 T 的序列上进行训练,其中该 RNN 被展开 T 次而成为一个前向网络。也就是说,通过使用输入 x1,x2,...,xT 和初始状态 s0 来构建该前向网络:

286904983071613429.png

其中 sT 是该 RNN 的最终状态。我们应该指的是进行 T 步如 (3) 所示的 RNN 核展开,通过57d7c971e1c98.png。其中,x1:T 是输入向量的序列,而 s1:T 是对应状态的序列。注意该算法的截断版本(truncated version)可以被看作是将 s0 作为之前批的最后状态 sT。

4 使用通过时间的反向传播的截断贝叶斯(Truncated Bayes by Backprop Through Time)

FvqYVz.png

图 1:BBB 应用于一个 RNN 的图示

图 1 给出将 BBB 应用于 RNN 的示意,其中该 RNN 的权重矩阵是根据分布(通过 BBB 学到的)而绘出的。但是,这种直接的应用有两个问题:什么时候对该 RNN 的参数采样,怎么衡量 (2) 的 KL 正则化器的贡献。

10415_670d3369559cd4f24e79046d6372a0e9.png

下面的算法 2 中,我们简要地说明了 BBB 对 RNN 的适应。

Cover5th.png

5 后验锐化(Posterior Sharpening)

算法 3 给出了实际中学习执行的方式:

Capture.PNG

不同于一般的 BBB (其中 KL 项在推理过程中可以忽略,参见补充材料),在后验锐化下进行推理,我们有两种选择。第一种涉及到使用 q(ϕ) 并忽略任何 KL 项,类似于一般的 BBB。第二种涉及到使用 q(θ|(x, y)),这需要用 KL [q(θ|ϕ,(x, y)) || p(θ|ϕ)] 项得出一个困惑度(perplexity)上界(下界用对数概率产生,参见补充材料)。下一节提供了这两种方法的比较。

6 相关工作(略)

7 实验

我们给出了我们的方法在一个语言建模基准和一个图像描述生成任务上的结果。

7.1. 语言建模(Language Modelling)

开源地址:https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py

duwo4tifgyjsjdqxzyxr.png

表 1:在 Penn Treebank 语言建模任务上的词级困惑度(越低越好)

7.2. 图像描述生成(Image Caption Generation)

开源地址:https://github.com/tensorflow/models/tree/master/im2txt

race_for_AI_q3-16_1.png

图 3:在 MSCOCO 开发集上的图像描述结果

cocotama_ftw_by_cansin13art-d9o0s1i.png

表 2:在 MSCOCO 开发集上的图像描述结果

讨论、致谢、参考文献和补充材料(略)

理论谷歌贝叶斯论文理论循环神经网络实现
返回顶部