吴仕超作者

Tree-CNN:一招解决深度学习中的「灾难性遗忘」

本期推荐的论文笔记来自 PaperWeekly 社区用户 @Cratial深度学习领域一直存在一个比较严重的问题——“灾难性遗忘”,即一旦使用新的数据集去训练已有的模型,该模型将会失去对原数据集识别的能力。

为解决这一问题,本文提出了树卷积神经网络,通过先将物体分为几个大类,然后再将各个大类依次进行划分、识别,就像树一样不断地开枝散叶,最终叶节点得到的类别就是我们所要识别的类。

关于作者:吴仕超,东北大学硕士生,研究方向为脑机接口、驾驶疲劳检测和机器学习

■ 论文 | Tree-CNN: A Deep Convolutional Neural Network for Lifelong Learning

■ 链接 | https://www.paperweekly.site/papers/1839

■ 作者 | Deboleena Roy / Priyadarshini Panda / Kaushik Roy

网络结构及学习策略

网络结构 

Tree-CNN 模型借鉴了层分类器,树卷积神经网络由节点构成,和数据结构中的树一样,每个节点都有自己的 ID、父亲(Parent)及孩子(Children),网(Net,处理图像的卷积神经网络),LT("Labels Transform",就是每个节点所对应的标签,对于根节点和枝节点来说,可以是对最终分类类别的一种划分,对于叶节点来说,就是最终的分类类别),其中最顶部为树的根节点。

本文提出的网络结构如下图所示。对于一张图像,首先会将其送到根节点网络去分类得到“super-classes”,然后根据所识别到的“super-classes”,将图像送入对应的节点做进一步分类,得到一个更“具体”的类别,依次进行递推,直到分类出我们想要的类。

▲ 图1

其实这就和人的识别过程相似,例如有下面一堆物品:数学书、语文书、物理书、橡皮、铅笔。如果要识别物理书,我们可能要经历这样的过程,先在这一堆中找到书,然后可能还要在书里面找到理科类的书,然后再从理科类的书中找到物理书,同样我们要找铅笔的话,我们可能需要先找到文具类的物品,然后再从中找到铅笔。

学习策略 

在识别方面,Tree-CNN 的思想很简单。如图 1 所示,主要就是从根节点出发,输出得到一个图像属于各个大类的概率,根据最大概率所对应的位置将识别过程转移到下一节点,这样最终我们能够到达叶节点,叶节点对应得到的就是我们要识别的结果。整个过程如图 2 所示。

▲ 图2

如果仅按照上面的思路去做识别,其实并没有太大的意义,不仅使识别变得很麻烦,而且在下面的实验中也证明了采用该方法所得到的识别率并不会有所提高。而这篇论文最主要的目的就是要解决我们在前面提到的“灾难性遗忘问题”,即文中所说的达到“lifelong”的效果。 

对于新给的类别,我们将这些类的图像输入到根节点网络中,根节点的输出为 OK×M×I,其中 K、M、I 分别为根节点的孩子数、新类别数、每类的图像数。

然后利用式(1)来求得每类图像的输出平均值 Oavg,再使用 softmax 来计算概率情况。以概率分布表示该类与根节点下面子类的相似程度。对于第 m 类,我们按照其概率分布进行排列,得到公式(3)。

根据根节点得到的概率分布,文中分别对下面三种情况进行了讨论: 

  • 当输出概率中最大概率大于设定的阈值,则说明该类别和该位置对应的子节点有很大的关系,因此将该类别加到该子节点上; 

  • 若输出概率中有多个概率值大于设定的阈值,就联合多个子节点来共同组成新的子节点; 

  • 如果所有的输出概率值都小于阈值,那么就为新类别增加新的子节点,这个节点是一个叶节点。 

同样,我们将会对别的支节点继续上面的操作。通过上面的这些操作,实现对新类别的学习,文中称这种学习方式为 incremental/lifelong learning。

实验方法与结果分析

在这部分,作者分别针对 CIFAR-10 及 CIFAT-100 数据集上进行了测试。 

实验方法 

1. CIFAR-10 

在 CIFAR-10 的实验中,作者选取 6 类图像作为初始训练集,又将 6 类中的为汽车、卡车设定为交通工具类,将猫、狗、马设为动物类,因此构建出的初始树的结构如图 3(a)所示

▲ 图3

具体网络结构如图 4 所示,根节点网络是包含两层卷积、两层池化卷积神经网络,支节点是包含 3 层卷积卷积神经网络

▲ 图4

当新的类别出现时(文中将 CIFAR-10 另外 4 个类别作为新类别),按照文中的学习策略,我们先利用根节点的网络对四种类别的图片进行分类,得到的输出情况如图 5 所示,从图中可以看出,在根节点的识别中 Frog、Deer、Bird 被分类为动物的概率很高,Airplane 被分类为交通工具的概率较高。

