Pete Warden作者丁楠雅校对申利彬翻译

如何改善你的训练数据集?(附案例)

这张幻灯片是Andrej Karpathy 在Train AI 演讲的一部分,我很赞同它表达的观点。它充分体现了深度学习在研究和应用上的差异。学术论文几乎全部集中在新的和改进的模型上,使用的数据集是从公共数据集中选出的一小部分。相反,我认识的将深度学习作为实际应用的一部分人,他们大部分时间都在思考如何改善训练数据。

关于研究人员专注于模型架构有很多好的理由,但它确实意味着很少有资源可以引导那些专注于在生产中部署机器学习的人。我在会上的发言是“那些有效到不合常理的训练数据”。在这里我想稍微扩展一下,解释训练数据为什么如此重要,以及一些改进它的实用技巧。

因为工作的原因,我需要与许多研究人员和产品团队紧密合作。我对于改善数据可以带来效果提升的信念来源于我看到它们在构建模型时取得了巨大的成果。现实世界的大部分应用中,运用深度学习的最大障碍就是没有足够高的精度,而我看到提高精度的最快的方法就是改善训练数据集。即使你被困在延迟或存储大小等其他约束上,你可以通过更小的架构来换取一些性能特征,这样可以提高特定模型的准确性。

语音指令

我不能分享我对生产系统的大部分观察,但是我有一个开源例子可以证明同样的道理。去年,我用Tensorflow创建了一个简单的语音识别的例子,但是事实证明,没有现有的数据集可以很容易地用于训练模型。不过在很多志愿者慷慨的帮助下,我收集了60000个由他们说的短语音频片段。在此感谢“开放式语音录制网站”(Open Speech Recording site)的AIY团队帮我发起这个项目。最后得到的模型是可以使用的,但并没有达到我所希望的精度。

为了看看模型设计者的身份对我产生的局限性有多大,我使用相同的数据集发起了一个Kaggle比赛。参赛者的结果要比我最初的模型好很多,但即使有很多团队提出很多不同的方法,最后达到91%精度的只有很少的一部分人。对我来说,这意味着数据有一些根本上的错误,而且参赛者也确实发现了很多错误,比如不正确的标签或者截断的音频。更多的样本开始促使我关注数据集新版本中他们发现的固定的问题。

我查看了错误评价指标,来了解模型中哪些词的问题最多。结果发现“其它”类别(当语音被识别,但单词表不在模型的有限词汇表内)特别容易出错。为了解决这个问题,我增加了我们正在捕获的不同单词的数量,以提供更多样化的训练数据。

因为Kaggle参赛者提出的标签错误,我“众包”了一个额外的验证通道,要求人们听每个剪辑,并确保它可以匹配到期望的标签。另外,他们也发现了一些几乎无声或者被截断的文件,因此我编写了一个实用工具来做一些音频分析,并自动剔除糟糕的样本。尽管删除了一些糟糕的文件,最后我还是将总的说话数量增加到了100000。这要感谢更多志愿者和收费“众包”者的帮助。

为了帮助别人使用这个数据集(并从我的错误中学习),我将所有相关的事情和最新的精度结果写进了一篇论文。最重要的结论是,在不改变模型或测试数据的前提下,第一名的精度提高了4%,从85.4%提高到了89.7%。这个提高让人很激动,并且当人们在Android 或 Raspberry Pi 演示应用中使用该模型时,反映了更高的满意度。我相信如果我花时间在模型架构的调整上,尽管我知道我的模型不如最好的模型,最终我得到的精度的提高肯定没有现在的多。

论文:

https://arxiv.org/abs/1804.03209

这就是在生产环境中一次又一次地产生伟大结果的过程。但是如果你想做同样的事情,很难知道从哪里开始,你可以从我处理语音数据所使用的技巧中得到一些启发。为了更加明确,这里有一些我发现的有用的方法。

首先,了解你的数据

这似乎是显而易见的,但你的第一步应该是随机浏览你将要开始使用的训练数据。复制一些数据文件到你本地的机器上,然后花费几个小时预览它们。如果你的数据集是图片,可以使用类似MacOS’s的查找器来滚动缩略图视图,可以很快的检查完数千张图片。

对于音频,可以使用取景器播放预览,对于文本可以将随机片段转存到终端上。

在第一个版本的语音指令中,我没有花费足够的时间来做这些。这也是为什么Kaggle参赛者一开始使用这个数据集就会发现很多问题。经历这个过程我总觉得有点傻,但事后我再也没有后悔过。每次我做完这个过程,我都会从数据中发现一些重要的事情。比如是否各类别中例子的数量不均衡,损坏的数据(例如,用JPG文件扩展标记的PNG),不正确的标签,或者只是令人惊讶的组合。

