Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

ECCV 2020 | 华为诺亚与高校合作提出基于元强化学习的跨任务可迁移网络架构搜索方案

在近日召开的国际计算机视觉顶会 ECCV 2020 中,来自华为诺亚方舟实验室、香港大学、中山大学、香港科技大学的研究者合作提出了一种基于元强化学习(Meta-RL)的网络架构搜索方案,通过构建多个小型任务对搜索策略进行预训练,从而在迁移到目标任务上时能取得更快更好的搜索效果。

近年来,神经网络架构搜索(NAS)取得了许多突破,但许多算法仍受限于特定的搜索空间及视觉任务,比如大多数算法都是面向一个固定的数据集,无法在面对多项任务时,对跨任务的知识进行重复利用,因此无法实现对搜索策略进行跨任务的高效迁移学习

本文提出的 CATCH 是一种基于元强化学习(Meta-RL)的网络架构搜索方案,它通过构建多个小型任务对搜索策略进行预训练,从而在迁移到目标任务上时能取得更快更好的搜索效果。元学习(Meta-learning)和强化学习(RL)的结合使得 CATCH 能有效地适应新任务,且由于各类搜索空间都能被建构为序列决策(sequential decision-making)类问题,因此这类方法可以适用于多种搜索空间。近年来的 NAS 研究进展迅速,但大多关注单一任务,本篇论文也是目前少数率先研究多领域(分类、检测、分割)架构搜索的算法之一。

论文地址:https://www.catch-nas.com/


背景和动机

当前的 NAS 方法已经在多个领域产生了超过了人工设计的神经网络,不过这些方法在可迁移性和搜索效率上还有待提高。NAS 的技术在未来有很大的应用潜能,但这些美好愿景的实现很大程度要求搜索算法具备一些能力,如: (1) 有效处理大量任务;(2)广泛适用于不同的搜索空间(search space);(3)保持其在各种任务下的搜索表现。这些特征是当前的许多算法所忽略的,主要体现在:

1. 搜索策略缺乏在多个任务之间的迁移能力。许多算法只能在遇到新任务时从头开始重复而低效地进行搜索。

2. 对源任务搜索结果的直接部署无法保证最优表现。例如当前通常的做法是将 CIFAR-10 分类任务的搜索结果直接部署到 ImageNet 分类任务,这样的做法无法保证直接部署的网络结构是最优的,且当任务的性质差距较大时(如 MNIST 与 ImageNet),直接部署也并不合理。

3. 一些算法的搜索空间比较受限。 例如当前一些可微分算法(DARTS)只能应用于微观结构(cell-structure)的搜索,尚未适用于更广泛结构的搜索,缺乏通用性。


为了解决这些问题和增强 NAS 算法的可迁移性,作者提出了 CATCH,一个基于元强化学习的跨任务可迁移网络架构搜索方案。如图 1 所示,CATCH 框架中的搜索代理(即 CATCHer)充当决策者。作者首先通过构建资源耗费低的元训练任务,对 CATCHer 进行元训练,然后将其部署到目标任务以快速适应。

CATCH 方法概述

作者的灵感来自元强化学习(meta-RL)。在 meta-RL 中,相同动作空间(action space)但奖励函数(reward function)不同的问题可以被看作为不同的任务(task)。在 NAS 中,大部分问题的搜索空间是不变的,即动作空间一致;但数据集(例如 CIFAR-10 与 ImageNet)或者视觉领域(例如分类、检测、分割)的更改都会改变奖励函数,因此根据定义可以将这些数据集(无论是否在同一领域)视为不同的任务。从这个视角来看,NAS 的跨领域、跨数据集高效搜索问题,就可以被转化为元强化学习中对多个任务的快速迁移与适应的一个问题。

强化学习在这里之所以有效,很大程度也依赖于许多数据集的网络设计中存在着一些共性,这些共性在过去成就了迁移学习(如将一个优秀的 ImageNet 分类网络迁移到 COCO 检测任务上),也为元强化学习创造了快速适应(fast adaptation)的空间。

搜索过程

如上图所示,在每一次任务的搜索过程中,CATCH 会首先选取任意一个网络进行训练、评估,并获得(模型 - 奖励)的一个元组(m,r)用于初始化搜索历史。

接下来 CATCHer 的三个核心组件会分别发挥作用:

