Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

KAUST & Intel ISL作者

CVPR 2020:基于贪心思想的CNN/GCN网络结构搜索算法SGAS

本工作通过贪心的搜索方式减轻了NAS中模型排名在搜索和最后评估不一致的问题。是一种更优更快的网络结构搜索算法,并同时支持CNN和GCN的搜索。代码已开源,想在图像,点云,生物图数据上做网络结构搜索的同学都可以试一试。

Code: https://github.com/lightaime/sgas

Arxiv: https://arxiv.org/abs/1912.00195

Project webpage: https://www.deepgcns.org/auto/sgas

相关工作

网络结构搜索(Neural Architecture Search, 简称NAS) 是一种神经网络结构自动化设计的技术。NAS基于相应算法在特定的样本集内自动设计出高性能的网络结构。这些自动搜索出的网络结构在某些任务上已经媲美或超过了人类专家手工设计的网络结构。

早期NAS的算法是基于强化学习(Zoph et al.[1])或进化算法(Real et al.[2])。这些算法计算成本高昂,阻碍了其广泛应用。 近来,Liu et al.[3]提出了一种高效的可微分的网络结构搜索算法:可微分网络结构搜索(Differentiable Architecture Search, 简称DARTS)。DARTS的提出使得网络结构搜索在单卡一天内完成搜索。后续许多工作都基于DARTS基础上进行改进,比如SNAS,FBNet,ProxylessNAS,P-DARTS,GDAS,MdeNAS,PC-DARTS,FairDARTS等等。

背景知识:DARTS

DARTS采用基于单元(Cell)的搜索方法进行网络结构搜索。Cell是一个网络子模块,可以自由堆叠多次形成卷积网络。DARTS通过学习cell的结构,完成对网络的结构搜索。Cell是由N个节点的有序序列组成的有向无环图(如图1)。 Cell中每个节点x^(i)是卷积网络中的特征图,每个有向边(i,j)代表一种对x的运算o(i, j) (如3x3的卷积)。 一个cell具有两个输入节点,一个输出节点和多个中间结点。 Cell的输入节点被定义为前两层的输出。 Cell的输出是对所有中间节点进行归约运算(例如concatenation)后的结果。每个中间节点由它之前结点经过算子op变换后相加得到:

DARTS为了实现可微搜索,提出了搜索空间的continuous relaxation机制,利用softmax函数来学习所有可能候选运算op的权重

其中,$\mathcal(O)$ 表示搜索空间中的候选运算(例如卷积,最大池化,零)等,零表示没有运算(边)。 其中一对节点(i,j)之间的运算由向量\alpha_{i, j}参数化。运算的结果是每种可能运算结果的加权求和。向量\alpha_{i, j}的维度为搜索空间长度|O|。因而,DARTS将网络结构搜索的任务简化为了学习一组连续变量\alpha= {\alpha(i, j)},如图1所示。DARTS在搜索结束阶段,通过argmax得到权重最大的候选运算op当作该边的搜索结果:

图1 DARTS的单元结构以及网络结构搜索示意图[3]

DARTS存在的部分问题

一般NAS的流程是分为搜索阶段和评估阶段,在训练集与验证集上进行结构的搜索(搜索阶段),然后在测试集上进行模型的评估(评估阶段)。DARTS在搜索结束的阶段直接得到网络结构,如图1和公式3所示。 这种策略导致在搜索和评估阶段派生的网络结构的性能相关性非常低。网络在搜索阶段的效果,并不能反映其在评估阶段的真实效果。我们将这种现象称之为:退化的搜索评估相关性(degenerate search-evaluation correlation)。我们认为造成这种想象的主要原因是:(1)搜索阶段和评估阶段设置的不一致性;(2)权重共享(weight sharing)造成的副作用。

举个例子:假设我们只搜索3种候选运算,skip-connect,3x3卷积,5x5卷积。如果搜索时分配的权重分别是skip-connect (0.34),3x3卷积(0.33),5x5卷积(0.33),最后所选的操作会是没有可学习参数的skip-connect,如果所有的边都是这种情况,那么最后的网络在评价阶段性能就会很差,然而在搜索阶段这个网络和权重分配为skip-connect (0.33),3x3卷积(0.34),5x5卷积(0.33),最后会得到3x3卷积的网络性能几乎不会有区别。这一现象并不仅仅发生在DARTS上,也发生其他大部分NAS算法上,这严重影响了NAS的性能。

