Skip to content

Commit

Permalink
Add message feeds for synchronization and framing.
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <[email protected]>
  • Loading branch information
mhidalgo-bdai committed May 31, 2024
1 parent e618797 commit 1d157d6
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 9 deletions.
112 changes: 108 additions & 4 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.

from typing import Any, Callable, Iterator, List, Optional
from typing import Any, Callable, Iterable, Iterator, List, Optional

from message_filters import SimpleFilter
import tf2_ros
from message_filters import ApproximateTimeSynchronizer, SimpleFilter
from rclpy.node import Node
from rclpy.task import Future

import bdai_ros2_wrappers.scope as scope
from bdai_ros2_wrappers.filters import TransformFilter
from bdai_ros2_wrappers.utilities import Tape


Expand Down Expand Up @@ -34,7 +36,7 @@ def __init__(
history_length = 1
self._link = link
self._tape = Tape(history_length)
self._link.registerCallback(self._tape.write)
self._link.registerCallback(lambda *msgs: self._tape.write(msgs if len(msgs) > 1 else msgs[0]))
node.context.on_shutdown(self._tape.close)

@property
Expand All @@ -50,7 +52,7 @@ def history(self) -> List[Any]:
@property
def latest(self) -> Optional[Any]:
"""Gets the latest message received, if any."""
return next(self._tape.content(), None)
return self._tape.head

@property
def update(self) -> Future:
Expand Down Expand Up @@ -99,3 +101,105 @@ def stream(
def close(self) -> None:
"""Closes the message feed."""
self._tape.close()


class FramedMessageFeed(MessageFeed):
"""A message feed decorator, incorporating transforms using a `TransformFilter` instance."""

def __init__(
self,
feed: MessageFeed,
target_frame_id: str,
*,
tolerance_sec: float = 1.0,
tf_buffer: Optional[tf2_ros.Buffer] = None,
history_length: Optional[int] = None,
node: Optional[Node] = None,
) -> None:
"""Initializes the message feed.
Args:
feed: the upstream message feed.
target_frame_id: the target frame ID for transforms.
tf_buffer: optional buffer of transforms to look up. If none is provided
a transforms' buffer and a listener will be instantiated.
tolerance_sec: optional tolerance, in seconds, to wait for late transforms.
history_length: optional historic data size, defaults to 1.
node: optional node for the underlying native subscription, defaults to
the current process node.
"""
if node is None:
node = scope.ensure_node()
if tf_buffer is None:
tf_buffer = tf2_ros.Buffer()
self._tf_listener = tf2_ros.TransformListener(tf_buffer, node)
super().__init__(
TransformFilter(
feed.link,
target_frame_id,
tf_buffer,
tolerance_sec,
node.get_logger(),
),
history_length=history_length,
node=node,
)
self._feed = feed

@property
def feed(self) -> MessageFeed:
"""Gets the upstream message feed."""
return self._feed

def close(self) -> None:
"""Closes this message feed and the upstream one as well."""
self._feed.close()
super().close()


class SynchronizedMessageFeed(MessageFeed):
"""A message feeds' aggregator using a `message_filters.ApproximateTimeSynchronizer` instance."""

def __init__(
self,
*feeds: MessageFeed,
queue_size: int = 10,
delay: float = 0.2,
allow_headerless: bool = False,
history_length: Optional[int] = None,
node: Optional[Node] = None,
) -> None:
"""Initializes the message feed.
Args:
feeds: upstream message feeds to be synchronized.
queue_size: the message queue size for synchronization.
delay: the maximum delay, in seconds, between messages for synchronization to succeed.
allow_headerless: whether it's OK for there to be no header in the messages (in which
case, the ROS time of arrival will be used).
history_length: optional historic data size, defaults to 1.
node: optional node for the underlying native subscription, defaults to
the current process node.
"""
super().__init__(
ApproximateTimeSynchronizer(
[f.link for f in feeds],
queue_size,
delay,
allow_headerless=allow_headerless,
),
history_length=history_length,
node=node,
)
self._feeds = feeds

@property
def feeds(self) -> Iterable[MessageFeed]:
"""Gets all aggregated message feeds."""
return self._feeds

def close(self) -> None:
"""Closes this message feed and all upstream ones as well."""
for feed in self._feeds:
feed.close()
super().close()
2 changes: 1 addition & 1 deletion bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
"""
super().__init__()
self._logger = logger
self._lock = threading.Lock()
self._lock = threading.RLock()
self._waitqueue: collections.deque = collections.deque()
self._ongoing_wait: Optional[Future] = None
self._ongoing_wait_time: Optional[Time] = None
Expand Down
27 changes: 27 additions & 0 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ def write(self, data: Any) -> bool:
self._future_matching_writes.remove(item)
return True

@property
def head(self) -> Optional[Any]:
"""Returns the data tape head, if any."""
with self._lock:
if self._content is None:
return None
if len(self._content) == 0:
return None
return self._content[0]

def content(
self,
*,
Expand Down Expand Up @@ -450,3 +460,20 @@ def take_kwargs(func: Callable, kwargs: Mapping) -> Tuple[Mapping, Mapping]:
else:
dropped[name] = value
return taken, dropped


def ensure(value: Optional[Any]) -> Any:
"""Ensures `value` is not None or fails trying."""
if value is None:
frame = inspect.currentframe()
assert frame is not None
frame = frame.f_back
assert frame is not None
traceback = inspect.getframeinfo(frame)
message = f"{traceback.filename}:{traceback.lineno}: "
if traceback.code_context is not None:
message += "".join(traceback.code_context).strip("\n ") + " failed"
else:
message += "ensure() failed"
raise ValueError(message)
return value
75 changes: 75 additions & 0 deletions bdai_ros2_wrappers/test/test_feeds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.


import tf2_ros
from geometry_msgs.msg import (
PoseStamped,
TransformStamped,
TwistStamped,
)
from message_filters import SimpleFilter

from bdai_ros2_wrappers.feeds import (
FramedMessageFeed,
MessageFeed,
SynchronizedMessageFeed,
)
from bdai_ros2_wrappers.scope import ROSAwareScope
from bdai_ros2_wrappers.utilities import ensure


def test_framed_message_feed(ros: ROSAwareScope) -> None:
tf_buffer = tf2_ros.Buffer()
pose_message_feed = MessageFeed(SimpleFilter())
framed_message_feed = FramedMessageFeed(
pose_message_feed,
target_frame_id="map",
tf_buffer=tf_buffer,
node=ros.node,
)

expected_transform_message = TransformStamped()
expected_transform_message.header.frame_id = "map"
expected_transform_message.child_frame_id = "odom"
expected_transform_message.transform.translation.y = 1.0
expected_transform_message.transform.rotation.w = 1.0
tf_buffer.set_transform_static(expected_transform_message, "pytest")

expected_pose_message = PoseStamped()
expected_pose_message.header.frame_id = "odom"
expected_pose_message.header.stamp.sec = 1
expected_pose_message.pose.position.x = 1.0
expected_pose_message.pose.orientation.w = 1.0
pose_message_feed.link.signalMessage(expected_pose_message)

pose_message, transform_message = ensure(framed_message_feed.latest)
assert pose_message.pose.position.x == expected_pose_message.pose.position.x
assert transform_message.transform.translation.y == expected_transform_message.transform.translation.y


def test_synchronized_message_feed(ros: ROSAwareScope) -> None:
pose_message_feed = MessageFeed(SimpleFilter())
twist_message_feed = MessageFeed(SimpleFilter())
synchronized_message_feed = SynchronizedMessageFeed(
pose_message_feed,
twist_message_feed,
node=ros.node,
)

expected_pose_message = PoseStamped()
expected_pose_message.header.frame_id = "odom"
expected_pose_message.header.stamp.sec = 1
expected_pose_message.pose.position.x = 1.0
expected_pose_message.pose.orientation.w = 1.0
pose_message_feed.link.signalMessage(expected_pose_message)

expected_twist_message = TwistStamped()
expected_twist_message.header.frame_id = "base_link"
expected_twist_message.header.stamp.sec = 1
expected_twist_message.twist.linear.x = 1.0
expected_twist_message.twist.angular.z = 1.0
twist_message_feed.link.signalMessage(expected_twist_message)

pose_message, twist_message = ensure(synchronized_message_feed.latest)
assert pose_message.pose.position.x == expected_pose_message.pose.position.x
assert twist_message.twist.linear.x == expected_twist_message.twist.linear.x
7 changes: 4 additions & 3 deletions bdai_ros2_wrappers/test/test_subscription.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.
import itertools
import time
from typing import Any, Iterator, cast
from typing import Any, Iterator

from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile
from std_msgs.msg import Int8, String
Expand All @@ -10,6 +10,7 @@
from bdai_ros2_wrappers.node import Node
from bdai_ros2_wrappers.scope import ROSAwareScope
from bdai_ros2_wrappers.subscription import Subscription, wait_for_message, wait_for_messages
from bdai_ros2_wrappers.utilities import ensure

DEFAULT_QOS_PROFILE = QoSProfile(
durability=DurabilityPolicy.TRANSIENT_LOCAL,
Expand Down Expand Up @@ -44,7 +45,7 @@ def test_subscription_future_wait(ros: ROSAwareScope) -> None:
pub.publish(Int8(data=1))

assert wait_for_future(sequence.update, timeout_sec=5.0)
assert cast(Int8, sequence.latest).data == 1
assert ensure(sequence.latest).data == 1


def test_subscription_matching_future_wait(ros: ROSAwareScope) -> None:
Expand Down Expand Up @@ -132,7 +133,7 @@ def deferred_cancellation() -> None:
assert len(historic_numbers) == 1
assert historic_numbers[0] == 1

assert cast(Int8, sequence.latest).data == 1
assert ensure(sequence.latest).data == 1


def test_wait_for_messages(ros: ROSAwareScope) -> None:
Expand Down
11 changes: 10 additions & 1 deletion bdai_ros2_wrappers/test/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import argparse

from bdai_ros2_wrappers.utilities import either_or, namespace_with
import pytest

from bdai_ros2_wrappers.utilities import either_or, ensure, namespace_with


def test_either_or() -> None:
Expand All @@ -18,3 +20,10 @@ def test_namespace_with() -> None:
assert namespace_with("", "foo") == "foo"
assert namespace_with("/", "foo") == "/foo"
assert namespace_with("foo", "bar") == "foo/bar"


def test_ensure() -> None:
data = None
with pytest.raises(ValueError) as excinfo:
ensure(data)
assert "ensure(data) failed" in str(excinfo.value)

0 comments on commit 1d157d6

Please sign in to comment.