diff --git a/ada_feeding/CMakeLists.txt b/ada_feeding/CMakeLists.txt index ed5e6277..5db4f1b1 100644 --- a/ada_feeding/CMakeLists.txt +++ b/ada_feeding/CMakeLists.txt @@ -62,13 +62,17 @@ install(DIRECTORY if(BUILD_TESTING) find_package(ament_cmake_pytest REQUIRED) set(_pytest_tests + tests/__init__.py # not technically a test, but necessary for other tests + tests/helpers.py # not technically a test, but necessary for other tests + tests/test_eventually_swiss.py + tests/test_scoped_behavior.py # Add other test files here ) foreach(_test_path ${_pytest_tests}) get_filename_component(_test_name ${_test_path} NAME_WE) ament_add_pytest_test(${_test_name} ${_test_path} APPEND_ENV PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR} - TIMEOUT 60 + TIMEOUT 10 WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} ) endforeach() diff --git a/ada_feeding/ada_feeding/behaviors/move_to.py b/ada_feeding/ada_feeding/behaviors/move_to.py index 9fdea0e1..8ace09d9 100644 --- a/ada_feeding/ada_feeding/behaviors/move_to.py +++ b/ada_feeding/ada_feeding/behaviors/move_to.py @@ -395,7 +395,7 @@ def terminate(self, new_status: py_trees.common.Status) -> None: # A termination request has not succeeded until the MoveIt2 action server is IDLE with self.moveit2_lock: while self.moveit2.query_state() != MoveIt2State.IDLE: - self.node.logger.info( + self.logger.info( f"MoveTo Update MoveIt2State not Idle {time.time()} {terminate_requested_time} " f"{self.terminate_timeout_s}" ) diff --git a/ada_feeding/ada_feeding/decorators/__init__.py b/ada_feeding/ada_feeding/decorators/__init__.py index 3a4f9236..077755fa 100644 --- a/ada_feeding/ada_feeding/decorators/__init__.py +++ b/ada_feeding/ada_feeding/decorators/__init__.py @@ -22,3 +22,6 @@ from .set_joint_path_constraint import SetJointPathConstraint from .set_position_path_constraint import SetPositionPathConstraint from .set_orientation_path_constraint import SetOrientationPathConstraint + +# On Preempt +from .on_preempt import OnPreempt diff --git a/ada_feeding/ada_feeding/decorators/on_preempt.py b/ada_feeding/ada_feeding/decorators/on_preempt.py new file mode 100644 index 00000000..508183b0 --- /dev/null +++ b/ada_feeding/ada_feeding/decorators/on_preempt.py @@ -0,0 +1,115 @@ +""" +NOTE: This is a multi-tick version of the decorator discussed in +https://github.com/splintered-reality/py_trees/pull/427 . Once a +multi-tick version of that decorator is merged into py_trees, this +decorator should be removed in favor of the main py_trees one. +""" + +import time +import typing + +from py_trees import behaviour, common +from py_trees.decorators import Decorator + + +class OnPreempt(Decorator): + """ + Behaves identically to :class:`~py_trees.decorators.PassThrough` except + that if it gets preempted (i.e., `terminate(INVALID)` is called on it) + while its status is :data:`~py_trees.common.Status.RUNNING`, it will + tick `on_preempt` either: (a) for a single tick; or (b) until `on_preempt` + reaches a status other than :data:`~py_trees.common.Status.RUNNING` or + times out. Note that `on_preempt` may be a behavior that exists elsewhere + in the tree, or it may be a separate behavior. + + This is useful to cleanup, restore a context switch or to + implement a finally-like behaviour. + + .. seealso:: :meth:`py_trees.idioms.eventually`, :meth:`py_trees.idioms.eventually_swiss` + """ + + # pylint: disable=too-many-arguments + # This is acceptable, to give users maximum control over how this decorator + # behaves. + def __init__( + self, + name: str, + child: behaviour.Behaviour, + on_preempt: behaviour.Behaviour, + single_tick: bool = True, + period_ms: int = 0, + timeout: typing.Optional[float] = None, + ): + """ + Initialise with the standard decorator arguments. + + Args: + name: the decorator name + child: the child to be decorated + on_preempt: the behaviour or subtree to tick on preemption + single_tick: if True, tick the child once on preemption. Else, + tick the child until it reaches a status other than + :data:`~py_trees.common.Status.RUNNING`. + period_ms: how long to sleep between ticks (in milliseconds) + if `single_tick` is False. If 0, then do not sleep. + timeout: how long (sec) to wait for the child to reach a status + other than :data:`~py_trees.common.Status.RUNNING` if + `single_tick` is False. If None, then do not timeout. + """ + super().__init__(name=name, child=child) + self.on_preempt = on_preempt + self.single_tick = single_tick + self.period_ms = period_ms + self.timeout = timeout + + def update(self) -> common.Status: + """ + Just reflect the child status. + + Returns: + the behaviour's new status :class:`~py_trees.common.Status` + """ + return self.decorated.status + + def stop(self, new_status: common.Status) -> None: + """ + Check if the child is running (dangling) and stop it if that is the case. + + This function departs from the standard :meth:`~py_trees.decorators.Decorator.stop` + in that it *first* stops the child, and *then* stops the decorator. + + Args: + new_status (:class:`~py_trees.common.Status`): the behaviour is transitioning + to this new status + """ + self.logger.debug(f"{self.__class__.__name__}.stop({new_status})") + # priority interrupt handling + if new_status == common.Status.INVALID: + self.decorated.stop(new_status) + # if the decorator returns SUCCESS/FAILURE and should stop the child + if self.decorated.status == common.Status.RUNNING: + self.decorated.stop(common.Status.INVALID) + self.terminate(new_status) + self.status = new_status + + def terminate(self, new_status: common.Status) -> None: + """Tick the child behaviour once.""" + self.logger.debug( + f"{self.__class__.__name__}.terminate({self.status}->{new_status})" + ) + if new_status == common.Status.INVALID and self.status == common.Status.RUNNING: + terminate_start_s = time.monotonic() + # Tick the child once + self.on_preempt.tick_once() + # If specified, tick until the child reaches a non-RUNNING status + if not self.single_tick: + while self.on_preempt.status == common.Status.RUNNING and ( + self.timeout is None + or time.monotonic() - terminate_start_s < self.timeout + ): + if self.period_ms > 0: + time.sleep(self.period_ms / 1000.0) + self.on_preempt.tick_once() + # Do not need to stop the child here - this method + # is only called by Decorator.stop() which will handle + # that responsibility immediately after this method returns. diff --git a/ada_feeding/ada_feeding/idioms/__init__.py b/ada_feeding/ada_feeding/idioms/__init__.py index 0d841124..db6e1b7e 100644 --- a/ada_feeding/ada_feeding/idioms/__init__.py +++ b/ada_feeding/ada_feeding/idioms/__init__.py @@ -3,5 +3,7 @@ project. """ from .add_pose_path_constraints import add_pose_path_constraints +from .eventually_swiss import eventually_swiss from .pre_moveto_config import pre_moveto_config from .retry_call_ros_service import retry_call_ros_service +from .scoped_behavior import scoped_behavior diff --git a/ada_feeding/ada_feeding/idioms/eventually_swiss.py b/ada_feeding/ada_feeding/idioms/eventually_swiss.py new file mode 100644 index 00000000..bbbba6a7 --- /dev/null +++ b/ada_feeding/ada_feeding/idioms/eventually_swiss.py @@ -0,0 +1,114 @@ +""" +NOTE: This is a preempt-handling version of the idiom discussed in +https://github.com/splintered-reality/py_trees/pull/427 . Once a +preempt-handling version of that idiom is merged into py_trees, this +idiom should be removed in favor of the main py_trees one. +""" + +import typing + +from py_trees import behaviour, behaviours, composites + +from ada_feeding.decorators import OnPreempt + + +def eventually_swiss( + name: str, + workers: typing.List[behaviour.Behaviour], + on_failure: behaviour.Behaviour, + on_success: behaviour.Behaviour, + on_preempt: behaviour.Behaviour, + on_preempt_single_tick: bool = True, + on_preempt_period_ms: int = 0, + on_preempt_timeout: typing.Optional[float] = None, + return_on_success_status: bool = True, +) -> behaviour.Behaviour: + """ + Implement a multi-tick, general purpose 'try-except-else'-like pattern. + + This is a swiss knife version of the eventually idiom + that facilitates a multi-tick response for specialised + handling work sequence's completion status. Specifically, this idiom + guarentees the following: + 1. The on_success behaviour is ticked only if the workers all return SUCCESS. + 2. The on_failure behaviour is ticked only if at least one worker returns FAILURE. + 3. The on_preempt behaviour is ticked only if `stop(INVALID)` is called on the + root behaviour returned from this idiom while the root behaviour's status is + :data:`~py_trees.common.Status.RUNNING`. + + The return status of this idiom in non-preemption cases is: + - If the workers all return SUCCESS: + - If `return_on_success_status` is True, then the status of the root behaviour + returned from this idiom is status of `on_success`. + - If `return_on_success_status` is False, then the status of the root behaviour + returned from this idiom is :data:`~py_trees.common.Status.SUCCESS`. + - If at least one worker returns FAILURE, return :data:`~py_trees.common.Status.FAILURE`. + + .. graphviz:: dot/eventually-swiss.dot + + Args: + name: the name to use for the idiom root + workers: the worker behaviours or subtrees + on_success: the behaviour or subtree to tick on work success + on_failure: the behaviour or subtree to tick on work failure + on_preempt: the behaviour or subtree to tick on work preemption + on_preempt_single_tick: if True, tick the on_preempt behaviour once + on preemption. Else, tick the on_preempt behaviour until it + reaches a status other than :data:`~py_trees.common.Status.RUNNING`. + on_preempt_period_ms: how long to sleep between ticks (in milliseconds) + if `on_preempt_single_tick` is False. If 0, then do not sleep. + on_preempt_timeout: how long (sec) to wait for the on_preempt behaviour + to reach a status other than :data:`~py_trees.common.Status.RUNNING` + if `on_preempt_single_tick` is False. If None, then do not timeout. + return_on_success_status: if True, pass the `on_success` status to the + root, else return :data:`~py_trees.common.Status.SUCCESS`. + + Returns: + :class:`~py_trees.behaviour.Behaviour`: the root of the eventually_swiss subtree + + .. seealso:: :meth:`py_trees.idioms.eventually`, :ref:`py-trees-demo-eventually-swiss-program` + """ + # pylint: disable=too-many-arguments, too-many-locals + # This is acceptable, to give users maximum control over how this swiss-knife + # idiom behaves. + # pylint: disable=abstract-class-instantiated + # behaviours.Failure and behaviours.Success are valid instantiations + + workers_sequence = composites.Sequence( + name="Workers", + memory=True, + children=workers, + ) + on_failure_return_status = composites.Sequence( + name="On Failure Return Failure", + memory=True, + children=[on_failure, behaviours.Failure(name="Failure")], + ) + on_failure_subtree = composites.Selector( + name="On Failure", + memory=True, + children=[workers_sequence, on_failure_return_status], + ) + if return_on_success_status: + on_success_return_status = on_success + else: + on_success_return_status = composites.Selector( + name="On Success Return Success", + memory=True, + children=[on_success, behaviours.Success(name="Success")], + ) + on_success_subtree = composites.Sequence( + name="On Success", + memory=True, + children=[on_failure_subtree, on_success_return_status], + ) + root = OnPreempt( + name=name, + child=on_success_subtree, + on_preempt=on_preempt, + single_tick=on_preempt_single_tick, + period_ms=on_preempt_period_ms, + timeout=on_preempt_timeout, + ) + + return root diff --git a/ada_feeding/ada_feeding/idioms/scoped_behavior.py b/ada_feeding/ada_feeding/idioms/scoped_behavior.py new file mode 100644 index 00000000..94d666a0 --- /dev/null +++ b/ada_feeding/ada_feeding/idioms/scoped_behavior.py @@ -0,0 +1,132 @@ +""" +This module defines the `scoped_behavior` idiom, which is a way to run a main +behavior within the scope of a pre and post behavior. + +In expected usage, the pre behavior will open or create a resources, the main +behavior will use those resources, and the post behavior will close or delete the +resources. The idiom guarentees the following: + 1. The main behavior will not be ticked unless the pre behavior returns + SUCCESS. + 2. The behavior returned by this idiom will not reach a terminal (non-RUNNING) + status until the post behavior has been ticked to a terminal status. In + other words, regardless of whether the main behavior returns SUCCESS, + FAILURE, or if the idiom is preempted (e.g., had `stop(INVALID)` called + on it), the post behavior will still be ticked till a terminal status. + 3. The root behavior's terminal status will be FAILURE if the pre behavior + returns FAILURE, else it will be the main behavior's terminal status. + +Note the following nuances: + 1. If the main behaviour reaches SUCCESS or FAILURE, the post behaviour will + be ticked asynchronously during the standard `tick()` of the tree. However, + if the idiom is preempted, the post behaviour will be ticked synchronously, + as part of the `stop(INVALID)` code of the tree, e.g., progression of + the `stop(INVALID)` code will be blocked until the post behaviour reaches + a terminal status. + 2. It is possible that the post behavior will be ticked to completion multiple + times. For example, consider the case where the main behavior succeeds, + the post behavior succeeds, and then the idiom is preempted. Therefore, + the post behavior should be designed in a way that it can be run to completion + multiple times sequentially, without negative side effects. +""" + +# Standard imports +from typing import List, Optional + +# Third-party imports +import py_trees +from py_trees.behaviours import BlackboardToStatus, UnsetBlackboardVariable +from py_trees.decorators import ( + FailureIsSuccess, + StatusToBlackboard, +) + +# Local imports +from ada_feeding.decorators import OnPreempt + + +# pylint: disable=too-many-arguments +# One over is fine. +def scoped_behavior( + name: str, + pre_behavior: py_trees.behaviour.Behaviour, + workers: List[py_trees.behaviour.Behaviour], + post_behavior: py_trees.behaviour.Behaviour, + on_preempt_period_ms: int = 0, + on_preempt_timeout: Optional[float] = None, + status_blackboard_key: Optional[str] = None, +) -> py_trees.behaviour.Behaviour: + """ + Returns a behavior that runs the main behavior within the scope of the pre + and post behaviors. See the module docstring for more details. + + Parameters + ---------- + name: The name to associate with this behavior. + pre_behavior: The behavior to run before the main behavior. + workers: The behaviors to run in the middle. + post_behavior: The behavior to run after the main behavior. + on_preempt_period_ms: How long to sleep between ticks (in milliseconds) + if the behavior gets preempted. If 0, then do not sleep. + on_preempt_timeout: How long (sec) to wait for the behavior to reach a + terminal status if the behavior gets preempted. If None, then do not + timeout. + status_blackboard_key: The blackboard key to use to store the status of + the behavior. If None, use `/{name}/scoped_behavior_status`. + """ + if status_blackboard_key is None: + status_blackboard_key = f"/{name}/scoped_behavior_status" + + main_sequence = py_trees.composites.Sequence( + name="Scoped Behavior", + memory=True, + ) + + # First, unset the status variable. + unset_status = UnsetBlackboardVariable( + name="Unset Status", key=status_blackboard_key + ) + main_sequence.children.append(unset_status) + + # Then, execute the pre behavior and the workers + pre_and_workers_sequence = py_trees.composites.Sequence( + name="Pre & Workers", + children=[pre_behavior] + workers, + memory=True, + ) + write_workers_status = StatusToBlackboard( + name="Write Pre & Workers Status", + child=pre_and_workers_sequence, + variable_name=status_blackboard_key, + ) + workers_branch = FailureIsSuccess( + name="Pre & Workers Branch", + child=write_workers_status, + ) + main_sequence.children.append(workers_branch) + + # Then, execute the post behavior + post_branch = FailureIsSuccess( + name="Post Branch", + child=post_behavior, + ) + main_sequence.children.append(post_branch) + + # Finally, write the status of the main behavior to the blackboard. + write_status = BlackboardToStatus( + name="Write Status", + variable_name=status_blackboard_key, + ) + main_sequence.children.append(write_status) + + # To handle preemptions, we place the main behavior into an OnPreempt + # decorator, with `post` as the preemption behavior. + root = OnPreempt( + name=name, + child=main_sequence, + on_preempt=post_behavior, + single_tick=False, + period_ms=on_preempt_period_ms, + timeout=on_preempt_timeout, + ) + + return root diff --git a/ada_feeding/ada_feeding/trees/move_from_mouth_tree.py b/ada_feeding/ada_feeding/trees/move_from_mouth_tree.py index 54a1e296..f307734b 100644 --- a/ada_feeding/ada_feeding/trees/move_from_mouth_tree.py +++ b/ada_feeding/ada_feeding/trees/move_from_mouth_tree.py @@ -20,7 +20,7 @@ # Local imports from ada_feeding.behaviors import ModifyCollisionObject, ModifyCollisionObjectOperation -from ada_feeding.idioms import pre_moveto_config +from ada_feeding.idioms import pre_moveto_config, scoped_behavior from ada_feeding.idioms.bite_transfer import ( get_toggle_collision_object_behavior, ) @@ -352,45 +352,24 @@ def gen_remove_in_front_of_wheelchair_wall() -> None: return retval # Link all the behaviours together in a sequence with memory - move_from_mouth = py_trees.composites.Sequence( + root = py_trees.composites.Sequence( name=name + " Main", memory=True, children=[ # For now, we only re-tare the F/T sensor once, since no large forces # are expected during transfer. pre_moveto_config_behavior, - allow_wheelchair_collision, - move_to_staging_configuration, - gen_disallow_wheelchair_collision(), - add_in_front_of_wheelchair_wall, - move_to_end_configuration, - gen_remove_in_front_of_wheelchair_wall(), - ], - ) - move_from_mouth.logger = logger - - # Create a cleanup branch for the behaviors that should get executed if - # the main tree has a failure - cleanup_tree = py_trees.composites.Sequence( - name=name + " Cleanup", - memory=True, - children=[ - gen_disallow_wheelchair_collision(), - gen_remove_in_front_of_wheelchair_wall(), - ], - ) - - # If move_from_mouth fails, we still want to do some cleanup (e.g., turn - # face detection off). - root = py_trees.composites.Selector( - name=name, - memory=True, - children=[ - move_from_mouth, - # Even though we are cleaning up the tree, it should still - # pass the failure up. - py_trees.decorators.SuccessIsFailure( - name + " Cleanup Root", cleanup_tree + scoped_behavior( + name=name + " AllowWheelchairCollisionScope", + pre_behavior=allow_wheelchair_collision, + main_behaviors=[move_to_staging_configuration], + post_behavior_fn=gen_disallow_wheelchair_collision, + ), + scoped_behavior( + name=name + " AddInFrontOfWheelchairWallScope", + pre_behavior=add_in_front_of_wheelchair_wall, + main_behaviors=[move_to_end_configuration], + post_behavior_fn=gen_remove_in_front_of_wheelchair_wall, ), ], ) diff --git a/ada_feeding/ada_feeding/trees/move_to_configuration_with_ft_thresholds_tree.py b/ada_feeding/ada_feeding/trees/move_to_configuration_with_ft_thresholds_tree.py index 85c45b40..0a7daa89 100644 --- a/ada_feeding/ada_feeding/trees/move_to_configuration_with_ft_thresholds_tree.py +++ b/ada_feeding/ada_feeding/trees/move_to_configuration_with_ft_thresholds_tree.py @@ -6,6 +6,7 @@ """ # Standard imports +from functools import partial import logging from typing import List, Set @@ -14,7 +15,7 @@ from rclpy.node import Node # Local imports -from ada_feeding.idioms import pre_moveto_config +from ada_feeding.idioms import pre_moveto_config, scoped_behavior from ada_feeding.idioms.bite_transfer import get_toggle_watchdog_listener_behavior from ada_feeding.trees import MoveToTree, MoveToConfigurationTree @@ -23,6 +24,9 @@ class MoveToConfigurationWithFTThresholdsTree(MoveToTree): """ A behavior tree that moves the robot to a specified configuration, after re-taring the FT sensor and setting specific FT thresholds. + + TODO: Add the ability to pass force-torque thresholds to revert after the + motion, and then set the force-torque thresholds in the scoped behavior. """ # pylint: disable=too-many-instance-attributes, too-many-arguments @@ -182,46 +186,35 @@ def create_move_to_tree( logger=logger, ) - # Combine them in a sequence with memory - main_tree = py_trees.composites.Sequence( - name=name, - memory=True, - children=[pre_moveto_behavior, move_to_configuration_root], - ) - main_tree.logger = logger - if self.toggle_watchdog_listener: # If there was a failure in the main tree, we want to ensure to turn # the watchdog listener back on # pylint: disable=duplicate-code # This is similar to any other tree that needs to cleanup pre_moveto_config - turn_watchdog_listener_on = get_toggle_watchdog_listener_behavior( + turn_watchdog_listener_on_fn = partial( + get_toggle_watchdog_listener_behavior, name, turn_watchdog_listener_on_prefix, True, logger, ) - # Create a cleanup branch for the behaviors that should get executed if - # the main tree has a failure - cleanup_tree = turn_watchdog_listener_on - - # If main_tree fails, we still want to do some cleanup. - root = py_trees.composites.Selector( + # Create the main tree + root = scoped_behavior( + name=name + " ToggleWatchdogListenerOffScope", + pre_behavior=pre_moveto_behavior, + main_behaviors=[move_to_configuration_root], + post_behavior_fn=turn_watchdog_listener_on_fn, + ) + root.logger = logger + else: + # Combine them in a sequence with memory + root = py_trees.composites.Sequence( name=name, memory=True, - children=[ - main_tree, - # Even though we are cleaning up the tree, it should still - # pass the failure up. - py_trees.decorators.SuccessIsFailure( - name + " Cleanup Root", cleanup_tree - ), - ], + children=[pre_moveto_behavior, move_to_configuration_root], ) root.logger = logger - else: - root = main_tree tree = py_trees.trees.BehaviourTree(root) return tree diff --git a/ada_feeding/ada_feeding/trees/move_to_mouth_tree.py b/ada_feeding/ada_feeding/trees/move_to_mouth_tree.py index c18bf6bb..f467d5cb 100644 --- a/ada_feeding/ada_feeding/trees/move_to_mouth_tree.py +++ b/ada_feeding/ada_feeding/trees/move_to_mouth_tree.py @@ -29,7 +29,7 @@ from ada_feeding.helpers import ( POSITION_GOAL_CONSTRAINT_NAMESPACE_PREFIX, ) -from ada_feeding.idioms import pre_moveto_config +from ada_feeding.idioms import pre_moveto_config, scoped_behavior from ada_feeding.idioms.bite_transfer import ( get_toggle_collision_object_behavior, get_toggle_face_detection_behavior, @@ -440,48 +440,27 @@ def create_move_to_tree( ) # Link all the behaviours together in a sequence with memory - move_to_mouth = py_trees.composites.Sequence( + root = py_trees.composites.Sequence( name=name + " Main", memory=True, children=[ - turn_face_detection_on, # For now, we only re-tare the F/T sensor once, since no large forces # are expected during transfer. pre_moveto_config_behavior, move_to_staging_configuration, - detect_face, + scoped_behavior( + name=name + " FaceDetectionOnScope", + pre_behavior=turn_face_detection_on, + main_behaviors=[detect_face], + post_behavior_fn=gen_turn_face_detection_off, + ), compute_target_position, move_head, - allow_wheelchair_collision, - move_to_target_pose, - gen_disallow_wheelchair_collision(), - gen_turn_face_detection_off(), - ], - ) - move_to_mouth.logger = logger - - # Create a cleanup branch for the behaviors that should get executed if - # the main tree has a failure - cleanup_tree = py_trees.composites.Sequence( - name=name + " Cleanup", - memory=True, - children=[ - gen_disallow_wheelchair_collision(), - gen_turn_face_detection_off(), - ], - ) - - # If move_to_mouth fails, we still want to do some cleanup (e.g., turn - # face detection off). - root = py_trees.composites.Selector( - name=name, - memory=True, - children=[ - move_to_mouth, - # Even though we are cleaning up the tree, it should still - # pass the failure up. - py_trees.decorators.SuccessIsFailure( - name + " Cleanup Root", cleanup_tree + scoped_behavior( + name=name + " AllowWheelchairCollisionScope", + pre_behavior=allow_wheelchair_collision, + main_behaviors=[move_to_target_pose], + post_behavior_fn=gen_disallow_wheelchair_collision, ), ], ) diff --git a/ada_feeding/package.xml b/ada_feeding/package.xml index b59a917f..134946f4 100644 --- a/ada_feeding/package.xml +++ b/ada_feeding/package.xml @@ -12,6 +12,7 @@ ament_lint_auto ament_lint_common + ament_cmake_pytest rcl_interfaces ros2launch diff --git a/ada_feeding/tests/__init__.py b/ada_feeding/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ada_feeding/tests/helpers.py b/ada_feeding/tests/helpers.py new file mode 100644 index 00000000..b7545f2e --- /dev/null +++ b/ada_feeding/tests/helpers.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 +""" +This module defines unit tests for the eventually_swiss idiom. +""" + +# Standard imports +import time +from typing import List, Optional, Union + +# Third-party imports +import py_trees +from py_trees.blackboard import Blackboard + + +class TickCounterWithTerminateTimestamp(py_trees.behaviours.TickCounter): + """ + This class is identical to TickCounter, except that it also stores the + timestamp when the behavior terminated. + """ + + def __init__( + self, + name: str, + duration: int, + completion_status: py_trees.common.Status, + ns: str = "/", + ): + """ + Initialise the behavior. + """ + super().__init__( + name=name, duration=duration, completion_status=completion_status + ) + self.termination_new_status = None + self.termination_timestamp = None + self.num_times_ticked_to_non_running_status = 0 + + # Create a blackboard client to store this behavior's status, + # counter, termination_new_status, and termination_timestamp. + self.blackboard = self.attach_blackboard_client(name=name, namespace=ns) + self.blackboard.register_key(key="status", access=py_trees.common.Access.WRITE) + self.blackboard.register_key(key="counter", access=py_trees.common.Access.WRITE) + self.blackboard.register_key( + key="num_times_ticked_to_non_running_status", + access=py_trees.common.Access.WRITE, + ) + self.blackboard.register_key( + key="termination_new_status", access=py_trees.common.Access.WRITE + ) + self.blackboard.register_key( + key="termination_timestamp", access=py_trees.common.Access.WRITE + ) + + # Initialize the blackboard + self.blackboard.status = py_trees.common.Status.INVALID + self.blackboard.counter = 0 + self.blackboard.termination_new_status = self.termination_new_status + self.blackboard.termination_timestamp = self.termination_timestamp + + def initialise(self) -> None: + """Reset the tick counter.""" + self.counter = 0 + + # Update the blackboard. + self.blackboard.status = self.status + self.blackboard.counter = self.counter + self.blackboard.termination_new_status = self.termination_new_status + self.blackboard.termination_timestamp = self.termination_timestamp + + def update(self) -> py_trees.common.Status: + """ + Update the behavior. + """ + new_status = super().update() + + if new_status != py_trees.common.Status.RUNNING: + self.num_times_ticked_to_non_running_status += 1 + + # Update the blackboard. + self.blackboard.status = new_status + self.blackboard.counter = self.counter + self.blackboard.num_times_ticked_to_non_running_status = ( + self.num_times_ticked_to_non_running_status + ) + + return new_status + + def terminate(self, new_status: py_trees.common.Status) -> None: + """ + Terminate the behavior. + """ + self.termination_new_status = new_status + self.termination_timestamp = time.time() + + # Update the blackboard. + self.blackboard.termination_new_status = self.termination_new_status + self.blackboard.termination_timestamp = self.termination_timestamp + # Although self.status will be set in the `stop` method that called + # this, it won't set on the blackboard. So we set that here. + self.blackboard.status = new_status + + +def check_count_status( + behaviors: List[Union[TickCounterWithTerminateTimestamp, str]], + counts: List[int], + statuses: List[py_trees.common.Status], + num_times_ticked_to_non_running_statuses: List[int], + descriptor: str = "", +) -> None: + """ + Takes in a list of TickCounter behaviors and checks that their counts and + statuses are correct. + + Parameters + ---------- + behaviors: The list of behaviors to check. The values are either behaviors, + in which case the attributes of the behavior will directly be checked, + or strings, which is the blackboard namespace where the behavior has + stored its attributes. + counts: The expected counts for each behavior. + statuses: The expected statuses for each behavior. + num_times_ticked_to_non_running_statuses: The expected number of times each + behavior had a tick resulting in a non-running status. + """ + assert ( + len(behaviors) == len(counts) == len(statuses) + ), "lengths of behaviors, counts, and statuses must be equal" + + for i, behavior in enumerate(behaviors): + # Get the actual count and status + if isinstance(behavior, str): + name = behavior + actual_count = Blackboard().get( + Blackboard.separator.join([behavior, "counter"]) + ) + actual_status = Blackboard().get( + Blackboard.separator.join([behavior, "status"]) + ) + actual_num_times_ticked_to_non_running_statuses = Blackboard().get( + Blackboard.separator.join( + [behavior, "num_times_ticked_to_non_running_status"] + ) + ) + else: + name = behavior.name + actual_count = behavior.counter + actual_status = behavior.status + actual_num_times_ticked_to_non_running_statuses = ( + behavior.num_times_ticked_to_non_running_status + ) + + # Check the actual count and status against the expected ones + assert actual_count == counts[i], ( + f"behavior '{name}' actual count {actual_count}, " + f"expected count {counts[i]}, " + f"{descriptor}" + ) + assert actual_status == statuses[i], ( + f"behavior '{name}' actual status {actual_status}, " + f"expected status {statuses[i]}, " + f"{descriptor}" + ) + assert ( + actual_num_times_ticked_to_non_running_statuses + == num_times_ticked_to_non_running_statuses[i] + ), ( + f"behavior '{name}' actual num_times_ticked_to_non_running_statuses " + f"{actual_num_times_ticked_to_non_running_statuses}, " + f"expected num_times_ticked_to_non_running_statuses " + f"{num_times_ticked_to_non_running_statuses[i]}, " + f"{descriptor}" + ) + + +def check_termination_new_statuses( + behaviors: List[Union[TickCounterWithTerminateTimestamp, str]], + statuses: List[Optional[py_trees.common.Status]], + descriptor: str = "", +) -> None: + """ + Checkes that `terminate` either has not been called on the behavior, or + that it has been called with the correct new status. + + Parameters + ---------- + behaviors: The list of behaviors to check. The values are either behaviors, + in which case the attributes of the behavior will directly be checked, + or strings, which is the blackboard namespace where the behavior has + stored its attributes. + statuses: The expected new statuses for each behavior when `terminate` was + called, or `None` if `terminate` was not expected to be called. + """ + assert len(behaviors) == len( + statuses + ), "lengths of behaviors and statuses must be equal" + + for i, behavior in enumerate(behaviors): + # Get the actual termination_new_status + if isinstance(behavior, str): + name = behavior + actual_termination_new_status = Blackboard().get( + Blackboard.separator.join([behavior, "termination_new_status"]) + ) + else: + name = behavior.name + actual_termination_new_status = behavior.termination_new_status + + # Check the actual termination_new_status against the expected one + if statuses[i] is None: + assert actual_termination_new_status is None, ( + f"behavior '{name}' expected termination_new_status None, actual " + f"termination_new_status {actual_termination_new_status}, " + f"{descriptor}" + ) + else: + assert actual_termination_new_status == statuses[i], ( + f"behavior '{name}' actual termination_new_status " + f"{actual_termination_new_status}, expected termination_new_status " + f"{statuses[i]}, {descriptor}" + ) + + +def check_termination_order( + behaviors: List[Union[TickCounterWithTerminateTimestamp, str]], + descriptor: str = "", +) -> None: + """ + Checks that the behaviors terminated in the correct order. + + Parameters + ---------- + behaviors: The list of behaviors to check, in the order that `terminate` + should have been called on them. The values are either behaviors, in + which case the attributes of the behavior will directly be checked, or + strings, which is the blackboard namespace where the behavior has stored + its attributes. + """ + for i in range(len(behaviors) - 1): + # Get the actual termination_timestamp + if isinstance(behaviors[i], str): + curr_name = behaviors[i] + actual_curr_termination_timestamp = Blackboard().get( + Blackboard.separator.join([behaviors[i], "termination_timestamp"]) + ) + else: + curr_name = behaviors[i].name + actual_curr_termination_timestamp = behaviors[i].termination_timestamp + + if isinstance(behaviors[i + 1], str): + next_name = behaviors[i + 1] + actual_next_termination_timestamp = Blackboard().get( + Blackboard.separator.join([behaviors[i + 1], "termination_timestamp"]) + ) + else: + next_name = behaviors[i + 1].name + actual_next_termination_timestamp = behaviors[i + 1].termination_timestamp + + # Check the actual termination_timestamp against the expected one + assert actual_curr_termination_timestamp <= actual_next_termination_timestamp, ( + f"behavior '{curr_name}' terminated after behavior " + f"'{next_name}', when it should have terminated before, " + f"{descriptor}" + ) diff --git a/ada_feeding/tests/test_eventually_swiss.py b/ada_feeding/tests/test_eventually_swiss.py new file mode 100644 index 00000000..b5a5e8c0 --- /dev/null +++ b/ada_feeding/tests/test_eventually_swiss.py @@ -0,0 +1,462 @@ +#!/usr/bin/env python3 +""" +This module defines unit tests for the eventually_swiss idiom. +""" + +# Standard imports +from enum import Enum +from functools import partial +from typing import List + +# Third-party imports +import py_trees + +# Local imports +from ada_feeding.idioms import eventually_swiss +from .helpers import ( + TickCounterWithTerminateTimestamp, + check_count_status, + check_termination_new_statuses, + check_termination_order, +) + +# pylint: disable=duplicate-code +# `test_scoped_behavior` and `test_eventually_swiss` have similar code because +# they are similar idioms. That is okay. +# pylint: disable=redefined-outer-name +# When generating tests, we use global variables with the same names as +# variables in the functions. That is okay, since the functions don't need +# access to the global variables. + + +class ExecutionCase(Enum): + """ + Tree execution can broadly fall into one of the below cases. + """ + + NONE = 0 + HASNT_STARTED = 1 + WORKER_RUNNING = 2 + WORKER_TERMINATED_CALLBACK_RUNNING = 3 + TREE_TERMINATED = 4 + + +def generate_test( + worker_duration: int, + worker_completion_status: py_trees.common.Status, + on_success_duration: int, + on_success_completion_status: py_trees.common.Status, + on_failure_duration: int, + on_failure_completion_status: py_trees.common.Status, + on_preempt_duration: int, + on_preempt_completion_status: py_trees.common.Status, + return_on_success_status: bool, +): + """ + Generates a worker, on_success, on_failure, and on_preempt behavior with the + specified durations and completion statuses. + + Note that this always generates the multi-tick version of eventually_swiss. + + Parameters + ---------- + worker_duration: The number of ticks it takes for the worker to terminate. + worker_completion_status: The completion status of the worker. + on_success_duration: The number of ticks it takes for `on_success` to terminate. + on_success_completion_status: The completion status of `on_success`. + on_failure_duration: The number of ticks it takes for `on_failure` to terminate. + on_failure_completion_status: The completion status of `on_failure`. + on_preempt_duration: The number of ticks it takes for `on_preempt` to terminate. + on_preempt_completion_status: The completion status of `on_preempt`. + return_on_success_status: If True, return `on_success` status. Else, return + """ + # pylint: disable=too-many-arguments + # Necessary to create a versatile test generation function. + + # Setup the test + worker = TickCounterWithTerminateTimestamp( + name="Worker", + duration=worker_duration, + completion_status=worker_completion_status, + ns="/worker", + ) + on_success = TickCounterWithTerminateTimestamp( + name="On Success", + duration=on_success_duration, + completion_status=on_success_completion_status, + ns="/on_success", + ) + on_failure = TickCounterWithTerminateTimestamp( + name="On Failure", + duration=on_failure_duration, + completion_status=on_failure_completion_status, + ns="/on_failure", + ) + on_preempt = TickCounterWithTerminateTimestamp( + name="On Preempt", + duration=on_preempt_duration, + completion_status=on_preempt_completion_status, + ns="/on_preempt", + ) + root = eventually_swiss( + name="Eventually Swiss", + workers=[worker], + on_success=on_success, + on_failure=on_failure, + on_preempt=on_preempt, + on_preempt_single_tick=False, + return_on_success_status=return_on_success_status, + ) + return root, worker, on_success, on_failure, on_preempt + + +def combined_test( + worker_completion_status: py_trees.common.Status, + callback_completion_status: py_trees.common.Status, + global_num_cycles: int = 2, + preempt_times: List[ExecutionCase] = [ExecutionCase.NONE, ExecutionCase.NONE], + return_on_success_status: bool = True, +) -> None: + """ + This function ticks the root to completion `global_num_cycles` times and checks + the following three cases: + + Case WORKER_RUNNING: + - While the worker is RUNNING, `on_success`, `on_failure`, and `on_preempt` + should not be ticked, none of the functions should be terminated, and the + root should be running. + + When the worker succeeds: + Case WORKER_TERMINATED_CALLBACK_RUNNING: + - While `on_success` is RUNNING, `on_failure` and `on_preempt` should + not be ticked, none of the functions should be terminated, and the + root should be running. + Case TREE_TERMINATED: + - When `on_success` returns `callback_completion_status`, `on_failure` + and `on_preempt` should not be ticked, none of the functions should + be terminated, and the root should return `callback_completion_status`. + + When the worker fails: + Case WORKER_TERMINATED_CALLBACK_RUNNING: + - While `on_failure` is RUNNING, `on_success` and `on_preempt` should + not be ticked, none of the functions should be terminated, and the + root should be running. + Case TREE_TERMINATED: + - When `on_failure` returns `callback_completion_status`, `on_success` + and `on_preempt` should not be ticked, none of the functions should + be terminated, and the root should return FAILURE. + + Additionally, this function can terminate the tree up to once per cycle. + For cycle i, this function will terminate the tree depending on the value of + `preempt_times[i]`: + - None: Don't terminate this cycle. + - WORKER_RUNNING: Terminate the tree after the first tick when the worker + is RUNNING. + - WORKER_TERMINATED_CALLBACK_RUNNING: Terminate the tree after the first + tick when the worker has terminated and the callback is RUNNING. + - TREE_TERMINATED: Terminate the tree after the tick + when the worker has terminated and the callback has terminated (i.e., after + the tick where the root returns a non-RUNNING status). + After terminating the tree, in the first two cases, this function checks that + the tree has not ticked `worker`, `on_success`, or `on_failure` any more, but + has ticked `on_preempt` to completion. It also checks that the tree has + terminated in the correct order: `worker` -> `on_success`/`on_failure` -> `on_preempt`. + In the third case, since the tree has already reached a non-RUNNING status, + `on_preempt` should not be run, and this function verifies that. + + Parameters + ---------- + worker_completion_status: The completion status of the worker. + callback_completion_status: The completion status of the callback. + global_num_cycles: The number of times to tick the root to completion. + preempt_times: A list of ExecutionCase values, one for each cycle. If None, + don't preempt the tree during that cycle. + """ + # pylint: disable=too-many-locals, too-many-branches, too-many-statements + # pylint: disable=too-many-nested-blocks, too-many-arguments + # This is where the bulk of the work to test eventually_swiss is done, so + # it's hard to reduce the number of locals, branches, and statements. + # pylint: disable=dangerous-default-value + # A default value of a list is fine in this case. + + assert len(preempt_times) >= global_num_cycles, "Malformed test case." + + # Setup the test + worker_duration = 2 + callback_duration = 3 + other_callbacks_completion_status = py_trees.common.Status.SUCCESS + root, worker, on_success, on_failure, on_preempt = generate_test( + worker_duration=worker_duration, + worker_completion_status=worker_completion_status, + on_success_duration=callback_duration, + on_success_completion_status=( + callback_completion_status + if worker_completion_status == py_trees.common.Status.SUCCESS + else other_callbacks_completion_status + ), + on_failure_duration=callback_duration, + on_failure_completion_status=( + callback_completion_status + if worker_completion_status == py_trees.common.Status.FAILURE + else other_callbacks_completion_status + ), + on_preempt_duration=callback_duration, + on_preempt_completion_status=other_callbacks_completion_status, + return_on_success_status=return_on_success_status, + ) + + # Get the number of ticks it should take to terminate this tree. + num_ticks_to_terminate = worker_duration + callback_duration + 1 + + # Initialize the expected counts, statuses, termination_new_statuses, and + # root status for the tests + behaviors = [worker, on_success, on_failure, on_preempt] + expected_counts = [0, 0, 0, 0] + expected_statuses = [ + py_trees.common.Status.INVALID, + py_trees.common.Status.INVALID, + py_trees.common.Status.INVALID, + py_trees.common.Status.INVALID, + ] + expected_num_times_ticked_to_non_running_statuses = [0, 0, 0, 0] + expected_termination_new_statuses = [None, None, None, None] + + # Tick the tree + for num_cycles in range(global_num_cycles): + execution_case = ExecutionCase.HASNT_STARTED + for num_ticks in range(1, num_ticks_to_terminate + 2): + descriptor = f"num_ticks {num_ticks}, num_cycles {num_cycles}" + + # Preempt if requested + if preempt_times[num_cycles] == execution_case: + root.stop(py_trees.common.Status.INVALID) + descriptor += " after preemption" + + # Update the expected termination of all behaviors but `on_preempt` + termination_order_on_success = [] + termination_order_on_failure = [] + for i in range(3): + if expected_statuses[i] != py_trees.common.Status.INVALID: + expected_statuses[i] = py_trees.common.Status.INVALID + expected_termination_new_statuses[ + i + ] = py_trees.common.Status.INVALID + if i == 1: + termination_order_on_success.append(behaviors[i]) + elif i == 2: + termination_order_on_failure.append(behaviors[i]) + else: + termination_order_on_success.append(behaviors[i]) + termination_order_on_failure.append(behaviors[i]) + root_expected_status = py_trees.common.Status.INVALID + # `on_preempt` should only get ticked if the worker/callback + # have not yet terminated. If they have terminated, the root + # is considered complete and there is no reason to run `on_preempt`. + if execution_case in [ + ExecutionCase.WORKER_RUNNING, + ExecutionCase.WORKER_TERMINATED_CALLBACK_RUNNING, + ]: + # `on_preempt` should get ticked to completion + expected_counts[3] = callback_duration + 1 + expected_num_times_ticked_to_non_running_statuses[3] += 1 + + # Because `on_preempt` is not officially a part of the tree, + # it won't get called as part of the preemption. So it's + # status will be its terminal status. + expected_statuses[3] = other_callbacks_completion_status + expected_termination_new_statuses[ + 3 + ] = other_callbacks_completion_status + termination_order_on_success.append(behaviors[3]) + termination_order_on_failure.append(behaviors[3]) + + # Run the preemption tests + check_count_status( + behaviors=behaviors, + counts=expected_counts, + statuses=expected_statuses, + num_times_ticked_to_non_running_statuses=( + expected_num_times_ticked_to_non_running_statuses + ), + descriptor=descriptor, + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=expected_termination_new_statuses, + descriptor=descriptor, + ) + check_termination_order(termination_order_on_success, descriptor) + check_termination_order(termination_order_on_failure, descriptor) + assert ( + root.status == root_expected_status + ), f"root status {root.status} is not {root_expected_status}, {descriptor}" + + # End this cycle + break + + if num_ticks == num_ticks_to_terminate + 1: + # End this cycle. We only go past the ticks to terminate in case + # the tree is preempted after termination. + break + if num_ticks == 1: + # The worker's count gets re-initialized at the beginning of every cycle. + expected_counts[0] = 0 + # The worker, on_success, and on_failure get reset to INVALID at the + # beginning of every cycle. + if num_cycles > 0: + if expected_statuses[0] != py_trees.common.Status.INVALID: + expected_statuses[0] = py_trees.common.Status.INVALID + expected_termination_new_statuses[ + 0 + ] = py_trees.common.Status.INVALID + if ( + worker_completion_status == py_trees.common.Status.SUCCESS + and expected_statuses[1] != py_trees.common.Status.INVALID + ): + expected_termination_new_statuses[ + 1 + ] = py_trees.common.Status.INVALID + expected_statuses[1] = py_trees.common.Status.INVALID + if ( + worker_completion_status == py_trees.common.Status.FAILURE + and expected_statuses[2] != py_trees.common.Status.INVALID + ): + expected_termination_new_statuses[ + 2 + ] = py_trees.common.Status.INVALID + expected_statuses[2] = py_trees.common.Status.INVALID + + # Tick the tree + root.tick_once() + # Get the expected counts, statuses, termination_new_statuses, and + # root status. + # The worker is still running. WORKER_RUNNING case. + if num_ticks <= worker_duration: + execution_case = ExecutionCase.WORKER_RUNNING + expected_counts[0] += 1 # The worker should have gotten ticked. + expected_statuses[0] = py_trees.common.Status.RUNNING + root_expected_status = py_trees.common.Status.RUNNING + # The worker has terminated, but the success/failure callback is still running. + # WORKER_TERMINATED_CALLBACK_RUNNING case. + elif num_ticks <= worker_duration + callback_duration: + execution_case = ExecutionCase.WORKER_TERMINATED_CALLBACK_RUNNING + if num_ticks == worker_duration + 1: + # The worker terminates on the first tick after `worker_duration` + expected_counts[0] += 1 + expected_num_times_ticked_to_non_running_statuses[0] += 1 + # on_success and on_failure only gets reinitialized after the + # worker terminates. + expected_counts[1] = 0 + expected_counts[2] = 0 + # The worker status gets set + expected_statuses[0] = worker_completion_status + expected_termination_new_statuses[0] = worker_completion_status + elif worker_completion_status == py_trees.common.Status.FAILURE: + # The Selector with memory unnecessarily sets previous children to + # INVALID the tick after they fail, hence the below switch. + # https://github.com/splintered-reality/py_trees/blob/0d5b39f2f6333c504406d8a63052c456c6bd1ce5/py_trees/composites.py#L427 + expected_statuses[0] = py_trees.common.Status.INVALID + expected_termination_new_statuses[ + 0 + ] = py_trees.common.Status.INVALID + if worker_completion_status == py_trees.common.Status.SUCCESS: + expected_counts[1] += 1 + expected_statuses[1] = py_trees.common.Status.RUNNING + elif worker_completion_status == py_trees.common.Status.FAILURE: + expected_counts[2] += 1 + expected_statuses[2] = py_trees.common.Status.RUNNING + else: + assert ( + False + ), f"Unexpected worker_completion_status {worker_completion_status}." + root_expected_status = py_trees.common.Status.RUNNING + # The success/failure callback has terminated. + # TREE_TERMINATED case. + elif num_ticks == num_ticks_to_terminate: + execution_case = ExecutionCase.TREE_TERMINATED + if worker_completion_status == py_trees.common.Status.SUCCESS: + expected_counts[1] += 1 + expected_statuses[1] = callback_completion_status + expected_num_times_ticked_to_non_running_statuses[1] += 1 + expected_termination_new_statuses[1] = callback_completion_status + elif worker_completion_status == py_trees.common.Status.FAILURE: + expected_counts[2] += 1 + expected_statuses[2] = callback_completion_status + expected_num_times_ticked_to_non_running_statuses[2] += 1 + expected_termination_new_statuses[2] = callback_completion_status + else: + assert ( + False + ), f"Unexpected worker_completion_status {worker_completion_status}." + if worker_completion_status == py_trees.common.Status.SUCCESS: + if return_on_success_status: + root_expected_status = callback_completion_status + else: + root_expected_status = py_trees.common.Status.SUCCESS + elif worker_completion_status == py_trees.common.Status.FAILURE: + root_expected_status = py_trees.common.Status.FAILURE + else: + assert ( + False + ), f"Unexpected worker_completion_status {worker_completion_status}." + else: + assert False, ( + f"Should not get here, num_ticks {num_ticks}, " + f"num_cycles {num_cycles}" + ) + + # Run the tests + check_count_status( + behaviors=behaviors, + counts=expected_counts, + statuses=expected_statuses, + num_times_ticked_to_non_running_statuses=( + expected_num_times_ticked_to_non_running_statuses + ), + descriptor=descriptor, + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=expected_termination_new_statuses, + descriptor=descriptor, + ) + assert ( + root.status == root_expected_status + ), f"root status {root.status} is not {root_expected_status}, {descriptor}" + + +################################################################################ +# Generate all tests with 2 cycles +################################################################################ + +# Set the cases to iterate over +status_cases = [ + py_trees.common.Status.SUCCESS, + py_trees.common.Status.FAILURE, +] +preempt_cases = [ + ExecutionCase.NONE, + ExecutionCase.HASNT_STARTED, + ExecutionCase.WORKER_RUNNING, + ExecutionCase.WORKER_TERMINATED_CALLBACK_RUNNING, + ExecutionCase.TREE_TERMINATED, +] +for worker_completion_status in status_cases: + for callback_completion_status in status_cases: + for return_on_success_status in [True, False]: + for first_preempt in preempt_cases: + for second_preempt in preempt_cases: + test_name = ( + f"test_worker_{worker_completion_status.name}_callback_" + f"{callback_completion_status.name}_ret_succ_" + f"{return_on_success_status}_" + f"first_preempt_{first_preempt.name}_second_preempt_" + f"{second_preempt.name}" + ) + globals()[test_name] = partial( + combined_test, + worker_completion_status=worker_completion_status, + callback_completion_status=callback_completion_status, + preempt_times=[first_preempt, second_preempt], + return_on_success_status=return_on_success_status, + ) diff --git a/ada_feeding/tests/test_scoped_behavior.py b/ada_feeding/tests/test_scoped_behavior.py new file mode 100644 index 00000000..131629a1 --- /dev/null +++ b/ada_feeding/tests/test_scoped_behavior.py @@ -0,0 +1,699 @@ +#!/usr/bin/env python3 +""" +This module defines unit tests for the scoped_behaviour idiom. +""" + +# Standard imports +from enum import Enum +from functools import partial +from typing import List, Optional + +# Third-party imports +import py_trees + +# Local imports +from ada_feeding.idioms import scoped_behavior +from .helpers import ( + TickCounterWithTerminateTimestamp, + check_count_status, + check_termination_new_statuses, + check_termination_order, +) + +# pylint: disable=duplicate-code +# `test_scoped_behavior` and `test_eventually_swiss` have similar code because +# they are similar idioms. That is okay. +# pylint: disable=redefined-outer-name +# When generating tests, we use global variables with the same names as +# variables in the functions. That is okay, since the functions don't need +# access to the global variables. + + +class ExecutionCase(Enum): + """ + Tree execution can broadly fall into one of the below cases. + """ + + NONE = 0 + HASNT_STARTED = 1 + PRE_RUNNING = 2 + WORKERS_RUNNING = 3 + POST_RUNNING = 4 + TREE_TERMINATED = 5 + + +def generate_test( + pre_duration: int, + pre_completion_status: py_trees.common.Status, + worker_duration: int, + worker_completion_status: py_trees.common.Status, + post_duration: int, + post_completion_status: py_trees.common.Status, + worker_override: Optional[py_trees.behaviour.Behaviour] = None, + suffix: str = ",", +): + """ + Generates a worker, pre, and post behavior with the + specified durations and completion statuses. + + Parameters + ---------- + pre_duration: The duration of the pre behavior. + pre_completion_status: The completion status of the pre behavior. + worker_duration: The duration of the worker behavior. + worker_completion_status: The completion status of the worker behavior. + post_duration: The duration of the post behavior. + post_completion_status: The completion status of the post behavior. + worker_override: If not None, this behavior will be used instead of the + default worker behavior. + """ + # pylint: disable=too-many-arguments + # Necessary to create a versatile test generation function. + + # Setup the test + pre = TickCounterWithTerminateTimestamp( + name="Pre" + suffix, + duration=pre_duration, + completion_status=pre_completion_status, + ns="/pre" + suffix, + ) + if worker_override is None: + worker = TickCounterWithTerminateTimestamp( + name="Worker" + suffix, + duration=worker_duration, + completion_status=worker_completion_status, + ns="/worker" + suffix, + ) + else: + worker = worker_override + + post = TickCounterWithTerminateTimestamp( + name="Post" + suffix, + duration=post_duration, + completion_status=post_completion_status, + ns="/post" + suffix, + ) + + root = scoped_behavior( + name="Root" + suffix, + pre_behavior=pre, + workers=[worker], + post_behavior=post, + ) + return root, pre, worker, post + + +def combined_test( + pre_completion_status: py_trees.common.Status, + worker_completion_status: py_trees.common.Status, + post_completion_status: py_trees.common.Status, + global_num_cycles: int = 2, + preempt_times: List[ExecutionCase] = [ExecutionCase.NONE, ExecutionCase.NONE], +) -> None: + """ + This function ticks the root to completion `global_num_cycles` times and checks + the following three cases: + + Case PRE_RUNNING: + - While `pre` is RUNNING, `worker` and `post` should not be ticked, none of + the functions should be terminated, and the root should be running. + + Case WORKERS_RUNNING: + - While `worker` is RUNNING, `post` should not be ticked, only `pre` should be + terminated, and the root should be running. + + Case POST_RUNNING: + - While `post` is RUNNING, only `pre` and `worker` should be terminated, and + the root should be running. + + Case TREE_TERMINATED: + - When the root returns a non-RUNNING status, `pre`, `worker`, and `post` + should be terminated, and the root should return the correct status. + + Additionally, this function can terminate the tree up to once per cycle. + For cycle i, this function will terminate the tree depending on the value of + `preempt_times[i]`: + - None: Don't terminate this cycle. + - PRE_RUNNING: Terminate the tree after the first tick when the pre is + RUNNING. + - WORKERS_RUNNING: Terminate the tree after the first tick when the worker + is RUNNING. + - POST_RUNNING: Terminate the tree after the first tick when the post is + RUNNING. + - TREE_TERMINATED: Terminate the tree after the tick when the root returns + a non-RUNNING status. + After terminating the tree, in the first three cases, this function checks that + the tree has not ticked `pre` or `worker` any more, but + has ticked `post` to completion. It also checks that the tree has + terminated in the correct order: `pre` -> `worker` -> `pose`. + In the third case, since the tree has already reached a non-RUNNING status, + nothing should change other than the statuses of the behaviors. + + Parameters + ---------- + pre_completion_status: The completion status of the pre behavior. + worker_completion_status: The completion status of the worker behavior. + post_completion_status: The completion status of the post behavior. + global_num_cycles: The number of times to tick the tree to completion. + preempt_times: A list of ExecutionCase enums, one for each cycle. + """ + # pylint: disable=too-many-locals, too-many-branches, too-many-statements + # pylint: disable=too-many-nested-blocks + # This is where the bulk of the work to test eventually_swiss is done, so + # it's hard to reduce the number of locals, branches, and statements. + # pylint: disable=dangerous-default-value + # A default value of a list is fine in this case. + + assert len(preempt_times) >= global_num_cycles, "Malformed test case." + + # Setup the test + pre_duration = 3 + worker_duration = 2 + post_duration = 6 + root, pre, worker, post = generate_test( + pre_duration=pre_duration, + pre_completion_status=pre_completion_status, + worker_duration=worker_duration, + worker_completion_status=worker_completion_status, + post_duration=post_duration, + post_completion_status=post_completion_status, + ) + + # Get the number of ticks it should take to terminate this tree. + num_ticks_to_terminate = ( + pre_duration + + ( + worker_duration + if pre_completion_status == py_trees.common.Status.SUCCESS + else 0 + ) + + post_duration + + 1 + ) + + # Initialize the expected counts, statuses, termination_new_statuses, and + # root status for the tests + behaviors = [pre, worker, post] + expected_counts = [0, 0, 0] + expected_statuses = [ + py_trees.common.Status.INVALID, + py_trees.common.Status.INVALID, + py_trees.common.Status.INVALID, + ] + expected_num_times_ticked_to_non_running_statuses = [0, 0, 0] + expected_termination_new_statuses = [None, None, None] + + # Tick the tree + preempted_in_previous_cycle = False + for num_cycles in range(global_num_cycles): + execution_case = ExecutionCase.HASNT_STARTED + for num_ticks in range(1, num_ticks_to_terminate + 2): + descriptor = f"num_ticks {num_ticks}, num_cycles {num_cycles}" + + # Preempt if requested + if preempt_times[num_cycles] == execution_case: + root.stop(py_trees.common.Status.INVALID) + descriptor += " after preemption" + + # Update the expected termination of all behaviors but `post` + termination_order = [] + for i in range(2): + if expected_statuses[i] != py_trees.common.Status.INVALID: + expected_statuses[i] = py_trees.common.Status.INVALID + expected_termination_new_statuses[ + i + ] = py_trees.common.Status.INVALID + termination_order.append(behaviors[i]) + root_expected_status = py_trees.common.Status.INVALID + # `post` should only get ticked on preemption if the worker/callback + # have not yet terminated. If they have terminated, the root + # is considered complete and there is no reason to run `post` again. + if execution_case == ExecutionCase.HASNT_STARTED: + # In this cases, `post` should not get ticked as part of + # preemption. Its status will only be set to INVALID if neither + # it nor its parent is INVALID. The only case where `post` is + # not INVALID but its parent is is if the tree was preempted + # in the previous cycle. + if ( + expected_statuses[2] != py_trees.common.Status.INVALID + and not preempted_in_previous_cycle + ): + expected_statuses[2] = py_trees.common.Status.INVALID + expected_termination_new_statuses[ + 2 + ] = py_trees.common.Status.INVALID + preempted_in_previous_cycle = False + elif execution_case == ExecutionCase.TREE_TERMINATED: + # In this cases, `post` should not get ticked as part of + # preemption. Its status will be set to INVALID through the + # normal termination process. + expected_statuses[2] = py_trees.common.Status.INVALID + expected_termination_new_statuses[ + 2 + ] = py_trees.common.Status.INVALID + else: + preempted_in_previous_cycle = True + # `post` should get ticked to completion + expected_counts[2] = post_duration + 1 + + # Because `post` is not officially a part of the tree, + # it won't get called as part of the preemption. So it's + # status will be its terminal status. + expected_statuses[2] = post_completion_status + expected_num_times_ticked_to_non_running_statuses[2] += 1 + expected_termination_new_statuses[2] = post_completion_status + termination_order.append(behaviors[2]) + + # Run the preemption tests + check_count_status( + behaviors=behaviors, + counts=expected_counts, + statuses=expected_statuses, + num_times_ticked_to_non_running_statuses=( + expected_num_times_ticked_to_non_running_statuses + ), + descriptor=descriptor, + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=expected_termination_new_statuses, + descriptor=descriptor, + ) + check_termination_order(termination_order, descriptor) + assert ( + root.status == root_expected_status + ), f"root status {root.status} is not {root_expected_status}, {descriptor}" + + # End this cycle + break + + if num_ticks == num_ticks_to_terminate + 1: + # End this cycle. We only go past the ticks to terminate in case + # the tree is preempted after termination. + preempted_in_previous_cycle = False + break + if num_ticks == 1: + # The pre's count gets re-initialized at the beginning of every cycle. + expected_counts[0] = 0 + # The pre and worker get reset to INVALID at the + # beginning of every cycle. Post gets reset to INVALID only if the + # last cycle was not preempted. + if num_cycles > 0: + for i in range(3): + if i < 2 or preempt_times[num_cycles - 1] == ExecutionCase.NONE: + expected_statuses[i] = py_trees.common.Status.INVALID + if expected_termination_new_statuses[i] is not None: + expected_termination_new_statuses[ + i + ] = py_trees.common.Status.INVALID + + # Tick the tree + root.tick_once() + # Get the expected counts, statuses, termination_new_statuses, and + # root status. + # The pre is still running. PRE_RUNNING case. + if num_ticks <= pre_duration: + execution_case = ExecutionCase.PRE_RUNNING + expected_counts[0] += 1 # The pre should have gotten ticked. + expected_statuses[0] = py_trees.common.Status.RUNNING + root_expected_status = py_trees.common.Status.RUNNING + # The pre succeeded, but the worker is still running. + # WORKERS_RUNNING case. + elif ( + pre_completion_status == py_trees.common.Status.SUCCESS + and num_ticks <= pre_duration + worker_duration + ): + execution_case = ExecutionCase.WORKERS_RUNNING + if num_ticks == pre_duration + 1: + # The pre terminates on the first tick after `pre_duration` + expected_counts[0] += 1 + # The worker only gets re-initialized after the pre terminates. + expected_counts[1] = 0 + # The pre's status gets set + expected_statuses[0] = pre_completion_status + expected_num_times_ticked_to_non_running_statuses[0] += 1 + expected_termination_new_statuses[0] = pre_completion_status + expected_counts[1] += 1 + expected_statuses[1] = py_trees.common.Status.RUNNING + root_expected_status = py_trees.common.Status.RUNNING + # The pre succeeded and the worker has terminated. + # POST_RUNNING case. + elif ( + pre_completion_status == py_trees.common.Status.SUCCESS + and num_ticks <= pre_duration + worker_duration + post_duration + ): + execution_case = ExecutionCase.POST_RUNNING + if num_ticks == pre_duration + worker_duration + 1: + # The worker terminates on the first tick after `worker_duration` + expected_counts[1] += 1 + # Post only gets reinitialized after the worker terminates. + expected_counts[2] = 0 + # The worker status gets set + expected_statuses[1] = worker_completion_status + expected_num_times_ticked_to_non_running_statuses[1] += 1 + expected_termination_new_statuses[1] = worker_completion_status + expected_counts[2] += 1 + expected_statuses[2] = py_trees.common.Status.RUNNING + root_expected_status = py_trees.common.Status.RUNNING + # The pre failed, but the post is still running. + # POST_RUNNING case. + elif ( + pre_completion_status == py_trees.common.Status.FAILURE + and num_ticks <= pre_duration + post_duration + ): + execution_case = ExecutionCase.POST_RUNNING + if num_ticks == pre_duration + 1: + # The pre terminates on the first tick after `pre_duration` + expected_counts[0] += 1 + # Post only gets reinitialized after the worker terminates. + expected_counts[2] = 0 + # The pre's status gets set + expected_statuses[0] = pre_completion_status + expected_num_times_ticked_to_non_running_statuses[0] += 1 + expected_termination_new_statuses[0] = pre_completion_status + expected_counts[2] += 1 + expected_statuses[2] = py_trees.common.Status.RUNNING + root_expected_status = py_trees.common.Status.RUNNING + # The post has terminated. TREE_TERMINATED case. + elif num_ticks == num_ticks_to_terminate: + execution_case = ExecutionCase.TREE_TERMINATED + expected_counts[2] += 1 + expected_statuses[2] = post_completion_status + expected_num_times_ticked_to_non_running_statuses[2] += 1 + expected_termination_new_statuses[2] = post_completion_status + root_expected_status = ( + py_trees.common.Status.FAILURE + if pre_completion_status == py_trees.common.Status.FAILURE + else worker_completion_status + ) + else: + assert False, ( + f"Should not get here, num_ticks {num_ticks}, " + f"num_cycles {num_cycles}" + ) + + # Run the tests + check_count_status( + behaviors=behaviors, + counts=expected_counts, + statuses=expected_statuses, + num_times_ticked_to_non_running_statuses=( + expected_num_times_ticked_to_non_running_statuses + ), + descriptor=descriptor, + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=expected_termination_new_statuses, + descriptor=descriptor, + ) + assert ( + root.status == root_expected_status + ), f"root status {root.status} is not {root_expected_status}, {descriptor}" + + +################################################################################ +# Generate all tests with 2 cycles +################################################################################ + +# Set the status cases to iterate over +status_cases = [py_trees.common.Status.SUCCESS, py_trees.common.Status.FAILURE] +for pre_completion_status in status_cases: + # Set the preempt cases to iterate over + preempt_cases = [ + ExecutionCase.NONE, + ExecutionCase.HASNT_STARTED, + ExecutionCase.PRE_RUNNING, + ] + if pre_completion_status == py_trees.common.Status.SUCCESS: + preempt_cases.append(ExecutionCase.WORKERS_RUNNING) + preempt_cases += [ExecutionCase.POST_RUNNING, ExecutionCase.TREE_TERMINATED] + + for worker_completion_status in status_cases: + for post_completion_status in status_cases: + for first_preempt in preempt_cases: + for second_preempt in preempt_cases: + test_name = ( + f"test_pre_{pre_completion_status.name}_worker_" + f"{worker_completion_status.name}_post_{post_completion_status.name}_" + f"first_preempt_{first_preempt.name}_second_preempt_{second_preempt.name}" + ) + globals()[test_name] = partial( + combined_test, + pre_completion_status=pre_completion_status, + worker_completion_status=worker_completion_status, + post_completion_status=post_completion_status, + preempt_times=[first_preempt, second_preempt], + ) + +################################################################################ +# Test Nested Scoped Behaviors +################################################################################ + + +class NestedExecutionCase(Enum): + """ + With a single nested sequence, execution can broadly fall into one of the + below cases. + """ + + NONE = 0 + HASNT_STARTED = 1 + PRE1_RUNNING = 2 + PRE2_RUNNING = 3 + WORKERS_RUNNING = 4 + POST1_RUNNING = 5 + POST2_RUNNING = 6 + TREE_TERMINATED = 7 + + +def nested_behavior_tests( + preempt_time: NestedExecutionCase, +): + """ + In the test of nested scope, we will assume all behaviors succeed, because + success/failure was already tested above. We will also only tick the tree + for one cycle, because multiple cycles were tested above. The main goal of + this test is to ensure the following: + - NONE: If the tree is not preempted, both post-behaviors should be ticked + to completion. + - PRE1_RUNNING: If the tree is preempted while pre1 is running, post1 should + be ticked to completion, and post2 should not be ticked. + - PRE2_RUNNING, WORKERS_RUNNING, POST1_RUNNING, POST2_RUNNING: In all of + these cases, post1 and post2 should be ticked to completion. + - TREE_TERMINATED: If the tree is preempted after the tree has terminated, + post1 and post2 should not be ticked. + """ + # pylint: disable=too-many-branches, too-many-statements + # Necessary to test all the cases + + pre1_duration = 3 + pre2_duration = 2 + worker_duration = 2 + post1_duration = 6 + post2_duration = 4 + worker_override, pre2, worker, post2 = generate_test( + pre_duration=pre2_duration, + pre_completion_status=py_trees.common.Status.SUCCESS, + worker_duration=worker_duration, + worker_completion_status=py_trees.common.Status.SUCCESS, + post_duration=post2_duration, + post_completion_status=py_trees.common.Status.SUCCESS, + suffix="2", + ) + root, pre1, _, post1 = generate_test( + pre_duration=pre1_duration, + pre_completion_status=py_trees.common.Status.SUCCESS, + worker_duration=0, + worker_completion_status=py_trees.common.Status.INVALID, + post_duration=post1_duration, + post_completion_status=py_trees.common.Status.SUCCESS, + worker_override=worker_override, + suffix="1", + ) + behaviors = [pre1, pre2, worker, post2, post1] + + # Get the number of ticks to terminate the tree + num_ticks_to_terminate = ( + pre1_duration + + pre2_duration + + worker_duration + + post1_duration + + post2_duration + + 1 + ) + + if preempt_time == NestedExecutionCase.NONE: + for _ in range(num_ticks_to_terminate): + root.tick_once() + check_count_status( + behaviors=behaviors, + counts=[ + pre1_duration + 1, + pre2_duration + 1, + worker_duration + 1, + post2_duration + 1, + post1_duration + 1, + ], + statuses=[py_trees.common.Status.SUCCESS] * 5, + num_times_ticked_to_non_running_statuses=[1] * 5, + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=[py_trees.common.Status.SUCCESS] * 5, + ) + check_termination_order(behaviors) + assert ( + root.status == py_trees.common.Status.SUCCESS + ), f"root status {root.status} is not SUCCESS" + elif preempt_time == NestedExecutionCase.HASNT_STARTED: + root.stop(py_trees.common.Status.INVALID) + check_count_status( + behaviors=behaviors, + counts=[0] * 5, + statuses=[py_trees.common.Status.INVALID] * 5, + num_times_ticked_to_non_running_statuses=[0] * 5, + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=[None] * 5, + ) + elif preempt_time == NestedExecutionCase.PRE1_RUNNING: + for _ in range(1): + root.tick_once() + root.stop(py_trees.common.Status.INVALID) + check_count_status( + behaviors=behaviors, + counts=[1, 0, 0, 0, post1_duration + 1], + statuses=[py_trees.common.Status.INVALID] * 4 + + [py_trees.common.Status.SUCCESS], + num_times_ticked_to_non_running_statuses=[0] * 4 + [1], + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=[py_trees.common.Status.INVALID] + + [None] * 3 + + [py_trees.common.Status.SUCCESS], + ) + check_termination_order([pre1, post1]) + elif preempt_time == NestedExecutionCase.PRE2_RUNNING: + for _ in range(pre1_duration + 1): + root.tick_once() + root.stop(py_trees.common.Status.INVALID) + check_count_status( + behaviors=behaviors, + counts=[pre1_duration + 1, 1, 0, post2_duration + 1, post1_duration + 1], + statuses=[py_trees.common.Status.INVALID] * 3 + + [py_trees.common.Status.SUCCESS] * 2, + num_times_ticked_to_non_running_statuses=[1] + [0] * 2 + [1] * 2, + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=[py_trees.common.Status.INVALID] * 2 + + [None] + + [py_trees.common.Status.SUCCESS] * 2, + ) + check_termination_order([pre2, post2, post1]) + elif preempt_time == NestedExecutionCase.WORKERS_RUNNING: + for _ in range(pre1_duration + pre2_duration + 1): + root.tick_once() + root.stop(py_trees.common.Status.INVALID) + check_count_status( + behaviors=behaviors, + counts=[ + pre1_duration + 1, + pre2_duration + 1, + 1, + post2_duration + 1, + post1_duration + 1, + ], + statuses=[py_trees.common.Status.INVALID] * 3 + + [py_trees.common.Status.SUCCESS] * 2, + num_times_ticked_to_non_running_statuses=[1] * 2 + [0] + [1] * 2, + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=[py_trees.common.Status.INVALID] * 3 + + [py_trees.common.Status.SUCCESS] * 2, + ) + check_termination_order([pre1, pre2, worker, post2, post1]) + elif preempt_time == NestedExecutionCase.POST2_RUNNING: + for _ in range(pre1_duration + pre2_duration + worker_duration + 1): + root.tick_once() + root.stop(py_trees.common.Status.INVALID) + check_count_status( + behaviors=behaviors, + counts=[ + pre1_duration + 1, + pre2_duration + 1, + worker_duration + 1, + post2_duration + 1, + post1_duration + 1, + ], + statuses=[py_trees.common.Status.INVALID] * 3 + + [py_trees.common.Status.SUCCESS] * 2, + num_times_ticked_to_non_running_statuses=[1] * 3 + [1] * 2, + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=[py_trees.common.Status.INVALID] * 3 + + [py_trees.common.Status.SUCCESS] * 2, + ) + check_termination_order([pre1, pre2, worker, post2, post1]) + elif preempt_time == NestedExecutionCase.POST1_RUNNING: + for _ in range( + pre1_duration + pre2_duration + worker_duration + post2_duration + 1 + ): + root.tick_once() + root.stop(py_trees.common.Status.INVALID) + check_count_status( + behaviors=behaviors, + counts=[ + pre1_duration + 1, + pre2_duration + 1, + worker_duration + 1, + post2_duration + 1, + post1_duration + 1, + ], + # This is crucial -- POST2 should get terminated through the standard means, + # not through OnPreempt. + statuses=[py_trees.common.Status.INVALID] * 4 + + [py_trees.common.Status.SUCCESS], + # This is crucial -- POST2 should only get ticked to completion once. + num_times_ticked_to_non_running_statuses=[1] * 5, + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=[py_trees.common.Status.INVALID] * 4 + + [py_trees.common.Status.SUCCESS], + ) + check_termination_order([pre1, pre2, worker, post2, post1]) + elif preempt_time == NestedExecutionCase.TREE_TERMINATED: + for _ in range(num_ticks_to_terminate): + root.tick_once() + root.stop(py_trees.common.Status.INVALID) + check_count_status( + behaviors=behaviors, + counts=[ + pre1_duration + 1, + pre2_duration + 1, + worker_duration + 1, + post2_duration + 1, + post1_duration + 1, + ], + statuses=[py_trees.common.Status.INVALID] * 5, + num_times_ticked_to_non_running_statuses=[1] * 5, + ) + check_termination_new_statuses( + behaviors=behaviors, + statuses=[py_trees.common.Status.INVALID] * 5, + ) + check_termination_order([pre1, pre2, worker, post2, post1]) + + +for preempt_time in NestedExecutionCase: + test_name = f"test_nested_{preempt_time.name}" + globals()[test_name] = partial( + nested_behavior_tests, + preempt_time=preempt_time, + )