Ruijia Xu等作者

鸡尾酒网络 DCTN:源分布结合律引导的迁移学习框架

来自中山大学、哈尔滨工业大学以及商汤科技公司的研究人员联合提出了一种名为「鸡尾酒网络」(DCTN)的深度迁移学习框架,将现有的单源域适应过程推广到了更加真实和通用的多源域适应场景。DCTN 启发于 2009 年 Y. Mansou [1] 的源分布结合律(source distribution combining rule)。具体而言,DCTN 通过多路对抗机制学习领域无关的特征表达,依据对抗相似性分数作为多源分布结合律的权重和各自的源分类器合作从而对目标域样本进行联合识别,并使用高置信度的伪标注样本对特征表达进行再适应从而引入更强的判别性能。实验中,DCTN 在 Office31, Image-CLEF 以及新提出的基于同时迁移四个数据源的 Digit-five 评测数据集上均取得了比较显著的性能提升。该论文已经被 CVPR 2018 大会接收。

一、简介

随着大规模数据的不断产生和依靠人力进行信息标注的困难,域适应迁移方法逐渐成为机器学习领域中一项非常重要的研究课题。域适应学习旨在适配不同领域数据间的特征分布,提升不同领域间分类器迁移后的性能表现,解决目标域数据缺乏标注信息的难题。域适应迁移学习同时也是工业界的一项关键技术手段,在人脸识别自动驾驶和医学影像等垂直领域均具有较强的应用需求。比如在自动驾驶领域,如何最小化虚拟环境与真实环境、其他城市和当前城市的领域偏差?又比如在医学影像领域,如何综合利用多源医疗影像给出全面诊断?这些都是领域迁移学习极具应用潜力的场景。

值得注意的是,我们在生活中搜集到的数据源往往是来自各种渠道的。不同渠道的得到的数据源分别与目标域的数据存在偏移现象,而且这些数据源之间也存在偏移。一种更值得广泛考虑的情况是,多个数据源之间的类别也具有差异性,这对多源迁移学习带来了新的挑战 (单源域适应与多源域适应的区别于联系见图 1)。然而,虽然深度迁移学习已经在单领域适应问题上取得各种研究进展,如何运用深度网络去解决多领域适应问题目前依然处于接近空白的阶段。

图 1.(a)单元域适应;(b)多源域适应

二、鸡尾酒背后的理论驱动

多数据源的迁移学习研究可以追溯到 J. Blitzer [1] 和 Y. Mansour [2] 的理论工作。其中 [1] 提出了第一个多源迁移学习的学习上界,为以后的多源域适应学习模型设计打下了基础;[2] 则提出了迁移目标域分布由多个源域分布混合组成。基于这种设定下,他们提出了源分布结合律(source distribution combining rule)。直观而言,他们认为不同的源域学习应该具有自己的分类器,而非使用一个单一分类器去统一所有源域与目标域分类结果。同时,不同源域跟目标域之间分别有相似度,那么与目标域更相似的源域,它的分类器对于目标域的数据进行分类的结果应该更可信。根据以上两点,目标域的分类结果应该由多个源域的分类结果加权而成,而每个源分类权重应该要反映对应源域与目标域的相似性。

三、鸡尾酒网络(Deep CockTail Network) 

图 2. 鸡尾酒网络 DCTN

为满足多源分布结合律,我们提出 Deep CockTail Network(鸡尾酒网络 DCTN)。在图 2 的具体数据流中,我们利用共享特征网络对所有源域以及目标域进行特征建模,然后利用多路对抗域适应技术(基于单路对抗域适应(adversarial domain adaptation)下的扩展,对抗域适应的共享特征网络对应于生成对抗学习 (GAN) 里面的生成器),每个源域分别与目标域进行两两组合对抗学习域不变特征。同时每个源域也分别进行监督学习,训练基于不同源类别下的多个 softmax 分类器。注意到,基于对抗学习的建模,我们在得到共享特征网络的同时,也可以得到多个源分别和目标域对抗的判别器。这些判别器在对于目标域的数据,可以分别给出与每一个源域之间的混淆分数(perplexity score)去衡量该源域与目标域之间的相似性。因此,对于每一个来自目标域的数据,我们首先利用不同源下的 softmax 分类器给出其多个分类结果。然后,基于每一个类别,我们找到包含该类别的所有源域 softmax 分类概率,再基于这些源域与目标域的混淆分数,对分类概率取加权平均得到每个类别的分数。简而言之就是,越跟目标域相识的源域混淆度会更高,意味着其分类结果更可信从而具有更高的加权权值。需要注意的是,我们并没有直接作用于所有 softmax 分类器上反而是基于每个类别分别进行加权平均处理。这是因为在我们的假设下,每个源的类别不一定共享,从而 softmax 结果不能简单相加。当然,我们的方法也适用于所有源共享类别的情况,这样我们的公式会等价于直接将 softmax 分类结果进行加权相加。

考虑到不同源域数据分布间的差异性,某些源域的样本对提升目标域分类性能有较大的正向迁移表现,而有些源域的样本域适应性能则较弱甚至带来相对的负面迁移影响。为此,我们设计了如下基于域间困难样本的梯度回传策略,具体可参考图 3 的算法流程。

