Skip to content

Commit

Permalink
WIP responses over websocket
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Dec 10, 2024
1 parent 9343301 commit 65b2c13
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 42 deletions.
40 changes: 8 additions & 32 deletions chat/src/agent/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,15 @@
from langchain.schema import AgentFinish, AgentAction


class AgentHandler(BaseCallbackHandler):
def on_llm_new_token(self, token: str, metadata: Optional[dict[str, Any]], **kwargs: Any) -> Any:
socket: Websocket = metadata.get("socket", None)

if socket is None:
raise ValueError("Socket not defined in agent handler via metadata")

socket.send("test")

# def on_tool_start(self, serialized: Dict[str, Any], input_str: str, metadata: Optional[dict[str, Any]], **kwargs: Any) -> Any:
# print(f"on_tool_start: {serialized, input_str}")
# # socket: Websocket = metadata.get("socket", None)

# # if socket is None:
# # raise ValueError("Socket not defined in agent handler via metadata")

# # socket.send(f"🎉 I'm working on it")

# def on_tool_end(self, output: str, **kwargs: Any) -> Any:
# print(f"on_tool_end: {output}")
# pass

# def on_text(self, text: str, **kwargs: Any) -> Any:
# print(f"on_text: {text}")
# pass

# def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
# print(f"on_agent_action: {action}")
# pass

class AgentHandler(BaseCallbackHandler):
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, metadata: Optional[dict[str, Any]], **kwargs: Any) -> Any:
socket: Websocket = metadata.get("socket", None)
socket: Websocket = metadata.get("socket")
if socket is None:
raise ValueError("Socket not defined in agent handler via metadata")

socket.send(f"🎉 I'm working on it")
socket.send({"type": "tool_start", "message": serialized})

def on_tool_end(self, output, **kwargs):
print("on_tool_end output:")
print(output)

15 changes: 12 additions & 3 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from typing import Literal

from agent.tools import search, aggregate
from langchain_core.runnables.config import RunnableConfig
from langchain_core.messages.base import BaseMessage
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from setup import openai_chat_client
from websocket import Websocket


tools = [search, aggregate]

Expand All @@ -26,10 +30,15 @@ def should_continue(state: MessagesState) -> Literal["tools", END]:


# Define the function that calls the model
def call_model(state: MessagesState):
def call_model(state: MessagesState, config: RunnableConfig):
messages = state["messages"]
response = model.invoke(messages, model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID"))
socket = config["configurable"].get("socket", None)
response: BaseMessage = model.invoke(messages, model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID"))
# We return a list, because this will get added to the existing list
# if socket is not none and the response content is not an empty string
if socket is not None and response.content != "":
print("Sending response to socket")
socket.send({"type": "answer", "message": response.content})
return {"messages": [response]}


Expand All @@ -56,4 +65,4 @@ def call_model(state: MessagesState):
checkpointer = MemorySaver()

# Compile the graph
search_agent = workflow.compile(checkpointer=checkpointer, debug=False)
search_agent = workflow.compile(checkpointer=checkpointer)
16 changes: 12 additions & 4 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,18 @@ def handler(event, context):
)

callbacks = [AgentHandler()]
response = search_agent.invoke(
{"messages": [HumanMessage(content=config.question)]},
config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks, "metadata": {"socket": config.socket}},
)
try:
search_agent.invoke(
{"messages": [HumanMessage(content=config.question)]},
config={"configurable": {"thread_id": config.ref, "socket": config.socket}, "callbacks": callbacks, "metadata": {"socket": config.socket}},
debug=False
)
except Exception as e:
print(f"Error: {e}")
error_response = {"type": "error", "message": "An unexpected error occurred. Please try again later."}
if config.socket:
config.socket.send(error_response)
return {"statusCode": 500, "body": json.dumps(error_response)}

return {"statusCode": 200}

Expand Down
7 changes: 4 additions & 3 deletions chat/src/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ def __init__(self, client=None, endpoint_url=None, connection_id=None, ref=None)
self.ref = ref if ref else {}

def send(self, data):
if isinstance(data, str):
data = {"message": data}
data["ref"] = self.ref
# if isinstance(data, str):
# data = {"message": data}
# data["ref"] = self.ref
data_as_bytes = bytes(json.dumps(data), "utf-8")

if self.connection_id == "debug":
print(data)
else:
print(f"Sending data to {self.connection_id}: {data}")
self.client.post_to_connection(Data=data_as_bytes, ConnectionId=self.connection_id)
return data

Expand Down

0 comments on commit 65b2c13

Please sign in to comment.