forked from luxonis/yolo2openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_weights_pb.py
77 lines (64 loc) · 3.69 KB
/
convert_weights_pb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
Tool used to convert darknet model weights (.weight + .cfg) to Tensorflow model wights (.pb).
Converts example pretrained (COCO) yolov4-tiny by default.
"""
import tensorflow as tf
from models import yolo_v3, yolo_v3_tiny, yolo_v4, yolo_v4_tiny
from utils.anchors import Anchors
from utils.utils import load_weights, load_names, detections_boxes, freeze_graph
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('yolo', 4, 'Yolo version. Possible options: 3 or 4.')
tf.app.flags.DEFINE_string('class_names', './examples/coco.names', 'File with class names. Usually a file ending in .names.')
tf.app.flags.DEFINE_string('weights_file', './examples/yolov4-tiny.weights', 'Binary file with Yolo\' weights')
tf.app.flags.DEFINE_string('data_format', 'NHWC', 'Data format: NCHW (gpu only) / NHWC')
tf.app.flags.DEFINE_string('output', './examples/yolov4-tiny.pb', 'Frozen tensorflow protobuf model output path')
tf.app.flags.DEFINE_bool('tiny', True, 'Use tiny version of Yolo')
tf.app.flags.DEFINE_integer('size', 416, 'Image size. If both, height and width, are not provided, a square input shape of size as set with this flag will be used')
tf.app.flags.DEFINE_integer('height', None, 'Input image height. If height is set, width must also be set. The size flag will be ignored.', short_name='h')
tf.app.flags.DEFINE_integer('width', None, 'Input image width. If width is set, height must also be set. The size flag will be ignored.', short_name='w')
tf.app.flags.DEFINE_list('anchors', None, 'List of anchors. If not set default anchors for YoloV3, YoloV4, and YoloV3/V4-tiny will be set.', short_name='a')
def main(argv=None):
if FLAGS.yolo == 3:
if FLAGS.tiny:
model = yolo_v3_tiny.yolo_v3_tiny
default_anchors = Anchors.YOLOV3TINY.value
else:
model = yolo_v3.yolo_v3
default_anchors = Anchors.YOLOV3.value
elif FLAGS.yolo == 4:
if FLAGS.tiny:
model = yolo_v4_tiny.yolo_v4_tiny
default_anchors = Anchors.YOLOV4TINY.value
else:
model = yolo_v4.yolo_v4
default_anchors = Anchors.YOLOV4.value
else:
raise ValueError(f"{FLAGS.yolo} is not supported Yolo version. Supported versions: 3, 4.")
print(FLAGS.anchors)
selected_anchors = default_anchors if FLAGS.anchors is None else [int(a) for a in FLAGS.anchors]
anchors = [(selected_anchors[i * 2], selected_anchors[i * 2 + 1]) for i in range(len(selected_anchors) // 2)]
classes = load_names(FLAGS.class_names)
# set input shape
if FLAGS.height is not None and FLAGS.width is not None:
inputs = tf.compat.v1.placeholder(tf.float32, [None, FLAGS.height, FLAGS.width, 3], "inputs")
if FLAGS.size is not None:
print("Width and height are set, size flag will be ignored!")
elif FLAGS.size is not None:
inputs = tf.compat.v1.placeholder(tf.float32, [None, FLAGS.size, FLAGS.size, 3], "inputs")
else:
raise Exception("Neither size nor width and height flags are set. Please specify input shape!")
with tf.compat.v1.variable_scope('detector'):
detections = model(inputs, len(classes), anchors, data_format=FLAGS.data_format)
load_ops = load_weights(tf.compat.v1.global_variables(scope='detector'), FLAGS.weights_file)
# Sets the output nodes in the current session
boxes = detections_boxes(detections)
print("Starting conversion with the following parameters:")
print(f"Yolo version: {FLAGS.yolo}")
print(f"Anchors: {anchors}")
print(f"Shape: {inputs}")
print(f"Classes: {classes}")
with tf.compat.v1.Session() as sess:
sess.run(load_ops)
freeze_graph(sess, FLAGS.output)
if __name__ == '__main__':
tf.compat.v1.app.run()