-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtrain.py
338 lines (262 loc) · 16.6 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
# coding: utf-8
from __future__ import division, print_function
import tensorflow as tf
from tensorflow.python.framework import graph_util
import numpy as np
import argparse
import logging
from utils.data_utils import parse_data
from utils.misc_utils import parse_anchors, read_class_names, shuffle_and_overwrite, update_dict, make_summary, config_learning_rate, config_optimizer, list_add
from utils.eval_utils import evaluate_on_cpu, evaluate_on_gpu
from utils.nms_utils import gpu_nms
from model.yolov3 import yolov3
from model.yolov3_tiny import yolov3_tiny
net_name = 'yolov3_tiny'
anchors_name = 'yolov3_tiny'
body_name = 'darknet19'
data_name = 'hardhat'
restore_path = ''
is_train_scratch = True
#################
# ArgumentParser
#################
parser = argparse.ArgumentParser(description="%s training procedure."%net_name)
# some paths
parser.add_argument("--train_file", type=str, default="./data/my_data/train.txt",
help="The path of the training txt file.")
parser.add_argument("--val_file", type=str, default="./data/my_data/test.txt",
help="The path of the validation txt file.")
parser.add_argument("--restore_path", type=str, default="/home/rashid/YOLOv3_tiny/checkpoint/yolov3_tiny_hardhat/model-step_6000_loss_1.212937_lr_0.0008153726",
help="The path of the weights to restore.")
parser.add_argument("--save_dir", type=str, default="./checkpoint/%s_%s/"%(net_name,data_name),
help="The directory of the weights to save.")
parser.add_argument("--log_dir", type=str, default="./data/logs/%s_%s/"%(net_name,data_name),
help="The directory to store the tensorboard log files.")
parser.add_argument("--progress_log_path", type=str, default="./data/%s_progress.log"%data_name,
help="The path to record the training progress.")
parser.add_argument("--anchor_path", type=str, default="./data/tiny_yolo_anchors.txt",
help="The path of the anchor txt file.")
parser.add_argument("--class_name_path", type=str, default="./data/data.names",
help="The path of the class names.")
# some numbers
parser.add_argument("--batch_size", type=int, default=8,
help="The batch size for training.")
parser.add_argument("--img_size", nargs='*', type=int, default=[416, 416],
help="Resize the input image to `img_size`, size format: [width, height]")
parser.add_argument("--total_epoches", type=int, default=100,
help="Total epoches to train.")
parser.add_argument("--train_evaluation_freq", type=int, default=500,
help="Evaluate on the training batch after some steps.")
parser.add_argument("--print_evaluation_freq", type=int, default=50,
help="Print the result of evaluate on the training batch after some steps.")
parser.add_argument("--val_evaluation_freq", type=int, default=2000,
help="Evaluate on the whole validation dataset after some steps.")
parser.add_argument("--save_freq", type=int, default=2000,
help="Save the model after some steps.")
parser.add_argument("--num_threads", type=int, default=4,
help="Number of threads for image processing used in tf.data pipeline.")
parser.add_argument("--prefetech_buffer", type=int, default=3,
help="Prefetech_buffer used in tf.data pipeline.")
# learning rate and optimizer
parser.add_argument("--optimizer_name", type=str, default='adam',
help="The optimizer name. Chosen from [sgd, momentum, adam, rmsprop]")
parser.add_argument("--save_optimizer", type=lambda x: (str(x).lower() == 'true'), default=False,
help="Whether to save the optimizer parameters into the checkpoint file.")
parser.add_argument("--learning_rate_init", type=float, default=1.0e-3,
help="The initial learning rate.")
parser.add_argument("--lr_type", type=str, default='exponential',
help="The learning rate type. Chosen from [fixed, exponential]")
parser.add_argument("--lr_decay_freq", type=int, default=1000,
help="The learning rate decay frequency. Used when chosen exponential lr_type.")
parser.add_argument("--lr_decay_factor", type=float, default=0.96,
help="The learning rate decay factor. Used when chosen exponential lr_type.")
parser.add_argument("--lr_lower_bound", type=float, default=1e-8,
help="The minimum learning rate. Used when chosen exponential lr type.")
# finetune
parser.add_argument("--restore_part", nargs='*', type=str, default=['%s/%s_body'%(net_name,body_name)],
help="Partially restore part of the model for finetuning. Set [None] to restore the whole model.")
parser.add_argument("--update_part", nargs='*', type=str, default=['%s/%s_head'%(net_name,net_name)],
help="Partially restore part of the model for finetuning. Set [None] to train the whole model.")
# warm up strategy
parser.add_argument("--use_warm_up", type=lambda x: (str(x).lower() == 'true'), default=True,
help="Whether to use warm up strategy.")
parser.add_argument("--warm_up_lr", type=float, default=5e-5,
help="Warm up learning rate.")
parser.add_argument("--warm_up_epoch", type=int, default=1,
help="Warm up training epoches.")
args = parser.parse_args()
# args params
args.anchors = parse_anchors(args.anchor_path)
args.classes = read_class_names(args.class_name_path)
args.class_num = len(args.classes)
args.train_img_cnt = len(open(args.train_file, 'r').readlines())
args.val_img_cnt = len(open(args.val_file, 'r').readlines())
args.train_batch_num = int(np.floor(float(args.train_img_cnt) / args.batch_size))
args.val_batch_num = int(np.floor(float(args.val_img_cnt) / args.batch_size))
# setting loggers
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s',
datefmt='%a, %d %b %Y %H:%M:%S', filename=args.progress_log_path, filemode='w')
# setting placeholders
is_training = tf.placeholder(dtype=tf.bool, name="phase_train")
handle_flag = tf.placeholder(tf.string, [], name='iterator_handle_flag')
##################
# tf.data pipeline
##################
# Selecting `feedable iterator` to switch between training dataset and validation dataset
# manually shuffle the train txt file because tf.data.shuffle is soooo slow!!
# you can google it for more details.
shuffle_and_overwrite(args.train_file)
train_dataset = tf.data.TextLineDataset(args.train_file)
train_dataset = train_dataset.apply(tf.data.experimental.map_and_batch(
lambda x: tf.py_func(parse_data, [x, args.class_num, args.img_size, args.anchors, 'train'], [tf.float32, tf.float32, tf.float32, tf.float32]),
num_parallel_calls=args.num_threads, batch_size=args.batch_size))
train_dataset = train_dataset.prefetch(args.prefetech_buffer)
val_dataset = tf.data.TextLineDataset(args.val_file)
val_dataset = val_dataset.apply(tf.data.experimental.map_and_batch(
lambda x: tf.py_func(parse_data, [x, args.class_num, args.img_size, args.anchors, 'val'], [tf.float32, tf.float32, tf.float32, tf.float32]),
num_parallel_calls=args.num_threads, batch_size=args.batch_size))
val_dataset.prefetch(args.prefetech_buffer)
# creating two dataset iterators
train_iterator = train_dataset.make_initializable_iterator()
val_iterator = val_dataset.make_initializable_iterator()
# creating two dataset handles
train_handle = train_iterator.string_handle()
val_handle = val_iterator.string_handle()
# select a specific iterator based on the passed handle
dataset_iterator = tf.data.Iterator.from_string_handle(handle_flag, train_dataset.output_types,
train_dataset.output_shapes)
# get an element from the choosed dataset iterator
image, y_true_13, y_true_26, y_true_52 = dataset_iterator.get_next()
if (net_name == 'yolov3'):
y_true = [y_true_13, y_true_26, y_true_52]
else:
y_true = [y_true_13, y_true_26]
# tf.data pipeline will lose the data shape, so we need to set it manually
image.set_shape([None, args.img_size[1], args.img_size[0], 3])
for y in y_true:
y.set_shape([None, None, None, None, None])
##################
# Model definition
##################
# define yolo-v3 model here
#yolo_model = yolov3(args.class_num, args.anchors)
yolo_model = yolov3_tiny(args.class_num, args.anchors)
# the input variables name of .pb
image = tf.identity(image, name='inputs')
with tf.variable_scope(net_name):
pred_feature_maps = yolo_model.forward(image, is_training=is_training)
loss = yolo_model.compute_loss(pred_feature_maps, y_true)
y_pred = yolo_model.predict(pred_feature_maps)
# the output variables name of .pb
boxes, confs, probs = y_pred
boxes = tf.identity(boxes, name='boxes')
confs = tf.identity(confs, name='confs')
probs = tf.identity(probs, name='probs')
################
# register the gpu nms operation here for the following evaluation scheme
pred_boxes_flag = tf.placeholder(tf.float32, [1, None, None])
pred_scores_flag = tf.placeholder(tf.float32, [1, None, None])
gpu_nms_op = gpu_nms(pred_boxes_flag, pred_scores_flag, args.class_num)
################
# train the whole model from scratch
if is_train_scratch == True:
args.restore_part = ['None']
args.update_part = ['None']
if args.restore_part == ['None']:
args.restore_part = [None]
if args.update_part == ['None']:
args.update_part = [None]
saver_to_restore = tf.train.Saver(var_list=tf.contrib.framework.get_variables_to_restore(include=args.restore_part))
update_vars = tf.contrib.framework.get_variables_to_restore(include=args.update_part)
tf.summary.scalar('train_batch_statistics/total_loss', loss[0])
tf.summary.scalar('train_batch_statistics/loss_xy', loss[1])
tf.summary.scalar('train_batch_statistics/loss_wh', loss[2])
tf.summary.scalar('train_batch_statistics/loss_conf', loss[3])
tf.summary.scalar('train_batch_statistics/loss_class', loss[4])
global_step = tf.Variable(0, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
if args.use_warm_up:
learning_rate = tf.cond(tf.less(global_step, args.train_batch_num * args.warm_up_epoch),
lambda: args.warm_up_lr, lambda: config_learning_rate(args, global_step - args.train_batch_num * args.warm_up_epoch))
else:
learning_rate = config_learning_rate(args, global_step)
tf.summary.scalar('learning_rate', learning_rate)
if not args.save_optimizer:
saver_to_save = tf.train.Saver()
optimizer = config_optimizer(args.optimizer_name, learning_rate)
if args.save_optimizer:
saver_to_save = tf.train.Saver()
# set dependencies for BN ops
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss[0], var_list=update_vars, global_step=global_step)
with tf.Session() as sess:
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer(), train_iterator.initializer])
train_handle_value, val_handle_value = sess.run([train_handle, val_handle])
if args.restore_path:
saver_to_restore.restore(sess, args.restore_path)
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter(args.log_dir, sess.graph)
print('\n----------- start to train -----------\n')
for epoch in range(args.total_epoches):
for i in range(args.train_batch_num):
_, summary, y_pred_, y_true_, loss_, global_step_, lr = sess.run([train_op, merged, y_pred, y_true, loss, global_step, learning_rate],
feed_dict={is_training: True, handle_flag: train_handle_value})
writer.add_summary(summary, global_step=global_step_)
# print the evaluation
if global_step_ % args.print_evaluation_freq == 0 and global_step_ > 0:
info = "Epoch: {}, global_step: {}, lr: {:.8f}, total_loss: {:.3f}, loss_xy: {:.3f}, loss_wh: {:.3f}, loss_conf: {:.3f}, loss_class: {:.3f}".format(
epoch, global_step_, lr, loss_[0], loss_[1], loss_[2], loss_[3], loss_[4])
print(info)
logging.info(info)
# evaluation on the training batch
if global_step_ % args.train_evaluation_freq == 0 and global_step_ > 0:
# recall, precision = evaluate_on_cpu(y_pred_, y_true_, args.class_num, calc_now=True)
recall, precision = evaluate_on_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, y_pred_, y_true_, args.class_num, calc_now=True)
info = "===> batch recall: {:.3f}, batch precision: {:.3f} <===".format(recall, precision)
print(info)
logging.info(info)
writer.add_summary(make_summary('evaluation/train_batch_recall', recall), global_step=global_step_)
writer.add_summary(make_summary('evaluation/train_batch_precision', precision), global_step=global_step_)
# start to save
# NOTE: this is just demo. You can set the conditions when to save the weights.
if global_step_ % args.save_freq == 0 and global_step_ > 0:
if loss_[0] <= 10. or (epoch > 20 and global_step_ % (args.save_freq * 4) == 0):
saver_to_save.save(sess, args.save_dir + 'model-step_{}_loss_{:4f}_lr_{:.7g}'.format(global_step_, loss_[0], lr))
# switch to validation dataset for evaluation
if global_step_ % args.val_evaluation_freq == 0 and global_step_ > 0:
sess.run(val_iterator.initializer)
true_positive_dict, true_labels_dict, pred_labels_dict = {}, {}, {}
val_loss = [0., 0., 0., 0., 0.]
for j in range(args.val_batch_num):
y_pred_, y_true_, loss_ = sess.run([y_pred, y_true, loss],
feed_dict={is_training: False, handle_flag: val_handle_value})
true_positive_dict_tmp, true_labels_dict_tmp, pred_labels_dict_tmp = \
evaluate_on_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag,
y_pred_, y_true_, args.class_num, calc_now=False)
true_positive_dict = update_dict(true_positive_dict, true_positive_dict_tmp)
true_labels_dict = update_dict(true_labels_dict, true_labels_dict_tmp)
pred_labels_dict = update_dict(pred_labels_dict, pred_labels_dict_tmp)
val_loss = list_add(val_loss, loss_)
# make sure there is at least one ground truth object in each image
# avoid divided by 0
recall = float(sum(true_positive_dict.values())) / (sum(true_labels_dict.values()) + 1e-6)
precision = float(sum(true_positive_dict.values())) / (sum(pred_labels_dict.values()) + 1e-6)
info = "===> Epoch: {}, global_step: {}, recall: {:.3f}, precision: {:.3f}, total_loss: {:.3f}, loss_xy: {:.3f}, loss_wh: {:.3f}, loss_conf: {:.3f}, loss_class: {:.3f}".format(
epoch, global_step_, recall, precision, val_loss[0] / args.val_batch_num, val_loss[1] / args.val_batch_num, val_loss[2] / args.val_batch_num, val_loss[3] / args.val_batch_num, val_loss[4] / args.val_batch_num)
print(info)
logging.info(info)
writer.add_summary(make_summary('evaluation/val_recall', recall), global_step=epoch)
writer.add_summary(make_summary('evaluation/val_precision', precision), global_step=epoch)
writer.add_summary(make_summary('validation_statistics/total_loss', val_loss[0] / args.val_batch_num), global_step=epoch)
writer.add_summary(make_summary('validation_statistics/loss_xy', val_loss[1] / args.val_batch_num), global_step=epoch)
writer.add_summary(make_summary('validation_statistics/loss_wh', val_loss[2] / args.val_batch_num), global_step=epoch)
writer.add_summary(make_summary('validation_statistics/loss_conf', val_loss[3] / args.val_batch_num), global_step=epoch)
writer.add_summary(make_summary('validation_statistics/loss_class', val_loss[4] / args.val_batch_num), global_step=epoch)
# manually shuffle the training data in a new epoch
shuffle_and_overwrite(args.train_file)
sess.run(train_iterator.initializer)
# save the pb, "boxes","confs","probs" 为输出变量
print('saving the pb ...')
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['boxes', 'confs','probs'])
with tf.gfile.FastGFile(args.save_dir+'model.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())