Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

哈佛大学开源自然语言项目:带有注意神经网络的Seq2seq学习

在带有(可选)注意(attention)的标准seq2seq模型的Torch实现中,其编码器-解码器(encoder-decoder)模型是LSTM。编码器可以是一个双向LSTM。此外还能在字符嵌入(character embeddings)上运行一个卷积神经网络然后再运行一个 highway network,从而将字符(而不是输入的词嵌入)作为输入来使用。

该注意模型来源于发表于EMNLP2015大会上的论文《Effective Approaches to Attention-based Neural Machine Translation》。我们使用该论文中的带有输入-反馈方式的全局通用注意力模型( global-general-attention model with the input-feeding approach)。输入-反馈作为可选项,也可去掉。

其中的字符模型来源于AAAI2016上的论文《Character-Aware Neural Language Models

在基准模型之上有很多其他可选模型,要感谢谢SYSTRAN小伙伴们的提供

准备

Python

  • h5py

  • numpy

Lua

需要下面的packages:

  • hdf5

  • nn

  • nngraph


GPU的使用还需要:

  • cutorch

  • cunn


如果你运行了字符模型,你还要装载:

  • cudnn

  • luautf8

快速入门


我们现在要从data/文件夹中调用一些样本数据。首先运行数据处理代码

python preprocess.py --srcfile data/src-train.txt --targetfile data/targ-train.txt
--srcvalfile data/src-val.txt --targetvalfile data/targ-val.txt --outputfile data/demo


这将打开源/目标列/有效文件s (src-train.txt, targ-train.txt, src-val.txt, targ-val.txt),并生成一些hdf5文件供Lua使用。


demo.src.dict:源词汇字典映射到索引。 demo.targ.dict: 目标词汇字典映射到索引映射。 demo-val.hdf5 hdf5文件包含了验证数据。


在新数据上做预测时需要用到 *.dict 文件


现在运行模型


th train.lua -data_file data/demo-train.hdf5 -val_data_file data/demo-val.hdf5 -savefile demo-model


这会运行默认的模型,它由一个在encoder/decoder上都带有500个隐藏单元的双层LSTM。你也可在这个集群中添加-gpuid 1 来使用GPU1。

th evaluate.lua -model demo-model_final.t7 -src_file data/src-val.txt -output_file pred.txt
-src_dict data/demo.src.dict -targ_dict data/demo.targ.dict


这会将预测输出到 o pred.txt 中。由于演示数据库很小,所以这个预测的结果可能会很糟。你可以尝试在更大的数据库上试试!比如可以下数百万条载翻译或总结的平行句子。

详细过程请见:https://github.com/harvardnlp/seq2seq-attn

入门开源NLP工程哈佛大学GitHubSeq2Seq
暂无评论
暂无评论~