Sergei Ivanov等作者小舟、蛋酱编辑

让GBDT和GNN结合起来:Criteo AI Lab提出全新架构BGNN

GBDT 和 GNN 方法各有各的优势,现在,来自法国、俄罗斯两家机构的研究者将二者的优势结合起来,探索使用 GBDT 模型处理图结构数据。


图片



论文地址:https://openreview.net/pdf?id=ebS5NUfoMKL


神经网络(GNN)已经在学习图结构方面取得了巨大成功,应用于分子设计、计算机视觉、组合优化、推荐系统等多个方面。该领域的进展依靠于规范的 GNN 架构的存在,该架构将原始输入数据有效地编码为表达型表征,从而在新的数据集和任务上获得高质量的结果。

最近的一些研究主要集中于具有稀疏数据的 GNN 上,这些数据代表同质节点嵌入(例如 one-hot 编码的图统计)或者词袋表征。但是在许多情况下,具有详细信息和丰富语义的表格数据更为自然,现实世界的人工智能也更丰富。例如在社交网络中,每个人都有社会人口统计学特征(例如年龄、性别、毕业日期),这些特征在数据类型、规模和缺失值上有很大差异。带有表格数据图形的 GNN 的研究是缺少的,梯度提升决策树(GBDT)在具有此类异构数据的应用程序中占主导地位。


GBDT 非常适用于表格数据,因为它们具有以下特性:

  • 能够有效地学习表格数据中常见的具有超平面边界的决策空间;

  • 非常适合处理基数高、值缺失且比例不同的变量;

  • 它们为决策树或通过事后分析阶段的集合提供定性解释;

  • 在实际应用中,对于大量数据,它们收敛速度更快。


相反,GNN 的关键特征使它们同时考虑节点的邻域信息和节点特征来进行预测,这与 GBDT 不同,GBDT 需要额外的预处理分析才能为算法提供图摘要(例如通过无监督图嵌入)。此外,理论上已经证明,通过消息传递的 GNN 可以在其图输入上计算任何可由图灵机器计算的函数,即 GNN 是唯一在图上具有通用性的学习架构(近似化和可计算性)。


此外,与基于树的方法相比,基于梯度的神经网络学习具有多种优势:

  • GNN 中含有的关系归纳偏差减轻了手动设计捕获网络拓扑特征的需求;

  • 训练神经网络的端到端属性允许在依赖于应用程序的解决方案中将 GNN 进行多阶段或多组件集成;

  • 采用图网络的预训练表征丰富了迁移学习中的许多重要任务,例如无监督领域自适应、自监督学习主动学习机制。


显然,GBDT 和 GNN 方法都有着明显的优势,可以将二者的优势结合起来吗?此前所有尝试将梯度提升神经网络结合起来的方法在计算上都很繁琐,没有考虑图结构化数据,并且缺乏 GNN 架构中包含的关系偏向。


本研究是第一个探索使用 GBDT 模型处理图结构数据的研究。在这篇论文中,研究者提出了一种针对含表格数据的图的新型学习架构 BGNN,该架构将 GBDT 对表格节点特征的学习与 GNN 相结合,从而利用图的拓扑优化预测。这使 BGNN 可以继承梯度提升方法(异构学习和可解释性)和图网络(表征学习和端到端训练)的优势。

总体而言,该研究的贡献有:

  • 设计了一种新的通用体系架构,将 GBDT 和 GNN 组合为一个 pipeline;

  • 通过迭代添加适合 GNN 梯度更新的新树,该研究克服了 GBDT 的端到端训练的挑战,使得错误信号可从网络拓扑反向传播到 GBDT;

  • 研究者针对节点预测任务中的强基准对方法进行了广泛的评估,实验结果表明,在各种现实表格数据中,异构节点回归和节点分类任务的性能显著提高;

  • 由于训练过程中的损失收敛速度更快,该方法比其他 GNN SOTA 模型更有效。此外,学习到的表征在潜在空间中展现出可辨别的结构,这进一步证明了该方法的表达能力。



方法介绍

GBDT 和 GNN 的优化遵循不同的方法:GNN 的参数通过梯度下降进行优化,而 GBDT 则是迭代构建的,并且决策树在构建之后仍然保持固定(决策树基于特征空间的硬拆分,因此不可微分)。


一种简单的方法是仅在节点特征上训练 GBDT 模型,然后将得到的 GBDT 预测与原始输入一起用作 GNN 的新节点特征。在这种情况下,将通过神经网络进一步完善 GBDT 对图不敏感的预测。这种方法(被称为 Res-GNN)已经可以提高某些任务的 GNN 性能,但 GBDT 模型将完全忽略图的结构,并可能会丢失图的描述性特征,从而向 GNN 提供不准确的输入数据。

相反,该研究提出对 GBDT 和 GNN 进行端到端的训练,称为 BGNN(Boost-GNN)。该研究首先应用 GBDT,然后再应用 GNN。但考虑到最终预测的质量,该研究对它们进行了优化,BGNN 的训练如图 1 所示。已经构建好的决策树由于其离散的结构而无法正确调整,因此该研究通过添加新的树来迭代地更新 GBDT 模型,使其近似于 GNN 损失函数


图片

图 1:BGNN 的训练过程,计算每一个 epoch 的步数。

