Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

Mathieu Carrière作者和中华校对孙韬淳翻译

拓扑机器学习的神圣三件套:Gudhi,Scikit-Learn和Tensorflow(附链接&代码)

Hi大家好。今天,我想强调下在机器学习中拓扑数据分析(TDA,Topological Data Analysis)的力量,并展示如何配合三个Python库:Gudhi,Scikit-Learn和Tensorflow进行实践。

拓扑数据分析

首先,让我们谈谈TDA。它是数据科学中相对小众的一个领域,尤其是当与机器学习深度学习对比的时候。但是它正迅速成长,并引起了数据科学家的注意。很多初创企业和公司正积极把这些技术整合进它们的工具箱中(比如IBM,Fujitsu,Ayasdi),原因则是近年来它在多种应用领域的成功,包括生物学、时间序列、金融、科学可视化、计算机图形学等。未来我可能会写一个关于TDA一般用途和最佳实践的帖子,所以请大家等待下。

TDA:

https://en.wikipedia.org/wiki/Topological_data_analysis

IBM

https://researcher.watson.ibm.com/researcher/view_group.php?id=6585

Fujitsu:

https://www.fujitsu.com/global/about/resources/news/press-releases/2016/0216-01.html

Ayasdi:

https://www.ayasdi.com/platform/technology/

生物学:

https://www.ncbi.nlm.nih.gov/pubmed/28459448

时间序列:

https://www.ams.org/journals/notices/201905/rnoti-p686.pdf

金融:

https://papers.ssrn.com/sol3/papers.cfm?abstract_id=2931836

科学可视化:

https://topology-tool-kit.github.io/

计算机图形学:

http://www.lix.polytechnique.fr/~maks/papers/top_opt_SGP18.pdf

TDA的目标是对你数据的拓扑性质进行计算和编码,这意味着记录数据集中多样的连接成分,环,腔和高维结构。这非常有用,主要是因为其他描述符不可能计算这类信息。所以TDA真的储存了一组你不可能在其他地方找到的数据特征。现实情况是这类特征已被证明对提升机器学习预测能力很有用,所以如果你以前还没见过或听过这类特征,我来带你快速了解一下。
我已经写过很多这个主题的文章,你可以在Medium找到关于TDA的很多其他帖子,所以我不打算浪费时间在数学定义上面,而是通过解释TDA文献中的典型例子,来展示如何在你的数据集上应用TDA。

文章:
https://towardsdatascience.com/mixing-topology-and-deep-learning-with-perslay-2e60af69c321

帖子:
https://towardsdatascience.com/applied-topological-data-analysis-to-deep-learning-hands-on-arrhythmia-classification-48993d78f9e6

TDA的参考示例:点云分类

这个数据集在一篇开创性的TDA文章上介绍过。它由通过下述动力系统生成的轨迹来得到的点云集组成:

开创性的TDA文章
http://jmlr.org/papers/v18/16-337.html

一个动力系统的方程

这意味着我们将从一个单位正方形内随机抽取一个初始点,并通过上面的方程生成一个点的序列。这将给我们一个点云。现在我们可以根据意愿重复这个操作,得到一堆点云。这些点云的一个有趣的属性在于,根据你用来生成点序列的r参数的值,点云会有非常不一样且有意思的结构。比如,如果r=3.5,得到的点云似乎覆盖了整个单位正方形,但如果r=4.1,单位正方形的一些区域就是空的:换句话说,在你的点云里有好多洞。这对我们是个好消息:TDA可以直接计算这些结构是否能出现。

r=3.5(左)和r=4.1(右)计算出的点云。相当明显的是后者有个洞,但前者没有

TDA跟踪这些洞的方式实际上相当简单。想象给定半径为R的每个球的圆心都在你点云的每个点上。如果R=0,这些球的并集就是点云本身。如果R为无穷,那么球的并集是整个单位正方形。但如果R被很精心的选择,球的并集可能存在很多拓扑结构,比如,洞。

球并集的例子。对于中间图的并集,它清晰的组成了一个洞。整张图片被我“不要脸”地借用自我之前的一个帖子帖子

https://towardsdatascience.com/a-concrete-application-of-topological-data-analysis-86b89aa27586

那么,为了避免人工选择R的“好值”,TDA将针对每一个可能的R值(从0到无穷)计算球的并集,并记录每个洞出现或者消失时的半径,并对一些点使用这些半径值作为二维坐标。TDA的输出则是另一个点云,其中每个点代表一个洞:这叫做Rips持续图(Rips persistence diagram)。假设点云在一个numpy数组X中储存(shape为N*2),通过Gudhi,这个图可以用两行代码计算出来:这个漂亮的持续图由r=4.1对应的点云计算出。红色的点代表相连的成分,蓝色的点代表洞接下来我们将解决的任务则是给定点云预测r的值。

