From 65b2c13ede53f8292646267bc31003842ad3ccf4 Mon Sep 17 00:00:00 2001 From: Brendan Quinn Date: Tue, 10 Dec 2024 17:03:27 +0000 Subject: [PATCH] WIP responses over websocket --- chat/src/agent/agent_handler.py | 40 +++++++-------------------------- chat/src/agent/search_agent.py | 15 ++++++++++--- chat/src/handlers/chat.py | 16 +++++++++---- chat/src/websocket.py | 7 +++--- 4 files changed, 36 insertions(+), 42 deletions(-) diff --git a/chat/src/agent/agent_handler.py b/chat/src/agent/agent_handler.py index 23454db..c4dd9d2 100644 --- a/chat/src/agent/agent_handler.py +++ b/chat/src/agent/agent_handler.py @@ -8,39 +8,15 @@ from langchain.schema import AgentFinish, AgentAction -class AgentHandler(BaseCallbackHandler): - def on_llm_new_token(self, token: str, metadata: Optional[dict[str, Any]], **kwargs: Any) -> Any: - socket: Websocket = metadata.get("socket", None) - - if socket is None: - raise ValueError("Socket not defined in agent handler via metadata") - - socket.send("test") - - # def on_tool_start(self, serialized: Dict[str, Any], input_str: str, metadata: Optional[dict[str, Any]], **kwargs: Any) -> Any: - # print(f"on_tool_start: {serialized, input_str}") - # # socket: Websocket = metadata.get("socket", None) - - # # if socket is None: - # # raise ValueError("Socket not defined in agent handler via metadata") - - # # socket.send(f"🎉 I'm working on it") - - # def on_tool_end(self, output: str, **kwargs: Any) -> Any: - # print(f"on_tool_end: {output}") - # pass - - # def on_text(self, text: str, **kwargs: Any) -> Any: - # print(f"on_text: {text}") - # pass - - # def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: - # print(f"on_agent_action: {action}") - # pass - +class AgentHandler(BaseCallbackHandler): def on_tool_start(self, serialized: Dict[str, Any], input_str: str, metadata: Optional[dict[str, Any]], **kwargs: Any) -> Any: - socket: Websocket = metadata.get("socket", None) + socket: Websocket = metadata.get("socket") if socket is None: raise ValueError("Socket not defined in agent handler via metadata") - socket.send(f"🎉 I'm working on it") \ No newline at end of file + socket.send({"type": "tool_start", "message": serialized}) + + def on_tool_end(self, output, **kwargs): + print("on_tool_end output:") + print(output) + \ No newline at end of file diff --git a/chat/src/agent/search_agent.py b/chat/src/agent/search_agent.py index 8fe2c3d..f55f70c 100644 --- a/chat/src/agent/search_agent.py +++ b/chat/src/agent/search_agent.py @@ -3,10 +3,14 @@ from typing import Literal from agent.tools import search, aggregate +from langchain_core.runnables.config import RunnableConfig +from langchain_core.messages.base import BaseMessage from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph, MessagesState from langgraph.prebuilt import ToolNode from setup import openai_chat_client +from websocket import Websocket + tools = [search, aggregate] @@ -26,10 +30,15 @@ def should_continue(state: MessagesState) -> Literal["tools", END]: # Define the function that calls the model -def call_model(state: MessagesState): +def call_model(state: MessagesState, config: RunnableConfig): messages = state["messages"] - response = model.invoke(messages, model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")) + socket = config["configurable"].get("socket", None) + response: BaseMessage = model.invoke(messages, model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")) # We return a list, because this will get added to the existing list + # if socket is not none and the response content is not an empty string + if socket is not None and response.content != "": + print("Sending response to socket") + socket.send({"type": "answer", "message": response.content}) return {"messages": [response]} @@ -56,4 +65,4 @@ def call_model(state: MessagesState): checkpointer = MemorySaver() # Compile the graph -search_agent = workflow.compile(checkpointer=checkpointer, debug=False) +search_agent = workflow.compile(checkpointer=checkpointer) diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index 923c56e..df4ce09 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -72,10 +72,18 @@ def handler(event, context): ) callbacks = [AgentHandler()] - response = search_agent.invoke( - {"messages": [HumanMessage(content=config.question)]}, - config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks, "metadata": {"socket": config.socket}}, - ) + try: + search_agent.invoke( + {"messages": [HumanMessage(content=config.question)]}, + config={"configurable": {"thread_id": config.ref, "socket": config.socket}, "callbacks": callbacks, "metadata": {"socket": config.socket}}, + debug=False + ) + except Exception as e: + print(f"Error: {e}") + error_response = {"type": "error", "message": "An unexpected error occurred. Please try again later."} + if config.socket: + config.socket.send(error_response) + return {"statusCode": 500, "body": json.dumps(error_response)} return {"statusCode": 200} diff --git a/chat/src/websocket.py b/chat/src/websocket.py index 0607420..9670868 100644 --- a/chat/src/websocket.py +++ b/chat/src/websocket.py @@ -8,14 +8,15 @@ def __init__(self, client=None, endpoint_url=None, connection_id=None, ref=None) self.ref = ref if ref else {} def send(self, data): - if isinstance(data, str): - data = {"message": data} - data["ref"] = self.ref + # if isinstance(data, str): + # data = {"message": data} + # data["ref"] = self.ref data_as_bytes = bytes(json.dumps(data), "utf-8") if self.connection_id == "debug": print(data) else: + print(f"Sending data to {self.connection_id}: {data}") self.client.post_to_connection(Data=data_as_bytes, ConnectionId=self.connection_id) return data