forked from rh-aiservices-bu/object-detection-rest
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathobject_rendering.py
49 lines (40 loc) · 1.33 KB
/
object_rendering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import random
from PIL import Image, ImageDraw, ImageFont
from matplotlib import pyplot as plt
import numpy as np
def draw_boxes(image_path, model_output, scaling, padding, class_labels):
image = Image.open(image_path).convert("RGB") # Read image and convert to RGB
draw = ImageDraw.Draw(image)
# Load a default font
try:
font = ImageFont.truetype("arial.ttf", size=16)
except IOError:
font = ImageFont.load_default()
colors = {
name: tuple(
random.randint(0, 255) for _ in range(3)
) for i, name in enumerate(class_labels)
}
for i, (x0, y0, x1, y1, score, cls_id) in enumerate(model_output):
box = np.array([x0, y0, x1, y1])
box -= np.array(padding * 2)
box /= scaling
box = box.round().astype(np.int32).tolist()
cls_id = int(cls_id)
score = round(float(score), 3)
name = class_labels[cls_id]
color = colors[name]
name += ' ' + str(score)
draw.rectangle(box, outline=color, width=2)
draw.text(
(box[0], box[1] - 15),
name,
fill=(0, 255, 0),
font=font
)
# Convert image to an array for display with Matplotlib
img = np.array(image)
fig = plt.gcf()
fig.set_size_inches(24, 12)
plt.axis('off')
plt.imshow(img)