From 700d7cd78c988bdd2d5a804e95d28122994be2fe Mon Sep 17 00:00:00 2001 From: Michel Hidalgo Date: Wed, 2 Oct 2024 12:45:44 -0300 Subject: [PATCH] Allow adapted feeds to filter messages Signed-off-by: Michel Hidalgo --- .../bdai_ros2_wrappers/feeds.py | 31 ++++++++++++++++--- .../bdai_ros2_wrappers/filters.py | 21 ++++++++++--- bdai_ros2_wrappers/test/test_feeds.py | 28 +++++++++++++++++ 3 files changed, 71 insertions(+), 9 deletions(-) diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py index 9cdab12..a7a06d5 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py @@ -1,12 +1,30 @@ # Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved. -from typing import Any, Callable, Generator, Generic, Iterable, List, Literal, Optional, TypeVar, Union, overload +from typing import ( + Any, + Callable, + Generator, + Generic, + Iterable, + List, + Literal, + Optional, + TypeVar, + Union, + overload, +) import tf2_ros from rclpy.node import Node import bdai_ros2_wrappers.scope as scope -from bdai_ros2_wrappers.filters import Adapter, ApproximateTimeSynchronizer, Filter, TransformFilter, Tunnel +from bdai_ros2_wrappers.filters import ( + Adapter, + ApproximateTimeSynchronizer, + Filter, + TransformFilter, + Tunnel, +) from bdai_ros2_wrappers.futures import FutureLike from bdai_ros2_wrappers.utilities import Tape @@ -36,7 +54,9 @@ def __init__( history_length = 1 self._link = link self._tape: Tape[MessageT] = Tape(history_length) - self._link.registerCallback(lambda *msgs: self._tape.write(msgs if len(msgs) > 1 else msgs[0])) + self._link.registerCallback( + lambda *msgs: self._tape.write(msgs if len(msgs) > 1 else msgs[0]), + ) node.context.on_shutdown(self._tape.close) @property @@ -59,7 +79,10 @@ def update(self) -> FutureLike[MessageT]: """Gets the future to the next message yet to be received.""" return self._tape.future_write - def matching_update(self, matching_predicate: Callable[[MessageT], bool]) -> FutureLike[MessageT]: + def matching_update( + self, + matching_predicate: Callable[[MessageT], bool], + ) -> FutureLike[MessageT]: """Gets a future to the next matching message yet to be received. Args: diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py index 7afd4c5..9cbc7a1 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py @@ -98,7 +98,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: to the underlying `message_filters.ApproximateTimeSynchronizer`. """ super().__init__() - self._unsafe_synchronizer = message_filters.ApproximateTimeSynchronizer(*args, **kwargs) + self._unsafe_synchronizer = message_filters.ApproximateTimeSynchronizer( + *args, + **kwargs, + ) self._unsafe_synchronizer.registerCallback(self.signalMessage) def __getattr__(self, name: str) -> Any: @@ -175,7 +178,9 @@ def _wait_callback(self, messages: Sequence[Any], future: Future) -> None: time, ) self._ongoing_wait_time = time - self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages)) + self._ongoing_wait.add_done_callback( + functools.partial(self._wait_callback, messages), + ) else: self._ongoing_wait_time = None self._ongoing_wait = None @@ -204,7 +209,9 @@ def add(self, *messages: Any) -> None: time, ) self._ongoing_wait_time = time - self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages)) + self._ongoing_wait.add_done_callback( + functools.partial(self._wait_callback, messages), + ) class Adapter(Filter): @@ -215,7 +222,9 @@ def __init__(self, upstream: Filter, fn: Callable) -> None: Args: upstream: the upstream message filter. - fn: adapter implementation as a callable. + fn: a callable that takes messages as arguments and returns some + data to be signaled (i.e. propagated down the filter chain). + If none is returned, no message signaling will occur. """ super().__init__() self.fn = fn @@ -223,7 +232,9 @@ def __init__(self, upstream: Filter, fn: Callable) -> None: def add(self, *messages: Any) -> None: """Adds new `messages` to the adapter.""" - self.signalMessage(self.fn(*messages)) + result = self.fn(*messages) + if result is not None: + self.signalMessage(result) class Tunnel(Filter): diff --git a/bdai_ros2_wrappers/test/test_feeds.py b/bdai_ros2_wrappers/test/test_feeds.py index 04c2b1c..9bdc2ac 100644 --- a/bdai_ros2_wrappers/test/test_feeds.py +++ b/bdai_ros2_wrappers/test/test_feeds.py @@ -101,6 +101,34 @@ def test_adapted_message_feed(ros: ROSAwareScope) -> None: assert position_message is expected_pose_message.pose.position +def test_masked_message_feed(ros: ROSAwareScope) -> None: + pose_message_feed = MessageFeed[PoseStamped](Filter()) + position_masking_feed = AdaptedMessageFeed[Point]( + pose_message_feed, + fn=lambda message: message if message.pose.position.x > 0.0 else None, + ) + expected_pose_message0 = PoseStamped() + expected_pose_message0.header.frame_id = "odom" + expected_pose_message0.header.stamp.sec = 1 + expected_pose_message0.pose.position.x = -1.0 + expected_pose_message0.pose.position.z = -1.0 + expected_pose_message0.pose.orientation.w = 1.0 + pose_message_feed.link.signalMessage(expected_pose_message0) + assert position_masking_feed.latest is None + + expected_pose_message1 = PoseStamped() + expected_pose_message1.header.frame_id = "odom" + expected_pose_message1.header.stamp.sec = 2 + expected_pose_message1.pose.position.x = 1.0 + expected_pose_message1.pose.position.z = -1.0 + expected_pose_message1.pose.orientation.w = 1.0 + pose_message_feed.link.signalMessage(expected_pose_message1) + + pose_message: Point = ensure(position_masking_feed.latest) + # no copies are expected, thus an identity check is valid + assert pose_message is expected_pose_message1 + + def test_message_feed_recalls(ros: ROSAwareScope) -> None: pose_message_feed = MessageFeed[PoseStamped](Filter())