图 3. 域间困难样本的梯度回传策略

基于多路对抗域适应下,我们进一步提出分类再适应机制。我们基于各源域的分类器和对抗相似性分数对目标域图片进行识别,选取高置信度的目标域伪标记样本微调特征提取器和多源分类器。如此下两个域适应学习进行交替迭代直至模型收敛。整个 DCTN 的训练可参考图 4 的算法流程。

四、实验

图 4. DCTN 的学习过程

论文在 Office-31、ImageCLEF 等主流域适应数据集上进行了实验。其中 Office-31 数据集来自 Amazon(电商图片)、Webcam(网络摄像头拍摄图片)、DSLR(单反相机拍摄图片) 三个视觉领域,共包含 4652 张图片 31 个类别标签。论文以单源最佳模型、多源合并模型等作为评测标准,与 DAN、RevGrad 等流行算法进行了充分对比。

多源域适应框架在 Office-31 和 ImageCLEF 数据集上的分类准确率

除了常规的域适应设置,我们还进行了包含类别偏差的实验探索。仍然以 Office-31 数据集为例,令两个源域分别包含前三分之二和后三分之二的类别标签,据此进行多源域适应过程,对目标域的图片类别进行预测。

类别偏差设置下在 A,D→W 任务上的分类准确率

类别偏差设置下在 I,P→C 任务上的分类准确率

最后,我们构建了四对一的多源迁移学习任务标准 Digit-five,同时给出在该多源迁移学习标准下的两个多源域适应任务结果。我们可以看出 DCTN 明显优于目前主流的深度迁移学习算法。

论文:Deep Cocktail Network: Multi-source Unsupervised Domain Adaptation with Category Shift

项目链接:http://www.sysu-hcp.net/deep-cocktail-network/

[1] J. Blitzer, K. Crammer, A. Kulesza, F. Pereira, and J. Wortman. Learning bounds for domain adaptation. In Advances in neural information processing systems, pages 129–136, 2008. 

[2] Y. Mansour, M. Mohri , and A. Rostamizadeh . Domain adaptation with multiple sources. In Advances in neural information processing systems

理论
1
相关数据
收敛技术
Convergence

在数学,计算机科学和逻辑学中,收敛指的是不同的变换序列在有限的时间内达到一个结论(变换终止),并且得出的结论是独立于达到它的路径(他们是融合的)。 通俗来说,收敛通常是指在训练期间达到的一种状态,即经过一定次数的迭代之后,训练损失和验证损失在每次迭代中的变化都非常小或根本没有变化。也就是说,如果采用当前数据进行额外的训练将无法改进模型,模型即达到收敛状态。在深度学习中,损失值有时会在最终下降之前的多次迭代中保持不变或几乎保持不变,暂时形成收敛的假象。

人脸识别技术
Facial recognition

广义的人脸识别实际包括构建人脸识别系统的一系列相关技术,包括人脸图像采集、人脸定位、人脸识别预处理、身份确认以及身份查找等;而狭义的人脸识别特指通过人脸进行身份确认或者身份查找的技术或系统。 人脸识别是一项热门的计算机技术研究领域,它属于生物特征识别技术,是对生物体(一般特指人)本身的生物特征来区分生物体个体。

机器学习技术
Machine Learning

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

混乱度技术
perplexity

衡量概率分布或概率模型预测样本能力的一个度量单位,其可以被用来比较概率模型的好坏,值越低表示在预测样本方面的效果越好。

自动驾驶技术
self-driving

从 20 世纪 80 年代首次成功演示以来(Dickmanns & Mysliwetz (1992); Dickmanns & Graefe (1988); Thorpe et al. (1988)),自动驾驶汽车领域已经取得了巨大进展。尽管有了这些进展,但在任意复杂环境中实现完全自动驾驶导航仍被认为还需要数十年的发展。原因有两个:首先,在复杂的动态环境中运行的自动驾驶系统需要人工智能归纳不可预测的情境,从而进行实时推论。第二,信息性决策需要准确的感知,目前大部分已有的计算机视觉系统有一定的错误率,这是自动驾驶导航所无法接受的。

监督学习技术
Supervised learning

监督式学习(Supervised learning),是机器学习中的一个方法,可以由标记好的训练集中学到或建立一个模式(函数 / learning model),并依此模式推测新的实例。训练集是由一系列的训练范例组成,每个训练范例则由输入对象(通常是向量)和预期输出所组成。函数的输出可以是一个连续的值(称为回归分析),或是预测一个分类标签(称作分类)。

迁移学习技术
Transfer learning

迁移学习是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中。迁移学习是通过从已学习的相关任务中转移知识来改进学习的新任务,虽然大多数机器学习算法都是为了解决单个任务而设计的,但是促进迁移学习的算法的开发是机器学习社区持续关注的话题。 迁移学习对人类来说很常见,例如,我们可能会发现学习识别苹果可能有助于识别梨,或者学习弹奏电子琴可能有助于学习钢琴。

权重技术
Weight

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

准确率技术
Accuracy

分类模型的正确预测所占的比例。在多类别分类中,准确率的定义为:正确的预测数/样本总数。 在二元分类中,准确率的定义为:(真正例数+真负例数)/样本总数

商汤机构
SenseTime

返回顶部