-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdetect.py
115 lines (94 loc) · 3.94 KB
/
detect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
@Author: Matheus Teixeira de Sousa ([email protected])
Detect forks from images with trained YOLOv7 ONNX model
"""
import cv2 as cv
import numpy as np
import onnxruntime as ort
from torch.utils.data import DataLoader
from utils.dataset import TestDataset
from utils.utils import plot_one_box, show_predicted_image, adjust_image
from os.path import exists, isdir
from os import makedirs
from random import randint
import argparse
def predict_bbox(session, images):
"""
Predict bounding boxes from images
"""
outname = [i.name for i in session.get_outputs()]
dict_output = {}
for i, samples in enumerate(images):
im, ratio, dwdh, name = samples['image'], samples['ratio'], samples['dwdh'], samples['name']
im = np.ascontiguousarray(im/255)
out = session.run(outname, {'images':im})[0]
dict_output[f"batch {i}"] = {"preds": out, "ratio": ratio, "dwdh": dwdh, "name": name}
return dict_output
if __name__ == '__main__':
# Parse command line arguments
parser = argparse.ArgumentParser(description='Predict with YOLOv7-fork ONNX model')
parser.add_argument('-m', '--model', required=True, metavar='/path/to/model.onnx', help="Path to ONNX model.")
parser.add_argument('-i', '--input', required=True, metavar='/path/to/images/ or /path/to/image.jpg', help="Path to images or path to image.")
parser.add_argument('-b', '--batch', default=1, help="Batch size. (Default: 1)")
parser.add_argument('-s', '--save', default=False, action='store_true', help="Save predicted image. (Default: False)")
parser.add_argument('-d', '--dontshow', default=False, action='store_true', help="Don't show predicted image. (Default: False)")
parser.add_argument('-c', '--cuda', default=False, action='store_true', help="Set execution on GPU. (Default: False)")
args = parser.parse_args()
for key, value in args._get_kwargs():
if value is not None:
print(f'{key.capitalize()}: {value}')
print()
# Check if the input is a dir
input_isdir = isdir(args.input)
# Load the model
print('Loading model...', flush=True)
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if args.cuda else ['CPUExecutionProvider']
session = ort.InferenceSession(args.model, providers=providers)
# Get output name and input shape
outname = [i.name for i in session.get_outputs()]
input_shape = session.get_inputs()[0].shape
h, w = input_shape[2], input_shape[3]
# Load the images
print('Loading images...', flush=True)
if input_isdir:
dataset = TestDataset(args.input, shape=(h, w))
images = DataLoader(dataset, batch_size=args.batch, shuffle=False, num_workers=0)
else:
images = [adjust_image(args.input, shape=(h, w))]
# Predict from images
print('Making predictions...', flush=True)
dict_output = predict_bbox(session, images)
names = ['fork']
colors = {name: [randint(0, 255) for _ in range(3)] for name in names}
# colors = {name: [104, 184, 82] for name in names} # green
if args.save and not exists(f'data/responses'):
makedirs(f'data/responses')
# For each image, plot the results
print('Plotting results...', flush=True)
for i, key in enumerate(dict_output.keys()):
pred, ratio, dwdh, name = dict_output[key]['preds'], dict_output[key]['ratio'][0], dict_output[key]['dwdh'], dict_output[key]['name'][0]
ratio = float(ratio)
dwdh = float(dwdh[0]), float(dwdh[1])
# Load original image
if input_isdir:
image = dataset.__getsrc__(i)
else:
image = cv.imread(args.input)
# Adjust bounding box to original image
for prediction in pred:
batch_id, x0, y0, x1, y1, cls_id, score = prediction
box = np.array([x0,y0,x1,y1])
box -= np.array(dwdh*2)
box /= ratio
box = box.round().astype(np.int32).tolist()
cls_id = int(cls_id)
score = round(float(score),3)
label = names[cls_id]
color = colors[label]
label += ' ' + str(score)
plot_one_box(box, image, label=label, color=color, line_thickness=1)
if args.save:
path = 'data/responses/' + name
cv.imwrite(path, image)
if not args.dontshow:
show_predicted_image(image)