Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

在Transformer时代重塑RNN,RWKV将非Transformer架构扩展到数百亿参数

Transformer 模型在几乎所有自然语言处理(NLP)任务中都带来了革命,但其在序列长度上的内存和计算复杂性呈二次方增长。相比之下,循环神经网络(RNNs)在内存和计算需求上呈线性增长,但由于并行化和可扩展性的限制,很难达到与 Transformer 相同的性能水平。本文提出了一种新颖的模型架构,Receptance Weighted Key Value(RWKV),将 Transformer 的高效可并行训练与 RNN 的高效推理相结合。实验证明,RWKV 的性能与相同规模的 Transformer 相当。

深度学习技术在人工智能领域取得了重大进展,在各种科学和工业应用中发挥了关键作用。这些应用通常涉及复杂的序列数据处理任务,包括自然语言理解、对话式人工智能、时间序列分析等,其中用到的技术主要包括循环神经网络(RNNs)、卷积神经网络(CNNs)和 Transformer 等。


不过,这些方法各自存在不同的缺点,从而限制了它们在某些场景下的效率。循环神经网络(RNNs)面临着梯度消失的问题,使得它们难以对长序列进行训练。此外,在训练过程中无法在时间维度上并行化,进而限制了其可扩展性。另一方面,卷积神经网络(CNNs)只擅长捕捉局部模式,在处理长程依赖方面还很欠缺,而这对于许多序列处理任务至关重要。


Transformer 模型由于其处理局部和长程依赖关系的能力以及可并行化训练的特点而成为一个强大的替代方案,如 GPT-3、ChatGPT、GPT-4、LLaMA 和 Chinchilla 等都展示了这种架构的能力,推动了自然语言处理领域的前沿。尽管取得了这些重大进展,Transformer 中固有的自注意力机制带来了独特的挑战,主要是由于其二次复杂度造成的。这种复杂性使得该架构在涉及长输入序列或资源受限情况下计算成本高昂且占用内存。这也促使了大量研究的发布,旨在改善 Transformer 的扩展性,但往往以牺牲一些特性为代价。


为了应对这些挑战,一个由 27 所大学、研究机构组成的开源研究团队,联合发表论文《 RWKV: Reinventing RNNs for the Transformer Era 》,文中介绍了一种新型模型:RWKV(Receptance Weighted Key Value),这是一种新颖的架构,有效地结合了 RNN 和 Transformer 的优点,同时规避了两者的缺点。RWKV 设计精良,能够缓解 Transformer 所带来的内存瓶颈和二次方扩展问题,实现更有效的线性扩展,同时保留了使 Transformer 在这个领域占主导的一些性质。



  • 论文地址:https://arxiv.org/pdf/2305.13048.pdf

  • RWKV 模型下载:https://huggingface.co/BlinkDL/rwkv-4-raven

  • Demo 地址:https://www.codewithgpu.com/i/app/BlinkDL/ChatRWKV/RWKV-4-Raven-7B


本文利用线性注意力机制,允许将模型定义为 Transformer 或 RNN,从而在训练期间并行化计算,并在推理过程中保持恒定的计算和内存复杂性,使其成为第一个可扩展到数百亿参数的非 Transformer 架构。


RWKV 其中的一个特征是它能够提供并行训练和强大的可扩展性,类似于 Transformer。此外,该研究对 RWKV 中的注意力机制进行了重新阐述,引入了线性注意力的一个变体,避开了传统点积(dot-product)token 交互,转而采用更有效的通道导向注意力( channel directed attention )。这种方法与传统的 Transformer 架构形成了鲜明的对比,其中特定的 token 交互主导了注意力。在 RWKV 中,线性注意力的实施是无需近似的,这在效率上提供了显著的改进,并增强了可扩展性,详见表 1。


该研究表示,开发 RWKV 的主要动机是弥补神经网络架构在计算效率和表达能力之间的差距。它为处理涉及数十亿参数的大规模模型的任务提供了一个有希望且可行的解决方案,以极低的计算成本展现出强有力的竞争性。


