From e154f6c3b9b8559cec61b7c4429f60908cd7d848 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 6 Sep 2024 12:43:55 +0200 Subject: [PATCH 01/12] feat: `rai_node` can handle only 1 mission at once + mission result format improvements --- src/rai/rai/agents/state_based.py | 3 + src/rai/rai/node.py | 106 +++++++++++++++--------------- 2 files changed, 57 insertions(+), 52 deletions(-) diff --git a/src/rai/rai/agents/state_based.py b/src/rai/rai/agents/state_based.py index 384121b57..4dd2d9192 100644 --- a/src/rai/rai/agents/state_based.py +++ b/src/rai/rai/agents/state_based.py @@ -76,6 +76,9 @@ class Report(BaseModel): steps: List[str] = Field( ..., title="Steps", description="The steps taken to solve the problem" ) + success: bool = Field( + ..., title="Success", description="Whether the problem was solved" + ) response_to_user: str = Field( ..., title="Response", description="The response to the user" ) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 784a39003..c185347be 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -17,7 +17,6 @@ import time from collections import deque from dataclasses import dataclass, field -from pprint import pformat from typing import Any, Callable, Deque, Dict, List, Literal, Optional, Tuple import rcl_interfaces.msg @@ -34,7 +33,7 @@ from langchain_openai import ChatOpenAI from langgraph.graph.graph import CompiledGraph from rclpy.action.graph import get_action_names_and_types -from rclpy.action.server import ActionServer +from rclpy.action.server import ActionServer, GoalResponse, ServerGoalHandle from rclpy.node import Node from rclpy.qos import ( DurabilityPolicy, @@ -45,7 +44,7 @@ ) from std_srvs.srv import Trigger -from rai.agents.state_based import State +from rai.agents.state_based import Report, State from rai.messages.multimodal import HumanMultimodalMessage from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str from rai.tools.utils import wait_for_message @@ -233,20 +232,8 @@ def __init__( Trigger, "rai_whoami_identity_service" ) - self.DISCOVERY_FREQ = 2.0 - self.DISCOVERY_DEPTH = 5 - self.callback_group = rclpy.callback_groups.MutuallyExclusiveCallbackGroup() - self.qos_profile = QoSProfile( - history=HistoryPolicy.KEEP_LAST, - depth=1, - reliability=ReliabilityPolicy.BEST_EFFORT, - durability=DurabilityPolicy.VOLATILE, - liveliness=LivelinessPolicy.AUTOMATIC, - ) - - self.state_subscribers = dict() self.initialize_robot_state_interfaces(self.state_topics) self.system_prompt = self.initialize_system_prompt(system_prompt) @@ -310,6 +297,14 @@ def initialize_robot_state_interfaces(self, topics): self.state_subscribers[topic] = subscriber +def parse_task_goal(ros_action_goal: TaskAction.Goal) -> Dict[str, Any]: + return dict( + task=ros_action_goal.task, + description=ros_action_goal.description, + priority=ros_action_goal.priority, + ) + + class RaiNode(RaiGenericBaseNode): def __init__( self, @@ -332,11 +327,7 @@ def __init__( **kwargs, ) - # ---------- ROS Parameters ---------- - self.task_topic = "/task_addition_requests" - # ---------- ROS configuration ---------- - self.rosout_sub = self.create_subscription( rcl_interfaces.msg.Log, "/rosout", @@ -347,8 +338,14 @@ def __init__( # ---------- Task Queue ---------- self.task_action_server = ActionServer( - self, TaskAction, "perform_task", self.agent_loop + self, + TaskAction, + "perform_task", + execute_callback=self.agent_loop, + goal_callback=self.goal_callback, ) + # Node is busy when task is executed. Only 1 task is allowed + self.busy = False # ---------- LLM Agents ---------- self.AGENT_RECURSION_LIMIT = 100 @@ -358,45 +355,50 @@ def __init__( # self.agent_loop_thread = Thread(target=self.agent_loop) # self.agent_loop_thread.start() - def agent_loop(self, goal_handle: TaskAction.Goal): - self.get_logger().info(f"Received goal handle: {goal_handle}") - action_request = goal_handle.request - task = dict( - task=action_request.task, - description=action_request.description, - priority=action_request.priority, - ) - self.get_logger().info(f"Received task: {task}") - - # ---- LLM Task Handling ---- - messages = [ - SystemMessage(content=self.system_prompt), - HumanMessage(content=f"Task: {task}"), - ] + def goal_callback(self, _) -> GoalResponse: + """Accept or reject a client request to begin an action.""" + response = GoalResponse.REJECT if self.busy else GoalResponse.ACCEPT + self.get_logger().info(f"Received goal request. Response: {response}") + return response + + async def agent_loop(self, goal_handle: ServerGoalHandle): + self.busy = True + try: + action_request: TaskAction.Goal = goal_handle.request + task: Dict[str, Any] = parse_task_goal( + action_request + ) # TODO(boczekbartek): base model and json + + self.get_logger().info(f"Received task: {task}") + + # ---- LLM Task Handling ---- + messages = [ + SystemMessage(content=self.system_prompt), + HumanMessage(content=f"Task: {task}"), + ] - payload = State(messages=messages) + payload = State(messages=messages) - state: State = self.llm_app.invoke( - payload, {"recursion_limit": self.AGENT_RECURSION_LIMIT} - ) # type: ignore + state: State = self.llm_app.invoke( + payload, {"recursion_limit": self.AGENT_RECURSION_LIMIT} + ) # type: ignore - # ---- Share Action feedback ---- - # TODO(boczekbartek): add graph node to langgraph which will send ros2 action feedback to HMI + # ---- Share Action feedback ---- + # TODO(boczekbartek): add graph node to langgraph which will send ros2 action feedback to HMI - # ---- Share Action Result ---- - report = state["messages"][-1] - report = pformat(report.json()) + # ---- Share Action Result ---- + report: Report = state["messages"][-1].content - result = TaskAction.Result() - result.success = ( - True # TODO(boczekbartek): ask llm if the action has been successful - ) - result.report = report + result = TaskAction.Result() + result.success = report.success + result.report = report.response_to_user - self.get_logger().info(f"Finished task:\n{report}") - self.clear_state() + self.get_logger().info(f"Finished task:\n{result}") + self.clear_state() - return report + return result + finally: + self.busy = False def set_app(self, app: CompiledGraph): self.llm_app = app From 260fb8727c63dca7cfb3a929d25566a978bd013c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 6 Sep 2024 16:22:07 +0200 Subject: [PATCH 02/12] render tools description and args in system prompt --- examples/rosbot-xl-generic-node-demo.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/examples/rosbot-xl-generic-node-demo.py b/examples/rosbot-xl-generic-node-demo.py index 662f8fca5..dda43ca30 100644 --- a/examples/rosbot-xl-generic-node-demo.py +++ b/examples/rosbot-xl-generic-node-demo.py @@ -20,6 +20,7 @@ import rclpy.qos import rclpy.subscription import rclpy.task +from langchain.tools.render import render_text_description_and_args from langchain_openai import ChatOpenAI from rai.agents.state_based import create_state_based_agent @@ -68,10 +69,9 @@ def main(): "/wait", ] - SYSTEM_PROMPT = "You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests. " - "Do not make assumptions about the environment you are currently in. " - "Use the tooling provided to gather information about the environment." - "You can use ros2 topics, services and actions to operate." + # TODO(boczekbartek): refactor system prompt + + SYSTEM_PROMPT = "" node = RaiNode( llm=ChatOpenAI( @@ -94,6 +94,18 @@ def main(): state_retriever = node.get_robot_state + SYSTEM_PROMPT = f"""You are an autonomous robot connected to ros2 environment. Your main goal is to fulfill the user's requests. + Do not make assumptions about the environment you are currently in. + Use the tooling provided to gather information about the environment: + + {render_text_description_and_args(tools)} + + You can use ros2 topics, services and actions to operate. """ + + node.get_logger().info(f"{SYSTEM_PROMPT=}") + + node.system_prompt = node.initialize_system_prompt(SYSTEM_PROMPT) + app = create_state_based_agent( llm=llm, tools=tools, From 37c4ca7a8674221b0ebea60a8416961801ec49c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 6 Sep 2024 16:22:32 +0200 Subject: [PATCH 03/12] fix: artifact database --- src/rai/rai/agents/state_based.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/rai/rai/agents/state_based.py b/src/rai/rai/agents/state_based.py index 4dd2d9192..f250d03d6 100644 --- a/src/rai/rai/agents/state_based.py +++ b/src/rai/rai/agents/state_based.py @@ -17,6 +17,7 @@ import pickle import time from functools import partial +from pathlib import Path from typing import ( Any, Callable, @@ -84,15 +85,31 @@ class Report(BaseModel): ) -def get_stored_artifacts(tool_call_id: str) -> List[Any]: - with open("artifact_database.pkl", "rb") as file: - artifact_database = pickle.load(file) +def get_stored_artifacts( + tool_call_id: str, db_path="artifact_database.pkl" +) -> List[Any]: + # TODO(boczekbartek): refactor + db_path = Path(db_path) + if not db_path.is_file(): + return [] + + with db_path.open("rb") as db: + artifact_database = pickle.load(db) if tool_call_id in artifact_database: return artifact_database[tool_call_id] + return [] -def store_artifacts(tool_call_id: str, artifacts: List[Any]): +def store_artifacts( + tool_call_id: str, artifacts: List[Any], db_path="artifact_database.pkl" +): + # TODO(boczekbartek): refactor + db_path = Path(db_path) + if not db_path.is_file(): + artifact_database = {} + with open("artifact_database.pkl", "wb") as file: + pickle.dump(artifact_database, file) with open("artifact_database.pkl", "rb") as file: artifact_database = pickle.load(file) if tool_call_id not in artifact_database: From e7c59124c0c9d5acf4ce6362a924fcf313352d32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 6 Sep 2024 16:22:49 +0200 Subject: [PATCH 04/12] fix: topics whitelisting --- src/rai/rai/node.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index c185347be..d26321a9c 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -130,6 +130,7 @@ def dict(self): class RaiBaseNode(Node): def __init__( self, + whitelist: Optional[List[str]] = None, *args, **kwargs, ): @@ -143,7 +144,7 @@ def __init__( self.DISCOVERY_FREQ, self.discovery, ) - self.ros_discovery_info = NodeDiscovery(whitelist=None) + self.ros_discovery_info = NodeDiscovery(whitelist=whitelist) self.discovery() self.qos_profile = QoSProfile( history=HistoryPolicy.KEEP_LAST, @@ -214,10 +215,9 @@ def __init__( *args, **kwargs, ): - super().__init__(node_name, *args, **kwargs) + super().__init__(node_name=node_name, whitelist=whitelist, *args, **kwargs) self.llm = llm - self.whitelist = whitelist self.robot_state = dict() self.state_topics = observe_topics if observe_topics is not None else [] self.state_postprocessors = ( @@ -387,7 +387,9 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): # TODO(boczekbartek): add graph node to langgraph which will send ros2 action feedback to HMI # ---- Share Action Result ---- - report: Report = state["messages"][-1].content + report = state["messages"][-1] + if not isinstance(report, Report): + raise ValueError(f"Unexpected type of agent output: {type(report)}") result = TaskAction.Result() result.success = report.success From 6826e61c7e2b705bb56867daa6b6e48e77358537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Fri, 6 Sep 2024 16:23:28 +0200 Subject: [PATCH 05/12] revert: handling error in tool --- src/rai/rai/tools/ros/native.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/rai/rai/tools/ros/native.py b/src/rai/rai/tools/ros/native.py index 2ae58b05f..9bcdd26ea 100644 --- a/src/rai/rai/tools/ros/native.py +++ b/src/rai/rai/tools/ros/native.py @@ -72,8 +72,6 @@ class Ros2BaseTool(BaseTool): node: rclpy.node.Node = Field(..., exclude=True, required=True) args_schema: Type[Ros2BaseInput] = Ros2BaseInput - handle_tool_error = True - handle_validation_error = True @property def logger(self) -> RcutilsLogger: From 977105c5b2757647b641bebd3e90c23e19dd6ede Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Mon, 9 Sep 2024 07:16:53 +0200 Subject: [PATCH 06/12] feat(`ros2_native_tools`): improve action description --- src/rai/rai/tools/ros/native_actions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rai/rai/tools/ros/native_actions.py b/src/rai/rai/tools/ros/native_actions.py index c464b6f90..7a9a7c0c5 100644 --- a/src/rai/rai/tools/ros/native_actions.py +++ b/src/rai/rai/tools/ros/native_actions.py @@ -56,7 +56,7 @@ def _run(self): class Ros2RunActionSync(Ros2BaseTool): name: str = "Ros2RunAction" description: str = ( - "A tool for running a ros2 action. Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result" + "A tool for running a ros2 action. Make sure you know the action interface first!!! Actions might take some time to execute and are blocking - you will not be able to check their feedback, only will be informed about the result" ) args_schema: Type[Ros2ActionRunnerInput] = Ros2ActionRunnerInput From cf0d5afd64ae02e1059d10ada69c2f3db46a5262 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 10 Sep 2024 10:28:41 +0200 Subject: [PATCH 07/12] refactor and cleanup --- examples/rosbot-xl-generic-node-demo.py | 5 +- src/rai/rai/node.py | 88 ++----------------------- src/rai/rai/tools/time.py | 13 +--- src/rai/rai/utils/ros.py | 76 +++++++++++++++++++++ 4 files changed, 89 insertions(+), 93 deletions(-) create mode 100644 src/rai/rai/utils/ros.py diff --git a/examples/rosbot-xl-generic-node-demo.py b/examples/rosbot-xl-generic-node-demo.py index dda43ca30..b2bbbb614 100644 --- a/examples/rosbot-xl-generic-node-demo.py +++ b/examples/rosbot-xl-generic-node-demo.py @@ -24,7 +24,7 @@ from langchain_openai import ChatOpenAI from rai.agents.state_based import create_state_based_agent -from rai.node import RaiNode, describe_ros_image, wait_for_2s +from rai.node import RaiNode, describe_ros_image from rai.tools.ros.native import ( GetCameraImage, GetMsgFromTopic, @@ -32,6 +32,7 @@ ) from rai.tools.ros.native_actions import Ros2RunActionSync from rai.tools.ros.tools import GetOccupancyGridTool +from rai.tools.time import WaitForSecondsTool def main(): @@ -84,7 +85,7 @@ def main(): ) tools = [ - wait_for_2s, + WaitForSecondsTool(), GetMsgFromTopic(node=node), Ros2RunActionSync(node=node), GetCameraImage(node=node), diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index d26321a9c..9700a8673 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -15,9 +15,7 @@ import functools import time -from collections import deque -from dataclasses import dataclass, field -from typing import Any, Callable, Deque, Dict, List, Literal, Optional, Tuple +from typing import Any, Callable, Dict, List, Literal, Optional import rcl_interfaces.msg import rclpy @@ -27,9 +25,7 @@ import rclpy.subscription import rclpy.task import sensor_msgs.msg -from langchain.tools import tool from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langgraph.graph.graph import CompiledGraph from rclpy.action.graph import get_action_names_and_types @@ -48,85 +44,10 @@ from rai.messages.multimodal import HumanMultimodalMessage from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str from rai.tools.utils import wait_for_message +from rai.utils.ros import NodeDiscovery, RosoutBuffer from rai_interfaces.action import Task as TaskAction -class RosoutBuffer: - def __init__(self, llm, bufsize: int = 100) -> None: - self.bufsize = bufsize - self._buffer: Deque[str] = deque() - self.template = ChatPromptTemplate.from_messages( - [ - ( - "system", - "Shorten the following log keeping its format - for example merge simillar or repeating lines", - ), - ("human", "{rosout}"), - ] - ) - llm = llm - self.llm = self.template | llm - - def clear(self): - self._buffer.clear() - - def append(self, line: str): - self._buffer.append(line) - if len(self._buffer) > self.bufsize: - self._buffer.popleft() - - def get_raw_logs(self, last_n: int = 30) -> str: - return "\n".join(list(self._buffer)[-last_n:]) - - def summarize(self): - if len(self._buffer) == 0: - return "No logs" - buffer = self.get_raw_logs() - response = self.llm.invoke({"rosout": buffer}) - return str(response.content) - - -@tool -def wait_for_2s(): - """Wait for 2 seconds""" - time.sleep(2) - - -@dataclass -class NodeDiscovery: - topics_and_types: Dict[str, str] = field(default_factory=dict) - services_and_types: Dict[str, str] = field(default_factory=dict) - actions_and_types: Dict[str, str] = field(default_factory=dict) - whitelist: Optional[List[str]] = field(default_factory=list) - - def set(self, topics, services, actions): - def to_dict(info: List[Tuple[str, List[str]]]) -> Dict[str, str]: - return {k: v[0] for k, v in info} - - self.topics_and_types = to_dict(topics) - self.services_and_types = to_dict(services) - self.actions_and_types = to_dict(actions) - if self.whitelist is not None: - self.__filter(self.whitelist) - - def __filter(self, whitelist: List[str]): - for d in [ - self.topics_and_types, - self.services_and_types, - self.actions_and_types, - ]: - to_remove = [k for k in d if k not in whitelist] - for k in to_remove: - d.pop(k) - - def dict(self): - return { - "topics_and_types": self.topics_and_types, - "services_and_types": self.services_and_types, - "actions_and_types": self.actions_and_types, - } - - class RaiBaseNode(Node): def __init__( self, @@ -391,6 +312,11 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): if not isinstance(report, Report): raise ValueError(f"Unexpected type of agent output: {type(report)}") + if report.success: + goal_handle.succeed() + else: + goal_handle.abort() + result = TaskAction.Result() result.success = report.success result.report = report.response_to_user diff --git a/src/rai/rai/tools/time.py b/src/rai/rai/tools/time.py index 3ce638440..3b2bd44d8 100644 --- a/src/rai/rai/tools/time.py +++ b/src/rai/rai/tools/time.py @@ -17,19 +17,9 @@ from typing import Type from langchain.pydantic_v1 import BaseModel, Field -from langchain.tools import tool from langchain_core.tools import BaseTool -@tool -def sleep_max_5s(n: int): - """Wait n seconds, max 5s""" - if n > 5: - n = 5 - - time.sleep(n) - - class WaitForSecondsToolInput(BaseModel): """Input for the WaitForSecondsTool tool.""" @@ -44,11 +34,14 @@ class WaitForSecondsTool(BaseTool): "A tool for waiting. " "Useful for pausing execution for a specified number of seconds. " "Input should be the number of seconds to wait." + "Maximum allowed time is 5 seconds" ) args_schema: Type[WaitForSecondsToolInput] = WaitForSecondsToolInput def _run(self, seconds: int): """Waits for the specified number of seconds.""" + if seconds > 5: + seconds = 5 time.sleep(seconds) return f"Waited for {seconds} seconds." diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py new file mode 100644 index 000000000..4c637916a --- /dev/null +++ b/src/rai/rai/utils/ros.py @@ -0,0 +1,76 @@ +from collections import deque +from dataclasses import dataclass, field +from typing import Deque, Dict, List, Optional, Tuple + +from langchain_core.language_models import BaseChatModel +from langchain_core.prompts import ChatPromptTemplate + + +class RosoutBuffer: + def __init__(self, llm: BaseChatModel, bufsize: int = 100) -> None: + self.bufsize = bufsize + self._buffer: Deque[str] = deque() + self.template = ChatPromptTemplate.from_messages( + [ + ( + "system", + "Shorten the following log keeping its format - for example merge simillar or repeating lines", + ), + ("human", "{rosout}"), + ] + ) + llm = llm + self.llm = self.template | llm + + def clear(self): + self._buffer.clear() + + def append(self, line: str): + self._buffer.append(line) + if len(self._buffer) > self.bufsize: + self._buffer.popleft() + + def get_raw_logs(self, last_n: int = 30) -> str: + return "\n".join(list(self._buffer)[-last_n:]) + + def summarize(self): + if len(self._buffer) == 0: + return "No logs" + buffer = self.get_raw_logs() + response = self.llm.invoke({"rosout": buffer}) + return str(response.content) + + +@dataclass +class NodeDiscovery: + topics_and_types: Dict[str, str] = field(default_factory=dict) + services_and_types: Dict[str, str] = field(default_factory=dict) + actions_and_types: Dict[str, str] = field(default_factory=dict) + whitelist: Optional[List[str]] = field(default_factory=list) + + def set(self, topics, services, actions): + def to_dict(info: List[Tuple[str, List[str]]]) -> Dict[str, str]: + return {k: v[0] for k, v in info} + + self.topics_and_types = to_dict(topics) + self.services_and_types = to_dict(services) + self.actions_and_types = to_dict(actions) + if self.whitelist is not None: + self.__filter(self.whitelist) + + def __filter(self, whitelist: List[str]): + for d in [ + self.topics_and_types, + self.services_and_types, + self.actions_and_types, + ]: + to_remove = [k for k in d if k not in whitelist] + for k in to_remove: + d.pop(k) + + def dict(self): + return { + "topics_and_types": self.topics_and_types, + "services_and_types": self.services_and_types, + "actions_and_types": self.actions_and_types, + } From 0b9e58d7cd1430ec786cb0d9a7ec619de4f6a014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Mon, 16 Sep 2024 16:26:29 +0200 Subject: [PATCH 08/12] feat(`text_hmi`): expand missing messages --- src/rai_hmi/rai_hmi/text_hmi.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rai_hmi/rai_hmi/text_hmi.py b/src/rai_hmi/rai_hmi/text_hmi.py index 66ce71810..d9dd0876f 100644 --- a/src/rai_hmi/rai_hmi/text_hmi.py +++ b/src/rai_hmi/rai_hmi/text_hmi.py @@ -228,8 +228,9 @@ def display_agent_message( return # we do not handle system messages elif isinstance(message, MissionMessage): logger.info("Displaying mission message") - avatar, content = message.render_steamlit() - st.chat_message("bot", avatar=avatar).markdown(content) + with st.expander(label=message.STATUS): + avatar, content = message.render_steamlit() + st.chat_message("bot", avatar=avatar).markdown(content) else: raise ValueError("Unknown message type") From a5a3f350339a313e18c06d24fcc74f8f85efa0bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 17 Sep 2024 09:59:05 +0200 Subject: [PATCH 09/12] refactor: convert to f-string --- src/rai/rai/agents/state_based.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rai/rai/agents/state_based.py b/src/rai/rai/agents/state_based.py index f250d03d6..2a1c298a5 100644 --- a/src/rai/rai/agents/state_based.py +++ b/src/rai/rai/agents/state_based.py @@ -303,7 +303,7 @@ def retriever_wrapper( info = str_output(retrieved_info) state["messages"].append( HumanMultimodalMessage( - content="Retrieved state: {}".format(info), images=images, audios=audios + content=f"Retrieved state: {info}", images=images, audios=audios ) ) return state From 25dd876e01cc059d80ff8f34c17fcee0cfbd2728 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 17 Sep 2024 09:59:35 +0200 Subject: [PATCH 10/12] make ros2 action tool more robust --- src/rai/rai/tools/ros/native_actions.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/rai/rai/tools/ros/native_actions.py b/src/rai/rai/tools/ros/native_actions.py index 7a9a7c0c5..740e6e22c 100644 --- a/src/rai/rai/tools/ros/native_actions.py +++ b/src/rai/rai/tools/ros/native_actions.py @@ -82,6 +82,10 @@ def _build_msg( def _run( self, action_name: str, action_type: str, action_goal_args: Dict[str, Any] ): + if action_name[0] != "/": + action_name = "/" + action_name + self.node.get_logger().info(f"Action name corrected to: {action_name}") + try: goal_msg, msg_cls = self._build_msg(action_type, action_goal_args) except Exception as e: @@ -89,7 +93,13 @@ def _run( client = ActionClient(self.node, msg_cls, action_name) + retries = 0 while not client.wait_for_server(timeout_sec=1.0): + retries += 1 + if retries > 5: + raise Exception( + f"Action server '{action_name}' is not available. Make sure `action_name` is correct..." + ) self.node.get_logger().info( f"'{action_name}' action server not available, waiting..." ) From 165b51daaa1c9ebd9957dd211aaa0aa27951a54a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 17 Sep 2024 09:59:55 +0200 Subject: [PATCH 11/12] feat(`rai_node`): handle task as ros2 action --- src/rai/rai/node.py | 42 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 9700a8673..92a60f197 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -41,7 +41,7 @@ from std_srvs.srv import Trigger from rai.agents.state_based import Report, State -from rai.messages.multimodal import HumanMultimodalMessage +from rai.messages.multimodal import HumanMultimodalMessage, MultimodalMessage from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str from rai.tools.utils import wait_for_message from rai.utils.ros import NodeDiscovery, RosoutBuffer @@ -300,15 +300,41 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): payload = State(messages=messages) - state: State = self.llm_app.invoke( + state = None + for state in self.llm_app.stream( payload, {"recursion_limit": self.AGENT_RECURSION_LIMIT} - ) # type: ignore + ): - # ---- Share Action feedback ---- - # TODO(boczekbartek): add graph node to langgraph which will send ros2 action feedback to HMI + print(state.keys()) + graph_node_name = list(state.keys())[0] + if graph_node_name == "reporter": + continue + + msg = state[graph_node_name]["messages"][-1] + + if isinstance(msg, MultimodalMessage): + last_msg = msg.text + else: + last_msg = msg.content + + feedback_msg = TaskAction.Feedback() + feedback_msg.current_status = f"{graph_node_name}: {last_msg}" + + goal_handle.publish_feedback(feedback_msg) # ---- Share Action Result ---- - report = state["messages"][-1] + if state is None: + raise ValueError("No output from LLM") + print(state) + + graph_node_name = list(state.keys())[0] + if graph_node_name != "reporter": + raise ValueError(f"Unexpected output llm node: {graph_node_name}") + + report = state["reporter"]["messages"][ + -1 + ] # TODO define graph more strictly not as dict key + if not isinstance(report, Report): raise ValueError(f"Unexpected type of agent output: {type(report)}") @@ -333,6 +359,10 @@ def set_app(self, app: CompiledGraph): def get_robot_state(self) -> Dict[str, str]: state_dict = dict() + + if self.robot_state is None: + return state_dict + for t in self.state_subscribers: if t not in self.robot_state: msg = "No message yet" From dcbb778db9622576a76fb8b8663990dab3095c82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 17 Sep 2024 10:18:49 +0200 Subject: [PATCH 12/12] fix: missing licence --- src/rai/rai/utils/ros.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/rai/rai/utils/ros.py b/src/rai/rai/utils/ros.py index 4c637916a..c281de74e 100644 --- a/src/rai/rai/utils/ros.py +++ b/src/rai/rai/utils/ros.py @@ -1,3 +1,18 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from collections import deque from dataclasses import dataclass, field from typing import Deque, Dict, List, Optional, Tuple