Aravind Srinivas等作者Racoon、Jamin编译

BAIR最新RL算法超越谷歌Dreamer,性能提升2.8倍

pixel-based RL 算法逆袭,BAIR 提出将对比学习与 RL 相结合的算法,其 sample-efficiency 匹敌 state-based RL。

此次研究的本质在于回答一个问题—使用图像作为观测值(pixel-based)的 RL 是否能够和以坐标状态作为观测值的 RL 一样有效?传统意义上,大家普遍认为以图像为观测值的 RL 数据效率较低,通常需要一亿个交互的 step 来解决 Atari 游戏那样的基准测试任务。

研究人员介绍了 CURL:一种用于强化学习的无监督对比表征。CURL 使用对比学习的方式从原始像素中提取高阶特征,并在提取的特征之上执行异策略控制。在 DeepMind Control Suite 和 Atari Games 中的复杂任务上,CURL 优于以前的 pixel-based 的方法(包括 model-based 和 model-free),在 100K 交互步骤基准测试中,其性能分别提高了 2.8 倍以及 1.6 倍。在 DeepMind Control Suite 上,CURL 是第一个几乎与基于状态特征方法的 sample-efficiency 和性能所匹配的基于图像的算法。


  • 论文链接:https://arxiv.org/abs/2004.04136

  • 网站:https://mishalaskin.github.io/curl/

  • GitHub 链接:https://github.com/MishaLaskin/curl


背景介绍

CURL 是将对比学习与 RL 相结合的通用框架。理论上,可以在 CURL pipeline 中使用任一 RL 算法,无论是同策略还是异策略。对于连续控制基准而言(DM Control),研究团队使用了较为熟知的 Soft Actor-Critic(SAC)(Haarnoja et al., 2018) ;而对于离散控制基准(Atari),研究团队使用了 Rainbow DQN(Hessel et al., 2017))。下面,我们简要回顾一下 SAC,Rainbow DQN 以及对比学习。

Soft Actor Critic

SAC 是一种异策略 RL 算法,它优化了随机策略,以最大化预期的轨迹回报。像其他 SOTA 端到端的 RL 算法一样,SAC 在从状态观察中解决任务时非常有效,但却无法从像素中学习有效的策略。

Rainbow

最好将 Rainbow DQN(Hessel et al., 2017)总结为在原来应用 Nature DQN 之上的多项改进(Mnih et al., 2015)。具体来说,深度 Q 网络(DQN)(Mnih et al., 2015)将异策略算法 Q-Learning 与卷积神经网络作为函数逼近器相结合,将原始像素映射到动作价值函数里。

除此之外,价值分布强化学习(Bellemare et al., 2017)提出了一种通过 C51 算法预测可能值函数 bin 上的分布技术。Rainbow DQN 将上述所有技术组合在单一的异策略算法中,用以实现 Atari 基准的最新 sample efficiency。此外,Rainbow 还使用了多步回报(Sutton et al.,1998)。

对比学习

CURL 的关键部分是使用对比无监督学习来学习高维数据的丰富表示的能力。对比学习可以理解为可区分的字典查找任务。给定一个查询 q、键 K= {k_0, k_1, . . . } 以及一个明确的 K(关于 q)P(K) = ({k+}, K \ {k+}) 分区,对比学习的目标是确保 q 与 k +的匹配程度比 K \ {k +} 中的任何的键都更大。在对比学习中,q,K,k +和 K \ {k +} 也分别称为锚点(anchor),目标(targets),正样本(positive), 负样本(negatives)。

CURL 具体实现

CURL 通过将训练对比目标作为批更新时的辅助损失函数,在最小程度上改变基础 RL 算法。在实验中,研究者将 CURL 与两个无模型 RL 算法一同训练——SAC 用于 DMControl 实验,Rainbow DQN 用于 Atari 实验。

总体框架概述

CURL 使用的实例判别方法(instance discrimination)类似于 SimCLR、MoC 和 CPC。大多数深度强化学习框架采用一系列堆叠在一起的图像作为输入。因此,算法在多个堆叠的帧中进行实例判别,而不是单一的图像实例。

