参与:思源

模型转代码:XGBoost等模型也能快速转纯C或Java代码

你是否常训练炫酷的机器学习模型,用来分析数据或执行某些新奇的任务?你是否发现你的模型只能在一定开发环境上才能跑起来,很难部署也很难加入其它程序?今天我们将介绍一个炫酷的工具,它可以把构建在 scikit-learn 或 XGBoost 等库上的 ML 模型直接转化为不需要任何依赖项的 Java/Python/C 源代码。

  • 项目地址:https://github.com/BayesWitnesses/m2cgen/

那么转化为 Java/Python/C 源代码有什么用呢?想象一下如果我们使用 ML 框架(scikit-learn\XGBoost\LightGBM)训练了一个模型,现在我们希望把这个模型做成应用或嵌入到已有的模型中,那么我们肯定需要考虑这些问题:

  • 如果产品环境并没有 Python 运行时怎么办?

  • 如果产品不能通过云服务器进行计算,只能在本地进行怎么办?

  • ML 模型的推断速度太慢又怎么办?

这些问题都很难解决,也是开发者在做自己项目时常遇到的问题。如果我们能将用 Python 和 ML 库构建的模型转换一下,变成纯 Java 或 C 写的代码,且这些代码不会依赖各种库,那么部署或嵌入不就简单了么。在 m2cgen 这个项目中,它就可以将 ML 模型转化为不带有依赖项的纯代码。

m2cgen(Model 2 Code Generator)是一个轻量级的 Python 库,它能快速便捷地将已训练统计模型转化为 Python、C 和 Java 代码。目前 m2cgen 已经支持各种分类模型与回归模型,包括支持向量机决策树随机森林梯度提升树等,注意这些都是统计机器学习模型,深度神经网络还是老老实实使用 DL 框架吧。

模型转换效果

我们可以通过几个简单案例了解 m2cgen 是如何转换为纯代码的,简单而言即把模型架构和权重显化了。如下所示如果我们简单地训练一个线性回归模型,当然代码看着简单是因为我们直接调用了 scikit-learn 库中的模型。

from sklearn.datasets import load_boston
from sklearn import linear_model
import m2cgen as m2c

boston = load_boston()
X, y = boston.data, boston.target

estimator = linear_model.LinearRegression()
estimator.fit(X, y)

code = m2c.export_to_java(estimator)

上面最后一行将 scikit-learn 中的线性回归模型转化为 Java 代码,注意这个模型已经拟合了训练数据,或者说已经完成了训练。转化后的代码如下所示:

public class Model {

    public static double score(double[] input) {
        return (((((((((((((36.45948838508965) + ((input[0]) * (-0.10801135783679647))) + ((input[1]) * (0.04642045836688297))) + ((input[2]) * (0.020558626367073608))) + ((input[3]) * (2.6867338193449406))) + ((input[4]) * (-17.76661122830004))) + ((input[5]) * (3.8098652068092163))) + ((input[6]) * (0.0006922246403454562))) + ((input[7]) * (-1.475566845600257))) + ((input[8]) * (0.30604947898516943))) + ((input[9]) * (-0.012334593916574394))) + ((input[10]) * (-0.9527472317072884))) + ((input[11]) * (0.009311683273794044))) + ((input[12]) * (-0.5247583778554867));
    }
}

如上 return 后面的语句,它就是一个线性回归的表达式,每一个 input[ * ] 都是一种特征,它后面的数值就是训练后的权重。所以整个线性模型有 13 个特征及对应权重,以及另外一个偏置项。

我们还可以找到更多的案例,如果我们用 XGBoost 训练一个简单的分类模型,我们可以看到转化的代码会大量使用 if-else 大法,不过我们本身也不用维护生成的代码,所以这种结构也没什么关系了。

import numpy as np
def score(input):
    if (input[2]) >= (2.5999999):
        var0 = -0.0731707439
    else:
        var0 = 0.142857149
    if (input[2]) >= (2.5999999):
        var1 = -0.0705206916
    else:
        var1 = 0.12477719
    var2 = np.exp(((0.5) + (var0)) + (var1))
    if (input[2]) >= (2.5999999):
        if (input[2]) >= (4.85000038):
            var3 = -0.0578680299
        else:
            var3 = 0.132596686
    else:
        var3 = -0.0714285821
    if (input[2]) >= (2.5999999):
        if (input[2]) >= (4.85000038):
            var4 = -0.0552999191
        else:
            var4 = 0.116139404
    else:
        var4 = -0.0687687024
    var5 = np.exp(((0.5) + (var3)) + (var4))
    if (input[2]) >= (4.85000038):
        if (input[3]) >= (1.75):
            var6 = 0.142011836
        else:
            var6 = 0.0405405387
    else:
        if (input[3]) >= (1.6500001):
            var6 = 0.0428571403
        else:
            var6 = -0.0730659068
    if (input[2]) >= (4.85000038):
        if (input[3]) >= (1.75):
            var7 = 0.124653712
        else:
            var7 = 0.035562478
    else:
        if (input[3]) >= (1.6500001):
            var7 = 0.0425687581
        else:
            var7 = -0.0704230517
    var8 = np.exp(((0.5) + (var6)) + (var7))
    var9 = ((var2) + (var5)) + (var8)
    return np.asarray([(var2) / (var9), (var5) / (var9), (var8) / (var9)])

不过上面这种代码也非常合理,本身决策树就可以视为一种 if-else 的规则集合,不同输入特征 input[ * ] 满足不同的条件就能得到不同的值,这些值最后能联合计算分类结果。

