From 6022271e6492f2adfbef8886e8b7abde69b7bb83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Boczek?= Date: Tue, 17 Sep 2024 09:59:55 +0200 Subject: [PATCH] feat(`rai_node`): handle task as ros2 action --- src/rai/rai/node.py | 42 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/src/rai/rai/node.py b/src/rai/rai/node.py index 9700a8673..92a60f197 100644 --- a/src/rai/rai/node.py +++ b/src/rai/rai/node.py @@ -41,7 +41,7 @@ from std_srvs.srv import Trigger from rai.agents.state_based import Report, State -from rai.messages.multimodal import HumanMultimodalMessage +from rai.messages.multimodal import HumanMultimodalMessage, MultimodalMessage from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str from rai.tools.utils import wait_for_message from rai.utils.ros import NodeDiscovery, RosoutBuffer @@ -300,15 +300,41 @@ async def agent_loop(self, goal_handle: ServerGoalHandle): payload = State(messages=messages) - state: State = self.llm_app.invoke( + state = None + for state in self.llm_app.stream( payload, {"recursion_limit": self.AGENT_RECURSION_LIMIT} - ) # type: ignore + ): - # ---- Share Action feedback ---- - # TODO(boczekbartek): add graph node to langgraph which will send ros2 action feedback to HMI + print(state.keys()) + graph_node_name = list(state.keys())[0] + if graph_node_name == "reporter": + continue + + msg = state[graph_node_name]["messages"][-1] + + if isinstance(msg, MultimodalMessage): + last_msg = msg.text + else: + last_msg = msg.content + + feedback_msg = TaskAction.Feedback() + feedback_msg.current_status = f"{graph_node_name}: {last_msg}" + + goal_handle.publish_feedback(feedback_msg) # ---- Share Action Result ---- - report = state["messages"][-1] + if state is None: + raise ValueError("No output from LLM") + print(state) + + graph_node_name = list(state.keys())[0] + if graph_node_name != "reporter": + raise ValueError(f"Unexpected output llm node: {graph_node_name}") + + report = state["reporter"]["messages"][ + -1 + ] # TODO define graph more strictly not as dict key + if not isinstance(report, Report): raise ValueError(f"Unexpected type of agent output: {type(report)}") @@ -333,6 +359,10 @@ def set_app(self, app: CompiledGraph): def get_robot_state(self) -> Dict[str, str]: state_dict = dict() + + if self.robot_state is None: + return state_dict + for t in self.state_subscribers: if t not in self.robot_state: msg = "No message yet"