Skip to content

Commit

Permalink
feat(rai_node): handle task as ros2 action
Browse files Browse the repository at this point in the history
  • Loading branch information
boczekbartek committed Sep 17, 2024
1 parent 0468302 commit 6022271
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions src/rai/rai/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from std_srvs.srv import Trigger

from rai.agents.state_based import Report, State
from rai.messages.multimodal import HumanMultimodalMessage
from rai.messages.multimodal import HumanMultimodalMessage, MultimodalMessage
from rai.tools.ros.utils import convert_ros_img_to_base64, import_message_from_str
from rai.tools.utils import wait_for_message
from rai.utils.ros import NodeDiscovery, RosoutBuffer
Expand Down Expand Up @@ -300,15 +300,41 @@ async def agent_loop(self, goal_handle: ServerGoalHandle):

payload = State(messages=messages)

state: State = self.llm_app.invoke(
state = None
for state in self.llm_app.stream(
payload, {"recursion_limit": self.AGENT_RECURSION_LIMIT}
) # type: ignore
):

# ---- Share Action feedback ----
# TODO(boczekbartek): add graph node to langgraph which will send ros2 action feedback to HMI
print(state.keys())
graph_node_name = list(state.keys())[0]
if graph_node_name == "reporter":
continue

msg = state[graph_node_name]["messages"][-1]

if isinstance(msg, MultimodalMessage):
last_msg = msg.text
else:
last_msg = msg.content

feedback_msg = TaskAction.Feedback()
feedback_msg.current_status = f"{graph_node_name}: {last_msg}"

goal_handle.publish_feedback(feedback_msg)

# ---- Share Action Result ----
report = state["messages"][-1]
if state is None:
raise ValueError("No output from LLM")
print(state)

graph_node_name = list(state.keys())[0]
if graph_node_name != "reporter":
raise ValueError(f"Unexpected output llm node: {graph_node_name}")

report = state["reporter"]["messages"][
-1
] # TODO define graph more strictly not as dict key

if not isinstance(report, Report):
raise ValueError(f"Unexpected type of agent output: {type(report)}")

Expand All @@ -333,6 +359,10 @@ def set_app(self, app: CompiledGraph):

def get_robot_state(self) -> Dict[str, str]:
state_dict = dict()

if self.robot_state is None:
return state_dict

for t in self.state_subscribers:
if t not in self.robot_state:
msg = "No message yet"
Expand Down

0 comments on commit 6022271

Please sign in to comment.