杨旭韵作者H4O编辑

如何基于元学习方法进行有效的模型训练?四篇论文详细剖析元模型的学习原理和过程

本文以四篇最新论文为例,详细剖析了元模型的学习原理和过程。

机器学习领域,普通的基于学习的模型可以通过大量的数据来训练得到模型参数,并在某种特定任务上达到很不错的效果。但是这种学习方法限制了模型在很多应用场景下的可行性:在具体的现实情况中,大量数据的获取通常是有难度的,小样本学习机器学习领域目前正在研究的问题之一;另外,模型在训练过程中只接触了某一特定任务相关的数据样本,在面对新任务时,其适应能力和泛化能力较弱。


反观人类的学习方法,不仅仅是学会了一样任务,更重要的是具备学习能力,能够利用以往学习到的知识来指导学习新的任务。如何设计能够通过少量样本的训练来适应新任务的学习模型,是元学习解决的目标问题,实现的方式包括[1]:根据模型评估指标(如模型预测的精确度)学习一种映射关系函数(如排序),基于新任务的表示,找到对应的最优模型参数;学习任务层面的知识,而不仅仅是任务中的具体内容,如任务的分布、不同任务的特征表示;学习一个基模型,这个基模型的参数是基于以往多种任务的各个特定模型而得到的,等等。

图 1:什么是元学习(图源:http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/Meta1%20(v6).pdf)

下面从元学习的工程优化、解决局部最优和过拟合问题、模型解释性等方面详细解读和分析四篇论文。

一、"TaskNorm: Rethinking Batch Normalization for Meta-Learning"

核心:元模型训练阶段的工程优化


本文是发表于 ICML 2020 中的一篇文章[2],是剑桥大学、Invenia 实验室和微软研究院学者共同合作的研究成果,提出了一种适用于元学习在模型训练时的数据批量标准化方法。

深度学习中网络模型的训练通常基于梯度下降法,与模型学习效果相关的因素包括了学习步长(学习率)、网络初始化参数,并且当涉及深层网络训练时,还需要考虑梯度消失的问题。标准化层(normalization layer,NL)的提出,使得增加了标准化层的网络在训练时,能够使用更高的学习率,并且能够降低网络对于初始参数的敏感度,对于深层网络的训练更加重要。NL 的一般表示为:


其中,γ和β为学习的参数,μ和σ是标准化的统计量,a_n 和 a’_n 是输入和标准化后的输出。

图 1.1:元学习的训练集。这是图片分类的例子,在不同 episode 中,由不同的子类构成不同的分类任务;在相同的 episode 中,支持集和查询集包含了相同的子类。来自:https://www.jiqizhixin.com/articles/2019-07-01-8

元学习的训练数据集包括了 context set Dτ(也称为 support set,支持集)和 target set Tτ(也称为 query set,查询集),如图 1.1 所示。利用这个数据集进行两个阶段的训练:在内层(inner loop)阶段,使用 context set 来更新参数θ,得到特定任务的参数ψ;在外层(outer loop)阶段(fφ表示由θ生成ψ的一个过程,可能会引入额外的参数φ),对 target set 中的 input 进行预测,并得到目标损失函数


元学习中的分层框架(inner loop 和 outer loop 两层更新,如图 1.2 所示),可能会使得传统的批标准化方式(batch normalization,BN)失效:BN 的使用具有一定的前提条件,独立同分布 iid 条件,而元学习可能不满足这个条件,如果直接使用 BN 方法在元学习的网络模型中引入标准化层,可能会导致不理想的元模型效果。


作者提出了一种适用于元学习的标准化方式 --- 任务标准化(task normalization,TaskNorm),它能够提升模型训练的速率和稳定性,并且能够保持理想的测试效果;另外,它适用于不同大小的 context set,并没有受到很大的影响;而且这种标准化方式是非直推式的,因此在测试的时候能够适用于更多的情景(即更多样的图像分类任务)。在具体展开介绍 TaskNorm 之前,作者先对元学习的推理方式和几种常见的标准化方法进行简单介绍,并且说明了在元学习中对应不同的标准化方法的统计量μ和σ的计算和使用方式。

1.1 方法介绍

  • 直推学习(transductive meta-learning)和非直推学习(non-transductive meta-learning)


对于元学习,作者讨论了两种方式:直推学习和非直推学习。非直推学习的元测试(meta-test)阶段,在对测试集(和训练集类似,也包括了 context set 和 target set)中的单个样本进行类别预测时,仅仅使用 context set 以及输入的观测值。直推学习的元测试阶段,对单个样本进行预测时,不仅需要 context set 和观测值,还需要测试集中其他样本的观测值。作者认为,元学习中的标准化层需要是 * 非直推式 * 的,因为对于直推学习,作者认为它的两个问题:

1. 对 target set 的分布敏感。在 outer loop 时,需要用到 target set 的其他样本,即当前样本的预测输出还与其他样本的输入相关,因此这种方式相比于非直推学习的泛化性更弱。如果在元测试中使用的 target set 样本的类别平衡情况和训练时有差别,那么模型在测试时的分类效果可能并不会很好。

2. 直推学习利用到了更多的信息(相当于需要依赖的信息更多),因此如果将两种方法直接进行比较是不公平的。


  • 几种基本的标准化方式以及在元学习中的应用


批标准化(batch normalization,BN)。BN 在训练阶段和测试阶段的使用模式是不一样的。在元训练(meta-training)阶段,均值和方差的计算如下所示:


