Skip to content

Commit

Permalink
Add matching Subscription update futures
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <[email protected]>
  • Loading branch information
mhidalgo-bdai committed Apr 26, 2024
1 parent ed2a9b9 commit 3b1ac5d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 4 deletions.
15 changes: 13 additions & 2 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2023 Boston Dynamics AI Institute Inc. All rights reserved.

from collections.abc import Sequence
from typing import Any, Iterator, Optional, TypeVar, Union
from typing import Any, Callable, Iterator, Optional, TypeVar, Union

import message_filters
from rclpy.callback_groups import CallbackGroup
Expand Down Expand Up @@ -77,9 +77,20 @@ def latest(self) -> Optional[Any]:

@property
def update(self) -> Future:
"""Gets the a future to the next message yet to be received."""
"""Gets the future to the next message yet to be received."""
return self._message_tape.future_write

def matching_update(self, matching_predicate: Callable[[Any], bool]) -> Future:
"""Gets a future to the next matching message yet to be received.
Args:
matching_predicate: a boolean predicate to match incoming messages.
Returns:
a future.
"""
return self._message_tape.future_matching_write(matching_predicate)

def stream(
self,
*,
Expand Down
34 changes: 32 additions & 2 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import queue
import threading
import warnings
from typing import Any, Callable, Iterator, List, Optional
from typing import Any, Callable, Iterator, List, Optional, Tuple

import rclpy.clock
import rclpy.duration
Expand Down Expand Up @@ -140,18 +140,41 @@ def __init__(self, max_length: Optional[int] = None) -> None:
if max_length is None or max_length > 0:
self._content = collections.deque(maxlen=max_length)
self._future_write: Optional[Future] = None
self._future_matching_writes: List[Tuple[Callable[[Any], bool], Future]] = []
self._closed = False

@property
def future_write(self) -> Future:
"""Gets the a future to the next data yet to be written."""
"""Gets the future to the next data yet to be written."""
with self._lock:
if self._future_write is None:
self._future_write = Future()
if self._closed:
self._future_write.cancel()
return self._future_write

def future_matching_write(self, matching_predicate: Callable[[Any], bool]) -> Future:
"""Gets a future to the next matching data yet to be written.
Args:
matching_predicate: a boolean predicate to match written data.
Returns:
a future.
"""
with self._lock:
future_write = Future()
if not self._closed:
self._future_matching_writes.append(
(
matching_predicate,
future_write,
),
)
else:
future_write.cancel()
return future_write

def write(self, data: Any) -> bool:
"""Write the data tape."""
with self._lock:
Expand All @@ -168,6 +191,11 @@ def write(self, data: Any) -> bool:
if self._future_write is not None:
self._future_write.set_result(data)
self._future_write = None
for item in list(self._future_matching_writes):
matching_predicate, future_write = item
if matching_predicate(data):
future_write.set_result(data)
self._future_matching_writes.remove(item)
return True

def content(
Expand Down Expand Up @@ -244,6 +272,8 @@ def close(self) -> None:
stream.interrupt()
if self._future_write is not None:
self._future_write.cancel()
for _, future_write in self._future_matching_writes:
future_write.cancel()


def synchronized(
Expand Down
20 changes: 20 additions & 0 deletions bdai_ros2_wrappers/test/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,26 @@ def test_subscription_future_wait(ros: ROSAwareScope) -> None:
assert cast(Int8, sequence.latest).data == 1


def test_subscription_matching_future_wait(ros: ROSAwareScope) -> None:
"""Asserts that waiting for a matching subscription update works as expected."""
assert ros.node is not None
pub = ros.node.create_publisher(Int8, "sequence", DEFAULT_QOS_PROFILE)
sequence = Subscription(Int8, "sequence", DEFAULT_QOS_PROFILE, node=ros.node)

def deferred_publish() -> None:
time.sleep(0.5)
for num in range(5):
pub.publish(Int8(data=num))

assert ros.executor is not None
ros.executor.create_task(deferred_publish)

future = sequence.matching_update(lambda message: message.data == 3)
assert wait_for_future(future, timeout_sec=5.0)
message = future.result()
assert message.data == 3


def test_subscription_iteration(ros: ROSAwareScope) -> None:
"""Asserts that iterating over subscription messages works as expected."""
assert ros.node is not None
Expand Down

0 comments on commit 3b1ac5d

Please sign in to comment.