- 开源不易,麻烦给个【Star】
- Github: https://github.com/PanJinquan/Pytorch-Base-Trainer
- pip安装包: https://pypi.org/project/basetrainer/
- 博客地址:https://panjinquan.blog.csdn.net/article/details/122662902
考虑到深度学习训练过程都有一套约定成俗的流程,鄙人借鉴Keras开发了一套基础训练库: Pytorch-Base-Trainer(PBT); 这是一个基于Pytorch开发的基础训练库,支持以下特征:
- [x] 支持多卡训练训练(DP模式)和分布式多卡训练(DDP模式),参考build_model_parallel
- [x] 支持argparse命令行指定参数,也支持config.yaml配置文件
- [x] 支持最优模型保存ModelCheckpoint
- [x] 支持自定义回调函数Callback
- [x] 支持NNI模型剪枝(L1/L2-Pruner,FPGM-Pruner Slim-Pruner)nni_pruning
- [x] 非常轻便,安装简单
诚然,诸多大公司已经开源基础库,如MMClassification,MMDetection等库; 但碍于这些开源库安装麻烦,依赖库多,版本差异大等问题;鄙人开发了一套比较基础的训练Pipeline: Pytorch-Base-Trainer(PBT), 基于PBT可以快速搭建自己的训练工程; 目前,基于PBT完成了通用分类库(PBTClassification),通用检测库(PBTDetection),通用语义分割库( PBTSegmentation)以及,通用姿态检测库(PBTPose)
基于PBT框架训练的模型,已经形成了一套完整的Android端上部署流程,支持CPU和GPU
PS:受商业保护,目前,仅开源Pytorch-Base-Trainer(PBT),基于PBT的分类,检测和分割以及姿态估计训练库,暂不开源。
- 源码安装
git clone https://github.com/PanJinquan/Pytorch-Base-Trainer
cd Pytorch-Base-Trainer
bash setup.sh #pip install dist/basetrainer-*.*.*.tar.gz
# 安装方法1:(有延时,可能不是最新版本)
pip install basetrainer
# 安装方法2:(从pypi源下载最新版本)
pip install --upgrade basetrainer -i https://pypi.org/simple
# Linux or macOS
python3 -m pip install --upgrade nni
# Windows
python -m pip install --upgrade nni
PBT基础训练库定义了一个基类(Base),所有训练引擎(Engine)以及回调函数(Callback)都会继承基类。
Engine
类实现了训练/测试的迭代方法(如on_batch_begin,on_batch_end),其迭代过程参考如下,
用户可以根据自己的需要自定义迭代过程:
self.on_train_begin()
for epoch in range(num_epochs):
self.set_model() # 设置模型
# 开始训练
self.on_epoch_begin() # 开始每个epoch调用
for inputs in self.train_dataset:
self.on_batch_begin() # 每次迭代开始时回调
self.run_step() # 每次迭代返回outputs, losses
self.on_train_summary() # 每次迭代,训练结束时回调
self.on_batch_end() # 每次迭代结束时回调
# 开始测试
self.on_test_begin()
for inputs in self.test_dataset:
self.run_step() # 每次迭代返回outputs, losses
self.on_test_summary() # 每次迭代,测试结束时回调
self.on_test_end() # 结束测试
# 结束当前epoch
self.on_epoch_end()
self.on_train_end()
EngineTrainer
类继承Engine
类,用户需要继承该类,并实现相关接口:
接口 | 说明 |
---|---|
build_train_loader | 定义训练数据 |
build_test_loader | 定义测试数据 |
build_model | 定义模型 |
build_optimizer | 定义优化器 |
build_criterion | 定义损失函数 |
build_callbacks | 定义回调函数 |
另外,EngineTrainer
类还是实现了两个重要的类方法(build_dataloader和build_model_parallel),用于构建分布式训练
每个回调函数都需要继承(Callback),用户在回调函数中,可实现对迭代方法输入/输出的处理,例如:
basetrainer
使用方法可以参考example.py,构建自己的训练器,可通过如下步骤实现:
- step1:
新建一个类
ClassificationTrainer
,继承trainer.EngineTrainer
- step2: 实现接口
def build_train_loader(self, cfg, **kwargs):
"""定义训练数据"""
raise NotImplementedError("build_train_loader not implemented!")
in_file, 'rst', format='md', outputfile="README.rst", encoding='utf-8')
def build_test_loader(self, cfg, **kwargs):
"""定义测试数据"""
raise NotImplementedError("build_test_loader not implemented!")
def build_model(self, cfg, **kwargs):
"""定于训练模型"""
raise NotImplementedError("build_model not implemented!")
def build_optimizer(self, cfg, **kwargs):
"""定义优化器"""
raise NotImplementedError("build_optimizer not implemented!")
def build_criterion(self, cfg, **kwargs):
"""定义损失函数"""
raise NotImplementedError("build_criterion not implemented!")
def build_callbacks(self, cfg, **kwargs):
"""定义回调函数"""
raise NotImplementedError("build_callbacks not implemented!")
- step3: 在初始化中调用
build
def __init__(self, cfg):
super(ClassificationTrainer, self).__init__(cfg)
...
self.build(cfg)
...
- step4:
实例化
ClassificationTrainer
,并使用launch
启动分布式训练
def main(cfg):
t = ClassificationTrainer(cfg)
return t.run()
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
cfg = setup_config.parser_config(args)
launch(main,
num_gpus_per_machine=len(cfg.gpu_id),
dist_url="tcp://127.0.0.1:28661",
num_machines=1,
machine_rank=0,
distributed=cfg.distributed,
args=(cfg,))
basetrainer
使用方法可以参考example.py
# 单进程多卡训练
python example.py --gpu_id 0 1 # 使用命令行参数
python example.py --config_file configs/config.yaml # 使用yaml配置文件
# 多进程多卡训练(分布式训练)
python example.py --config_file configs/config.yaml --distributed # 使用yaml配置文件
- 目标支持的backbone有:resnet[18,34,50,101], ,mobilenet_v2等,详见backbone等 ,其他backbone可以自定义添加
- 训练参数可以通过两种方法指定: (1) 通过argparse命令行指定 (2)通过`config.yaml <configs/config.yaml>`__配置文件,当存在同名参数时,以配置文件为默认值
- 学习率调整策略
scheduler | 说明 | lr-epoch曲线图 |
---|---|---|
multi_step | 阶梯学习率调整策略 | |
cosine | 余弦退火学习率调整策略 | |
ExpLR | 指数衰减学习率调整策略 | |
LambdaLR | Lambda学习率调整策略 |
目前训练过程可视化工具是使用Tensorboard,使用方法:
tensorboard --logdir=path/to/log/
作者 | PKing |
---|---|
联系方式 | [email protected] |