Skip to content

Commit

Permalink
Add agent handler unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesLoder authored and mbklein committed Dec 17, 2024
1 parent 5611be8 commit f3e646d
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 21 deletions.
9 changes: 3 additions & 6 deletions chat/src/agent/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,9 @@ def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
case "discover_fields":
pass
case "search":
try:
result_fields = ("id", "title", "visibility", "work_type", "thumbnail")
docs: List[Dict[str, Any]] = [{k: doc.metadata.get(k) for k in result_fields} for doc in output.artifact]
self.socket.send({"type": "search_result", "ref": self.ref, "message": docs})
except json.decoder.JSONDecodeError as e:
print(f"Invalid json ({e}) returned from {output.name} tool: {output.content}")
result_fields = ("id", "title", "visibility", "work_type", "thumbnail")
docs: List[Dict[str, Any]] = [{k: doc.metadata.get(k) for k in result_fields} for doc in output.artifact]
self.socket.send({"type": "search_result", "ref": self.ref, "message": docs})
case _:
print(f"Unhandled tool_end message: {output}")

Expand Down
169 changes: 154 additions & 15 deletions chat/test/agent/test_agent_handler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import unittest
from unittest import TestCase
from unittest.mock import patch
from unittest.mock import MagicMock
import sys

sys.path.append("./src")

from agent.agent_handler import AgentHandler
from agent.search_agent import SearchAgent
from langchain_core.language_models.fake_chat_models import FakeListChatModel
from langgraph.checkpoint.memory import MemorySaver
from websocket import Websocket

class MockClient:
def __init__(self):
Expand All @@ -18,13 +15,155 @@ def post_to_connection(self, Data, ConnectionId):
return Data

class TestAgentHandler(TestCase):
@patch('agent.search_agent.checkpoint_saver', return_value=MemorySaver())
def test_search_agent_invoke_simple(self, mock_create_saver):
websocket_client = MockClient()
websocket = Websocket(client=websocket_client, connection_id="test_connection_id", ref="test_ref")
expected_response = "This is a mocked LLM response."
chat_model = FakeListChatModel(responses=[expected_response], )
search_agent = SearchAgent(model=chat_model, streaming=True)
callbacks = [AgentHandler(websocket, "test_ref")]
search_agent.invoke(question="What is the capital of France?", ref="test_ref", callbacks=callbacks)
print(websocket_client.received)
def setUp(self):
self.mock_socket = MagicMock()
self.ref = "test_ref"
self.handler = AgentHandler(socket=self.mock_socket, ref=self.ref)

def test_on_llm_start(self):
# Given metadata that includes model name
metadata = {"ls_model_name": "test_model"}

# When on_llm_start is called
self.handler.on_llm_start(serialized={}, prompts=["Hello"], metadata=metadata)

# Then verify the socket was called with the correct start message
self.mock_socket.send.assert_called_once_with({
"type": "start",
"ref": self.ref,
"message": {"model": "test_model"}
})

def test_on_llm_end_with_content(self):
# Mocking LLMResult and Generations
class MockMessage:
def __init__(self, text, response_metadata):
self.text = text
self.message = self # For simplicity, reuse same object for .message
self.response_metadata = response_metadata

class MockLLMResult:
def __init__(self, text, stop_reason="end_turn"):
self.generations = [[MockMessage(text, {"stop_reason": stop_reason})]]

# When response has content and end_turn stop reason
response = MockLLMResult("Here is the answer", stop_reason="end_turn")
self.handler.on_llm_end(response)

# Verify "stop" and "answer" and then "final_message" were sent
expected_calls = [
unittest.mock.call({"type": "stop", "ref": self.ref}),
unittest.mock.call({"type": "answer", "ref": self.ref, "message": "Here is the answer"}),
unittest.mock.call({"type": "final_message", "ref": self.ref})
]
self.mock_socket.send.assert_has_calls(expected_calls, any_order=False)

def test_on_llm_new_token(self):
# When a new token arrives
self.handler.on_llm_new_token("hello")

# Then verify the socket sent a token message
self.mock_socket.send.assert_called_once_with({
"type": "token",
"ref": self.ref,
"message": "hello"
})

def test_on_tool_start(self):
# When tool starts
self.handler.on_tool_start({"name": "test_tool"}, "input_value")

# Verify the tool_start message
self.mock_socket.send.assert_called_once_with({
"type": "tool_start",
"ref": self.ref,
"message": {
"tool": "test_tool",
"input": "input_value"
}
})

def test_on_tool_end_search(self):
# Mock tool output
class MockDoc:
def __init__(self, metadata):
self.metadata = metadata

class MockToolMessage:
def __init__(self, name, artifact):
self.name = name
self.artifact = artifact

artifact = [
MockDoc({"id": 1, "title": "Result 1", "visibility": "public", "work_type": "article", "thumbnail": "img1"}),
MockDoc({"id": 2, "title": "Result 2", "visibility": "private", "work_type": "document", "thumbnail": "img2"})
]

output = MockToolMessage("search", artifact)
self.handler.on_tool_end(output)

# Verify search_result message was sent
expected_message = [
{"id": 1, "title": "Result 1", "visibility": "public", "work_type": "article", "thumbnail": "img1"},
{"id": 2, "title": "Result 2", "visibility": "private", "work_type": "document", "thumbnail": "img2"}
]

self.mock_socket.send.assert_called_once_with({
"type": "search_result",
"ref": self.ref,
"message": expected_message
})

def test_on_tool_end_aggregate(self):
class MockToolMessage:
def __init__(self, name, artifact):
self.name = name
self.artifact = artifact

output = MockToolMessage("aggregate", {"aggregation_result": {"count": 10}})
self.handler.on_tool_end(output)

# Verify aggregation_result message was sent
self.mock_socket.send.assert_called_once_with({
"type": "aggregation_result",
"ref": self.ref,
"message": {"count": 10}
})

def test_on_tool_end_discover_fields(self):
class MockToolMessage:
def __init__(self, name, artifact):
self.name = name
self.artifact = artifact

output = MockToolMessage("discover_fields", {})
self.handler.on_tool_end(output)

self.mock_socket.send.assert_not_called()

def test_on_tool_end_unknown(self):
class MockToolMessage:
def __init__(self, name, artifact):
self.name = name
self.artifact = artifact

output = MockToolMessage("unknown", {})
self.handler.on_tool_end(output)

self.mock_socket.send.assert_not_called()

def test_on_agent_finish(self):
self.handler.on_agent_finish(finish={})
self.mock_socket.send.assert_called_once_with({
"type": "final",
"ref": self.ref,
"message": "Finished"
})

class TestAgentHandlerErrors(TestCase):
def test_missing_socket(self):
with self.assertRaises(ValueError) as context:
AgentHandler(socket=None, ref="abc123")

self.assertIn("Socket not provided to agent callback handler", str(context.exception))

0 comments on commit f3e646d

Please sign in to comment.