Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

机器之心编辑部机器之心报道

在表格数据上,为什么基于树的模型仍然优于深度学习?

为什么基于树的机器学习方法,如 XGBoost 和随机森林在表格数据上优于深度学习?本文给出了这种现象背后的原因,他们选取了 45 个开放数据集,并定义了一个新基准,对基于树的模型和深度模型进行比较,总结出三点原因来解释这种现象。

深度学习在图像、语言甚至音频等领域取得了巨大的进步。然而,在处理表格数据上,深度学习却表现一般。由于表格数据具有特征不均匀、样本量小、极值较大等特点,因此很难找到相应的不变量。

基于树的模型不可微,不能与深度学习模块联合训练,因此创建特定于表格的深度学习架构是一个非常活跃的研究领域。许多研究都声称可以击败或媲美基于树的模型,但他们的研究遭到很多质疑。

事实上,对表格数据的学习缺乏既定基准,这样一来研究人员在评估他们的方法时就有很多自由度。此外,与其他机器学习子域中的基准相比,大多数在线可用的表格数据集都很小,这使得评估更加困难。

为了缓解这些担忧,来自法国国家信息与自动化研究所、索邦大学等机构的研究者提出了一个表格数据基准,其能够评估最新的深度学习模型,并表明基于树的模型在中型表格数据集上仍然是 SOTA。

对于这一结论,文中给出了确凿的证据,在表格数据上,使用基于树的方法比深度学习(甚至是现代架构)更容易实现良好的预测,研究者并探明了其中的原因。
图片
论文地址:https://hal.archives-ouvertes.fr/hal-03723551/document

值得一提的是,论文作者之一是 Gaël Varoquaux ,他是 Scikit-learn 计划的领导者之一。目前该项目在 GitHub 上已成为最流行的机器学习库之一。而由 Gaël Varoquaux 参与的文章《Scikit-learn: Machine learning in Python》,引用量达 58949。
图片
本文贡献可总结为:

该研究为表格数据创建了一个新的基准(选取了 45 个开放数据集),并通过 OpenML 共享这些数据集,这使得它们易于使用。

该研究在表格数据的多种设置下比较了深度学习模型和基于树的模型,并考虑了选择超参数的成本。该研究还分享了随机搜索的原始结果,这将使研究人员能够廉价地测试新算法以获得固定的超参数优化预算。

在表格数据上,基于树的模型仍然优于深度学习方法

基准参考 45 个表格数据集,选择基准如下 :

  • 异构列,列应该对应不同性质的特征,从而排除图像或信号数据集。
  • 维度低,数据集 d/n 比率低于 1/10。
  • 无效数据集,删除可用信息很少的数据集。
  • I.I.D.(独立同分布)数据,移除类似流的数据集或时间序列。
  • 真实世界数据,删除人工数据集,但保留一些模拟数据集。
  • 数据集不能太小,删除特征太少(< 4)和样本太少(< 3 000)的数据集。
  • 删除过于简单的数据集。
  • 删除扑克和国际象棋等游戏的数据集,因为这些数据集目标都是确定性的。
 
在基于树的模型中,研究者选择了 3 种 SOTA 模型:Scikit Learn 的 RandomForest,GradientBoostingTrees (GBTs) , XGBoost

该研究对深度模型进行了以下基准测试:MLP、Resnet 、FT Transformer、SAINT 。

图 1 和图 2 给出了不同类型数据集的基准测试结果

图片
图片
实证调查:为什么基于树的模型在表格数据上仍然优于深度学习

归纳偏差。基于树的模型在各种超参数选择中击败了神经网络。事实上,处理表格数据的最佳方法有两个共有属性:它们是集成方法、bagging(随机森林)或 boosting(XGBoost、GBT),而这些方法中使用的弱学习器是决策树。

发现 1:神经网络(NN)倾向于过度平滑的解决方案

如图 3 所示,对于较小的尺度,平滑训练集上的目标函数会显着降低基于树的模型的准确率,但几乎不会影响 NN。这些结果表明,数据集中的目标函数并不平滑,与基于树的模型相比,NN 难以适应这些不规则函数。这与 Rahaman 等人的发现一致,他们发现 NN 偏向于低频函数。基于决策树的模型学习分段(piece-wise)常函数,没有这样的偏置。
图片
发现 2:非信息特征更能影响类似 MLP 的 NN

