Skip to content

Commit

Permalink
Improve prompting
Browse files Browse the repository at this point in the history
Compress checkpoint data
Add more fields back to hybrid query
  • Loading branch information
mbklein committed Dec 11, 2024
1 parent 185691b commit 892e4d2
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 17 deletions.
5 changes: 4 additions & 1 deletion chat/src/agent/agent_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from websocket import Websocket

Expand Down Expand Up @@ -27,6 +27,9 @@ def __init__(self, socket: Websocket, ref: str, *args: List[Any], **kwargs: Dict
self.ref = ref
super().__init__(*args, **kwargs)

def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], metadata: Optional[dict[str, Any]] = None, **kwargs: Dict[str, Any]):
self.socket.send({"type": "start", "ref": self.ref, "message": {"model": metadata.get("model_deployment")}})

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
content = response.generations[0][0].text
if content != "":
Expand Down
3 changes: 0 additions & 3 deletions chat/src/agent/dynamodb_cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def delete_thread(table_name, thread_id, region_name=os.getenv("AWS_REGION")):
items.extend(response.get('Items', []))

if not items:
print(f"No items found with thread_id: {thread_id}")
return

# Prepare delete requests in batches of 25 (DynamoDB limit for BatchWriteItem)
Expand All @@ -56,7 +55,5 @@ def delete_thread(table_name, thread_id, region_name=os.getenv("AWS_REGION")):
}
batch.delete_item(Key=key)

print(f"Successfully deleted {len(items)} items with thread_id: {thread_id}")

except ClientError as e:
print(f"An error occurred: {e.response['Error']['Message']}")
18 changes: 10 additions & 8 deletions chat/src/agent/dynamodb_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import langchain_core.messages as langchain_messages
import time
import json
import bz2
import base64

from langchain_core.runnables import RunnableConfig

Expand All @@ -30,11 +32,13 @@ def default(o):
raise TypeError(f'Object of type {o.__class__.__name__} is not JSON serializable')

json_str = json.dumps(obj, default=default)
return 'json', json_str
compressed_str = base64.b64encode(bz2.compress(json_str.encode("utf-8"))).decode("utf-8")
return 'compressed_json', compressed_str

def loads_typed(self, data: Tuple[str, Any]) -> Any:
type_, json_str = data

type_, compressed_str = data
json_str = bz2.decompress(base64.b64decode(compressed_str)).decode("utf-8")

def object_hook(dct):
if '__type__' in dct:
type_name = dct['__type__']
Expand All @@ -48,7 +52,6 @@ def object_hook(dct):

obj = json.loads(json_str, object_hook=object_hook)
return obj


class DynamoDBSaver(BaseCheckpointSaver):
"""A checkpoint saver that stores checkpoints in DynamoDB using JSON-compatible formats."""
Expand All @@ -74,7 +77,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
thread_id = config["configurable"]["thread_id"]
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
checkpoint_id = get_checkpoint_id(config)

if checkpoint_id:
# Need to scan for item with matching checkpoint_id
response = self.table.query(
Expand Down Expand Up @@ -139,7 +142,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:

# Retrieve metadata directly
metadata = item['metadata']
metadata = self.serde.loads(metadata)
metadata = self.serde.loads_typed(("compressed_json", metadata))
# Get parent config if any
parent_checkpoint_id = item.get('parent_checkpoint_id')
if parent_checkpoint_id:
Expand Down Expand Up @@ -197,12 +200,11 @@ def list(

# Retrieve the checkpoint data directly
checkpoint_data = item['checkpoint']

checkpoint = self.serde.loads_typed((checkpoint_type, checkpoint_data))

# Retrieve metadata directly
metadata = item['metadata']
metadata = self.serde.loads(metadata)
metadata = self.serde.loads_typed(("compressed_json", metadata))

parent_checkpoint_id = item.get('parent_checkpoint_id')
if parent_checkpoint_id:
Expand Down
12 changes: 10 additions & 2 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@
from agent.dynamodb_saver import DynamoDBSaver
from agent.tools import aggregate, discover_fields, search
from langchain_core.messages.base import BaseMessage
from langchain_core.messages.system import SystemMessage
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from setup import openai_chat_client

# Keep your answer concise and keep reading time under 45 seconds.

system_message = """
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that
support your answer. Answer in raw markdown, but not within a code block. When citing source documents, construct Markdown
links using the document's canonical_link field.
"""

tools = [discover_fields, search, aggregate]

tool_node = ToolNode(tools)
Expand All @@ -28,13 +37,12 @@ def should_continue(state: MessagesState) -> Literal["tools", END]:

# Define the function that calls the model
def call_model(state: MessagesState):
messages = state["messages"]
messages = [SystemMessage(content=system_message)] + state["messages"]
response: BaseMessage = model.invoke(messages, model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID"))
# We return a list, because this will get added to the existing list
# if socket is not none and the response content is not an empty string
return {"messages": [response]}


# Define a new graph
workflow = StateGraph(MessagesState)

Expand Down
4 changes: 3 additions & 1 deletion chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import boto3
import json
import os
import traceback
from datetime import datetime
from event_config import EventConfig
# from honeybadger import honeybadger
Expand Down Expand Up @@ -78,11 +79,12 @@ def handler(event, context):
try:
search_agent.invoke(
{"messages": [HumanMessage(content=config.question)]},
config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks},
config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks, "metadata": {"model_deployment": os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")}},
debug=False
)
except Exception as e:
print(f"Error: {e}")
print(traceback.format_exc())
error_response = {"type": "error", "message": "An unexpected error occurred. Please try again later."}
if config.socket:
config.socket.send(error_response)
Expand Down
9 changes: 7 additions & 2 deletions chat/src/helpers/hybrid_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@ def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k:
result = {
"size": kwargs.get("size", 20),
"_source": {
"include": ["title", "description", "collection.title", "id", "collection.id"],
"exclude": ["embedding"]
"include": ["abstract", "accession_number", "alternate_title", "api_link", "ark", "canonical_link",
"collection", "contributor", "creator", "date_created_edtf", "description", "genre",
"id", "iiif_manifest", "language", "library_unit", "license", "location", "physical_description_material",
"physical_description_size", "provenance", "publisher", "rights_holder", "rights_statement",
"scope_and_contents", "series", "style_period", "subject", "table_of_contents", "technique",
"thumbnail", "title", "visibility", "work_type"],
"exclude": ["embedding", "embedding_model"]
},
"query": {
"hybrid": {
Expand Down

0 comments on commit 892e4d2

Please sign in to comment.