Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

用微信控制深度学习训练的Keras插件

创意来源


深度学习训练是一个非常耗时、枯燥的过程:一次训练少则几个小时,多则数天,而且中途能人为干预的机会更是少之又少,在大部分时间里能做的只有等待。

不少人都有这样强迫症,脚本开始运行后会不停的看准确率和 loss,就像像球迷看球赛、股民盯报价一样刺激。一般来说,想要监控深度学习训练,只能使用 ssh 或者 Tensorboard。一旦需要外出,只能通过手机监控深度学习,操作十分麻烦,于是只能无奈地守在电脑前空耗生命。

我们急需一种办法能够将我们从电脑前解放,让监控深度学习变得简单方便,不再受时间、空间和平台的限制。在当今社会,有什么东西是在任何时间任何地点我们都能使用的呢?答案只有一个:微信。

微信是每个中国人接入社会的重要接口,我们用微信在地铁上看鸡汤、在商场使用电子支付、瘫在沙发上刷朋友圈……总之,微信几乎包办了我们生活中的一切,让人产生一种错觉,在不久的未来手机上只需要装微信一个软件就够了。那么,如果微信能把监控深度学习也包办了岂不是妙哉?

在知乎上看到@Coldwings - 利用微信监管你的TF训练(https://zhuanlan.zhihu.com/p/25597975?group_id=822180572054048768)的用微信监控 Tensorflow 训练的项目,很受启发,决定开发一款用微信监控 Keras 训练的插件。欢迎 fork 这个项目的 GitHub:https://github.com/QuantumLiu/wechat_callback。

原理介绍

本插件基于 python3.5,python2.7 需要将_thread 改为 thread。当 import 程序文件时,将首先通过 ItChat(https://github.com/littlecodersh/ItChat)在服务器扫码登录微信(网页版)。

插件的主体是 sendmessage(),一个 Keras 的 keras.callbacks.Callback() 类,训练时被传入 fit() 方法的 callbacklist。在 on_train_begin(self, logs={}) 时,利用 @itchat.msg_register(TEXT) 注册针对不同命令的响应方法,启动新线程开始实时监控命令。

对命令的识别采用的是 python 的 if any((k in text) for k in cmdlist): 方法,即只要发送的消息内容包含符合条件的关键词即可出发,冗余字符不影响命令的触发。

训练过程中,由手机微信与“文件传输助手”间的通讯来实现命令的传递与内容反馈。


如图,“获取图表{batches}”和“gpu[MEMORY TEMPERATURE]”是使用者本人发送的命令,服务器响应命令,以给文件助手发送消息的方式反馈内容(图表、汇报和状态)。

主要功能


目前已经实现了被动监控、主动查询、远程关机/停止训练等多项功能。

1. 实时监控:在每个epoch结束后,自动发送本epoch训练信息以及两张分别代表所有batch和epoch信息的图表至文件传输助手。


2. 主动查询:在训练开始后的任意时刻,发送特定格式的指令,可获得指定查询项的信息。目前支持 batch 和 epoch 的各个指标的信息、显卡状态信息。


3. 远程控制:当你觉得训练已经收敛,或者因任何原因需要停止训练时,可以优雅地终止训练甚至关机。Keras 的 fit 方法中,可以通过在 callback 设置 self.model.stop_training = True 来实现在当前 epoch 结束时终止训练,否则只能 Ctrl+C 暴力停止。利用本插件,可以使用特殊格式的指令来指定停止 epoch、立刻停止训练,甚至关机和取消关机。


例子与讲解

1. 准备工作

git clone https://github.com/QuantumLiu/wechat_callback.git

cd wechat_callback

需要用到的库:itchat, keras, numpy, scipy, matplotlib, _thread(py3)

请确保 nvidia-smi 可用,如果 windows 的 cmd 里找不到命令,请手动将 nvidia-smi.exe 所在位置添加进环境变量。

2. 运行测试脚本

python wechat_test.py

解析:

在 wechat_test.py 的开头,首先 import  wechat_utils

import wechat_utils #will login automaticly#wechat_utils.sendmessage()isthe callback class#wechat_utils.sendmessage()是 keras 的回调类,fit 时传入 callbacklist

在 wechat_utils.py 中:

# Automaticly login when imported #在被import时自动登录#==============================================================================itchat.auto_login(enableCmdQR=0.5,hotReload=True)itchat.dump_login_status()#dump

可以看到,当 wechat_utils 被 import 时会调用 itchat.auto_login(),不出意外的话,将会在命令行显示二维码,需要使用手机微信扫码登录你的微信账号。


在测试脚本里我使用 numpy.random 来生成训练数据,搭建了一个多层的 FC 网络

model = Sequential()model.add(Dense(2048, input_dim=784))model.add(Activation('relu'))for i in range(9):    model.add(Dense(2048))    model.add(Activation('relu'))model.add(Dense(1,activation='sigmoid')) x=np.random.rand(nb_sample,dim)   y=np.random.randint(2,size=(nb_sample,1))

调用插件非常简单,只需要在 fit 时把 wechat_utils.sendmessage() 这个 keras 的 Callback 类传入 Callbacklist。

model.fit(x=train_x,y=train_y,batch_size=batch_size,nb_epoch=60,validation_data=(val_x,val_y),callbacks=[wechat_utils.sendmessage()])

于是训练开始,手机会收到如下反馈:


现在,我们可以向它发送查询指令,指令一般包括关键词和参数,以获取图表为例,包含以下任意关键词将被识别为获取图表指令:

[u'获取图表','Show me the figure']

参数则用{}或[]来指定,所有的指令均支持不指定参数,获取图表的默认参数是查询所有信息,例如:


‘Show me the figure’ 触发了指令,{batches} 表示查询 batches 级别信息,[losshinge] 表示查询 loss 和 hinge 指标(一般的,同一属性参数用空格隔开)。

同理,['GPU','gpu',u'显卡'] 是 gpu 状态查询的关键词,用[]指定参数,如图,查询了 gpu 的显存和温度。GPU 参数是根据 nvidia-smi 的预置参数确定的,全部都是大写,具体可查询属性请看 GitHub 的 readme 或者阅读源码。


关机指令关键词是 [u'关机','Shut down','Shut down the computer',u'别浪费电了',u'洗洗睡吧'],使用 {sec} 和 [name] 指定等待时间和保存文件名,文件名不包括.h5。默认保存模型,如果不想保存,可以在消息中包含 [u'不保存模型',"don't save"] 比如:

Shut down now{120},don't save

取消关机只需要包含 [u'取消','cancel','aaaa'] 就可以了,也就是说如果着急的话打一串 a 发过去也是可以的。


立刻停止训练的关键词是 ['Stopnow',"That's enough",u'停止训练',u'放弃治疗'] (《西部世界》看多了)。


指定停止 epoch 的关键词是 ‘Stop at’,参数可以直接用整数表示,不需要 []。

命令与关键词列表

1. 远程停止训练

关键词列表:

['Stop now',"That'senough",u'停止训练',u'放弃治疗']  

说明:发送的消息中包含任意一项都可触发命令。将在当前epoch结束后停止训练。无参数。

2. 远程关机

关键词列表:

[u'关机','Shut down','Shut down the computer',u'别浪费电了',u'洗洗睡吧']

说明:发在指定秒数后关机,用 {sec} 和 [name] 指定参数等待时间和保存文件名,文件名不包括 .h5。如果同时包含 [u'不保存模型',"don't save"],则不会保存模型。例:发送 'Shut down now [test]{120}',电脑将在 120 秒后关机,将模型保存为 test.h5。若发送 'Shut down now{120},don't save',则模型将不会被保存。

3. 取消关机

关键词列表:

[u'取消','cancel','aaaa']

说明:发送的消息中包含任意一项都可触发命令。无参数。

4. 获取图表

关键词列表:

[u'获取图表','Show me the figure']

说明:发送的消息中包含任意一项都可触发命令。通过 [metrics] 和 {level} 指定参数,如果没有指定则皆默认为 ’all'。例:手机发送"获取图表[loss]{batches}",会收到一个 jpg 格式的 loss 随 batches 变化的图片。手机发送"获取图表",则会得到两张图片,分别是所有指标随 batch 和 epoch 的变化。

5. 指定训练停止轮数

关键词列表:

'Stop at '

参数可以直接用整数表示,不需要 []。例:手机发送“Stop at:8”,训练将在 epoch8 完成后停止。

6. 查询显卡状态

关键词列表:

['GPU','gpu',u'显卡']

参数使用 [TYPE] 来指定,GPU 参数是根据 nvidia-smi 的预置参数确定的,全部都是大写,在 [] 内用空格分隔。

可用参数列表:

['MEMORY', 'UTILIZATION', 'ECC', 'TEMPERATURE','POWER', 'CLOCK', 'COMPUTE', 'PIDS', 'PERFORMANCE','SUPPORTED_CLOCKS,PAGE_RETIREMENT', 'ACCOUNTING']

例:发送'gpu[MEMORY]'或者'显卡[MEMORY]'查询显存使用;发送'GPU[MEMORYTEMPERATURE]'查询显存和温度。

总结

这个项目从有想法算起到写注释、开 GitHub、写知乎不过两天半,做的很匆忙也很粗糙,特别是画图的细节和多线程的处理。

我只是一名阿语专业大一学生(休学ing),水平十分有限,恳请各位多加指点,提高我的姿势水平,如果这个项目能给你带来一点点便利或者灵感,那么我将感到十分荣幸与欣慰。

再次感谢 @Coldwings 的原创创意。


PaperWeekly
PaperWeekly

推荐、解读、讨论和报道人工智能前沿论文成果的学术平台。

工程工程深度学习模型训练TensorFlow
2
暂无评论
暂无评论~