FAIR提出突触可塑算法:让神经网络学会记住和忘却

Facebook AI 研究院近日联合 KU Leuven 提出了一种由著名神经科学定律——赫泊规则启发的在线学习算法。研究表明,这种方法可以使模型根据当前任务保留过去任务的重要特征,灵活地适应新环境;并且可以无监督地应用于任何预训练模型,而不受基于损失函数方法的限制。

论文:Memory Aware Synapses: Learning what (not) to forget


论文地址:https://arxiv.org/abs/1711.09601

人类可以持续不断地学习,陈旧且不常用的知识会被新信息覆盖,但重要且常用的知识不会被随意擦除。目前在人工学习系统中,终生学习(lifelong learning,LLL)主要关注在任务中积累知识和克服灾难性忘却问题(catastrophic forgetting)。在这篇论文中,我们指出,给定有限的模型容量和无限的将要学习的新信息的时候,需要选择对知识进行保留还是擦除。由突触可塑性所启发,我们提出了一种在线学习方法,基于网络对数据的激活频率,以无监督的方式计算神经网络参数的「重要性」。在学习了一个任务之后,每当有样本馈送到网络中,就会基于预测输出对参数变化的敏感度,测量网络的每个参数的重要性。当学习一个新任务的时候,会对重要参数的改变进行惩罚(即阻碍该变化)。我们证明了我们的方法的一个局域版本正好是赫泊规则(Hebb's rule)在识别神经元之间的重要连接的直接应用。我们在一系列的目标识别任务和持续学习向量的挑战性问题上测试了我们的方法,取得了当前最佳的结果,展示了根据需求调整参数的重要性的能力。


图 1. 研究人员提出的持续学习模式。

正如大多数终生学习论文所述,任务是按照序列学习的。在这里我们假设,在任务学习之间,智能体是被激活且持续学习的。在这样的过程中它会看到此前任务中未标记的样本。这种信息可以用来更新模型参数中一些重要的权重。频繁出现的类有更大的贡献。这样,智能体就可以明白哪些类别是重要的,不能被遗忘。作为结果,这些类知识在学习新任务时不会被抹去。

新研究的主要贡献可以总结为:

  • 首先,这是一种新的 LLL 方法——Memory Aware Synapses(MAS)。它基于函数逼近而不是损失函数优化,当学习重要性的权重的时候不需要使用标签。从而该方法可以应用于无标签数据,例如真实的测试环境。
  • 其次,我们证明了我们的 LLL 方法和赫泊学习规律的联系,可以视其为我们方法的局域版本。
  • 最后,我们在目标识别和事实学习(例如,<主, 谓, 宾>三元组,使用向量而不是 softmax 输出)任务中都达到了当前最佳性能。


图 2. 和基于损失函数优化的方法不同,我们的方法基于输入-输出的函数对参数的敏感度(梯度)。(a)在训练第一个任务的同时,(基于损失的方法)测量损失函数对参数变化的敏感度以表示参数重要性。(b)相对的,我们在训练完成之后,使用无标记数据计算输出函数对参数变化的敏感度,测量参数的重要性。(c)当学习一个新任务的时候,对重要参数的改变进行惩罚。

目标识别


表 1. 目标识别的分类准确率(%)。重要性的权重Ω_ij 是在训练数据上计算的。加粗的数据表示当前最佳。


表 2. 目标识别的分类准确率(%)。使用训练数据和测试数据(无标签)计算重要性的权重Ω_ij 的结果对比。

两个任务的实验

我们随机地将事实分成两部分以作为数据的两个批量,B_1 和 B_2,并将任务设置为从 B_1 到 B_2 的迁移。


表 3. 在由 6DS 数据集随机分成的两个任务场景中进行事实学习的平均准确率。


表 4. 对测试条件的适应能力。分别在 B_11 和 B_12(由 B_1 分成的两个子集)上学习重要性的权重。在由 6DS 数据集随机分成的两个任务场景中进行事实学习的平均准确率。

更长的任务序列


表 5. 在由 6DS 数据集分成的 4 个不相交任务场景中进行事实学习的平均准确率。

适应性测试


图 4. 每完成 4 个任务序列中的一个之后,测试对 6DS 数据集的(关于体育运动的)子集的平均准确率。

其中 g-MAS(粉色线)学习到该子集是重要的,需要保留,并显著地防止了对该子集的忘却。联合训练方法(Joint Training,黑色虚线)作为参考,但实际上它违反了 LLL 的设置,因为它是同时训练所有的数据。

理论理论论文FAIR在线学习
2