项目细节

工具的安装很简单,直接用 pip 就行了:

pip install m2cgen

除了前面那样在代码中调用转换工具,我们还能通过命令行使用序列化的模型目标(pickle protocol)生成代码:

$ m2cgen <pickle_file> --language <language> [--indent <indent>]
         [--class_name <class_name>] [--package_name <package_name>]
         [--recursion-limit <recursion_limit>]

目前项目支持以下分类和回归模型的转换:


分类模型输出结果:



工程SciKit-LearnGitHubPython
3
相关数据
权重技术

线性模型中特征的系数,或深度网络中的边。训练线性模型的目标是确定每个特征的理想权重。如果权重为 0,则相应的特征对模型来说没有任何贡献。

机器学习技术

机器学习是人工智能的一个分支,是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。

梯度提升技术

梯度提升是用于回归和分类问题的机器学习技术,其以弱预测模型(通常为决策树)的集合的形式产生预测模型。 它像其他增强方法一样以阶段式方式构建模型,并且通过允许优化任意可微损失函数来推广它们。

统计模型技术

统计模型[stochasticmodel;statisticmodel;probabilitymodel]指以概率论为基础,采用数学统计方法建立的模型。有些过程无法用理论分析方法导出其模型,但可通过试验测定数据,经过数理统计法求得各变量之间的函数关系,称为统计模型。常用的数理统计分析方法有最大事后概率估算法、最大似然率辨识法等。常用的统计模型有一般线性模型、广义线性模型和混合模型。统计模型的意义在对大量随机事件的规律性做推断时仍然具有统计性,因而称为统计推断。常用的统计模型软件有SPSS、SAS、Stata、SPLM、Epi-Info、Statistica等。

随机森林技术

在机器学习中,随机森林是一个包含多个决策树的分类器,并且其输出的类别是由个别树输出的类别的众数而定。 Leo Breiman和Adele Cutler发展出推论出随机森林的算法。而"Random Forests"是他们的商标。这个术语是1995年由贝尔实验室的Tin Kam Ho所提出的随机决策森林(random decision forests)而来的。这个方法则是结合Breimans的"Bootstrap aggregating"想法和Ho的"random subspace method" 以建造决策树的集合。

线性回归技术

在现实世界中,存在着大量这样的情况:两个变量例如X和Y有一些依赖关系。由X可以部分地决定Y的值,但这种决定往往不很确切。常常用来说明这种依赖关系的最简单、直观的例子是体重与身高,用Y表示他的体重。众所周知,一般说来,当X大时,Y也倾向于大,但由X不能严格地决定Y。又如,城市生活用电量Y与气温X有很大的关系。在夏天气温很高或冬天气温很低时,由于室内空调、冰箱等家用电器的使用,可能用电就高,相反,在春秋季节气温不高也不低,用电量就可能少。但我们不能由气温X准确地决定用电量Y。类似的例子还很多,变量之间的这种关系称为“相关关系”,回归模型就是研究相关关系的一个有力工具。

支持向量机技术

在机器学习中,支持向量机是在分类与回归分析中分析数据的监督式学习模型与相关的学习算法。给定一组训练实例,每个训练实例被标记为属于两个类别中的一个或另一个,SVM训练算法创建一个将新的实例分配给两个类别之一的模型,使其成为非概率二元线性分类器。SVM模型是将实例表示为空间中的点,这样映射就使得单独类别的实例被尽可能宽的明显的间隔分开。然后,将新的实例映射到同一空间,并基于它们落在间隔的哪一侧来预测所属类别。

XGBoost技术

XGBoost是一个开源软件库,为C ++,Java,Python,R,和Julia提供了渐变增强框架。 它适用于Linux,Windows,MacOS。从项目描述来看,它旨在提供一个“可扩展,便携式和分布式的梯度提升(GBM,GBRT,GBDT)库”。 除了在一台机器上运行,它还支持分布式处理框架Apache Hadoop,Apache Spark和Apache Flink。 由于它是许多机器学习大赛中获胜团队的首选算法,因此它已经赢得了很多人的关注。

深度神经网络技术

深度神经网络(DNN)是深度学习的一种框架,它是一种具备至少一个隐层的神经网络。与浅层神经网络类似,深度神经网络也能够为复杂非线性系统提供建模,但多出的层次为模型提供了更高的抽象层次,因而提高了模型的能力。

360机构

奇虎360科技有限公司,是中国领先的互联网和手机安全产品及服务供应商。据第三方统计,按照用户数量计算,360是中国领先的互联网安全公司,用户6亿,市场渗透率96.6%;中国领先的移动互联网安全公司,用户数近8亿,市场渗透率近70%;中国领先的浏览器公司之一,活跃用户达到4亿,渗透率超过70%。 360致力于通过提供高品质的免费安全服务,为中国互联网用户解决上网时遇到的各种安全问题。面对互联网时代木马、病毒、流氓软件、钓鱼欺诈网页等多元化的安全威胁,360以互联网的思路解决网络安全问题。360是免费安全的首倡者,认为互联网安全像搜索、电子邮箱、即时通讯一样,是互联网的基础服务,应该免费。为此,360安全卫士、360杀毒等系列安全产品免费提供给中国数亿互联网用户。同时,360开发了全球规模和技术均领先的云安全体系,能够快速识别并清除新型木马病毒以及钓鱼、挂马恶意网页,全方位保护用户的上网安全。

https://www.360.cn/
推荐文章
暂无评论
暂无评论~