Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

简化版Transformer来了,网友:年度论文

从大模型的根源开始优化。

Transformer 架构可以说是近期深度学习领域许多成功案例背后的主力军。构建深度 Transformer 架构的一种简单方法是将多个相同的 Transformer 「块」(block)依次堆叠起来,但每个「块」都比较复杂,由许多不同的组件组成,需要以特定的排列组合才能实现良好的性能。

自从 2017 年 Transformer 架构诞生以来,研究者们基于其推出了大量衍生研究,但几乎没有改动过 Transformer 「块」。

那么问题来了,标准 Transformer 块是否可以简化?

在最近的一篇论文中,来自 ETH Zurich 的研究者讨论了如何在不影响收敛特性和下游任务性能的情况下简化 LLM 所必需的标准 Transformer 块。基于信号传播理论和经验证据,他们发现可以移除一些部分,比如残差连接、归一化层(LayerNorm)、投影和值参数以及 MLP 序列化子块(有利于并行布局),以简化类似 GPT 的解码器架构以及编码器式 BERT 模型。

对于每个涉及的组件,研究者都探讨了是否可以在不降低训练速度的情况下将其移除(包括每次更新步骤和运行时间),以及为此需要 Transformer 块进行哪些架构修改。

图片

论文链接:https://arxiv.org/pdf/2311.01906.pdf

Lightning AI 创始人、机器学习研究者 Sebastian Raschka 将这项研究称为自己的「年度最爱论文之一」:

图片

但也有研究者质疑:「这很难评,除非我看过完整的训练过程。如果没有归一化层,也没有残差连接,如何能在大于 1 亿参数的网络中进行扩展?

图片

Sebastian Raschka 表示赞同:「是的,他们试验的架构相对较小,这是否能推广到数十亿参数的 Transformer 上还有待观察。」但他仍然表示这项工作令人印象深刻,并认为成功移除残差连接是完全合理的(考虑到其初始化方案)。

对此,图灵奖得主 Yann LeCun 的评价是:「我们仅仅触及了深度学习架构领域的皮毛。这是一个高维空间,因此体积几乎完全包含在表面中,但我们只触及了表面的一小部分。

图片

为什么需要简化 Transformer 块?

研究者表示,在不影响训练速度的前提下简化 Transformer 块是一个有趣的研究问题。

首先,现代神经网络架构设计复杂,包含许多组件,而这些不同组件在神经网络训练动态中所扮演的角色,以及它们之间如何相互作用,人们对此尚不清楚。这个问题事关深度学习理论与实践之间存在的差距,因此非常重要。

信号传播理论(Signal propagation)已被证明具有影响力,因为它能够激励深度神经网络架构中的实际设计选择。信号传播研究了初始化时神经网络中几何信息的演化,通过跨输入的分层表征的内积来捕捉,在训练深度神经网络方面取得了许多令人印象深刻的成果。

然而,目前该理论只考虑初始化时的模型,而且往往只考虑初始前向传递,因此无法揭示深度神经网络训练动态的许多复杂问题,例如残差连接对训练速度的助益。虽然信号传播对修改动机至关重要,但研究者表示,他们不能仅从理论上就得出简化的 Transformer 模块,还要依靠经验见解。

在实际应用方面,考虑到目前训练和部署大型 Transformer 模型的高昂成本,Transformer 架构的训练和推理流水线的任何效率提升都代表着巨大的潜在节约意义。如果能够通过移除非必要组件来简化 Transformer 模块,既能减少参数数量,又能提高模型的吞吐量。

这篇论文也提到,移除残差连接、值参数、投影参数和序列化子块之后,可以同时做到在训练速度和下游任务性能方面与标准 Transformer 相匹配。最终,研究者将参数量减少了 16%,并观察到训练和推理时间的吞吐量增加了 16%。

如何简化 Transformer 块?

研究者结合信号传播理论和经验观察,介绍了如何从 Pre-LN 模块出发,生成最简单的 Transformer 块(如下图)。

图片

在论文第四章的每一个小节,作者分别介绍了如何在不影响训练速度的情况下每次删除一个块组件。