研究者发现,使用类似于 MoCo 的动量编码流程(momentum encoding)来处理目标,在 RL 中性能较好。最后,研究者使用一个类似于 CPC 中的双线性内积来处理 InfoNCE score 方程,研究者发现效果比 MoCo 和 SimCLR 中的单位范数向量积(unit norm vector products)要好。对比表征和 RL 算法一同进行训练,同时从对比目标和 Q 函数中获得梯度。总体框架如下图所示。

图 2:CURL 总体框架示意图

判别目标

选择关于一个锚点的正、负样本是对比表征学习的其中一个关键组成部分。

不同于在同一张图像上的 image-patches,判别变换后的图像实例优化带有 InfoNCE 损失项的简化实例判别目标函数,并需要最小化对结构的调整。在 RL 设定下,选择更简化判别目标的理由主要有如下两点:

  • 鉴于 RL 算法十分脆弱,复杂的判别目标可能导致 RL 目标不稳定。

  • RL 算法在动态生成的数据集上进行训练,复杂的判别目标可能会显著增加训练所需时间。


因此,CURL 使用实例判别而不是 patch 判别。我们可将类似于 SimCLR 和 MoCo 这样的对比实例判别设置,看做最大化一张图像与其对应增广版本之间的共同信息。

查询-键值对的生成

类似于在图像设定下的实例判别,锚点和正观测值是来自同一幅图像的两个不同增广值,而负观测值则来源于其他图像。CURL 主要依靠随机裁切数据增广方法,从原始渲染图像中随机裁切一个正方形的 patch。

研究者在批数据上使用随机数据增广,但在同一堆帧之间保持一致,以保留观测值时间结构的信息。数据增广流程如图 3 所示。

图 3: 使用随机裁剪产生锚点与其正样本过程的直观展示。

相似度量

区分目标中的另一个决定因素是用于测量查询键对之间的内部乘积。CURL 采用双线性内积 sim(q,k)= q^TW_k,其中 W 是学习的参数矩阵。研究团队发现这种相似性度量的性能优于最近在计算机视觉(如 MoCo 和 SimCLR)中最新的对比学习方法中使用的标准化点积。

动量目标编码

在 CURL 中使用对比学习的目标是训练从高维像素中能映射到更多语义隐状态的编码器。InfoNCE 是一种无监督的损失,它通过学习编码器 f_q 和 f_k 将原始锚点(查询)x_q 和目标(关键字)x_k 映射到潜在值 q = f_q(x_q) 和 k = f_k(x_k) 上,在此团队应用相似点积。通常在锚点和目标映射之间共享相同的编码器,即 f_q = f_k。

CURL 将帧-堆栈实例的识别与目标的动量编码结合在一起,同时 RL 是在编码器特征之上执行的。

CURL 对比学习伪代码(PyTorch 风格)


实验

研究者评估(i)sample-efficiency,方法具体为测量表现最佳的基线需要多少个交互步骤才能与 100k 交互步骤的 CURL 性能相匹配,以及(ii)通过测量 CURL 取得的周期回报值与最佳表现基线的比例来对性能层面的 100k 步骤进行衡量。换句话说,当谈到数据或 sample-efficiency 时,其实指的是(i),而当谈起性能时则指的是(ii)。

DMControl

在 DMControl 实验中的主要发现:


  1. CURL 是我们在每个 DMControl 环境上进行基准测试的 SOTA ImageBased RL 算法,用于根据现有的 Image-based 的基准进行采样效率测试。在 DMControl100k 上,CURL 的性能比 Dreamer(Hafner 等人,2019)高 2.8 倍,这是一种领先的 model-based 的方法,并且数据效率高 9.9 倍。

  2. 从图 7 所示的大多数 16 种 DMControl 环境中的状态开始,仅靠像素操作的 CURL 几乎可以进行匹配(有时甚至超过)SAC 的采样效率。它是基于 model-based,model-free,有辅助任务或者是没有辅助任务。

  3. 在 50 万步之内,CURL 解决了 16 个 DMControl 实验中的大多数(收敛到接近 1000 的最佳分数)。它在短短 10 万步的时间内就具有与 SOTA 相似性能的竞争力,并且大大优于该方案中的其他方法。


表 1. 在 500k(DMControl500k)和 100k(DMControl100k)环境步长基准下,CURL 和 DMControl 基准上获得的基线得分。

