Skip to content

Commit

Permalink
Merge branch 'main' into mhidalgo-bdai/time-utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
mhidalgo-bdai authored Jun 11, 2024
2 parents c3262c8 + 2c14e4b commit cc69e9e
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 5 deletions.
36 changes: 33 additions & 3 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from rclpy.task import Future

import bdai_ros2_wrappers.scope as scope
from bdai_ros2_wrappers.filters import TransformFilter
from bdai_ros2_wrappers.filters import SimpleAdapter, TransformFilter
from bdai_ros2_wrappers.utilities import Tape


Expand All @@ -27,8 +27,7 @@ def __init__(
Args:
link: Wrapped message filter, connecting this message feed with its source.
history_length: optional historic data size, defaults to 1
node: optional node for the underlying native subscription, defaults to
the current process node.
node: optional node for lifetime control, defaults to the current process node.
"""
if node is None:
node = scope.ensure_node()
Expand Down Expand Up @@ -103,6 +102,37 @@ def close(self) -> None:
self._tape.close()


class AdaptedMessageFeed(MessageFeed):
"""A message feed decorator to simplify adapter patterns."""

def __init__(
self,
feed: MessageFeed,
fn: Callable,
**kwargs: Any,
) -> None:
"""Initializes the message feed.
Args:
feed: the upstream (ie. decorated) message feed.
fn: message adapting callable.
kwargs: all other keyword arguments are forwarded
for `MessageFeed` initialization.
"""
super().__init__(SimpleAdapter(feed.link, fn), **kwargs)
self._feed = feed

@property
def feed(self) -> MessageFeed:
"""Gets the upstream message feed."""
return self._feed

def close(self) -> None:
"""Closes this message feed and the upstream one as well."""
self._feed.close()
super().close()


class FramedMessageFeed(MessageFeed):
"""A message feed decorator, incorporating transforms using a `TransformFilter` instance."""

Expand Down
21 changes: 20 additions & 1 deletion bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import functools
import threading
from collections.abc import Sequence
from typing import Any, Optional
from typing import Any, Callable, Optional

import tf2_ros
from message_filters import SimpleFilter
Expand Down Expand Up @@ -115,3 +115,22 @@ def add(self, *messages: Any) -> None:
)
self._ongoing_wait_time = time
self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages))


class SimpleAdapter(SimpleFilter):
"""A message filter for data adaptation."""

def __init__(self, f: SimpleFilter, fn: Callable) -> None:
"""Initializes the adapter.
Args:
f: the upstream message filter.
fn: adapter implementation as a callable.
"""
super().__init__()
self.do_adapt = fn
self.incoming_connection = f.registerCallback(self.add)

def add(self, *messages: Any) -> None:
"""Adds new `messages` to the adapter."""
self.signalMessage(self.do_adapt(*messages))
20 changes: 19 additions & 1 deletion bdai_ros2_wrappers/bdai_ros2_wrappers/futures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2023 Boston Dynamics AI Institute Inc. All rights reserved.
from threading import Event
from typing import Awaitable, Callable, Optional, Protocol, TypeVar, Union, runtime_checkable
from typing import Any, Awaitable, Callable, Optional, Protocol, TypeVar, Union, runtime_checkable

from rclpy.context import Context
from rclpy.utilities import get_default_context
Expand Down Expand Up @@ -84,3 +84,21 @@ def wait_for_future(
event.set()
event.wait(timeout=timeout_sec)
return proper_future.done()


def unwrap_future(
future: AnyFuture,
timeout_sec: Optional[float] = None,
*,
context: Optional[Context] = None,
) -> Any:
"""Fetch future result when it is done.
Note this function may block and may raise if the future does or it times out
waiting for it. See wait_for_future() documentation for further reference on
arguments taken.
"""
proper_future = as_proper_future(future)
if not wait_for_future(proper_future, timeout_sec, context=context):
raise ValueError("cannot unwrap future that is not done")
return proper_future.result()
21 changes: 21 additions & 0 deletions bdai_ros2_wrappers/test/test_feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from message_filters import SimpleFilter

from bdai_ros2_wrappers.feeds import (
AdaptedMessageFeed,
FramedMessageFeed,
MessageFeed,
SynchronizedMessageFeed,
Expand Down Expand Up @@ -73,3 +74,23 @@ def test_synchronized_message_feed(ros: ROSAwareScope) -> None:
pose_message, twist_message = ensure(synchronized_message_feed.latest)
assert pose_message.pose.position.x == expected_pose_message.pose.position.x
assert twist_message.twist.linear.x == expected_twist_message.twist.linear.x


def test_adapted_message_feed(ros: ROSAwareScope) -> None:
pose_message_feed = MessageFeed(SimpleFilter())
position_message_feed = AdaptedMessageFeed(
pose_message_feed,
fn=lambda message: message.pose.position,
)

expected_pose_message = PoseStamped()
expected_pose_message.header.frame_id = "odom"
expected_pose_message.header.stamp.sec = 1
expected_pose_message.pose.position.x = 1.0
expected_pose_message.pose.position.z = -1.0
expected_pose_message.pose.orientation.w = 1.0
pose_message_feed.link.signalMessage(expected_pose_message)

position_message = ensure(position_message_feed.latest)
# no copies are expected, thus an identity check is valid
assert position_message is expected_pose_message.pose.position

0 comments on commit cc69e9e

Please sign in to comment.