用对抗网络生成训练数据:CMU论文A-Fast-RCNN的Caffe实现

最近,卡耐基梅隆大学(CMU)的王小龙等人发表的论文《A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection》引起了很多人的关注。该研究将对抗学习的思路应用在图像识别问题中,通过对抗网络生成遮挡和变形图片样本来训练检测网络,取得了不错的效果。该论文已被 CVPR2017 大会接收。


论文链接:http://www.cs.cmu.edu/~xiaolonw/papers/CVPR2017_Adversarial_Det.pdf

Github:https://github.com/xiaolonw/adversarial-frcnn


论文:A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection


640-14.jpeg


摘要


如何确定物体探测器能够应对被遮蔽、不同角度或变形的图像?我们目前的解决方法是使用数据驱动的策略,收集一个巨大的数据集——覆盖所有条件下物体的样子,并希望通过模型训练能够让分类器学会把它们识别为同一个物体。但是数据集真的能够覆盖所有的情况吗?我们认为像分类、遮蔽与变形这样的特性也符合长尾理论。一些遮蔽与变形非常罕见,几乎永远不会发生,而我们希望训练出的模型是能够应付所有情况的。在本论文中,我们提出了一种新的解决方案。我们提出了一种对抗网络,可以自我生成遮蔽与变形例子。对抗的目标是生成物体探测器难以识别的例子。在我们的架构中,原识别器与它的对手共同进行学习。实验证明,我们的方法与 Fast-RCNN 相比,在 VOC07 上的 mAP 上的升幅为 2.3%,在 VOC2012 物体识别挑战中的 mAP 升幅为 2.6%。我们同时发布了本研究的代码。


640-7.jpeg

图 1:在论文中,我们提出了使用对抗网络来生成带有遮挡和变形的例子,从而让物体探测器难以进行分类。随着探测器的性能逐渐提升,对抗网络产生的图片质量也在提升。通过这种对抗策略,神经网络识别物体的准确性得到了进一步提升。


640-8.jpeg

图 2:该方法的 ASDN 网络架构以及如何与 Fast RCNN 结合的示意图。我们的 ASDN 网络使用输入图片加入 RoI 池化层中得到的补丁。ASDN 网络预测遮挡/极高光蒙版,然后将其用于丢弃特征值,并传递到 Fast-RCNN 分类塔。


640-9.jpeg

图 3:(a)模型预训练——寻找难度最高的遮挡用于训练 ASDN 网络。(b)ASDN 网络生成的遮挡蒙版事例,黑色区域在通过 FRCN 管道时被遮挡。


640-10.jpeg

图 4:ASDN 与 ASTN 网络组合架构示意。首先创建遮挡蒙版,随后旋转路径以产生用于训练的例子。


640-11.jpeg

表格 1:VOC 识别测试的平均精度,FRCN 指使用我们训练方式的 FRCN 成绩。


该研究的 Caffe 实现:A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection


介绍


本实现是 Caffe 版本的 A-Fast-RCNN。尽管我们在论文中的初始实现是在 Torch 上进行的。但 Caffe 的版本更加简单、快速和易于使用。我们发布了用 Adversarial Spatial Dropout Network 训练 A-Fast-RCNN 的训练数据的代码。


许可


本代码是在 MIT License 之下发布的(请参阅 LICENSE 文件获取详细信息)。


引用


如果你认为本内容对你的研究有帮助,可以进行引用:

@inproceedings{WangCVPR17afrcnn,
   Author = {Xiaolong Wang and Abhinav Shrivastava and Abhinav Gupta},
   Title = {A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection},
   Booktitle = {Conference on Computer Vision and Pattern Recognition ({CVPR})},
   Year = {2017}
}


免责声明


本实现是建立在 OHEM 代码的一个 fork 上的,后者又建立在 Faster R-CNN Python 代码和 Fast R-CNN 之上。请在使用时选择相应的研究论文加以引用。


OHEM:https://github.com/abhi2610/ohem

Faster R-CNN Python:https://github.com/rbgirshick/py-faster-rcnn

Fast R-CNN:https://github.com/rbgirshick/fast-rcnn

结果


640.png


注意:研究中记录的结果基于 VGG16 网络。


安装


请遵循 VOC 数据下载和安装规范,这方面与 Faster R-CNN Python 一样。


使用


想要运行代码,请输入:

./train.sh

它包括三个阶段的训练:

./experiments/scripts/fast_rcnn_std.sh  [GPU_ID]  VGG16 pascal_voc

这曾被用来进行标准 Fast-RCNN 一万次迭代的训练,你或许需要下载模型和 log。


模型:http://suo.im/2cgwYG

Log:http://suo.im/39gkhf

./experiments/scripts/fast_rcnn_adv_pretrain.sh  [GPU_ID]  VGG16 pascal_voc

在对抗网络的预训练阶段,可能会需要下载模型和 log:


模型:http://suo.im/2cgwYG

Log:http://suo.im/1TSiRh

./copy_model.h

用于复制上述两个模型的权重,用于初始化联合模型。

./experiments/scripts/fast_rcnn_adv.sh  [GPU_ID]  VGG16 pascal_voc

用于 detector 联合训练对抗网络,在这一步中你可能会需要下载模型和 log:

模型:http://suo.im/25uFFX

Log:http://suo.im/2UTbnC

理论深度学习理论论文图像识别CMUCaffeGAN