▲ 图5

根据文中的策略,Frog、Deer、Bird 将会被加入到动物类节点,同样 Airplane 将会被加入到交通工具类节点。经过 incremental/lifelong learning 后的 Tree-CNN 的结构如图 3(b)所示。 具体训练过程如图 6 所示。

▲ 图6

为了对比 Tree-CNN 的效果,作者又搭建了一个包含 4 层卷积神经网络,并分别通过调节全连接层、全连接 +conv1、全连接 +conv1+conv2、全连接 +conv1+conv2+conv3、全连接 +conv1+conv2+conv3+conv4 的参数来进行微调。 

2. CIFAR-100 

对于 CIFAR-100 数据集,作者将 100 类数据分为 10 组,每组包含 10 类样本。在网络方面,作者将根节点网络的卷积层改为 3,并改变了全连接层的输出数目。

实验结果分析

在这部分,作者通过设置两个参数来衡量 Tree-CNN 的性能

其中,Training Effort 表示 incremental learning 网络的更改程度,即可以衡量“灾难性遗忘”的程度,参数改变的程度越高,遗忘度越强。 

图 7 比较了在 CIFAR-10 上微调网络和 Tree-CNN 的识别效果对比,可以看出相对于微调策略,Tree-CNN 的 Training Effort 仅比微调全连接层高,而准确率却能超出微调全连接层 +conv1。

▲ 图7

这一现象在 CIFAR-100 中表现更加明显。

▲ 图8

从图 7、图 8 中可以看出 Tree-CNN 的准确率已经和微调整个网络相差无几,但是在 Training Effort 上却远小于微调整个网络。 

从图 9 所示分类结果中可以看出,在各个枝节点中,具有相同的特性的类被分配在相同的枝节点中。这一情况在 CIFAR-100 所得到的 Tree-CNN 最终的结构中更能体现出来。

除了一些叶节点外,在语义上具有相同特征的物体会被分类到同一支节点下,如图 10 所示。

▲ 图10

总结与分析

本文虽然在一定程度上减少了神经网络“灾难性遗忘”问题,但是从整篇文章来看,本文并没能使网络的识别准确率得到提升,反而,相对于微调整个网络来说,准确率还有所降低。

此外,本文搭建的网络实在太多,虽然各个子网络的网络结构比较简单,但是调节网络会很费时。

入门
相关数据
神经网络技术
Neural Network

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

卷积神经网络技术
Convolutional neural network

卷积神经网路(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。卷积神经网路由一个或多个卷积层和顶端的全连通层(对应经典的神经网路)组成,同时也包括关联权重和池化层(pooling layer)。这一结构使得卷积神经网路能够利用输入数据的二维结构。与其他深度学习结构相比,卷积神经网路在图像和语音识别方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他深度、前馈神经网路,卷积神经网路需要考量的参数更少,使之成为一种颇具吸引力的深度学习结构。 卷积网络是一种专门用于处理具有已知的、网格状拓扑的数据的神经网络。例如时间序列数据,它可以被认为是以一定时间间隔采样的一维网格,又如图像数据,其可以被认为是二维像素网格。

机器学习技术
Machine Learning

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

池化技术
Pooling

池化(Pooling)是卷积神经网络中的一个重要的概念,它实际上是一种形式的降采样。有多种不同形式的非线性池化函数,而其中“最大池化(Max pooling)”是最为常见的。它是将输入的图像划分为若干个矩形区域,对每个子区域输出最大值。直觉上,这种机制能够有效的原因在于,在发现一个特征之后,它的精确位置远不及它和其他特征的相对位置的关系重要。池化层会不断地减小数据的空间大小,因此参数的数量和计算量也会下降,这在一定程度上也控制了过拟合。通常来说,CNN的卷积层之间都会周期性地插入池化层。

参数技术
parameter

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

深度学习技术
Deep learning

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法。观测值(例如一幅图像)可以使用多种方式来表示,如每个像素强度值的向量,或者更抽象地表示成一系列边、特定形状的区域等。而使用某些特定的表示方法更容易从实例中学习任务(例如,人脸识别或面部表情识别)。 近年来监督式深度学习方法(以反馈算法训练CNN、LSTM等)获得了空前的成功,而基于半监督或非监督式的方法(如DBM、DBN、stacked autoencoder)虽然在深度学习兴起阶段起到了重要的启蒙作用,但仍处在研究阶段并已获得不错的进展。在未来,非监督式学习将是深度学习的重要研究方向,因为人和动物的学习大多是非监督式的,我们通过观察来发现世界的构造,而不是被提前告知所有物体的名字。 至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

准确率技术
Accuracy

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

概率分布技术
Probability distribution

卷积技术
Convolution

返回顶部