Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

看见鸟头认识鸟,看见鸟尾巴就不认识鸟了?热点图你不懂CNN

当前卷积神经网络(CNN)越来越深,动辄一个几十层以上的黑箱。CNN到底是怎么分类的?用的是图像的纹理特征还是形状特征?特斯拉基于 CNN 的视觉系统为什么有缺陷?


深度网络在社会中的应用已越来越普遍,其安全性就变得尤为重要。尤其是近年来,所谓的长尾问题(long-tail problem)越来越普遍,工业应用中经常需要搞清楚深度网络每个错误的来源。因此,用人类能看懂的方式解释 CNN 的分类,一直是领域内的一个热点。研究者们提出了很多方法,其中最常见的一大类方法是 Saliency Map 或者叫热点图(heatmap),一般通过某种算法在图像上优化出一个热点图,来显示哪些区域在 CNN 的分类中是重要的。近年来,各种热点图算法五花八门,基于梯度的、基于优化的,有数百种不同的热点图算法。然而,热点图真的能解释 CNN 是怎么分类的么?


论文地址:https://arxiv.org/abs/2011.06733

NeurIPS 2021 的一篇论文《Structured Attention Graphs for Understanding Deep Image Classifications》表明了使用单张热点图解释 CNN 的局限性:下图 (a) 取自 ImageNet,其分类为 “美洲金翅雀”,CNN 可以正确将其分类。那么,CNN 到底是用了金翅雀的什么特征呢?(b), (c) 展示了两种热点图算法的结果。其中一张(Grad CAM)认为鸟的头部和翅膀更为重要,而另一张(I-GOS)则认为翅膀和腹部更为重要,哪种热点图说的更对呢?


答案其实是 “都对,也都不对!” 如果做一个干预性(interventional) 的实验:把图像所有部分遮挡住,只露出一小块送入训练好的 CNN 进行分类,会发现,如图 (d), (e), (f) 的三种完全不同的遮挡方式,CNN 都能以很高的置信度(predicted class-conditional probability)对遮挡后的图片给出正确的分类。


可见,热点图只能表示出 “这个区域对 CNN 分类是重要的”,但它并不是唯一重要的,所以这个解释就不全面。CNN 分类时,也许用的是“或” 的逻辑:如果看到了区域 1,那么这是只金翅雀,或者看到区域 2,也足够判断这是金翅雀,或者看到区域 3,也足够判断这是金翅雀了。说明无论哪种热点图,都没有全面地解释 CNN 的所有分类依据。本质上,热点图没法表示 “Region 1 或 Region 2 或 Region 3” 这种 “或” 的逻辑

实际上,机器学习的可解释性中的一个大难题,就是如何把像 CNN 这种做出连续判断的黑箱系统转化为人类更熟悉、更容易理解的离散的逻辑系统,比如说用一个(概率)逻辑表达式来表示 CNN 的分类方式。之前学者们提出过一些使用决策树来近似 CNN 的方法,但是对于复杂的 CNN 来说,用简单的决策树很难近似它的分类决策。如果能用如上这种逻辑表达式来表示 CNN 的分类,从而让人更容易理解,无疑是在可解释性的研究中迈出了一大步。

借助搜索的方法解释 CNN 的分类

那么,首先需要解决的问题是,这种 “或” 的逻辑,究竟是偶然现象,还是普遍存在的呢?在该论文中,研究者们创造性地提出使用搜索的方法来枚举 CNN 所有可能的热点图。

该研究将图像分成 7x7 的网格,共 49 个小块(patch),这些小块的组合可以产生不同的遮挡方式(mask),如果遮挡之后,CNN 的分类结果和全图的分类结果相同,并且预测的概率也差不多,就可以认为这种遮挡方式之后露出的小块足以让 CNN 做出分类。这里使用集束搜索(beam search)的方式搜索出所有这样的遮挡图像。Beam search 首先将图像全部遮挡住,然后搜索所有只露出一个小块的图像,将每个图像输入 CNN 后,得到此图属于和全图同一类的概率,然后保留 K 个预测概率最高的小块 mask 进入下一轮。下一轮中在每个 mask 上试着增加一个小块,并同样保留 K 个预测概率最高的 mask,直到可以完全解释 CNN 在此图上的分类为止。



