Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

思源、一鸣编译

利用LSTM思想来做CNN剪枝,北大提出Gate Decorator

利用LSTM基本思想门控机制进行剪枝?让模型自己决定哪些卷积核可以扔。

还记得在理解 LSTM 的时候,我们会发现,它用一种门控机制记住重要的信息而遗忘不重要的信息。在此之后,很多机器学习方法都受到了门控机制的影响,包括 Highway Network 和 GRU 等等。北大的研究者同样也是,它们将门控机制加入到 CNN 剪枝中,让模型自己决定哪些滤波器不太重要,那么它们就可以删除了。

其实对滤波器进行剪枝是一种最为有效的、用于加速和压缩卷积神经网络的方法。在这篇论文中,来自北大的研究者提出了一种全局滤波器剪枝的算法,名为「门装饰器(gate decorator)」。这一算法可以通过将输出和通道方向的尺度因子(门)相乘,进而改变标准的 CNN 模块。当这种尺度因子被设0的时候,就如同移除了对应的滤波器。

研究人员使用了泰勒展开,用于估计因设定了尺度因子为 0 时对损失函数造成的影响,并用这种估计值来给全局滤波器的重要性进行打分排序。接着,研究者移除哪些不重要的滤波器。在剪枝后,研究人员将所有的尺度因子合并到原始的模块中,因此不需要引入特别的运算或架构。此外,为了提升剪枝的准确率,研究者还提出了一种迭代式的剪枝架构—— Tick-Tock。

图 1:滤波器剪枝图示。第 i 个层有4个滤波器(通道)。如果移除其中一个,对应的特征映射就会消失,而输入 i+1 层的通道也会变为3。

扩展实验说明了研究者提出的方法的效果。例如,研究人员在 ResNet-56 上达到了剪枝比例最好的 SOTA,减少了 70% 的每秒浮点运算次数,但没有带来明显的准确率降低。

在 ImageNet 上训练的 ResNet-50 上,研究者减少了 40% 的每秒浮点运算次数,且在 top-1 准确率上超过了基线模型 0.31%。在研究中使用了多种数据,包括 CIFAR-10、CIFAR-100、CUB-200、ImageNet ILSVRC-12 和 PASCAL VOC 2011。

本文的主要贡献包括两个部分:第一部分是「门装饰器」算法,用于解决 GFIR 问题。第二部分是 Tick-Tock 剪枝框架,用于提升剪枝准确率。

具体而言,研究者展示了如何将门装饰器用于批归一化操作,并将这种方法命名为门批归一化(GBN)。给定预训练模型,研究者在剪枝前将归一化模块转换成门批归一化剪枝结束后,他们将门批归一化还原为批归一化。通过这样的方法,不需要给模型引入特殊的运算或架构。

  • 论文地址:https://arxiv.org/abs/1909.08174

  • 实现地址:https://github.com/youzhonghui/gate-decorator-pruning

门控剪枝到底怎么做

那么到底怎样使用门控机制解决全局滤波器重要性排序呢?研究者表示他们会先将 Gate Decorator 应用到批归一化机制中,然后使用一种名为 Tick-Tock 的迭代剪枝框架来获得更好的剪枝准确率,最后再采用分组剪枝(Group Pruning)技术解决待条件的剪枝问题,例如剪枝带残差连接的网络。

上面简要展示了叙述了门控剪枝三步走,后面会做一个简单的介绍,当然更详细的内容可查阅原论文。

门控批归一化

研究者将 Gate Decorator应用到批归一化中,并将该模块称之为门控批归一化(GBN),门控批归一化如下方程7所示,它和标准批归一化的不同之处在于 φ arrow的门控选择。其中 φ arrow 是 φ 的一个向量,c 是 Z_in 的通道数。

如果 φ arrow 中的元素是零,那么就表示它对应的通道被裁减了。此外,对于不使用BN 的网络,我们也可以直接将 Gate Decorator 应用到卷积运算中,从而达到门控剪枝的效果。

Tick-Tock 剪枝框架

研究者还引进了一种迭代式的剪枝框架,从而提升剪枝准确率,他们将该框架称为Tick-Tok。其中 Tick 阶段会在训练数据的子集上执行,卷积核会被设定为不可更新状态。而 Tock 阶段使用全部训练数据,并将稀疏约束 φ 添加到损失函数中。

图2:Tick-Tock剪枝框架图示。

其中 Tick 阶段主要希望能实现以下三个目标:加速剪枝过程;计算每一个滤波器的重要性分数 Θ;降低前面剪枝引起的内部协变量迁移问题。

在 Tick 阶段中,研究者会在训练数据的子集中训练一个 Epoch,我们仅允许门控  φ 和最终的线性层能更新,这样能大大降低小数据集上的过拟合风险。通过训练后,模型会根据重要性分数 Θ 排序所有的滤波器,并将不那么重要的滤波器移除。

在 Tock 阶段前,Tick 阶段能重复 T 次。Tock 阶段会微调网络以降低总体误差,这些误差可能是由于一处滤波器造成的。此外,Tock 阶段和一般的微调过程有两大不同:微调比 Tock 要训练更多的 Epoch;微调并不会给损失函数加上稀疏性约束。

分组剪枝:解决带约束的剪枝问题

ResNet 和其变体包含残差连接,也就是在两个残差块产生的特征图上执行元素级的加法。如果单独修剪每个层的滤波器,可能会导致残差连接中特征图对不齐。这可以视为一种带约束的剪枝问题,我们希望剪枝是在对齐特征图的条件下完成的。

