Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

新型池化层sort_pool2d实现更快更好敛:表现优于最大池化层(附代码实现)

近日,Sahil Singla 在 Medium 上发表了一篇题为《A new kind of pooling layer for faster and sharper convergence》的文章,提出一种可实现更快更好收敛的新型池化层 sort_pool2d,表现优于最大池化层,同时解决了最大池化层无法使用来自多层激活函数信息的问题,以及反向传播只会提升最大池化的激活函数的问题。作者还给出了该池化层的代码实现。作者表示,介绍此研究的论文将会提交至 arXiv 上。
  • sort_pool2d 代码:https://github.com/singlasahil14/sortpool2d/blob/master/sortpool2d_test.py
  • sort_pool2d 实现:https://github.com/singlasahil14/sortpool2d/blob/master/sort_pool2d.py

现在的最大池化层(应用于当前几乎所有顶级的计算机视觉任务,甚至是一些 NLP 任务中)大约去掉了 75% 的激活函数。本文试图设计一种新型池化层 sort_pool2d,摆脱与之关联的一些问题。

问题如下:

  1. 空间信息损失:当去掉 75% 的激活函数时,关于其来源的信息就会丢失。
  2. 最大池化层无法使用来自多层激活函数的信息。
  3. 反向传播只会提升最大池化的激活函数,即使其他激活函数的值可能出现错误。

本文想要设计一种新型池化层,尽可能多地解决这些问题。在这一过程中,我想出了一个非常简单的技巧,可以解决问题 2 和 3。

想法与动机

按照渐增的顺序排列 4 个激活函数,而不是选择最大的那个。用 4 个权重 [w1,w2,w3,w4] 与之相乘,并添加 4 个值。

这一想法背后的动机也非常简单:

用这种方式,网络依然能够学习对应于 [w1,w2,w3,w4] = [1,0,0,0] 的良好的、旧的最大池化。

后面的层可以获取更多信息。因此,在非最大激活函数可用于降低损失函数时,网络只可学习使用其他值。

梯度流过上一层中的所有 4 个值(相比之下,最大池化层只有一个值)。

因此我希望基于上述原因,这一想法能够比最大池化层做的更好。这是一个非常少见的深度学习实验,其结果与我设想的简直一模一样。

具体定义

设池化之前的层的输出为张量 T,大小为 [B, H, W, C]。定义一个超参数 pool_range,它可以是 [1,2,3,4] 中的任意一个。pool_range 指定激活函数(按照排列顺序保存)的数量。假设要被池化的张量 T 有 4 个激活函数,我首先按照 [a1, a2, a3, a4] 的顺序排列它们,其中 a1 ≥ a2 ≥ a3 ≥ a4。接着保留其中的第一个 pool_range,我称之为激活向量。

我将 pool_range 的权重向量定义为 [w{1},.... w{pool_range}]。这里需要注意的是,如果这些权重中的任何一个是负值,则激活向量按强度排序且采用加权平均的假设将不成立。因此,我没有直接使用权重,而是在权重向量上取一个 softmax,并将结果乘以激活向量。为了证明添加 softmax 的重要性,我在 cluttered-mnist 数据集上进行了一个 toy experiment,当 pool_range=3 时看看有无 softmax 的区别。以下是测试数据集上的结果:


cluttered-mnist 数据集测试数据上的精确度(accuracy)对比与交叉熵对比

很明显,softmax 在这里是赢家。

我本来还可以对不同的通道使用不同的权重,但是为了便于与最大池化进行对比,我在不同通道上使用了 4 个相同的权重。

实现细节

我在 TnsorFlow 中写了该层的代码。TnsorFlow 的 top_k 层在 CPU 上非常快但是在 GPU 上非常慢。为了对这 4 个浮点数进行排序,我亲自写了程序。测试 sort_pool2d 的代码请参见 https://github.com/singlasahil14/sortpool2d/blob/master/sortpool2d_test.py。导入并将其作为层来实现的代码请参见 https://github.com/singlasahil14/sortpool2d/blob/master/sort_pool2d.py。

结果