在 BN 中,输入的通道数不变,对每个通道、使用整个 batch 进行变换,这种标准化的方式没有涉及不同通道之间的数据交换。更直观一点,数据集输入的维度表示为 < B,C,W,H>,那么标准化计算量μ和σ的维度表示为 <1,C,1,1>。使用所有 batch 计算统计量有一个前提,就是假设了 batch 中的数据服从独立同分布。在测试阶段,使用的均值和方差是训练集所有数据的均值和方差。

元学习网络中直接使用批标准化(Conventional Usage of Batch Normalization,CBN),会有两个重要的问题:(1)在元测试阶段,使用的是根据元训练阶段数据集计算得到的μ和σ,可以认为这两个统计量是和元模型等效的参数。然而,训练时的数据集包括了所有不同的任务,独立同分布的条件只是在相同任务的数据之间满足、在不同任务之间不一定满足。作者将 CBN 应用在 MAML 方法 [3] 中,实验结果表明了该方法在预测任务上表现并不好。(2)当训练过程中使用的 batch-size 较小,得到的统计量可能并不准确时,模型的效果也会受到影响。

图 1.3:批标准化(BN),元学习训练和测试过程中直接使用 BN 的方式。图源:[2]

基于实例的标准化(Instance-based Normalization)。基于实例的标准化方式是非直推式的,统计量只根据当前实例(如单张图片)来计算μ和σ,并且不依赖于 context set 数据集的大小

1. 实例标准化(instance normalization,IN)。针对单张图片的 (H,W) 两个维度计算统计量(即每一张图只对 H 和 W 维度进行归一化),每一张图都有对应的统计量。该计算方式在元训练阶段(使用训练集)和元测试阶段(使用测试集)是一样的。

图 1.4:实例标准化(IN),元学习中 context set 和 target set 使用 IN 的方式。图源[2]

2. 层标准化(layer normalization,LN)。LN 针对图片单独进行变换,并考虑到了多个通道的维度。该计算方式在元训练阶段(使用训练集)和元测试阶段(使用测试集)是一样的。作者在后续提供的实验结果中,指出 LN 相比于其他标准化方式,在训练效率方面的表现较不足。

图 1.5:层标准化(LN),元学习中 context set 和 target set 使用 LN 的方式。图源:[2]

直推批标准化(transductive batch normalization,TBN)。相比于 CBN,TBN 的标准化方式在元测试阶段,并不是使用元训练阶段数据集的统计量,而是使用测试数据集(包括 context-set 或者是 target-set)来计算μ和σ。另外,TBN 会根据不同的任务分别计算各自的统计量。

虽然这种方法能够获得更好的效果,但是在元测试时,对于 target-set 的标准化处理使用了 target-set 全局的统计量,相当于测试的数据之间是存在某种信息交流和利用的,给了更多的先验信息,提升测试的准确率。这种方式在信息利用方面和非直推学习方式并不是对等的,因此不能直接比较 TBN 和其他的非直推方式。

图 1.6:直推式批标准化(TBN)。图源:[2]

1.2 任务标准化(Task Normalization, TASKNORM)

本质上,找到适用于元学习的标准化方法,关键在于找到合适的统计量μ和σ。根据标准化处理对于数据的独立同分布条件要求,对于元学习来说, μ和σ应该是任务级别的统计量,在一定程度上是融入任务模型参数ψ中。ψ是元模型通过适应 context set 而得到的任务模型的参数,因此在任务模型的推理阶段,用到的统计量μ和σ也应该能够从 context set 计算得到。

结合上述元学习对于标准化统计量的要求,作者首先提出了一种元批量标准化方法( meta-batch normalization,MetaBN)。对于每个任务,在 context set 中计算各自的均值和方差,这个统计量共用于 context set 和 target set;在元训练阶段和元测试阶段,是分别根据训练集中的 context set 和测试集中的 target set 得到各阶段的标准化统计量。但是,这种标准化方法仍然会受到 context set 大小的影响:当 context set 的 batch size 较小时,统计量的准确度不够高,会影响模型的预测效果。

图 1.7:MetaBN 方法和 TaskNorm 方法(包括 TaskNorm-L 和 TaskNorm-I)。图源:[2]

进一步地,作者保留了 MetaBN 的优点,结合基于实例的标准化方法不依赖数据集大小的特点,提出了本文的核心内容:任务标准化(TASKNORM)。TASKNORM 方法是在 MetaBN 的基础上,结合了 LN 或者是 IN,可以具体分为 TaskNorm-L 以及 TaskNorm-I 两种标准化方法:元训练(元测试)阶段,使用训练集(测试集)的 context set 得到统计量,context set 和 target set 都使用该统计量以及各自的 LN 或者 IN 的加权和,得到最终用于标准化的统计量,其中两部分统计量的权重由超参数α控制。此时的μ和σ的计算由下式得到:


其中,μ_{BN}和σ^2_{BN}是根据 context set 计算的统计量,μ+ 和σ+ 是根据层标准化(LN)或者是实例标准化(IN)得到的非直推式的统计量。这种结合方式的出发点是 * 解决使用少样本学习时存在的样本数量相关问题 *:当 context set 的样本量很少时,仅根据该数量集得到的统计量可能会得到关于该任务的不准确的数据;当结合其他统计量时,有助于提升训练效率以及模型的预测效果。

作者将权重α定义为一个参数化的变量,它和 context set 大小具有线性关系,表示为:α=sigmoid(scale|Dt| + offset)。其中 Dt 为 context set 元素个数,scale 和 offset 在元训练阶段是可学习的。α和 support set 大小之间存在线性关系式,表示为:α=sigmoid(scale|Dt| + offset)。其中 Dt 为 context set 的大小,scale 和 offset 是在元训练时学习得到的。

