-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add metrics-gathering callback handler
Make search agent a class Wire model, stream_response input options back up Add jupyter notebook for rapid iteration Remove errant #* comment
- Loading branch information
Showing
12 changed files
with
481 additions
and
264 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import Any, Dict | ||
from langchain_core.callbacks import BaseCallbackHandler | ||
from langchain_core.outputs import LLMResult | ||
from langchain_core.messages.tool import ToolMessage | ||
import json | ||
|
||
class MetricsHandler(BaseCallbackHandler): | ||
def __init__(self, *args, **kwargs): | ||
self.accumulator = {} | ||
self.answers = [] | ||
self.artifacts = [] | ||
super().__init__(*args, **kwargs) | ||
|
||
def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]): | ||
for generation in response.generations[0]: | ||
self.answers.append(generation.text) | ||
for k, v in generation.message.usage_metadata.items(): | ||
if k not in self.accumulator: | ||
self.accumulator[k] = v | ||
else: | ||
self.accumulator[k] += v | ||
|
||
def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]): | ||
match output.name: | ||
case "aggregate": | ||
self.artifacts.append({"type": "aggregation", "artifact": output.artifact.get("aggregation_result", {})}) | ||
case "search": | ||
try: | ||
source_urls = [doc.metadata["api_link"] for doc in output.artifact] | ||
self.artifacts.append({"type": "source_urls", "artifact": source_urls}) | ||
except json.decoder.JSONDecodeError as e: | ||
print(f"Invalid json ({e}) returned from {output.name} tool: {output.content}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,81 @@ | ||
import os | ||
|
||
from typing import Literal | ||
from typing import Literal, List | ||
|
||
from agent.s3_saver import S3Saver | ||
from agent.s3_saver import S3Saver, delete_checkpoints | ||
from agent.tools import aggregate, discover_fields, search | ||
from langchain_aws import ChatBedrock | ||
from langchain_core.messages import HumanMessage | ||
from langchain_core.messages.base import BaseMessage | ||
from langchain_core.callbacks import BaseCallbackHandler | ||
from langchain_core.messages.system import SystemMessage | ||
from langgraph.graph import END, START, StateGraph, MessagesState | ||
from langgraph.prebuilt import ToolNode | ||
from setup import chat_client | ||
|
||
# Keep your answer concise and keep reading time under 45 seconds. | ||
|
||
system_message = """ | ||
DEFAULT_SYSTEM_MESSAGE = """ | ||
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that | ||
support your answer. Answer in raw markdown, but not within a code block. When citing source documents, construct Markdown | ||
links using the document's canonical_link field. Do not include intermediate messages explaining your process. | ||
""" | ||
|
||
tools = [discover_fields, search, aggregate] | ||
|
||
tool_node = ToolNode(tools) | ||
|
||
model = chat_client(streaming=True).bind_tools(tools) | ||
|
||
# Define the function that determines whether to continue or not | ||
def should_continue(state: MessagesState) -> Literal["tools", END]: | ||
messages = state["messages"] | ||
last_message = messages[-1] | ||
# If the LLM makes a tool call, then we route to the "tools" node | ||
if last_message.tool_calls: | ||
return "tools" | ||
# Otherwise, we stop (reply to the user) | ||
return END | ||
|
||
|
||
# Define the function that calls the model | ||
def call_model(state: MessagesState): | ||
messages = [SystemMessage(content=system_message)] + state["messages"] | ||
response: BaseMessage = model.invoke(messages) # , model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") | ||
# We return a list, because this will get added to the existing list | ||
# if socket is not none and the response content is not an empty string | ||
return {"messages": [response]} | ||
|
||
# Define a new graph | ||
workflow = StateGraph(MessagesState) | ||
|
||
# Define the two nodes we will cycle between | ||
workflow.add_node("agent", call_model) | ||
workflow.add_node("tools", tool_node) | ||
|
||
# Set the entrypoint as `agent` | ||
workflow.add_edge(START, "agent") | ||
|
||
# Add a conditional edge | ||
workflow.add_conditional_edges( | ||
"agent", | ||
should_continue, | ||
) | ||
|
||
# Add a normal edge from `tools` to `agent` | ||
workflow.add_edge("tools", "agent") | ||
|
||
# checkpointer = DynamoDBSaver(os.getenv("CHECKPOINT_TABLE"), os.getenv("CHECKPOINT_WRITES_TABLE"), os.getenv("AWS_REGION", "us-east-1")) | ||
|
||
checkpointer = S3Saver( | ||
bucket_name=os.getenv("CHECKPOINT_BUCKET_NAME"), | ||
region_name=os.getenv("AWS_REGION"), | ||
compression="gzip") | ||
|
||
search_agent = workflow.compile(checkpointer=checkpointer) | ||
class SearchAgent: | ||
def __init__( | ||
self, | ||
*, | ||
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME"), | ||
system_message: str = DEFAULT_SYSTEM_MESSAGE, | ||
**kwargs): | ||
|
||
self.checkpoint_bucket = checkpoint_bucket | ||
|
||
tools = [discover_fields, search, aggregate] | ||
tool_node = ToolNode(tools) | ||
model = ChatBedrock(**kwargs).bind_tools(tools) | ||
|
||
# Define the function that determines whether to continue or not | ||
def should_continue(state: MessagesState) -> Literal["tools", END]: | ||
messages = state["messages"] | ||
last_message = messages[-1] | ||
# If the LLM makes a tool call, then we route to the "tools" node | ||
if last_message.tool_calls: | ||
return "tools" | ||
# Otherwise, we stop (reply to the user) | ||
return END | ||
|
||
|
||
# Define the function that calls the model | ||
def call_model(state: MessagesState): | ||
messages = [SystemMessage(content=system_message)] + state["messages"] | ||
response: BaseMessage = model.invoke(messages) # , model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") | ||
# We return a list, because this will get added to the existing list | ||
# if socket is not none and the response content is not an empty string | ||
return {"messages": [response]} | ||
|
||
# Define a new graph | ||
workflow = StateGraph(MessagesState) | ||
|
||
# Define the two nodes we will cycle between | ||
workflow.add_node("agent", call_model) | ||
workflow.add_node("tools", tool_node) | ||
|
||
# Set the entrypoint as `agent` | ||
workflow.add_edge(START, "agent") | ||
|
||
# Add a conditional edge | ||
workflow.add_conditional_edges("agent", should_continue) | ||
|
||
# Add a normal edge from `tools` to `agent` | ||
workflow.add_edge("tools", "agent") | ||
|
||
checkpointer = S3Saver(bucket_name=checkpoint_bucket, compression="gzip") | ||
self.search_agent = workflow.compile(checkpointer=checkpointer) | ||
|
||
def invoke(self, question: str, ref: str, *, callbacks: List[BaseCallbackHandler] = [], forget: bool = False, **kwargs): | ||
if forget: | ||
delete_checkpoints(self.checkpoint_bucket, ref) | ||
|
||
return self.search_agent.invoke( | ||
{"messages": [HumanMessage(content=question)]}, | ||
config={"configurable": {"thread_id": ref}, "callbacks": callbacks}, | ||
**kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.