张莹作者杨茹茵编辑

深度互学习-Deep Mutual Learning:三人行必有我师

编者按:更高性能的深度神经网络往往伴随着愈加庞大的参数量,而大量的计算需求使其难以部署在移动端。为此,精巧的网络结构设计(如MobileNet、ShuffleNet)、模型压缩策略(剪枝二值化等)及其他优化方法应运而生。

Hinton等人在2015年提出的模型蒸馏算法,利用预训练好的大网络当作教师来向小网络传递知识,从而提高小网络性能。而模型蒸馏算法需要有提前预训练好的大网络,且仅可对小网络进行单向的知识传递。古人云“三人行必有我师焉”,本文作者提出了一种“深度互学习Deep Mutual Learning”策略,使得小网络之间能够互相学习共同进步。

1.研究动机

近几年来,深度神经网络计算机视觉语音识别、语言翻译等领域中取得了令人瞩目的成果,为了完成更加复杂的信息处理任务,网络在设计上不断增加深度或宽度,使得模型参数量越来越大,如经典的VGG、Inception、ResNet系列网络。尽管更深或更宽的神经网络取得了更好的性能,大量计算需求使得它们难以部署在资源条件有限的环境中,如手机、平板、车载等移动端应用。这促使研究者们采用各种各样的方法去探索更高效的模型,如更精巧的网络结构设计MobileNet和ShuffleNet,还有网络压缩、剪枝二值化,以及比较有趣的模型蒸馏等。

模型蒸馏算法由Hinton等人在2015年提出,利用一个预训练好的大网络当作教师来提供小网络额外的知识即平滑后的概率估计,实验表明小网络通过模仿大网络估计的类别概率,优化过程变得更容易,且表现出与大网络相近甚至更好的性能。然而模型蒸馏算法需要有提前预训练好的大网络,且大网络在学习过程中保持固定,仅对小网络进行单向的知识传递,难以从小网络的学习状态中得到反馈信息来对训练过程进行优化调整。

我们尝试探索一种能够学习到更强大小网络的训练机制—深度互学习,即采用多个网络同时进行训练,每个网络在训练过程中不仅接受来自真值标记的监督,还参考同伴网络的学习经验来进一步提升泛化能力。在整个过程中,两个网络之间不断分享学习经验,实现互相学习共同进步。

2.算法描述图1 深度互学习算法框架具体来说,每个网络在学习过程中有两个损失函数,一个是传统的监督损失函数,采用交叉熵损失来度量网络预测的目标类别与真实标签之间的差异,另一个是网络间的交互损失函数,采用KL散度来度量两个网络预测概率分布之间的差异。公式表示为

采用这两种损失函数,不仅可以使得网络学习到如何区分不同的类别,还能够使其参考另一个网络的概率估计来提升自身泛化能力。

接下来我们给出网络的优化策略。对于单块GPU,我们采用交替迭代的方式依次更新两个网络,当有多块GPU时,我们可以采用分布式训练,每次迭代时两个网络同时计算概率估计差异并更新模型参数。实验发现分布式训练可以获得更好的性能。目前关于分布式训练为何能比串行训练获得更好的性能还未有比较好的理论解释,一些研究者认为在分布式训练中每个worker对附近参数空间的探索实际上提高了模型在连续梯度下降方面的统计性能。

我们提出的互学习算法也很容易扩展到多网络学习和半监督学习场景中。当有K个网络时,深度互学习学习每个网络时将其余K-1个网络分别作为教师来提供学习经验。另外一种策略是将其余K-1个网络融合后得到一个教师来提供学习经验 。在半监督互学习场景中,我们对有标签的数据计算监督损失和交互损失,而针对无标签数据我们仅计算交互损失来帮助网络从训练数据中挖掘更多有用信息。

3.实验结果

我们首先在CIFAR-10和CIFAR-100上用不同的网络做了实验,从表中可以看出,所有不同的网络组合采用深度互学习算法均可以提升分类准确率,这表明了我们算法具有较高的灵活性,对网络结构的适应性较强。一般来说小网络从互学习训练中获益更多,比如Resnet-32和MobileNet。尽管WRN-28-10网络参数量很大,与其它网络进行互学习训练依然可以获得性能提升。因此,不同于模型蒸馏算法需要预训练大网络来帮助小网络提升性能,我们提出的深度互学习算法也可以帮助参与训练的大网络来提升其性能。