实验结果表明,RWKV 可以成为一个有价值的工具,用于解决各个领域扩展和部署人工智能模型的各种挑战,特别是那些涉及序列数据处理的领域。RWKV 为下一代更可持续、计算效率更高的序列处理任务的 AI 模型铺平了道路。


总结而言,本文的贡献如下:


  • 引入了 RWKV 网络架构,该架构结合了 RNN 和 Transformer 的优点,同时减轻了它们已知的限制。

  • 本文提出了一个新的注意力机制重构,进而提出线性注意力,避开了与标准 Transformer 模型相关的二次复杂性。

  • 本文在基准数据集上进行了一系列全面的实验,展示了 RWKV 在处理涉及大规模模型和长距离依赖任务上的性能、效率和可扩展性。

  • 发布了预训练模型,其大小从 1.69 亿到 140 亿的参数不等,这些模型是在 Pile 上训练的。


值得注意的是,论文参与机构之一的 EleutherAI 表示:这篇论文还不是最终版本,后续会不断完善。



RWKV 模型


RWKV 架构的名称来源于时间混合和通道混合块中使用的四个主要模型元素,分别如下:


  • R:Receptance 向量,用于接收以往信息;

  • W:权重(weight)是位置权重衰减向量,是可训练的模型参数

  • K:键(Key)是类似于传统注意力中 K 的向量;

  • V:值(Value)是类似于传统注意力中 V 的向量。


每一时间步的主要元素之间的交互是相乘增加的,具体如下图 2 所示。


架构细节


RWKV 架构由一系列堆叠的残差块组成,每个残差块又由具有循环结构的时间混合和通道混合子块组成。


循环被表示为当前输入和前一个时间步的输入之间的线性插值(研究者称这种技术为时移混合或 token shift,如下图 3 所示),该插值可以针对输入嵌入的每个线性投影进行独立调整(比如时间混合中的 R、K 和 V,通道混合中的 R 和 K),并作为公式 14 中形式化的 WKV 的时变更新。


类 Transformer 的并行化


RWKV 可以在时间并行模式下进行高效地并行化,让人联想到 Transformer。单个层中一个 batch 序列的时间复杂度为 O (BTd^2 ),它主要由矩阵乘法 W_□,  □ ∈ {r, k, v, o}(假设 B 个序列、T 个最大 token 和 d 个通道)。同时更新注意力分数 wkv_t 需要串行扫描,并且复杂度为 O (BTd)。


类 RNN 的序列解码


在循环网络中,将状态 t 时的输出用作状态 t+1 时的输入很常见。这在语言模型的自回归解码推理中尤为明显,要求每一个 token 在馈入下一步之前必须进行计算,从而使 RWKV 可以利用类 RNN 结构(即时序模式)。在这种情况下,RWKV 可以方便地循环用于推理解码,从而利用每个输出 token 仅依赖于最新状态的优势。


然后 RWKV 充当 RNN 解码器,在序列长度方面保持恒定速度和内存占用,从而更高效地处理更长的序列。相比之下,自注意力通常需要 KV 缓存相对于序列长度呈线性增长,这会导致效率下降,并随序列长度增加消耗更多内存和时间。


软件实现


RWKV 最初使用 PyTorch 深度学习库和自定义 CUDA 内核(它用于 WKV 计算)来实现。尽管 RWKV 是一个通用循环网络,但其当前的实现主要集中在语言建模任务(RWKV-LM)。该模型架构包含了一个嵌入层,为此研究者遵循第 4.7 节中的设置,并按照第 4.6 节中的原则依次应用几个相同的残差块,具体如上图 2 和 3 所示。


梯度稳定性和层堆叠


RWKV 架构被设计为 Transformer 和 RNN 的融合,与传统的 RNN 相比,Transformers 具有稳定梯度和更深层次架构的优势,同时推理效率高。


