Yufeng Xiong作者Joni编辑

使用强化学习在模型测试时修改模型行为

机器学习流程中,训练阶段往往需要大量计算资源和时间,而推理阶段则不然,所需的计算资源要少得多,而且有时还需要即时响应。另外,训练后的模型在实际应用时可能还会遇到特定的约束条件。针对这些约束条件重新训练模型会导致成本过高。为此,谷歌大脑的研究者提出了使用强化学习修改测试时间的模型行为的方法,在这一问题上得到不错的结果。机器之心技术分析师解读了这项研究,本文为解读的中文版。

论文地址:https://openreview.net/pdf?id=Hk8-lkHKe

1 引言

在测试时间,机器学习模型往往需要在特定的预算约束下实现高预测准确度;而在训练时间则没有这样的约束。比如,在搜索引擎中,生成搜索结果的速度往往和结果的准确度一样重要。在自动医疗诊断中,因为医疗检验往往成本高昂,所以目标是能在一定预算内基于一系列测试正确诊断任意给定的病人。

在这些案例中,往往存在预测准确度(得到能让人接受的模型表现)和预测成本(满足特定约束条件)之间的权衡。但在训练时间就将这些约束明确设置在模型中却并不好,原因如下:

  1. 对于诊断这类情况,还需要人工设计和再训练一个新模型,这样做没效果。
  2. 这让人无法在推理时间以与输入无关的方式调整这些约束。

在这篇论文中,研究者研究了一种可以根据每个输入的情况在测试时间修改模型行为的方法。该方法包含两部分:构成器模型(Composer Model)和策略偏好(Policy Preferences)。其中构成器模型由一组模块和一个控制器网络构成,可使用强化学习来训练以探究中间激活。通过使用策略偏好的概念,我们可以在推理时间自适应地调整控制器策略。

2 构成器模型

构成器模型由一个控制器网络和一组子模块构成,如图 1 所示。这些子模块都是神经网络,可被组织成被称为“元层(metalayers)”的分组。

图 1:构成器模型示意图。这个构成器有 3 个元层,每个元层包含 2 个模块。圆圈表示随机性节点;正方形表示确定性节点。

如上图所示,x 为输入,f_i 表示第 i 个元层。当 i=0 时,f_0 被称为“stem”,这是一个特殊的簿记层(book-keeping layer)。当 i>0 时,f_i 由 m_i 个函数构成,这些函数分别表示第 i 层中的各个子模块。这些模块可以有不同的架构和参数数量等,这一点很重要。一旦选择了不同元层中的模块,我们就可以使用随机梯度下降来训练这个定义后的神经网络

控制器网络由 n 个函数构成(g_1, ..., g_n),其输出是一个对应元层的策略概率分布。其详情可以表示成一个等式序列:

其中,c_i 是第 i 步在控制器的策略分布上采样得到的一个样本。因为控制器网络处理的是一个模块序列,所以可用 RNN 来实现它,其中控制器函数共享权重和隐藏状态。在这篇论文中,研究者使用了 REINFORCE [1] 来训练控制器网络,以最大化奖励函数 log p(y|x, c_1:n),其中 c_1:n 表示 c_1, ..., c_n。论文附录 A 解释了 REINFORCE 的详情。

3 策略偏好

为了阐释对不同策略的偏好,我们可以增强控制器的奖励函数 log p(y|x, c_1:n)。为了能在测试时间修改偏好,研究者在奖励函数上添加了一个成本 C(p_1:n, c_1:n, γ),其中 γ 是一个偏好值,我们可以根据每个输入修改,也可根据每个 mini-batch 修改。

其基本思想是,参数 γ 是控制器网络的一个输入;在训练时间,我们从一个代表偏好概率的分布中采样 γ。有了这样的设置,控制器可以学习根据训练时间使用的不同 γ 值来更改策略。在推理时间,还可以修改 γ 以适应变化的偏好。

在这篇论文中,研究者描述了策略偏好的两个示例:略看偏好(Glimpse Preferences)或熵偏好(Entropy Preference)。但是,策略偏好并不只限于这两个实例,它们也可在更通用的强化学习设置中应用于构成器模型,这可留待未来研究。

3.1 略看偏好

计算时间 [2,3] 等资源消耗是模型在测试时间需要考虑的重要情况。假设模型可以略看与不同成本和有用性相关的输入,并且看的范围越大,需要的参数就越多,那么该成本可定义为:

其中 β_i 和 c_i 表示第 i 个元层中参数向量的所选分量。这里的成本 C_1 是基本每个输入的成本。

3.2 熵偏好

研究者还引入了另一个成本,应用于每个批(batch)的情况,以抵消控制器将更多概率分配给训练更多的模块的情况(这有损多样性)。

注意,N 是一个批中示例的数量,p_i (j) 是第 i 个元层中第 j 个批元素的模块概率向量。

4 实验结果

在实验中,研究者引入了一个修改过的 MNIST 版本——Wide-MNIST 数据集,并将其用于测试构成器模型和策略偏好组合后的系统。Wide-MNIST 数据集中的图像大小为 28*56,而且数字会出现在图像的左半部份或右半部份。所以总类别数为 20,其中 10 类左侧数字(0-9),10 类右侧数字(0-9)。然后使用两个子模块训练这个构成器模型,其中小模块仅略看输入的左边部分,有更多参数的大模块会略看整个输入。这里的策略偏好既包含略看偏好,也包含熵偏好。图 2 给出了详细结果。

