Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

Steeve Huang作者萝卜兔编辑整理

简单的图神经网络介绍

最近,Graph Neural Network(GNN)在很多领域日益普及,包括社交网络、知识图谱推荐系统甚至于生命科学。GNN在对节点关系建模方面表现十分突出,使得相关的研究领域取得了一定突破。本文旨在对GNN做一个简单的介绍,并介绍两种前沿算法,DeepWalk和GraphSage。

Graph

在学习GNN之前,先让我们了解一下什么是Graph。在计算机科学中,graph是一种数据结构,由两部分组成,顶点和边。图G可以通过顶点集V和边集E来描述:

根据顶点之间是否有方向,边可以分为无向和有向。

有向图

顶点又称节点,本文中两者可以互换。

Graph Neural Network

GNN是直接在图数据结构上运行的神经网络。GNN的典型应用便是节点分类。图中的每个节点都有一个标签,我们希望不需要标注数据,可以预测新的节点标签。本节将讲解论文《The graph neural network model》中的GNN算法,算得上第一个GNN。

在节点分类问题中,每个节点v的特征用xv表示,并且和标签tv相关联。给定部分标记的图G,目标是利用这些标记的节点来预测未标记节点的标签。网络学会用d维向量(状态)hv表示每个节点,其中包含其邻域信息。

其中,xco[v]表示与v连接的边的特征,hne[v]表示v的相邻节点嵌入特征,xne[v]表示v的相邻节点的特征。函数f是将这些输入投影到d维空间的传递函数。由于我们正在寻找hv的唯一解,可以应用Banach不动点定理并将上述等式重写为迭代更新过程。

HX分别表示所有hx的连接。通过将状态hv以及特征xv传递给输出函数g来计算GNN的输出。

f和这里的g都可以解释为前馈全连接神经网络。 L1损失可以直接表述如下:

再通过梯度下降优化。上述的GNN算法有三个限制:

1、如果放宽“固定点”的假设,可以利用多层感知机来学习更稳定的表示,并删除迭代更新过程。这是因为,在该提议中,不同的迭代使用传递函数f的相同参数,而不同MLP层中的不同参数允许分层特征提取。

2、它不能处理边信息(例如知识图中的不同边可能表示节点之间的不同关系)。

3、固定点可以阻止节点分布的多样化,因此可能不适合学习表示节点。

当然,已经有几种GNN的变体来解决上述问题,我在这里不展开讲解了。

DeepWalk

DeepWalk是以无监督的方式学习node embedding的算法。它的训练过程非常类似word embedding。动机是图表中的节点和语料库中的单词的分布遵循幂定律,如下图所示:

算法包括两个步骤:

1、在图中的节点上进行随机游走以生成节点序列;

2、使用skip-gram,根据步骤1中生成的节点序列学习每个节点的嵌入。

在随机游走的每个时间步骤,从前一节点的邻居统一采样下一个节点。然后将每个序列截短为长度为2|w|+1的子序列,其中w表示skip-gram中的窗口大小。

在提出DeepWalk的论文中,分层softmax用于解决由于节点数量庞大而导致的softmax计算成本高昂的问题。为了计算每个单独输出元素的softmax值,我们必须计算所有元素k的所有exk

因此,原始softmax的计算时间为O(|V|),其中V表示图中的顶点集。

分层softmax利用二叉树来处理问题。在这个二叉树中,所有叶子(上图中的v1v2,..., v8)都是图中的顶点。在每个内部节点中,有一个二元分类器来决定选择哪条路径。为了计算给定顶点vk的概率,可以简单地计算沿着从根节点到离开vk的路径中的每个子路径的概率。由于每个节点的子概率为1,因此所有顶点的概率之和等于1的特性仍然保持在分层softmax中。现在,元素的计算时间减少到O(log|V|),因为二叉树的最长路径由O(log|n|)限定,其中是n叶子的数量。

分层Softmax

在训练DeepWalk GNN之后,模型已经学习了每个节点的良好表示,如下图所示。不同的颜色表示输入图中的不同标签。我们可以看到,在输出图形(嵌入2维)中,具有相同标签的节点聚集在一起,而具有不同标签的大多数节点被正确分开。

然而,DeepWalk的主要问题是缺乏泛化能力。每当有新节点加入时,它必须重新训练模型以表示该节点。因此,这种GNN不适用于图中节点不断变化的动态图。

GraphSage

GraphSage提供解决上述问题的方案,以归纳方式学习每个节点的嵌入。具体而言,每个节点由其邻域的聚合表示。因此,即使在训练时间内看不到的新节点出现在图中,它仍然可以由其相邻节点正确地表示。下面显示了GraphSage的算法。

外层循环表示更新迭代次数,而hvk表示更新迭代k时节点v的特征。在每次更新迭代时,基于聚合函数,前一次迭代中vv邻域的特征以及权重矩阵Wk来更新hvk。本文提出了三种聚合函数:

1. Mean aggregator

平均聚合器获取节点及其所有邻域的特征的平均值。

与原始方程相比,它删除了上述伪代码中第5行的连接操作。此操作可以被视为“跳过连接”,本文稍后将证明可以在很大程度上提高模型的性能。

2. LSTM aggregator