RWKV 模型具有用于更新类似注意力分数的单步过程,其中包括一个依赖于时间的 softmax 操作,该操作有助于数值稳定性并防止梯度消失(有关严格证明,请参见附录 F)。直观地说,此操作可确保梯度沿最相关的路径传播。Layer normalization (Ba et al., 2016) 是架构的另一个关键方面,它通过稳定梯度、解决梯度消失和爆炸问题来增强深度神经网络的训练动态。


利用时间结构进行时序数据处理


RWKV 通过三种机制的组合来捕获和传播时序信息:循环、时间衰减和 token shift。


RWKV 时间混合块中的循环是模型捕获序列元素之间复杂关系和随时间传播局部信息的能力的基础。


时间衰减机制(等式 14 中的 e^−w 和 e^u)保持了对序列元素之间位置关系的敏感性。通过逐渐减少以往信息随时间的影响,该模型保留了时间局部性和进展感,这对于时序处理至关重要。


token shift 或 time-shift 混合或(图 3 中的对角线箭头),也有助于模型适应时序数据。通过在当前输入和前一个时间步输入之间进行线性插值,模型自然地聚合和门控输入通道中的信息。


实验结果


实验的重点是回答以下问题:


  • RQ1:在参数数量和训练 token 数量相等的情况下,RWKV 与二次 transformer 架构相比具有竞争力吗?

  • RQ2:增加参数数量时,RWKV 是否仍然具有与二次 transformer 架构相竞争的能力?

  • RQ3:当 RWKV 模型被训练用于开源二次 transformer 无法高效处理的上下文长度时,增加 RWKV 的参数是否能够获得更好的语言建模损失?


首先是回答 RQ1 和 RQ2 问题,从图 4 可以看出,在六个基准测试中(Winogrande、PIQA、ARC-C、ARC-E、LAMBADA 和 SciQ),RWKV 与开源二次复杂度 transformer 模型 Pythia、OPT 和 BLOOM 具有相当的竞争力。RWKV 甚至在四个任务(PIQA、OBQA、ARC-E 和 COPA)中胜过了 Pythia 和 GPT-Neo。



对于 RQ3,图 5 显示,增加上下文长度会导致 Pile 上的测试损失降低,这表明 RWKV 能够有效利用较长的上下文信息。


理论RNNTransformer
1
相关数据
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

自然语言理解技术

自然语言理解是人工智能的核心课题之一,也被广泛认为是最困难和最具标志性的任务。最经典的两个人工智能思想实验——图灵测试和中文房间,都是围绕自然语言理解来构建的。自然语言理解在人工智能技术体系中的重要性不言而喻,它一方面承载着机器和人的交流,另一方面直达知识和逻辑。自然语言理解也是人工智能学者孜孜以求的圣杯,机器学习的巨擘 Michael I. Jordan 就曾经在 Reddit 上的 AMA(Ask Me Anything)栏目中畅想用十亿美元建立一个专门用于自然语言理解的实验室。

重构技术

代码重构(英语:Code refactoring)指对软件代码做任何更动以增加可读性或者简化结构而不影响输出结果。 软件重构需要借助工具完成,重构工具能够修改代码同时修改所有引用该代码的地方。在极限编程的方法学中,重构需要单元测试来支持。

自注意力技术

自注意力(Self-attention),有时也称为内部注意力,它是一种涉及单序列不同位置的注意力机制,并能计算序列的表征。自注意力在多种任务中都有非常成功的应用,例如阅读理解、摘要概括、文字蕴含和语句表征等。自注意力这种在序列内部执行 Attention 的方法可以视为搜索序列内部的隐藏关系,这种内部关系对于翻译以及序列任务的性能非常重要。

人工智能技术

在学术研究领域,人工智能通常指能够感知周围环境并采取行动以实现最优的可能结果的智能体(intelligent agent)

基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

时间复杂度技术

