Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TransformFilter #98

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.

import collections
import functools
import threading
from collections.abc import Sequence
from typing import Any, Optional

import tf2_ros
from message_filters import SimpleFilter
from rclpy.duration import Duration
from rclpy.task import Future
from rclpy.time import Time

from bdai_ros2_wrappers.logging import RcutilsLogger


class TransformFilter(SimpleFilter):
"""A :mod:`tf2_ros` driven message filter, ensuring user defined transforms' availability.

This filter passes a stamped transform along with filtered messages, from message frame ID
to the given target frame ID, looked up at the message timestamp. The message is assumed
to have a header. When filtering message tuples, only the first message header is observed.
"""

def __init__(
self,
f: SimpleFilter,
target_frame_id: str,
tf_buffer: tf2_ros.Buffer,
tolerance_sec: float,
logger: Optional[RcutilsLogger] = None,
) -> None:
"""Initializes the transform filter.

Args:
f: the upstream message filter.
target_frame_id: the target frame ID for transforms.
tf_buffer: a buffer of transforms to look up.
tolerance_sec: a tolerance, in seconds, to wait for late transforms
before abandoning any waits and filtering out the corresponding messages.
logger: an optional logger to notify the yser about any errors during filtering.
"""
super().__init__()
self._logger = logger
self._lock = threading.Lock()
self._waitqueue: collections.deque = collections.deque()
self._ongoing_wait: Optional[Future] = None
self._ongoing_wait_time: Optional[Time] = None
self.target_frame_id = target_frame_id
self.tf_buffer = tf_buffer
self.tolerance = Duration(seconds=tolerance_sec)
self.incoming_connection = f.registerCallback(self.add)

def _wait_callback(self, messages: Sequence[Any], future: Future) -> None:
if future.cancelled():
return
with self._lock:
try:
if future.result() is True:
source_frame_id = messages[0].header.frame_id
time = Time.from_msg(messages[0].header.stamp)
transform = self.tf_buffer.lookup_transform(
self.target_frame_id,
source_frame_id,
time,
)
self.signalMessage(*messages, transform)
except tf2_ros.TransformException as e:
if self._logger is not None:
self._logger.error(
(
"Got an exception during transform lookup: %s",
str(e),
),
)

if self._waitqueue:
messages = self._waitqueue.popleft()
source_frame_id = messages[0].header.frame_id
time = Time.from_msg(messages[0].header.stamp)
self._ongoing_wait = self.tf_buffer.wait_for_transform_async(
self.target_frame_id,
source_frame_id,
time,
)
self._ongoing_wait_time = time
self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages))
else:
self._ongoing_wait_time = None
self._ongoing_wait = None

def add(self, *messages: Any) -> None:
"""Adds new `messages` to the filter."""
with self._lock:
time = Time.from_msg(messages[0].header.stamp)
if self._ongoing_wait and not self._ongoing_wait.done() and time - self._ongoing_wait_time > self.tolerance:
self._ongoing_wait.cancel()
self._ongoing_wait = None
while self._waitqueue:
pending_messages = self._waitqueue[0]
pending_time = Time.from_msg(pending_messages[0].header.stamp)
if time - pending_time <= self.tolerance:
break
self._waitqueue.popleft()
self._waitqueue.append(messages)
if not self._ongoing_wait:
messages = self._waitqueue.popleft()
source_frame_id = messages[0].header.frame_id
time = Time.from_msg(messages[0].header.stamp)
self._ongoing_wait = self.tf_buffer.wait_for_transform_async(
self.target_frame_id,
source_frame_id,
time,
)
self._ongoing_wait_time = time
self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages))
73 changes: 73 additions & 0 deletions bdai_ros2_wrappers/test/test_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.

from typing import List, Tuple

import tf2_ros
from geometry_msgs.msg import PoseStamped, TransformStamped
from message_filters import SimpleFilter

from bdai_ros2_wrappers.filters import TransformFilter


def test_transform_wait() -> None:
source = SimpleFilter()
tf_buffer = tf2_ros.Buffer()
tf_filter = TransformFilter(source, "map", tf_buffer, tolerance_sec=1.0)
sink: List[Tuple[PoseStamped, TransformStamped]] = []
tf_filter.registerCallback(lambda *msgs: sink.append(msgs))
assert len(sink) == 0
pose_message = PoseStamped()
pose_message.header.frame_id = "odom"
pose_message.header.stamp.sec = 1
pose_message.pose.position.x = 1.0
pose_message.pose.orientation.w = 1.0
source.signalMessage(pose_message)
assert len(sink) == 0
transform_message = TransformStamped()
transform_message.header.frame_id = "map"
transform_message.child_frame_id = "odom"
transform_message.transform.rotation.w = 1.0
tf_buffer.set_transform_static(transform_message, "pytest")
assert len(sink) == 1
filtered_pose_message, filtered_transform_message = sink.pop()
assert filtered_pose_message.header.frame_id == pose_message.header.frame_id
assert filtered_pose_message.pose.position.x == pose_message.pose.position.x
assert filtered_pose_message.pose.orientation.w == pose_message.pose.orientation.w
assert filtered_transform_message.header.frame_id == transform_message.header.frame_id
assert filtered_transform_message.child_frame_id == transform_message.child_frame_id
assert filtered_transform_message.transform.rotation.w == transform_message.transform.rotation.w


def test_old_transform_filtering() -> None:
source = SimpleFilter()
tf_buffer = tf2_ros.Buffer()
tf_filter = TransformFilter(source, "map", tf_buffer, tolerance_sec=2.0)
sink: List[Tuple[PoseStamped, TransformStamped]] = []
tf_filter.registerCallback(lambda *msgs: sink.append(msgs))
assert len(sink) == 0
first_pose_message = PoseStamped()
first_pose_message.header.frame_id = "odom"
first_pose_message.pose.position.x = 1.0
first_pose_message.pose.orientation.w = 1.0
source.signalMessage(first_pose_message)
assert len(sink) == 0
second_pose_message = PoseStamped()
second_pose_message.header.frame_id = "odom"
second_pose_message.header.stamp.sec = 10
second_pose_message.pose.position.x = 2.0
second_pose_message.pose.orientation.w = 1.0
source.signalMessage(second_pose_message)
assert len(sink) == 0
transform_message = TransformStamped()
transform_message.header.frame_id = "map"
transform_message.child_frame_id = "odom"
transform_message.transform.rotation.w = 1.0
tf_buffer.set_transform_static(transform_message, "pytest")
assert len(sink) == 1
filtered_pose_message, filtered_transform_message = sink.pop()
assert filtered_pose_message.header.frame_id == second_pose_message.header.frame_id
assert filtered_pose_message.pose.position.x == second_pose_message.pose.position.x
assert filtered_pose_message.pose.orientation.w == second_pose_message.pose.orientation.w
assert filtered_transform_message.header.frame_id == transform_message.header.frame_id
assert filtered_transform_message.child_frame_id == transform_message.child_frame_id
assert filtered_transform_message.transform.rotation.w == transform_message.transform.rotation.w
Loading