1.3 实验介绍

作者分别在小规模数据集和大规模数据集上进行少样本(few-shot)分类任务,对比几种标准化方法,验证本文提出的几个猜想:1)元学习对于标准化方式是比较敏感的;2)直推批标准化(TBN)比非直推批标准化的效果普遍要好;3)考虑了元学习数据集特性的方法如 TaskNorm,MetaBN 以及 RN 的效果,会比 CBN,BRN(batch renormalization),IN,LN 等没有考虑元学习数据特性的方法要好。在实验中,作者关注的指标包括模型预测的准确度和训练效率。

表 1.1:基于小数据集(mini imagenet 和 omniglot)的分类实验,此时仅考虑固定大小的 context set 和 target set。来自:[2]

表 1.2:基于大数据集 meta-dataset(包含了 13 个图像分类的数据集)的分类实验。来自:[2]

图 1.8:不同标准化方法得到的模型准确度和训练过程的对比图。图源:[2]

1.4 小结

本文提出了一种适用于元学习的标准化方法 TASKNORM,基于传统批标准化方法对统计量的计算进行改进。在计算用于数据标准化的统计量均值μ和方差σ^2 时,该方法考虑了任务内数据的独立同分布、任务间的数据不满足独立同分布条件,context set 大小的影响,以及考虑非直推式的学习方式,从而使得元学习模型能够应用在更多的场景。通过大量的对比实验,验证了使用 TASKNORM 方法能够提升元学习模型的训练效率和预测效果。

二、 "Meta-Learning with Warped Gradient Descent" (ICLR2020)

核心:解决基于梯度的元学习方法的参数局部最优问题


本文是发表于 ICLR 2020 中的一篇满分论文[4],由曼彻斯特大学、Alan 图灵研究机构和 DeepMind 的研究员提出了元学习中的梯度预处理计算方法。