图 4. 相对于 SLAC、PlaNet、Pixel SAC 和 State SAC 基线,平均 10 个 seeds 的 CURL 耦合 SAC 性能。

图 6. 要获得与 CURL 在 100k 训练步骤中所得分相同的分数,需要先行采用领先的 pixel-based 方法 Dreamer 的步骤数。

图 7. 将 CURL 与 state-based 的 SAC 进行比较,在 16 个所选 DMControl 环境中的每个环境上运行 2 个 seeds。

Atari

在 Atari 实验中的主要发现:

  1. 就大多数 26 项 Atari100k 实验的数据效率而言,CURL 是 SOTA PixelBased RL 算法。平均而言,在 Atari100k 上,CURL 的性能比 SimPLe 高 1.6 倍,而 Efficient Rainbow DQN 则高 2.5 倍。

  2. CURL 达到 24%的人类标准化分数(HNS),而 SimPLe 和 Efficient Rainbow DQN 分别达到 13.5%和 14.7%。CURL,SimPLe 和 Efficient Rainbow DQN 的平均 HNS 分别为 37.3%,39%和 23.8%。

  3. CURL 在三款游戏 JamesBond(98.4%HNS),Freeway(94.2%HNS)和 Road Runner(86.5%HNS)上几乎可以与人类的效率相提并论,这在所有 pixel-based 的 RL 算法中均属首例。


表 2. 通过 CURL 和以 10 万个时间步长(Atari100k)为标准所获得的分数。CURL 在 26 个环境中的 14 个环境中实现了 SOTA。

项目介绍

安装

所有相关项都在 conda_env.yml 文件中。它们可以手动安装,也可以使用以下命令安装:

conda env create -f conda_env.yml

使用说明

要从基于图像的观察中训练 CURL agent 完成 cartpole swingup 任务,请从该目录的根目录运行 bash script/run.sh。run.sh 文件包含以下命令,也可以对其进行修改以尝试不同的环境/超参数

CUDA_VISIBLE_DEVICES=0 python train.py \
    --domain_name cartpole \
    --task_name swingup \
    --encoder_type pixel \
    --action_repeat 8 \
    --save_tb --pre_transform_image_size 100 --image_size 84 \
    --work_dir ./tmp \
    --agent curl_sac --frame_stack 3 \
    --seed -1 --critic_lr 1e-3 --actor_lr 1e-3 --eval_freq 10000 --batch_size 128 --num_train_steps 1000000

在控制台中,应该看到如下所示的输出:

| train | E: 221 | S: 28000 | D: 18.1 s | R: 785.2634 | BR: 3.8815 | A_LOSS: -305.7328 | CR_LOSS: 190.9854 | CU_LOSS: 0.0000
| train | E: 225 | S: 28500 | D: 18.6 s | R: 832.4937 | BR: 3.9644 | A_LOSS: -308.7789 | CR_LOSS: 126.0638 | CU_LOSS: 0.0000
| train | E: 229 | S: 29000 | D: 18.8 s | R: 683.6702 | BR: 3.7384 | A_LOSS: -311.3941 | CR_LOSS: 140.2573 | CU_LOSS: 0.0000
| train | E: 233 | S: 29500 | D: 19.6 s | R: 838.0947 | BR: 3.7254 | A_LOSS: -316.9415 | CR_LOSS: 136.5304 | CU_LOSS: 0.0000

cartpole swing up 的最高分数约为 845 分。而且,CURL 如何以小于 50k 的步长解决 visual cartpole。根据使用者的 GPU 不同而定,大约需要一个小时的训练。同时作为参考,最新的端到端方法 D4PG 需要 50M 的 timesteps 来解决相同的问题。

Log abbreviation mapping:

train - training episode
E - total number of episodes 
S - total number of environment steps
D - duration in seconds to train 1 episode
R - mean episode reward
BR - average reward of sampled batch
A_LOSS - average loss of actor
CR_LOSS - average loss of critic
CU_LOSS - average loss of the CURL encoder

与运行相关的所有数据都存储在指定的 working_dir 中。若要启用模型或视频保存,请使用--save_model 或--save_video。而对于所有可用的标志,需要检查 train.py。使用 tensorboard 运行来进行可视化:

tensorboard --logdir log --port 6006

