-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathretinanet_quanteval.py
306 lines (256 loc) · 11 KB
/
retinanet_quanteval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
#!/usr/bin/env python3
# pylint: disable=E0401,E1101,W0621,R0915,R0914,R0912
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2022 of Qualcomm Innovation Center, Inc. All rights reserved.
#
# @@-COPYRIGHT-END-@@
# =============================================================================
"""Quantsim evaluation script for retinanet"""
# pylint:disable = import-error, wrong-import-order
# adding this due to docker image not setup yet
from glob import glob
import urllib.request
import argparse
import os
import progressbar
from tqdm import tqdm
import tensorflow as tf
from aimet_tensorflow.quantsim import save_checkpoint, load_checkpoint
from aimet_tensorflow.batch_norm_fold import fold_all_batch_norms
from aimet_tensorflow import quantsim
from keras_retinanet import models
from keras_retinanet.utils.coco_eval import evaluate_coco
from keras_retinanet.utils.image import read_image_bgr, preprocess_image, resize_image
from keras import backend as K
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
# Keras RetinaNet
# AIMET
def download_weights():
"""Downloading weights and config"""
if not os.path.exists("resnet50_coco_best_v2.1.0.h5"):
URL = "https://github.com/fizyr/keras-retinanet/releases/download/0.5.1/resnet50_coco_best_v2.1.0.h5"
urllib.request.urlretrieve(URL, "resnet50_coco_best_v2.1.0.h5")
# Config file
if not os.path.exists("default_config.json"):
URL = "https://raw.githubusercontent.com/quic/aimet/release-aimet-1.22/TrainingExtensions/common/src/python/aimet_common/quantsim_config/default_config.json"
urllib.request.urlretrieve(URL, "default_config.json")
# pylint: disable=W0613
# pylint: disable=W0612
def quantize_retinanet(model_path, cocopath, action):
"""
Quantize the original RetinaNet model.
Loads the keras model.
Retrieve the back-end TF session and saves a checkpoint for quantized evaluatoin by AIMET
Invoke AIMET APIs to quantize the and save a quantized checkpoint - which includes quantize ops
:param model_path: Path to the downloaded keras retinanet model - read the docs for download path
:param cocopath: Path to the top level COCO dataset
:param action: eval_original or eval_quantized
:return:
"""
model_path = os.path.join(model_path, "resnet50_coco_best_v2.1.0.h5")
model = models.load_model(model_path, backbone_name="resnet50")
# Clean weights from the prior run to avoid mismatch errors
previous_weights = glob("model*") + glob("checkpoint*")
for file in previous_weights:
os.remove(file)
variant_directories = [
"original_fp32",
"original_int8",
"optimized_fp32",
"optimized_int8",
]
for dir_name in variant_directories:
os.makedirs(dir_name, exist_ok=True)
# Note that AIMET APIs need TF session. So retrieve the TF session from
# the backend
session = K.get_session()
if action == "original_fp32":
saver = tf.train.Saver()
saver.save(session, "./original_fp32/model.ckpt")
elif action == "original_int8":
in_tensor = "input_1:0"
out_tensor = [
"filtered_detections/map/TensorArrayStack/TensorArrayGatherV3:0",
"filtered_detections/map/TensorArrayStack_1/TensorArrayGatherV3:0",
"filtered_detections/map/TensorArrayStack_2/TensorArrayGatherV3:0",
]
selected_ops = ["P" + str(i) + "/BiasAdd" for i in range(3, 8)]
#pylint:disable = use-maxsplit-arg
sim = quantsim.QuantizationSimModel(
session,
[in_tensor.split(":")[0]],
selected_ops,
config_file="default_config.json",
)
def forward_pass(session2: tf.Session, args):
images_raw = glob(cocopath + "/images/train2017/*.jpg")
for idx in tqdm(range(10)):
image = read_image_bgr(images_raw[idx])
image = preprocess_image(image)
image, scale = resize_image(image)
session2.run(out_tensor, feed_dict={in_tensor: [image]})
sim.compute_encodings(forward_pass, None)
save_checkpoint(sim, "./original_int8/model.ckpt", "model")
elif action == "optimized_fp32":
in_tensor = "input_1:0"
out_tensor = [
"filtered_detections/map/TensorArrayStack/TensorArrayGatherV3:0",
"filtered_detections/map/TensorArrayStack_1/TensorArrayGatherV3:0",
"filtered_detections/map/TensorArrayStack_2/TensorArrayGatherV3:0",
]
selected_ops = ["P" + str(i) + "/BiasAdd" for i in range(3, 8)]
#pylint:disable = use-maxsplit-arg
session, folded_pairs = fold_all_batch_norms(
session, [in_tensor.split(":")[0]], selected_ops
)
#pylint:disable = use-maxsplit-arg
sim = quantsim.QuantizationSimModel(
session,
[in_tensor.split(":")[0]],
selected_ops,
config_file="default_config.json",
)
def forward_pass(session2: tf.Session, args):
images_raw = glob(cocopath + "/images/train2017/*.jpg")
for idx in tqdm(range(10)):
image = read_image_bgr(images_raw[idx])
image = preprocess_image(image)
image, scale = resize_image(image)
session2.run(out_tensor, feed_dict={in_tensor: [image]})
sim.compute_encodings(forward_pass, None)
saver = tf.train.Saver()
saver.save(sim.session, "./optimized_fp32/model.ckpt")
elif action == "optimized_int8":
in_tensor = "input_1:0"
out_tensor = [
"filtered_detections/map/TensorArrayStack/TensorArrayGatherV3:0",
"filtered_detections/map/TensorArrayStack_1/TensorArrayGatherV3:0",
"filtered_detections/map/TensorArrayStack_2/TensorArrayGatherV3:0",
]
#pylint:disable = use-maxsplit-arg
selected_ops = ["P" + str(i) + "/BiasAdd" for i in range(3, 8)]
session, folded_pairs = fold_all_batch_norms(
session, [in_tensor.split(":")[0]], selected_ops
)
sim = quantsim.QuantizationSimModel(
session, [in_tensor.split(":")[0]], selected_ops
)
def forward_pass(session2: tf.Session, args):
images_raw = glob(cocopath + "/images/train2017/*.jpg")
for idx in tqdm(range(10)):
image = read_image_bgr(images_raw[idx])
image = preprocess_image(image)
image, scale = resize_image(image)
session2.run(out_tensor, feed_dict={in_tensor: [image]})
sim.compute_encodings(forward_pass, None)
save_checkpoint(sim, "./optimized_int8/model.ckpt", "model")
else:
raise ValueError(
"--action must be one of: original_fp32, original_int8, optimized_fp32, optimized_int8"
)
assert callable(
progressbar.progressbar
), "Using wrong progressbar module, install 'progressbar2' instead."
def evaluate(generator, action, threshold=0.05):
"""
Evaluate the model and saves results
:param generator: generator for validation dataset
:param action: eval the original or quantized model
:param threshold: Score Threshold
:return:
"""
in_tensor = "input_1:0"
out_tensor = [
"filtered_detections/map/TensorArrayStack/TensorArrayGatherV3:0",
"filtered_detections/map/TensorArrayStack_1/TensorArrayGatherV3:0",
"filtered_detections/map/TensorArrayStack_2/TensorArrayGatherV3:0",
]
with tf.Session() as new_sess:
if action == "original_fp32":
saver = tf.train.import_meta_graph("./original_fp32/model.ckpt.meta")
saver.restore(new_sess, "./original_fp32/model.ckpt")
elif action == "original_int8":
new_quantsim = load_checkpoint("./original_int8/model.ckpt", "model")
new_sess = new_quantsim.session
elif action == "optimized_fp32":
saver = tf.train.import_meta_graph("./optimized_fp32/model.ckpt.meta")
saver.restore(new_sess, "./optimized_fp32/model.ckpt")
elif action == "optimized_int8":
new_quantsim = load_checkpoint("./optimized_int8/model.ckpt", "model")
new_sess = new_quantsim.session
model = TFRunWrapper(new_sess, in_tensor, out_tensor)
evaluate_coco(generator, model, threshold)
def create_generator(args, preprocess_image):
"""
Create generator to use for eval for coco validation set
:param args: args from commandline
:param preprocess_image: input preprocessing
:return:
"""
common_args = {
"preprocess_image": preprocess_image,
}
#pylint:disable = import-outside-toplevel
from keras_retinanet.preprocessing.coco import CocoGenerator
validation_generator = CocoGenerator(
args.dataset_path,
"val2017",
image_min_side=args.image_min_side,
image_max_side=args.image_max_side,
# config=args.config,
shuffle_groups=False,
**common_args
)
return validation_generator
def parse_args(args):
"""argument parser"""
parser = argparse.ArgumentParser(
description="Evaluation script for a RetinaNet network."
)
parser.add_argument(
"--dataset-path", help="Path to dataset directory (ie. /tmp/COCO)."
)
parser.add_argument(
"--action",
help="action to perform - eval_quantized|eval_original",
default="eval_quantized",
choices={"original_fp32", "original_int8", "optimized_fp32", "optimized_int8"},
)
return parser.parse_args(args)
class TFRunWrapper:
"""The coco_eval in keras-retinanet repository needs a model as input for prediction
We have a TF back-end session - so we wrap it in a Wrapper and implement
predict to call session run"""
def __init__(self, tf_session, in_tensor, out_tensor):
self.sess = tf_session
self.in_tensor = in_tensor
self.out_tensor = out_tensor
def predict_on_batch(self, input_name):
"""predict on batch"""
return self.sess.run(self.out_tensor, feed_dict={self.in_tensor: input_name})
class ModelConfig:
"""Hardcoded model configuration"""
def __init__(self, args):
self.model_path = "./"
self.score_threshold = 0.05
self.iou_threshold = 0.5
self.max_detections = 100
self.image_min_side = 800
self.image_max_side = 1333
self.quantsim_config_file = "default_config.json"
for arg in vars(args):
setattr(self, arg, getattr(args, arg))
def main(args=None):
"""evaluation main script"""
args = parse_args(args)
config = ModelConfig(args)
download_weights()
backbone = models.backbone("resnet50")
generator = create_generator(config, backbone.preprocess_image)
quantize_retinanet(config.model_path, config.dataset_path, config.action)
evaluate(generator, config.action, config.score_threshold)
if __name__ == "__main__":
main()