为了解决无法对齐的问题,作者们提出了分组剪枝:将通过纯残差方式连接的 GBN 分配给同一组。纯残差连接是指在侧分支上没有卷积层的一种方式,如图3所示。

图3:组剪枝展示。同样颜色的GBN属于同一组。

每一组可以视为一个 Virtual GBN,它的所有组成卷积共享了相同的剪枝模式。并且在分组中,滤波器的重要性分数就是成员卷积分数的和。

实验设置和数据集

数据集

研究者使用了多种数据集,包括 CIFAR-10,CIFAR-100,CUB-200,  ImageNet ILSVRC-12和 PASCAL VOC 2011。CIFAR-10 数据集包括了50K的训练数据和10K的测试数据。CIFAR-100和CIFAR-10相同,但有100个类别,每个类别有600张图片。CUB-200包括了将近6000张训练图片和5700张测试图片,涵盖了200种鸟类。ImageNet ILSVRC-12有128万训练图像和50K的测试图像,覆盖1000个类别。研究者还使用了PASCAL VOC 2011分割数据集和其扩展数据集SBD,它有20个类别,共8498张训练样本图片和2857张测试样本图片。

剪枝的模型

研究者使用了三种网络架构进行剪枝:VGGNet、ResNet和FCN。所有的网络都使用SGD进行训练,权重衰减和动量超参数分别设定为10-4和0.9。

研究者使用了多种训练数据和不同的批大小对这些网络进行了训练,同时加入了一些数据增强的方法。

剪枝阶段,研究者在每个Tick阶段剪去ResNet0.2%的滤波器,在VGG和FCN上减去1%的滤波器。在每10个Tick操作后进行一次Tock操作。

剪枝效果

表1:在 ResNet-56上,使用CIFAR-10训练的模型剪枝后的表现。基线准确率为93.1%。

表 2:在ResNet-50上,使用ImageNe训练的模型剪枝后的表现。P.Top-1、P.Top-5 分别表示 top-1和 top-5剪枝后的模型在验证集上的单中心裁剪准确率。[Top-1] ↓ 和 [Top-5] ↓分别表示剪枝后模型准确率和基线模型相比的下降情况。Global 表示这一剪枝方法是否是全局滤波器剪枝算法。

图4:VGG-16-M在CUB-200数据集上的剪枝效果。

下图5的基线模型是VGG-16-M,他在CIFAR-100上的测试准确率为73.19%。其中「shrunk」版表示将所有卷积层的通道数减半,因此将FLOPs降低到了基线模型的1/4,从头训练后它的测试准确率会降低1.98%。「pruned」版表示采用Tick-Tock框架进行剪枝的结果,它的测试准确率会降低1.3%。

如果我们从头训练「pruned」版模型,那么它的准确率能达到71.02%,相当于降低了2.17%。不过重要的是,「pruned」版模型的参数量只有「shrunk」版模型的1/3。

图5:两种网络的效果和通道数对比,它们有相同的FLOPs。

理论北大Gate Decorator
7
相关数据
剪枝技术

剪枝顾名思义,就是删去一些不重要的节点,来减小计算或搜索的复杂度。剪枝在很多算法中都有很好的应用,如:决策树,神经网络,搜索算法,数据库的设计等。在决策树和神经网络中,剪枝可以有效缓解过拟合问题并减小计算复杂度;在搜索算法中,可以减小搜索范围,提高搜索效率。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

超参数技术

在机器学习中,超参数是在学习过程开始之前设置其值的参数。 相反,其他参数的值是通过训练得出的。 不同的模型训练算法需要不同的超参数,一些简单的算法(如普通最小二乘回归)不需要。 给定这些超参数,训练算法从数据中学习参数。相同种类的机器学习模型可能需要不同的超参数来适应不同的数据模式,并且必须对其进行调整以便模型能够最优地解决机器学习问题。 在实际应用中一般需要对超参数进行优化,以找到一个超参数元组(tuple),由这些超参数元组形成一个最优化模型,该模型可以将在给定的独立数据上预定义的损失函数最小化。

验证集技术

验证数据集是用于调整分类器超参数(即模型结构)的一组数据集,它有时也被称为开发集(dev set)。

卷积神经网络技术

卷积神经网路(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。卷积神经网路由一个或多个卷积层和顶端的全连通层(对应经典的神经网路)组成,同时也包括关联权重和池化层(pooling layer)。这一结构使得卷积神经网路能够利用输入数据的二维结构。与其他深度学习结构相比,卷积神经网路在图像和语音识别方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他深度、前馈神经网路,卷积神经网路需要考量的参数更少,使之成为一种颇具吸引力的深度学习结构。 卷积网络是一种专门用于处理具有已知的、网格状拓扑的数据的神经网络。例如时间序列数据,它可以被认为是以一定时间间隔采样的一维网格,又如图像数据,其可以被认为是二维像素网格。

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

批归一化技术

批归一化(Batch Normalization,BN)由谷歌于2015年提出,是一个深度神经网络训练的技巧,它不仅可以加快模型的收敛速度,还能在一定程度上缓解深层网络中的“梯度弥散”问题,从而使得训练深层网络模型更加容易和稳定。目前BN已经成为几乎所有卷积神经网络的标配技巧了。从字面意思看来Batch Normalization(简称BN)就是对每一批数据进行归一化。

推荐文章
暂无评论
暂无评论~