From b1b14306550324786b97ca0baf06ff619e40693e Mon Sep 17 00:00:00 2001 From: Kalana Ratnayake Date: Thu, 12 Sep 2024 17:20:12 +1000 Subject: [PATCH] added yolo into the node --- boxmot_ros/node.py | 157 +++++++++++++++++++++++++++++++++++++++------ requirements.txt | 3 +- 2 files changed, 140 insertions(+), 20 deletions(-) diff --git a/boxmot_ros/node.py b/boxmot_ros/node.py index b186de7..f5a5c25 100644 --- a/boxmot_ros/node.py +++ b/boxmot_ros/node.py @@ -4,14 +4,18 @@ from pathlib import Path +from message_filters import Subscriber +from message_filters import ApproximateTimeSynchronizer + from rclpy.node import Node from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSHistoryPolicy -from sensor_msgs.msg import Image +from sensor_msgs.msg import Image, PointCloud2 from detection_msgs.msg import Detections from cv_bridge import CvBridge +from ultralytics import YOLO from boxmot import BoTSORT, DeepOCSORT, OCSORT, HybridSORT, BYTETracker, StrongSORT class BoxmotROS(Node): @@ -19,18 +23,24 @@ class BoxmotROS(Node): def __init__(self): super().__init__('boxmot_ros') + self.declare_parameter("yolo_model", "yolov8n.pt") self.declare_parameter("tracking_model", "deepocsort") self.declare_parameter("reid_model", "osnet_x0_25_msmt17.pt") - self.declare_parameter("input_topic", "/input") + self.declare_parameter("input_rgb_topic", "/camera/color/image_raw") + self.declare_parameter("input_depth_topic", "/camera/depth/points") + self.declare_parameter("subscribe_depth", False) self.declare_parameter("publish_annotated_image", False) self.declare_parameter("annotated_topic", "/boxmot_ros/annotated_image") self.declare_parameter("detailed_topic", "/boxmot_ros/tracking_result") self.declare_parameter("threshold", 0.25) self.declare_parameter("device", "cpu") + self.yolo_model = self.get_parameter("yolo_model").get_parameter_value().string_value self.tracking_model = self.get_parameter("tracking_model").get_parameter_value().string_value self.reid_model = self.get_parameter("reid_model").get_parameter_value().string_value - self.input_topic = self.get_parameter("input_topic").get_parameter_value().string_value + self.input_rgb_topic = self.get_parameter("input_rgb_topic").get_parameter_value().string_value + self.input_depth_topic = self.get_parameter("input_depth_topic").get_parameter_value().string_value + self.subscribe_depth = self.get_parameter("subscribe_depth").get_parameter_value().bool_value self.publish_annotated_image = self.get_parameter("publish_annotated_image").get_parameter_value().bool_value self.annotated_topic = self.get_parameter("annotated_topic").get_parameter_value().string_value self.detailed_topic = self.get_parameter("detailed_topic").get_parameter_value().string_value @@ -39,6 +49,9 @@ def __init__(self): self.bridge = CvBridge() + self.model = YOLO(self.yolo_model) + self.model.fuse() + if self.tracking_model == "deepocsort": self.tracker = DeepOCSORT( model_weights=Path(self.reid_model), # which ReID model to use device=self.device, @@ -73,7 +86,15 @@ def __init__(self): history=QoSHistoryPolicy.KEEP_LAST, depth=1 ) - self.subscription = self.create_subscription(Detections, self.input_topic, self.image_callback, qos_profile=self.subscriber_qos_profile) + if self.subscribe_depth: + self.rgb_message_filter = Subscriber(self, Image, self.input_rgb_topic, qos_profile=self.subscriber_qos_profile) + self.depth_message_filter = Subscriber(self, PointCloud2, self.input_depth_topic, qos_profile=self.subscriber_qos_profile) + + self.synchornizer = ApproximateTimeSynchronizer([self.rgb_message_filter, self.depth_message_filter], 10, 1) + self.synchornizer.registerCallback(self.sync_callback) + + else: + self.subscription = self.create_subscription(Image, self.input_rgb_topic, self.image_callback, qos_profile=self.subscriber_qos_profile) self.publisher_results = self.create_publisher(Detections, self.detailed_topic, 10) @@ -85,32 +106,132 @@ def __init__(self): self.time = 0 self.tracking_msg = Detections() + self.class_list_set = False - def image_callback(self, received_msg): + + def sync_callback(self, rgb_msg, depth_msg): start = time.time_ns() - self.input_image = self.bridge.imgmsg_to_cv2(received_msg.source_rgb, desired_encoding="bgr8") + self.input_image = self.bridge.imgmsg_to_cv2(rgb_msg, desired_encoding="bgr8") + + self.result = self.model.predict(source = self.input_image, + conf=self.threshold, + device=self.device, + verbose=False) + + if (not self.class_list_set) and (self.result is not None): + for i in range(len(self.result[0].names)): + self.tracking_msg.full_class_list.append(self.result[0].names.get(i)) + self.class_list_set = True - if received_msg.detections: + if self.result is not None: detection_list = [] - for i in range(len(received_msg.class_id)): + for bbox, cls, conf in zip(self.result[0].boxes.xywh, self.result[0].boxes.cls, self.result[0].boxes.conf): detection = [] - clid = received_msg.class_id[i] - conf = received_msg.confidence[i] + cx = int(bbox[0]) + cy = int(bbox[1]) + sw = int(bbox[2]) + sh = int(bbox[3]) + + x1 = cx - (sw/2) + y1 = cy - (sh/2) + x2 = cx + (sw/2) + y2 = cy + (sh/2) + + detection = [x1, y1, x2, y2, float(conf), int(cls)] + + detection_list.append(detection) + + detection_numpy = np.array(detection_list) + else: + detection_numpy = np.empty((0, 5)) + + # input is of shape (x, y, x, y, conf, cls) + # output is of shape (x, y, x, y, id, conf, cls, ind) + + self.result_tracks = self.tracker.update(detection_numpy, self.input_image) + + if self.result_tracks is not None: + + self.tracking_msg.header = rgb_msg.header + self.tracking_msg.source_rgb = rgb_msg + self.tracking_msg.source_depth = depth_msg + + for track in self.result_tracks: + + x1 = track[0].astype('int') + y1 = track[1].astype('int') + x2 = track[2].astype('int') + y2 = track[3].astype('int') + + tracking_id = track[4].astype('int') + confidence = track[5].astype('float') + class_id = track[6].astype('int') + + cx = (x2 + x1)/2 + cy = (y2 + y1)/2 + sw = x2 - x1 + sh = y2 - y1 + + self.tracking_msg.bbx_center_x.append(int(cx)) + self.tracking_msg.bbx_center_y.append(int(cy)) + self.tracking_msg.bbx_size_w.append(int(cx)) + self.tracking_msg.bbx_size_h.append(int(cx)) + self.tracking_msg.class_id.append(class_id) + self.tracking_msg.tracking_id.append(tracking_id) + self.tracking_msg.confidence.append(confidence) + + self.publisher_results.publish(self.tracking_msg) + + if self.publish_annotated_image: + self.output_image = self.tracker.plot_results(self.input_image, show_trajectories=True) + result_msg = self.bridge.cv2_to_imgmsg(self.output_image, encoding="bgr8") + + self.publisher_image.publish(result_msg) + + self.counter += 1 + self.time += time.time_ns() - start + + if (self.counter == 100): + self.get_logger().info('Callback execution time for 100 loops: %d ms' % ((self.time/100)/1000000)) + self.time = 0 + self.counter = 0 + + + def image_callback(self, rgb_image): + start = time.time_ns() + + self.input_image = self.bridge.imgmsg_to_cv2(rgb_image, desired_encoding="bgr8") + + self.result = self.model.predict(source = self.input_image, + conf=self.threshold, + device=self.device, + verbose=False) + + if (not self.class_list_set) and (self.result is not None): + for i in range(len(self.result[0].names)): + self.tracking_msg.full_class_list.append(self.result[0].names.get(i)) + self.class_list_set = True + + if self.result is not None: + detection_list = [] + + for bbox, cls, conf in zip(self.result[0].boxes.xywh, self.result[0].boxes.cls, self.result[0].boxes.conf): + detection = [] - cx = received_msg.bbx_center_x[i] - cy = received_msg.bbx_center_y[i] - sw = received_msg.bbx_size_w[i] - sh = received_msg.bbx_size_h[i] + cx = int(bbox[0]) + cy = int(bbox[1]) + sw = int(bbox[2]) + sh = int(bbox[3]) x1 = cx - (sw/2) y1 = cy - (sh/2) x2 = cx + (sw/2) y2 = cy + (sh/2) - detection = [x1, y1, x2, y2, conf, clid] + detection = [x1, y1, x2, y2, float(conf), int(cls)] detection_list.append(detection) @@ -125,10 +246,8 @@ def image_callback(self, received_msg): if self.result_tracks is not None: - self.tracking_msg.header = received_msg.header - self.tracking_msg.source_rgb = received_msg.source_rgb - self.tracking_msg.source_depth = received_msg.source_depth - self.tracking_msg.full_class_list = received_msg.full_class_list + self.tracking_msg.header = rgb_image.header + self.tracking_msg.source_rgb = rgb_image for track in self.result_tracks: diff --git a/requirements.txt b/requirements.txt index a6259bf..9a96d6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -boxmot \ No newline at end of file +boxmot +ultralytics \ No newline at end of file