Skip to content

Commit

Permalink
Add metrics-gathering callback handler
Browse files Browse the repository at this point in the history
Make search agent a class
Wire model, stream_response input options back up
Add jupyter notebook for rapid iteration

Remove errant #* comment
  • Loading branch information
mbklein committed Dec 19, 2024
1 parent 4f1d4ed commit da4505b
Show file tree
Hide file tree
Showing 12 changed files with 481 additions and 264 deletions.
1 change: 0 additions & 1 deletion .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ jobs:
env:
AWS_ACCESS_KEY_ID: ci
AWS_SECRET_ACCESS_KEY: ci
SKIP_LLM_REQUEST: 'True'
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ test-node: deps-node
deps-python:
cd chat/src && pip install -r requirements.txt && pip install -r requirements-dev.txt
cover-python: deps-python
cd chat && export SKIP_LLM_REQUEST=True && coverage run --source=src -m unittest -v && coverage report --skip-empty
cd chat && coverage run --source=src -m unittest -v && coverage report --skip-empty
cover-html-python: deps-python
cd chat && export SKIP_LLM_REQUEST=True && coverage run --source=src -m unittest -v && coverage html --skip-empty
cd chat && coverage run --source=src -m unittest -v && coverage html --skip-empty
style-python: deps-python
cd chat && ruff check .
style-python-fix: deps-python
cd chat && ruff check --fix .
test-python: deps-python
cd chat && __SKIP_SECRETS__=true SKIP_LLM_REQUEST=True PYTHONPATH=src:test python -m unittest discover -v
cd chat && __SKIP_SECRETS__=true PYTHONPATH=src:test python -m unittest discover -v
python-version:
cd chat && python --version
build: .aws-sam/build.toml
Expand Down
344 changes: 344 additions & 0 deletions chat-playground/playground.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion chat/src/agent/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, socket: Websocket, ref: str, *args: List[Any], **kwargs: Dict
super().__init__(*args, **kwargs)

def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], metadata: Optional[dict[str, Any]] = None, **kwargs: Dict[str, Any]):
self.socket.send({"type": "start", "ref": self.ref, "message": {"model": metadata.get("model_deployment")}})
self.socket.send({"type": "start", "ref": self.ref, "message": {"model": metadata.get("ls_model_name")}})

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
response_generation = response.generations[0][0]
Expand Down
32 changes: 32 additions & 0 deletions chat/src/agent/metrics_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Any, Dict
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_core.messages.tool import ToolMessage
import json

class MetricsHandler(BaseCallbackHandler):
def __init__(self, *args, **kwargs):
self.accumulator = {}
self.answers = []
self.artifacts = []
super().__init__(*args, **kwargs)

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
for generation in response.generations[0]:
self.answers.append(generation.text)
for k, v in generation.message.usage_metadata.items():
if k not in self.accumulator:
self.accumulator[k] = v
else:
self.accumulator[k] += v

def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
match output.name:
case "aggregate":
self.artifacts.append({"type": "aggregation", "artifact": output.artifact.get("aggregation_result", {})})
case "search":
try:
source_urls = [doc.metadata["api_link"] for doc in output.artifact]
self.artifacts.append({"type": "source_urls", "artifact": source_urls})
except json.decoder.JSONDecodeError as e:
print(f"Invalid json ({e}) returned from {output.name} tool: {output.content}")
3 changes: 2 additions & 1 deletion chat/src/agent/s3_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import base64
import bz2
import gzip
import os
import time
from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, List
from langchain_core.runnables import RunnableConfig
Expand Down Expand Up @@ -118,7 +119,7 @@ class S3Saver(BaseCheckpointSaver):
def __init__(
self,
bucket_name: str,
region_name: str = 'us-west-2',
region_name: str = os.getenv("AWS_REGION"),
endpoint_url: Optional[str] = None,
compression: Optional[str] = None,
) -> None:
Expand Down
125 changes: 67 additions & 58 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,81 @@
import os

from typing import Literal
from typing import Literal, List

from agent.s3_saver import S3Saver
from agent.s3_saver import S3Saver, delete_checkpoints
from agent.tools import aggregate, discover_fields, search
from langchain_aws import ChatBedrock
from langchain_core.messages import HumanMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages.system import SystemMessage
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from setup import chat_client

# Keep your answer concise and keep reading time under 45 seconds.