下图算法 1 展示了 BGNN 模型的训练算法,它可用于半监督节点回归与分类这样的任意节点级别的预热问题。

图片


更多实验结果读者可以查看原论文。
理论GNNGBDT
相关数据
人工智能技术

在学术研究领域,人工智能通常指能够感知周围环境并采取行动以实现最优的可能结果的智能体(intelligent agent)

基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

收敛技术

在数学,计算机科学和逻辑学中,收敛指的是不同的变换序列在有限的时间内达到一个结论(变换终止),并且得出的结论是独立于达到它的路径(他们是融合的)。 通俗来说,收敛通常是指在训练期间达到的一种状态,即经过一定次数的迭代之后,训练损失和验证损失在每次迭代中的变化都非常小或根本没有变化。也就是说,如果采用当前数据进行额外的训练将无法改进模型,模型即达到收敛状态。在深度学习中,损失值有时会在最终下降之前的多次迭代中保持不变或几乎保持不变,暂时形成收敛的假象。

梯度提升技术

梯度提升是用于回归和分类问题的机器学习技术,其以弱预测模型(通常为决策树)的集合的形式产生预测模型。 它像其他增强方法一样以阶段式方式构建模型,并且通过允许优化任意可微损失函数来推广它们。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

表征学习技术

在机器学习领域,表征学习(或特征学习)是一种将原始数据转换成为能够被机器学习有效开发的一种技术的集合。在特征学习算法出现之前,机器学习研究人员需要利用手动特征工程(manual feature learning)等技术从原始数据的领域知识(domain knowledge)建立特征,然后再部署相关的机器学习算法。虽然手动特征工程对于应用机器学习很有效,但它同时也是很困难、很昂贵、很耗时、并依赖于强大专业知识。特征学习弥补了这一点,它使得机器不仅能学习到数据的特征,并能利用这些特征来完成一个具体的任务。

计算机视觉技术

计算机视觉(CV)是指机器感知环境的能力。这一技术类别中的经典任务有图像形成、图像处理、图像提取和图像的三维推理。目标识别和面部识别也是很重要的研究领域。

推荐系统技术

推荐系统(RS)主要是指应用协同智能(collaborative intelligence)做推荐的技术。推荐系统的两大主流类型是基于内容的推荐系统和协同过滤(Collaborative Filtering)。另外还有基于知识的推荐系统(包括基于本体和基于案例的推荐系统)是一类特殊的推荐系统,这类系统更加注重知识表征和推理。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

梯度下降技术

梯度下降是用于查找函数最小值的一阶迭代优化算法。 要使用梯度下降找到函数的局部最小值,可以采用与当前点的函数梯度(或近似梯度)的负值成比例的步骤。 如果采取的步骤与梯度的正值成比例,则接近该函数的局部最大值,被称为梯度上升。

图灵机技术

图灵机,又称确定型图灵机,是英国数学家艾伦·图灵于1936年提出的一种抽象计算模型,其更抽象的意义为一种数学逻辑机,可以看作等价于任何有限逻辑数学过程的终极强大逻辑机器。

图神经网络技术

图网络即可以在社交网络或其它基于图形数据上运行的一般深度学习架构,它是一种基于图结构的广义神经网络。图网络一般是将底层图形作为计算图,并通过在整张图上传递、转换和聚合节点特征信息,从而学习神经网络基元以生成单节点嵌入向量。生成的节点嵌入向量可作为任何可微预测层的输入,并用于节点分类或预测节点之间的连接,完整的模型可以通过端到端的方式训练。

主动学习技术

主动学习是半监督机器学习的一个特例,其中学习算法能够交互式地查询用户(或其他信息源)以在新的数据点处获得期望的输出。 在统计学文献中,有时也称为最佳实验设计。

图网技术

ImageNet 是一个计算机视觉系统识别项目, 是目前世界上图像识别最大的数据库。

图网络技术

2018年6月,由 DeepMind、谷歌大脑、MIT 和爱丁堡大学等公司和机构的 27 位科学家共同提交了论文《Relational inductive biases, deep learning, and graph networks》,该研究提出了一个基于关系归纳偏置的 AI 概念:图网络(Graph Networks)。研究人员称,该方法推广并扩展了各种神经网络方法,并为操作结构化知识和生成结构化行为提供了新的思路。

自监督学习技术

一个例子中的内容特别多,而用一个例子做一个任务,就等于把其他的内容浪费了,因此我们需要从一个样本中找出多个任务。比如说遮挡图片的一个特定部分,用没遮挡部分来猜遮挡的部分是一个任务。那么通过遮挡不同的部分,就可以用一个样本完成不同任务。Yann Lecun描述的这个方法被业界称作「自监督学习」

节点分类技术

节点分类任务是算法必须通过查看其邻居的标签来确定样本的标记(表示为节点)的任务。

迁移学习技术

迁移学习 是属于机器学习的一种研究领域。它专注于存储已有问题的解决模型,并将其利用在其他不同但相关问题上。比如说,用来辨识汽车的知识(或者是模型)也可以被用来提升识别卡车的能力。计算机领域的迁移学习和心理学常常提到的学习迁移在概念上有一定关系,但是两个领域在学术上的关系非常有限。

推荐文章
暂无评论
暂无评论~