这一部分的所有实验都在 CodeParrot 数据集上使用了一个 18-block 768-width 的因果仅解码器类 GPT 模型,这个数据集足够大,因此当作者处于单个训练 epoch 模式时,泛化差距非常小(见图 2),这使得他们可以专注于训练速度。

图片

删除残差连接

研究者首先考虑删除注意力子块中的残差连接。在公式(1)的符号中,这相当于将 α_SA 固定为 0。简单地移除注意力残差连接会导致信号退化,即秩崩溃(rank collapse),从而导致可训练性差。在论文 4.1 部分,研究者详细解释了他们的方法。

图片

删除投影 / 值参数

从图 3 中可以得出结论,完全移除值和投影参数 W^V、W^P 是可能的,而且每次更新的训练速度损失最小。也就是说,当 β_V = β_P = 0 和 identity 初始化的

图片

时,在相同的训练步数后,本研究基本上能达到 Pre-LN 块的性能。在这种情况下,在整个训练过程中都有 W^V = W^P = I,即值和投影参数是一致的。作者在 4.2 节介绍了详细方法。

图片

删除 MLP 子块残差连接

与上述几个模块相比,删除 MLP 子块残差连接要更具挑战性。与之前的研究一样,作者发现,在使用 Adam 时,如果没有 MLP 残差连接,通过信号传播使激活更加线性仍会导致每次更新训练速度的显著下降,如图 22 所示。

图片

他们还尝试了 Looks Linear 初始化的各种变体,包括高斯权重、正交权重或恒等权重,但都无济于事。因此,他们在整个工作中使用标准激活(例如 ReLU)和 MLP 子块中的初始化。

作者转向并行 MHA 和 MLP 子块的概念,这在几个近期的大型 transformer 模型中已被证明很受欢迎,例如 PALM 和 ViT-22B。并行 transformer 块如下图所示。

图片

作者在论文 4.3 节详细介绍了移除 MLP 子块残差连接的具体操作。

删除归一化层

最后一个被删除的是归一化层,这样就得到了图 1 右上角的最简块。从信号传播初始化的角度来看,作者可以在本节简化的任何阶段移除归一化层。他们的想法是,Pre-LN 块中的归一化会隐式地降低残差分支的权重,而这种有利的效果可以通过另一种机制在没有归一化层的情况下复制:要么在使用残差连接时明确降低残差分支的权重,要么将注意力矩阵偏向 identity / 将 MLP 非线性转化为「更」线性。

由于作者在修改过程中考虑到了这些机制(如降低 MLP β_FF 和 Shaped Attention 的权重),因此无需进行归一化处理。作者在第 4.4 节介绍了更多信息。

实验结果

深度扩展

鉴于信号传播理论通常关注很大的深度,而这种情况下通常会出现信号退化。因此一个很自然的问题就是,本文的简化 transformer 块所提高的训练速度是否也能扩展到更大的深度?

从图 6 中可以观察到,将深度从 18 个块扩展到 72 个块后,本研究的模型和 Pre-LN transformer 的性能都得到了提高,这表明本研究中的简化模型不仅训练速度更快,而且还能利用更大的深度所提供的额外能力。事实上,在使用归一化时,本研究中的简化块和 Pre-LN 的每次更新轨迹在不同深度下几乎没有区别。

图片

BERT

接下来,作者展示了他们的简化块性能除了适用于自回归解码器之外,还适用于不同的数据集和架构,以及下游任务。他们选择了双向仅编码器 BERT 模型的流行设置,用于掩蔽语言建模,并采用下游 GLUE 基准

如图 7 所示,在 24 小时运行时内,与(Crammed)Pre-LN 基线相比,本研究的简化块可以媲美掩蔽语言建模任务的预训练速度。另一方面,在不修改值和投影的情况下删除残差连接再次导致训练速度的显著下降。在图 24 中,作者提供了 microbatch 步骤的等效图。

图片

图片

此外,在表 1 中,研究者发现他们的方法在 GLUE 基准上经过微调后,性能与 Crammed BERT 基准相当。

图片

