魔王报道

数学奥赛冠军都做不对的题,却被拿来考ML模型?GPT-3:我不行

为了衡量机器学习模型的数学求解能力,来自 UC 伯克利和芝加哥大学的研究者提出了一个包含 12, 500 道数学竞赛难题的新型数据集 MATH,以及帮助模型学习数学基础知识的预训练数据集 AMPS。研究发现,即使是大参数的 Transformer 模型准确率也很低。

许多学术研究探讨数学问题求解,但对于计算机而言这超出了其能力范畴。那么机器学习模型是否具备数学问题求解能力呢?

来自加州大学伯克利分校和芝加哥大学的研究者为此创建了一个新型数据集 MATH。该数据集包含 12, 500 道数学竞赛难题,每个数学题都有完整的逐步求解过程,可用来教机器学习模型生成答案和解释。为了促进未来研究,提升模型在 MATH 数据集上的准确率,研究者还创建了另一个大型辅助预训练数据集,它可以教模型数学基础知识。

尽管通过这些方法提升了模型在 MATH 数据集上的准确率,但实验结果表明,准确率仍然很低,即使 Transformer 模型也不例外。研究者还发现,仅靠增加预算和模型参数量并不能实现强大的数学推理能力。扩展 Transformer 能够自动解决大多数文本任务,但目前仍无法解决 MATH 问题。

该研究第一作者 Dan Hendrycks 发推表示:

国际数学奥林匹克竞赛(IMO)三金得主能达到 90% 的准确率,而 GPT-3 的准确率只能达到约 5%。

如果这一趋势持续下去,那么机器学习模型距离获得数学推理能力还很遥远。

数据集

这部分介绍两个新型数据集,一个是用于测试模型数学问题求解能力的 MATH 数据集,另一个是用于辅助预训练的 AMPS 数据集。

MATH 数据集

MATH 数据集包含 12, 500 个数学问题(其中 7500 个属于训练集,5000 个属于测试集),这些问题收集自 AMC 10、AMC 12、AIME 等数学竞赛(这些数学竞赛已经持续数十年,旨在评估美国最优秀的年轻数学人才的数学问题求解能力)。与大多数之前的研究不同,MATH 数据集中的大部分问题无法通过直接应用标准 K-12 数学工具来解决,人类解决这类问题通常需要用到问题求解技术和「启发式」方法。

基于这些数学问题,模型可以学习多种有用的问题求解启发式方法,且每个问题都有逐步求解过程和最终答案。具备逐步求解过程的问题示例参见下图 1:

该数据集的创建涉及以下重要步骤:

  • 问题分类:该数据集中的问题难度不同,并涉及多个主题,包括算术、代数、数论、计数与概率、几何、中级代数、预备微积分。研究者按照对人类而言从易到难的程度将问题难度等级标注为 1-5。

  • 格式化:使用 LATEX 和 Asymptote 矢量图语言将数学问题及其解进行统一格式化。

  • 自动评估生成的答案:MATH 数据集的独特设计使得研究者可以自动评估模型生成的答案,即使模型输出空间非常大。

  • 人类性能:为了估计人类性能,研究者从 MATH 测试集中随机采样了 20 个问题,交由高校学生回答。一位不喜欢数学的参与者答对了 8 道题(准确率 40%),两位喜欢数学的参与者分别答对了 14 题和 15 题,一位在 AMC 10 数学竞赛中拿到满分并多次参加 USAMO 竞赛的参与者答对了 18 道题,一位 IMO 三金得主也答对了 18 道题(准确率 90%)。这说明 MATH 数据集中的数学问题对于人类而言也是有一定难度的。

AMPS 数据集(可汗学院 + Mathematica)

预训练数据会对性能产生极大影响,而数学是在线文本的一小部分,于是该研究创建了一个大型多样化的数学预训练语料库。该预训练数据集 Auxiliary Mathematics Problems and Solutions (AMPS) 包括许多问题及 LATEX 格式的逐步求解过程。

