Skip to content

Commit

Permalink
feat: ROS2 topic toolkit
Browse files Browse the repository at this point in the history
fix: allow arbitrary types in toolkit definitions
  • Loading branch information
maciejmajek committed Mar 7, 2025
1 parent 52278e7 commit a44d590
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/rai_core/rai/tools/ros2/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import Any, Callable, Dict, List, Type

from langchain_core.tools import BaseTool, BaseToolkit # type: ignore
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage

Expand All @@ -41,6 +41,10 @@ class ROS2ActionToolkit(BaseToolkit):
action_feedbacks_store_lock: Lock = Lock()
internal_action_id_mapping: Dict[str, str] = {}

model_config = ConfigDict(
arbitrary_types_allowed=True,
)

def get_tools(self) -> List[BaseTool]:
return [
StartROS2ActionTool(
Expand Down
25 changes: 23 additions & 2 deletions src/rai_core/rai/tools/ros2/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
)

import json
from typing import Any, Dict, Literal, Tuple, Type
from typing import Any, Dict, List, Literal, Tuple, Type

import rosidl_runtime_py.set_message
import rosidl_runtime_py.utilities
from cv_bridge import CvBridge
from langchain.tools import BaseTool
from langchain_core.tools import BaseToolkit
from langchain_core.utils import stringify_dict
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from sensor_msgs.msg import CompressedImage, Image

from rai.communication.ros2.connectors import ROS2ARIConnector, ROS2ARIMessage
Expand All @@ -36,6 +37,26 @@
from rai.tools.ros2.utils import ros2_message_to_dict


class ROS2TopicsToolkit(BaseToolkit):
name: str = "ROS2TopicsToolkit"
description: str = "A toolkit for ROS2 topics"
connector: ROS2ARIConnector

model_config = ConfigDict(
arbitrary_types_allowed=True,
)

def get_tools(self) -> List[BaseTool]:
return [
PublishROS2MessageTool(connector=self.connector),
ReceiveROS2MessageTool(connector=self.connector),
GetROS2ImageTool(connector=self.connector),
GetROS2TransformTool(connector=self.connector),
GetROS2TopicsNamesAndTypesTool(connector=self.connector),
GetROS2MessageInterfaceTool(connector=self.connector),
]


class PublishROS2MessageToolInput(BaseModel):
topic: str = Field(..., description="The topic to publish the message to")
message: Dict[str, Any] = Field(..., description="The message to publish")
Expand Down

0 comments on commit a44d590

Please sign in to comment.