From 633f29286756061d782e6152b822f26999899ad7 Mon Sep 17 00:00:00 2001 From: Michel Hidalgo Date: Thu, 21 Nov 2024 15:21:24 -0300 Subject: [PATCH] Test and please pre-commit linters Signed-off-by: Michel Hidalgo --- .../bdai_ros2_wrappers/feeds.py | 32 ++++++++-------- .../bdai_ros2_wrappers/filters.py | 38 ++++++++++++++----- .../bdai_ros2_wrappers/subscription.py | 5 +-- bdai_ros2_wrappers/test/test_subscription.py | 24 ++++++++++++ 4 files changed, 70 insertions(+), 29 deletions(-) diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py index 3152710..280fae8 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py @@ -57,7 +57,7 @@ def __init__( self._link.registerCallback( lambda *msgs: self._tape.write(msgs if len(msgs) > 1 else msgs[0]), ) - node.context.on_shutdown(self._tape.close) + node.context.on_shutdown(self.close) @property def link(self) -> Filter: @@ -180,12 +180,10 @@ def start(self) -> None: def stop(self) -> None: """Stop the message feed.""" self._link.stop() - - def close(self) -> None: - """Closes the message feed.""" - self._link.stop() self._tape.close() + close = stop + class AdaptedMessageFeed(MessageFeed[MessageT]): """A message feed decorator to simplify adapter patterns.""" @@ -214,10 +212,10 @@ def feed(self) -> MessageFeed: """Gets the upstream message feed.""" return self._feed - def close(self) -> None: - """Closes this message feed and the upstream one as well.""" - self._feed.close() - super().close() + def stop(self) -> None: + """Stop this message feed and the upstream one as well.""" + self._feed.stop() + super().stop() class FramedMessageFeed(MessageFeed[MessageT]): @@ -271,10 +269,10 @@ def feed(self) -> MessageFeed[MessageT]: """Gets the upstream message feed.""" return self._feed - def close(self) -> None: - """Closes this message feed and the upstream one as well.""" - self._feed.close() - super().close() + def stop(self) -> None: + """Stop this message feed and the upstream one as well.""" + self._feed.stop() + super().stop() class SynchronizedMessageFeed(MessageFeed): @@ -321,8 +319,8 @@ def feeds(self) -> Iterable[MessageFeed]: """Gets all aggregated message feeds.""" return self._feeds - def close(self) -> None: - """Closes this message feed and all upstream ones as well.""" + def stop(self) -> None: + """Stop this message feed and all upstream ones as well.""" for feed in self._feeds: - feed.close() - super().close() + feed.stop() + super().stop() diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py index 4b14791..e0270ca 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py @@ -37,7 +37,7 @@ def __init__(self, autostart: bool = True) -> None: Args: autostart: whether to start filtering on instantiation or not. """ - self._active = False + self._stopped = self._started = False self._connection_lock = threading.Lock() self._connection_sequence = itertools.count() self.callbacks: Dict[int, Tuple[Callable, Tuple]] = {} @@ -48,21 +48,33 @@ def _start(self) -> None: """Hook for start logic customization""" def start(self) -> None: - """Start filtering.""" + """Start filtering. + + Raises: + RuntimeError: if filtering has been stopped already. + """ with self._connection_lock: - if not self._active: + if self._stopped: + raise RuntimeError("filtering already stopped") + if not self._started: self._start() - self._active = True + self._started = True def _stop(self) -> None: """Hook for stop logic customization""" def stop(self) -> None: - """Stop filtering.""" + """Stop filtering. + + Raises: + RuntimeError: if filter has not been started. + """ with self._connection_lock: - if not self._active: + if not self._started: + raise RuntimeError("filter not started") + if not self._stopped: self._stop() - self._active = False + self._stopped = True def registerCallback(self, fn: Callable, *args: Any) -> int: """Register callable to be called on filter output. @@ -73,8 +85,13 @@ def registerCallback(self, fn: Callable, *args: Any) -> int: Returns: a unique connection identifier. + + Raises: + RuntimeError: if filter has been stopped. """ with self._connection_lock: + if self._stopped: + raise RuntimeError("filter stopped") connection = next(self._connection_sequence) self.callbacks[connection] = (fn, args) return connection @@ -95,10 +112,13 @@ def signalMessage(self, *messages: Any) -> None: messages: messages to be forwarded through the filter. Raises: - RuntimeError: if filter has not been started yet. + RuntimeError: if filter is not active + (either not started or already stopped). """ with self._connection_lock: - if not self._active: + if self._stopped: + raise RuntimeError("filter stopped") + if not self._started: raise RuntimeError("filter not started") callbacks = list(self.callbacks.values()) diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py index 600d172..59aa39d 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py @@ -296,10 +296,9 @@ def callback(*messages: Sequence[Any]) -> None: sync.registerCallback(callback) def cleanup_subscribers(_: Future) -> None: - nonlocal node, subscribers - assert node is not None + nonlocal subscribers for sub in subscribers: - node.destroy_subscription(sub.sub) + sub.close() future.add_done_callback(cleanup_subscribers) return future diff --git a/bdai_ros2_wrappers/test/test_subscription.py b/bdai_ros2_wrappers/test/test_subscription.py index 025d095..513ab67 100644 --- a/bdai_ros2_wrappers/test/test_subscription.py +++ b/bdai_ros2_wrappers/test/test_subscription.py @@ -123,6 +123,30 @@ def deferred_publish() -> None: assert expected_sequence_numbers == historic_numbers +def test_deferred_start_subscription(ros: ROSAwareScope) -> None: + """Asserts that deferred subscription start works as expected.""" + assert ros.node is not None + pub = ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE) + sequence = Subscription( + Int8, + "sequence", + DEFAULT_QOS_PROFILE, + node=ros.node, + autostart=False, + ) + assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0) + assert sequence.matched_publishers == 1 + + pub.publish(Int8(data=1)) + + future = sequence.update + assert not future.done() + sequence.start() + + assert wait_for_future(future, timeout_sec=5.0) + assert cast(Int8, ensure(sequence.latest)).data == 1 + + def test_subscription_cancelation(ros: ROSAwareScope) -> None: """Asserts that cancelling a subscription works as expected.""" assert ros.node is not None