From 5aa699ac3aa69eb260e99727c2650cb91fcf0a0e Mon Sep 17 00:00:00 2001 From: mhidalgo-bdai <144129882+mhidalgo-bdai@users.noreply.github.com> Date: Mon, 25 Nov 2024 12:57:38 -0300 Subject: [PATCH] Deferred message filter/feed activation for race-free setups (#132) Signed-off-by: Michel Hidalgo --- .../bdai_ros2_wrappers/feeds.py | 48 ++-- .../bdai_ros2_wrappers/filters.py | 215 +++++++++++++++--- .../bdai_ros2_wrappers/subscription.py | 13 +- bdai_ros2_wrappers/test/test_subscription.py | 24 ++ 4 files changed, 244 insertions(+), 56 deletions(-) diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py index d5187f0..fa5d61e 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py @@ -57,7 +57,7 @@ def __init__( self._link.registerCallback( lambda *msgs: self._tape.write(msgs if len(msgs) > 1 else msgs[0]), ) - node.context.on_shutdown(self._tape.close) + node.context.on_shutdown(self.close) @property def link(self) -> Filter: @@ -178,10 +178,17 @@ def stream( timeout_sec=timeout_sec, ) - def close(self) -> None: - """Closes the message feed.""" + def start(self) -> None: + """Start the message feed.""" + self._link.start() + + def stop(self) -> None: + """Stop the message feed.""" + self._link.stop() self._tape.close() + close = stop + class AdaptedMessageFeed(MessageFeed[MessageT]): """A message feed decorator to simplify adapter patterns.""" @@ -190,6 +197,8 @@ def __init__( self, feed: MessageFeed, fn: Callable[..., MessageT], + *, + autostart: bool = True, **kwargs: Any, ) -> None: """Initializes the message feed. @@ -199,8 +208,9 @@ def __init__( fn: message adapting callable. kwargs: all other keyword arguments are forwarded for `MessageFeed` initialization. + autostart: whether to start feeding messages immediately or not. """ - super().__init__(Adapter(feed.link, fn), **kwargs) + super().__init__(Adapter(feed.link, fn, autostart=autostart), **kwargs) self._feed = feed @property @@ -208,10 +218,10 @@ def feed(self) -> MessageFeed: """Gets the upstream message feed.""" return self._feed - def close(self) -> None: - """Closes this message feed and the upstream one as well.""" - self._feed.close() - super().close() + def stop(self) -> None: + """Stop this message feed and the upstream one as well.""" + self._feed.stop() + super().stop() class FramedMessageFeed(MessageFeed[MessageT]): @@ -226,6 +236,7 @@ def __init__( tf_buffer: Optional[tf2_ros.Buffer] = None, history_length: Optional[int] = None, node: Optional[Node] = None, + autostart: bool = True, ) -> None: """Initializes the message feed. @@ -238,6 +249,7 @@ def __init__( history_length: optional historic data size, defaults to 1. node: optional node for the underlying native subscription, defaults to the current process node. + autostart: whether to start feeding messages immediately or not. """ if node is None: node = scope.ensure_node() @@ -251,6 +263,7 @@ def __init__( tf_buffer, tolerance_sec, node.get_logger(), + autostart=autostart, ), history_length=history_length, node=node, @@ -262,10 +275,10 @@ def feed(self) -> MessageFeed[MessageT]: """Gets the upstream message feed.""" return self._feed - def close(self) -> None: - """Closes this message feed and the upstream one as well.""" - self._feed.close() - super().close() + def stop(self) -> None: + """Stop this message feed and the upstream one as well.""" + self._feed.stop() + super().stop() class SynchronizedMessageFeed(MessageFeed): @@ -279,6 +292,7 @@ def __init__( allow_headerless: bool = False, history_length: Optional[int] = None, node: Optional[Node] = None, + autostart: bool = True, ) -> None: """Initializes the message feed. @@ -291,6 +305,7 @@ def __init__( history_length: optional historic data size, defaults to 1. node: optional node for the underlying native subscription, defaults to the current process node. + autostart: whether to start feeding messages immediately or not. """ super().__init__( ApproximateTimeSynchronizer( @@ -298,6 +313,7 @@ def __init__( queue_size, delay, allow_headerless=allow_headerless, + autostart=autostart, ), history_length=history_length, node=node, @@ -309,8 +325,8 @@ def feeds(self) -> Iterable[MessageFeed]: """Gets all aggregated message feeds.""" return self._feeds - def close(self) -> None: - """Closes this message feed and all upstream ones as well.""" + def stop(self) -> None: + """Stop this message feed and all upstream ones as well.""" for feed in self._feeds: - feed.close() - super().close() + feed.stop() + super().stop() diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py index 9cbc7a1..62dfa7f 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py @@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, Optional, Protocol, Tuple import message_filters +import rclpy.subscription import tf2_ros from rclpy.duration import Duration from rclpy.node import Node @@ -30,10 +31,50 @@ def signalMessage(self, *messages: Any) -> None: class Filter(SimpleFilterProtocol): """A threadsafe `message_filters.SimpleFilter` compliant message filter.""" - def __init__(self) -> None: - self.__lock = threading.Lock() + def __init__(self, autostart: bool = True) -> None: + """Initialize filter. + + Args: + autostart: whether to start filtering on instantiation or not. + """ + self._stopped = self._started = False + self._connection_lock = threading.Lock() self._connection_sequence = itertools.count() self.callbacks: Dict[int, Tuple[Callable, Tuple]] = {} + if autostart: + self.start() + + def _start(self) -> None: + """Hook for start logic customization""" + + def start(self) -> None: + """Start filtering. + + Raises: + RuntimeError: if filtering has been stopped already. + """ + with self._connection_lock: + if self._stopped: + raise RuntimeError("filtering already stopped") + if not self._started: + self._start() + self._started = True + + def _stop(self) -> None: + """Hook for stop logic customization""" + + def stop(self) -> None: + """Stop filtering. + + Raises: + RuntimeError: if filter has not been started. + """ + with self._connection_lock: + if not self._started: + raise RuntimeError("filter not started") + if not self._stopped: + self._stop() + self._stopped = True def registerCallback(self, fn: Callable, *args: Any) -> int: """Register callable to be called on filter output. @@ -44,8 +85,13 @@ def registerCallback(self, fn: Callable, *args: Any) -> int: Returns: a unique connection identifier. + + Raises: + RuntimeError: if filter has been stopped. """ - with self.__lock: + with self._connection_lock: + if self._stopped: + raise RuntimeError("filter stopped") connection = next(self._connection_sequence) self.callbacks[connection] = (fn, args) return connection @@ -56,12 +102,24 @@ def unregisterCallback(self, connection: int) -> None: Args: connection: unique identifier for the callback. """ - with self.__lock: + with self._connection_lock: del self.callbacks[connection] def signalMessage(self, *messages: Any) -> None: - """Feed one or more `messages` to the filter.""" - with self.__lock: + """Feed one or more `messages` to the filter. + + Args: + messages: messages to be forwarded through the filter. + + Raises: + RuntimeError: if filter is not active + (either not started or already stopped). + """ + with self._connection_lock: + if self._stopped: + raise RuntimeError("filter stopped") + if not self._started: + raise RuntimeError("filter not started") callbacks = list(self.callbacks.values()) for fn, args in callbacks: @@ -71,38 +129,78 @@ def signalMessage(self, *messages: Any) -> None: class Subscriber(Filter): """A threadsafe `message_filters.Subscriber` equivalent.""" - def __init__(self, node: Node, *args: Any, **kwargs: Any) -> None: + def __init__(self, node: Node, *args: Any, autostart: bool = True, **kwargs: Any) -> None: """Initializes the `Subscriber` instance. - All positional and keyword arguments are forwarded - to `rclpy.node.Node.create_subscription`. + Args: + node: ROS 2 node to subscribe with. + args: positional arguments to forward to `rclpy.node.Node.create_subscription`. + autostart: whether to start filtering on instantiation or not. + kwargs: keyword arguments to forward to `rclpy.node.Node.create_subscription`. """ - super().__init__() - self.sub = node.create_subscription( + self.node = node + self.options = (args, kwargs) + self._subscription: Optional[rclpy.subscription.Subscription] = None + super().__init__(autostart) + + def _start(self) -> None: + if self._subscription is not None: + raise RuntimeError("subscriber already subscribed") + args, kwargs = self.options + self._subscription = self.node.create_subscription( *args, callback=self.signalMessage, **kwargs, ) + def _stop(self) -> None: + if self._subscription is None: + raise RuntimeError("subscriber not subscribed") + self.node.destroy_subscription(self._subscription) + self._subscription = None + + close = Filter.stop + def __getattr__(self, name: str) -> Any: - return getattr(self.subscription, name) + return getattr(self._subscription, name) class ApproximateTimeSynchronizer(Filter): """A threadsafe `message_filters.ApproximateTimeSynchronizer` equivalent.""" - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, upstreams: Sequence[Filter], *args: Any, autostart: bool = True, **kwargs: Any) -> None: """Initializes the `ApproximateTimeSynchronizer` instance. - All positional and keyword arguments are forwarded - to the underlying `message_filters.ApproximateTimeSynchronizer`. + Args: + upstreams: message filters to be synchronized. + args: positional arguments to forward to `message_filters.ApproximateTimeSynchronizer`. + autostart: whether to start filtering on instantiation or not. + kwargs: keyword arguments to forward to `message_filters.ApproximateTimeSynchronizer`. """ - super().__init__() + self.upstreams = list(upstreams) + self.options = (args, kwargs) + self._unsafe_synchronizer: Optional[message_filters.ApproximateTimeSynchronizer] = None + super().__init__(autostart) + + def _start(self) -> None: + if self._unsafe_synchronizer is not None: + raise RuntimeError("synchronizer already connected") + args, kwargs = self.options self._unsafe_synchronizer = message_filters.ApproximateTimeSynchronizer( + self.upstreams, *args, **kwargs, ) self._unsafe_synchronizer.registerCallback(self.signalMessage) + for upstream in self.upstreams: + upstream.start() + + def _stop(self) -> None: + if self._unsafe_synchronizer is None: + raise RuntimeError("synchronizer not connected") + for upstream in self.upstreams: + upstream.stop() + self._unsafe_synchronizer = None def __getattr__(self, name: str) -> Any: return getattr(self._unsafe_synchronizer, name) @@ -123,6 +221,8 @@ def __init__( tf_buffer: tf2_ros.Buffer, tolerance_sec: float, logger: Optional[RcutilsLogger] = None, + *, + autostart: bool = True, ) -> None: """Initializes the transform filter. @@ -133,22 +233,37 @@ def __init__( tolerance_sec: a tolerance, in seconds, to wait for late transforms before abandoning any waits and filtering out the corresponding messages. logger: an optional logger to notify the yser about any errors during filtering. + autostart: whether to start filtering on instantiation or not. """ - super().__init__() + self.__lock = threading.RLock() self._logger = logger - self._lock = threading.RLock() self._waitqueue: collections.deque = collections.deque() self._ongoing_wait: Optional[Future] = None self._ongoing_wait_time: Optional[Time] = None + self._connection: Optional[int] = None self.target_frame_id = target_frame_id self.tf_buffer = tf_buffer self.tolerance = Duration(seconds=tolerance_sec) - self.connection = upstream.registerCallback(self.add) + self.upstream = upstream + super().__init__(autostart) + + def _start(self) -> None: + if self._connection is not None: + raise RuntimeError("filter already connected") + self._connection = self.upstream.registerCallback(self.add) + self.upstream.start() + + def _stop(self) -> None: + if self._connection is None: + raise RuntimeError("filter not connected") + self.upstream.stop() + self.upstream.unregisterCallback(self._connection) + self._connection = None def _wait_callback(self, messages: Sequence[Any], future: Future) -> None: if future.cancelled(): return - with self._lock: + with self.__lock: try: if future.result() is True: source_frame_id = messages[0].header.frame_id @@ -186,8 +301,8 @@ def _wait_callback(self, messages: Sequence[Any], future: Future) -> None: self._ongoing_wait = None def add(self, *messages: Any) -> None: - """Adds new `messages` to the filter.""" - with self._lock: + """Add `messages` to the filter.""" + with self.__lock: time = Time.from_msg(messages[0].header.stamp) if self._ongoing_wait and not self._ongoing_wait.done() and time - self._ongoing_wait_time > self.tolerance: self._ongoing_wait.cancel() @@ -217,7 +332,7 @@ def add(self, *messages: Any) -> None: class Adapter(Filter): """A message filter for data adaptation.""" - def __init__(self, upstream: Filter, fn: Callable) -> None: + def __init__(self, upstream: Filter, fn: Callable, *, autostart: bool = True) -> None: """Initializes the adapter. Args: @@ -225,13 +340,28 @@ def __init__(self, upstream: Filter, fn: Callable) -> None: fn: a callable that takes messages as arguments and returns some data to be signaled (i.e. propagated down the filter chain). If none is returned, no message signaling will occur. + autostart: whether to start filtering on instantiation or not. """ - super().__init__() self.fn = fn - self.connection = upstream.registerCallback(self.add) + self.upstream = upstream + self._connection: Optional[int] = None + super().__init__(autostart) + + def _start(self) -> None: + if self._connection is not None: + raise RuntimeError("adapter already connected") + self._connection = self.upstream.registerCallback(self.add) + self.upstream.start() + + def _stop(self) -> None: + if self._connection is None: + raise RuntimeError("adapter not connected") + self.upstream.stop() + self.upstream.unregisterCallback(self._connection) + self._connection = None def add(self, *messages: Any) -> None: - """Adds new `messages` to the adapter.""" + """Add `messages` to the filter.""" result = self.fn(*messages) if result is not None: self.signalMessage(result) @@ -240,16 +370,37 @@ def add(self, *messages: Any) -> None: class Tunnel(Filter): """A message filter that simply forwards messages but can be detached.""" - def __init__(self, upstream: Filter) -> None: + def __init__(self, upstream: Filter, *, autostart: bool = True) -> None: """Initializes the tunnel. Args: upstream: the upstream message filter. + autostart: whether to start filtering on instantiation or not. """ - super().__init__() - self.upstream = upstream - self.connection = upstream.registerCallback(self.signalMessage) + self.upstream: Optional[Filter] = upstream + self._connection: Optional[int] = None + super().__init__(autostart) + + def _start(self) -> None: + if self.upstream is None: + raise RuntimeError("tunnel closed") + if self._connection is not None: + raise RuntimeError("tunnel already connected") + self._connection = self.upstream.registerCallback(self.signalMessage) + self.upstream.start() + + def _stop(self) -> None: + if self.upstream is None: + raise RuntimeError("tunnel closed") + if self._connection is None: + raise RuntimeError("tunnel not connected") + self.upstream.stop() + self.upstream.unregisterCallback(self._connection) + self._connection = None def close(self) -> None: - """Closes the tunnel, disconnecting it from upstream.""" - self.upstream.unregisterCallback(self.connection) + """Closes the tunnel, simply disconnecting it from upstream.""" + with self._connection_lock: + if self._connection is not None and self.upstream is not None: + self.upstream.unregisterCallback(self._connection) + self.upstream = None diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py index 3f7de7f..59aa39d 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py @@ -30,6 +30,7 @@ def __init__( qos_profile: Optional[Union[QoSProfile, int]] = None, history_length: Optional[int] = None, node: Optional[Node] = None, + autostart: bool = True, **kwargs: Any, ) -> None: """Initializes the subscription. @@ -42,6 +43,7 @@ def __init__( history_length: optional historic data size, defaults to 1 node: optional node for the underlying native subscription, defaults to the current process node. + autostart: whether to start feeding messages immediately or not. kwargs: other keyword arguments are used to create the underlying native subscription. See `rclpy.node.Node.create_subscription` documentation for further reference. """ @@ -57,6 +59,7 @@ def __init__( message_type, topic_name, qos_profile=qos_profile, + autostart=autostart, **kwargs, ), history_length=history_length, @@ -126,11 +129,6 @@ def topic_name(self) -> str: """Gets the name of the topic subscribed.""" return self._topic_name - def close(self) -> None: - """Closes the subscription.""" - self._node.destroy_subscription(self.subscriber.sub) - super().close() - # Aliases for improved readability cancel = MessageFeed.close unsubscribe = MessageFeed.close @@ -298,10 +296,9 @@ def callback(*messages: Sequence[Any]) -> None: sync.registerCallback(callback) def cleanup_subscribers(_: Future) -> None: - nonlocal node, subscribers - assert node is not None + nonlocal subscribers for sub in subscribers: - node.destroy_subscription(sub.sub) + sub.close() future.add_done_callback(cleanup_subscribers) return future diff --git a/bdai_ros2_wrappers/test/test_subscription.py b/bdai_ros2_wrappers/test/test_subscription.py index 025d095..513ab67 100644 --- a/bdai_ros2_wrappers/test/test_subscription.py +++ b/bdai_ros2_wrappers/test/test_subscription.py @@ -123,6 +123,30 @@ def deferred_publish() -> None: assert expected_sequence_numbers == historic_numbers +def test_deferred_start_subscription(ros: ROSAwareScope) -> None: + """Asserts that deferred subscription start works as expected.""" + assert ros.node is not None + pub = ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE) + sequence = Subscription( + Int8, + "sequence", + DEFAULT_QOS_PROFILE, + node=ros.node, + autostart=False, + ) + assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0) + assert sequence.matched_publishers == 1 + + pub.publish(Int8(data=1)) + + future = sequence.update + assert not future.done() + sequence.start() + + assert wait_for_future(future, timeout_sec=5.0) + assert cast(Int8, ensure(sequence.latest)).data == 1 + + def test_subscription_cancelation(ros: ROSAwareScope) -> None: """Asserts that cancelling a subscription works as expected.""" assert ros.node is not None