Skip to content

Commit

Permalink
Add MODEL_OVERRIDE global
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Dec 16, 2024
1 parent 11e64e4 commit 7e4cae7
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 8 deletions.
10 changes: 8 additions & 2 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from agent.s3_saver import S3Saver, delete_checkpoints
from agent.tools import aggregate, discover_fields, search
from langchain_aws import ChatBedrock
from langchain_core.messages import HumanMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.language_models.chat_models import BaseModel
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages.system import SystemMessage
from langgraph.graph import END, START, StateGraph, MessagesState
Expand All @@ -21,6 +21,7 @@
class SearchAgent:
def __init__(
self,
model: BaseModel,
*,
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME"),
system_message: str = DEFAULT_SYSTEM_MESSAGE,
Expand All @@ -30,7 +31,12 @@ def __init__(

tools = [discover_fields, search, aggregate]
tool_node = ToolNode(tools)
model = ChatBedrock(**kwargs).bind_tools(tools)

try:
model = model.bind_tools(tools)
except NotImplementedError:
print("Model does not support tool binding")
pass

# Define the function that determines whether to continue or not
def should_continue(state: MessagesState) -> Literal["tools", END]:
Expand Down
3 changes: 2 additions & 1 deletion chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from agent.search_agent import SearchAgent
from agent.agent_handler import AgentHandler
from agent.metrics_handler import MetricsHandler
from handlers.model import chat_model

# honeybadger.configure()
# logging.getLogger("honeybadger").addHandler(logging.StreamHandler())
Expand All @@ -28,7 +29,7 @@ def handler(event, context):

metrics = MetricsHandler()
callbacks = [AgentHandler(config.socket, config.ref), metrics]
search_agent = SearchAgent(model=config.model, streaming=True)
search_agent = SearchAgent(model=chat_model(event), streaming=True)
try:
search_agent.invoke(config.question, config.ref, forget=config.forget, callbacks=callbacks)
log_metrics(context, metrics, config)
Expand Down
12 changes: 12 additions & 0 deletions chat/src/handlers/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from event_config import EventConfig
from langchain_aws import ChatBedrock
from langchain_core.language_models.base import BaseModel

MODEL_OVERRIDE: BaseModel = None

def chat_model(event: EventConfig):
return MODEL_OVERRIDE or ChatBedrock(model=event.model)

def set_model_override(model: BaseModel):
global MODEL_OVERRIDE
MODEL_OVERRIDE = model
9 changes: 6 additions & 3 deletions chat/test/handlers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from unittest import mock, TestCase
from unittest.mock import patch
from handlers.chat import handler
from agent.search_agent import SearchAgent
from helpers.apitoken import ApiToken
from websocket import Websocket
from event_config import EventConfig

from langchain_core.language_models.fake_chat_models import FakeListChatModel
from handlers.model import set_model_override

class MockClient:
def __init__(self):
Expand Down Expand Up @@ -56,14 +57,16 @@ def mock_response(**kwargs):
},
)


class TestHandler(TestCase):
def test_handler_unauthorized(self):
def test_handler_unauthorized(self):
event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")}
self.assertEqual(handler(event, MockContext()), {'body': 'Unauthorized', 'statusCode': 401})

@patch.object(ApiToken, 'is_logged_in')
def test_handler_success(self, mock_is_logged_in):
mock_is_logged_in.return_value = True
set_model_override(FakeListChatModel(responses=["one", "two", "three"]))
event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test"), "body": '{"question": "Question?"}' }
self.assertEqual(handler(event, MockContext()), {'statusCode': 200})

Expand Down
3 changes: 1 addition & 2 deletions chat/test/test_event_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# ruff: noqa: E402
import json
import os
import sys
sys.path.append('./src')

from event_config import EventConfig
from unittest import TestCase, mock
from unittest import TestCase

class TestEventConfig(TestCase):
def test_defaults(self):
Expand Down

0 comments on commit 7e4cae7

Please sign in to comment.