Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow adapted feeds to filter messages #123

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.

from typing import Any, Callable, Generator, Generic, Iterable, List, Literal, Optional, TypeVar, Union, overload
from typing import (
Any,
Callable,
Generator,
Generic,
Iterable,
List,
Literal,
Optional,
TypeVar,
Union,
overload,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

black decided it was time to reformat 🤷‍♂️

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a line length limit for black.


import tf2_ros
from rclpy.node import Node

import bdai_ros2_wrappers.scope as scope
from bdai_ros2_wrappers.filters import Adapter, ApproximateTimeSynchronizer, Filter, TransformFilter, Tunnel
from bdai_ros2_wrappers.filters import (
Adapter,
ApproximateTimeSynchronizer,
Filter,
TransformFilter,
Tunnel,
)
from bdai_ros2_wrappers.futures import FutureLike
from bdai_ros2_wrappers.utilities import Tape

Expand Down Expand Up @@ -36,7 +54,9 @@ def __init__(
history_length = 1
self._link = link
self._tape: Tape[MessageT] = Tape(history_length)
self._link.registerCallback(lambda *msgs: self._tape.write(msgs if len(msgs) > 1 else msgs[0]))
self._link.registerCallback(
lambda *msgs: self._tape.write(msgs if len(msgs) > 1 else msgs[0]),
)
node.context.on_shutdown(self._tape.close)

@property
Expand All @@ -59,7 +79,10 @@ def update(self) -> FutureLike[MessageT]:
"""Gets the future to the next message yet to be received."""
return self._tape.future_write

def matching_update(self, matching_predicate: Callable[[MessageT], bool]) -> FutureLike[MessageT]:
def matching_update(
self,
matching_predicate: Callable[[MessageT], bool],
) -> FutureLike[MessageT]:
"""Gets a future to the next matching message yet to be received.

Args:
Expand Down
21 changes: 16 additions & 5 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
to the underlying `message_filters.ApproximateTimeSynchronizer`.
"""
super().__init__()
self._unsafe_synchronizer = message_filters.ApproximateTimeSynchronizer(*args, **kwargs)
self._unsafe_synchronizer = message_filters.ApproximateTimeSynchronizer(
*args,
**kwargs,
)
self._unsafe_synchronizer.registerCallback(self.signalMessage)

def __getattr__(self, name: str) -> Any:
Expand Down Expand Up @@ -175,7 +178,9 @@ def _wait_callback(self, messages: Sequence[Any], future: Future) -> None:
time,
)
self._ongoing_wait_time = time
self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages))
self._ongoing_wait.add_done_callback(
functools.partial(self._wait_callback, messages),
)
else:
self._ongoing_wait_time = None
self._ongoing_wait = None
Expand Down Expand Up @@ -204,7 +209,9 @@ def add(self, *messages: Any) -> None:
time,
)
self._ongoing_wait_time = time
self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages))
self._ongoing_wait.add_done_callback(
functools.partial(self._wait_callback, messages),
)


class Adapter(Filter):
Expand All @@ -215,15 +222,19 @@ def __init__(self, upstream: Filter, fn: Callable) -> None:

Args:
upstream: the upstream message filter.
fn: adapter implementation as a callable.
fn: a callable that takes messages as arguments and returns some
data to be signaled (i.e. propagated down the filter chain).
If none is returned, no message signaling will occur.
"""
super().__init__()
self.fn = fn
self.connection = upstream.registerCallback(self.add)

def add(self, *messages: Any) -> None:
"""Adds new `messages` to the adapter."""
self.signalMessage(self.fn(*messages))
result = self.fn(*messages)
Copy link
Contributor Author

@mhidalgo-bdai mhidalgo-bdai Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amessing-bdai ended up going with the trivial implementation. There is a way, an async way, for fn to nicely embed complex, stateful logic in a functional way, but it was more trouble than it was worth. If one needs lots of bookkeeping and control flow for an adapter, it may be about time to write a new filter.

if result is not None:
self.signalMessage(result)


class Tunnel(Filter):
Expand Down
28 changes: 28 additions & 0 deletions bdai_ros2_wrappers/test/test_feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,34 @@ def test_adapted_message_feed(ros: ROSAwareScope) -> None:
assert position_message is expected_pose_message.pose.position


def test_masked_message_feed(ros: ROSAwareScope) -> None:
pose_message_feed = MessageFeed[PoseStamped](Filter())
position_masking_feed = AdaptedMessageFeed[Point](
pose_message_feed,
fn=lambda message: message if message.pose.position.x > 0.0 else None,
)
expected_pose_message0 = PoseStamped()
expected_pose_message0.header.frame_id = "odom"
expected_pose_message0.header.stamp.sec = 1
expected_pose_message0.pose.position.x = -1.0
expected_pose_message0.pose.position.z = -1.0
expected_pose_message0.pose.orientation.w = 1.0
pose_message_feed.link.signalMessage(expected_pose_message0)
assert position_masking_feed.latest is None

expected_pose_message1 = PoseStamped()
expected_pose_message1.header.frame_id = "odom"
expected_pose_message1.header.stamp.sec = 2
expected_pose_message1.pose.position.x = 1.0
expected_pose_message1.pose.position.z = -1.0
expected_pose_message1.pose.orientation.w = 1.0
pose_message_feed.link.signalMessage(expected_pose_message1)

pose_message: Point = ensure(position_masking_feed.latest)
# no copies are expected, thus an identity check is valid
assert pose_message is expected_pose_message1


def test_message_feed_recalls(ros: ROSAwareScope) -> None:
pose_message_feed = MessageFeed[PoseStamped](Filter())

Expand Down
Loading