肯德尔系数\tau[4]可用于量化搜索评估相关性( search-evaluation correlation)。肯德尔系数介于-1到1,-1表示完全负相关,1表示完全正相关。 如果肯德尔系数为0,则分布完全独立。理想的NAS方法应具有较高的搜索评估相关性\tau。 我们以DARTS 为例, 在CIFAR-10数据集上运行10次,分别根据搜索准确性和最终评估准确性排名,计算其肯德尔系数。 一阶和二阶的DARTS的肯德尔系数分别仅为0.16和-0.29。 因此,DARTS算法的搜索评估相关性极低,无法根据DARTS在搜索阶段的效果预测模型测试阶段的准确性。

图2 搜索-测试相关性可以用肯德尔系数衡量。 常见的NAS算法,如DARTS,肯德尔系数低,无法根据其搜索阶段对最终测试的准确性做出可靠的预测。

SGAS详细方案

针对退化的搜索评估相关性这一重要的问题,我们提出了SGAS(Sequential Greedy Architecture Search),一种顺序贪心决策的搜索算法。图2 SGAS的网络结构搜索示意图我们基于贪心算法的思想将网络结构的搜索问题,转化为逐步地选择一条边并确定其运算的子问题。实验证明,依次解决这些简单的子问题,可以让最终结构具有更高的搜索-测试相关性。算法的迭代过程如算法1所示。

在每个决策时期,我们根据预先确定的选择标准选择一条边(i^{+}, j^{+})。通过用公式(3)得到这条边的运算,并替换相应的混合运算o^{\bar}。所选择的运算,即是所选的边基于贪心的最优选择。每当确定好一条边的运算,我们就不再需要这条边的结构参数\alpha(i^{+}, j^{+}))以及混合操作中其余路径的权重,我们可以将这条边从后续的优化中去除。这样可以带给我们一个额外的好处是:优化问题得到了剪枝,进而可以提高搜索的效率。一条边被剪枝后,剩下的超网络以及参数形成一个新的子问题,该子问题将被以相同的算法迭代求解。在搜索阶段的最后,便得到一个没有权重共享的离散子网络,如图2所示。SGAS算法基于顺序贪心算法,减少了在搜索阶段和评价阶段的模型不一致性和权重共享的副作用,使得模型的搜索-测试相关性最大化。

在SGAS中,选择标准的设计至关重要。我们考虑影响边的选择的三个重要因素:边的重要性,选择确定性和选择稳定性。

边的重要性:如果这条边的非零运算选择的可能性越高,表明这条边越重要。

选择确定性:熵(entropy)是分布用度量不确定性的常。 非零运算的归一化权重可以看作是一种分布:

我们将选择确定性定义为一减去操作分布的归一化熵:

选择稳定性:为了让选择更稳定,我们需要考虑选择确定性的历史分布。直方图相交[48]是检测分布变化的常用方法,我们利用直方图相交来计算第T步中的前K时间的边的平均选择稳定性:

根据这三个影响边选择的三个因素,我们提出了两个选择指标:

指标1 :具有高的边重要性和高的边确定性的边将被选择,公式为:

指标2:在指标1的基础上,被选择的边也应该具有较高的稳定性:

这里,normalize(·) 指 Min-Max标准化。

实验结果

我们搜索了CNN和GCN网络结构,并在CIFAR,ImageNet图像分类,ModelNet点云分类,PPI生物图数据节点分类上达到了SOTA效果。

CNN

我们将SGAS用到CNN的网络结构搜索中, CNN网络结构由普通单元(normal cell) 和 归约单元(reduction cell)组成。普通单元保持特征图大小不变,归约单元缩小特征图至½. CNN任务中,搜索空间由8个运算组成:skip-connect, max-pool-3×3, avg-pool-3×3, sep-conv-3×3, sep-conv5×5, dil-conv-3×3, dil-conv-5×5, zero。

SGAS在CIFAR-10的训练集与验证集搜索结构,并在测试集上进行测试,结果如表1所示:

SGAS在CIFAR-10的训练集与验证集搜索结构,并在ImageNet测试集上进行测试,结果如表2所示:

我们的SGAS在性能超越了手工设计的网络结构以及其他NAS算法。

GCN

我们是同时将SGAS用到GCN的网络结构搜索中的。 GCN网络结构由普通单元(normal cell) 组成。其搜索空间由10个运算组成:conv-1×1, MRConv, EdgeConv, GAT, SemiGCN, GIN, SAGE, RelSAGE, skip-connect, and zero operation。

SGAS在ModelNet10的训练集与测试集搜索结构,并在ModelNet40训练集和测试集上进行训练与测试,结果如表3所示:

我们也将SGAS应用到生物信息图的结点预测上。我们在PPI (protein protein intersection) 数据集的训练集与验证集搜索结构,并在PPI的训练集和测试集上进行训练与测试,结果如表4所示:

我们SGAS在GCN上的实验,超越了之前最好的模型。我们在ModelNet40以及PPI数据集上成为了新的state-of-the-art. 

更多的结果与分析可以阅读我们的论文https://arxiv.org/abs/1912.00195(包含附录), 访问我们的项目网页 https://www.deepgcns.org/auto/sgas或 参考我们的开源代码 https://github.com/lightaime/sgas。 

[1] Barret Zoph and Quoc V Le. Neural architecture search with reinforcement learning. arXiv preprint arXiv:1611.01578, 2016.

[2] Esteban Real, Alok Aggarwal, Yanping Huang, and Quoc V Le. Regularized evolution for image classifier architecture search. In Proceedings of the AAAI Conference on Artificial Intelligence, volume 33, pages 4780–4789, 2019.

[3] Hanxiao Liu, Karen Simonyan, and Yiming Yang. Darts: Differentiable architecture search. arXiv preprint arXiv:1806.09055, 2018.

[4] Maurice G Kendall. A new measure of rank correlation. Biometrika, 30(1/2):81–93, 1938.

理论CVPR 2020CNNGCN
3
相关数据
权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

剪枝技术

剪枝顾名思义,就是删去一些不重要的节点,来减小计算或搜索的复杂度。剪枝在很多算法中都有很好的应用,如:决策树,神经网络,搜索算法,数据库的设计等。在决策树和神经网络中,剪枝可以有效缓解过拟合问题并减小计算复杂度;在搜索算法中,可以减小搜索范围,提高搜索效率。

有向无环图技术

在图论中,如果一个有向图从任意顶点出发无法经过若干条边回到该点,则这个图是一个有向无环图(DAG图)。 因为有向图中一个点经过两种路线到达另一个点未必形成环,因此有向无环图未必能转化成树,但任何有向树均为有向无环图。

验证集技术

验证数据集是用于调整分类器超参数(即模型结构)的一组数据集,它有时也被称为开发集(dev set)。

最大池化技术

最大池化(max-pooling)即取局部接受域中值最大的点。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

贪心算法技术

贪心法,又称贪心算法、贪婪算法、或称贪婪法,是一种在每一步选择中都采取在当前状态下最好或最优(即最有利)的选择,从而希望导致结果是最好或最优的算法。比如在旅行推销员问题中,如果旅行员每次都选择最近的城市,那这就是一种贪心算法。

图像分类技术

图像分类,根据各自在图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法。它利用计算机对图像进行定量分析,把图像或图像中的每个像元或区域划归为若干个类别中的某一种,以代替人的视觉判读。

强化学习技术

强化学习是一种试错方法,其目标是让软件智能体在特定环境中能够采取回报最大化的行为。强化学习在马尔可夫决策过程环境中主要使用的技术是动态规划(Dynamic Programming)。流行的强化学习方法包括自适应动态规划(ADP)、时间差分(TD)学习、状态-动作-回报-状态-动作(SARSA)算法、Q 学习、深度强化学习(DQN);其应用包括下棋类游戏、机器人控制和工作调度等。

堆叠技术

堆叠泛化是一种用于最小化一个或多个泛化器的泛化误差率的方法。它通过推导泛化器相对于所提供的学习集的偏差来发挥其作用。这个推导的过程包括:在第二层中将第一层的原始泛化器对部分学习集的猜测进行泛化,以及尝试对学习集的剩余部分进行猜测,并且输出正确的结果。当与多个泛化器一起使用时,堆叠泛化可以被看作是一个交叉验证的复杂版本,利用比交叉验证更为复杂的策略来组合各个泛化器。当与单个泛化器一起使用时,堆叠泛化是一种用于估计(然后纠正)泛化器的错误的方法,该泛化器已经在特定学习集上进行了训练并被询问了特定问题。

节点分类技术

节点分类任务是算法必须通过查看其邻居的标签来确定样本的标记(表示为节点)的任务。

结构搜索技术

深度学习提供了这样一种承诺:它可以绕过手动特征工程的流程,通过端对端的方式联合学习中间表征与统计模型。 然而,神经网络架构本身通常由专家以艰苦的、一事一议的方式临时设计出来。 神经网络架构搜索(NAS)被誉为一条减轻痛苦之路,它可以自动识别哪些网络优于手工设计的网络。

量化技术

深度学习中的量化是指,用低位宽数字的神经网络近似使用了浮点数的神经网络的过程。

暂无评论
暂无评论~