1. 任务信息编码器(Context Encoder):任务信息编码器通过变分推断(Variational Inference)的方式将搜索历史编码为任务表征 z,指导 RL 控制器和网络评估器的表现。 作者将任务表征 z 建模为具有对角协方差矩阵的多元高斯分布,在编码的过程中,编码器旨在估计后验分布 p(z|c_{1:N})。由于 c_{1:N} 只与任务有关,因此可以被分解为高斯因子的乘积,

2. 其中 f_\phi 用于预测 p(z|c_i)任务信息均值和标准差的神经网络。任务表征 z 从分布中随机采样得到,这样的设定有利于将对任务的不确定性加入模型,并缓解稀疏奖励(sparse reward)的问题实现有效探索。

3.RL 控制器(RL Controller):RL 控制器进行序列决策,生成候选网络(Candidate Networks)。网络的生成可以被视为决策问题,其中 RL 控制器的每个动作都决定了最终架构的一个属性。该属性可以是在微观结构搜索中形成特定的操作类型(例如,跳跃连接、卷积操作等),也可以是在宏观结构(macro skeleton)搜索中形成网络的形状(例如,宽度、深度等)。

网络 m 可以表示成序列决策列表 [a_1, a_2, …, a_L]。在每个时间步长,RL 控制器输入已经完成的决策和任务表征 z,并输出选择接下来动作的分布,从中采样相应的动作。控制器随机采样 M 个网络作为候选网络。

4. 网络评估器(Network Evaluator)。网络评估器用于预测候选网络的性能,并确定选取预测值最高的网络进行实际训练。


优化过程

所有这三个组件都可以进行端到端优化。RL 控制器(以 \ theta_c 为参数)使用 RL 算法近端策略优化(PPO)进行训练:

网络评估器通过优化 Huber Loss 进行训练,其中采用了优先经验回放(Prioritized Experience Replay)的技巧,从而提高采样效率。

任务信息编码器的优化过程将上述两个优化目标作为其优化目标的一部分。每个任务的最终变分下界(Variational Lower Bound)为

公式中 KL 散度近似于约束 z 和 c 之间的互信息的变分信息瓶颈(Variational Information Bottleneck),此信息瓶颈可作为正则化,以避免过拟合训练任务。p(z)是单位高斯先验。由于 (1) 任务表征 z 用作控制器和评估器的输入,并且 (2) p(z)和 q_\phi(z|c)是高斯分布,且其中 KL 散度使用其均值和标准差进行计算,因此可以使用重参数化技巧(Reparameterization Trick)将梯度端对端反向传播到任务信息编码器进行更新。

CATCH 框架包括两个阶段:如算法 1 所示的元训练阶段(Meta-training Phase)和适应阶段(Adaptation Phase)。在元训练阶段,我们在一组资源消耗低的元训练任务中训练 CATCHer。此阶段的主要目标是为任务信息编码器提供多样化的任务,使其产生有意义的表征。在适应阶段,经过元训练的 CATCHer 在任务信息编码的指导下来有效找到目标任务的优秀网络。

实验和结果

作者在两个不同的搜索空间上证明了 CATCH 的有效性和通用性,分别是微观结构搜索空间和基于残差模块(Residual Block)的宏观结构搜索空间。同时也探索了从分类任务迁移到不同领域任务(目标检测语义分割)的可能性。

微观结构搜索空间

作者从 ImageNet16(图片缩小为 16x16)数据集中任意抽取 10/20/30 类,构建了 25 个元训练图像分类任务。CATCH 在元训练过后分别迁移到 CIFAR-10、CIFAR-100、ImageNet16-120(图片缩小为 16x16,抽取前 120 类)分类任务上进行搜索。并基于 NAS-Bench-201 基准数据集进行快速的算法评估,与其他 NAS 算法进行了比较。实验证明和 sample-based 和 one-shot 的方法进行比较都能够快速适应目标任务找到最优网络。

基于残差模块的宏观结构搜索空间

作者将 ImageNet 数据集中的图片任意缩小为 16x16、32x32、224x224 并任意抽取 10/20/30 类,同样构建了 25 个元训练图像分类任务。在元训练过后分别迁移到 ImageNet 图像分类、COCO 目标检测以及 CityScapes 语义分割任务上进行网络搜索。得到了具有竞争力的搜索结果。

总结