Tom White通过观察ImageNet得到了一些奇妙的发现,包括“太阳镜”标签实际上是一个古老的放大阳光的装置,用于“垃圾车”的魅力镜头,对不死女性的“斗篷”偏见。Andrej’s的工作是手工从ImageNet中分类照片,这也教会我关于数据集的很多东西。包括即使对于一个人来说,将所有不同品种的狗区分出来是有多难。

你将要做什么取决于你发现了什么。你应该在清洗数据之前总是进行一次这种数据观察,因为,对数据集的直观认识将会有助于你在接下来的流程中做决策。

快速选择一个模型

不要在选择模型上花费太多时间。如果你在做图片分类,可以参考AutoML,或者看看类似Tensorflow的模型库,再或者从Fast.AI搜集的例子中找一个解决类似问题的模型(http://www.fast.ai/)。重要的是尽快开始迭代,这样你就可以提前和真实用户一起尝试你的模型。你总是可以在以后得出一个改进的模型,并且也许可以得到更好的结果,但是你首先要得到数据。深度学习仍然遵循‘垃圾入,垃圾出’(“garbage in, garbage out”)的基本计算法则,所以即使是最好的模型也会受到训练集缺陷的限制。通过挑选一个模型并测试它,你将能够得知这些缺陷是什么并且开始改进它们。

AutoML

https://cloud.google .com/automl/

Fast.AI:

http://www.fast.ai/

为了加快你的迭代速度,可以尝试从一个已经在一个大的现有数据集上预先训练的模型开始,然后使用迁移学习在你收集的数据集(可能很小)上进行微调。

这通常比只在较小的数据集上进行训练的效果要好得多,而且速度快得多,并且你可以快速地了解如何调整数据收集策略。最重要的是,你可以把你的结果反馈到你的收集过程中,以适应你学习的情况,而不是在训练之前把收集数据作为一个单独的阶段来进行。

成为它之前先假装它

研究模型和生产模型的最大区别在于研究通常在开始时有明确的问题陈述,但是实际应用的要求被锁定在用户的意识行为中,并且只能随着时间的推移而被提取。

例如,在Jetpac中我们想要找到一张好的照片去展现在城市自动旅行指南中。我们开始时要求评价人给他们认为好的照片打一个标签,但最后我们看到了很多微笑的人的照片,因为他们就是这样解释这个问题的。我们把这些放在产品的模型中,看看测试用户是如何反应的。结果是他们没有留下深刻的印象,也没有被这些照片所鼓舞。

为了解决这个问题,我们重新定义了提问的问题:“这张照片会让你想去它所展示的地方吗?”。这使我们得到了更好的结果,但也反应出我们使用的工人是东南亚人,他们认为会议照片看起来令人很惊异,因为大饭店里充满了穿西装和拿红酒杯的人。

这种不匹配及时提醒了我们生活在“泡沫”里,但这也确实是一个现实的问题,因为我们美国的目标观众看到这些会议照片会感到沮丧和没有理想。最后,我们在JETPAC团队中的六个人手动评估了超过二百万张照片,因为我们比我们可以训练的任何人都要熟悉标准。

这是一个极端的例子,但是它证明了标记过程很大程度上取决于应用的需求。对大多数生产用例来说,存在一个要为模型找合适的问题去回答的过程,而且这才是关键所在。如果你用你的模型回答了错误的问题,你将永远无法在这个糟糕的基础上建立一个可靠的用户体验。

Thomas Hawk拍摄

我已经告诉你询问正确问题的唯一方法就是模仿你的应用,而不是一个人陷在机器学习循环中。因为有一个人在幕后,这有时被称为‘Wizard-of-Oz-ing’。我们让人们手动选择一些旅行指南的样本照片,而不是训练一个模型,然后使用来自测试用户的反馈来调整我们挑选图片的标准。

一旦我们从测试用户那里得到可靠的正向反馈,为了得到数百万张照片的训练集,我们会把制定的挑选照片的规则转换为标签集。然后,它训练了能够预测数十亿张照片质量的模型,但是它的DNA来自我们开发的原始手工规则。

在真实的数据上训练

在Jetpac,我们用来训练我们模型的图像来自相同的数据源(大部分来自Facebook和Instagram) ,也是我们想用在模型上的图像。我所看到的一个常见问题是训练数据集在重要的方面与模型最终会在生产中看到的输入不同。

目前世界上图像识别最大的数据库ImageNet

例如,我经常会看到团队在ImageNet上训练一个模型,但当他们试图在无人机或机器人中使用时就会碰到问题。原因ImageNet都是人拍摄的照片,这些照片有很多共同之处。它们是用手机或静态相机拍摄的,使用中性透镜,在大致的高度,白天或人工照明的条件下,把对象标记在中心突出的位置。

机器人和无人机使用的摄像机通常是高视野镜头。无论是从地面还是从上方,照明都很差,没有任何对象的智能框架,因此它们通常被裁剪。这种差异意味着如果你只接受一个从ImageNet的照片中训练出来的模型,并将其部署在这些设备上,那么你就会发现精确度不高。

有关你的训练数据偏离模型本来应该需要的训练数据,还存在很多微妙的形式。想象一下,你正在建造一个相机来识别野生动物,并利用世界各地的动物数据集进行训练。如果你只在Borneo丛林中部署,那么企鹅标签的正确率肯定是极低的。

如果南极照片被包含在训练数据中,那么它将有更高的几率将其他东西误认为企鹅,所以你的总错误率会比你排除那些训练中的图像更糟糕。有一些方法可以根据已知的先验信息来校准你的结果(例如,在丛林环境下大规模的企鹅的概率),但是使用一个反映产品实际遇到的情况的训练集更容易和更有效。

我发现,最好的方法是使用直接从实际应用程序得到的数据,这些数据与上面提到的Wizard of Oz方法很好地联系在一起。循环中的人成为初始数据集的打标签者,即使收集的标签数量很小,它们也会反映实际使用情况,并且对于迁移学习的一些初步实验应该是足够的。

遵循指标

当我在做语音指令的例子时,看到的最频繁的报告就是训练过程中的混淆矩阵。这里有一个例子,展示了如何在控制台中显示:

这看起来可能很吓人,但实际上它只是一张表格,显示了网络所犯的错误的细节。这里有一份更漂亮的标签版本:

表格中的每一行代表一组样本,其中真实的标签是相同的。每一列代表样本被预测为对应标签的次数。例如,高亮显示的一行代表所有实际上是无声的音频样本,如果你从左读到右,你可以看到那些预测正确的标签,每一个都落在预测无声的列中。

这告诉我们,这个模型可以很好地发现真正的无声样本,并且没有负样本。如果我们看一下展示有多少将音频预测为无声的一整列,就可以发现一些音频片段实际上是误分到无声的一列中的,这一列有很多假正例。

事实证明这个是很有帮助的,因为它可以让我更加仔细地分析那些被错误地归类为无声的片段,从而发现他们大部分是极其安静的录音。根据混淆矩阵提供的线索,我清除了低音量的音频片段,这帮助我提高了数据质量。

虽然大多数结果是有用的,但是我发现混淆矩阵是一个很好的折衷,因为它比仅仅一个精确值给的信息要多,却又没有呈现太多复杂的细节。在训练过程中观察数字的变化是很有用的,因为它可以告诉你模型正在努力学习的类别,并且可以让你在清理和扩展数据集时集中精力。

相似的方法

我最喜欢的一种理解我的模型如何解释训练数据的方法就是可视化。TensorBoard可以很好的支持这种探索,虽然它经常用来可视化词嵌入,但是我发现它几乎对每一层都很有用,工作原理也像词嵌入。例如,图像分类网络通常在最后一层的全连接层或者softmax之前有一层网络可以用来作为嵌入(这就是简单的迁移学习的例子,和TensorFlow for Poets(地址如下)工作流程很像)。

这些并不是严格意义上的嵌入,因为在训练过程中并没有任何机制去保证真正的嵌入布局中有理想的空间属性,但是对它们的向量进行聚类确实可以产生很多有趣的东西。

链接:

https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#2

举一个实际的例子,我合作的一个团队对某些动物的图像分类模型的高错误率感到很困惑。他们使用聚类可视化去观察训练数据中不同的类别是如何分布的。当他们在看“捷豹”这个类别时,很清楚的看到数据被分为两组之间的距离。

图片来自djblock99Dave Adams

这是他们看到的一幅图,一旦每个聚类的照片都显示出来,就可以很明显的发现许多捷豹品牌的汽车都被错误地贴上了捷豹猫的标签。如果团队成员知道了这些,那么就会去关注标注过程,并且可以意识到工人的方向和用于标注的用户界面不够完善。

有了这些信息,他们就能够改进标注者(人)的培训过程并且去修复标注工具。这可以将所有的汽车图像从捷豹类别中移除,并为这一类别提供了一个更好的模型。

聚类通过让你对训练集进行深刻的了解,可以让你得到与你探索数据相似的好处。但是,网络实际上是按照它自己的学习理解将输入数据排序分组,然后指导你探索数据。

人类很擅长在视觉信息中发现异常,因此将我们的直觉和计算机处理大量数据的能力结合起来是一种非常灵活的追踪数据集质量的解决方案。关于如何使用TensorBoard来做这件事超出了本文的范围(文章已经足够长了,我很感激你还在继续读下去)。但是如果你真的想提高你的结果,我强烈建议你熟悉这个工具。

收集数据不能停

我从来没有见过收集更多的数据不能提高模型准确性的例子,而且也有很多研究可以支持我的经验。

这张图片来自“重新审视那些有效到不合常理的训练数据”,并且展示了即使数据集已经增长到了数亿,图像分类模型的精度依然不断增加。

Facebook最近更加深入的使用大数据量,例如,在ImageNet分类中使用了数十亿个带有标签的Instagram图片,以达到新的记录精度。这表明,即使对于大型、高质量数据集的问题,增加训练集的大小仍然可以提高模型结果。

这意味着只要用户可以从更高精度的模型中受益,你就需要一个不断改善数据质量的策略。如果可以的话,找到一种创造性的方法,利用即使微弱的信号也可以得到更大的数据集。Facebook使用Instagram标签就是一个很好的例子。

还有一种方法是提高标注“管道”的智能性,例如通过增加由初始模型预测的建议标签的工具,这样可以使打标签的人快速做决定。这在刚开始可能有风险,但是在实际应用中受益往往超过了这种风险。

通过雇佣更多的人来给新的训练数据贴上标签来解决这个问题通常也是一项有价值的投资。不过因为这种花费通常没有预算,组织过程中会有很多困难。如果是一个非盈利的组织,则可以让你的支持者通过某种公共工具自愿贡献数据,这是一种在不花费钱的同时提高数据集规模的好方式。

当然任何组织都希望有一个产品,当它在正常使用时可以生成标注数据。

我不会太执着于这样的想法,它不符合很多现实世界的用例。即人们只是想尽快得到一个答案而并不涉及标签的复杂问题。如果你是一家创业公司,这是一个很好的投资项目,因为它就像是一台用于改进模型的永动机。

但是在清理或增加你接收到的数据时,几乎总是会有一些单位成本,因此,最后花的钱往往最终看起来更像是一个廉价版的商业众包,而不是真正免费的东西。

通往危险区域的高速公路

模型错误对产品用户的影响往往要大于由损失函数捕捉到的错误。你应该提前想到可能发生的最糟糕的结果,并为模型设计一个辅助程序来避免发生。这也许是一个你永远都不想预测的类别黑名单,因为假正例的代价太大。

或者你仅仅有一套简单算法去保证发生的结果不会超过你已经设定的参数边界。例如,你可能会保留一个永远不希望文本生成器输出的粗俗语言的列表,即使它们在训练集中,因为它们不适合出现在产品中。

因为我们不能总是知道未来可能会出现什么不好的结果,所以学习现实世界中的错误是很重要的。如果你有了合适的产品或市场,那么从现实中学习最简单的办法就是使用错误报告。

另外,当用户使用你的应用程序出现他们不想要的东西时,应该给用户一个便捷的反馈路径。如果可以的话,获取模型的全部输入,但是如果数据是敏感数据,那么仅仅知道错误的输出是什么也可以帮助你调查原因。这些类别可以用来决定收集更多什么样的数据,并且这些类别可以让你理解当前标签的质量。

一旦你对模型进行了新的修改,就会有一组先前产生了坏结果的输入,并且除了正常的测试集之外,还对它们进行单独的评估。这个有点像一个回归测试,并给你一个方法追踪你改进用户体验的效果如何,因为单一的模型精度度量永远不会完全捕捉到人们关心的一切。

通过看一些过去引起强烈反应的例子,你就有了一些独立证据表明你实际上是在为你的用户做得更好。如果在一些情况下因为数据太敏感而不能得到输入数据,可以使用内部测试或者内部实验来确定什么样的输入会产生这些错误,然后代替回归数据集中的那些数据。

故事是什么,昙花一现?

我希望我已经说服你花更多的时间在你的数据上,并且给你了一些关于如何投入精力改进它的想法。对数据领域的关注并没有它值得的那么多,而且我真的觉得我在这里的建议仅仅是涉及数据表面。

我很感谢所有与我分享他们的策略的人,另外我希望我可以从更多的人那里听到你已经取得成功的方法。我认为会有越来越多的机构将工程师团队专门用于数据集的改进,而不是让机器学习研究人员来推动进展。我期待着看到整个领域的发展。

我总是惊叹于即使是在有着严重缺陷训练数据的情况下模型依然可以运作良好。因此我迫不及待的想看看随着我们数据集质量的提高我们可以做些什么。

入门数据集
2
相关数据
混淆矩阵技术
Confusion matrix

混淆矩阵也称误差矩阵,是表示精度评价的一种标准格式,用n行n列的矩阵形式来表示。具体评价指标有总体精度、制图精度、用户精度等,这些精度指标从不同的侧面反映了图像分类的精度。在人工智能中,混淆矩阵(confusion matrix)是可视化工具,特别用于监督学习,在无监督学习一般叫做匹配矩阵。矩阵的每一行表示预测类中的实例,而每一列表示实际类中的实例(反之亦然)。 这个名字源于这样一个事实,即很容易看出系统是否混淆了两个类。

无人机技术
Drones

无人机(Uncrewed vehicle、Unmanned vehicle、Drone)或称无人载具是一种无搭载人员的载具。通常使用遥控、导引或自动驾驶来控制。可在科学研究、军事、休闲娱乐用途上使用。

机器学习技术
Machine Learning

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

损失函数技术
Loss function

在数学优化,统计学,计量经济学,决策理论,机器学习和计算神经科学等领域,损失函数或成本函数是将一或多个变量的一个事件或值映射为可以直观地表示某种与之相关“成本”的实数的函数。

参数技术
parameter

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

语音识别技术
Speech Recognition

自动语音识别是一种将口头语音转换为实时可读文本的技术。自动语音识别也称为语音识别(Speech Recognition)或计算机语音识别(Computer Speech Recognition)。自动语音识别是一个多学科交叉的领域,它与声学、语音学、语言学、数字信号处理理论、信息论、计算机科学等众多学科紧密相连。由于语音信号的多样性和复杂性,目前的语音识别系统只能在一定的限制条件下获得满意的性能,或者说只能应用于某些特定的场合。自动语音识别在人工智能领域占据着极其重要的位置。

迁移学习技术
Transfer learning

迁移学习是一种机器学习方法,就是把为任务 A 开发的模型作为初始点,重新使用在为任务 B 开发模型的过程中。迁移学习是通过从已学习的相关任务中转移知识来改进学习的新任务,虽然大多数机器学习算法都是为了解决单个任务而设计的,但是促进迁移学习的算法的开发是机器学习社区持续关注的话题。 迁移学习对人类来说很常见,例如,我们可能会发现学习识别苹果可能有助于识别梨,或者学习弹奏电子琴可能有助于学习钢琴。

词嵌入技术
Word embedding

词嵌入是自然语言处理(NLP)中语言模型与表征学习技术的统称。概念上而言,它是指把一个维数为所有词的数量的高维空间嵌入到一个维数低得多的连续向量空间中,每个单词或词组被映射为实数域上的向量。

深度学习技术
Deep learning

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

张量技术
Tensor

张量是一个可用来表示在一些矢量、标量和其他张量之间的线性关系的多线性函数,这些线性关系的基本例子有内积、外积、线性映射以及笛卡儿积。其坐标在 维空间内,有 个分量的一种量,其中每个分量都是坐标的函数,而在坐标变换时,这些分量也依照某些规则作线性变换。称为该张量的秩或阶(与矩阵的秩和阶均无关系)。 在数学里,张量是一种几何实体,或者说广义上的“数量”。张量概念包括标量、矢量和线性算子。张量可以用坐标系统来表达,记作标量的数组,但它是定义为“不依赖于参照系的选择的”。张量在物理和工程学中很重要。例如在扩散张量成像中,表达器官对于水的在各个方向的微分透性的张量可以用来产生大脑的扫描图。工程上最重要的例子可能就是应力张量和应变张量了,它们都是二阶张量,对于一般线性材料他们之间的关系由一个四阶弹性张量来决定。

TensorFlow技术
TensorFlow

TensorFlow是一个开源软件库,用于各种感知和语言理解任务的机器学习。目前被50个团队用于研究和生产许多Google商业产品,如语音识别、Gmail、Google 相册和搜索,其中许多产品曾使用过其前任软件DistBelief。

数据派THU
数据派THU

发布清华大学数据科学相关科研动态、教学成果及线下活动。

THU数据派
THU数据派

THU数据派"基于清华,放眼世界",以扎实的理工功底闯荡“数据江湖”。发布全球大数据资讯,定期组织线下活动,分享前沿产业动态。了解清华大数据,敬请关注姐妹号“数据派THU”。

返回顶部