参与:路作者:edvardHua

实时单人姿态估计,在自己手机上就能实现 : ) 安卓和iOS都可以哦~

本文介绍了如何使用 TensorFlow 在智能机上(包括安卓和 iOS 设备)执行实时单人姿态估计。

GitHub 地址:https://github.com/edvardHua/PoseEstimationForMobile

该 repo 使用 TensorFlow 实现 CPM 和 Hourglass 模型。这里未使用常规的卷积,而是在模型内部使用了反向卷积(又叫 Mobilenet V2),以便执行实时推断。

注:你可以修改网络架构,来训练更高 PCKh 的模型。架构地址:https://github.com/edvardHua/PoseEstimationForMobile/tree/master/training/src

该库包含:

  • 训练 CPM 和 Hourglass 模型的代码;

  • 安卓 demo 的源代码;

  • IOS demo 的源代码。

下面的 gif 是在 Mi Mix2s 上截取的(~60 FPS)

你可以下载以下 apk,在自己的设备上进行测试。

  • PoseEstimation-Mace.apk:https://raw.githubusercontent.com/edvardHua/PoseEstimationForMobile/master/release/PoseEstimation-Mace.apk

  • PoseEstimation-TFlite.apk:https://raw.githubusercontent.com/edvardHua/PoseEstimationForMobile/master/release/PoseEstimation-TFlite.apk

训练

依赖项

  • Python3

  • TensorFlow >= 1.4

  • Mace

数据集

训练数据集:https://drive.google.com/open?id=1zahjQWhuKIYWRRI2ZlHzn65Ug_jIiC4l

将其解压缩,获取以下文件结构:

# root @ ubuntu in ~/hdd/ai_challenger
$ tree -L 1 ..
├── ai_challenger_train.json
├── ai_challenger_valid.json
├── train
└── valid

该训练数据集仅包含单人图像,数据来源是 AI Challenger 竞赛。共包含 22446 个训练样本和 1500 个测试样本。

该 repo 作者使用 tf-pose-estimation 库中的数据增强代码将标注迁移为 COCO 格式。tf-pose-estimation 库:https://github.com/ildoonet/tf-pose-estimation

参数

训练步骤中,使用 experiments 文件夹中的 cfg 文件传输参数

以下是 mv2_cpm.cfg 文件的内容:

[Train]
model: 'mv2_cpm'
checkpoint: False
datapath: '/root/hdd/ai_challenger'
imgpath: '/root/hdd/'
visible_devices: '0, 1, 2'
multiprocessing_num: 8
max_epoch: 1000
lr: '0.001'
batchsize: 5
decay_rate: 0.95
input_width: 192
input_height: 192
n_kpoints: 14
scale: 2
modelpath: '/root/hdd/trained/mv2_cpm/models'
logpath: '/root/hdd/trained/mv2_cpm/log'
num_train_samples: 20000
per_update_tensorboard_step: 500
per_saved_model_step: 2000
pred_image_on_tensorboard: True

该 cfg 文件覆盖模型的所有参数,在 network_mv2_cpm.py 中仍有一些参数

使用 nvidia-docker 训练

通过以下命令构建 docker:

cd training/docker
docker build -t single-pose .

或者

docker pull edvardhua/single-pose

然后运行以下命令,训练模型:

nvidia-docker run -it -d \
-v <dataset_path>:/data5 -v <training_code_path>/training:/workspace \
-p 6006:6006 -e LOG_PATH=/root/hdd/trained/mv2_cpm/log \
-e PARAMETERS_FILE=experiments/mv2_cpm.cfg edvardhua/single-pose

此外,它还在 port 6006 上创建了 tensorboard。确保安装了 nvidia-docker。

按一般方法训练

1. 安装依赖项:

cd training
pip3 install -r requirements.txt

