From 52278e785f66278f154a3dfd04a7c11405744b4a Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Fri, 7 Mar 2025 23:46:22 +0100 Subject: [PATCH] feat: extend ros2 actions tool suite --- src/rai_core/rai/tools/ros2/actions.py | 120 ++++++++++++++++++++++--- 1 file changed, 110 insertions(+), 10 deletions(-) diff --git a/src/rai_core/rai/tools/ros2/actions.py b/src/rai_core/rai/tools/ros2/actions.py index c4168c53d..9edf076ca 100644 --- a/src/rai_core/rai/tools/ros2/actions.py +++ b/src/rai_core/rai/tools/ros2/actions.py @@ -19,14 +19,60 @@ "This is a ROS2 feature. Make sure ROS2 is installed and sourced." ) -from typing import Any, Callable, Dict, Type +import uuid +from collections import defaultdict +from functools import partial +from threading import Lock +from typing import Any, Callable, Dict, List, Type -from langchain_core.tools import BaseTool, tool # type: ignore +from langchain_core.tools import BaseTool, BaseToolkit # type: ignore from pydantic import BaseModel, Field from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage +class ROS2ActionToolkit(BaseToolkit): + name: str = "ros2_action" + description: str = "A toolkit for ROS2 actions" + connector: ROS2ARIConnector + action_results_store: Dict[str, Any] = {} + action_results_store_lock: Lock = Lock() + action_feedbacks_store: Dict[str, List[Any]] = defaultdict(list) + action_feedbacks_store_lock: Lock = Lock() + internal_action_id_mapping: Dict[str, str] = {} + + def get_tools(self) -> List[BaseTool]: + return [ + StartROS2ActionTool( + connector=self.connector, + feedback_callback=self._generic_feedback_callback, + on_done_callback=self._generic_on_done_callback, + ), + CancelROS2ActionTool(connector=self.connector), + GetROS2ActionFeedbackTool( + action_feedbacks_store=self.action_feedbacks_store, + action_feedbacks_store_lock=self.action_feedbacks_store_lock, + internal_action_id_mapping=self.internal_action_id_mapping, + ), + GetROS2ActionResultTool( + action_results_store=self.action_results_store, + action_results_store_lock=self.action_results_store_lock, + internal_action_id_mapping=self.internal_action_id_mapping, + ), + GetROS2ActionIDsTool( + internal_action_id_mapping=self.internal_action_id_mapping + ), + ] + + def _generic_feedback_callback(self, action_id: str, feedback: Any) -> None: + with self.action_feedbacks_store_lock: + self.action_feedbacks_store[action_id].append(feedback) + + def _generic_on_done_callback(self, action_id: str, result: Any) -> None: + with self.action_results_store_lock: + self.action_results_store[action_id] = result + + class StartROS2ActionToolInput(BaseModel): action_name: str = Field(..., description="The name of the action to start") action_type: str = Field(..., description="The type of the action") @@ -37,8 +83,9 @@ class StartROS2ActionToolInput(BaseModel): class StartROS2ActionTool(BaseTool): connector: ROS2ARIConnector - feedback_callback: Callable[[Any], None] = lambda _: None - on_done_callback: Callable[[Any], None] = lambda _: None + feedback_callback: Callable[[Any, str], None] = lambda _, __: None + on_done_callback: Callable[[Any, str], None] = lambda _, __: None + internal_action_id_mapping: Dict[str, str] = {} name: str = "start_ros2_action" description: str = "Start a ROS2 action" args_schema: Type[StartROS2ActionToolInput] = StartROS2ActionToolInput @@ -47,16 +94,61 @@ def _run( self, action_name: str, action_type: str, action_args: Dict[str, Any] ) -> str: message = ROS2ARIMessage(payload=action_args) + action_id = str(uuid.uuid4()) response = self.connector.start_action( message, action_name, - on_feedback=self.feedback_callback, - on_done=self.on_done_callback, + on_feedback=partial(self.feedback_callback, action_id), + on_done=partial(self.on_done_callback, action_id), msg_type=action_type, ) + self.internal_action_id_mapping[response] = action_id return "Action started with ID: " + response +class GetROS2ActionFeedbackToolInput(BaseModel): + action_id: str = Field(..., description="The ID of the action to get feedback for") + + +class GetROS2ActionFeedbackTool(BaseTool): + name: str = "get_ros2_action_feedback" + description: str = "Get the feedback of a ROS2 action by its action ID" + args_schema: Type[GetROS2ActionFeedbackToolInput] = GetROS2ActionFeedbackToolInput + + action_feedbacks_store: Dict[str, List[Any]] + action_feedbacks_store_lock: Lock + internal_action_id_mapping: Dict[str, str] = {} + + def _run(self, action_id: str) -> str: + with self.action_feedbacks_store_lock: + external_action_id = self.internal_action_id_mapping[action_id] + feedbacks = self.action_feedbacks_store[external_action_id] + self.action_feedbacks_store[external_action_id] = [] + return str(feedbacks) + + +class GetROS2ActionResultToolInput(BaseModel): + action_id: str = Field( + ..., description="The id of the action to get the result for" + ) + + +class GetROS2ActionResultTool(BaseTool): + name: str = "get_ros2_action_result" + description: str = "Get the result of a ROS2 action by its id" + args_schema: Type[GetROS2ActionResultToolInput] = GetROS2ActionResultToolInput + + action_results_store: Dict[str, Any] + action_results_store_lock: Lock + internal_action_id_mapping: Dict[str, str] = {} + + def _run(self, action_id: str) -> str: + with self.action_results_store_lock: + external_action_id = self.internal_action_id_mapping[action_id] + result = self.action_results_store[external_action_id] + return str(result) + + class CancelROS2ActionToolInput(BaseModel): action_id: str = Field(..., description="The ID of the action to cancel") @@ -72,7 +164,15 @@ def _run(self, action_id: str) -> str: return f"Action {action_id} cancelled" -@tool -def get_ros2_action_feedback(action_id: str) -> str: - """Get the feedback of a ROS2 action by its action ID""" - raise NotImplementedError("Not implemented") +class GetROS2ActionIDsToolInput(BaseModel): + pass + + +class GetROS2ActionIDsTool(BaseTool): + name: str = "get_ros2_action_ids" + description: str = "Get the IDs of all ROS2 actions" + args_schema: Type[GetROS2ActionIDsToolInput] = GetROS2ActionIDsToolInput + internal_action_id_mapping: Dict[str, str] = {} + + def _run(self) -> str: + return str(list(self.internal_action_id_mapping.keys()))