如何定义 “完全解释 CNN 在此图上的分类”?假设 CNN 在全图上预测的分类概率(predicted class-conditional probability) 为,作者定义极小充分解释 (Minimal sufficient explanation ,MSE)为极小的,可以使 CNN 给出至少倍分类概率的部分遮挡图片。极小的意义是在这个部分遮挡图片中去掉任何一个小块,都会导致分类概率低于倍。这个在文中取 90%,即部分遮挡图片的分类概率要超过全图的 90%,才被认为是个极小充分解释。找到极小充分解释后,即称图片被解释了。

在 ImageNet 数据集上对每张图像进行搜索实验的结果见下图和表:


上图描述了有多少图像可以被少量的小块解释。从图中可以看出,通过 beam search 方法,有 80% 的图像可以被不超过 10 个小块所解释——在全图 49 个小块中,10 个小块大致是 20%,即说明 80% 的 ImageNet 图像,可以仅显示 20% 的内容即可被 CNN 分类。另外值得注意的是,用 beam search 搜索的方法,在同样的块数时,可以比 Grad-CAM/I-GOS 一类优化 / 梯度的热点图方法解释更多的图片—如显示 10 个小块时,Grad-CAM/I-GOS 只能解释 50% 左右的图像,而通过搜索方法可以解释 80% 的图像!说明即使仅仅是搜索热点图,搜索方法也远好于基于梯度或优化的方法!


上表描述了每个图像平均存在解释的个数:如果要求小块之间不能重叠,采用 K=15 的 beam search 可以为每个图像找到平均 1.87 个解释,如果要求小块之间可以有一块重叠,则每个图像平均可以找到 4.51 个解释!由此可见,试图通过单张热点图来解释 CNN 在一张图像上的分类依据,是想得过于简单了。

既然很多图像都有多个解释,那么接下来应该考虑一下如何更好地显示这些解释。简单的方法是将其罗列,但该文选择的方式是对每一个解释给出一个树状结构,从而体现去掉某些小块对 CNN 输出概率的影响。作者称之为 Structural Attention Graph(SAG)。如下图所示,红框的是被去掉的块,高亮的部分是输入 CNN 的块。可以看到,最左边的解释中去掉上方的块之后,CNN 的分类概率从 97% 降到 10%,而去掉下方的块后,CNN 的分类概率仅仅降到 53%,从中可以发现不同的块对 CNN 分类概率的不同影响。


那么,这种新的解释效果如何呢?作者们在 60 个用户身上进行了 user study。为了测试不同的解释方式能否让用户明白 CNN 的分类机制,user study 中用户需要预测 CNN 在被遮挡小块上的行为,如下图:


用户需要在左右两张部分遮挡的图片中选择 CNN 会对哪一图预测的概率更高。其实,很多验证 CNN 解释方法的用户研究经常陷入一个误区即去询问用户 CNN 提供的解释是否能帮助用户更好的分类。但是,对于日常的图片,用户不需要看 CNN 就已经都认识了!所以 CNN 当然没法提高用户的分类精度。然而,用户却不一定能准确的说出 CNN 什么时候会判断正确,什么时候会判断错误。而这种对 CNN 内部机理的了解,或称为用户对 CNN 的 Mental model,是决定 CNN 模型能否部署在实用应用中的关键。因此,本文的用户实验侧重研究用户根据给出的解释,是否能判断 CNN 在一个新场景(此处使用一种用户没有见过的遮挡方式)下是否判断正确。你能从下面的 SAG 图中,猜出上面问题的答案吗?


用户实验发现,普通的热点图方法 GradCAM 和 I-GOS 对此任务没有任何显著的帮助,而该文提出的 SAG 则可以提高用户答题准确率 30% 之多。

总的来说,该论文从热点图方法不能全面解释 CNN 的局限性开始,致力于寻找一种方法把 CNN 的分类结果解释为逻辑表达式的形式,从而更全面地解释 CNN 的分类依据。其中,该研究创造性地采用了搜索而非优化的方法寻找 CNN 在每张图片上全部可能符合 “或” 逻辑的解释。并提出了 Structural Attention Graph(SAG)来把多个存在 “或” 关系的热点图用树状形式展示给用户。这种新的解释方式,跳出了单张热点图的窠臼,给深度学习的可解释性提供了新的思路。
理论NeurIPS 2021卷积神经网络
1
相关数据
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

准确率技术

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

逻辑技术

人工智能领域用逻辑来理解智能推理问题;它可以提供用于分析编程语言的技术,也可用作分析、表征知识或编程的工具。目前人们常用的逻辑分支有命题逻辑(Propositional Logic )以及一阶逻辑(FOL)等谓词逻辑。

暂无评论
暂无评论~