Skip to content

Commit

Permalink
Add thread-safe message filters (#110)
Browse files Browse the repository at this point in the history
Follow-up to #109. It turns out that upstream `message_filters` are not quite safe to use in multi-threaded applications (at least in Python). This patch introduces thread-safe equivalents that are API compatible with upstream ones.
  • Loading branch information
mhidalgo-bdai authored Jul 29, 2024
1 parent a583e79 commit 4d07bf2
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 27 deletions.
9 changes: 4 additions & 5 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from typing import Any, Callable, Iterable, Iterator, List, Optional

import tf2_ros
from message_filters import ApproximateTimeSynchronizer, SimpleFilter
from rclpy.node import Node
from rclpy.task import Future

import bdai_ros2_wrappers.scope as scope
from bdai_ros2_wrappers.filters import SimpleAdapter, TransformFilter, Tunnel
from bdai_ros2_wrappers.filters import Adapter, ApproximateTimeSynchronizer, Filter, TransformFilter, Tunnel
from bdai_ros2_wrappers.utilities import Tape


Expand All @@ -17,7 +16,7 @@ class MessageFeed:

def __init__(
self,
link: SimpleFilter,
link: Filter,
*,
history_length: Optional[int] = None,
node: Optional[Node] = None,
Expand All @@ -39,7 +38,7 @@ def __init__(
node.context.on_shutdown(self._tape.close)

@property
def link(self) -> SimpleFilter:
def link(self) -> Filter:
"""Gets the underlying message connection."""
return self._link

Expand Down Expand Up @@ -129,7 +128,7 @@ def __init__(
kwargs: all other keyword arguments are forwarded
for `MessageFeed` initialization.
"""
super().__init__(SimpleAdapter(feed.link, fn), **kwargs)
super().__init__(Adapter(feed.link, fn), **kwargs)
self._feed = feed

@property
Expand Down
112 changes: 101 additions & 11 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,110 @@

import collections
import functools
import itertools
import threading
from collections.abc import Sequence
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, Optional, Protocol, Tuple

import message_filters
import tf2_ros
from message_filters import SimpleFilter
from rclpy.duration import Duration
from rclpy.node import Node
from rclpy.task import Future
from rclpy.time import Time

from bdai_ros2_wrappers.logging import RcutilsLogger


class TransformFilter(SimpleFilter):
class SimpleFilterProtocol(Protocol):
"""Protocol for `message_filters.SimpleFilter` subclasses."""

def registerCallback(self, callback: Callable, *args: Any) -> int:
"""Register callable to be called on filter output."""

def signalMessage(self, *messages: Any) -> None:
"""Feed one or more `messages` to the filter."""


class Filter(SimpleFilterProtocol):
"""A threadsafe `message_filters.SimpleFilter` compliant message filter."""

def __init__(self) -> None:
self.__lock = threading.Lock()
self._connection_sequence = itertools.count()
self.callbacks: Dict[int, Tuple[Callable, Tuple]] = {}

def registerCallback(self, fn: Callable, *args: Any) -> int:
"""Register callable to be called on filter output.
Args:
fn: callback callable.
args: optional positional arguments to supply on call.
Returns:
a unique connection identifier.
"""
with self.__lock:
connection = next(self._connection_sequence)
self.callbacks[connection] = (fn, args)
return connection

def unregisterCallback(self, connection: int) -> None:
"""Unregister a callback.
Args:
connection: unique identifier for the callback.
"""
with self.__lock:
del self.callbacks[connection]

def signalMessage(self, *messages: Any) -> None:
"""Feed one or more `messages` to the filter."""
with self.__lock:
callbacks = list(self.callbacks.values())

for fn, args in callbacks:
fn(*(messages + args))


class Subscriber(Filter):
"""A threadsafe `message_filters.Subscriber` equivalent."""

def __init__(self, node: Node, *args: Any, **kwargs: Any) -> None:
"""Initializes the `Subscriber` instance.
All positional and keyword arguments are forwarded
to `rclpy.node.Node.create_subscription`.
"""
super().__init__()
self.sub = node.create_subscription(
*args,
callback=self.signalMessage,
**kwargs,
)

def __getattr__(self, name: str) -> Any:
return getattr(self.subscription, name)


class ApproximateTimeSynchronizer(Filter):
"""A threadsafe `message_filters.ApproximateTimeSynchronizer` equivalent."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initializes the `ApproximateTimeSynchronizer` instance.
All positional and keyword arguments are forwarded
to the underlying `message_filters.ApproximateTimeSynchronizer`.
"""
super().__init__()
self._unsafe_synchronizer = message_filters.ApproximateTimeSynchronizer(*args, **kwargs)
self._unsafe_synchronizer.registerCallback(self.signalMessage)

def __getattr__(self, name: str) -> Any:
return getattr(self._unsafe_synchronizer, name)


class TransformFilter(Filter):
"""A :mod:`tf2_ros` driven message filter, ensuring user defined transforms' availability.
This filter passes a stamped transform along with filtered messages, from message frame ID
Expand All @@ -25,7 +115,7 @@ class TransformFilter(SimpleFilter):

def __init__(
self,
upstream: SimpleFilter,
upstream: Filter,
target_frame_id: str,
tf_buffer: tf2_ros.Buffer,
tolerance_sec: float,
Expand Down Expand Up @@ -117,29 +207,29 @@ def add(self, *messages: Any) -> None:
self._ongoing_wait.add_done_callback(functools.partial(self._wait_callback, messages))


class SimpleAdapter(SimpleFilter):
class Adapter(Filter):
"""A message filter for data adaptation."""

def __init__(self, upstream: SimpleFilter, fn: Callable) -> None:
def __init__(self, upstream: Filter, fn: Callable) -> None:
"""Initializes the adapter.
Args:
upstream: the upstream message filter.
fn: adapter implementation as a callable.
"""
super().__init__()
self.do_adapt = fn
self.fn = fn
self.connection = upstream.registerCallback(self.add)

def add(self, *messages: Any) -> None:
"""Adds new `messages` to the adapter."""
self.signalMessage(self.do_adapt(*messages))
self.signalMessage(self.fn(*messages))


class Tunnel(SimpleFilter):
class Tunnel(Filter):
"""A message filter that simply forwards messages but can be detached."""

def __init__(self, upstream: SimpleFilter) -> None:
def __init__(self, upstream: Filter) -> None:
"""Initializes the tunnel.
Args:
Expand All @@ -151,4 +241,4 @@ def __init__(self, upstream: SimpleFilter) -> None:

def close(self) -> None:
"""Closes the tunnel, disconnecting it from upstream."""
del self.upstream.callbacks[self.connection]
self.upstream.unregisterCallback(self.connection)
2 changes: 1 addition & 1 deletion bdai_ros2_wrappers/bdai_ros2_wrappers/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from collections.abc import Sequence
from typing import Any, Optional, Type, Union, cast

from message_filters import ApproximateTimeSynchronizer, Subscriber
from rclpy.callback_groups import CallbackGroup
from rclpy.node import Node
from rclpy.qos import QoSProfile
from rclpy.task import Future

import bdai_ros2_wrappers.scope as scope
from bdai_ros2_wrappers.feeds import MessageFeed
from bdai_ros2_wrappers.filters import ApproximateTimeSynchronizer, Subscriber
from bdai_ros2_wrappers.futures import wait_for_future
from bdai_ros2_wrappers.type_hints import Msg as MessageT

Expand Down
12 changes: 6 additions & 6 deletions bdai_ros2_wrappers/test/test_feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
TransformStamped,
TwistStamped,
)
from message_filters import SimpleFilter

from bdai_ros2_wrappers.feeds import (
AdaptedMessageFeed,
FramedMessageFeed,
MessageFeed,
SynchronizedMessageFeed,
)
from bdai_ros2_wrappers.filters import Filter
from bdai_ros2_wrappers.scope import ROSAwareScope
from bdai_ros2_wrappers.utilities import ensure


def test_framed_message_feed(ros: ROSAwareScope) -> None:
tf_buffer = tf2_ros.Buffer()
pose_message_feed = MessageFeed(SimpleFilter())
pose_message_feed = MessageFeed(Filter())
framed_message_feed = FramedMessageFeed(
pose_message_feed,
target_frame_id="map",
Expand Down Expand Up @@ -50,8 +50,8 @@ def test_framed_message_feed(ros: ROSAwareScope) -> None:


def test_synchronized_message_feed(ros: ROSAwareScope) -> None:
pose_message_feed = MessageFeed(SimpleFilter())
twist_message_feed = MessageFeed(SimpleFilter())
pose_message_feed = MessageFeed(Filter())
twist_message_feed = MessageFeed(Filter())
synchronized_message_feed = SynchronizedMessageFeed(
pose_message_feed,
twist_message_feed,
Expand All @@ -78,7 +78,7 @@ def test_synchronized_message_feed(ros: ROSAwareScope) -> None:


def test_adapted_message_feed(ros: ROSAwareScope) -> None:
pose_message_feed = MessageFeed(SimpleFilter())
pose_message_feed = MessageFeed(Filter())
position_message_feed = AdaptedMessageFeed(
pose_message_feed,
fn=lambda message: message.pose.position,
Expand All @@ -98,7 +98,7 @@ def test_adapted_message_feed(ros: ROSAwareScope) -> None:


def test_message_feed_recalls(ros: ROSAwareScope) -> None:
pose_message_feed = MessageFeed(SimpleFilter())
pose_message_feed = MessageFeed(Filter())

latest_message: Optional[PoseStamped] = None

Expand Down
7 changes: 3 additions & 4 deletions bdai_ros2_wrappers/test/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@

import tf2_ros
from geometry_msgs.msg import PoseStamped, TransformStamped
from message_filters import SimpleFilter

from bdai_ros2_wrappers.filters import TransformFilter
from bdai_ros2_wrappers.filters import Filter, TransformFilter


def test_transform_wait() -> None:
source = SimpleFilter()
source = Filter()
tf_buffer = tf2_ros.Buffer()
tf_filter = TransformFilter(source, "map", tf_buffer, tolerance_sec=1.0)
sink: List[Tuple[PoseStamped, TransformStamped]] = []
Expand Down Expand Up @@ -39,7 +38,7 @@ def test_transform_wait() -> None:


def test_old_transform_filtering() -> None:
source = SimpleFilter()
source = Filter()
tf_buffer = tf2_ros.Buffer()
tf_filter = TransformFilter(source, "map", tf_buffer, tolerance_sec=2.0)
sink: List[Tuple[PoseStamped, TransformStamped]] = []
Expand Down

0 comments on commit 4d07bf2

Please sign in to comment.