由于图中的节点没有任何顺序,因此它们通过置换这些节点来随机分配顺序。

3. Pooling aggregator

此运算符在相邻集上执行逐元素池化功能。下面显示了max-pooling示例:

可以用mean-pooling或任何其他对称池化函数替换。文章指出pooling aggregator执行最佳,而mean-pooling和max-pooling具有相似的性能。本文使用max-pooling作为默认聚合函数。

损失函数定义如下:

其中uv共同出现在固定长度的随机游走中,而vn是不与u共同出现的负样本。这种损失函数鼓励具有类似嵌入的节点更接近,而那些相距很远的节点在投影空间中分离。通过这种方法,节点将获得越来越多关于其邻域的信息。

GraphSage通过聚合其附近的节点,可以为看不见的节点生成可表示的嵌入。它允许将节点嵌入应用于涉及动态图的域,其中图的结构不断变化。例如,Pinterest采用了GraphSage的扩展版本PinSage作为其内容发现系统的核心。

总结

我们已经学习了图形神经网络,DeepWalk和GraphSage的基础知识。 GNN在复杂图形结构建模中的强大功能确实令人惊讶。鉴于其有效性,我相信,在不久的将来,GNN将在人工智能的发展中发挥重要作用。

相关参考

原文链接:

https://towardsdatascience.com/a-gentle-introduction-to-graph-neural-network-basics-deepwalk-and-graphsage-db5d540d50b3

参考论文:

https://arxiv.org/pdf/1812.08434.pdf

http://www.perozzi.net/publications/14_kdd_deepwalk.pdf

https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf

http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.1015.7227&rep=rep1&type=pdf

极验
极验

极验是全球顶尖的交互安全技术服务商,于2012年在武汉成立。全球首创 “行为式验证技术” ,利用生物特征与人工智能技术解决交互安全问题,为企业抵御恶意攻击防止资产损失提供一站式解决方案。

入门图神经网络GNNDeepWalkGraphSage
9
相关数据
池化技术

池化(Pooling)是卷积神经网络中的一个重要的概念,它实际上是一种形式的降采样。有多种不同形式的非线性池化函数,而其中“最大池化(Max pooling)”是最为常见的。它是将输入的图像划分为若干个矩形区域,对每个子区域输出最大值。直觉上,这种机制能够有效的原因在于,在发现一个特征之后,它的精确位置远不及它和其他特征的相对位置的关系重要。池化层会不断地减小数据的空间大小,因此参数的数量和计算量也会下降,这在一定程度上也控制了过拟合。通常来说,CNN的卷积层之间都会周期性地插入池化层。

权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

人工智能技术

在学术研究领域,人工智能通常指能够感知周围环境并采取行动以实现最优的可能结果的智能体(intelligent agent)

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

伪代码技术

伪代码,又称为虚拟代码,是高层次描述算法的一种方法。它不是一种现实存在的编程语言;它可能综合使用多种编程语言的语法、保留字,甚至会用到自然语言。 它以编程语言的书写形式指明算法的职能。相比于程序语言它更类似自然语言。它是半形式化、不标准的语言。

知识图谱技术

知识图谱本质上是语义网络,是一种基于图的数据结构,由节点(Point)和边(Edge)组成。在知识图谱里,每个节点表示现实世界中存在的“实体”,每条边为实体与实体之间的“关系”。知识图谱是关系的最有效的表示方式。通俗地讲,知识图谱就是把所有不同种类的信息(Heterogeneous Information)连接在一起而得到的一个关系网络。知识图谱提供了从“关系”的角度去分析问题的能力。 知识图谱这个概念最早由Google提出,主要是用来优化现有的搜索引擎。不同于基于关键词搜索的传统搜索引擎,知识图谱可用来更好地查询复杂的关联信息,从语义层面理解用户意图,改进搜索质量。比如在Google的搜索框里输入Bill Gates的时候,搜索结果页面的右侧还会出现Bill Gates相关的信息比如出生年月,家庭情况等等。

推荐系统技术

推荐系统(RS)主要是指应用协同智能(collaborative intelligence)做推荐的技术。推荐系统的两大主流类型是基于内容的推荐系统和协同过滤(Collaborative Filtering)。另外还有基于知识的推荐系统(包括基于本体和基于案例的推荐系统)是一类特殊的推荐系统,这类系统更加注重知识表征和推理。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

梯度下降技术

梯度下降是用于查找函数最小值的一阶迭代优化算法。 要使用梯度下降找到函数的局部最小值,可以采用与当前点的函数梯度(或近似梯度)的负值成比例的步骤。 如果采取的步骤与梯度的正值成比例,则接近该函数的局部最大值,被称为梯度上升。

语料库技术

语料库一词在语言学上意指大量的文本,通常经过整理,具有既定格式与标记;事实上,语料库英文 "text corpus" 的涵意即为"body of text"。

分类问题技术

分类问题是数据挖掘处理的一个重要组成部分,在机器学习领域,分类问题通常被认为属于监督式学习(supervised learning),也就是说,分类问题的目标是根据已知样本的某些特征,判断一个新的样本属于哪种已知的样本类。根据类别的数量还可以进一步将分类问题划分为二元分类(binary classification)和多元分类(multiclass classification)。

推荐文章
暂无评论
暂无评论~