Skip to content

Commit

Permalink
[EK-30] Fix TOCTTOU race in SingleGoalMultipleActionServers impleme…
Browse files Browse the repository at this point in the history
…ntation (#78)

Signed-off-by: Michel Hidalgo <[email protected]>
  • Loading branch information
mhidalgo-bdai authored Mar 19, 2024
1 parent 316430d commit facc8b2
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from typing import Any, Callable, List, Optional, Tuple

from rclpy.action import ActionServer, CancelResponse, GoalResponse
from rclpy.action.server import ServerGoalHandle
from rclpy.action.server import GoalEvent, RCLError, ServerGoalHandle
from rclpy.callback_groups import CallbackGroup
from rclpy.impl.rcutils_logger import RcutilsLogger
from rclpy.node import Node

from bdai_ros2_wrappers.type_hints import Action, ActionType
from bdai_ros2_wrappers.utilities import synchronized


class SingleGoalMultipleActionServers:
Expand All @@ -22,25 +23,39 @@ def __init__(
self,
node: Node,
action_server_parameters: List[Tuple[ActionType, str, Callable, Optional[CallbackGroup]]],
nosync: bool = False,
) -> None:
"""Constructor"""
"""Constructor.
Args:
node: ROS 2 node to use for action servers.
action_server_parameters: tuples per action server, listing action type, action name,
action execution callback, and action callback group (which may be None).
nosync: whether to synchronize action execution callbacks using locks or not.
Set to True when action execution callback already enforce mutual exclusion.
"""
self._node = node
self._goal_handle_lock = threading.Lock()
self._goal_handle: Optional[ServerGoalHandle] = None
self._goal_lock = threading.Lock()
self._action_servers = []
for action_type, action_topic, execute_callback, callback_group in action_server_parameters:
self._action_servers.append(
ActionServer(
node,
action_type,
action_topic,
execute_callback=execute_callback,
goal_callback=self.goal_callback,
handle_accepted_callback=self.handle_accepted_callback,
cancel_callback=self.cancel_callback,
callback_group=callback_group,
),
if not nosync:
execution_lock = threading.Lock()
action_server_parameters = [
(action_type, action_topic, synchronized(execute_callback, execution_lock), callback_group)
for action_type, action_topic, execute_callback, callback_group in action_server_parameters
]
self._action_servers = [
ActionServer(
node,
action_type,
action_topic,
execute_callback=execute_callback,
goal_callback=self.goal_callback,
handle_accepted_callback=self.handle_accepted_callback,
cancel_callback=self.cancel_callback,
callback_group=callback_group,
)
for action_type, action_topic, execute_callback, callback_group in action_server_parameters
]

def get_logger(self) -> RcutilsLogger:
"""Returns the ros logger"""
Expand All @@ -58,12 +73,14 @@ def goal_callback(self, goal: Action.Goal) -> GoalResponse:

def handle_accepted_callback(self, goal_handle: ServerGoalHandle) -> None:
"""Callback triggered when an action is accepted."""
with self._goal_lock:
with self._goal_handle_lock:
# This server only allows one goal at a time
if self._goal_handle is not None and self._goal_handle.is_active:
self.get_logger().info("Aborting previous goal")
# Abort the existing goal
self._goal_handle.abort()
if self._goal_handle is not None:
self.get_logger().info("Canceling previous goal")
try:
self._goal_handle._update_state(GoalEvent.CANCEL_GOAL)
except RCLError as ex:
self.get_logger().debug(f"Failed to cancel goal: {ex}")
self._goal_handle = goal_handle

goal_handle.execute()
Expand Down
35 changes: 35 additions & 0 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,38 @@ def _wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return callable_(*args, **kwargs)

return _wrapper


def synchronized(
func: typing.Optional[typing.Callable] = None,
lock: typing.Optional[threading.Lock] = None,
) -> typing.Callable:
"""Wraps `func` to synchronize invocations, optionally taking a user defined `lock`.
This function can be used as a decorator, like:
@synchronized
def my_function(...):
...
or
@synchronized(lock=my_lock)
def my_function(...):
...
"""
if lock is None:
lock = threading.Lock()
assert lock is not None

def _decorator(func: typing.Callable) -> typing.Callable:
@functools.wraps(func)
def __wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
with lock: # type: ignore
return func(*args, **kwargs)

return __wrapper

if func is None:
return _decorator
return _decorator(func)
180 changes: 103 additions & 77 deletions bdai_ros2_wrappers/test/test_single_goal_multiple_action_servers.py
Original file line number Diff line number Diff line change
@@ -1,108 +1,134 @@
# Copyright (c) 2023 Boston Dynamics AI Institute Inc. All rights reserved.
# Copyright (c) 2023-2024 Boston Dynamics AI Institute Inc. All rights reserved.
import array
import time
from typing import Tuple
from threading import Barrier, Lock
from typing import Iterable, Tuple

import pytest
from example_interfaces.action import Fibonacci
from rclpy.action.server import GoalStatus, ServerGoalHandle
from rclpy.action.server import ServerGoalHandle

from bdai_ros2_wrappers.action_client import ActionClientWrapper
from bdai_ros2_wrappers.action import Actionable
from bdai_ros2_wrappers.futures import wait_for_future
from bdai_ros2_wrappers.scope import ROSAwareScope
from bdai_ros2_wrappers.single_goal_multiple_action_servers import SingleGoalMultipleActionServers


def execute_callback(goal_handle: ServerGoalHandle) -> Fibonacci.Result:
"""Executor callback for a server that does fibonacci"""
sequence = [0, 1]
for i in range(1, goal_handle.request.order):
sequence.append(sequence[i] + sequence[i - 1])

goal_handle.succeed()

result = Fibonacci.Result()
result.sequence = sequence
return result


def execute_callback_wrong_fib(goal_handle: ServerGoalHandle) -> Fibonacci.Result:
"""Different executor for another server that does fibonacci wrong"""
# time delay to make interrupting easier
time.sleep(1)
sequence = [0, 1]
for i in range(1, goal_handle.request.order):
sequence.append(sequence[i] * sequence[i - 1])

result = None
if goal_handle.status != GoalStatus.STATUS_ABORTED:
goal_handle.succeed()
result = Fibonacci.Result()
result.sequence = sequence
else:
result = Fibonacci.Result()
result.sequence = [-1]

return result


@pytest.fixture
def action_triplet(
ros: ROSAwareScope,
) -> Tuple[SingleGoalMultipleActionServers, ActionClientWrapper, ActionClientWrapper]:
def action_triplet(ros: ROSAwareScope) -> Iterable[Tuple[Barrier, Actionable, Actionable]]:
lock = Lock()
barrier = Barrier(2)

def execute_callback(goal_handle: ServerGoalHandle) -> Fibonacci.Result:
nonlocal barrier, lock

if not barrier.broken:
barrier.wait()

with lock:
sequence = [0, 1]
for i in range(1, goal_handle.request.order):
sequence.append(sequence[i] + sequence[i - 1])

if not barrier.broken:
barrier.wait()

result = Fibonacci.Result()
if not goal_handle.is_cancel_requested:
result.sequence = sequence
goal_handle.succeed()
else:
goal_handle.canceled()
return result

def reversed_execute_callback(goal_handle: ServerGoalHandle) -> Fibonacci.Result:
nonlocal barrier, lock

if not barrier.broken:
barrier.wait()

with lock:
sequence = [0, 1]
for i in range(1, goal_handle.request.order):
sequence.append(sequence[i] + sequence[i - 1])

if not barrier.broken:
barrier.wait()

result = Fibonacci.Result()
if not goal_handle.is_cancel_requested:
result.sequence = list(reversed(sequence))
goal_handle.succeed()
else:
goal_handle.canceled()
return result

action_parameters = [
(Fibonacci, "fibonacci", execute_callback, None),
(Fibonacci, "fibonacci_wrong", execute_callback_wrong_fib, None),
(Fibonacci, "fibonacci/compute", execute_callback, None),
(Fibonacci, "fibonacci/compute_reversed", reversed_execute_callback, None),
]
assert ros.node is not None
action_server = SingleGoalMultipleActionServers(ros.node, action_parameters)
action_client_a = ActionClientWrapper(Fibonacci, "fibonacci", ros.node)
action_client_b = ActionClientWrapper(Fibonacci, "fibonacci_wrong", ros.node)
return action_server, action_client_a, action_client_b
SingleGoalMultipleActionServers(ros.node, action_parameters, nosync=True)
compute_fibonacci = Actionable(Fibonacci, "fibonacci/compute", ros.node)
compute_fibonacci_reversed = Actionable(Fibonacci, "fibonacci/compute_reversed", ros.node)

try:
yield barrier, compute_fibonacci, compute_fibonacci_reversed
finally:
barrier.abort()


def test_actions_in_sequence(
action_triplet: Tuple[SingleGoalMultipleActionServers, ActionClientWrapper, ActionClientWrapper],
action_triplet: Tuple[Barrier, Actionable, Actionable],
) -> None:
"""Tests out normal operation with multiple action servers and clients"""
_, action_client_a, action_client_b = action_triplet
barrier, compute_fibonacci, compute_fibonacci_reversed = action_triplet

barrier.abort() # avoid synchronization

goal = Fibonacci.Goal()
goal.order = 5
# use first client
result = action_client_a.send_goal_and_wait("action_request_a", goal=goal, timeout_sec=5)
assert result is not None
result = compute_fibonacci(goal)
expected_result = array.array("i", [0, 1, 1, 2, 3, 5])
assert result.sequence == expected_result
# use second client
result = action_client_b.send_goal_and_wait("action_request_b", goal=goal, timeout_sec=5)
assert result is not None
expected_result = array.array("i", [0, 1, 0, 0, 0, 0])
result = compute_fibonacci_reversed(goal)
expected_result = array.array("i", [5, 3, 2, 1, 1, 0])
assert result.sequence == expected_result


def test_action_interruption(
ros: ROSAwareScope,
action_triplet: Tuple[SingleGoalMultipleActionServers, ActionClientWrapper, ActionClientWrapper],
def test_same_action_interruption(
action_triplet: Tuple[Barrier, Actionable, Actionable],
) -> None:
"""This test should start a delayed request from another client
then make an immediate request to interrupt the last request.
barrier, compute_fibonacci, _ = action_triplet

Due to the threading and reliance on sleeps this test might be
tempermental on other machines.
"""
_, action_client_a, action_client_b = action_triplet
goal = Fibonacci.Goal()
goal.order = 5
action_a = compute_fibonacci.asynchronously(goal)
barrier.wait(timeout=5.0) # let action A start
action_b = compute_fibonacci.asynchronously(goal)
# Actions B and A will allow each other to start and finish, respectively
assert wait_for_future(action_a.finalization, timeout_sec=5.0)
assert action_a.cancelled
barrier.wait(timeout=5.0) # let action B finish
assert wait_for_future(action_b.finalization, timeout_sec=5.0)
assert action_b.succeeded
expected_result = array.array("i", [0, 1, 1, 2, 3, 5])
assert action_b.result.sequence == expected_result

def deferred_request() -> None:
# time delay to give other action time to get started before interrupting
time.sleep(0.3)
goal = Fibonacci.Goal()
goal.order = 5
action_client_a.send_goal_and_wait("deferred_action_request", goal=goal, timeout_sec=2)

assert ros.executor is not None
ros.executor.create_task(deferred_request)
def test_different_action_interruption(
action_triplet: Tuple[Barrier, Actionable, Actionable],
) -> None:
barrier, compute_fibonacci, compute_fibonacci_reversed = action_triplet

# immediately start the request for other goal
goal = Fibonacci.Goal()
goal.order = 5
result = action_client_b.send_goal_and_wait("action_request", goal=goal, timeout_sec=5)
assert result is None
action_a = compute_fibonacci.asynchronously(goal)
barrier.wait(timeout=5.0) # let action A start
action_b = compute_fibonacci_reversed.asynchronously(goal)
# Actions B and A will allow each other to start and finish, respectively
assert wait_for_future(action_a.finalization, timeout_sec=5.0)
assert action_a.cancelled
barrier.wait(timeout=5.0) # let action B finish
assert wait_for_future(action_b.finalization, timeout_sec=5.0)
assert action_b.succeeded
expected_result = array.array("i", [5, 3, 2, 1, 1, 0])
assert action_b.result.sequence == expected_result

0 comments on commit facc8b2

Please sign in to comment.