Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing ActionableProtocol #116

Merged
merged 1 commit into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 105 additions & 11 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/action.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Copyright (c) 2024 Boston Dynamics AI Institute Inc. All rights reserved.

import inspect
from typing import Any, Callable, Generator, Generic, Iterator, List, Optional, Type, TypeVar, Union, overload
from typing import Any, Callable, Generator, Generic, Iterator, List, Optional, Protocol, Type, TypeVar, Union, overload

import action_msgs.msg
from rclpy.action.client import ActionClient, ClientGoalHandle
from rclpy.node import Node
from rclpy.task import Future

import bdai_ros2_wrappers.scope as scope
from bdai_ros2_wrappers.callables import ComposableCallable, VectorizingCallable
from bdai_ros2_wrappers.callables import ComposableCallable, ComposedCallable, VectorizedCallable, VectorizingCallable
from bdai_ros2_wrappers.futures import FutureConvertible, FutureLike, wait_for_future
from bdai_ros2_wrappers.utilities import Tape

Expand Down Expand Up @@ -46,9 +46,9 @@ class ActionCancelled(ActionException):
pass


ActionGoalT = TypeVar("ActionGoalT")
ActionFeedbackT = TypeVar("ActionFeedbackT")
ActionResultT = TypeVar("ActionResultT")
ActionGoalT = TypeVar("ActionGoalT", contravariant=True)
ActionResultT = TypeVar("ActionResultT", covariant=True)
ActionFeedbackT = TypeVar("ActionFeedbackT", covariant=True)


class ActionFuture(FutureConvertible[ActionResultT], Generic[ActionResultT, ActionFeedbackT]):
Expand Down Expand Up @@ -344,6 +344,100 @@ def _response_callback(cancel_goal_future: Future) -> None:
return cancellation_future


class ActionableProtocol(Protocol[ActionGoalT, ActionResultT, ActionFeedbackT]):
"""Ergonomic protocol to call actions in ROS 2."""

@property
def action_client(self) -> ActionClient:
"""Get the underlying action client."""

def wait_for_server(self, *args: Any, **kwargs: Any) -> bool:
"""Wait for action to become available."""

@overload
def synchronous(
self,
goal: Optional[ActionGoalT] = ...,
*,
feedback_callback: Optional[Callable[[ActionFeedbackT], None]] = ...,
timeout_sec: Optional[float] = ...,
) -> ActionResultT:
"""Invoke action synchronously.

Args:
goal: target action goal, or a default initialized one if none is provided.
feedback_callback: optional action feedback callback.
timeout_sec: optional action timeout, in seconds. If a timeout is specified and it
expires, the action goal will be cancelled and the call will raise. Note this
timeout is local to the caller.

Returns:
the action result.

Raises:
ActionTimeout: if the action timed out.
ActionRejected: if the action was not accepted.
ActionCancelled: if the action was cancelled.
ActionAborted: if the action was aborted.
RuntimeError: if there is an internal server error.
"""

@overload
def synchronous(
self,
goal: Optional[ActionGoalT] = ...,
*,
feedback_callback: Optional[Callable[[ActionFeedbackT], None]] = ...,
timeout_sec: Optional[float] = ...,
nothrow: bool = ...,
) -> Optional[ActionResultT]:
"""Invoke action synchronously.

Args:
goal: target action goal, or a default initialized one if none is provided.
feedback_callback: optional action feedback callback.
timeout_sec: optional action timeout, in seconds. If a timeout is specified and it
expires, the action goal will be cancelled and the call will raise. Note this
timeout is local to the caller.
nothrow: when set, errors do not raise exceptions.

Returns:
the action result or None on timeout or failure.
"""

def asynchronous(
self,
goal: Optional[ActionGoalT] = ...,
*,
track_feedback: Union[int, bool] = ...,
) -> ActionFuture[ActionResultT, ActionFeedbackT]:
"""Invoke action asynchronously.

Args:
goal: target action goal, or a default initialized one if none is provided.
track_feedback: whether and how to track action feedback. Other than a boolean to
enable or disable tracking, a positive integer may be provided to cap feedback buffer
size.

Returns:
the future action outcome.
"""

def compose(self, func: Callable) -> ComposedCallable:
"""Compose this actionable with the given `func`.

Args:
func: callable to be composed, assumed synchronous.

Returns:
the composed generalized callable.
"""

@property
def vectorized(self) -> VectorizedCallable:
"""Get a vectorized version of this actionable."""


class Actionable(Generic[ActionGoalT, ActionResultT, ActionFeedbackT], ComposableCallable, VectorizingCallable):
"""An ergonomic interface to call actions in ROS 2.

Expand Down Expand Up @@ -406,8 +500,8 @@ def synchronous(
feedback_callback: optional action feedback callback.
timeout_sec: optional action timeout, in seconds. If a timeout is specified and it
expires, the action will be cancelled and the call will raise. Note this timeout is
local to the caller. It may take some time for the action to be cancelled, that is
if cancellation is not rejected by the action server.
local to the caller. It may take some time for the action to be cancelled, if cancellation
is not rejected by the action server.

Returns:
the action result.
Expand Down Expand Up @@ -436,8 +530,8 @@ def synchronous(
feedback_callback: optional action feedback callback.
timeout_sec: optional action timeout, in seconds. If a timeout is specified and it
expires, the action will be cancelled and the call will raise. Note this timeout is
local to the caller. It may take some time for the action to be cancelled, that is
if cancellation is not rejected by the action server.
local to the caller. It may take some time for the action to be cancelled, if cancellation
is not rejected by the action server.
nothrow: when set, errors will not raise exceptions.

Returns:
Expand All @@ -454,7 +548,7 @@ def synchronous(
) -> Optional[ActionResultT]:
"""Invoke action synchronously.

Check available overloads documentation.
See `Actionable.synchronous()` overloads documentation.
"""
if goal is None:
goal = self.action_type.Goal()
Expand Down Expand Up @@ -494,7 +588,7 @@ def asynchronous(
may be provided to cap feedback buffer size.

Returns:
the action future.
the future action outcome.
"""
feedback_tape: Optional[Tape[ActionFeedbackT]] = None
if goal is None:
Expand Down
10 changes: 5 additions & 5 deletions bdai_ros2_wrappers/bdai_ros2_wrappers/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def synchronous(
the service response.

Raises:
ServiceTimeout: if the service request timed out and `nothrow` was not set.
ServiceError: if the service request failed and `nothrow` was not set.
ServiceTimeout: if the service request timed out.
ServiceError: if the service request failed.
"""

@overload
Expand All @@ -89,7 +89,7 @@ def synchronous(
the service response or None on timeout or failure.
"""

def asynchronous(self, request: Optional[ServiceRequestT] = None) -> FutureLike[ServiceResponseT]:
def asynchronous(self, request: Optional[ServiceRequestT] = ...) -> FutureLike[ServiceResponseT]:
"""Invoke service asynchronously.

Args:
Expand Down Expand Up @@ -156,8 +156,8 @@ def synchronous(
the service response.

Raises:
ServiceTimeout: if the service request timed out and `nothrow` was not set.
ServiceError: if the service request failed and `nothrow` was not set.
ServiceTimeout: if the service request timed out.
ServiceError: if the service request failed.
"""

@overload
Expand Down
Loading