diff --git a/chat/src/agent/search_agent.py b/chat/src/agent/search_agent.py index 6ee3996..32b85ba 100644 --- a/chat/src/agent/search_agent.py +++ b/chat/src/agent/search_agent.py @@ -4,9 +4,9 @@ from agent.s3_saver import S3Saver, delete_checkpoints from agent.tools import aggregate, discover_fields, search -from langchain_aws import ChatBedrock from langchain_core.messages import HumanMessage from langchain_core.messages.base import BaseMessage +from langchain_core.language_models.chat_models import BaseModel from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages.system import SystemMessage from langgraph.graph import END, START, StateGraph, MessagesState @@ -21,6 +21,7 @@ class SearchAgent: def __init__( self, + model: BaseModel, *, checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME"), system_message: str = DEFAULT_SYSTEM_MESSAGE, @@ -30,7 +31,12 @@ def __init__( tools = [discover_fields, search, aggregate] tool_node = ToolNode(tools) - model = ChatBedrock(**kwargs).bind_tools(tools) + + try: + model = model.bind_tools(tools) + except NotImplementedError: + print("Model does not support tool binding") + pass # Define the function that determines whether to continue or not def should_continue(state: MessagesState) -> Literal["tools", END]: diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index 07aea7c..9fb5b24 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -9,6 +9,7 @@ from agent.search_agent import SearchAgent from agent.agent_handler import AgentHandler from agent.metrics_handler import MetricsHandler +from handlers.model import chat_model # honeybadger.configure() # logging.getLogger("honeybadger").addHandler(logging.StreamHandler()) @@ -28,7 +29,7 @@ def handler(event, context): metrics = MetricsHandler() callbacks = [AgentHandler(config.socket, config.ref), metrics] - search_agent = SearchAgent(model=config.model, streaming=True) + search_agent = SearchAgent(model=chat_model(event), streaming=True) try: search_agent.invoke(config.question, config.ref, forget=config.forget, callbacks=callbacks) log_metrics(context, metrics, config) diff --git a/chat/src/handlers/model.py b/chat/src/handlers/model.py new file mode 100644 index 0000000..8d94ed5 --- /dev/null +++ b/chat/src/handlers/model.py @@ -0,0 +1,12 @@ +from event_config import EventConfig +from langchain_aws import ChatBedrock +from langchain_core.language_models.base import BaseModel + +MODEL_OVERRIDE: BaseModel = None + +def chat_model(event: EventConfig): + return MODEL_OVERRIDE or ChatBedrock(model=event.model) + +def set_model_override(model: BaseModel): + global MODEL_OVERRIDE + MODEL_OVERRIDE = model \ No newline at end of file diff --git a/chat/test/handlers/test_chat.py b/chat/test/handlers/test_chat.py index d714941..c0e30ea 100644 --- a/chat/test/handlers/test_chat.py +++ b/chat/test/handlers/test_chat.py @@ -9,10 +9,11 @@ from unittest import mock, TestCase from unittest.mock import patch from handlers.chat import handler -from agent.search_agent import SearchAgent from helpers.apitoken import ApiToken from websocket import Websocket -from event_config import EventConfig + +from langchain_core.language_models.fake_chat_models import FakeListChatModel +from handlers.model import set_model_override class MockClient: def __init__(self): @@ -56,14 +57,16 @@ def mock_response(**kwargs): }, ) + class TestHandler(TestCase): - def test_handler_unauthorized(self): + def test_handler_unauthorized(self): event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")} self.assertEqual(handler(event, MockContext()), {'body': 'Unauthorized', 'statusCode': 401}) @patch.object(ApiToken, 'is_logged_in') def test_handler_success(self, mock_is_logged_in): mock_is_logged_in.return_value = True + set_model_override(FakeListChatModel(responses=["one", "two", "three"])) event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test"), "body": '{"question": "Question?"}' } self.assertEqual(handler(event, MockContext()), {'statusCode': 200}) diff --git a/chat/test/test_event_config.py b/chat/test/test_event_config.py index ad52e21..401c841 100644 --- a/chat/test/test_event_config.py +++ b/chat/test/test_event_config.py @@ -1,11 +1,10 @@ # ruff: noqa: E402 import json -import os import sys sys.path.append('./src') from event_config import EventConfig -from unittest import TestCase, mock +from unittest import TestCase class TestEventConfig(TestCase): def test_defaults(self):