diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py new file mode 100644 index 0000000..1b438dc --- /dev/null +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved. + +import collections +import functools +import threading +from collections.abc import Sequence +from typing import Any, Optional + +import tf2_ros +from message_filters import SimpleFilter +from rclpy.duration import Duration +from rclpy.task import Future +from rclpy.time import Time + +from bdai_ros2_wrappers.logging import RcutilsLogger + + +class TransformFilter(SimpleFilter): + """A :mod:`tf2_ros` driven message filter, ensuring user defined transforms' availability. + + This filter passes a stamped transform along with filtered messages, from message frame ID + to the given target frame ID, looked up at the message timestamp. The message is assumed + to have a header. When filtering message tuples, only the first message header is observed. + """ + + def __init__( + self, + f: SimpleFilter, + target_frame_id: str, + tf_buffer: tf2_ros.Buffer, + tolerance_sec: float, + logger: Optional[RcutilsLogger] = None, + ) -> None: + """Initializes the transform filter. + + Args: + f: the upstream message filter. + target_frame_id: the target frame ID for transforms. + tf_buffer: a buffer of transforms to look up. + tolerance_sec: a tolerance, in seconds, to wait for late transforms + before abandoning any waits and filtering out the corresponding messages. + logger: an optional logger to notify the yser about any errors during filtering. + """ + super().__init__() + self._logger = logger + self._lock = threading.Lock() + self._waitqueue: collections.deque = collections.deque() + self._ongoing_wait: Optional[Future] = None + self._ongoing_wait_time: Optional[Time] = None + self.target_frame_id = target_frame_id + self.tf_buffer = tf_buffer + self.tolerance = Duration(seconds=tolerance_sec) + self.incoming_connection = f.registerCallback(self.add) + + def _wait_callback(self, messages: Sequence[Any], future: Future) -> None: + if future.cancelled(): + return + with self._lock: + try: + if future.result() is True: + source_frame_id = messages[0].header.frame_id + time = Time.from_msg(messages[0].header.stamp) + transform = self.tf_buffer.lookup_transform( + self.target_frame_id, + source_frame_id, + time, + ) + self.signalMessage(*messages, transform) + except tf2_ros.TransformException as e: + if self._logger is not None: + self._logger.error( + ( + "Got an exception during transform lookup: %s", + str(e), + ), + ) + + if self._waitqueue: + messages = self._waitqueue.popleft() + source_frame_id = messages[0].header.frame_id + time = Time.from_msg(messages[0].header.stamp) + self._ongoing_wait = self.tf_buffer.wait_for_transform_async( + self.target_frame_id, + source_frame_id, + time, + ) + self._ongoing_wait_time = time + self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages)) + else: + self._ongoing_wait_time = None + self._ongoing_wait = None + + def add(self, *messages: Any) -> None: + """Adds new `messages` to the filter.""" + with self._lock: + time = Time.from_msg(messages[0].header.stamp) + if self._ongoing_wait and not self._ongoing_wait.done() and time - self._ongoing_wait_time > self.tolerance: + self._ongoing_wait.cancel() + self._ongoing_wait = None + while self._waitqueue: + pending_messages = self._waitqueue[0] + pending_time = Time.from_msg(pending_messages[0].header.stamp) + if time - pending_time <= self.tolerance: + break + self._waitqueue.popleft() + self._waitqueue.append(messages) + if not self._ongoing_wait: + messages = self._waitqueue.popleft() + source_frame_id = messages[0].header.frame_id + time = Time.from_msg(messages[0].header.stamp) + self._ongoing_wait = self.tf_buffer.wait_for_transform_async( + self.target_frame_id, + source_frame_id, + time, + ) + self._ongoing_wait_time = time + self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages)) diff --git a/bdai_ros2_wrappers/test/test_filters.py b/bdai_ros2_wrappers/test/test_filters.py new file mode 100644 index 0000000..afa8fe6 --- /dev/null +++ b/bdai_ros2_wrappers/test/test_filters.py @@ -0,0 +1,73 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved. + +from typing import List, Tuple + +import tf2_ros +from geometry_msgs.msg import PoseStamped, TransformStamped +from message_filters import SimpleFilter + +from bdai_ros2_wrappers.filters import TransformFilter + + +def test_transform_wait() -> None: + source = SimpleFilter() + tf_buffer = tf2_ros.Buffer() + tf_filter = TransformFilter(source, "map", tf_buffer, tolerance_sec=1.0) + sink: List[Tuple[PoseStamped, TransformStamped]] = [] + tf_filter.registerCallback(lambda *msgs: sink.append(msgs)) + assert len(sink) == 0 + pose_message = PoseStamped() + pose_message.header.frame_id = "odom" + pose_message.header.stamp.sec = 1 + pose_message.pose.position.x = 1.0 + pose_message.pose.orientation.w = 1.0 + source.signalMessage(pose_message) + assert len(sink) == 0 + transform_message = TransformStamped() + transform_message.header.frame_id = "map" + transform_message.child_frame_id = "odom" + transform_message.transform.rotation.w = 1.0 + tf_buffer.set_transform_static(transform_message, "pytest") + assert len(sink) == 1 + filtered_pose_message, filtered_transform_message = sink.pop() + assert filtered_pose_message.header.frame_id == pose_message.header.frame_id + assert filtered_pose_message.pose.position.x == pose_message.pose.position.x + assert filtered_pose_message.pose.orientation.w == pose_message.pose.orientation.w + assert filtered_transform_message.header.frame_id == transform_message.header.frame_id + assert filtered_transform_message.child_frame_id == transform_message.child_frame_id + assert filtered_transform_message.transform.rotation.w == transform_message.transform.rotation.w + + +def test_old_transform_filtering() -> None: + source = SimpleFilter() + tf_buffer = tf2_ros.Buffer() + tf_filter = TransformFilter(source, "map", tf_buffer, tolerance_sec=2.0) + sink: List[Tuple[PoseStamped, TransformStamped]] = [] + tf_filter.registerCallback(lambda *msgs: sink.append(msgs)) + assert len(sink) == 0 + first_pose_message = PoseStamped() + first_pose_message.header.frame_id = "odom" + first_pose_message.pose.position.x = 1.0 + first_pose_message.pose.orientation.w = 1.0 + source.signalMessage(first_pose_message) + assert len(sink) == 0 + second_pose_message = PoseStamped() + second_pose_message.header.frame_id = "odom" + second_pose_message.header.stamp.sec = 10 + second_pose_message.pose.position.x = 2.0 + second_pose_message.pose.orientation.w = 1.0 + source.signalMessage(second_pose_message) + assert len(sink) == 0 + transform_message = TransformStamped() + transform_message.header.frame_id = "map" + transform_message.child_frame_id = "odom" + transform_message.transform.rotation.w = 1.0 + tf_buffer.set_transform_static(transform_message, "pytest") + assert len(sink) == 1 + filtered_pose_message, filtered_transform_message = sink.pop() + assert filtered_pose_message.header.frame_id == second_pose_message.header.frame_id + assert filtered_pose_message.pose.position.x == second_pose_message.pose.position.x + assert filtered_pose_message.pose.orientation.w == second_pose_message.pose.orientation.w + assert filtered_transform_message.header.frame_id == transform_message.header.frame_id + assert filtered_transform_message.child_frame_id == transform_message.child_frame_id + assert filtered_transform_message.transform.rotation.w == transform_message.transform.rotation.w