Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EK-30] Fix TOCTTOU race in SingleGoalMultipleActionServers implementation #78

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from typing import Any, Callable, List, Optional, Tuple

from rclpy.action import ActionServer, CancelResponse, GoalResponse
from rclpy.action.server import ServerGoalHandle
from rclpy.action.server import GoalEvent, RCLError, ServerGoalHandle
from rclpy.callback_groups import CallbackGroup
from rclpy.impl.rcutils_logger import RcutilsLogger
from rclpy.node import Node

from bdai_ros2_wrappers.type_hints import Action, ActionType
from bdai_ros2_wrappers.utilities import synchronized


class SingleGoalMultipleActionServers:
Expand All @@ -25,22 +26,22 @@ def __init__(
) -> None:
"""Constructor"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the docstring with when you would want to set nosync to True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 224a408.

self._node = node
self._goal_handle_lock = threading.Lock()
self._goal_handle: Optional[ServerGoalHandle] = None
self._goal_lock = threading.Lock()
self._action_servers = []
for action_type, action_topic, execute_callback, callback_group in action_server_parameters:
self._action_servers.append(
ActionServer(
node,
action_type,
action_topic,
execute_callback=execute_callback,
goal_callback=self.goal_callback,
handle_accepted_callback=self.handle_accepted_callback,
cancel_callback=self.cancel_callback,
callback_group=callback_group,
),
self._execution_lock = threading.Lock()
self._action_servers = [
ActionServer(
node,
action_type,
action_topic,
execute_callback=synchronized(execute_callback, self._execution_lock),
goal_callback=self.goal_callback,
handle_accepted_callback=self.handle_accepted_callback,
cancel_callback=self.cancel_callback,
callback_group=callback_group,
)
for action_type, action_topic, execute_callback, callback_group in action_server_parameters
]

def get_logger(self) -> RcutilsLogger:
"""Returns the ros logger"""
Expand All @@ -58,12 +59,14 @@ def goal_callback(self, goal: Action.Goal) -> GoalResponse:

def handle_accepted_callback(self, goal_handle: ServerGoalHandle) -> None:
"""Callback triggered when an action is accepted."""
with self._goal_lock:
with self._goal_handle_lock:
# This server only allows one goal at a time
if self._goal_handle is not None and self._goal_handle.is_active:
self.get_logger().info("Aborting previous goal")
# Abort the existing goal
self._goal_handle.abort()
if self._goal_handle is not None:
self.get_logger().info("Canceling previous goal")
try:
self._goal_handle._update_state(GoalEvent.CANCEL_GOAL)
except RCLError as ex:
self.get_logger().debug(f"Failed to cancel goal: {ex}")
self._goal_handle = goal_handle

goal_handle.execute()
Expand Down
35 changes: 35 additions & 0 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,38 @@ def _wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return callable_(*args, **kwargs)

return _wrapper


def synchronized(
func: typing.Optional[typing.Callable] = None,
lock: typing.Optional[threading.Lock] = None,
) -> typing.Callable:
"""Wraps `func` to synchronize invocations, optionally taking a user defined `lock`.

This function can be used as a decorator, like:

@synchronized
def my_function(...):
...

or

