Skip to content

Commit

Permalink
refactor - use futures
Browse files Browse the repository at this point in the history
  • Loading branch information
kzheng-bdai committed Apr 7, 2024
1 parent d55520c commit 65e075f
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 118 deletions.
179 changes: 75 additions & 104 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,13 @@ def wait_for_message(


def wait_for_messages(
node: rclpy.node.Node,
topics: typing.List,
mtypes: typing.List,
*,
node: typing.Optional[rclpy.node.Node] = None,
timeout_sec: typing.Optional[float] = None,
**kwargs: typing.Any,
) -> typing.Any:
) -> typing.Union[None, typing.List[MessageT]]:
"""Waits for messages to arrive at multiple topics within a given time window.
Uses message_filters.ApproximateTimeSynchronizer. This function blocks
Expand All @@ -102,115 +104,84 @@ def wait_for_messages(
node (Node): the node being attached
topics (list): List of topics
mtypes (list): List of message types, one for each topic.
timeout_sec (float or None): Time in seconds to wait. None if forever.
If exceeded timeout, self.messages will contain None for
each topic.
kwargs: additional arguments, including
delay (float) The delay in seconds for which the messages
could be synchronized.
could be synchronized (i.e. the time window).
allow_headerless (bool): Whether it's ok for there to be
no header in the messages.
sleep (float) the amount of time to wait before checking
whether messages are received
timeout (float or None): Time in seconds to wait. None if forever.
If exceeded timeout, self.messages will contain None for
each topic.
qos_profiles (Dict): maps from topic name to QoSProfile
"""
return _WaitForMessages(node, topics, mtypes, **kwargs).messages


class _WaitForMessages:
def __init__(
self,
node: rclpy.node.Node,
topics: typing.List,
mtypes: typing.List,
queue_size: int = 10,
delay: float = 0.2,
allow_headerless: bool = False,
sleep: float = 0.5,
timeout: typing.Optional[float] = None,
verbose: bool = False,
exception_on_timeout: bool = False,
qos_profiles: typing.Optional[typing.Dict[str, rclpy.qos.QoSProfile]] = None,
callback_group: typing.Optional[rclpy.callback_groups.CallbackGroup] = None,
) -> None:
self.node = node
self.messages: typing.Optional[typing.Tuple] = None
self.verbose = verbose
self.topics = topics
self.timeout = timeout
self.exception_on_timeout = exception_on_timeout
self.has_timed_out = False
if qos_profiles is None:
qos_profiles = {}
self.qos_profiles = qos_profiles
self.logger = rclpy.impl.rcutils_logger.RcutilsLogger(name="wait_for_messages")