他们在表 2 中对下游任务进行了细分。为了进行公平比较,他们使用了与 Geiping & Goldstein (2023) 相同的微调协议(5 个 epoch、各任务超参数恒定、dropout regularisation)。

图片

效率提升

在表 1 中,研究者还详细列出了使用不同 Transformer 块的模型在掩蔽语言建模任务中的参数数量和训练速度。他们以预训练 24 小时内所采取的 microbatch 步骤数与基线 Pre-LN Crammed BERT 的比率计算了速度。结论是,模型使用的参数减少了 16%,SAS-P 和 SAS 的每次迭代速度分别比 Pre-LN 块快 16% 和 9%。

可以注意到,在这里的实现中,并行块只比 Pre-LN 块快 5%,而 Chowdhery et al.(2022 )观察到的训练速度则快 15%,这表明通过更优化的实现,整个训练速度有可能进一步提高。与 Geiping & Goldstein(2023 年)一样,此处实现也使用了 PyTorch 中的自动算子融合技术 (Sarofeen et al., 2022)。

更长的训练

最后,考虑到当前在更多数据上长时间训练较小模型的趋势,研究者讨论了简化块在长时间训练后是否仍能达到 Pre-LN 块的训练速度。为此,他们在 CodeParrot 上使用图 5 中的模型,并使用 3 倍 token 进行训练。准确地说,是在批大小为 128、序列长度为 128 的情况下进行了约 120K 步(而不是 40K 步)的训练,这将导致约 2B 个 token。

从图 8 可以看出,当使用更多的 token 进行训练时,简化的 SAS 和 SAS-P 代码块的训练速度仍然与 PreLN 代码块相当,甚至优于 PreLN 代码块。

图片

更多研究细节,可参考原论文。

工程ETH ZurichTransformer
相关数据
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

收敛技术

在数学,计算机科学和逻辑学中,收敛指的是不同的变换序列在有限的时间内达到一个结论(变换终止),并且得出的结论是独立于达到它的路径(他们是融合的)。 通俗来说,收敛通常是指在训练期间达到的一种状态,即经过一定次数的迭代之后,训练损失和验证损失在每次迭代中的变化都非常小或根本没有变化。也就是说,如果采用当前数据进行额外的训练将无法改进模型,模型即达到收敛状态。在深度学习中,损失值有时会在最终下降之前的多次迭代中保持不变或几乎保持不变,暂时形成收敛的假象。

超参数技术

在机器学习中,超参数是在学习过程开始之前设置其值的参数。 相反,其他参数的值是通过训练得出的。 不同的模型训练算法需要不同的超参数,一些简单的算法(如普通最小二乘回归)不需要。 给定这些超参数,训练算法从数据中学习参数。相同种类的机器学习模型可能需要不同的超参数来适应不同的数据模式,并且必须对其进行调整以便模型能够最优地解决机器学习问题。 在实际应用中一般需要对超参数进行优化,以找到一个超参数元组(tuple),由这些超参数元组形成一个最优化模型,该模型可以将在给定的独立数据上预定义的损失函数最小化。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

堆叠技术

堆叠泛化是一种用于最小化一个或多个泛化器的泛化误差率的方法。它通过推导泛化器相对于所提供的学习集的偏差来发挥其作用。这个推导的过程包括:在第二层中将第一层的原始泛化器对部分学习集的猜测进行泛化,以及尝试对学习集的剩余部分进行猜测,并且输出正确的结果。当与多个泛化器一起使用时,堆叠泛化可以被看作是一个交叉验证的复杂版本,利用比交叉验证更为复杂的策略来组合各个泛化器。当与单个泛化器一起使用时,堆叠泛化是一种用于估计(然后纠正)泛化器的错误的方法,该泛化器已经在特定学习集上进行了训练并被询问了特定问题。

深度神经网络技术

深度神经网络(DNN)是深度学习的一种框架,它是一种具备至少一个隐层的神经网络。与浅层神经网络类似,深度神经网络也能够为复杂非线性系统提供建模,但多出的层次为模型提供了更高的抽象层次,因而提高了模型的能力。

推荐文章
暂无评论
暂无评论~