Skip to content

Commit

Permalink
feat: extend ros2 actions tool suite
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejmajek committed Mar 7, 2025
1 parent 4678efa commit 52278e7
Showing 1 changed file with 110 additions and 10 deletions.
120 changes: 110 additions & 10 deletions src/rai_core/rai/tools/ros2/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,60 @@
"This is a ROS2 feature. Make sure ROS2 is installed and sourced."
)

from typing import Any, Callable, Dict, Type
import uuid
from collections import defaultdict
from functools import partial
from threading import Lock
from typing import Any, Callable, Dict, List, Type

from langchain_core.tools import BaseTool, tool # type: ignore
from langchain_core.tools import BaseTool, BaseToolkit # type: ignore
from pydantic import BaseModel, Field

from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage


class ROS2ActionToolkit(BaseToolkit):
name: str = "ros2_action"
description: str = "A toolkit for ROS2 actions"
connector: ROS2ARIConnector
action_results_store: Dict[str, Any] = {}
action_results_store_lock: Lock = Lock()
action_feedbacks_store: Dict[str, List[Any]] = defaultdict(list)
action_feedbacks_store_lock: Lock = Lock()
internal_action_id_mapping: Dict[str, str] = {}

def get_tools(self) -> List[BaseTool]:
return [
StartROS2ActionTool(
connector=self.connector,
feedback_callback=self._generic_feedback_callback,
on_done_callback=self._generic_on_done_callback,
),
CancelROS2ActionTool(connector=self.connector),
GetROS2ActionFeedbackTool(
action_feedbacks_store=self.action_feedbacks_store,
action_feedbacks_store_lock=self.action_feedbacks_store_lock,
internal_action_id_mapping=self.internal_action_id_mapping,
),
GetROS2ActionResultTool(
action_results_store=self.action_results_store,
action_results_store_lock=self.action_results_store_lock,
internal_action_id_mapping=self.internal_action_id_mapping,
),
GetROS2ActionIDsTool(
internal_action_id_mapping=self.internal_action_id_mapping
),
]

def _generic_feedback_callback(self, action_id: str, feedback: Any) -> None:
with self.action_feedbacks_store_lock:
self.action_feedbacks_store[action_id].append(feedback)

def _generic_on_done_callback(self, action_id: str, result: Any) -> None:
with self.action_results_store_lock:
self.action_results_store[action_id] = result


class StartROS2ActionToolInput(BaseModel):
action_name: str = Field(..., description="The name of the action to start")
action_type: str = Field(..., description="The type of the action")
Expand All @@ -37,8 +83,9 @@ class StartROS2ActionToolInput(BaseModel):

class StartROS2ActionTool(BaseTool):
connector: ROS2ARIConnector
feedback_callback: Callable[[Any], None] = lambda _: None
on_done_callback: Callable[[Any], None] = lambda _: None
feedback_callback: Callable[[Any, str], None] = lambda _, __: None
on_done_callback: Callable[[Any, str], None] = lambda _, __: None
internal_action_id_mapping: Dict[str, str] = {}
name: str = "start_ros2_action"
description: str = "Start a ROS2 action"
args_schema: Type[StartROS2ActionToolInput] = StartROS2ActionToolInput
Expand All @@ -47,16 +94,61 @@ def _run(
self, action_name: str, action_type: str, action_args: Dict[str, Any]
) -> str:
message = ROS2ARIMessage(payload=action_args)
action_id = str(uuid.uuid4())
response = self.connector.start_action(
message,
action_name,
on_feedback=self.feedback_callback,
on_done=self.on_done_callback,
on_feedback=partial(self.feedback_callback, action_id),
on_done=partial(self.on_done_callback, action_id),
msg_type=action_type,
)
self.internal_action_id_mapping[response] = action_id
return "Action started with ID: " + response


class GetROS2ActionFeedbackToolInput(BaseModel):
action_id: str = Field(..., description="The ID of the action to get feedback for")


class GetROS2ActionFeedbackTool(BaseTool):
name: str = "get_ros2_action_feedback"
description: str = "Get the feedback of a ROS2 action by its action ID"
args_schema: Type[GetROS2ActionFeedbackToolInput] = GetROS2ActionFeedbackToolInput

action_feedbacks_store: Dict[str, List[Any]]
action_feedbacks_store_lock: Lock
internal_action_id_mapping: Dict[str, str] = {}

def _run(self, action_id: str) -> str:
with self.action_feedbacks_store_lock:
external_action_id = self.internal_action_id_mapping[action_id]
feedbacks = self.action_feedbacks_store[external_action_id]
self.action_feedbacks_store[external_action_id] = []
return str(feedbacks)


class GetROS2ActionResultToolInput(BaseModel):
action_id: str = Field(
..., description="The id of the action to get the result for"
)


class GetROS2ActionResultTool(BaseTool):
name: str = "get_ros2_action_result"
description: str = "Get the result of a ROS2 action by its id"
args_schema: Type[GetROS2ActionResultToolInput] = GetROS2ActionResultToolInput

action_results_store: Dict[str, Any]
action_results_store_lock: Lock
internal_action_id_mapping: Dict[str, str] = {}

def _run(self, action_id: str) -> str:
with self.action_results_store_lock:
external_action_id = self.internal_action_id_mapping[action_id]
result = self.action_results_store[external_action_id]
return str(result)


class CancelROS2ActionToolInput(BaseModel):
action_id: str = Field(..., description="The ID of the action to cancel")

Expand All @@ -72,7 +164,15 @@ def _run(self, action_id: str) -> str:
return f"Action {action_id} cancelled"


@tool
def get_ros2_action_feedback(action_id: str) -> str:
"""Get the feedback of a ROS2 action by its action ID"""
raise NotImplementedError("Not implemented")
class GetROS2ActionIDsToolInput(BaseModel):
pass


class GetROS2ActionIDsTool(BaseTool):
name: str = "get_ros2_action_ids"
description: str = "Get the IDs of all ROS2 actions"
args_schema: Type[GetROS2ActionIDsToolInput] = GetROS2ActionIDsToolInput
internal_action_id_mapping: Dict[str, str] = {}

def _run(self) -> str:
return str(list(self.internal_action_id_mapping.keys()))

0 comments on commit 52278e7

Please sign in to comment.