Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

机器之心编辑部报道

ICLR 2020 | 超越SOTA Transformer模型,哈佛、FAIR提出基于残差能量模型的文本生成

在本文中,来自哈佛大学、Facebook AI 研究院的研究者提出了一种基于残差能量模型的文本生成方法,效果超过 state-of-the-art 的 transformer 语言模型。这是能量模型在大规模文本生成中的首次成功应用,论文已入选 ICLR 2020。

论文链接:https://openreview.net/pdf?id=B1l4SgHKDH

近年来,随着 GPT-2、GPT-3 模型的出现,我们似乎已经可以使用语言模型生成以假乱真的文本。然而事实上,模型生成的文本存在明显的质量问题。

比如我们可以训练一个分类器去区分真实文本和语言模型生成的文本,而且可以达到非常高的准确率 [1,2]。那么,一个自然的问题是,我们能否使用这个分类器去提高文本生成的质量,以期达到更加以假乱真的水平呢?这就是本文的研究问题。

同时,本文还解答了另一个问题:由于传统的文本生成解码器只能使用单向模型,如何使用预训练的双向模型 BERT 改进文本生成解码器?

为了便于讨论,作者定义一段有 T 个词的文本为 x=x_1 x_2…x_T。它有可能是真实文本,也可能是一个语言模型 P_LM (x)生成的文本。他们训练了一个分类器 E_θ (x)去区分 x 是真实的(real)还是生成的:

这里的 σ 是 sigmoid 函数,以确保概率在 0-1 范围内。以下示意图展示了训练的目标:

一个好的分类器 E_θ (x)可以确保当 x 比较接近真实文本时,E_θ (x)比较小;而当 x 比较接近语言模型生成文本时,E_θ (x)比较大。利用 E_θ (x),可以修正语言模型 P_LM (x),从而得到一个新的文本生成模型 P_θ (x):

上式就是本文提出的残差能量模型(residual energy-based model),这里的 Z 是一个全局归一化常数。之所以叫它残差模型,是因为在修正,比如当 E_θ (x)≡0 时,

这个残差模型非常直观,当 x 比较「不真实」时,E_θ (x)比较大,因此在残差模型中的概率会低于未经修正前的

选择这样形式的模型是否有数学上的依据呢?事实上,作者的训练方法是噪声对抗训练(NCE)的一个特殊形式 [3,4]。理论保证详见论文中的定理 1,其结论是当 E_θ (x) 足够强大时(一般意味着足够多参数),目标函数的最优解是,亦即即使语言模型 P_LM (x)和真实文本有偏差,足够强大的 E_θ (x)和足够好的优化算法都可以使残差模型无限逼近真实文本分布。

虽然本文提出的模型具有很好的理论保证,但引入分类器 / 修正器 E_θ (x)引入了额外的参数。为什么不直接增加语言模型参数呢?这涉及到了语言模型 P_LM (x)和残差能量模型 P_θ (x)的本质区别:目前的语言模型 P_LM (x)一般是局部归一化(locally normalized)的,而 P_θ (x)是全局归一化的(globally normalized):

也就是说,P_LM (x)的模型在生成每个单词时,只能使用前面已经生成的单词的信息。因此我们只能使用单向的模型作为文本生成模型,而无法使用双向的模型。对比之下,E_θ (x_1 x_2…x_T )是直接取整个文本作为模型的输入,因此可以使用双向的模型,比如预训练的 BERT。由于不需要像 P_LM (x)那样每生成一个单词都归一化,因此全局归一化的 P_θ (x)更灵活。其实,P_LM (x)只是 P_θ (x)的一种特例。

虽然全局归一化的模型更灵活,但与 P_LM (x)不同,P_θ (x)不能从左至右逐词生成,因为 E_θ (x)需要以整个文本作为输入。对此,作者提出了基于 importance sampling 的生成方式:为了生成一个文本,作者

  1. 首先从 P_LM (x)中采样 N 个完整文本{x^1,x^2,…,x^N}

  2. 然后从这个样本集中进行采样:P(x=x^i)∝exp⁡(-E_θ (x^i ))

上述过程非常类似机器翻译和句法分析中的再排序算法(reranking),然而本文作者提出的算法有两点重要的改进:第一,他们的算法具有理论保证,当样本数 N 足够大,上述过程中采集的样本服从 P_θ (x)的分布;第二,再排序在第二步骤进行的是排序,而他们进行的是采样(初步实验证明排序的效果弱于采样,类似 [5] 中的观察)。

实验

最后简要介绍一下实验结果。本文主要使用的数据集 CC-News 规模非常大,有 160 亿个词 [6]。另外,作者选择的基线(baseline)是 GPT 级别的 state-of-the-art 语言模型。对如此大规模数据下基线模型的提高是非常有意义的。

首先,作为生成模型,作者使用自然语言处理中的常用指标 perplexity(PPL)衡量真实文本在模型下的概率。PPL 可以简化理解为正确生成每个词,模型平均需要猜几次。因此,PPL 越低越好。这里残差能量模型的 PPL 使用采样估计的上界,详见论文。


在上图中,BASE LM 是语言模型 P_LM (x),其余的(Joint 开头)都是残差能量模型。使用单向的 transformer 作为 E_θ (∙)(Joint UniT),PPL 略有降低,而使用双向的 transformer(Joint BiT-Base),PPL 比单向模型进一步下降(值得一提的是,传统的语言模型是没法使用双向 transformer 的)。最后两列展示了本文所提方法可以使用预训练的双向模型,这里作者使用了 BERT 的变种 Roberta-Base(Joint BiT-Base)和 Roberta-Large(Joint BiT-Large),效果得到了进一步的提升。