同时在浏览器中转到 localhost:6006。如果运行异常,可以尝试使用 ssh 进行端口转发。

对于使用 GPU 加速渲染,确保在计算机上安装了 EGL 并设置了 export MUJOCO_GL = egl。


理论RLDreamer谷歌BAIR强化学习
1
相关数据
DeepMind机构

DeepMind是一家英国的人工智能公司。公司创建于2010年,最初名称是DeepMind科技(DeepMind Technologies Limited),在2014年被谷歌收购。在2010年由杰米斯·哈萨比斯,谢恩·列格和穆斯塔法·苏莱曼成立创业公司。继AlphaGo之后,Google DeepMind首席执行官杰米斯·哈萨比斯表示将研究用人工智能与人类玩其他游戏,例如即时战略游戏《星际争霸II》(StarCraft II)。深度AI如果能直接使用在其他各种不同领域,除了未来能玩不同的游戏外,例如自动驾驶、投资顾问、音乐评论、甚至司法判决等等目前需要人脑才能处理的工作,基本上也可以直接使用相同的神经网上去学而习得与人类相同的思考力。

https://deepmind.com/
范数技术

范数(norm),是具有“长度”概念的函数。在线性代数、泛函分析及相关的数学领域,是一个函数,其为向量空间内的所有向量赋予非零的正长度或大小。半范数反而可以为非零的向量赋予零长度。

深度强化学习技术

强化学习(Reinforcement Learning)是主体(agent)通过与周围环境的交互来进行学习。强化学习主体(RL agent)每采取一次动作(action)就会得到一个相应的数值奖励(numerical reward),这个奖励表示此次动作的好坏。通过与环境的交互,综合考虑过去的经验(exploitation)和未知的探索(exploration),强化学习主体通过试错的方式(trial and error)学会如何采取下一步的动作,而无需人类显性地告诉它该采取哪个动作。强化学习主体的目标是学习通过执行一系列的动作来最大化累积的奖励(accumulated reward)。 一般来说,真实世界中的强化学习问题包括巨大的状态空间(state spaces)和动作空间(action spaces),传统的强化学习方法会受限于维数灾难(curse of dimensionality)。借助于深度学习中的神经网络,强化学习主体可以直接从原始输入数据(如游戏图像)中提取和学习特征知识,然后根据提取出的特征信息再利用传统的强化学习算法(如TD Learning,SARSA,Q-Learnin)学习控制策略(如游戏策略),而无需人工提取或启发式学习特征。这种结合了深度学习的强化学习方法称为深度强化学习。

基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

收敛技术

在数学,计算机科学和逻辑学中,收敛指的是不同的变换序列在有限的时间内达到一个结论(变换终止),并且得出的结论是独立于达到它的路径(他们是融合的)。 通俗来说,收敛通常是指在训练期间达到的一种状态,即经过一定次数的迭代之后,训练损失和验证损失在每次迭代中的变化都非常小或根本没有变化。也就是说,如果采用当前数据进行额外的训练将无法改进模型,模型即达到收敛状态。在深度学习中,损失值有时会在最终下降之前的多次迭代中保持不变或几乎保持不变,暂时形成收敛的假象。

损失函数技术

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

超参数技术

在机器学习中,超参数是在学习过程开始之前设置其值的参数。 相反,其他参数的值是通过训练得出的。 不同的模型训练算法需要不同的超参数,一些简单的算法(如普通最小二乘回归)不需要。 给定这些超参数,训练算法从数据中学习参数。相同种类的机器学习模型可能需要不同的超参数来适应不同的数据模式,并且必须对其进行调整以便模型能够最优地解决机器学习问题。 在实际应用中一般需要对超参数进行优化,以找到一个超参数元组(tuple),由这些超参数元组形成一个最优化模型,该模型可以将在给定的独立数据上预定义的损失函数最小化。

伪代码技术

伪代码,又称为虚拟代码,是高层次描述算法的一种方法。它不是一种现实存在的编程语言;它可能综合使用多种编程语言的语法、保留字,甚至会用到自然语言。 它以编程语言的书写形式指明算法的职能。相比于程序语言它更类似自然语言。它是半形式化、不标准的语言。

表征学习技术