表格数据集包含许多非信息( uninformative)特征,对于每个数据集,该研究根据特征的重要性会选择丢弃一定比例的特征(通常按随机森林排序)。从图 4 可以看出,去除一半以上的特征对 GBT 的分类准确率影响不大。
图片
图 5 可以看到移除非信息特征 (5a) 减少了 MLP (Resnet) 与其他模型(FT Transformers 和基于树的模型)之间的性能差距 ,而添加非信息特征会扩大差距,这表明 MLP 对非信息特征的鲁棒性较差。在图 5a 中,当研究者移除更大比例的特征时,相应的也会删除有用信息特征。图 5b 表明,去除这些特征所带来的准确率下降可以通过去除非信息特征来补偿,与其他模型相比,这对 MLP 更有帮助(同时,该研究还删除了冗余特性,也不会影响模型性能)。
图片
发现 3:通过旋转,数据是非不变的

与其他模型相比,为什么 MLP 更容易受到无信息特征的影响?其中一个答案是,MLP 是旋转不变的:当对训练集和测试集特征应用旋转时,在训练集上学习 MLP 并在测试集上进行评估,这一过程是不变的。事实上,任何旋转不变的学习过程都具有最坏情况下的样本复杂度,该复杂度至少在不相关特征的数量上呈线性增长。直观地说,为了去除无用特征,旋转不变算法必须首先找到特征的原始方向,然后选择信息最少的特征。
 
图 6a 显示了当对数据集进行随机旋转时的测试准确率变化,证实只有 Resnets 是旋转不变的。值得注意的是,随机旋转颠倒了性能顺序:结果是 NN 在基于树的模型之上,Resnets 在 FT Transformer 之上,这表明旋转不变性是不可取的。事实上,表格数据通常具有单独含义,例如年龄、体重等。

图 6b 中显示:删除每个数据集中最不重要的一半特征(在旋转之前),会降低除 Resnets 之外的所有模型的性能,但与没有删除特征使用所有特征时相比,相比较而言,下降的幅度较小。
图片
原文链接:https://twitter.com/GaelVaroquaux/status/1549422403889
理论表格数据
相关数据
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

超参数优化技术

基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

超参数技术

在机器学习中,超参数是在学习过程开始之前设置其值的参数。 相反,其他参数的值是通过训练得出的。 不同的模型训练算法需要不同的超参数,一些简单的算法(如普通最小二乘回归)不需要。 给定这些超参数,训练算法从数据中学习参数。相同种类的机器学习模型可能需要不同的超参数来适应不同的数据模式,并且必须对其进行调整以便模型能够最优地解决机器学习问题。 在实际应用中一般需要对超参数进行优化,以找到一个超参数元组(tuple),由这些超参数元组形成一个最优化模型,该模型可以将在给定的独立数据上预定义的损失函数最小化。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

随机森林技术

在机器学习中,随机森林是一个包含多个决策树的分类器,并且其输出的类别是由个别树输出的类别的众数而定。 Leo Breiman和Adele Cutler发展出推论出随机森林的算法。而"Random Forests"是他们的商标。这个术语是1995年由贝尔实验室的Tin Kam Ho所提出的随机决策森林(random decision forests)而来的。这个方法则是结合Breimans的"Bootstrap aggregating"想法和Ho的"random subspace method" 以建造决策树的集合。

准确率技术

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

集成方法技术

在统计学和机器学习中,集成方法使用多种学习算法来获得比单独使用任何组成学习算法更好的预测性能。

目标函数技术

目标函数f(x)就是用设计变量来表示的所追求的目标形式,所以目标函数就是设计变量的函数,是一个标量。从工程意义讲,目标函数是系统的性能标准,比如,一个结构的最轻重量、最低造价、最合理形式;一件产品的最短生产时间、最小能量消耗;一个实验的最佳配方等等,建立目标函数的过程就是寻找设计变量与目标的关系的过程,目标函数和设计变量的关系可用曲线、曲面或超曲面表示。

独立同分布技术

在概率论与统计学中,独立同分布(缩写为IID)是指一组随机变量中每个变量的概率分布都相同,且这些随机变量互相独立。一组随机变量独立同分布并不意味着它们的样本空间中每个事件发生概率都相同。例如,投掷非均匀骰子得到的结果序列是独立同分布的,但掷出每个面朝上的概率并不相同。

随机搜索技术

XGBoost技术

XGBoost是一个开源软件库,为C ++,Java,Python,R,和Julia提供了渐变增强框架。 它适用于Linux,Windows,MacOS。从项目描述来看,它旨在提供一个“可扩展,便携式和分布式的梯度提升(GBM,GBRT,GBDT)库”。 除了在一台机器上运行,它还支持分布式处理框架Apache Hadoop,Apache Spark和Apache Flink。 由于它是许多机器学习大赛中获胜团队的首选算法,因此它已经赢得了很多人的关注。

推荐文章
暂无评论
暂无评论~