本文是跨任务可迁移 NAS 的早期工作之一,提出了基于元强化学习的跨任务可迁移网络架构搜索方案 CATCH。 CATCH 主要通过在大量元训练任务上对搜索策略进行预训练,同时获得提取任务表征的能力来实现在目标任务上的快速适应。在两个搜索空间上的实验显示了将 CATCH 扩展到大型数据集和各种视觉领域方面的潜力。

理论华为诺亚方舟实验室NAS元强化学习
相关数据
高斯分布技术

正态分布是一个非常常见的连续概率分布。由于中心极限定理(Central Limit Theorem)的广泛应用,正态分布在统计学上非常重要。中心极限定理表明,由一组独立同分布,并且具有有限的数学期望和方差的随机变量X1,X2,X3,...Xn构成的平均随机变量Y近似的服从正态分布当n趋近于无穷。另外众多物理计量是由许多独立随机过程的和构成,因而往往也具有正态分布。

基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

变分推断技术

see Variational Bayesian methods (approximation)

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

元学习技术

元学习是机器学习的一个子领域,是将自动学习算法应用于机器学习实验的元数据上。现在的 AI 系统可以通过大量时间和经验从头学习一项复杂技能。但是,我们如果想使智能体掌握多种技能、适应多种环境,则不应该从头开始在每一个环境中训练每一项技能,而是需要智能体通过对以往经验的再利用来学习如何学习多项新任务,因此我们不应该独立地训练每一个新任务。这种学习如何学习的方法,又叫元学习(meta-learning),是通往可持续学习多项新任务的多面智能体的必经之路。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

协方差矩阵技术

在统计学与概率论中,协方差矩阵(也称离差矩阵、方差-协方差矩阵)是一个矩阵,其 i, j 位置的元素是第 i 个与第 j 个随机向量(即随机变量构成的向量)之间的协方差。这是从标量随机变量到高维度随机向量的自然推广。

迁移学习技术

迁移学习是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中。迁移学习是通过从已学习的相关任务中转移知识来改进学习的新任务,虽然大多数机器学习算法都是为了解决单个任务而设计的,但是促进迁移学习的算法的开发是机器学习社区持续关注的话题。 迁移学习对人类来说很常见,例如,我们可能会发现学习识别苹果可能有助于识别梨,或者学习弹奏电子琴可能有助于学习钢琴。

过拟合技术

过拟合是指为了得到一致假设而使假设变得过度严格。避免过拟合是分类器设计中的一个核心任务。通常采用增大数据量和测试样本集的方法对分类器性能进行评价。

正则化技术

当模型的复杂度增大时,训练误差会逐渐减小并趋向于0;而测试误差会先减小,达到最小值后又增大。当选择的模型复杂度过大时,过拟合现象就会发生。这样,在学习时就要防止过拟合。进行最优模型的选择,即选择复杂度适当的模型,以达到使测试误差最小的学习目的。

语义分割技术

语义分割,简单来说就是给定一张图片,对图片中的每一个像素点进行分类。图像语义分割是AI领域中一个重要的分支,是机器视觉技术中关于图像理解的重要一环。

图像分类技术

图像分类,根据各自在图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法。它利用计算机对图像进行定量分析,把图像或图像中的每个像元或区域划归为若干个类别中的某一种,以代替人的视觉判读。

强化学习技术

强化学习是一种试错方法,其目标是让软件智能体在特定环境中能够采取回报最大化的行为。强化学习在马尔可夫决策过程环境中主要使用的技术是动态规划(Dynamic Programming)。流行的强化学习方法包括自适应动态规划(ADP)、时间差分(TD)学习、状态-动作-回报-状态-动作(SARSA)算法、Q 学习、深度强化学习(DQN);其应用包括下棋类游戏、机器人控制和工作调度等。

目标检测技术

一般目标检测(generic object detection)的目标是根据大量预定义的类别在自然图像中确定目标实例的位置,这是计算机视觉领域最基本和最有挑战性的问题之一。近些年兴起的深度学习技术是一种可从数据中直接学习特征表示的强大方法,并已经为一般目标检测领域带来了显著的突破性进展。

结构搜索技术

深度学习提供了这样一种承诺:它可以绕过手动特征工程的流程,通过端对端的方式联合学习中间表征与统计模型。 然而,神经网络架构本身通常由专家以艰苦的、一事一议的方式临时设计出来。 神经网络架构搜索(NAS)被誉为一条减轻痛苦之路,它可以自动识别哪些网络优于手工设计的网络。

推荐文章
暂无评论
暂无评论~