我在不同的数据集和架构上尝试了这一想法,发现其性能全部优于基线最大池化。所有实验使用 pool_range 的 4 个值:1,2,3,4。pool_range=1 对应最大池化。

以下是我的实验结果:

cluttered-mnist

cluttered-mnist 和 fashion-mnist 上的 toy experiment


cluttered-mnist 数据集训练数据上的精确度与交叉熵对比



cluttered-mnist 数据集测试数据上的精确度与交叉熵对比


测试数据训练中的最佳精确度与交叉熵的值

网络获得的训练损失和精确度是相同的,但是 pool_range = 2,3,4 的验证精确度要远好于标准的最大池化。

fashion-mnist

fashion-mnis 数据集训练数据上的精确度与交叉熵对比


fashion-mnis 数据集测试数据上的精确度与交叉熵对比


训练过程中,测试数据上的最优精确度和交叉熵的值

pool_range>1 时,结果要好得多。

当前最佳模型上的实验

resnet 上的 cifar-100


cifar-10 数据集训练数据上的精确度与交叉熵对比


cifar-10 数据集测试数据上的精确度与交叉熵对比


训练过程中,测试数据上的最优精确度与交叉熵的值

pool_range>1 时,结果更好。

resnet 上的 cifar-100


在 cifar-100 数据集训练数据上的精确度与交叉熵对比


在 cifar-100 数据集测试数据上的精确度与交叉熵对比


训练过程中,在测试数据上的最优精确度与交叉熵的值

pool_range>1 时,结果更好。这里的结果优于 cifar-10 的结果,因为 cifar-10 拥有的每个类别的数据较少。这表明这个想法对解决每个类别数据较少的问题效果很好。

匹配网络中的 omniglot

我尝试使用论文《Matching Networks for one shot learning》提出的架构在 omniglot 数据集上对比 20 种方式的一次分类结果。


omniglot 数据集训练数据上的精确度与损失对比


omniglot 数据集验证数据上的精确度与损失对比


训练过程中,在验证数据上最优精确度与损失的值

注意:该实现使用该论文提出的已正则化的当前最佳的实现。因此,这些改进超过很多现有技巧。

论文《学习记忆罕见事件》(Learning to Remember Rare Events)中的 omniglot

我尝试使用《学习记忆罕见事件》论文中提出的架构在 omniglot 数据集上对比 5 种方式的一次分类结果。


omniglot 数据集验证数据上的损失对比


一次和二次分别在 omniglot 数据集验证数据上的精确度对比

pool_range=2、pool_range=4 时的收敛比使用基线最大池化要快得多。

这次加速再次超越该论文当前最佳的实现结果。因此这些改进超过很多现有技巧。

重现结果的代码和命令行参数

  • 所有实验可以从该 repo 中重现(地址:https://github.com/singlasahil14/sortpool2d)。
  • 重现 cluttered-mnist 和 fashion-mnist 数据集上结果的命令行参数地址:https://github.com/singlasahil14/sortpool2d/tree/master/resnet。
  • 重现使用 resnet 架构在 cifar10 和 cifar100 上的结果的命令行参数地址:https://github.com/singlasahil14/sortpool2d/tree/master/resnet。
  • 重现使用《Matching Networks for one shot learning》提出的架构在 omniglot 上的结果的命令行参数地址:https://github.com/singlasahil14/sortpool2d/tree/master/matching_networks。
  • 重现使用《学习记忆罕见事件》提出的架构在 omniglot 上的结果的命令行参数地址:https://github.com/singlasahil14/sortpool2d/tree/master/learning_to_remember_rare_events。

结论

这一池化层(我将其称之为 sort_pool2d)在所有数据集和架构中的表现大大优于 max_pool2d。而计算时间的优势也很大。通过编写高度优化的 C 语言代码和 CUDA 代码,我们还可以进一步优化每次迭代的时间。

虽然这一方式并不能解决空间信息丢失的问题。但是它为解决这个问题提出了一个很有意义的方向。


入门池化层GitHub工程sort_pool2d实现
3
暂无评论
暂无评论~