表1 数据集CIFAR-10与CIFAR-100实验结果我们在ImageNet上也做了实验,从图2中可以看出采用互学习训练均可以提升网络在大规模分类任务上的性能。
图2 ImageNet实验结果针对多网络互学习,我们从图3看出增加网络数量可以提升互学习策略下的单个网络性能,这说明更多教师网络提供了更多学习经验,帮助网络学习到更好的特征。另一方面,多网络互学习中多个独立教师(DML)的性能会优于融合教师(DML_e),这说明多个不同教师网络可以提供更多样化的学习经验,更有益于每个网络的学习。
图3 多网络互学习实验结果针对半监督学习,从图4中可以看出,仅采用有标签数据参与训练时,深度互学习策略可以提高算法分类准确率。而当我们将未标记数据加入互学习训练中,网络的性能可以得到进一步提升,当标记样本数量较少时,其优势更明显。
图4 半监督深度互学习实验结果4.作用机制分析

那么,为什么互学习机制能起作用呢?为什么网络从头开始互学习训练也能收敛到更好的解而不是被互相拉低?当两个网络均从头开始训练时额外的知识从哪里来?为什么约束两个网络的概率估计相近可以提升泛化能力?经过互学习训练后两个网络是不是更相似了?

首先,为什么网络从头开始互学习训练也能收敛到更好的解而不是被互相拉低?直观解释如下:每个网络一开始采用随机初始化,类别概率估计接近于均匀分布,这使得它们在训练初期的监督损失较大,交互损失较小,每个网络主要由传统的监督损失函数引导,这样可以保证网络的性能在逐渐提升。随着模型参数更新,每个网络在自己的学习过程中获得不同的知识,它们对样本类别的概率估计也会有所不同,这时交互损失开始促进网络互相参考学习经验。

接下来是最关键的问题,为什么互学习机制起作用?当两个网络均从头开始训练时额外的知识从哪里来?为什么约束两个网络的概率估计相近可以提升泛化能力?我们从三个角度来尝试理解这些问题。

首先我们认为类别概率估计蕴含了网络挖掘到的数据本质规律。网络的泛化能力越强,则表示网络越有可能挖掘到了数据的内在本质特性,并可以通过类别概率估计表现出来。例如我们希望网络学习区分猫、狗、桌子三个类别,如图5所示,网络在对猫进行分类时除了要最大化猫的类别概率估计,还会给错误类别如狗和桌子分配一定概率,尽管该概率值很低,但我们仍希望分配给狗的概率要大于分配给桌子的概率,即希望网络除了学习到猫的特征,还能学习到和狗共有的一些特征,认为猫与狗的类别距离要小于猫与桌子的类别距离。这样网络在新的测试数据上就更有可能捕捉猫的多种特性,表现出较强的泛化能力。真值标签提供的信息仅包含样本是否属于某一类,但缺少不同类别之间的联系,而网络输出的类别概率估计则能够在一定程度上恢复该信息,因此网络之间进行类别概率估计交互可以传递学习到的数据分布特性,从而帮助网络改善泛化性能。

其次我们认为约束类别概率相近起到正则化作用。深度神经网络在训练过程中一般采用one-hot-vector方式编码真实类别分布,即认为观测样本属于某一类时,其概率值为1,否则为0。InceptionV3论文中认为这种真值标签编码会使得模型在训练过程中对预测结果太过确信,容易导致过拟合,于是提出标签平滑(Label Smoothing)策略,将正确类的概率分配一些给错误类,防止模型把预测值过度集中在较大概率上。Chaudhar等在ICLR2017论文中提出增加熵正则,约束网络预测输出的概率稍微平滑一点。在互学习算法中,当我们将网络2的类别概率传递给网络1时,本质上也是提供额外的类别先验约束,防止网络1过度拟合真值标签的0-1分布,有效降低过拟合发生概率。然而不一样的是,标签平滑和熵正则的类别概率约束是盲目的,而互学习算法中会有更多类别信息。

