Skip to content

Commit

Permalink
Test and please pre-commit linters
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <[email protected]>
  • Loading branch information
mhidalgo-bdai committed Nov 21, 2024
1 parent 190dbc1 commit 633f292
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 29 deletions.
32 changes: 15 additions & 17 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self._link.registerCallback(
lambda *msgs: self._tape.write(msgs if len(msgs) > 1 else msgs[0]),
)
node.context.on_shutdown(self._tape.close)
node.context.on_shutdown(self.close)

@property
def link(self) -> Filter:
Expand Down Expand Up @@ -180,12 +180,10 @@ def start(self) -> None:
def stop(self) -> None:
"""Stop the message feed."""
self._link.stop()

def close(self) -> None:
"""Closes the message feed."""
self._link.stop()
self._tape.close()

close = stop


class AdaptedMessageFeed(MessageFeed[MessageT]):
"""A message feed decorator to simplify adapter patterns."""
Expand Down Expand Up @@ -214,10 +212,10 @@ def feed(self) -> MessageFeed:
"""Gets the upstream message feed."""
return self._feed

def close(self) -> None:
"""Closes this message feed and the upstream one as well."""
self._feed.close()
super().close()
def stop(self) -> None:
"""Stop this message feed and the upstream one as well."""
self._feed.stop()
super().stop()


class FramedMessageFeed(MessageFeed[MessageT]):
Expand Down Expand Up @@ -271,10 +269,10 @@ def feed(self) -> MessageFeed[MessageT]:
"""Gets the upstream message feed."""
return self._feed

def close(self) -> None:
"""Closes this message feed and the upstream one as well."""
self._feed.close()
super().close()
def stop(self) -> None:
"""Stop this message feed and the upstream one as well."""
self._feed.stop()
super().stop()


class SynchronizedMessageFeed(MessageFeed):
Expand Down Expand Up @@ -321,8 +319,8 @@ def feeds(self) -> Iterable[MessageFeed]:
"""Gets all aggregated message feeds."""
return self._feeds

def close(self) -> None:
"""Closes this message feed and all upstream ones as well."""
def stop(self) -> None:
"""Stop this message feed and all upstream ones as well."""
for feed in self._feeds:
feed.close()
super().close()
feed.stop()
super().stop()
38 changes: 29 additions & 9 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, autostart: bool = True) -> None:
Args:
autostart: whether to start filtering on instantiation or not.
"""
self._active = False
self._stopped = self._started = False
self._connection_lock = threading.Lock()
self._connection_sequence = itertools.count()
self.callbacks: Dict[int, Tuple[Callable, Tuple]] = {}
Expand All @@ -48,21 +48,33 @@ def _start(self) -> None:
"""Hook for start logic customization"""

def start(self) -> None:
"""Start filtering."""
"""Start filtering.
Raises:
RuntimeError: if filtering has been stopped already.
"""
with self._connection_lock:
if not self._active:
if self._stopped:
raise RuntimeError("filtering already stopped")
if not self._started:
self._start()
self._active = True
self._started = True

def _stop(self) -> None:
"""Hook for stop logic customization"""

def stop(self) -> None:
"""Stop filtering."""
"""Stop filtering.
Raises:
RuntimeError: if filter has not been started.
"""
with self._connection_lock:
if not self._active:
if not self._started:
raise RuntimeError("filter not started")
if not self._stopped:
self._stop()
self._active = False
self._stopped = True

def registerCallback(self, fn: Callable, *args: Any) -> int:
"""Register callable to be called on filter output.
Expand All @@ -73,8 +85,13 @@ def registerCallback(self, fn: Callable, *args: Any) -> int:
Returns:
a unique connection identifier.
Raises:
RuntimeError: if filter has been stopped.
"""
with self._connection_lock:
if self._stopped:
raise RuntimeError("filter stopped")
connection = next(self._connection_sequence)
self.callbacks[connection] = (fn, args)
return connection
Expand All @@ -95,10 +112,13 @@ def signalMessage(self, *messages: Any) -> None:
messages: messages to be forwarded through the filter.
Raises:
RuntimeError: if filter has not been started yet.
RuntimeError: if filter is not active
(either not started or already stopped).
"""
with self._connection_lock:
if not self._active:
if self._stopped:
raise RuntimeError("filter stopped")
if not self._started:
raise RuntimeError("filter not started")
callbacks = list(self.callbacks.values())

Expand Down
5 changes: 2 additions & 3 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,9 @@ def callback(*messages: Sequence[Any]) -> None:
sync.registerCallback(callback)

def cleanup_subscribers(_: Future) -> None:
nonlocal node, subscribers
assert node is not None
nonlocal subscribers
for sub in subscribers:
node.destroy_subscription(sub.sub)
sub.close()

future.add_done_callback(cleanup_subscribers)
return future
24 changes: 24 additions & 0 deletions bdai_ros2_wrappers/test/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,30 @@ def deferred_publish() -> None:
assert expected_sequence_numbers == historic_numbers


def test_deferred_start_subscription(ros: ROSAwareScope) -> None:
"""Asserts that deferred subscription start works as expected."""
assert ros.node is not None
pub = ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE)
sequence = Subscription(
Int8,
"sequence",
DEFAULT_QOS_PROFILE,
node=ros.node,
autostart=False,
)
assert wait_for_future(sequence.publisher_matches(1), timeout_sec=5.0)
assert sequence.matched_publishers == 1

pub.publish(Int8(data=1))

future = sequence.update
assert not future.done()
sequence.start()

assert wait_for_future(future, timeout_sec=5.0)
assert cast(Int8, ensure(sequence.latest)).data == 1


def test_subscription_cancelation(ros: ROSAwareScope) -> None:
"""Asserts that cancelling a subscription works as expected."""
assert ros.node is not None
Expand Down

0 comments on commit 633f292

Please sign in to comment.