通过Gudhi+Scikit-Learn进行拓扑机器学习

持续图很简洁,是不是?它们存在的问题则是,从不同点云计算出的持续图可能有不同数量的点(因为点云可能有不同数量的洞)。所以如果你想用Scikit-Learn从持续图中预测r,不幸的是,没有直接的方法,因为这些库预期输入是一个结构化的向量。这也是为什么目前大量的工作是关于将这些持续图转化为固定长度的欧几里得向量,或者是开发对应的核。这很棒,但是你应该使用哪种呢?
不要担心!Gudhi再一次给你解决办法。通过它的表达(representation)模块,你不仅可以计算所有的向量和核,甚至也可以使用Scikit-Learn来交叉验证并且(或)选择最佳的一种。就像下面这么简单:

表达

https://gudhi.inria.fr/python/latest/representations.html

在前面的代码中,我尝试了带切片Wasserstein核和持续权重Gaussian核的核SVM、带有Persistence Images的C-SVM,带有Persistence Landscapes的随机森林,和一个带有所谓的持久图之间瓶颈距离(bottleneck distance)的简单KNN。在Gudhi中还有许多其他的可能,所以你一定要试试!如果想了解更多细节你也可以看看Gudhi的Tutorial。

带切片Wasserstein核:

http://proceedings.mlr.press/v70/carriere17a/carriere17a.pdf

持续权重Gaussian核:

http://proceedings.mlr.press/v48/kusano16.html

Persistence Images:

http://jmlr.org/papers/v18/16-337.html

Persistence Landscapes:

http://www.jmlr.org/papers/volume16/bubenik15a/bubenik15a.pdf

Gudhi的Tutorial:

https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-representations.ipynb

用Gudhi和Tensorflow/Pytorch进行拓扑优化

我很确信你目前已经成为了TDA的爱好者。如果你仍不相信,我还有其他的东西给你,这是受这篇论文启发。想象你现在想解决一个更难的问题:我想让你给我一个点云,这个点云的持续图有尽可能多的点。换句话说,你需要生成一个有好多洞的点云。

论文:
https://arxiv.org/abs/1905.12200

我可以看见你额头上出汗了。但我是很仁慈的,转眼间就能让你知道Gudhi(1)可以做这个。想一想:当你生成一个持续图时,这个图中不同点的坐标并不受全部的初始点云影响,是不是?对于这个持续图的一个给定点p,p的坐标仅依赖于在初始点云中组成p对应洞的点的位置,以一种简单的方式:这些坐标仅是球的并集使得这个洞出现或者消失时候的半径;或者,等价表达是,这些点中的最大的成对距离。而Gudhi(2)可以通过它的persistence_pairs()函数找出这些关系。梯度则可以简单的定义成欧几里得距离函数的导数(正式定义见这篇论文)。

Gudhi(1):

http://gudhi.gforge.inria.fr/python/latest/

Gudhi(2):

https://gudhi.inria.fr/python/latest/

这篇论文:

https://sites.google.com/view/hiraoka-lab-en/research/mathematical-research/continuation-of-point-cloud-data-via-persistence-diagram

接下来让我们写两个函数,第一个从点云中计算Rips持续图,第二个计算持续图点集的导数。为了可读性我简化了一点点代码,实际的代码可以从这里找到。

https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-optimization.ipynb

现在让我们把函数封装进Tensorflow函数中(对Pytorch同样简单),并定义一个损失loss,这个损失是持续图点到其对角线的距离的相反数。这将迫使图有很多点,它们的纵坐标比横坐标大得多。这样的话,一个点云会有很多大尺寸的洞。

现在我们开始优化!这是epochs 0,20,90的结果:

好多洞,好漂亮。。我们是不是在梦里。如果你想往前看看,使用其它的损失,查阅这个Gudhi的tutorial。

https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-optimization.ipynb

最后的话

这个帖子仅是一瞥由Gudhi,Scikit-Learn和Tensorflow提供的众多可能性。我希望我可以使你相信,在你的流程中整合TDA已经成为很简单的事情。即使许多TDA应用已经在文献中出现,肯定还有更多的应用需要去发现!

原文标题:
The Holy Trinity of Topological Machine Learning: Gudhi, Scikit-Learn and Tensorflow
原文链接:
https://towardsdatascience.com/the-holy-trinity-of-topological-machine-learning-gudhi-scikit-learn-and-tensorflow-pytorch-3cda2aa249b5
THU数据派
THU数据派

THU数据派"基于清华,放眼世界",以扎实的理工功底闯荡“数据江湖”。发布全球大数据资讯,定期组织线下活动,分享前沿产业动态。了解清华大数据,敬请关注姐妹号“数据派THU”。