最后,我们认为网络在训练过程中会参考同伴网络的经验来调整自己的学习过程,最终能够收敛到一个更平缓的极小值点,从而具备更好的泛化性能。关于神经网络泛化性能的一些研究认为,尽管深度神经网络可以找到很多解(即网络学习到的参数)使得训练损失降到零,但一些解能够比其它解具有更好的泛化性能,其原因在于这些解处于更平缓的极小点,这意味着小的波动不会对网络的预测结果造成剧烈影响。

那么我们的深度互学习算法是不是帮助网络找到了一个更平缓的极小点呢?我们进行了实验验证,首先我们观测了两种训练策略下网络在训练数据集上的损失函数变化,从图(a)可以看出单独训练及互学习训练的网络都可以充分拟合训练数据,训练集上的分类准确率都可以达到100%,且训练损失都可以降到几乎相同的极小值。这说明深度互学习算法并没有帮助网络找一个更深的极小值点来帮助网络在训练集上实现损失更小,而是有可能找到了一个深度相同但更平缓的极小值点。

图5 深度互学习作用机制分析为了验证该猜想,我们对两种策略训练好的网络参数添加高斯噪声,并在图(b)中比较了添加不同方差高斯噪声后网络损失函数值的变化。从图中可以看出,单独训练的网络在添加噪声后损失函数值波动很大,而互学习训练网络的损失函数值则增加很小。该实验现象表明深度互学习算法帮助网络找到了一个更平缓的极小点,针对噪声具有更强的鲁棒性,从而具有更好的泛化性能。

那么深度互学习是如何帮助网络找到更好的解呢? 我们注意到深度互学习算法要求一个网络1的概率估计与同伴网络2的概率估计相匹配,网络1在某个类别上估计概率为为零而网络2估计不为零时,就会产生比较大的惩罚。因此当多个网络参与训练时,每个网络针对样本估计的概率值会分布在不同的类别上,监督损失函数会使得网络在第一最大类上产生较大的概率估计,而剩余的概率值会依次分布在第二最大类及之后的类别上。当两个网络类别概率估计在这些第二类别有差异时,KL损失函数会使两个网络相互妥协,每个网络将分出一些概率值给更接近真值类的第二最大类及之后类别,帮助网络挖掘更多类别信息来找到更好的解。从图上可以看出,采用深度互学习算法可以使得训练集上类别概率分布估计更平缓,且不同类别的相对距离也更明显。

5.结论与展望

我们提出了一个简单有效的互学习算法,通过采用两个网络联合训练来提升深度神经网络的泛化性能。该算法不仅可以用于训练高效的小网络,也可以进一步提升大网络性能,且容易扩展到多网络学习及半监督学习场景中。我们对算法的作用机制进行了探索分析,尝试从网络泛化能力和寻找到解的性质来分析深度互学习算法有效的原因。

代码:

https://github.com/YingZhangDUT/Deep-Mutual-Learning

论文:

http://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Deep_Mutual_Learning_CVPR_2018_paper.pdf

作者简介:

张莹,大连理工大学2015级博士生,导师卢湖川教授,研究方向为行人搜索,包括行人再识别和跨模态行人搜索。目前已发表论文9篇,其中第一作者论文5篇,包括2篇CVPR,ECCV等。2016年赴博二期间就读于伦敦玛丽女王大学进行联合培养,指导教师为向滔教授,合作导师Timothy M. Hospedales。
深度学习大讲堂
深度学习大讲堂

高质量原创内容平台,邀请学术界、工业界一线专家撰稿,致力于推送人工智能与深度学习最新技术、产品和活动信息。

理论半监督学习监督学习分布式计算技术智慧社会深度互学习
4
相关数据
张莹人物

佐治亚理工学院电气与计算机工程系副教授,感知器与智能系统实验室主任。

半监督学习技术

半监督学习属于无监督学习(没有任何标记的训练数据)和监督学习(完全标记的训练数据)之间。许多机器学习研究人员发现,将未标记数据与少量标记数据结合使用可以显着提高学习准确性。对于学习问题的标记数据的获取通常需要熟练的人类代理(例如转录音频片段)或物理实验(例如,确定蛋白质的3D结构或确定在特定位置处是否存在油)。因此与标签处理相关的成本可能使得完全标注的训练集不可行,而获取未标记的数据相对便宜。在这种情况下,半监督学习可能具有很大的实用价值。半监督学习对机器学习也是理论上的兴趣,也是人类学习的典范。