在机器学习领域,表征学习(或特征学习)是一种将原始数据转换成为能够被机器学习有效开发的一种技术的集合。在特征学习算法出现之前,机器学习研究人员需要利用手动特征工程(manual feature learning)等技术从原始数据的领域知识(domain knowledge)建立特征,然后再部署相关的机器学习算法。虽然手动特征工程对于应用机器学习很有效,但它同时也是很困难、很昂贵、很耗时、并依赖于强大专业知识。特征学习弥补了这一点,它使得机器不仅能学习到数据的特征,并能利用这些特征来完成一个具体的任务。

计算机视觉技术

计算机视觉(CV)是指机器感知环境的能力。这一技术类别中的经典任务有图像形成、图像处理、图像提取和图像的三维推理。目标识别和面部识别也是很重要的研究领域。

卷积神经网络技术

卷积神经网路(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。卷积神经网路由一个或多个卷积层和顶端的全连通层(对应经典的神经网路)组成,同时也包括关联权重和池化层(pooling layer)。这一结构使得卷积神经网路能够利用输入数据的二维结构。与其他深度学习结构相比,卷积神经网路在图像和语音识别方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他深度、前馈神经网路,卷积神经网路需要考量的参数更少,使之成为一种颇具吸引力的深度学习结构。 卷积网络是一种专门用于处理具有已知的、网格状拓扑的数据的神经网络。例如时间序列数据,它可以被认为是以一定时间间隔采样的一维网格,又如图像数据,其可以被认为是二维像素网格。

映射技术

映射指的是具有某种特殊结构的函数,或泛指类函数思想的范畴论中的态射。 逻辑和图论中也有一些不太常规的用法。其数学定义为:两个非空集合A与B间存在着对应关系f,而且对于A中的每一个元素x,B中总有有唯一的一个元素y与它对应,就这种对应为从A到B的映射,记作f:A→B。其中,y称为元素x在映射f下的象,记作:y=f(x)。x称为y关于映射f的原象*。*集合A中所有元素的象的集合称为映射f的值域,记作f(A)。同样的,在机器学习中,映射就是输入与输出之间的对应关系。

目标函数技术

目标函数f(x)就是用设计变量来表示的所追求的目标形式,所以目标函数就是设计变量的函数,是一个标量。从工程意义讲,目标函数是系统的性能标准,比如,一个结构的最轻重量、最低造价、最合理形式;一件产品的最短生产时间、最小能量消耗;一个实验的最佳配方等等,建立目标函数的过程就是寻找设计变量与目标的关系的过程,目标函数和设计变量的关系可用曲线、曲面或超曲面表示。

查询技术

一般来说,查询是询问的一种形式。它在不同的学科里涵义有所不同。在信息检索领域,查询指的是数据库和信息系统对信息检索的精确要求

动量技术

优化器的一种,是模拟物理里动量的概念,其在相关方向可以加速SGD,抑制振荡,从而加快收敛

强化学习技术

强化学习是一种试错方法,其目标是让软件智能体在特定环境中能够采取回报最大化的行为。强化学习在马尔可夫决策过程环境中主要使用的技术是动态规划(Dynamic Programming)。流行的强化学习方法包括自适应动态规划(ADP)、时间差分(TD)学习、状态-动作-回报-状态-动作(SARSA)算法、Q 学习、深度强化学习(DQN);其应用包括下棋类游戏、机器人控制和工作调度等。

堆叠技术

堆叠泛化是一种用于最小化一个或多个泛化器的泛化误差率的方法。它通过推导泛化器相对于所提供的学习集的偏差来发挥其作用。这个推导的过程包括:在第二层中将第一层的原始泛化器对部分学习集的猜测进行泛化,以及尝试对学习集的剩余部分进行猜测,并且输出正确的结果。当与多个泛化器一起使用时,堆叠泛化可以被看作是一个交叉验证的复杂版本,利用比交叉验证更为复杂的策略来组合各个泛化器。当与单个泛化器一起使用时,堆叠泛化是一种用于估计(然后纠正)泛化器的错误的方法,该泛化器已经在特定学习集上进行了训练并被询问了特定问题。

连续控制技术

连续控制代指需要进行连续控制的任务,经典例子包括推杆摆动,3D人形运动等等。

暂无评论
暂无评论~