@synchronized(lock=my_lock)
def my_function(...):
...
"""
if lock is None:
lock = threading.Lock()
assert lock is not None

def _decorator(func: typing.Callable) -> typing.Callable:
@functools.wraps(func)
def __wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
with lock: # type: ignore
return func(*args, **kwargs)

return __wrapper

if func is None:
return _decorator
return _decorator(func)
148 changes: 78 additions & 70 deletions bdai_ros2_wrappers/test/test_single_goal_multiple_action_servers.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,116 @@
# Copyright (c) 2023 Boston Dynamics AI Institute Inc. All rights reserved.
# Copyright (c) 2023-2024 Boston Dynamics AI Institute Inc. All rights reserved.
import array
import time
from threading import Semaphore
from typing import Tuple

import pytest
from example_interfaces.action import Fibonacci
from rclpy.action.server import GoalStatus, ServerGoalHandle
from rclpy.action.server import ServerGoalHandle

from bdai_ros2_wrappers.action_client import ActionClientWrapper
from bdai_ros2_wrappers.action import Actionable
from bdai_ros2_wrappers.futures import wait_for_future
from bdai_ros2_wrappers.scope import ROSAwareScope
from bdai_ros2_wrappers.single_goal_multiple_action_servers import SingleGoalMultipleActionServers


def execute_callback(goal_handle: ServerGoalHandle) -> Fibonacci.Result:
"""Executor callback for a server that does fibonacci"""
sequence = [0, 1]
for i in range(1, goal_handle.request.order):
sequence.append(sequence[i] + sequence[i - 1])

goal_handle.succeed()

result = Fibonacci.Result()
result.sequence = sequence
return result
@pytest.fixture
def action_triplet(ros: ROSAwareScope) -> Tuple[Semaphore, Actionable, Actionable]:
semaphore = Semaphore(0)

def execute_callback(goal_handle: ServerGoalHandle) -> Fibonacci.Result:
nonlocal semaphore
semaphore.acquire()

def execute_callback_wrong_fib(goal_handle: ServerGoalHandle) -> Fibonacci.Result:
"""Different executor for another server that does fibonacci wrong"""
# time delay to make interrupting easier
time.sleep(1)
sequence = [0, 1]
for i in range(1, goal_handle.request.order):
sequence.append(sequence[i] * sequence[i - 1])
sequence = [0, 1]
for i in range(1, goal_handle.request.order):
sequence.append(sequence[i] + sequence[i - 1])

result = None
if goal_handle.status != GoalStatus.STATUS_ABORTED:
goal_handle.succeed()
result = Fibonacci.Result()
result.sequence = sequence
else:
result = Fibonacci.Result()
result.sequence = [-1]
if not goal_handle.is_cancel_requested:
result.sequence = sequence
goal_handle.succeed()
else:
goal_handle.canceled()
return result

return result
def reversed_execute_callback(goal_handle: ServerGoalHandle) -> Fibonacci.Result:
nonlocal semaphore
semaphore.acquire()

sequence = [0, 1]
for i in range(1, goal_handle.request.order):
sequence.append(sequence[i] + sequence[i - 1])

result = Fibonacci.Result()
if not goal_handle.is_cancel_requested:
result.sequence = list(reversed(sequence))
goal_handle.succeed()
else:
goal_handle.canceled()
return result

@pytest.fixture
def action_triplet(
ros: ROSAwareScope,
) -> Tuple[SingleGoalMultipleActionServers, ActionClientWrapper, ActionClientWrapper]:
action_parameters = [
(Fibonacci, "fibonacci", execute_callback, None),
(Fibonacci, "fibonacci_wrong", execute_callback_wrong_fib, None),
(Fibonacci, "fibonacci/compute", execute_callback, None),
(Fibonacci, "fibonacci/compute_reversed", reversed_execute_callback, None),
]
assert ros.node is not None
action_server = SingleGoalMultipleActionServers(ros.node, action_parameters)
action_client_a = ActionClientWrapper(Fibonacci, "fibonacci", ros.node)
action_client_b = ActionClientWrapper(Fibonacci, "fibonacci_wrong", ros.node)
return action_server, action_client_a, action_client_b
SingleGoalMultipleActionServers(ros.node, action_parameters)
compute_fibonacci = Actionable(Fibonacci, "fibonacci/compute", ros.node)
compute_fibonacci_reversed = Actionable(Fibonacci, "fibonacci/compute_reversed", ros.node)
return semaphore, compute_fibonacci, compute_fibonacci_reversed


def test_actions_in_sequence(
action_triplet: Tuple[SingleGoalMultipleActionServers, ActionClientWrapper, ActionClientWrapper],
action_triplet: Tuple[Semaphore, Actionable, Actionable],
) -> None:
"""Tests out normal operation with multiple action servers and clients"""
_, action_client_a, action_client_b = action_triplet
semaphore, compute_fibonacci, compute_fibonacci_reversed = action_triplet

semaphore.release() # allow for action A
semaphore.release() # allow for action B

goal = Fibonacci.Goal()
goal.order = 5
# use first client
result = action_client_a.send_goal_and_wait("action_request_a", goal=goal, timeout_sec=5)
assert result is not None
result = compute_fibonacci(goal)
expected_result = array.array("i", [0, 1, 1, 2, 3, 5])
assert result.sequence == expected_result
# use second client
result = action_client_b.send_goal_and_wait("action_request_b", goal=goal, timeout_sec=5)
assert result is not None
expected_result = array.array("i", [0, 1, 0, 0, 0, 0])
result = compute_fibonacci_reversed(goal)
expected_result = array.array("i", [5, 3, 2, 1, 1, 0])
assert result.sequence == expected_result


def test_action_interruption(
ros: ROSAwareScope,
action_triplet: Tuple[SingleGoalMultipleActionServers, ActionClientWrapper, ActionClientWrapper],
def test_same_action_interruption(
action_triplet: Tuple[Semaphore, Actionable, Actionable],
) -> None:
"""This test should start a delayed request from another client
then make an immediate request to interrupt the last request.
semaphore, compute_fibonacci, _ = action_triplet

Due to the threading and reliance on sleeps this test might be
tempermental on other machines.
"""
_, action_client_a, action_client_b = action_triplet
goal = Fibonacci.Goal()
goal.order = 5
action_a = compute_fibonacci.asynchronously(goal)
action_b = compute_fibonacci.asynchronously(goal)
semaphore.release() # allow for action A
semaphore.release() # allow for action B
assert wait_for_future(action_a.finalization, timeout_sec=5.0)
assert action_a.cancelled
assert wait_for_future(action_b.finalization, timeout_sec=5.0)
assert action_b.succeeded
expected_result = array.array("i", [0, 1, 1, 2, 3, 5])
assert action_b.result.sequence == expected_result

def deferred_request() -> None:
# time delay to give other action time to get started before interrupting
time.sleep(0.3)
goal = Fibonacci.Goal()
goal.order = 5
action_client_a.send_goal_and_wait("deferred_action_request", goal=goal, timeout_sec=2)

assert ros.executor is not None
ros.executor.create_task(deferred_request)
def test_different_action_interruption(
action_triplet: Tuple[Semaphore, Actionable, Actionable],
) -> None:
semaphore, compute_fibonacci, compute_fibonacci_reversed = action_triplet

# immediately start the request for other goal
goal = Fibonacci.Goal()
goal.order = 5
result = action_client_b.send_goal_and_wait("action_request", goal=goal, timeout_sec=5)
assert result is None
action_a = compute_fibonacci.asynchronously(goal)
action_b = compute_fibonacci_reversed.asynchronously(goal)
semaphore.release() # allow for action A
semaphore.release() # allow for action B
assert wait_for_future(action_a.finalization, timeout_sec=5.0)
assert action_a.cancelled
assert wait_for_future(action_b.finalization, timeout_sec=5.0)
assert action_b.succeeded
expected_result = array.array("i", [5, 3, 2, 1, 1, 0])
assert action_b.result.sequence == expected_result
Loading