还需要安装 cocoapi (https://github.com/cocodataset/cocoapi)。

2. 编辑 experiments 文件夹中的参数文件,它包含几乎所有参数和训练中需要定义的其他配置。之后,传输参数文件,开始训练:

cd training
python3 src/train.py experiments/mv2_cpm.cfg

在 3 张英伟达 1080Ti 显卡上经过 12 个小时的训练后,该模型几乎收敛。以下是对应的 tensorboard 图。

基准(PCKh)

运行以下命令,评估 PCKh 值。

python3 src/benchmark.py --frozen_pb_path=hourglass/model-360000.pb \
--anno_json_path=/root/hdd/ai_challenger/ai_challenger_valid.json \
--img_path=/root/hdd \
--output_node_name=hourglass_out_3

预训练模型

  • CPM:https://github.com/edvardHua/PoseEstimationForMobile/tree/master/release/cpm_model 

  • Hourglass:https://github.com/edvardHua/PoseEstimationForMobile/tree/master/release/hourglass_model

安卓 demo

由于 mace 框架,你可以使用 GPU 在安卓智能机上运行该模型。

按照以下命令将模型转换为 mace 格式:

cd <your-mace-path># You transer hourglass or cpm model by changing `yml` file.
python tools/converter.py convert --config=<PoseEstimationForMobilePath>/release/mace_ymls/cpm.yml

然后根据 mace 文档的说明,将模型集成到安卓设备中。

至于如何调用模型、解析输出,可以参见安卓源代码:https://github.com/edvardHua/PoseEstimationForMobile/tree/master/android_demo。

一些芯片的平均推断时间基准如下所示:

以下是该 repo 作者构建该 demo 的环境:

  • 操作系统:macOS 10.13.6(mace 目前不支持 windows)

  • Android Studio:3.0.1

  • NDK 版本:r16

在构建 mace-demo 时,不同环境可能会遇到不同的错误。为避免这种情况,作者建议使用 docker。

docker pull registry.cn-hangzhou.aliyuncs.com/xiaomimace/mace-dev-lite

docker run -it
 --privileged -d --name mace-dev 
 --net=host 
 -v to/you/path/PoseEstimationForMobile/android_demo/demo_mace:/demo_mace 
 registry.cn-hangzhou.aliyuncs.com/xiaomimace/mace-dev-lite

docker run -it --privileged -d --name mace-dev --net=host \
 -v to/you/path/PoseEstimationForMobile/android_demo/demo_mace:/demo_mace \
 registry.cn-hangzhou.aliyuncs.com/xiaomimace/mace-dev-lite
# Enter to docker
docker exec -it mace-dev bash
# Exec command inside the dockercd /demo_mace && ./gradlew build

或者将模型转换为 tflite:

# Convert to frozen pb.cd training
python3 src/gen_frozen_pb.py \
--checkpoint=<you_training_model_path>/model-xxx --output_graph=<you_output_model_path>/model-xxx.pb \
--size=192 --model=mv2_cpm_2
# If you update tensorflow to 1.9, run following command.
python3 src/gen_tflite_coreml.py \
--frozen_pb=forzen_graph.pb \
--input_node_name='image' \
--output_node_name='Convolutional_Pose_Machine/stage_5_out' \
--output_path='./' \
--type=tflite
 # Convert to tflite.# See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/mobile/tflite/devguide.md for more information.
bazel-bin/tensorflow/contrib/lite/toco/toco \
--input_file=<you_output_model_path>/model-xxx.pb \
--output_file=<you_output_tflite_model_path>/mv2-cpm.tflite \
--input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
--inference_type=FLOAT \
--input_shape="1,192,192,3" \
--input_array='image' \
--output_array='Convolutional_Pose_Machine/stage_5_out'

然后,将 tflite 文件放在 android_demo/app/src/main/assets 中,修改 ImageClassifierFloatInception.kt 中的参数

............// parameters need to modify in ImageClassifierFloatInception.kt/**     * Create ImageClassifierFloatInception instance     *     * @param imageSizeX Get the image size along the x axis.     * @param imageSizeY Get the image size along the y axis.     * @param outputW The output width of model     * @param outputH The output height of model     * @param modelPath Get the name of the model file stored in Assets.     * @param numBytesPerChannel Get the number of bytes that is used to store a single     * color channel value.     */
    fun create(
      activity: Activity,
      imageSizeX: Int = 192,
      imageSizeY: Int = 192,
      outputW: Int = 96,
      outputH: Int = 96,
      modelPath: String = "mv2-cpm.tflite",
      numBytesPerChannel: Int = 4
    ): ImageClassifierFloatInception =
      ImageClassifierFloatInception(
          activity,
          imageSizeX,
          imageSizeY,
          outputW,
          outputH,
          modelPath,
          numBytesPerChannel)............

最后,将该项目导入 Android Studio,在智能机设备上运行。

iOS Demo

首先,将模型转换为 CoreML 模型:

# Convert to frozen pb.cd training
python3 src/gen_frozen_pb.py \
--checkpoint=<you_training_model_path>/model-xxx --output_graph=<you_output_model_path>/model-xxx.pb \
--size=192 --model=mv2_cpm_2
# Run the following command to get mlmodel
python3 src/gen_tflite_coreml.py \
--frozen_pb=forzen_graph.pb \
--input_node_name='image' \
--output_node_name='Convolutional_Pose_Machine/stage_5_out' \
--output_path='./' \
--type=coreml


然后,按照 PoseEstimation-CoreML 中的说明来操作(https://github.com/tucan9389/PoseEstimation-CoreML)。

工程GitHub移动端计算机视觉
71
相关数据
基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

收敛技术

在数学,计算机科学和逻辑学中,收敛指的是不同的变换序列在有限的时间内达到一个结论(变换终止),并且得出的结论是独立于达到它的路径(他们是融合的)。 通俗来说,收敛通常是指在训练期间达到的一种状态,即经过一定次数的迭代之后,训练损失和验证损失在每次迭代中的变化都非常小或根本没有变化。也就是说,如果采用当前数据进行额外的训练将无法改进模型,模型即达到收敛状态。在深度学习中,损失值有时会在最终下降之前的多次迭代中保持不变或几乎保持不变,暂时形成收敛的假象。

超参数技术

在机器学习中,超参数是在学习过程开始之前设置其值的参数。 相反,其他参数的值是通过训练得出的。 不同的模型训练算法需要不同的超参数,一些简单的算法(如普通最小二乘回归)不需要。 给定这些超参数,训练算法从数据中学习参数。相同种类的机器学习模型可能需要不同的超参数来适应不同的数据模式,并且必须对其进行调整以便模型能够最优地解决机器学习问题。 在实际应用中一般需要对超参数进行优化,以找到一个超参数元组(tuple),由这些超参数元组形成一个最优化模型,该模型可以将在给定的独立数据上预定义的损失函数最小化。

TensorFlow技术

TensorFlow是一个开源软件库,用于各种感知和语言理解任务的机器学习。目前被50个团队用于研究和生产许多Google商业产品,如语音识别、Gmail、Google 相册和搜索,其中许多产品曾使用过其前任软件DistBelief。

操作系统技术

操作系统(英语:operating system,缩写作 OS)是管理计算机硬件与软件资源的计算机程序,同时也是计算机系统的内核与基石。操作系统需要处理如管理与配置内存、决定系统资源供需的优先次序、控制输入与输出设备、操作网络与管理文件系统等基本事务。操作系统也提供一个让用户与系统交互的操作界面。

360机构

奇虎360科技有限公司,是中国领先的互联网和手机安全产品及服务供应商。据第三方统计,按照用户数量计算,360是中国领先的互联网安全公司,用户6亿,市场渗透率96.6%;中国领先的移动互联网安全公司,用户数近8亿,市场渗透率近70%;中国领先的浏览器公司之一,活跃用户达到4亿,渗透率超过70%。 360致力于通过提供高品质的免费安全服务,为中国互联网用户解决上网时遇到的各种安全问题。面对互联网时代木马、病毒、流氓软件、钓鱼欺诈网页等多元化的安全威胁,360以互联网的思路解决网络安全问题。360是免费安全的首倡者,认为互联网安全像搜索、电子邮箱、即时通讯一样,是互联网的基础服务,应该免费。为此,360安全卫士、360杀毒等系列安全产品免费提供给中国数亿互联网用户。同时,360开发了全球规模和技术均领先的云安全体系,能够快速识别并清除新型木马病毒以及钓鱼、挂马恶意网页,全方位保护用户的上网安全。

https://www.360.cn/
很cool,学习下