Skip to content

Commit

Permalink
Indicate final message in socket
Browse files Browse the repository at this point in the history
  • Loading branch information
kdid committed Dec 13, 2024
1 parent e61928d commit 1905eba
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion chat/src/agent/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,15 @@ def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], metadata:
self.socket.send({"type": "start", "ref": self.ref, "message": {"model": metadata.get("model_deployment")}})

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
content = response.generations[0][0].text
response_generation = response.generations[0][0]
content = response_generation.text
stop_reason = response_generation.message.response_metadata.get("stop_reason", "unknown")
if content != "":
self.socket.send({"type": "stop", "ref": self.ref})
self.socket.send({"type": "answer", "ref": self.ref, "message": content})
if stop_reason == "end_turn":
self.socket.send({"type": "final_message", "ref": self.ref})


def on_llm_new_token(self, token: str, **kwargs: Dict[str, Any]):
if token != "":
Expand All @@ -59,3 +64,6 @@ def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
print(f"Invalid json ({e}) returned from {output.name} tool: {output.content}")
case _:
print(f"Unhandled tool_end message: {output}")

def on_agent_finish(self, finish, **kwargs):
self.socket.send({"type": "final", "ref": self.ref, "message": "Finished"})

0 comments on commit 1905eba

Please sign in to comment.