-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference_kineva.py
65 lines (54 loc) · 2.08 KB
/
inference_kineva.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# KINEVA INFERENCE DEMO 08.10.2024
# REBOTNIX GmbH, all rights reserved.
from rb_engine.detect import Detector
import cv2
import numpy as np
# Generate a fixed color for each class and store them
def assign_colors_to_classes(classes):
colors = {}
for class_name in classes:
# Generate a unique color for each class
random_color = np.random.randint(0, 255, size=(3,))
colors[class_name] = (int(random_color[0]), int(random_color[1]), int(random_color[2])) # Convert to BGR tuple
return colors
#initialize Detector
detector = Detector
#load model from file
detect = detector(
weight_file="./model/kineva_person_head.best.pth",
conf_thres=0.25,
nms_thres=0.55,
input_size=[640,640],
fuse=False,
fp16=False,
use_decoder=False
)
#get class names
class_names = detect.class_names
#generate colors for class names
class_colors = assign_colors_to_classes(class_names)
#load images
img = cv2.imread("demos/d12.jpg")
# do detection
results = detect([img], False)
# draw rectangles on image
for res in results:
if res is not None:
for *xywh, obj, conf, cate in res:
x1 = int(xywh[0])
y1 = int(xywh[1])
x2 = int(xywh[2])
y2 = int(xywh[3])
acc = float(obj * conf)
catname = class_names[int(cate)]
cv2.rectangle(img, (x1, y1), (x2, y2), class_colors[catname], 2)
# Add the label above the rectangle
label_position = (x1, y1 - 10) # Slightly above the rectangle
label_size = cv2.getTextSize(catname, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0]
# Coordinates of the background box
box_top_left = (label_position[0] - 5, label_position[1] - label_size[1] - 5)
box_bottom_right = (label_position[0] + label_size[0] + 5, label_position[1] + 5)
cv2.rectangle(img, box_top_left, box_bottom_right, class_colors[catname], thickness=cv2.FILLED)
cv2.putText(img, catname, label_position, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), thickness=2)
#write image
cv2.imwrite("demos/result.jpg", img)