理论机器学习数据可视化PythonGudhiTensorFlowSciKit-Learn
1
相关数据
IBM机构

是美国一家跨国科技公司及咨询公司,总部位于纽约州阿蒙克市。IBM主要客户是政府和企业。IBM生产并销售计算机硬件及软件,并且为系统架构和网络托管提供咨询服务。截止2013年,IBM已在全球拥有12个研究实验室和大量的软件开发基地。IBM虽然是一家商业公司,但在材料、化学、物理等科学领域却也有很高的成就,利用这些学术研究为基础,发明很多产品。比较有名的IBM发明的产品包括硬盘、自动柜员机、通用产品代码、SQL、关系数据库管理系统、DRAM及沃森。

https://www.ibm.com/us-en/
相关技术
深度学习技术

深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。 深度学习是机器学习中一种基于对数据进行表征学习的算法,至今已有数种深度学习框架,如卷积神经网络和深度置信网络和递归神经网络等已被应用在计算机视觉、语音识别、自然语言处理、音频识别与生物信息学等领域并获取了极好的效果。

数据分析技术

数据分析是一类统计方法,其主要特点是多维性和描述性。有些几何方法有助于揭示不同的数据之间存在的关系,并绘制出统计信息图,以更简洁的解释这些数据中包含的主要信息。其他一些用于收集数据,以便弄清哪些是同质的,从而更好地了解数据。 数据分析可以处理大量数据,并确定这些数据最有用的部分。

计算机图形技术

图像数据处理、计算机图像(英语:Computer Graphics)是指用计算机所创造的图形。更具体的说,就是在计算机上用专门的软件和硬件用来表现和控制图像数据。

权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

数据科学技术

数据科学,又称资料科学,是一门利用数据学习知识的学科,其目标是通过从数据中提取出有价值的部分来生产数据产品。它结合了诸多领域中的理论和技术,包括应用数学、统计、模式识别、机器学习、数据可视化、数据仓库以及高性能计算。数据科学通过运用各种相关的数据来帮助非专业人士理解问题。

导数技术

导数(Derivative)是微积分中的重要基础概念。当函数y=f(x)的自变量x在一点x_0上产生一个增量Δx时,函数输出值的增量Δy与自变量增量Δx的比值在Δx趋于0时的极限a如果存在,a即为在x0处的导数,记作f'(x_0) 或 df(x_0)/dx。

随机森林技术

在机器学习中,随机森林是一个包含多个决策树的分类器,并且其输出的类别是由个别树输出的类别的众数而定。 Leo Breiman和Adele Cutler发展出推论出随机森林的算法。而"Random Forests"是他们的商标。这个术语是1995年由贝尔实验室的Tin Kam Ho所提出的随机决策森林(random decision forests)而来的。这个方法则是结合Breimans的"Bootstrap aggregating"想法和Ho的"random subspace method" 以建造决策树的集合。

欧几里得距离技术

在数学中,欧几里得距离或欧几里得度量是欧几里得空间中两点间“普通”(即直线)距离。 使用这个距离,欧氏空间成为度量空间。

云计算技术

云计算(英语:cloud computing),是一种基于互联网的计算方式,通过这种方式,共享的软硬件资源和信息可以按需求提供给计算机各种终端和其他设备。

动力系统技术

动态系统(dynamical system)是数学上的一个概念。动态系统是一种固定的规则,它描述一个给定空间(如某个物理系统的状态空间)中所有点随时间的变化情况。例如描述钟摆晃动、管道中水的流动,或者湖中每年春季鱼类的数量,凡此等等的数学模型都是动态系统。 在动态系统中有所谓状态的概念,状态是一组可以被确定下来的实数。状态的微小变动对应这组实数的微小变动。这组实数也是一种流形的几何空间坐标。动态系统的演化规则是一组函数的固定规则,它描述未来状态如何依赖于当前状态的。这种规则是确定性的,即对于给定的时间间隔内,从现在的状态只能演化出一个未来的状态。 若只是在一系列不连续的时间点考察系统的状态,则这个动态系统为离散动态系统;若时间连续,就得到一个连续动态系统。如果系统以一种连续可微的方式依赖于时间,我们就称它为一个光滑动态系统。

交叉验证技术

交叉验证,有时亦称循环估计, 是一种统计学上将数据样本切割成较小子集的实用方法。于是可以先在一个子集上做分析, 而其它子集则用来做后续对此分析的确认及验证。 一开始的子集被称为训练集。而其它的子集则被称为验证集或测试集。交叉验证的目标是定义一个数据集到“测试”的模型在训练阶段,以便减少像过拟合的问题,得到该模型将如何衍生到一个独立的数据集的提示。

暂无评论
暂无评论~