From 03038ab08cca6d6e1cb83996779d4dd6893108aa Mon Sep 17 00:00:00 2001
From: mhidalgo-bdai <144129882+mhidalgo-bdai@users.noreply.github.com>
Date: Tue, 26 Mar 2024 19:28:30 -0300
Subject: [PATCH] Cancel pending work for destroyed nodes (#80)

Signed-off-by: Michel Hidalgo <mhidalgo@theaiinstitute.com>
---
 .../bdai_ros2_wrappers/executors.py           | 24 +++++--
 bdai_ros2_wrappers/bdai_ros2_wrappers/node.py | 11 +++
 bdai_ros2_wrappers/test/test_node.py          | 69 +++++++++++++++++++
 3 files changed, 98 insertions(+), 6 deletions(-)
 create mode 100644 bdai_ros2_wrappers/test/test_node.py

diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py
index 3f5a819..43809fb 100644
--- a/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py
+++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py
@@ -510,17 +510,29 @@ class AutoScalingMultiThreadedExecutor(rclpy.executors.Executor):
     class Task:
         """A bundle of an executable task and its associated entity."""
 
-        def __init__(self, task: rclpy.task.Task, entity: typing.Optional[rclpy.executors.WaitableEntityType]) -> None:
+        def __init__(
+            self,
+            task: rclpy.task.Task,
+            entity: typing.Optional[rclpy.executors.WaitableEntityType],
+            node: typing.Optional[rclpy.node.Node],
+        ) -> None:
             self.task = task
             self.entity = entity
+            if node is not None and hasattr(node, "destruction_requested"):
+                self.valid = lambda: not node.destruction_requested  # type: ignore
+            else:
+                self.valid = lambda: True
             self.callback_group = entity.callback_group if entity is not None else None
 
         def __call__(self) -> None:
-            """Run or resume a task
+            """Run or resume a task.
 
             See rclpy.task.Task documentation for further reference.
             """
-            self.task.__call__()
+            if not self.valid():
+                self.cancel()
+                return
+            self.task()
 
         def __getattr__(self, name: str) -> typing.Any:
             return getattr(self.task, name)
@@ -591,7 +603,7 @@ def _do_spin_once(self, *args: typing.Any, **kwargs: typing.Any) -> None:
         with self._spin_lock:
             try:
                 task, entity, node = self.wait_for_ready_callbacks(*args, **kwargs)
-                task = AutoScalingMultiThreadedExecutor.Task(task, entity)
+                task = AutoScalingMultiThreadedExecutor.Task(task, entity, node)
                 with self._shutdown_lock:
                     if self._is_shutdown:
                         # Ignore task, let shutdown clean it up.
@@ -665,8 +677,8 @@ def shutdown(self, timeout_sec: typing.Optional[float] = None) -> bool:
             with self._spin_lock:
                 # rclpy.executors.Executor base implementation leaves tasks
                 # unawaited upon shutdown. Do the housekeepng.
-                for task, entity, _ in self._tasks:
-                    task = AutoScalingMultiThreadedExecutor.Task(task, entity)
+                for task, entity, node in self._tasks:
+                    task = AutoScalingMultiThreadedExecutor.Task(task, entity, node)
                     task.cancel()
         return done
 
diff --git a/bdai_ros2_wrappers/bdai_ros2_wrappers/node.py b/bdai_ros2_wrappers/bdai_ros2_wrappers/node.py
index 185cc28..9eb6489 100644
--- a/bdai_ros2_wrappers/bdai_ros2_wrappers/node.py
+++ b/bdai_ros2_wrappers/bdai_ros2_wrappers/node.py
@@ -28,6 +28,7 @@ def __init__(self, *args: Any, default_callback_group: Optional[CallbackGroup] =
         if default_callback_group is None:
             default_callback_group = NonReentrantCallbackGroup()
         self._default_callback_group_override = default_callback_group
+        self._destruction_requested = False
         super().__init__(*args, **kwargs)
 
     @property
@@ -35,3 +36,13 @@ def default_callback_group(self) -> CallbackGroup:
         """Get the default callback group."""
         # NOTE(hidmic): this overrides the hardcoded default group in rclpy.node.Node implementation
         return self._default_callback_group_override
+
+    @property
+    def destruction_requested(self) -> bool:
+        """Checks whether destruction was requested or not."""
+        return self._destruction_requested
+
+    def destroy_node(self) -> None:
+        """Overrides node destruction API."""
+        self._destruction_requested = True
+        super().destroy_node()
diff --git a/bdai_ros2_wrappers/test/test_node.py b/bdai_ros2_wrappers/test/test_node.py
new file mode 100644
index 0000000..df97259
--- /dev/null
+++ b/bdai_ros2_wrappers/test/test_node.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2024 Boston Dynamics AI Institute Inc.  All rights reserved.
+
+import threading
+from typing import Generator
+
+import pytest
+import rclpy
+from rclpy.context import Context
+from std_srvs.srv import Trigger
+
+from bdai_ros2_wrappers.executors import AutoScalingMultiThreadedExecutor
+from bdai_ros2_wrappers.node import Node
+
+
+@pytest.fixture
+def ros_context() -> Generator[Context, None, None]:
+    """A fixture yielding a managed rclpy.context.Context instance."""
+    context = Context()
+    rclpy.init(context=context)
+    try:
+        yield context
+    finally:
+        context.try_shutdown()
+
+
+def test_node_destruction_during_execution(ros_context: Context) -> None:
+    """Asserts that node destructionthe autoscaling multithreaded executor scales to attend a
+    synchronous service call from a "one-shot" timer callback, serviced by
+    the same executor.
+    """
+
+    def dummy_server_callback(_: Trigger.Request, response: Trigger.Response) -> Trigger.Response:
+        response.success = True
+        return response
+
+    node = Node("pytest_node", context=ros_context)
+    node.create_service(Trigger, "/dummy/trigger", dummy_server_callback)
+    client = node.create_client(Trigger, "/dummy/trigger")
+
+    executor = AutoScalingMultiThreadedExecutor(max_threads=1, context=ros_context)
+    executor.add_node(node)
+
+    barrier = threading.Barrier(2)
+    try:
+        # First smoke test the executor with a service invocation
+        future = client.call_async(Trigger.Request())
+        executor.spin_until_future_complete(future, timeout_sec=5.0)
+        assert future.done() and future.result().success
+        # Then block its sole worker thread
+        executor.create_task(lambda: barrier.wait())
+        executor.spin_once()
+        # Then queue node destruction
+        executor.create_task(lambda: node.destroy_node())
+        executor.spin_once()
+        assert not node.destruction_requested  # still queued
+        # Then queue another service invocation
+        future = client.call_async(Trigger.Request())
+        executor.spin_once()
+        # Unblock worker thread in executor
+        barrier.wait()
+        # Check that executor wraps up early due to node destruction
+        executor.spin_until_future_complete(future, timeout_sec=5.0)
+        assert node.destruction_requested
+        assert executor.thread_pool.wait(timeout=5.0)
+        assert not future.done()  # future response will never be resolved
+    finally:
+        barrier.reset()
+        executor.remove_node(node)
+        executor.shutdown()