Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

Geek AI编译

textgenrnn:只需几行代码即可训练文本生成网络

本文是一个 GitHub 项目,介绍了 textgenrnn,一个基于 Keras/TensorFlow 的 Python 3 模块。只需几行代码即可训练文本生成网络。

项目地址:https://github.com/minimaxir/textgenrnn?reddit=1

通过简简单单的几行代码,使用预训练神经网络生成文本,或者在任意文本数据集上训练你自己的任意规模和复杂度的文本生成神经网络。

textgenrnn 是一个基于 Keras/TensorFlow 的 Python 3 模块,用于创建 char-rnn,具有许多很酷炫的特性:

  • 它是一个使用注意力权重(attention-weighting)和跳跃嵌入(skip-embedding)等先进技术的现代神经网络架构,用于加速训练并提升模型质量。

  • 能够在字符层级和词层级上进行训练和预测。

  • 能够设置 RNN 的大小、层数,以及是否使用双向 RNN。

  • 能够对任何通用的输入文本文件进行训练。

  • 能够在 GPU 上训练模型,然后在 CPU 上使用这些模型。

  • 在 GPU 上训练时能够使用强大的 CuDNN 实现 RNN,这比标准的 LSTM 实现大大加速了训练时间。

  • 能够使用语境标签训练模型,能够更快地学习并在某些情况下产生更好的结果。

