From 03038ab08cca6d6e1cb83996779d4dd6893108aa Mon Sep 17 00:00:00 2001 From: mhidalgo-bdai <144129882+mhidalgo-bdai@users.noreply.github.com> Date: Tue, 26 Mar 2024 19:28:30 -0300 Subject: [PATCH] Cancel pending work for destroyed nodes (#80) Signed-off-by: Michel Hidalgo <mhidalgo@theaiinstitute.com> --- .../bdai_ros2_wrappers/executors.py | 24 +++++-- bdai_ros2_wrappers/bdai_ros2_wrappers/node.py | 11 +++ bdai_ros2_wrappers/test/test_node.py | 69 +++++++++++++++++++ 3 files changed, 98 insertions(+), 6 deletions(-) create mode 100644 bdai_ros2_wrappers/test/test_node.py diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py index 3f5a819..43809fb 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py @@ -510,17 +510,29 @@ class AutoScalingMultiThreadedExecutor(rclpy.executors.Executor): class Task: """A bundle of an executable task and its associated entity.""" - def __init__(self, task: rclpy.task.Task, entity: typing.Optional[rclpy.executors.WaitableEntityType]) -> None: + def __init__( + self, + task: rclpy.task.Task, + entity: typing.Optional[rclpy.executors.WaitableEntityType], + node: typing.Optional[rclpy.node.Node], + ) -> None: self.task = task self.entity = entity + if node is not None and hasattr(node, "destruction_requested"): + self.valid = lambda: not node.destruction_requested # type: ignore + else: + self.valid = lambda: True self.callback_group = entity.callback_group if entity is not None else None def __call__(self) -> None: - """Run or resume a task + """Run or resume a task. See rclpy.task.Task documentation for further reference. """ - self.task.__call__() + if not self.valid(): + self.cancel() + return + self.task() def __getattr__(self, name: str) -> typing.Any: return getattr(self.task, name) @@ -591,7 +603,7 @@ def _do_spin_once(self, *args: typing.Any, **kwargs: typing.Any) -> None: with self._spin_lock: try: task, entity, node = self.wait_for_ready_callbacks(*args, **kwargs) - task = AutoScalingMultiThreadedExecutor.Task(task, entity) + task = AutoScalingMultiThreadedExecutor.Task(task, entity, node) with self._shutdown_lock: if self._is_shutdown: # Ignore task, let shutdown clean it up. @@ -665,8 +677,8 @@ def shutdown(self, timeout_sec: typing.Optional[float] = None) -> bool: with self._spin_lock: # rclpy.executors.Executor base implementation leaves tasks # unawaited upon shutdown. Do the housekeepng. - for task, entity, _ in self._tasks: - task = AutoScalingMultiThreadedExecutor.Task(task, entity) + for task, entity, node in self._tasks: + task = AutoScalingMultiThreadedExecutor.Task(task, entity, node) task.cancel() return done diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/node.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/node.py index 185cc28..9eb6489 100644 --- a/bdai_ros2_wrappers/bdai_ros2_wrappers/node.py +++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/node.py @@ -28,6 +28,7 @@ def __init__(self, *args: Any, default_callback_group: Optional[CallbackGroup] = if default_callback_group is None: default_callback_group = NonReentrantCallbackGroup() self._default_callback_group_override = default_callback_group + self._destruction_requested = False super().__init__(*args, **kwargs) @property @@ -35,3 +36,13 @@ def default_callback_group(self) -> CallbackGroup: """Get the default callback group.""" # NOTE(hidmic): this overrides the hardcoded default group in rclpy.node.Node implementation return self._default_callback_group_override + + @property + def destruction_requested(self) -> bool: + """Checks whether destruction was requested or not.""" + return self._destruction_requested + + def destroy_node(self) -> None: + """Overrides node destruction API.""" + self._destruction_requested = True + super().destroy_node() diff --git a/bdai_ros2_wrappers/test/test_node.py b/bdai_ros2_wrappers/test/test_node.py new file mode 100644 index 0000000..df97259 --- /dev/null +++ b/bdai_ros2_wrappers/test/test_node.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved. + +import threading +from typing import Generator + +import pytest +import rclpy +from rclpy.context import Context +from std_srvs.srv import Trigger + +from bdai_ros2_wrappers.executors import AutoScalingMultiThreadedExecutor +from bdai_ros2_wrappers.node import Node + + +@pytest.fixture +def ros_context() -> Generator[Context, None, None]: + """A fixture yielding a managed rclpy.context.Context instance.""" + context = Context() + rclpy.init(context=context) + try: + yield context + finally: + context.try_shutdown() + + +def test_node_destruction_during_execution(ros_context: Context) -> None: + """Asserts that node destructionthe autoscaling multithreaded executor scales to attend a + synchronous service call from a "one-shot" timer callback, serviced by + the same executor. + """ + + def dummy_server_callback(_: Trigger.Request, response: Trigger.Response) -> Trigger.Response: + response.success = True + return response + + node = Node("pytest_node", context=ros_context) + node.create_service(Trigger, "/dummy/trigger", dummy_server_callback) + client = node.create_client(Trigger, "/dummy/trigger") + + executor = AutoScalingMultiThreadedExecutor(max_threads=1, context=ros_context) + executor.add_node(node) + + barrier = threading.Barrier(2) + try: + # First smoke test the executor with a service invocation + future = client.call_async(Trigger.Request()) + executor.spin_until_future_complete(future, timeout_sec=5.0) + assert future.done() and future.result().success + # Then block its sole worker thread + executor.create_task(lambda: barrier.wait()) + executor.spin_once() + # Then queue node destruction + executor.create_task(lambda: node.destroy_node()) + executor.spin_once() + assert not node.destruction_requested # still queued + # Then queue another service invocation + future = client.call_async(Trigger.Request()) + executor.spin_once() + # Unblock worker thread in executor + barrier.wait() + # Check that executor wraps up early due to node destruction + executor.spin_until_future_complete(future, timeout_sec=5.0) + assert node.destruction_requested + assert executor.thread_pool.wait(timeout=5.0) + assert not future.done() # future response will never be resolved + finally: + barrier.reset() + executor.remove_node(node) + executor.shutdown()