-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnanodet_model_infer.py
130 lines (116 loc) · 5.34 KB
/
nanodet_model_infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import cv2
import os
import time
import torch
import argparse
from nanodet.util import cfg, load_config, Logger
from nanodet.model.arch import build_model
from nanodet.util import load_model_weight
from nanodet.data.transform import Pipeline
from nanodet.util.path import mkdir
class Predictor(object):
def __init__(self, cfg, model_path, logger, device='cuda:0'):
self.cfg = cfg
self.device = device
model = build_model(cfg.model)
ckpt = torch.load(model_path, map_location=lambda storage, loc: storage)
load_model_weight(model, ckpt, logger)
if cfg.model.arch.backbone.name == 'RepVGG':
deploy_config = cfg.model
deploy_config.arch.backbone.update({'deploy': True})
deploy_model = build_model(deploy_config)
from nanodet.model.backbone.repvgg import repvgg_det_model_convert
model = repvgg_det_model_convert(model, deploy_model)
self.model = model.to(device).eval()
self.pipeline = Pipeline(cfg.data.val.pipeline, cfg.data.val.keep_ratio)
def inference(self, img):
img_info = {'id': 0}
if isinstance(img, str):
img_info['file_name'] = os.path.basename(img)
img = cv2.imread(img)
else:
img_info['file_name'] = None
height, width = img.shape[:2]
img_info['height'] = height
img_info['width'] = width
meta = dict(img_info=img_info,
raw_img=img,
img=img)
meta = self.pipeline(meta, self.cfg.data.val.input_size)
meta['img'] = torch.from_numpy(meta['img'].transpose(2, 0, 1)).unsqueeze(0).to(self.device)
with torch.no_grad():
results = self.model.inference(meta)
return meta, results
def visualize(self, dets, meta, class_names, score_thres, wait=0):
time1 = time.time()
result_img = self.model.head.show_result_new(meta['raw_img'], dets, class_names, score_thres=score_thres, show=False)
print('viz time: {:.3f}s'.format(time.time()-time1))
return result_img
def get_image_list(path):
image_names = []
for maindir, subdir, file_name_list in os.walk(path):
for filename in file_name_list:
apath = os.path.join(maindir, filename)
ext = os.path.splitext(apath)[1]
if ext in image_ext:
image_names.append(apath)
return image_names
# ### load model
# local_rank = 0
# logger = Logger(local_rank, use_tensorboard=False)
# load_config(cfg, "/home/ashish/Documents/DeepLearning/nanodet/config/nanodet-m.yml")
# predictor = Predictor(cfg, "/home/ashish/Documents/DeepLearning/nanodet/pytorch-model/nanodet_m.ckpt", logger, device='cpu')
def get_realtime_infer(image_np):
current_time = time.localtime()
meta, res = predictor.inference(image_np)
result_image = predictor.visualize(res[0], meta, cfg.class_names, 0.35)
return result_image
# def main():
# args = parse_args()
# local_rank = 0
# # torch.backends.cudnn.enabled = True
# # torch.backends.cudnn.benchmark = True
# load_config(cfg, args.config)
# logger = Logger(local_rank, use_tensorboard=False)
# predictor = Predictor(cfg, args.model, logger, device='cpu')
# logger.log('Press "Esc", "q" or "Q" to exit.')
# current_time = time.localtime()
# if args.demo == 'image':
# if os.path.isdir(args.path):
# files = get_image_list(args.path)
# else:
# files = [args.path]
# files.sort()
# for image_name in files:
# meta, res = predictor.inference(image_name)
# result_image = predictor.visualize(res[0], meta, cfg.class_names, 0.35)
# if args.save_result:
# save_folder = os.path.join(cfg.save_dir, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))
# mkdir(local_rank, save_folder)
# save_file_name = os.path.join(save_folder, os.path.basename(image_name))
# cv2.imwrite(save_file_name, result_image)
# ch = cv2.waitKey(0)
# if ch == 27 or ch == ord('q') or ch == ord('Q'):
# break
# elif args.demo == 'video' or args.demo == 'webcam':
# cap = cv2.VideoCapture(args.path if args.demo == 'video' else args.camid)
# width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) # float
# height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) # float
# fps = cap.get(cv2.CAP_PROP_FPS)
# save_folder = os.path.join(cfg.save_dir, time.strftime("%Y_%m_%d_%H_%M_%S", current_time))
# mkdir(local_rank, save_folder)
# save_path = os.path.join(save_folder, args.path.split('/')[-1]) if args.demo == 'video' else os.path.join(save_folder, 'camera.mp4')
# print(f'save_path is {save_path}')
# vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (int(width), int(height)))
# while True:
# ret_val, frame = cap.read()
# if ret_val:
# meta, res = predictor.inference(frame)
# result_frame = predictor.visualize(res[0], meta, cfg.class_names, 0.35)
# if args.save_result:
# vid_writer.write(result_frame)
# ch = cv2.waitKey(1)
# if ch == 27 or ch == ord('q') or ch == ord('Q'):
# break
# else:
# break