diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py index b431869..3a4c865 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py @@ -1,12 +1,14 @@ # Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved. -from typing import Any, Callable, Iterator, List, Optional +from typing import Any, Callable, Iterable, Iterator, List, Optional -from message_filters import SimpleFilter +import tf2_ros +from message_filters import ApproximateTimeSynchronizer, SimpleFilter from rclpy.node import Node from rclpy.task import Future import bdai_ros2_wrappers.scope as scope +from bdai_ros2_wrappers.filters import TransformFilter from bdai_ros2_wrappers.utilities import Tape @@ -34,7 +36,7 @@ def __init__( history_length = 1 self._link = link self._tape = Tape(history_length) - self._link.registerCallback(self._tape.write) + self._link.registerCallback(lambda *msgs: self._tape.write(msgs if len(msgs) > 1 else msgs[0])) node.context.on_shutdown(self._tape.close) @property @@ -50,7 +52,7 @@ def history(self) -> List[Any]: @property def latest(self) -> Optional[Any]: """Gets the latest message received, if any.""" - return next(self._tape.content(), None) + return self._tape.head @property def update(self) -> Future: @@ -99,3 +101,105 @@ def stream( def close(self) -> None: """Closes the message feed.""" self._tape.close() + + +class FramedMessageFeed(MessageFeed): + """A message feed decorator, incorporating transforms using a `TransformFilter` instance.""" + + def __init__( + self, + feed: MessageFeed, + target_frame_id: str, + *, + tolerance_sec: float = 1.0, + tf_buffer: Optional[tf2_ros.Buffer] = None, + history_length: Optional[int] = None, + node: Optional[Node] = None, + ) -> None: + """Initializes the message feed. + + Args: + feed: the upstream message feed. + target_frame_id: the target frame ID for transforms. + tf_buffer: optional buffer of transforms to look up. If none is provided + a transforms' buffer and a listener will be instantiated. + tolerance_sec: optional tolerance, in seconds, to wait for late transforms. + history_length: optional historic data size, defaults to 1. + node: optional node for the underlying native subscription, defaults to + the current process node. + """ + if node is None: + node = scope.ensure_node() + if tf_buffer is None: + tf_buffer = tf2_ros.Buffer() + self._tf_listener = tf2_ros.TransformListener(tf_buffer, node) + super().__init__( + TransformFilter( + feed.link, + target_frame_id, + tf_buffer, + tolerance_sec, + node.get_logger(), + ), + history_length=history_length, + node=node, + ) + self._feed = feed + + @property + def feed(self) -> MessageFeed: + """Gets the upstream message feed.""" + return self._feed + + def close(self) -> None: + """Closes this message feed and the upstream one as well.""" + self._feed.close() + super().close() + + +class SynchronizedMessageFeed(MessageFeed): + """A message feeds' aggregator using a `message_filters.ApproximateTimeSynchronizer` instance.""" + + def __init__( + self, + *feeds: MessageFeed, + queue_size: int = 10, + delay: float = 0.2, + allow_headerless: bool = False, + history_length: Optional[int] = None, + node: Optional[Node] = None, + ) -> None: + """Initializes the message feed. + + Args: + feeds: upstream message feeds to be synchronized. + queue_size: the message queue size for synchronization. + delay: the maximum delay, in seconds, between messages for synchronization to succeed. + allow_headerless: whether it's OK for there to be no header in the messages (in which + case, the ROS time of arrival will be used). + history_length: optional historic data size, defaults to 1. + node: optional node for the underlying native subscription, defaults to + the current process node. + """ + super().__init__( + ApproximateTimeSynchronizer( + [f.link for f in feeds], + queue_size, + delay, + allow_headerless=allow_headerless, + ), + history_length=history_length, + node=node, + ) + self._feeds = feeds + + @property + def feeds(self) -> Iterable[MessageFeed]: + """Gets all aggregated message feeds.""" + return self._feeds + + def close(self) -> None: + """Closes this message feed and all upstream ones as well.""" + for feed in self._feeds: + feed.close() + super().close() diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py index 1b438dc..dd4fb80 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py @@ -43,7 +43,7 @@ def __init__( """ super().__init__() self._logger = logger - self._lock = threading.Lock() + self._lock = threading.RLock() self._waitqueue: collections.deque = collections.deque() self._ongoing_wait: Optional[Future] = None self._ongoing_wait_time: Optional[Time] = None diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py index b646cdf..8ce9723 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py @@ -200,6 +200,16 @@ def write(self, data: Any) -> bool: self._future_matching_writes.remove(item) return True + @property + def head(self) -> Optional[Any]: + """Returns the data tape head, if any.""" + with self._lock: + if self._content is None: + return None + if len(self._content) == 0: + return None + return self._content[0] + def content( self, *, @@ -450,3 +460,20 @@ def take_kwargs(func: Callable, kwargs: Mapping) -> Tuple[Mapping, Mapping]: else: dropped[name] = value return taken, dropped + + +def ensure(value: Optional[Any]) -> Any: + """Ensures `value` is not None or fails trying.""" + if value is None: + frame = inspect.currentframe() + assert frame is not None + frame = frame.f_back + assert frame is not None + traceback = inspect.getframeinfo(frame) + message = f"{traceback.filename}:{traceback.lineno}: " + if traceback.code_context is not None: + message += "".join(traceback.code_context).strip("\n ") + " failed" + else: + message += "ensure() failed" + raise ValueError(message) + return value diff --git a/bdai_ros2_wrappers/test/test_feeds.py b/bdai_ros2_wrappers/test/test_feeds.py new file mode 100644 index 0000000..eb6f285 --- /dev/null +++ b/bdai_ros2_wrappers/test/test_feeds.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved. + + +import tf2_ros +from geometry_msgs.msg import ( + PoseStamped, + TransformStamped, + TwistStamped, +) +from message_filters import SimpleFilter + +from bdai_ros2_wrappers.feeds import ( + FramedMessageFeed, + MessageFeed, + SynchronizedMessageFeed, +) +from bdai_ros2_wrappers.scope import ROSAwareScope +from bdai_ros2_wrappers.utilities import ensure + + +def test_framed_message_feed(ros: ROSAwareScope) -> None: + tf_buffer = tf2_ros.Buffer() + pose_message_feed = MessageFeed(SimpleFilter()) + framed_message_feed = FramedMessageFeed( + pose_message_feed, + target_frame_id="map", + tf_buffer=tf_buffer, + node=ros.node, + ) + + expected_transform_message = TransformStamped() + expected_transform_message.header.frame_id = "map" + expected_transform_message.child_frame_id = "odom" + expected_transform_message.transform.translation.y = 1.0 + expected_transform_message.transform.rotation.w = 1.0 + tf_buffer.set_transform_static(expected_transform_message, "pytest") + + expected_pose_message = PoseStamped() + expected_pose_message.header.frame_id = "odom" + expected_pose_message.header.stamp.sec = 1 + expected_pose_message.pose.position.x = 1.0 + expected_pose_message.pose.orientation.w = 1.0 + pose_message_feed.link.signalMessage(expected_pose_message) + + pose_message, transform_message = ensure(framed_message_feed.latest) + assert pose_message.pose.position.x == expected_pose_message.pose.position.x + assert transform_message.transform.translation.y == expected_transform_message.transform.translation.y + + +def test_synchronized_message_feed(ros: ROSAwareScope) -> None: + pose_message_feed = MessageFeed(SimpleFilter()) + twist_message_feed = MessageFeed(SimpleFilter()) + synchronized_message_feed = SynchronizedMessageFeed( + pose_message_feed, + twist_message_feed, + node=ros.node, + ) + + expected_pose_message = PoseStamped() + expected_pose_message.header.frame_id = "odom" + expected_pose_message.header.stamp.sec = 1 + expected_pose_message.pose.position.x = 1.0 + expected_pose_message.pose.orientation.w = 1.0 + pose_message_feed.link.signalMessage(expected_pose_message) + + expected_twist_message = TwistStamped() + expected_twist_message.header.frame_id = "base_link" + expected_twist_message.header.stamp.sec = 1 + expected_twist_message.twist.linear.x = 1.0 + expected_twist_message.twist.angular.z = 1.0 + twist_message_feed.link.signalMessage(expected_twist_message) + + pose_message, twist_message = ensure(synchronized_message_feed.latest) + assert pose_message.pose.position.x == expected_pose_message.pose.position.x + assert twist_message.twist.linear.x == expected_twist_message.twist.linear.x diff --git a/bdai_ros2_wrappers/test/test_subscription.py b/bdai_ros2_wrappers/test/test_subscription.py index 3f9fac0..1ad820c 100644 --- a/bdai_ros2_wrappers/test/test_subscription.py +++ b/bdai_ros2_wrappers/test/test_subscription.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved. import itertools import time -from typing import Any, Iterator, cast +from typing import Any, Iterator from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile from std_msgs.msg import Int8, String @@ -10,6 +10,7 @@ from bdai_ros2_wrappers.node import Node from bdai_ros2_wrappers.scope import ROSAwareScope from bdai_ros2_wrappers.subscription import Subscription, wait_for_message, wait_for_messages +from bdai_ros2_wrappers.utilities import ensure DEFAULT_QOS_PROFILE = QoSProfile( durability=DurabilityPolicy.TRANSIENT_LOCAL, @@ -44,7 +45,7 @@ def test_subscription_future_wait(ros: ROSAwareScope) -> None: pub.publish(Int8(data=1)) assert wait_for_future(sequence.update, timeout_sec=5.0) - assert cast(Int8, sequence.latest).data == 1 + assert ensure(sequence.latest).data == 1 def test_subscription_matching_future_wait(ros: ROSAwareScope) -> None: @@ -132,7 +133,7 @@ def deferred_cancellation() -> None: assert len(historic_numbers) == 1 assert historic_numbers[0] == 1 - assert cast(Int8, sequence.latest).data == 1 + assert ensure(sequence.latest).data == 1 def test_wait_for_messages(ros: ROSAwareScope) -> None: diff --git a/bdai_ros2_wrappers/test/test_utilities.py b/bdai_ros2_wrappers/test/test_utilities.py index 366fafb..a286002 100644 --- a/bdai_ros2_wrappers/test/test_utilities.py +++ b/bdai_ros2_wrappers/test/test_utilities.py @@ -2,7 +2,9 @@ import argparse -from bdai_ros2_wrappers.utilities import either_or, namespace_with +import pytest + +from bdai_ros2_wrappers.utilities import either_or, ensure, namespace_with def test_either_or() -> None: @@ -18,3 +20,10 @@ def test_namespace_with() -> None: assert namespace_with("", "foo") == "foo" assert namespace_with("/", "foo") == "/foo" assert namespace_with("foo", "bar") == "foo/bar" + + +def test_ensure() -> None: + data = None + with pytest.raises(ValueError) as excinfo: + ensure(data) + assert "ensure(data) failed" in str(excinfo.value)