Skip to content

Commit

Permalink
Fix up mocking and SearchAgent test
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Dec 17, 2024
1 parent 2e4f4f8 commit 63c828c
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

logger = logging.getLogger(__name__)

def create_checkpoint_saver(**kwargs) -> BaseCheckpointSaver:
def checkpoint_saver(**kwargs) -> BaseCheckpointSaver:
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME")

return S3Saver(bucket_name=checkpoint_bucket, **kwargs)
7 changes: 2 additions & 5 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal, List

from agent.checkpoint_factory import create_checkpoint_saver
from agent.checkpoints import checkpoint_saver
from agent.tools import aggregate, discover_fields, search
from langchain_core.messages import HumanMessage
from langchain_core.messages.base import BaseMessage
Expand All @@ -21,12 +21,9 @@ def __init__(
self,
model: BaseModel,
*,
streaming: bool = True,
system_message: str = DEFAULT_SYSTEM_MESSAGE,
**kwargs
):
self.streaming = streaming

tools = [discover_fields, search, aggregate]
tool_node = ToolNode(tools)

Expand Down Expand Up @@ -71,7 +68,7 @@ def call_model(state: MessagesState):
# Add a normal edge from `tools` to `agent`
workflow.add_edge("tools", "agent")

self.checkpointer = create_checkpoint_saver()
self.checkpointer = checkpoint_saver()
self.search_agent = workflow.compile(checkpointer=self.checkpointer)

def invoke(self, question: str, ref: str, *, callbacks: List[BaseCallbackHandler] = [], forget: bool = False, **kwargs):
Expand Down
10 changes: 2 additions & 8 deletions chat/src/handlers/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
from langchain_aws import ChatBedrock
from langchain_core.language_models.base import BaseModel

MODEL_OVERRIDE: BaseModel = None

def chat_model(**kwargs):
return MODEL_OVERRIDE or ChatBedrock(**kwargs)

def set_model_override(model: BaseModel):
global MODEL_OVERRIDE
MODEL_OVERRIDE = model
def chat_model(**kwargs) -> BaseModel:
return ChatBedrock(**kwargs)
31 changes: 15 additions & 16 deletions chat/test/agent/test_search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,30 @@
sys.path.append('./src')

from agent.search_agent import SearchAgent
from handlers.model import chat_model, set_model_override
from langchain_core.language_models.fake_chat_models import FakeListChatModel
from langgraph.checkpoint.memory import MemorySaver


class TestSearchAgent(TestCase):

@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
@patch('agent.search_agent.checkpoint_saver', return_value=MemorySaver())
def test_search_agent_init(self, mock_create_saver):
set_model_override(FakeListChatModel(responses=["fake response"]))
search_agent = SearchAgent(model=chat_model("test"), streaming=True)
chat_model = FakeListChatModel(responses=["fake response"])
search_agent = SearchAgent(model=chat_model, streaming=True)
self.assertIsNotNone(search_agent)

@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
@patch('agent.search_agent.checkpoint_saver', return_value=MemorySaver())
def test_search_agent_invoke_simple(self, mock_create_saver):
expected_response = "This is a mocked LLM response."
set_model_override(FakeListChatModel(responses=[expected_response]))

search_agent = SearchAgent(model=chat_model("test"), streaming=True)
chat_model = FakeListChatModel(responses=[expected_response])
search_agent = SearchAgent(model=chat_model, streaming=True)
result = search_agent.invoke(question="What is the capital of France?", ref="test_ref")

self.assertIn("messages", result)
self.assertGreater(len(result["messages"]), 0)
self.assertEqual(result["messages"][-1].content, expected_response)

@patch('agent.search_agent.create_checkpoint_saver')
@patch('agent.search_agent.checkpoint_saver')
def test_search_agent_invocation(self, mock_create_saver):
# Create a memory saver instance with a Mock for delete_checkpoints
memory_saver = MemorySaver()
Expand All @@ -39,8 +37,8 @@ def test_search_agent_invocation(self, mock_create_saver):
mock_create_saver.return_value = memory_saver

