Skip to content

Commit

Permalink
Cancel pending work for destroyed nodes (#80)
Browse files Browse the repository at this point in the history
Signed-off-by: Michel Hidalgo <[email protected]>
  • Loading branch information
mhidalgo-bdai authored Mar 26, 2024
1 parent 5dd9fee commit 03038ab
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 6 deletions.
24 changes: 18 additions & 6 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,17 +510,29 @@ class AutoScalingMultiThreadedExecutor(rclpy.executors.Executor):
class Task:
"""A bundle of an executable task and its associated entity."""

def __init__(self, task: rclpy.task.Task, entity: typing.Optional[rclpy.executors.WaitableEntityType]) -> None:
def __init__(
self,
task: rclpy.task.Task,
entity: typing.Optional[rclpy.executors.WaitableEntityType],
node: typing.Optional[rclpy.node.Node],
) -> None:
self.task = task
self.entity = entity
if node is not None and hasattr(node, "destruction_requested"):
self.valid = lambda: not node.destruction_requested # type: ignore
else:
self.valid = lambda: True
self.callback_group = entity.callback_group if entity is not None else None

def __call__(self) -> None:
"""Run or resume a task
"""Run or resume a task.
See rclpy.task.Task documentation for further reference.
"""
self.task.__call__()
if not self.valid():
self.cancel()
return
self.task()

def __getattr__(self, name: str) -> typing.Any:
return getattr(self.task, name)
Expand Down Expand Up @@ -591,7 +603,7 @@ def _do_spin_once(self, *args: typing.Any, **kwargs: typing.Any) -> None:
with self._spin_lock:
try:
task, entity, node = self.wait_for_ready_callbacks(*args, **kwargs)
task = AutoScalingMultiThreadedExecutor.Task(task, entity)
task = AutoScalingMultiThreadedExecutor.Task(task, entity, node)
with self._shutdown_lock:
if self._is_shutdown:
# Ignore task, let shutdown clean it up.
Expand Down Expand Up @@ -665,8 +677,8 @@ def shutdown(self, timeout_sec: typing.Optional[float] = None) -> bool:
with self._spin_lock:
# rclpy.executors.Executor base implementation leaves tasks
# unawaited upon shutdown. Do the housekeepng.
for task, entity, _ in self._tasks:
task = AutoScalingMultiThreadedExecutor.Task(task, entity)
for task, entity, node in self._tasks:
task = AutoScalingMultiThreadedExecutor.Task(task, entity, node)
task.cancel()
return done

Expand Down
11 changes: 11 additions & 0 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,21 @@ def __init__(self, *args: Any, default_callback_group: Optional[CallbackGroup] =
if default_callback_group is None:
default_callback_group = NonReentrantCallbackGroup()
self._default_callback_group_override = default_callback_group
self._destruction_requested = False
super().__init__(*args, **kwargs)

@property
def default_callback_group(self) -> CallbackGroup:
"""Get the default callback group."""
# NOTE(hidmic): this overrides the hardcoded default group in rclpy.node.Node implementation
return self._default_callback_group_override

@property
def destruction_requested(self) -> bool:
"""Checks whether destruction was requested or not."""
return self._destruction_requested

def destroy_node(self) -> None:
"""Overrides node destruction API."""
self._destruction_requested = True
super().destroy_node()
69 changes: 69 additions & 0 deletions bdai_ros2_wrappers/test/test_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.

import threading
from typing import Generator

import pytest
import rclpy
from rclpy.context import Context
from std_srvs.srv import Trigger

from bdai_ros2_wrappers.executors import AutoScalingMultiThreadedExecutor
from bdai_ros2_wrappers.node import Node


@pytest.fixture
def ros_context() -> Generator[Context, None, None]:
"""A fixture yielding a managed rclpy.context.Context instance."""
context = Context()
rclpy.init(context=context)
try:
yield context
finally:
context.try_shutdown()


def test_node_destruction_during_execution(ros_context: Context) -> None:
"""Asserts that node destructionthe autoscaling multithreaded executor scales to attend a
synchronous service call from a "one-shot" timer callback, serviced by
the same executor.
"""

def dummy_server_callback(_: Trigger.Request, response: Trigger.Response) -> Trigger.Response:
response.success = True
return response

node = Node("pytest_node", context=ros_context)
node.create_service(Trigger, "/dummy/trigger", dummy_server_callback)
client = node.create_client(Trigger, "/dummy/trigger")

executor = AutoScalingMultiThreadedExecutor(max_threads=1, context=ros_context)
executor.add_node(node)

barrier = threading.Barrier(2)
try:
# First smoke test the executor with a service invocation
future = client.call_async(Trigger.Request())
executor.spin_until_future_complete(future, timeout_sec=5.0)
assert future.done() and future.result().success
# Then block its sole worker thread
executor.create_task(lambda: barrier.wait())
executor.spin_once()
# Then queue node destruction
executor.create_task(lambda: node.destroy_node())
executor.spin_once()
assert not node.destruction_requested # still queued
# Then queue another service invocation
future = client.call_async(Trigger.Request())
executor.spin_once()
# Unblock worker thread in executor
barrier.wait()
# Check that executor wraps up early due to node destruction
executor.spin_until_future_complete(future, timeout_sec=5.0)
assert node.destruction_requested
assert executor.thread_pool.wait(timeout=5.0)
assert not future.done() # future response will never be resolved
finally:
barrier.reset()
executor.remove_node(node)
executor.shutdown()

0 comments on commit 03038ab

Please sign in to comment.