PPL 的降低证明了:从概率模型的角度,本文提出的模型是优于基线模型的。但该模型能否生成更以假乱真的文本呢?下面的表格中,作者做了人工评测的实验,验证了该模型的确可以得到更好的文本:

最后,作者给出了一个具体例子,直观理解残差模型如何修正改进语言模型 P_LM (x)。

前文指出过,此项研究的生成过程是先采样一些样本,然后使用〖-E〗_θ (x)作为分数从这些样本中进行再次采样。以上的 Joint Bit-Base Worst 是〖-E〗_θ (x)最低的样本(也就是分类器认为最不像真实文本的)。这个样本中,词组「these grants」重复了两次。重复生成词组是目前语言模型的常见问题 [5],因此分类器会根据这个特点,很容易判断出这句话不是真实文本,由此在再采样过程中,这个分数很低的样本基本不可能被采样到。值得一提的是,本文提出的模型训练时并没有明确要求它不生成重复词组,但分类器自动发现重复词组是一个语言模型生成文本的明显特征,因此残差能量模型生成的重复词组明显减少(详见论文)。

总结来看,残差能量模型是比 state-of-the-art 的 transformer 语言模型效果更好的全局归一化模型。它的训练方式只是训练一个辨别真实文本还是语言模型生成的分类器,因此非常简单稳定,同时还拥有 NCE 带来的理论正确保证。

作者在实验中使用了语言模型作为测试任务,但实际上很容易推广到条件生成,比如机器翻译或者文本摘要。

另外,作者提出的能量模型和 GAN 的思路有很大不同:GAN 使用分类判别器的目的是改进生成器,最后并没有使用分类判别器;而残差能量模型最终使用分类器,而且训练过程中不去试图改变分类器,因此训练过程更加稳定。最后,全局归一化(globally normalized)的能量模型虽然在 Yann Lecun 等人看来是未来的重要方向(https://iclr.cc/virtual_2020/speaker_7.html),但目前还没有得到广泛重视。作者认为这里有很多未来工作的可能方向,比如和隐变量结合等。

引用:

[1]: Bakhtin, Anton, Yuntian Deng, Sam Gross, Myle Ott, Marc'Aurelio Ranzato, and Arthur Szlam."Energy-Based Models for Text." arXiv (2020): arXiv-2004.

[2]: Zellers, Rowan, Ari Holtzman, Hannah Rashkin, Yonatan Bisk, Ali Farhadi, Franziska Roesner, and Yejin Choi. "Defending against neural fake news." In Advances in Neural Information Processing Systems, pp. 9051-9062. 2019.

[3]: Gutmann, Michael, and Aapo Hyvärinen. "Noise-contrastive estimation: A new estimation principle for unnormalized statistical models." In Proceedings of the Thirteenth International Conference on Artificial Intelligence and Statistics, pp. 297-304. 2010.

[4]: Ma, Zhuang, and Michael Collins. "Noise contrastive estimation and negative sampling for conditional models: Consistency and statistical efficiency." arXiv preprint arXiv:1809.01812 (2018).

[5]: Holtzman, Ari, Jan Buys, Li Du, Maxwell Forbes, and Yejin Choi. "The curious case of neural text degeneration." arXiv preprint arXiv:1904.09751 (2019).

[6]: Liu, Yinhan, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, and Veselin Stoyanov. "Roberta: A robustly optimized bert pretraining approach." arXiv preprint arXiv:1907.11692 (2019).

入门哈佛大学Facebook AI Research(FAIR)文本生成
相关数据
参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

对抗训练技术

对抗训练涉及两个模型的联合训练:一个模型是生成器,学习生成假样本,目标是骗过另一个模型;这另一个模型是判别器,通过对比真实数据学习判别生成器生成样本的真伪,目标是不要被骗。一般而言,两者的目标函数是相反的。

自然语言处理技术

自然语言处理(英语:natural language processing,缩写作 NLP)是人工智能和语言学领域的分支学科。此领域探讨如何处理及运用自然语言;自然语言认知则是指让电脑“懂”人类的语言。自然语言生成系统把计算机数据转化为自然语言。自然语言理解系统把自然语言转化为计算机程序更易于处理的形式。

生成模型技术

在概率统计理论中, 生成模型是指能够随机生成观测数据的模型,尤其是在给定某些隐含参数的条件下。 它给观测值和标注数据序列指定一个联合概率分布。 在机器学习中,生成模型可以用来直接对数据建模(例如根据某个变量的概率密度函数进行数据采样),也可以用来建立变量间的条件概率分布。

语言模型技术

语言模型经常使用在许多自然语言处理方面的应用,如语音识别,机器翻译,词性标注,句法分析和资讯检索。由于字词与句子都是任意组合的长度,因此在训练过的语言模型中会出现未曾出现的字串(资料稀疏的问题),也使得在语料库中估算字串的机率变得很困难,这也是要使用近似的平滑n元语法(N-gram)模型之原因。

文本生成技术

文本生成是生成文本的任务,其目的是使人类书写文本难以区分。

GPT-2技术

GPT-2是OpenAI于2019年2月发布的基于 transformer 的大型语言模型,包含 15 亿参数、在一个 800 万网页数据集上训练而成。据介绍,该模型是对 GPT 模型的直接扩展,在超出 10 倍的数据量上进行训练,参数量也多出了 10 倍。在性能方面,该模型能够生产连贯的文本段落,在许多语言建模基准上取得了 SOTA 表现。而且该模型在没有任务特定训练的情况下,能够做到初步的阅读理解、机器翻译、问答和自动摘要。

推荐文章
暂无评论
暂无评论~