图 2:更改在 Wide-MNIST 数据集上训练的模型在测试时间的略看偏好。注意,在 Composer 曲线上的所有数据点都来自单个训练后的模型。左图:构成器由两个模块构成,其中小模块仅略看输入的左边部分,有更多参数的大模块会略看整个输入。构成器在不同的略看偏好值下进行了评估,得到了不同的平均参数使用值。基准模型是通过随机选择大模块和小模块而随时间变化创建的,以便得到不同的平均参数使用值。构成器曲线和基准曲线之间的差距代表了该控制器基于其输入样本“智能地”适应模块使用情况的程度。右图:热图表示了控制器分配给每个模块的概率,对于这 20 个标签中的每一个,都在整个测试集上进行了平均。左边一列对应小模块,右边一列对应大模块。上半部分 10 行对应“右边”数字,下半部分 10 行对应“左边”数字。随着略看偏好降低,控制器首先会选择使用小模块来处理“左边”的数字,因为这不会给分类准确度造成太大损失。只有当略看偏好非常低时,控制器才会将“右边”的数字分配给小模块。

定性实验

为了进一步阐释构成器模型和策略偏好的优势,作者还执行了一些定性实验,其中有的使用了策略偏好来训练构成器模型,有的没有。

5.1 模块限定

研究者在 MNIST 和 CIFAR-10 上分别进行了实验,并且得到了近似的结果。如图 3 所示,控制器在某种程度上根据相似度将数字分成了 3 组:{2, 3, 5, 6, 8}、{0, 1} 和 {4, 7, 9}。

图 3:对于每层有 3 个模块的构成器,在一个随机 minibatch 的第一个元层中的模块选择。每个模块中参数的数量从上到下增加,第二个模块的参数数量是第一个模块的两倍,第三个模块的参数数量是其三倍。这个构成器是使用一个惩罚所选模块的参数数量的策略偏好训练的。每一行都对应于一个模块,模块大小从上到下增大。底层模块似乎充当了 4,7,9 鉴别器。

在 CIFAR-10 数据上的实验得到了类似的结果,如图 4 所示。控制器自动将数据分成了两组:人造物体(飞机、汽车、船、卡车)和自然事物(鸟、猫、鹿、狗、蛙、马)。

图 4:这张热图展示了控制器为 CIFAR-10 中的每个类别分配的在第一个元层中使用每个模块的概率。注意控制器实际上已经将数据分成了两组——人造物体和自然事物。

5.2 熵偏好

为了研究批熵偏好和非批熵偏好的影响,研究者在 MNIST 数据集上执行了测试。如图 5 中的热图所示。左图:使用批熵偏好的模块被利用得更加平等。右图:使用非批熵偏好时,某些模块被用得更多,而有些模块则从未被使用过。

图 5:对于一个在 MNIST 上训练的带有 1 个元层和 4 个模块的构成器,批熵偏好对模块选择频率的影响。这是与前一张图片类似的模块选择热图。左图:使用熵偏好,模块被使用得更加平等。在 10 万步之后,这一次运行的模块选择和类别标签之间的互信息是 0.9 nats。右图:使用普通的熵惩罚,某些模块会“取得领先”,另一些则永远没法追上。在 10 万步之后,这一次运行的互信息仅有 0.43 nats。这项测试是为该偏好使用一个常数零方差值执行的,所以在这个案例中,非批熵偏好等效于标准的熵惩罚——这涉及到分立的熵惩罚(每个样本一个)的求和。

6 分析师简评

研究者在这篇论文中提出的技术可以:(1)在测试时间而不是训练阶段改变计算资源的量;(2)使用强化学习确定推理阶段的资源量。与之前的部分计算相比 [4],这项研究更有前景并且有多项优势:首先,它允许在训练完成之后在测试时间动态适应计算资源。其次,构成器模型可自适应地决定在更简单的样本上使用更少的计算资源。最后,我们可以直观且可解释性地查看在推理阶段不同资源使用了哪些输入。

参考文献

1. Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8(3-4):229–256, 1992. 

2. Michael Figurnov, Maxwell D. Collins, Yukun Zhu, Li Zhang, Jonathan Huang, Dmitry P. Vetrov, and Ruslan Salakhutdinov. Spatially adaptive computation time for residual networks. CoRR, abs/1612.02297, 2016. URL http://arxiv.org/abs/1612.02297

3. Alex Graves. Adaptive computation time for recurrent neural networks. CoRR, abs/1603.08983, 2016. URL http://arxiv.org/abs/1603.08983

4. Futamura, Y., 1983. Partial computation of programs. In RIMS Symposia on Software Science and Engineering (pp. 1-35). Springer Berlin Heidelberg.

技术分析
相关数据
权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

随机梯度下降技术

梯度下降(Gradient Descent)是遵循成本函数的梯度来最小化一个函数的过程。这个过程涉及到对成本形式以及其衍生形式的认知,使得我们可以从已知的给定点朝既定方向移动。比如向下朝最小值移动。 在机器学习中,我们可以利用随机梯度下降的方法来最小化训练模型中的误差,即每次迭代时完成一次评估和更新。 这种优化算法的工作原理是模型每看到一个训练实例,就对其作出预测,并重复迭代该过程到一定的次数。这个流程可以用于找出能导致训练数据最小误差的模型的系数。

强化学习技术

强化学习是一种试错方法,其目标是让软件智能体在特定环境中能够采取回报最大化的行为。强化学习在马尔可夫决策过程环境中主要使用的技术是动态规划(Dynamic Programming)。流行的强化学习方法包括自适应动态规划(ADP)、时间差分(TD)学习、状态-动作-回报-状态-动作(SARSA)算法、Q 学习、深度强化学习(DQN);其应用包括下棋类游戏、机器人控制和工作调度等。

暂无评论
暂无评论~