快速开启你的第一个项目:TensorFlow项目架构模板

作为最为流行的深度学习资源库,TensorFlow 是帮助深度学习新方法走向实现的强大工具。它为大多数深度学习领域中使用的常用语言提供了大量应用程序接口。对于开发者和研究人员来说,在开启新的项目前首先面临的问题是:如何构建一个简单明了的结构,本文或许可以为你带来帮助。

项目链接:https://github.com/Mrgemy95/Tensorflow-Project-Template


TensorFlow 项目模板


简洁而精密的结构对于深度学习项目来说是必不可少的,在经过多次练习和 TensorFlow 项目开发之后,本文作者提出了一个结合简便性、优化文件结构和良好 OOP 设计的 TensorFlow 项目模板。该模板可以帮助你快速启动自己的 TensorFlow 项目,直接从实现自己的核心思想开始。

这个简单的模板可以帮助你直接从构建模型、训练等任务开始工作。

目录


  • 概述

  • 详述

  • 项目架构

  • 文件夹结构

  • 主要组件

  • 模型

  • 训练器

  • 数据加载器

  • 记录器

  • 配置

  • Main

  • 未来工作

概述


简言之,本文介绍的是这一模板的使用方法,例如,如果你希望实现 VGG 模型,那么你应该:

在模型文件夹中创建一个名为 VGG 的类,由它继承「base_model」类

  1.   class VGGModel(BaseModel):

  2.        def __init__(self, config):

  3.            super(VGGModel, self).__init__(config)

  4.            #call the build_model and init_saver functions.

  5.            self.build_model()

  6.            self.init_saver()

覆写这两个函数 "build_model",在其中执行你的 VGG 模型;以及定义 TensorFlow 保存的「init_saver」,随后在 initalizer 中调用它们。

  1.    def build_model(self):

  2.        # here you build the tensorflow graph of any model you want and also define the loss.

  3.        pass

  4.     def init_saver(self):

  5.        #here you initalize the tensorflow saver that will be used in saving the checkpoints.

  6.        self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)

在 trainers 文件夹中创建 VGG 训练器,继承「base_train」类。

  1.        class VGGTrainer(BaseTrain):

  2.        def __init__(self, sess, model, data, config, logger):

  3.            super(VGGTrainer, self).__init__(sess, model, data, config, logger)

覆写这两个函数「train_step」、「train_epoch」,在其中写入训练过程的逻辑。

  1.       def train_epoch(self):

  2.        """

  3.       implement the logic of epoch:

  4.       -loop ever the number of iteration in the config and call teh train step

  5.       -add any summaries you want using the sammary

  6.        """

  7.        pass

  8.    def train_step(self):

  9.        """

  10.       implement the logic of the train step

  11.       - run the tensorflow session

  12.       - return any metrics you need to summarize

  13.       """

  14.        pass

在主文件中创建会话,创建以下对象:「Model」、「Logger」、「Data_Generator」、「Trainer」与配置:

  1.      sess = tf.Session()

  2.    # create instance of the model you want

  3.    model = VGGModel(config)

  4.    # create your data generator

  5.    data = DataGenerator(config)

  6.    # create tensorboard logger

  7.    logger = Logger(sess, config)

向所有这些对象传递训练器对象,通过调用「trainer.train()」开始训练。

  1.       trainer = VGGTrainer(sess, model, data, config, logger)

  2.    # here you train your model

  3.    trainer.train()

你会看到模板文件、一个示例模型和训练文件夹,向你展示如何快速开始你的第一个模型。

详述


模型架构



文件夹结构


  1.       ├──  base

  2.   ├── base_model.py   - this file contains the abstract class of the model.

  3.   └── ease_train.py - this file contains the abstract class of the trainer.

  4. ├── model               -This folder contains any model of your project.

  5.   └── example_model.py

  6. ├── trainer             -this folder contains trainers of your project.

  7.   └── example_trainer.py

  8.  

  9. ├──  mains              - here's the main/s of your project (you may need more than one main.

  10. │                        

  11. │  

  12. ├──  data _loader  

  13. │    └── data_generator.py  - here's the data_generator that responsible for all data handling.

  14. └── utils

  15.     ├── logger.py

  16.     └── any_other_utils_you_need

主要组件


模型

  • 基础模型

基础模型是一个必须由你所创建的模型继承的抽象类,其背后的思路是:绝大多数模型之间都有很多东西是可以共享的。基础模型包含:

  • Save-此函数可保存 checkpoint 至桌面。

  • Load-此函数可加载桌面上的 checkpoint。

  • Cur-epoch、Global_step counters-这些变量会跟踪训练 epoch 和全局步。

  • Init_Saver-一个抽象函数,用于初始化保存和加载 checkpoint 的操作,注意:请在要实现的模型中覆盖此函数。

  • Build_model-是一个定义模型的抽象函数,注意:请在要实现的模型中覆盖此函数。

  • 你的模型

以下是你在模型中执行的地方。因此,你应该:

  • 创建你的模型类并继承 base_model 类。

  • 覆写 "build_model",在其中写入你想要的 tensorflow 模型。

  • 覆写"init_save",在其中你创建 tensorflow 保存器,以用它保存和加载检查点。

  • 在 initalizer 中调用"build_model" 和 "init_saver"

训练器

  • 基础训练器

基础训练器(Base trainer)是一个只包装训练过程的抽象的类。

  • 你的训练器

以下是你应该在训练器中执行的。

  • 创建你的训练器类,并继承 base_trainer 类。

  • 覆写这两个函数,在其中你执行每一步和每一 epoch 的训练过程。

数据加载器

这些类负责所有的数据操作和处理,并提供一个可被训练器使用的易用接口。

记录器(Logger)

这个类负责 tensorboard 总结。在你的训练器中创建一个有关所有你想要的 tensorflow 变量的词典,并将其传递给 logger.summarize()。

配置

我使用 Json 作为配置方法,接着解析它,因此写入所有你想要的配置,然后用"utils/config/process_config"解析它,并把这个配置对象传递给所有其他对象。

Main

以下是你整合的所有之前的部分。

1. 解析配置文件。

2. 创建一个 TensorFlow 会话。

3. 创建 "Model"、"Data_Generator" 和 "Logger"实例,并解析所有它们的配置。

4. 创建一个"Trainer"实例,并把之前所有的对象传递给它。

5. 现在你可通过调用"Trainer.train()"训练你的模型。

未来工作


未来,该项目计划通过新的 TensorFlow 数据集 API 替代数据加载器。

工程
暂无评论
暂无评论~
返回顶部