元学习领域有一个重要的问题,是学会一种更新规则,能够快速适应新的任务。处理这个问题的方式通常有两种:训练网络来产生更新(学习更新方式);或者是学习一个比较好的初始化模型或者是比例因子,应用于基于梯度更新的学习方法(学习和梯度更新相关的因素)。前者容易导致不收敛的效果,后者在少样本(few-shot 任务中的适应效果可能不太好。

作者结合前面说的两种方式,提出一种弯曲梯度下降(warped gradient descent)的方法,它主要学习一个参数化预处理矩阵,该矩阵是通过在 task-learner 网络模型的各层之间交叉放置非线性激活层(即弯曲层,warped layers)而产生。在网络训练时,这些 warp 层提供了一种更新方式,而它的参数是 meta-learned,在模型训练过程中是不经过梯度回传的。

为了验证这种梯度更新方式的有效性,作者还将这种弯曲梯度方法应用在少样本学习,标准的有监督学习,持续学习和强化学习等多种设定下进行实验。

2.1 方法介绍

在基于梯度更新的元学习中,task-learner 元参数的更新规则表示为 U(θ; ξ):= θ-α∇L(θ),初始参数θ_0 的元学习过程可表示为:


这类方法由于依赖于梯度更新的轨迹,会存在一些问题:梯度的计算会涉及到较大的计算量;容易受到梯度爆炸或者是梯度消失情况的影响;置信度分配问题。将损失函数 L 抽象成一个曲面,该曲面的情况会影响参数调整的效果,并且此时的参数空间不一定是合理的、不一定适用于不同任务的空间。

针对这几个问题,作者首先了介绍了一种结合预处理的梯度更新通用规则,表示为:


其中,P 表示一个用于预处理梯度的曲面。为了更好地拆分预处理模块的参数和 task-learner 的参数,作者使用了一种更为灵活的结构:在多层网络模型中插入全局参数化的 warp 层。最为简单的一种插入方式表示为:


h 是网络的隐藏层,w 是插入的 warp 层。在梯度回传时,对于 warp 层使用的是 Jacobian 矩阵(Dx 和 Dθ)来计算:


  • warp-layers 的具体原理和计算流程


如图 2.1 所示,是 warp 层在 task-learner 中的使用和计算流程。对于 task learner f(x),隐藏层之间(h1 和 h2)嵌入 warp 层(ω1 和ω2):在前向计算时,warp 层相当于激活层;在任务适应阶段(task adaptation)的后向回传中,warp 层通过 Dω来提供梯度。这就是本文提出的用于网络参数更新的 WarpGrad 方法。

图 2.1:warp 层及 WarpGrad 计算的示意图。图源:[4]

通过曲面的图示来更形象地展示 WarpGrad 起到的作用,如图 2.2 所示。在理想的 W 空间曲面,能够产生梯度上的预处理,找出梯度下降的最大方向。

图 2.2:上一行表示 WarpGrad 学习到的元几何(meta-geometry)P 曲面;下一行表示不同任务的损失函数 W 曲面,其中黑线是普通梯度下降的方向,紫色是利用元几何得到的梯度下降的方向。图源:[4]

考虑到 warp 层具有几何曲面的表示意义,作者提出 warp 层实际上是近似一个矩阵 G,该矩阵是一个正定的矩阵向量,用于度量流形的曲率。

Ω表示 warp-layers 起到的作用,它相当于通过重参数化(ω)来近似于最快的梯度下降方向:


在 P - 空间和 W - 空间上的梯度表示为:


其中,γ=Ω(θ; φ)表示从 P 空间映射到 W 空间的映射参数,并且


P 空间的参数梯度和 W 空间的参数梯度之间的转换关系如图 2.3 所示:

图 2.3:P 空间的θ参数梯度等价于 W 空间的γ参数梯度。图源:[4]

Warp 层参数控制了理想曲面的生成,本质上控制了 task learner 的收敛目标。因此,为了积累所有任务的信息帮助提升任务适应的过程,warp 层参数是通过元学习来训练得到的,目标函数表示为:


  • Warp-layer 参数的学习方式


作者定义了一个高层的任务τ=(h, L_{meta}, L_{task}),L_{meta}作为元训练的目标损失函数,用于 warp 参数的适应学习;L_{task}作为任务适应的目标函数,用于θ参数的适应学习。


上式对于φ的学习,依赖于 L-task,会涉及到二阶梯度的计算。作者进一步做梯度截断(stop gradient),使得φ的更新只涉及一阶梯度。



图 2.4:warpgrad 应用于在线元学习和离线元学习的算法流程。图源:[4]

2.2 实验介绍

在实验部分,作者在元学习方法 MAML[3]和 Leap[5]方法中引入 WarpGrad 的更新方式,在两个数据集(miniImageNet 和 tieredImageNet)上做少样本(few-shot)学习和多样本(multi-shot)学习,使用了 WarpGrad 方法的元学习模型能够超过普通元学习模型在分类任务上的准确率

图 2.5:使用 warpgrad 方法进行少样本学习和多样本学习的对比实验。图源:[4]

作者还验证了 WarpGrad 方法对模型在不同任务上的泛化能力的作用。如图 2.6 所示,在不同任务数量的实验中,Warp-Leap 模型的测试准确率明显高于其他几种基准方法。

图 2.6:对比不同方法在不同任务数量实验中的准确率。图源:[4]

2.3 小结

本文提出了一种更为泛化的基于梯度的元学习方法 WarpGrad,在网络中引入 warp 层用于预处理原始梯度,该方法的特点包括:(1)WarpGrad 方法本质上是一种基于梯度的更新方式,它的创新之处在于对梯度进行了预处理,所以它也具有梯度下降法的特性,能够保证训练模型的收敛;(2)warp 层构造了梯度预处理的分布,而这个分布所具有的几何曲面能够从任务学习者中分离出来;(3)warp 层的参数是通过任务和对应轨迹来元学习得到的,根据局部的信息来获得任务分布相关的属性;(4)相比于用预处理矩阵来直接对梯度进行处理,warp 层在网络模型中同时参与了前向计算和后向梯度回传,是一种更为有效的学习方法。

三、"Meta-Learning without Memorization"

核心:解决任务层面的过拟合问题


本文是由 Google brain 团队和 UT Austin 学者发表于 ICLR 2020 中的一篇文章[6],它探讨了元学习模型的记忆问题并提出解决方法。

在分类任务中,当图片和类别标签并不是互斥的(mutually-exclusive)时(如在分类任务 1 中,狗的类别标签是 2;在分类任务 3 中,狗的类别标签仍然是 2),分类模型做的事情其实是直接将类别标签和图片中的数据特征对应起来。此时,训练得到元模型可能 * 无法 * 很好地应用在新的分类任务上:在训练阶段,模型不需要适应训练数据集、就可以在测试数据集上达到较好地效果;而在推理阶段,适应能力较弱的模型,则无法适应新任务的训练数据集,很难在新任务的测试数据集上达到理想效果。

图 3.1:Meta-learning 的图模型表示。图源:[6]

结合元学习的图模型来进一步理解这个问题的定义。M 是元训练数据集,包括了在元训练阶段的训练数据集 D(support set)和测试数据集 D*(query set),θ是元模型参数,φ是特定任务模型参数(task-specific parameters)。q(θ|M)表示基于元训练数据的元参数分布,q(φ|D, theta)表示基于任务训练(per-task training)的任务参数分布,q(y*|x*, φ, θ)表示预测的分布:


那什么是记忆问题?就是 y * 的计算,可以独立于φ和 Di,完全依赖于θ和 x*,即 q(y*|x*, φ, θ)=q(y*|x*, θ)。此时,在测试数据集上的预测结果可以直接根据元模型参数θ来得到,而不需要经过通过适应 D 而得到优化后的参数φ来进行预测的过程。

3.1 方法介绍

在本文中,作者给出了记忆问题的数学形式,引入互信息(mutual information)这个概念:在元学习中的完全记忆,指的是模型在预测 y 时忽略任务训练数据集 D 的信息,即 y 和 D 之间的互信息为 0,表示为 I(y;D|x,θ)=0。为了同时达到低误差,以及 y * 和 (x*,θ) 之间的低互信息,需要利用任务训练数据 D 来做预测,即增大 I(y*;D|x*, θ),从而减少记忆问题。

在本文中,作者提出元正则项(meta-regularizer, MR),基于信息论来提供一个通用的、不需要在任务分布上设置限制条件的方法,解决元学习的记忆问题。更具体地,分别是:激活项上的元正则化(meta regularization on activations),权重上的元正则化(meta regularization on weights)。

激活项上的元正则化在上图中,当给定 theta 时,y * 和 x * 之间的信息流,包括了 y * 和 x * 之间的直接依赖,以及经过数据集 D 的间接依赖。作者提出,通过引入了一个中间变量 z*,有 q(ˆy* |x* , φ, θ) = ∫ q(ˆy* |z* , φ, θ)q(z* |x* , θ) dz*,控制 \ hat{y}* 和 x * 之间的信息流来解决记忆问题。

图 3.2:引入中间变量 z 的元学习的图模型,。图源:[6]

此时,为了引导模型有效地利用任务训练数据 D,增大的互信息目标变为 I(y*;D|z*, θ),通过如下的推导,等价于增大互信息 I(x*;y*|θ)和减小 KL 散度 E[D_{KL}(q(z*|x*,θ) || r(z*))]:


对于上式左项的互信息,假如 I(x*;y*|θ)=0,并且存在记忆问题(I(y*;D|x*,θ)=0)时,那么有 q(y*|x*, θ, D)=q(y*|x*, θ)=q(y*|θ),即预测结果 y * 并不依赖于观测值 x*,显然这样的模型并不会得到理想的预测准确度。因此,最小化损失函数(如式 (1))有助于引导互信息 I(y*;D|x*,θ) 或者是 I(x*;y*|θ)的最大化,所以在引入中间变量 z * 后,需要做的就是最小化 KL 散度,最终的损失函数表示为:


但是,作者在实验过程中发现这种方法在一些情况下并不能避免记忆问题,并进一步提出了另一种元正则化方法。

权重上的元正则化作者提出,通过惩罚元模型参数,减少元参数所带有的任务信息,从而降低模型对于任务的记忆能力、解决记忆问题。对于元参数θ中包含的训练任务信息,可以表示为 I(y*1:N,D1:N; θ|x*1:N ),它的上确界有:


参数的惩罚项即为最后的 KL 散度,该惩罚项实际上是限制模型参数的复杂度:如果模型需要去记住所有任务的信息,那么模型非常复杂;所以限制模型的复杂度,在一定程度上能够减少元参数包含的任务信息。但是,作者并没有完全限制模型参数的复杂度,在实际应用中,仍允许部分模型参数对任务训练数据进行处理,因此只是在部分参数θ上执行该惩罚项(模型的其他参数则表示为θ~),最后损失函数可以表示为:


3.2 实验介绍

本文分别在分类任务和回归任务上进行对比实验,在这些任务中图片标签和图片数据本身是非互斥的,用于验证元正则化方法在记忆问题上的有效性。如表 3.1 和 3.2 所示,使用了元正则化(MR)的方法,相比于其他的元学习基准方法,在分类任务和回归任务上都能明显获得更好的效果。

表 3.1:图片标签非互斥的回归任务(均方差),A 表示使用了激活项上的元正则化,W 表示使用了权重上的元正则化。来自:[6]

表 3.2:图片标签非互斥的分类任务(准确率)。来自:[6]

3.3 小结

本文从信息论的角度,提出了一种适用于不同的元学习方法的元正则化(MR)方法。该方法可以用在标签没有打乱(或者是很难打乱)的任务中,能够提升元学习方法在更多场景中的适用性和可行性,在一定程度上解决元学习的记忆问题。

四、"Unraveling Meta-Learning: Understanding Feature Representations for Few-Shot Tasks"

核心:探讨元模型特征表示模块的作用(元学习方法的可解释性)


本文是由马里兰大学的学者发表于 ICML 2020 中的一篇文章[7]。在少样本分类(few-shot classification)任务的场景中,元学习方法能够提供一个快速适应新任务(new tasks)或者是新域(new domains)的基础模型。然而,很少有工作去探讨模型达到不错效果的深层原因,如元学习方法中特征提取模块(feature extractor)提取得到的特征表示的不同之处是什么。

本文提出,相比于普通学习得到的特征表示,元学习得到的特征表示(meta-learned representations)是有区别的、更有助于少样本学习。使用元学习的特征表示能够提升少样本学习的效果,本文作者归为两种不同的机制:(1)固定特征提取模块参数,只更新(微调)最后的分类层(classification layer)参数。在这种机制下,类别数据点在特征空间中会更加聚集,那么在微调时,分类边界对于提供的样本会没那么敏感。(2)在模型参数空间寻找最优点作为基础模型,该最优点接近大部分特定任务(task-specific)模型参数的最优点,那么在面对新的特定任务时,能够通过几步的梯度计算,将基础模型更新为适用于新任务的特定模型。

进一步地,作者分别探讨上述两种机制的作用,定义了几种正则项,并结合正则项提出了几种带正则化的模型训练方法,通过实验验证了相关猜想以及正则化训练方法的有效性。

4.1 基于特征聚集的正则化方法

  • 4.1.1 在特征空间的类别特征点聚集


作者先讨论第一种机制,即微调时固定特征提取模块、只更新分类层,使用这类机制的元学习方法包括 ProtoNet[8],R2-D2[9]和 MetaOptNet[10]。这类方法能够达到好的分类效果,猜想是特征提取模块已经能够做到很好的特征区分、从而对于新的分类任务也能够实现少样本学习。

特征点聚集对于少样本学习的重要性。如下图所示,当类别的特征点是分散的、类间相隔较近时,选取少量样本来训练分割平面容易导致较大的分割误差;而当类别的特征点是聚集的、类间相隔较远,训练得到的分割平面准确度较高,分割平面对于样本选取的依赖较弱。

图 4.1:特征点聚集对于少样本训练分割平面准确度的重要性。图源:[7]

然后,作者通过对比元学习的 ProtoNet 和传统训练的网络模型的特征提取效果,验证了元学习方法在特征点聚集上做得更好,虽然没有直接证明特征点聚集对于少样本学习的必要性,但是为接下来提出的基于特征点聚集的正则项提供了重要的思路和启发。

图 4.2:ProtoNet 和经典分类网络在 mini-ImageNet 数据集上提取的特征进行可视化(使用 LDA 处理元学习和经典分类器提取的特征,可视化映射到二维空间的特征)。图源:[7]

本文考虑特征聚集的评估指标(feature clustering, FC),定义为类内方差和类间方差的占比。根据 FC 的定义,本文给出了特征聚集的正则项(feature clustering regularizer, R_fc)定义:


其中,f_{θ}(x_i,j)是特征提取模块 f_{θ}对样本 x 给出的特征表示,μ_i 是第 i 类的特征向量均值,μ是所有数据的特征向量均值。作者基于 R2-D2 和 MetaOptNet 的网络结构,结合交叉熵损失函数和该正则项,作为传统的训练方法的损失函数,在 mini-ImageNet 数据集和 CIFAR-FS 数据集上进行 1-shot 和 5-shot 的实验,对比使用元训练的方法和不使用该正则项的传统训练方法。

如表 4.1 所示,相比于没有用 R_fc 训练的网络效果,使用 R_fc 来训练网络,能够和元学习网络达到类似的高分。这进一步说明了使用 R_fc 可以得到类似于元学习网络得到的特征表示,那么元学习方法实际上也有做特征聚集的工作。

更进一步地,作者探讨 特征点聚集分割平面对数据样本不变性两者之间的联系,提出了超平面方差的正则项(hyperplane variation regularizer):


对于两个类别的特征点(A 类的 x1 和 x2,B 类的 y1 和 y2),该正则项衡量了不同类别数据点之间的距离向量的差异。当超平面对于数据样本有较强不变性时,该正则项的值越小。同样地,作者使用该正则项进行对比实验,效果和 Rfc 类似,比没有使用 Rhv 的传统训练方法的到的模型的分类效果要好。

表 4.1:使用 Rfc 或者是 Rhv 的对比实验结果。来自:[7]

前面的实验中,考虑的元学习训练方式是第一种机制,那对于微调时不会固定特征提取模块的元学习训练方式(比如 MAML 方法),情况又是怎样的呢?作者将 MAML 方法和迁移学习方法对比,发现 MAML 模型的效果并没有比传统训练模型的 feature seperation 效果更优,说明了特征聚集的提升作用,并不是元学习训练中会有的普遍现象,而是特定地存在于使用第一种机制的元训练模型中。于是接下来,作者对于元学习第二种机制的有效性进行了探讨和分析。

4.2 权重聚集的正则化方法(weight-clustering regularization)

  • 4.1.2 在参数空间的任务损失函数的最优点聚集


接下来讨论没有固定特征提取模块的元模型,这类模型的参数能够很好地适应新任务。对于 Reptile[10],作者提出了一种假设:该方法寻找的模型参数,是接近于很多任务的最优点,所以能够在微调之后在这些任务上达到较好的效果。为了验证这个猜想,本文将 Reptile 方法表示为类似于一致性最优化方法的形式(consensus optimization,使用一项惩罚来促进各个特定任务的模型收敛到共同的参数),最小化的目标函数为:



θ~ 是 task-specific 参数,θ是一致值(实际上是元参数),左项是针对任务 p 的损失函数,右项是距离惩罚项,引导模型参数收敛到一个一致值的附近。虽然 Reptile 实际上并没有很明显地使用第二项来得到最优的 task-specific 参数,但是它使用了θ作为 task-specific 模型的初始化参数,隐式地促使θ~ 是在θ附近。

为了验证参数聚集的作用,作者在原始 reptile 算法中内部循环(inner loop)的损失函数加上如下一项,进而提出权重聚集(Weight Clustering)方法


该项给出了针对某个任务 i 的模型参数θ^~_i 与当前训练批次所有任务的模型参数θ^~_p 的均值之间的距离。通过将 Reptile 方法结合该正则项,能够更显式地促使训练模型的参数聚集,在 1-shot 和 5-shot 实验中都能获得更优于传统训练方法、一阶 MAML 方法(FOMAML)和原始 Reptile 方法的效果。

图 4.4:使用了参数聚集正则化的 Reptile 算法(红色椭圆即为参数聚集相关的正则项)。图源:[7]

表 4.2:通过在 mini-ImageNet 上的对比实验,验证了增加惩罚项 Ri(即表中 W-Clustering 所示)对于模型效果的提升作用。来自:[7]

4.3 小结

本文对于元学习训练方法在少样本学习场景中的有效性进行了深入探讨,并提出了元学习得到的数据特征表示是不同于普通训练方法得到的数据特征表示的猜想。本文根据这个猜想设计了具有特征聚集特性权重聚集特性两种正则项,并分别应用到迁移学习方法和原始元学习方法中,验证了正则项对于提升模型效果的作用。

参考文献

[1] Vanschoren J. "Meta-Learning: A Survey". Arxiv:1810.03548, 2018.
[2] Bronskill, John, Jonathan Gordon, James Requeima, Sebastian Nowozin and R. Turner. "TaskNorm: Rethinking Batch Normalization for Meta-Learning". Proceedings of the 37th International Conference on Machine Learning (ICML), 2020.
[3] Chelsea Finn, Pieter Abbeel, and Sergey Levine. "Model-agnostic meta-learning for fast adaptation of deep networks". Proceedings of the 34th International Conference on Machine Learning (ICML), 2017.
[4] Flennerhag, Sebastian, Andrei A. Rusu, Razvan Pascanu, H. Yin and Raia Hadsell. "Meta-Learning with Warped Gradient Descent". ArXiv:1909.00025, 2020.
[5]Flennerhag, Sebastian, Moreno, Pablo G., Lawrence, Neil D., and Damianou, Andreas. Transferring knowledge across learning processes. In International Conference on Learning Representations, 2019.
[5] Yin, Mingzhang, G. Tucker, M. Zhou, S. Levine and Chelsea Finn. "Meta-Learning without Memorization". ArXiv: 1912.03820, 2020.
[6] Goldblum, Micah, S. Reich, Liam Fowl, Renkun Ni, V. Cherepanova and T. Goldstein. "Unraveling Meta-Learning: Understanding Feature Representations for Few-Shot Tasks" Proceedings of the 37th International Conference on Machine Learning (ICML), 2020.
[7] Snell, J., Swersky, K., and Zemel, R. Prototypical networks for few-shot learning. In Advances in Neural Information Processing Systems, pp. 4077–4087, 2017.
[8] Bertinetto, L., Henriques, J. F., Torr, P. H., and Vedaldi, A. Meta-learning with differentiable closed-form solvers. arXiv preprint arXiv:1805.08136, 2018.
[9] Lee, K., Maji, S., Ravichandran, A., and Soatto, S. Metalearning with differentiable convex optimization. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 10657–10665, 2019.
[10] Nichol, A. and Schulman, J. Reptile: a scalable metalearning algorithm. arXiv preprint arXiv:1803.02999, 2:2, 2018.

分析师介绍:

杨旭韵,工程硕士,主要研究方向是强化学习模仿学习以及元学习。现从事工业机器人相关的技术研究工作,主要负责机器学习算法落地应用的工作。

关于机器之心全球分析师网络 Synced Global Analyst Network

机器之心全球分析师网络是由机器之心发起的全球性人工智能专业知识共享网络。在过去的四年里,已有数百名来自全球各地的 AI 领域专业学生学者、工程专家、业务专家,利用自己的学业工作之余的闲暇时间,通过线上分享、专栏解读、知识库构建、报告发布、评测及项目咨询等形式与全球 AI 社区共享自己的研究思路、工程经验及行业洞察等专业知识,并从中获得了自身的能力成长、经验积累及职业发展。
理论元学习
1
相关数据
DeepMind机构

DeepMind是一家英国的人工智能公司。公司创建于2010年,最初名称是DeepMind科技(DeepMind Technologies Limited),在2014年被谷歌收购。在2010年由杰米斯·哈萨比斯,谢恩·列格和穆斯塔法·苏莱曼成立创业公司。继AlphaGo之后,Google DeepMind首席执行官杰米斯·哈萨比斯表示将研究用人工智能与人类玩其他游戏,例如即时战略游戏《星际争霸II》(StarCraft II)。深度AI如果能直接使用在其他各种不同领域,除了未来能玩不同的游戏外,例如自动驾驶、投资顾问、音乐评论、甚至司法判决等等目前需要人脑才能处理的工作,基本上也可以直接使用相同的神经网上去学而习得与人类相同的思考力。

https://deepmind.com/
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

交叉熵技术

交叉熵(Cross Entropy)是Loss函数的一种(也称为损失函数或代价函数),用于描述模型预测值与真实值的差距大小

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

人工智能技术

在学术研究领域,人工智能通常指能够感知周围环境并采取行动以实现最优的可能结果的智能体(intelligent agent)

基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

收敛技术

在数学,计算机科学和逻辑学中,收敛指的是不同的变换序列在有限的时间内达到一个结论(变换终止),并且得出的结论是独立于达到它的路径(他们是融合的)。 通俗来说,收敛通常是指在训练期间达到的一种状态,即经过一定次数的迭代之后,训练损失和验证损失在每次迭代中的变化都非常小或根本没有变化。也就是说,如果采用当前数据进行额外的训练将无法改进模型,模型即达到收敛状态。在深度学习中,损失值有时会在最终下降之前的多次迭代中保持不变或几乎保持不变,暂时形成收敛的假象。

学习率技术

在使用不同优化器(例如随机梯度下降,Adam)神经网络相关训练中,学习速率作为一个超参数控制了权重更新的幅度,以及训练的速度和精度。学习速率太大容易导致目标(代价)函数波动较大从而难以找到最优,而弱学习速率设置太小,则会导致收敛过慢耗时太长

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

超参数技术

在机器学习中,超参数是在学习过程开始之前设置其值的参数。 相反,其他参数的值是通过训练得出的。 不同的模型训练算法需要不同的超参数,一些简单的算法(如普通最小二乘回归)不需要。 给定这些超参数,训练算法从数据中学习参数。相同种类的机器学习模型可能需要不同的超参数来适应不同的数据模式,并且必须对其进行调整以便模型能够最优地解决机器学习问题。 在实际应用中一般需要对超参数进行优化,以找到一个超参数元组(tuple),由这些超参数元组形成一个最优化模型,该模型可以将在给定的独立数据上预定义的损失函数最小化。

元学习技术

元学习是机器学习的一个子领域,是将自动学习算法应用于机器学习实验的元数据上。现在的 AI 系统可以通过大量时间和经验从头学习一项复杂技能。但是,我们如果想使智能体掌握多种技能、适应多种环境,则不应该从头开始在每一个环境中训练每一项技能,而是需要智能体通过对以往经验的再利用来学习如何学习多项新任务,因此我们不应该独立地训练每一个新任务。这种学习如何学习的方法,又叫元学习(meta-learning),是通往可持续学习多项新任务的多面智能体的必经之路。

知识库技术

知识库是用于知识管理的一种特殊的数据库,以便于有关领域知识的采集、整理以及提取。知识库中的知识源于领域专家,它是求解问题所需领域知识的集合,包括基本事实、规则和其它有关信息。

梯度下降技术

梯度下降是用于查找函数最小值的一阶迭代优化算法。 要使用梯度下降找到函数的局部最小值,可以采用与当前点的函数梯度(或近似梯度)的负值成比例的步骤。 如果采取的步骤与梯度的正值成比例,则接近该函数的局部最大值,被称为梯度上升。

准确率技术

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

监督学习技术

监督式学习(Supervised learning),是机器学习中的一个方法,可以由标记好的训练集中学到或建立一个模式(函数 / learning model),并依此模式推测新的实例。训练集是由一系列的训练范例组成,每个训练范例则由输入对象(通常是向量)和预期输出所组成。函数的输出可以是一个连续的值(称为回归分析),或是预测一个分类标签(称作分类)。

目标函数技术

目标函数f(x)就是用设计变量来表示的所追求的目标形式,所以目标函数就是设计变量的函数,是一个标量。从工程意义讲,目标函数是系统的性能标准,比如,一个结构的最轻重量、最低造价、最合理形式;一件产品的最短生产时间、最小能量消耗;一个实验的最佳配方等等,建立目标函数的过程就是寻找设计变量与目标的关系的过程,目标函数和设计变量的关系可用曲线、曲面或超曲面表示。

工业机器人技术

工业机器人是面向工业加工制造的可自动控制,多用途,需有三轴及以上可编程的固定或可移动机械手。其系统中包括带有执行机构的机械手以及示教控制器。 它可以依靠自身控制能力来执行预设的轨迹及动作。典型应用包括焊接,刷漆,组装,采集和放置等工作。工业机器人完成工作具有高效性,持久性和准确性。目前常用的工业机器人包括关节机器人,SCARA机器人,并联机器人和直角坐标机器人等。

迁移学习技术

迁移学习是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中。迁移学习是通过从已学习的相关任务中转移知识来改进学习的新任务,虽然大多数机器学习算法都是为了解决单个任务而设计的,但是促进迁移学习的算法的开发是机器学习社区持续关注的话题。 迁移学习对人类来说很常见,例如,我们可能会发现学习识别苹果可能有助于识别梨,或者学习弹奏电子琴可能有助于学习钢琴。

过拟合技术

过拟合是指为了得到一致假设而使假设变得过度严格。避免过拟合是分类器设计中的一个核心任务。通常采用增大数据量和测试样本集的方法对分类器性能进行评价。

独立同分布技术

在概率论与统计学中,独立同分布(缩写为IID)是指一组随机变量中每个变量的概率分布都相同,且这些随机变量互相独立。一组随机变量独立同分布并不意味着它们的样本空间中每个事件发生概率都相同。例如,投掷非均匀骰子得到的结果序列是独立同分布的,但掷出每个面朝上的概率并不相同。

查询技术

一般来说,查询是询问的一种形式。它在不同的学科里涵义有所不同。在信息检索领域,查询指的是数据库和信息系统对信息检索的精确要求

正则化技术

当模型的复杂度增大时,训练误差会逐渐减小并趋向于0;而测试误差会先减小,达到最小值后又增大。当选择的模型复杂度过大时,过拟合现象就会发生。这样,在学习时就要防止过拟合。进行最优模型的选择,即选择复杂度适当的模型,以达到使测试误差最小的学习目的。

批次技术

模型训练的一次迭代(即一次梯度更新)中使用的样本集。

图像分类技术

图像分类,根据各自在图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法。它利用计算机对图像进行定量分析,把图像或图像中的每个像元或区域划归为若干个类别中的某一种,以代替人的视觉判读。

模仿学习技术

模仿学习(Imitation Learning)背后的原理是是通过隐含地给学习器关于这个世界的先验信息,就能执行、学习人类行为。在模仿学习任务中,智能体(agent)为了学习到策略从而尽可能像人类专家那样执行一种行为,它会寻找一种最佳的方式来使用由该专家示范的训练集(输入-输出对)。

强化学习技术

强化学习是一种试错方法,其目标是让软件智能体在特定环境中能够采取回报最大化的行为。强化学习在马尔可夫决策过程环境中主要使用的技术是动态规划(Dynamic Programming)。流行的强化学习方法包括自适应动态规划(ADP)、时间差分(TD)学习、状态-动作-回报-状态-动作(SARSA)算法、Q 学习、深度强化学习(DQN);其应用包括下棋类游戏、机器人控制和工作调度等。

信息论技术

信息论是在信息可以量度的基础上,研究有效地和可靠地传递信息的科学,它涉及信息量度、信息特性、信息传输速率、信道容量、干扰对信息传输的影响等方面的知识。通常把上述范围的信息论称为狭义的信息论,又因为它的创始人是香农,故又称为香农信息论。

机器之心机构

机器之心,成立于2014年,是国内最具影响力、最专业、唯一用于国际品牌的人工智能信息服务与产业服务平台。目前机器之心已经建立起涵盖媒体、数据、活动、研究及咨询、线下物理空间于一体的业务体系,为各类人工智能从业者提供综合信息服务和产业服务。

https://www.jiqizhixin.com/
Infor机构

Infor是一家跨国企业软件公司,总部设在美国纽约市。Infor专注于通过云计算作为服务交付给组织的业务应用。最初专注于从财务系统和企业资源规划(ERP)到供应链和客户关系管理的软件, Infor在2010年开始专注于工业利基市场的软件,以及用户友好的软件设计。Infor通过Amazon Web Services和各种开源软件平台部署云应用。

www.infor.com
相关技术
小样本学习技术

人类非常擅长通过极少量的样本识别一个新物体,比如小孩子只需要书中的一些图片就可以认识什么是“斑马”,什么是“犀牛”。在人类的快速学习能力的启发下,研究人员希望机器学习模型在学习了一定类别的大量数据后,对于新的类别,只需要少量的样本就能快速学习,这就是 Few-shot Learning 要解决的问题。

暂无评论
暂无评论~