Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

Mamba可以替代Transformer,但它们也能组合起来使用

1+1>2。

Transformer 很厉害,但并不完美,尤其是在处理长序列方面。而状态空间模型(SSM)则在长序列上的表现相当不俗。早在去年就有研究者提出可使用 SSM 替代 Transformer,参见文章《预训练无需注意力,扩展到4096个token不成问题,与BERT相当》,前些天基于 SSM 方法的 Mamba 更是异军突起,推理吞吐量达到了 Transformer 的五倍之多,参阅《五倍吞吐量,性能全面包围Transformer:新架构Mamba引爆AI圈》。

但实际上,SSM 和 Transformer 并不是非此即彼的两种架构,它们完全可以组合起来!

近日公布的一篇 NeurIPS 2023 论文《Block-State Transformers》就采用了这种做法,其不仅能轻松支持 65k token 长度的超长输入,而且计算效率还非常高,速度相比使用循环单元的 Transformer 足可提升十倍之多!这篇论文也得到了 Mamba 作者 Tri Dao 的点赞,他表示:「SSM 和Transformer 似乎可以互补。」

图片

但在我们介绍这种新方法前,先简单说说 Transformer。在许多不同的自然语言处理(NLP)任务上,Transformer 的表现都非常出色。可以说 Transformer 已经很大相当程度上替代了循环神经网络。不仅如此,它也正在图像和视频等 NLP 之外的领域大展拳脚。

其成功的原因有很多,包括计算效率和架构层面的归纳偏差,这让它们非常适合在自然语言任务进行大规模训练。在计算方面,Transformer 能以并行方式处理输入序列的 token,从而使其能充分利用现代加速器硬件。此外,注意力机制让 Transformer 可以找到更长序列之间的关系,其方式是在推断下一个 token 时读取从过去 token 提取的所有信息。相比于 RNN 和 LSTM,自注意力有两个优势:(1) 存储信息以及将这些信息直接用作上下文的能力得到了极大提升,(2) 在更长序列上能更稳定地训练。

尽管 Transformer 相比 RNN 有很多优势,但它在输入序列长度的扩展上依然存在问题,其中涉及计算性能和质量等方面的原因。更进一步说,Transformer 的运行时间会随输入序列长度的增长成二次方增长,这会让训练这些模型的成本越来越高。

此外,众所周知使用注意力的 Transformer 在长输入分类任务上表现不佳。最基本的 Transformer 在长序列上训练时可能不稳定,而且其 token 重要度聚焦在当前时间步骤周围约 50 个 token 的局部感受野中。

近来,越来越多的研究表明状态空间模型(SSM)可以替代 Transformer,因为 SSM 可以捕获极长序列之中的依赖关系,同时还有更高的计算效率和更好的并行化能力。

尽管 SSM 依然属于自回归序列模型,但其底层的线性时间不变式动态系统可使用基于快速傅立叶变换(FFT)的可并行化卷积算子来高效地处理序列,而且这个过程的复杂度仅为 𝒪(𝐿 log 𝐿),其中 𝐿 是序列的长度。此外,借用在线函数近似的方法,通过推导循环更新规则,可以确保在长序列上保留过去的信息,甚至可达成千上万个时间步骤。在 Long-Range Arena 基准上,SSM 甚至超过了 Transformer 一大截,参阅机器之心报道《六项任务、多种数据类型,谷歌、DeepMind提出高效Transformer评估基准》。

尽管 SSM 在长程分类任务上很成功,但如果要用作通用语言建模的现成可用序列模型,SSM 还完全赶不上 Transformer。

近期又有研究《Long Range Language Modeling via Gated State Spaces》认为 Transformer 和 SSM 完全可以互补。

DeepMind 等机构提出的新架构 Block-State Transformer(BST)将强大的基于局部注意力的归纳偏差与长期上下文建模能力组合到了一起,做成了单一层。

图片

论文地址:https://arxiv.org/pdf/2306.09539.pdf

据介绍,该模型能在处理长输入序列的同时整合注意力机制来预测下一个 token。相比于基于 Transformer 的层,BST 是完全可并行化的,能扩展用于更长得多的序列,同时速度还能快 10 倍。

在每一层 BST 中,有一个 SSM 将输入的整个序列映射进一个同样长度的「上下文」序列。这个 SSM 子层使用基于 FFT 的卷积。然后将这个上下文序列分成大小相等的上下文块,这个大小即为窗口长度 W;然后再将每个上下文块输入一个 Transformer 层,其注意力关注的是大小为 W 的子序列。之后对输入 token 嵌入块与对应的上下文状态块使用交叉注意力,如图 1 所示。

图片

注意,通过将 SSM 用作一种上下文化的方法,就可以完全不需要序列循环,这样一来就能以完全并行的方式运行这种 SSM-Transformer 混合层。

最后的运行时间复杂度可以表示成一个和:𝒪(𝑊²)+𝒪(𝐿 log 𝐿),其中前一项表示 Transformer 子层的时间复杂度,后一项是 SSM 子层的时间复杂度