if self.verbose:
self.logger.info("initializing message filter ApproximateTimeSynchronizer")
self.subs: typing.Optional[typing.List] = [
self._message_filters_subscriber(mtype, topic, callback_group=callback_group)
for topic, mtype in zip(topics, mtypes, strict=True) # type: ignore
]
self.ts = message_filters.ApproximateTimeSynchronizer(
self.subs,
queue_size,
delay,
allow_headerless=allow_headerless,
node = node or scope.node()
if node is None:
raise ValueError("no ROS 2 node available (did you use bdai_ros2_wrapper.process.main?)")
future = wait_for_messages_async(topics, mtypes, node=node, **kwargs)
if not wait_for_future(future, timeout_sec, context=node.context):
future.cancel()
return None
return future.result()


def wait_for_messages_async(
topics: typing.List,
mtypes: typing.List,
*,
queue_size: int = 10,
delay: float = 0.2,
allow_headerless: bool = False,
sleep: float = 0.5,
verbose: bool = False,
exception_on_timeout: bool = False,
qos_profiles: typing.Optional[typing.Dict[str, rclpy.qos.QoSProfile]] = None,
node: typing.Optional[rclpy.node.Node] = None,
callback_group: typing.Optional[rclpy.callback_groups.CallbackGroup] = None,
) -> rclpy.task.Future:
"""Asynchronous version of wait_for_messages"""
node = node or scope.node()
if node is None:
raise ValueError("no ROS 2 node available (did you use bdai_ros2_wrapper.process.main?)")
if qos_profiles is None:
qos_profiles = {}
subs: typing.List[message_filters.Subscriber] = []
for topic, mtype in zip(topics, mtypes):
qos_profile = qos_profiles.get(topic, 10)
subs.append(
message_filters.Subscriber(node, mtype, topic, qos_profile=qos_profile, callback_group=callback_group),
)
self.ts.registerCallback(self._cb)

try:
self._start_time = self.node.get_clock().now()
rate = self.node.create_rate(1.0 / sleep)
while self.messages is None and not self.has_timed_out:
if self.check_messages_received():
break
rate.sleep()
finally:
self.destroy_subs()

def _message_filters_subscriber(
self,
mtype: typing.Any,
topic: str,
callback_group: typing.Optional[rclpy.callback_groups.CallbackGroup] = None,
) -> message_filters.Subscriber:
if topic in self.qos_profiles:
return message_filters.Subscriber(
self.node,
mtype,
topic,
qos_profile=self.qos_profiles[topic],
callback_group=callback_group,
)
return message_filters.Subscriber(self.node, mtype, topic, callback_group=callback_group)

def check_messages_received(self) -> bool:
if self.messages is not None:
self.logger.info("WaitForMessages: Received messages! Done!")
return True
if self.verbose:
self.logger.info("WaitForMessages: waiting for messages from {}".format(self.topics))
_dt = self.node.get_clock().now() - self._start_time
if self.timeout is not None and _dt.nanoseconds * 1e-9 > self.timeout:
self.logger.error("WaitForMessages: timeout waiting for messages")
self.messages = (None,) * len(self.topics)
self.has_timed_out = True
if self.exception_on_timeout:
raise TimeoutError("WaitForMessages: timeout waiting for messages")
return False

def _cb(self, *messages: typing.Any) -> None:
if self.messages is not None:
return
if self.verbose:
self.logger.info("WaitForMessages: callback got messages!")
self.messages = messages

def destroy_subs(self) -> None:
"""destroy all message filter subscribers"""
if self.subs is not None:
for mf_sub in self.subs:
self.node.destroy_subscription(mf_sub.sub)
self.subs = None
future = rclpy.task.Future()

def callback(*messages: typing.List[MessageT]) -> None:
if not future.done():
future.set_result(messages)

ts = message_filters.ApproximateTimeSynchronizer(
subs,
queue_size,
delay,
allow_headerless=allow_headerless,
)
ts.registerCallback(callback)
future.add_done_callback(lambda future: _destroy_subs(node, subs))
return future


def _message_filters_subscriber(
node: rclpy.node.Node,
mtype: typing.Any,
topic: str,
qos_profile: typing.Union[rclpy.qos.QoSProfile, int] = 1,
**kwargs: typing.Any,
) -> message_filters.Subscriber:
return message_filters.Subscriber(node, mtype, topic, qos_profile=qos_profile, **kwargs)


def _destroy_subs(node: rclpy.node.Node, subs: typing.List[message_filters.Subscriber]) -> None:
"""destroy all message filter subscribers"""
for mf_sub in subs:
node.destroy_subscription(mf_sub.sub)
34 changes: 20 additions & 14 deletions bdai_ros2_wrappers/test/test_wait_for_messages.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
# Copyright (c) 2023 Boston Dynamics AI Institute LLC. All rights reserved.
from enum import EnumMeta
from typing import Any

import std_msgs.msg
from rclpy.qos import QoSDurabilityPolicy, QoSProfile

from bdai_ros2_wrappers.node import Node
from bdai_ros2_wrappers.scope import ROSAwareScope
from bdai_ros2_wrappers.subscription import wait_for_messages


TRANSIENT_LOCAL = QoSProfile(depth=10, durability=QoSDurabilityPolicy.TRANSIENT_LOCAL)

class NodeFoo(Node):

def __init__(self, **kwargs: Any) -> None:
super().__init__("foo", **kwargs)
self.pub = self.create_publisher(std_msgs.msg.String, "/test1", TRANSIENT_LOCAL)
self.pub = self.create_publisher(std_msgs.msg.String, "/test1", 10)
self.timer = self.create_timer(0.5, self.publish_msg)

def publish_msg(self) -> None:
self.pub.publish(std_msgs.msg.String(data="hello from foo"))


Expand All @@ -37,31 +35,39 @@ def test_wait_for_messages(ros: ROSAwareScope) -> None:
node_wfm = ros.node

messages = wait_for_messages(
node_wfm,
["/test1", "/test2"],
[std_msgs.msg.String, std_msgs.msg.String],
allow_headerless=True,
verbose=True,
delay=0.5,
timeout=20,
qos_profiles={"/test1": TRANSIENT_LOCAL},
timeout_sec=20,
node=node_wfm,
)
assert messages == (None, None) or messages == (
assert messages is None or messages == (
std_msgs.msg.String(data="hello from foo"),
std_msgs.msg.String(data="hello from bar"),
)

messages = wait_for_messages(
node_wfm,
["/test2", "/test1"],
[std_msgs.msg.String, std_msgs.msg.String],
allow_headerless=True,
verbose=True,
delay=0.5,
timeout=20,
qos_profiles={"/test1": TRANSIENT_LOCAL}
timeout_sec=20,
node=node_wfm,
)
assert messages == (None, None) or messages == (
assert messages is None or messages == (
std_msgs.msg.String(data="hello from bar"),
std_msgs.msg.String(data="hello from foo"),
)

# Test the case where the topic doesn't exist - timeout expected
messages = wait_for_messages(
["/test3", "/test1"],
[std_msgs.msg.String, std_msgs.msg.String],
allow_headerless=True,
timeout_sec=1,
node=node_wfm,
)
assert messages is None

0 comments on commit 65e075f

Please sign in to comment.