在计算机科学中,算法的时间复杂度是一个函数,它定量描述了该算法的运行时间。这是一个代表算法输入值的字符串的长度的函数。时间复杂度常用大O符号表述,不包括这个函数的低阶项和首项系数。使用这种方式时,时间复杂度可被称为是渐近的,亦即考察输入值大小趋近无穷时的情况。例如,如果一个算法对于任何大小为 n (必须比 n0 大)的输入,它至多需要 5n3 + 3n 的时间运行完毕,那么它的渐近时间复杂度是 O(n3)。

注意力机制技术

我们可以粗略地把神经注意机制类比成一个可以专注于输入内容的某一子集(或特征)的神经网络. 注意力机制最早是由 DeepMind 为图像分类提出的,这让「神经网络在执行预测任务时可以更多关注输入中的相关部分,更少关注不相关的部分」。当解码器生成一个用于构成目标句子的词时,源句子中仅有少部分是相关的;因此,可以应用一个基于内容的注意力机制来根据源句子动态地生成一个(加权的)语境向量(context vector), 然后网络会根据这个语境向量而不是某个固定长度的向量来预测词。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

卷积神经网络技术

卷积神经网路(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。卷积神经网路由一个或多个卷积层和顶端的全连通层(对应经典的神经网路)组成,同时也包括关联权重和池化层(pooling layer)。这一结构使得卷积神经网路能够利用输入数据的二维结构。与其他深度学习结构相比,卷积神经网路在图像和语音识别方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他深度、前馈神经网路,卷积神经网路需要考量的参数更少,使之成为一种颇具吸引力的深度学习结构。 卷积网络是一种专门用于处理具有已知的、网格状拓扑的数据的神经网络。例如时间序列数据,它可以被认为是以一定时间间隔采样的一维网格,又如图像数据,其可以被认为是二维像素网格。

插值技术

数学的数值分析领域中,内插或称插值(英语:interpolation)是一种通过已知的、离散的数据点,在范围内推求新数据点的过程或方法。求解科学和工程的问题时,通常有许多数据点借由采样、实验等方法获得,这些数据可能代表了有限个数值函数,其中自变量的值。而根据这些数据,我们往往希望得到一个连续的函数(也就是曲线);或者更密集的离散方程与已知数据互相吻合,这个过程叫做拟合。

长距离依赖技术

也作“长距离调序”问题,在机器翻译中,比如中英文翻译,其语言结构差异比较大,词语顺序存在全局变化,不容易被捕捉

自然语言处理技术

自然语言处理(英语:natural language processing,缩写作 NLP)是人工智能和语言学领域的分支学科。此领域探讨如何处理及运用自然语言;自然语言认知则是指让电脑“懂”人类的语言。自然语言生成系统把计算机数据转化为自然语言。自然语言理解系统把自然语言转化为计算机程序更易于处理的形式。

堆叠技术

堆叠泛化是一种用于最小化一个或多个泛化器的泛化误差率的方法。它通过推导泛化器相对于所提供的学习集的偏差来发挥其作用。这个推导的过程包括:在第二层中将第一层的原始泛化器对部分学习集的猜测进行泛化,以及尝试对学习集的剩余部分进行猜测,并且输出正确的结果。当与多个泛化器一起使用时,堆叠泛化可以被看作是一个交叉验证的复杂版本,利用比交叉验证更为复杂的策略来组合各个泛化器。当与单个泛化器一起使用时,堆叠泛化是一种用于估计(然后纠正)泛化器的错误的方法,该泛化器已经在特定学习集上进行了训练并被询问了特定问题。

深度神经网络技术

深度神经网络(DNN)是深度学习的一种框架,它是一种具备至少一个隐层的神经网络。与浅层神经网络类似,深度神经网络也能够为复杂非线性系统提供建模,但多出的层次为模型提供了更高的抽象层次,因而提高了模型的能力。

语言模型技术

统计式的语言模型是借由一个几率分布,而指派几率给字词所组成的字串。语言模型经常使用在许多自然语言处理方面的应用,如语音识别,机器翻译,词性标注,句法分析和资讯检索。

推荐文章
暂无评论
暂无评论~