只要有支持并行计算的硬件,相较于 Block-Recurrent Transformer 的 𝒪(𝐿𝑊),这是一个重大提升。此外,由于硬件施加的限制,SSM 在完整序列上的运行时间复杂度与 Block Transformer 在 token 块上的运行时间复杂度相当,这进一步意味着 BST 层不存在速度瓶颈。该团队使用包含数十万 token 的序列通过实验验证了这一点。
 
方法

这里研究的是通过仅解码器语言模型实现下一 token 预测的问题。

对状态空间的前置说明

状态空间模型可以分为两大类:

状态空间:结构化核S4、S5、S4D、DSS遵循卷积核的一种结构化初始化,方式是展开一种线性时间不变式(LTI)动态系统,如下所示:

图片

其中的参数包括状态矩阵 𝚨∈ℝ^{N×N},向量 𝐁∈ℝ^{N×1}、𝐂∈ℝ^{1×N}、𝐃∈ℝ^{1×1}。SSM 会将一维的输入信号 u_k 映射成一维的输出信号 y_k。

显式参数化的过滤器。不同于结构化核,还可以将卷积核参数化为可训练的权重并优化它们。但是,这会导致性能很差,除非对这些核使用特定类型的正则化方法。替代 Transformer 的无注意力模型中也有使用可训练核的,比如 Hyena 涉及到沿核对权重进行指数衰减。
 
Block-State Transformer(BST)层

Block-State Transformer 层将 SSM 与 Block Transformer 组合到了一起。在每一次训练迭代中,都会从一个长文档采样一个包含 L 个 token 的序列。然后嵌入该 token 并将其馈送给模型。这个模型由堆叠的 Block-State Transformer 层构成。每一层 BST 都会选择性地包含一个 SSM 子层,其负责为 Block Transformer 层提供长程上下文,这与 Block-Recurrent Transformer(BRECT)单元的工作方式类似。这个 SSM 子层的输入是前一层的 token 嵌入序列,输出则是一个长度同样为 L 的序列。

这个输出经过了上下文编码,也就是说每个时间步骤的项目都可能包含有关该序列中元素之前的所有时间步骤的信息。他们从上下文序列收集一定数量 S 的「上下文状态」,并使得 S ≪ L。

这些上下文状态会被馈送给 Block Transformer,以替代 Block-Recurrent Transformer 中的「循环状态向量」。如图 1 右侧所示,后续操作保持不变,只是无需再运行 BRECT 单元的循环单元,因为现在是通过 SSM 来维护上下文。除了上下文状态,Block Transformer 的输入中还有长度 W 的 token 嵌入的块/窗口;然后在这个窗口与上下文状态上使用交叉注意力。然后将这个交叉注意力操作的输出与自注意力在输入嵌入上的输出连接起来,之后是一个简单的投影。

SSM 不仅能在更长时间尺度上保留信息,而且使用 SSM 来维持上下文状态以替代循环单元,可以得到计算效率更高的层。通过将 SSM 整合进 Transformer 层,可以移除循环部分,从而让 Block-State Transformer 层可以完全并行化。
 
上下文状态

尽管从技术上看,最新的 SSM 输出包含有关整个序列的信息,但仅从最后的状态检索单个 token 可能是不可行的。为了弥补这一点,该团队将一系列状态连接了起来,对应于最新的 token 块。这与 BRECT 采用的方法类似。这种表征可以通过冗余来确保可检索性和易访问性。

在新提出的方法中,上下文状态是使用 SSM 的输出构建的,并会被馈送给 Transformer 的注意力头。这些上下文状态的构建方式有很多。为了引导设计决策,该团队考虑了多种设计方案,包括使用单头(Single-Head)、多头(Multi-Head)或多过滤器(Multi-Filter)。其中单头设计见图 1。下图 2 则展示了多头和多过滤器的设计方案。

图片

比较下来,多过滤器的记忆状态的冗余最少,多头次之,单头的冗余最大。

结果

该团队在 PG19、GitHub 和 arXiv 三个数据集上进行了实验,检验了新提出的 BST 在不同长度的英语文本、latex 科学文章和源代码上的效果。下表 1 总结了实验结果。

图片

下图 3 则给出了长度泛化分析并报告了困惑度。实验中,新模型和基准模型的参数数量都约为 4 亿,训练时的序列长度为 4k,测试中的序列长度为 {512, 16k, 65k}。

可以看到,在 PG19、GitHub 和 arXiv 上,当序列长度为 65k 时,BST:SH:S4-L 的困惑度最好。

图片

在效率方面,下图 4 左给出了 BST 层在 GPU 上的基准测试结果。

可以看到 SSM 带来了非常显著的增长——比包含循环单元的 Block-Recurrent Transformer 快 6-11 倍;即使在序列长度达到 65k token 时,还依然能有 6 倍的提升,而这时候硬件就已经开始饱和了。当使用结构化的 SSM 时,计算复杂度与 SSM 的内部记忆状态大小 N 紧密相关。对于报告的性能,N = 16。

图片

研究者表示,如果使用其它自动微分框架中近期引入的更快的针对硬件的 I/O 感知型实现,BST 方法的速度还能更快。