AMPS 数据集包含 10 万个收集自可汗学院的数学问题,和约 500 万通过手动设计 Mathematica 脚本生成的问题。该研究使用 Mathematica 的计算机代数系统生成数学问题,是为了便于操作分数、超越数和解析函数。

这些问题涉及多个主题,包括代数、微积分、计数与统计、几何、线性代数,以及数论(参见下表 1)。

实验

模型性能

研究者通过实验调查了模型在 MATH 数据集上的性能,发现即使最优模型的准确率也很低。此外,与大多数基于文本的数据集不同,该数据集上的准确率增速随着模型规模的扩大而越来越慢。如果这一趋势继续,则要想在 MATH 数据集上取得较大进展,我们需要的不只是模型扩展,而是算法改进。

下表 2 表明,最小模型 GPT-2(0.1 billion 参数量,基线模型)在 MATH 数据集多个主题上的平均准确率为 5.4%,而 GPT-2(1.5 billion 参数量,参数量是基线模型的 15 倍)的平均准确率为 6.9%,相比基线提升了 28%。这表明与大部分其它基于文本的任务不同,在 MATH 数据集上增加模型参数确实有所帮助,但模型的绝对准确率仍然很低,且增速缓慢。

此外,研究者测试了使用 AMPS 预训练的效果。未经 AMPS 预训练时,GPT-2 (1.5B) 模型在 MATH 数据集上的准确率为 5.5%;而经过 AMPS 预训练后,GPT-2 (1.5B) 在 MATH 数据集上的准确率为 6.9%(参见表 2),准确率提升了 25%。也就是说,AMPS 预训练对准确率的提升效果相当于参数量 15 倍增加的效果,这表明 AMPS 预训练数据集是有价值的。

逐步求解

研究者对逐步求解过程进行了实验,发现模型在得到答案前先生成逐步求解过程会导致准确率下降。研究者利用 GPT-2 (1.5B) 进行评估,发现模型性能有所下降,从 6.9% 下降到了 5.3%。

研究者还对这些生成的逐步求解过程进行了定性评估,发现尽管很多步骤看似与问题相关,但其实存在逻辑问题。示例参见下图 3、4:

图 3:问题、GPT-2 (1.5B) 模型生成的逐步解、真值解。

图 4:问题、生成解和真值解示例。

不过,研究人员发现逐步求解仍能带来一定好处:提供部分真值逐步求解过程可以提升性能,在训练过程中为模型提供逐步求解过程可以提升准确率。下图 6 展示了 GPT-2 (0.7B) 模型使用不同部分求解过程的准确率变化。

入门数学问题GPT-3
相关数据
机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

准确率技术

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

逻辑技术

人工智能领域用逻辑来理解智能推理问题;它可以提供用于分析编程语言的技术,也可用作分析、表征知识或编程的工具。目前人们常用的逻辑分支有命题逻辑(Propositional Logic )以及一阶逻辑(FOL)等谓词逻辑。

微积分技术

微积分(Calculus)是高等数学中研究函数的微分(Differentiation)、积分(Integration)以及有关概念和应用的数学分支。它是数学的一个基础学科。内容主要包括极限、微分学、积分学及其应用。微分学包括求导数的运算,是一套关于变化率的理论。它使得函数、速度、加速度和曲线的斜率等均可用一套通用的符号进行讨论。积分学,包括求积分的运算,为定义和计算面积、体积等提供一套通用的方法 。

线性代数技术

线性代数是数学的一个分支,它的研究对象是向量,向量空间(或称线性空间),线性变换和有限维的线性方程组。向量空间是现代数学的一个重要课题;因而,线性代数被广泛地应用于抽象代数和泛函分析中;通过解析几何,线性代数得以被具体表示。线性代数的理论已被泛化为算子理论。由于科学研究中的非线性模型通常可以被近似为线性模型,使得线性代数被广泛地应用于自然科学和社会科学中。

算术技术

算术(英语:arithmetic)是数学最古老且最简单的一个分支,几乎被每个人使用着,从日常生活上简单的算数到高深的科学及工商业计算都会用到。一般而言,算术这一词指的是记录数字某些运算基本性质的数学分支。

推荐文章
暂无评论
暂无评论~