交叉熵技术

交叉熵(Cross Entropy)是Loss函数的一种(也称为损失函数或代价函数),用于描述模型预测值与真实值的差距大小

二值化技术

二值化是将像素图像转换为二进制图像的过程。

VGG技术

2014年,牛津大学提出了另一种深度卷积网络VGG-Net,它相比于AlexNet有更小的卷积核和更深的层级。AlexNet前面几层用了11×11和5×5的卷积核以在图像上获取更大的感受野,而VGG采用更小的卷积核与更深的网络提升参数效率。VGG-Net 的泛化性能较好,常用于图像特征的抽取目标检测候选框生成等。VGG最大的问题就在于参数数量,VGG-19基本上是参数量最多的卷积网络架构。VGG-Net的参数主要出现在后面两个全连接层,每一层都有4096个神经元,可想而至这之间的参数会有多么庞大。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

剪枝技术

剪枝顾名思义,就是删去一些不重要的节点,来减小计算或搜索的复杂度。剪枝在很多算法中都有很好的应用,如:决策树,神经网络,搜索算法,数据库的设计等。在决策树和神经网络中,剪枝可以有效缓解过拟合问题并减小计算复杂度;在搜索算法中,可以减小搜索范围,提高搜索效率。

概率分布技术

概率分布(probability distribution)或简称分布,是概率论的一个概念。广义地,它指称随机变量的概率性质--当我们说概率空间中的两个随机变量具有同样的分布(或同分布)时,我们是无法用概率来区别它们的。

收敛技术

在数学,计算机科学和逻辑学中,收敛指的是不同的变换序列在有限的时间内达到一个结论(变换终止),并且得出的结论是独立于达到它的路径(他们是融合的)。 通俗来说,收敛通常是指在训练期间达到的一种状态,即经过一定次数的迭代之后,训练损失和验证损失在每次迭代中的变化都非常小或根本没有变化。也就是说,如果采用当前数据进行额外的训练将无法改进模型,模型即达到收敛状态。在深度学习中,损失值有时会在最终下降之前的多次迭代中保持不变或几乎保持不变,暂时形成收敛的假象。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

计算机视觉技术

计算机视觉(CV)是指机器感知环境的能力。这一技术类别中的经典任务有图像形成、图像处理、图像提取和图像的三维推理。目标识别和面部识别也是很重要的研究领域。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

梯度下降技术

梯度下降是用于查找函数最小值的一阶迭代优化算法。 要使用梯度下降找到函数的局部最小值,可以采用与当前点的函数梯度(或近似梯度)的负值成比例的步骤。 如果采取的步骤与梯度的正值成比例,则接近该函数的局部最大值,被称为梯度上升。

准确率技术

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

过拟合技术

过拟合是指为了得到一致假设而使假设变得过度严格。避免过拟合是分类器设计中的一个核心任务。通常采用增大数据量和测试样本集的方法对分类器性能进行评价。

语音识别技术

自动语音识别是一种将口头语音转换为实时可读文本的技术。自动语音识别也称为语音识别(Speech Recognition)或计算机语音识别(Computer Speech Recognition)。自动语音识别是一个多学科交叉的领域,它与声学、语音学、语言学、数字信号处理理论、信息论、计算机科学等众多学科紧密相连。由于语音信号的多样性和复杂性,目前的语音识别系统只能在一定的限制条件下获得满意的性能,或者说只能应用于某些特定的场合。自动语音识别在人工智能领域占据着极其重要的位置。

正则化技术

当模型的复杂度增大时,训练误差会逐渐减小并趋向于0;而测试误差会先减小,达到最小值后又增大。当选择的模型复杂度过大时,过拟合现象就会发生。这样,在学习时就要防止过拟合。进行最优模型的选择,即选择复杂度适当的模型,以达到使测试误差最小的学习目的。

深度神经网络技术

深度神经网络(DNN)是深度学习的一种框架,它是一种具备至少一个隐层的神经网络。与浅层神经网络类似,深度神经网络也能够为复杂非线性系统提供建模,但多出的层次为模型提供了更高的抽象层次,因而提高了模型的能力。

推荐文章
暂无评论
暂无评论~