system_message = """
DEFAULT_SYSTEM_MESSAGE = """
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that
support your answer. Answer in raw markdown, but not within a code block. When citing source documents, construct Markdown
links using the document's canonical_link field. Do not include intermediate messages explaining your process.
"""

tools = [discover_fields, search, aggregate]

tool_node = ToolNode(tools)

model = chat_client(streaming=True).bind_tools(tools)

# Define the function that determines whether to continue or not
def should_continue(state: MessagesState) -> Literal["tools", END]:
messages = state["messages"]
last_message = messages[-1]
# If the LLM makes a tool call, then we route to the "tools" node
if last_message.tool_calls:
return "tools"
# Otherwise, we stop (reply to the user)
return END


# Define the function that calls the model
def call_model(state: MessagesState):
messages = [SystemMessage(content=system_message)] + state["messages"]
response: BaseMessage = model.invoke(messages) # , model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
# We return a list, because this will get added to the existing list
# if socket is not none and the response content is not an empty string
return {"messages": [response]}

# Define a new graph
workflow = StateGraph(MessagesState)

# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("tools", tool_node)

# Set the entrypoint as `agent`
workflow.add_edge(START, "agent")

# Add a conditional edge
workflow.add_conditional_edges(
"agent",
should_continue,
)

# Add a normal edge from `tools` to `agent`
workflow.add_edge("tools", "agent")

# checkpointer = DynamoDBSaver(os.getenv("CHECKPOINT_TABLE"), os.getenv("CHECKPOINT_WRITES_TABLE"), os.getenv("AWS_REGION", "us-east-1"))

checkpointer = S3Saver(
bucket_name=os.getenv("CHECKPOINT_BUCKET_NAME"),
region_name=os.getenv("AWS_REGION"),
compression="gzip")

search_agent = workflow.compile(checkpointer=checkpointer)
class SearchAgent:
def __init__(
self,
*,
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME"),
system_message: str = DEFAULT_SYSTEM_MESSAGE,
**kwargs):

self.checkpoint_bucket = checkpoint_bucket

tools = [discover_fields, search, aggregate]
tool_node = ToolNode(tools)
model = ChatBedrock(**kwargs).bind_tools(tools)

# Define the function that determines whether to continue or not
def should_continue(state: MessagesState) -> Literal["tools", END]:
messages = state["messages"]
last_message = messages[-1]
# If the LLM makes a tool call, then we route to the "tools" node
if last_message.tool_calls:
return "tools"
# Otherwise, we stop (reply to the user)
return END


# Define the function that calls the model
def call_model(state: MessagesState):
messages = [SystemMessage(content=system_message)] + state["messages"]
response: BaseMessage = model.invoke(messages) # , model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
# We return a list, because this will get added to the existing list
# if socket is not none and the response content is not an empty string
return {"messages": [response]}

# Define a new graph
workflow = StateGraph(MessagesState)

# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("tools", tool_node)

# Set the entrypoint as `agent`
workflow.add_edge(START, "agent")

# Add a conditional edge
workflow.add_conditional_edges("agent", should_continue)

# Add a normal edge from `tools` to `agent`
workflow.add_edge("tools", "agent")

checkpointer = S3Saver(bucket_name=checkpoint_bucket, compression="gzip")
self.search_agent = workflow.compile(checkpointer=checkpointer)

def invoke(self, question: str, ref: str, *, callbacks: List[BaseCallbackHandler] = [], forget: bool = False, **kwargs):
if forget:
delete_checkpoints(self.checkpoint_bucket, ref)

return self.search_agent.invoke(
{"messages": [HumanMessage(content=question)]},
config={"configurable": {"thread_id": ref}, "callbacks": callbacks},
**kwargs
)
32 changes: 2 additions & 30 deletions chat/src/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,6 @@ class EventConfig:
Default values are set for the following properties which can be overridden in the payload message.
"""

DEFAULT_ATTRIBUTES = [
"accession_number",
"alternate_title",
"api_link",
"canonical_link",
"caption",
"collection",
"contributor",
"date_created",
"date_created_edtf",
"description",
"genre",
"id",
"identifier",
"keywords",
"language",
"notes",
"physical_description_material",
"physical_description_size",
"provenance",
"publisher",
"rights_statement",
"subject",
"table_of_contents",
"thumbnail",
"title",
"visibility",
"work_type",
]