# Test that the SearchAgent invokes the model with the correct messages
set_model_override(FakeListChatModel(responses=["first response", "second response"]))
search_agent = SearchAgent(model=chat_model("test"), streaming=True)
chat_model = FakeListChatModel(responses=["first response", "second response"])
search_agent = SearchAgent(model=chat_model, streaming=True)

# First invocation with some question
result_1 = search_agent.invoke(question="First question?", ref="test_ref")
Expand All @@ -55,17 +53,18 @@ def test_search_agent_invocation(self, mock_create_saver):
memory_saver.delete_checkpoints.assert_not_called()


@patch('agent.search_agent.create_checkpoint_saver')
@patch('agent.search_agent.checkpoint_saver')
def test_search_agent_invoke_forget(self, mock_create_saver):
# Create a memory saver instance with a Mock for delete_checkpoints
memory_saver = MemorySaver()
from unittest.mock import Mock
memory_saver.delete_checkpoints = Mock()
mock_create_saver.return_value = memory_saver


# Test `forget=True` to ensure that state is reset and doesn't carry over between invocations
set_model_override(FakeListChatModel(responses=["first response", "second response"]))
search_agent = SearchAgent(model=chat_model("test"), streaming=True)
chat_model = FakeListChatModel(responses=["first response", "second response"])
search_agent = SearchAgent(model=chat_model, streaming=True)

# First invocation with some question
result_1 = search_agent.invoke(question="First question?", ref="test_ref")
Expand All @@ -77,8 +76,8 @@ def test_search_agent_invoke_forget(self, mock_create_saver):
self.assertEqual(result_2["messages"][-1].content, "second response")

# Now invoke with forget=True, resetting the state
set_model_override(FakeListChatModel(responses=["fresh response"]))
search_agent = SearchAgent(model=chat_model("test"), streaming=True)
new_chat_model = FakeListChatModel(responses=["fresh response"])
search_agent = SearchAgent(model=new_chat_model, streaming=True)

# Forget the state for "test_ref"
result_3 = search_agent.invoke(question="Third question after forgetting?", ref="test_ref", forget=True)
Expand Down
8 changes: 4 additions & 4 deletions chat/test/handlers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ def __init__(self):
class TestHandler(TestCase):

@patch.object(ApiToken, 'is_logged_in', return_value=False)
@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
@patch('agent.search_agent.checkpoint_saver', return_value=MemorySaver())
def test_handler_unauthorized(self, mock_create_saver, mock_is_logged_in):
event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")}
self.assertEqual(handler(event, MockContext()), {'statusCode': 401, 'body': 'Unauthorized'})

@patch.object(ApiToken, 'is_logged_in', return_value=True)
@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
@patch('agent.search_agent.checkpoint_saver', return_value=MemorySaver())
def test_handler_success(self, mock_create_saver, mock_is_logged_in):
set_model_override(FakeListChatModel(responses=["fake response"]))
event = {
Expand All @@ -46,7 +46,7 @@ def test_handler_success(self, mock_create_saver, mock_is_logged_in):
self.assertEqual(handler(event, MockContext()), {'statusCode': 200})

@patch.object(ApiToken, 'is_logged_in', return_value=True)
@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
@patch('agent.search_agent.checkpoint_saver', return_value=MemorySaver())
def test_handler_question_missing(self, mock_create_saver, mock_is_logged_in):
mock_client = MockClient()
mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test")
Expand All @@ -57,7 +57,7 @@ def test_handler_question_missing(self, mock_create_saver, mock_is_logged_in):
self.assertEqual(response["message"], "Question cannot be blank")

@patch.object(ApiToken, 'is_logged_in', return_value=True)
@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
@patch('agent.search_agent.checkpoint_saver', return_value=MemorySaver())
def test_handler_question_typo(self, mock_create_saver, mock_is_logged_in):
mock_client = MockClient()
mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test")
Expand Down

0 comments on commit 63c828c

Please sign in to comment.