你可以使用 textgenrnn,并且在该 Colaboratory Notebook(https://drive.google.com/file/d/1mMKGnVxirJnqDViH7BDJxFqWrsXlPSoK/view?usp=sharing)中免费使用 GPU 训练任意文本文件。

示例

from textgenrnn import textgenrnn

textgen = textgenrnn()
textgen.generate()

[Spoiler] Anyone else find this post and their person that was a little more than I really like the Star Wars in the fire or health and posting a personal house of the 2016 Letter for the game in a report of my backyard.

该模型可以很容易地在新的文本上进行训练,甚至可以在仅仅输入一次数据之后生成合适的文本。

textgen.train_from_file('hacker-news-2000.txt', num_epochs=1)
textgen.generate()

Project State Project Firefox

这个模型的权重比较小(占磁盘上 2 MB 的空间),它们可以很容易地被保存并加载到新的 textgenrnn 实例中。因此,你可以使用经过数百次数据输入训练的模型。(实际上,textgenrnn 的学习能力过于强大了,以至于你必须大大提高温度(Temperature)来得到有创造性的输出。)

textgen_2 = textgenrnn('/weights/hacker_news.hdf5')
textgen_2.generate(3, temperature=1.0)

Why we got money “regular alter”


Urburg to Firefox acquires Nelf Multi Shamn


Kubernetes by Google’s Bern

您还可以训练一个支持词级别嵌入和双向 RNN 层的新模型。

使用方法

textgenrnn 可以通过 pip 从 pypi(https://pypi.python.org/pypi/textgenrnn)中安装:

pip3 install textgenrnn
  • 你可以在该 Jupyter Notebook(https://github.com/minimaxir/textgenrnn/blob/master/docs/textgenrnn-demo.ipynb)中查看常见的功能和配置选项的演示案例。

  • /datasets 包含用于训练 textgenrnn 的 Hacker News 和 Reddit data 示例数据集。

  • /weights 包含在上述的数据集上进一步预训练的模型,它可以被加载到 textgenrnn 中。

  • /output 包含从上述预训练模型中生成文本的示例。

神经网络架构及实现

textgenrnn 基于 Andrej Karpathy 的 char-rnn 项目(https://github.com/karpathy/char-rnn),并且融入了一些最新的优化,如处理非常小的文本序列的能力。

本文涉及到的预训练模型遵循 DeepMoji 的神经网络架构(https://github.com/bfelbo/DeepMoji/blob/master/deepmoji/model_def.py)的启发。对于默认的模型,textgenrnn 接受最多 40 个字符的输入,它将每个字符转换为 100 维的字符嵌入向量,并将这些向量输入到一个包含 128 个神经元的长短期记忆(LSTM)循环层中。接着,这些输出被传输至另一个包含 128 个神经元的 LSTM 中。以上所有三层都被输入到一个注意力层中,用来给最重要的时序特征赋权,并且将它们取平均(由于嵌入层和第一个 LSTM 层是通过跳跃连接与注意力层相连的,因此模型的更新可以更容易地向后传播并且防止梯度消失)。该输出被映射到最多 394 个不同字符的概率分布上,这些字符是序列中的下一个字符,包括大写字母、小写字母、标点符号和表情。(如果在新的数据集上训练一个新模型,可以配置所有上面提到的数值参数。)

或者,如果可以获得每个文本文档的语境标签,则可以在语境模式下训练模型。在这种模式下,模型会学习给定语境的文本,这样循环层就会学习到非语境化的语言。前面提到的只包含文本的路径可以借助非语境化层提升性能;总之,这比单纯使用文本训练的模型训练速度更快,且具备更好的定量和定性的模型性能。

软件包包含的模型权重是基于(通过 BigQuery)在 Reddit 上提交的成千上万的文本文档训练的,它们来自各种各样的 subreddit 板块。此外,该网络还采用了上文提到的非语境方法,从而提高训练的性能,同时减少作者的偏见。

当使用 textgenrnn 在新的文本数据集上对模型进行微调时,所有的层都会被重新训练。然而,由于原始的预训练网络最初具备鲁棒性强得多的「知识」,新的 textgenrnn 最终能够训练地更快、更准确,并且可以学习原始数据集中未出现的新关系。(例如:预训练的字符嵌入包含所有可能的现代互联网语法类型中的字符语境。)

此外,重新训练是通过基于动量的优化器和线性衰减的学习率实现的,这两种方法都可以防止梯度爆炸,并且大大降低模型在长时间训练后发散的可能性。

注意事项

即使使用经过严格训练的神经网络,你也不能每次都能得到高质量的文本。这就是使用神经网络文本生成的博文(http://aiweirdness.com/post/170685749687/candy-heart-messages-written-by-a-neural-network)或推文(https://twitter.com/botnikstudios/status/955870327652970496)通常生成大量文本,然后挑选出最好的那些再进行编辑的主要原因。

不同的数据集得到的结果差异很大。因为预训练的神经网络相对来说较小,因此它不能像上述博客展示的 RNN 那样存储大量的数据。为了获得最佳结果,请使用至少包含 2000-5000 个文档的数据集。如果数据集较小,你需要在调用训练方法和/或从头开始训练一个新模型时,通过调高 num_epochs 参数来对模型进行更长时间的训练。即便如此,目前也没有一个判断模型」好坏」的启发式方法。

你并不一定需要用 GPU 重新训练 textgenrnn,但是在 CPU 上训练花费的时间较长。如果你使用 GPU 训练,我建议你增加 batch_size 参数,获得更好的硬件利用率。

未来计划

  • 更多正式文档;

  • 一个使用 tensorflow.js 的基于 web 的实现(由于网络规模小,效果特别好);

  • 一种将注意力层输出可视化的方法,以查看神经网络是如何「学习」的;

  • 有监督的文本生成模式:允许模型显示 top n 选项,并且由用户选择生成的下一个字符/单词(https://fivethirtyeight.com/features/some-like-it-bot/);

  • 一个允许将模型架构用于聊天机器人对话的模式(也许可以作为单独的项目发布);

  • 对语境进行更深入的探索(语境位置 + 允许多个语境标签);

  • 一个更大的预训练网络,它能容纳更长的字符序列和对语言的更深入理解,生成更好的语句;

  • 层次化的作用于词级别模型的 softmax 激活函数(Keras 对此有很好的支持);

  • 在 Volta/TPU 上进行超高速训练的 FP16 浮点运算(Keras 对此有很好的支持)。

使用 textgenrnn 的项目

  • Tweet Generator:训练一个为任意数量的 Twitter 用户生成推文而优化的神经网络。

工程文本生成GitHub
3
暂无评论
暂无评论~