api_token: ApiToken = field(init=False)
debug_mode: bool = field(init=False)
event: dict = field(default_factory=dict)
Expand All @@ -66,6 +36,7 @@ class EventConfig:
is_superuser: bool = field(init=False)
k: int = field(init=False)
max_tokens: int = field(init=False)
model: str = field(init=False)
payload: dict = field(default_factory=dict)
prompt_text: str = field(init=False)
prompt: ChatPromptTemplate = field(init=False)
Expand All @@ -88,6 +59,7 @@ def __post_init__(self):
self.is_superuser = self.api_token.is_superuser()
self.k = self._get_k()
self.max_tokens = min(self.payload.get("max_tokens", MAX_TOKENS), MAX_TOKENS)
self.model = self._get_payload_value_with_superuser_check("model", "us.anthropic.claude-3-5-sonnet-20241022-v2:0")
self.prompt_text = self._get_prompt_text()
self.request_context = self.event.get("requestContext", {})
self.question = self.payload.get("question")
Expand Down
88 changes: 30 additions & 58 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,13 @@
from datetime import datetime
from event_config import EventConfig
# from honeybadger import honeybadger
from agent.s3_saver import delete_checkpoints
from agent.search_agent import search_agent
from langchain_core.messages import HumanMessage
from agent.search_agent import SearchAgent
from agent.agent_handler import AgentHandler
from agent.metrics_handler import MetricsHandler

# honeybadger.configure()
# logging.getLogger("honeybadger").addHandler(logging.StreamHandler())

RESPONSE_TYPES = {
"base": ["answer", "ref"],
"debug": [
"answer",
"attributes",
"deployment_name",
"is_superuser",
"k",
"prompt",
"question",
"ref",
"temperature",
"text_key",
"token_counts",
],
"log": [
"answer",
"deployment_name",
"is_superuser",
"k",
"prompt",
"question",
"ref",
"size",
"source_documents",
"temperature",
"token_counts",
"is_dev_team",
],
"error": ["question", "error", "source_documents"],
}


def handler(event, context):
config = EventConfig(event)
socket = event.get("socket", None)
Expand All @@ -56,29 +22,16 @@ def handler(event, context):
config.socket.send({"type": "error", "message": "Unauthorized"})
return {"statusCode": 401, "body": "Unauthorized"}

if config.forget:
delete_checkpoints(os.getenv("CHECKPOINT_BUCKET_NAME"), config.ref)

if config.question is None or config.question == "":
config.socket.send({"type": "error", "message": "Question cannot be blank"})
return {"statusCode": 400, "body": "Question cannot be blank"}

log_group = os.getenv("METRICS_LOG_GROUP")
log_stream = context.log_stream_name
if log_group and ensure_log_stream_exists(log_group, log_stream):
log_client = boto3.client("logs")
log_events = [{"timestamp": timestamp(), "message": "Hello world"}]
log_client.put_log_events(
logGroupName=log_group, logStreamName=log_stream, logEvents=log_events
)

callbacks = [AgentHandler(config.socket, config.ref)]
metrics = MetricsHandler()
callbacks = [AgentHandler(config.socket, config.ref), metrics]
search_agent = SearchAgent(model=config.model, streaming=True)
try:
search_agent.invoke(
{"messages": [HumanMessage(content=config.question)]},
config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks},
debug=False
)
search_agent.invoke(config.question, config.ref, forget=config.forget, callbacks=callbacks)
log_metrics(context, metrics, config)
except Exception as e:
print(f"Error: {e}")
print(traceback.format_exc())
Expand All @@ -90,10 +43,29 @@ def handler(event, context):
return {"statusCode": 200}


def reshape_response(response, type):
return {k: response[k] for k in RESPONSE_TYPES[type]}


def log_metrics(context, metrics, config):
log_group = os.getenv("METRICS_LOG_GROUP")
log_stream = context.log_stream_name
if log_group and ensure_log_stream_exists(log_group, log_stream):
log_client = boto3.client("logs")
log_events = [{
"timestamp": timestamp(),
"message": json.dumps({
"answer": metrics.answers,
"is_dev_team": config.api_token.is_dev_team(),
"is_superuser": config.api_token.is_superuser(),
"k": config.k,
"model": config.model,
"question": config.question,
"ref": config.ref,
"artifacts": metrics.artifacts,
"token_counts": metrics.accumulator,
})
}]
log_client.put_log_events(
logGroupName=log_group, logStreamName=log_stream, logEvents=log_events
)

def ensure_log_stream_exists(log_group, log_stream):
log_client = boto3.client("logs")
try:
Expand Down
Loading

0 comments on commit da4505b

Please sign in to comment.