更多技术细节和实验结果参阅原论文。
产业NeurIPS 2023
相关数据
DeepMind机构

DeepMind是一家英国的人工智能公司。公司创建于2010年,最初名称是DeepMind科技(DeepMind Technologies Limited),在2014年被谷歌收购。在2010年由杰米斯·哈萨比斯,谢恩·列格和穆斯塔法·苏莱曼成立创业公司。继AlphaGo之后,Google DeepMind首席执行官杰米斯·哈萨比斯表示将研究用人工智能与人类玩其他游戏,例如即时战略游戏《星际争霸II》(StarCraft II)。深度AI如果能直接使用在其他各种不同领域,除了未来能玩不同的游戏外,例如自动驾驶、投资顾问、音乐评论、甚至司法判决等等目前需要人脑才能处理的工作,基本上也可以直接使用相同的神经网上去学而习得与人类相同的思考力。

https://deepmind.com/
权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

感知技术

知觉或感知是外界刺激作用于感官时,脑对外界的整体的看法和理解,为我们对外界的感官信息进行组织和解释。在认知科学中,也可看作一组程序,包括获取信息、理解信息、筛选信息、组织信息。与感觉不同,知觉反映的是由对象的各样属性及关系构成的整体。

自注意力技术

自注意力(Self-attention),有时也称为内部注意力,它是一种涉及单序列不同位置的注意力机制,并能计算序列的表征。自注意力在多种任务中都有非常成功的应用,例如阅读理解、摘要概括、文字蕴含和语句表征等。自注意力这种在序列内部执行 Attention 的方法可以视为搜索序列内部的隐藏关系,这种内部关系对于翻译以及序列任务的性能非常重要。

基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

时间复杂度技术

在计算机科学中,算法的时间复杂度是一个函数,它定量描述了该算法的运行时间。这是一个代表算法输入值的字符串的长度的函数。时间复杂度常用大O符号表述,不包括这个函数的低阶项和首项系数。使用这种方式时,时间复杂度可被称为是渐近的,亦即考察输入值大小趋近无穷时的情况。例如,如果一个算法对于任何大小为 n (必须比 n0 大)的输入,它至多需要 5n3 + 3n 的时间运行完毕,那么它的渐近时间复杂度是 O(n3)。

注意力机制技术

我们可以粗略地把神经注意机制类比成一个可以专注于输入内容的某一子集(或特征)的神经网络. 注意力机制最早是由 DeepMind 为图像分类提出的,这让「神经网络在执行预测任务时可以更多关注输入中的相关部分,更少关注不相关的部分」。当解码器生成一个用于构成目标句子的词时,源句子中仅有少部分是相关的;因此,可以应用一个基于内容的注意力机制来根据源句子动态地生成一个(加权的)语境向量(context vector), 然后网络会根据这个语境向量而不是某个固定长度的向量来预测词。

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

正则化技术

当模型的复杂度增大时,训练误差会逐渐减小并趋向于0;而测试误差会先减小,达到最小值后又增大。当选择的模型复杂度过大时,过拟合现象就会发生。这样,在学习时就要防止过拟合。进行最优模型的选择,即选择复杂度适当的模型,以达到使测试误差最小的学习目的。

自然语言处理技术

自然语言处理(英语:natural language processing,缩写作 NLP)是人工智能和语言学领域的分支学科。此领域探讨如何处理及运用自然语言;自然语言认知则是指让电脑“懂”人类的语言。自然语言生成系统把计算机数据转化为自然语言。自然语言理解系统把自然语言转化为计算机程序更易于处理的形式。

堆叠技术

堆叠泛化是一种用于最小化一个或多个泛化器的泛化误差率的方法。它通过推导泛化器相对于所提供的学习集的偏差来发挥其作用。这个推导的过程包括:在第二层中将第一层的原始泛化器对部分学习集的猜测进行泛化,以及尝试对学习集的剩余部分进行猜测,并且输出正确的结果。当与多个泛化器一起使用时,堆叠泛化可以被看作是一个交叉验证的复杂版本,利用比交叉验证更为复杂的策略来组合各个泛化器。当与单个泛化器一起使用时,堆叠泛化是一种用于估计(然后纠正)泛化器的错误的方法,该泛化器已经在特定学习集上进行了训练并被询问了特定问题。

机器之心机构

机器之心,成立于2014年,是国内最具影响力、最专业、唯一用于国际品牌的人工智能信息服务与产业服务平台。目前机器之心已经建立起涵盖媒体、数据、活动、研究及咨询、线下物理空间于一体的业务体系,为各类人工智能从业者提供综合信息服务和产业服务。

https://www.jiqizhixin.com/
感受野技术

一个感觉神经元的感受野是指这个位置里适当的刺激能够引起该神经元反应的区域。感受野一词主要是指听觉系统、本体感觉系统和视觉系统中神经元的一些性质。

语言模型技术

统计式的语言模型是借由一个几率分布,而指派几率给字词所组成的字串。语言模型经常使用在许多自然语言处理方面的应用,如语音识别,机器翻译,词性标注,句法分析和资讯检索。

推荐文章
暂无评论
暂无评论~