diff --git a/src/rai_core/rai/tools/ros2/actions.py b/src/rai_core/rai/tools/ros2/actions.py index 9edf076c..0b03c407 100644 --- a/src/rai_core/rai/tools/ros2/actions.py +++ b/src/rai_core/rai/tools/ros2/actions.py @@ -26,7 +26,7 @@ from typing import Any, Callable, Dict, List, Type from langchain_core.tools import BaseTool, BaseToolkit # type: ignore -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage @@ -41,6 +41,10 @@ class ROS2ActionToolkit(BaseToolkit): action_feedbacks_store_lock: Lock = Lock() internal_action_id_mapping: Dict[str, str] = {} + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + def get_tools(self) -> List[BaseTool]: return [ StartROS2ActionTool( diff --git a/src/rai_core/rai/tools/ros2/topics.py b/src/rai_core/rai/tools/ros2/topics.py index 82dbfb34..1ce58e86 100644 --- a/src/rai_core/rai/tools/ros2/topics.py +++ b/src/rai_core/rai/tools/ros2/topics.py @@ -20,14 +20,15 @@ ) import json -from typing import Any, Dict, Literal, Tuple, Type +from typing import Any, Dict, List, Literal, Tuple, Type import rosidl_runtime_py.set_message import rosidl_runtime_py.utilities from cv_bridge import CvBridge from langchain.tools import BaseTool +from langchain_core.tools import BaseToolkit from langchain_core.utils import stringify_dict -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from sensor_msgs.msg import CompressedImage, Image from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage @@ -36,6 +37,26 @@ from rai.tools.ros2.utils import ros2_message_to_dict +class ROS2TopicsToolkit(BaseToolkit): + name: str = "ROS2TopicsToolkit" + description: str = "A toolkit for ROS2 topics" + connector: ROS2ARIConnector + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + def get_tools(self) -> List[BaseTool]: + return [ + PublishROS2MessageTool(connector=self.connector), + ReceiveROS2MessageTool(connector=self.connector), + GetROS2ImageTool(connector=self.connector), + GetROS2TransformTool(connector=self.connector), + GetROS2TopicsNamesAndTypesTool(connector=self.connector), + GetROS2MessageInterfaceTool(connector=self.connector), + ] + + class PublishROS2MessageToolInput(BaseModel): topic: str = Field(..., description="The topic to publish the message to") message: Dict[str, Any] = Field(..., description="The message to publish")