本模型使用了全卷积网络解决服装关键点定位问题,给出了基于tensorflow的代码模型。
tensorflow1.3.0, opencv2.4, python3.5, numpy 1.13.3
用于调整模型的所有参数都放在'config.cgf'里面。
training_txt_file : 存放图片信息的csv文件路径
img_directory : 图片存放路径
img_size : 图片的大小
hm_size : 热图的大小
num_joints : 关键点的数量
joint_list: 关键点列表
name : 训练模型名
nFeats: 卷积层中特征的数量
nStacks: stack的数量
nModules : 模型数量
nLow : 下采样的数
dropout_rate : 训练结尾神经元舍弃的比例
batch_size : batch的大小
nEpochs : epoch的大小
epoch_size : 每个epoch中循环次数
learning_rate: 学习率
learning_rate_decay: 学习率的衰减速度(一般设为(0.0-0.99))
decay_step : 学习率衰减次数
valid_iteration : 用于validation的数量
log_dir_test : 测试文件的路径
log_dir_train : 训练文件的路径
saver_step : 写入训练文件的步长
saver_directory:保存训练模型的路径
运行train/train_launcher.py,其中初始学习率默认设为0.00025,训练过程中,每隔3000步,学习率衰减百分比0.96
$ cd train/
$ python train_launcher.py
首先将训练生成的model放到test文件夹下, 运行test/test_launcher.py进行测试,结果保存为csv文件
$ cd test/
$ python test_launcher.py
所有的训练数据在train/data文件夹下,测试数据在test/data文件夹下
测试过程的中间文件都保存在log_dir文件夹下
Project
train
data
Images
Annotations
model.py
mydatagen.py
train_launcher.py
myconfig.cfg
test
data
Images
model.py
predictClass.py
myconfig.cfg
test_launcher.py
使用Hourglass-network作为基础, 将每个特征点做二维高斯处理,得到一个64X64的heatmap,与网络输出结果进行比较, 根据不同类型的服装,设置不同的损失函数,并进行训练和优化