From 892e4d203d8766828ac97a8878e2361bbd8883b5 Mon Sep 17 00:00:00 2001 From: "Michael B. Klein" Date: Wed, 11 Dec 2024 23:16:58 +0000 Subject: [PATCH] Improve prompting Compress checkpoint data Add more fields back to hybrid query --- chat/src/agent/agent_handler.py | 5 ++++- chat/src/agent/dynamodb_cleaner.py | 3 --- chat/src/agent/dynamodb_saver.py | 18 ++++++++++-------- chat/src/agent/search_agent.py | 12 ++++++++++-- chat/src/handlers/chat.py | 4 +++- chat/src/helpers/hybrid_query.py | 9 +++++++-- 6 files changed, 34 insertions(+), 17 deletions(-) diff --git a/chat/src/agent/agent_handler.py b/chat/src/agent/agent_handler.py index 1823bef..be6fbaa 100644 --- a/chat/src/agent/agent_handler.py +++ b/chat/src/agent/agent_handler.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from websocket import Websocket @@ -27,6 +27,9 @@ def __init__(self, socket: Websocket, ref: str, *args: List[Any], **kwargs: Dict self.ref = ref super().__init__(*args, **kwargs) + def on_llm_start(self, serialized: dict[str, Any], prompts: list[str], metadata: Optional[dict[str, Any]] = None, **kwargs: Dict[str, Any]): + self.socket.send({"type": "start", "ref": self.ref, "message": {"model": metadata.get("model_deployment")}}) + def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]): content = response.generations[0][0].text if content != "": diff --git a/chat/src/agent/dynamodb_cleaner.py b/chat/src/agent/dynamodb_cleaner.py index d2c9346..8ac78fe 100644 --- a/chat/src/agent/dynamodb_cleaner.py +++ b/chat/src/agent/dynamodb_cleaner.py @@ -44,7 +44,6 @@ def delete_thread(table_name, thread_id, region_name=os.getenv("AWS_REGION")): items.extend(response.get('Items', [])) if not items: - print(f"No items found with thread_id: {thread_id}") return # Prepare delete requests in batches of 25 (DynamoDB limit for BatchWriteItem) @@ -56,7 +55,5 @@ def delete_thread(table_name, thread_id, region_name=os.getenv("AWS_REGION")): } batch.delete_item(Key=key) - print(f"Successfully deleted {len(items)} items with thread_id: {thread_id}") - except ClientError as e: print(f"An error occurred: {e.response['Error']['Message']}") diff --git a/chat/src/agent/dynamodb_saver.py b/chat/src/agent/dynamodb_saver.py index 316810c..9387c75 100644 --- a/chat/src/agent/dynamodb_saver.py +++ b/chat/src/agent/dynamodb_saver.py @@ -5,6 +5,8 @@ import langchain_core.messages as langchain_messages import time import json +import bz2 +import base64 from langchain_core.runnables import RunnableConfig @@ -30,11 +32,13 @@ def default(o): raise TypeError(f'Object of type {o.__class__.__name__} is not JSON serializable') json_str = json.dumps(obj, default=default) - return 'json', json_str + compressed_str = base64.b64encode(bz2.compress(json_str.encode("utf-8"))).decode("utf-8") + return 'compressed_json', compressed_str def loads_typed(self, data: Tuple[str, Any]) -> Any: - type_, json_str = data - + type_, compressed_str = data + json_str = bz2.decompress(base64.b64decode(compressed_str)).decode("utf-8") + def object_hook(dct): if '__type__' in dct: type_name = dct['__type__'] @@ -48,7 +52,6 @@ def object_hook(dct): obj = json.loads(json_str, object_hook=object_hook) return obj - class DynamoDBSaver(BaseCheckpointSaver): """A checkpoint saver that stores checkpoints in DynamoDB using JSON-compatible formats.""" @@ -74,7 +77,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"].get("checkpoint_ns", "") checkpoint_id = get_checkpoint_id(config) - + if checkpoint_id: # Need to scan for item with matching checkpoint_id response = self.table.query( @@ -139,7 +142,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: # Retrieve metadata directly metadata = item['metadata'] - metadata = self.serde.loads(metadata) + metadata = self.serde.loads_typed(("compressed_json", metadata)) # Get parent config if any parent_checkpoint_id = item.get('parent_checkpoint_id') if parent_checkpoint_id: @@ -197,12 +200,11 @@ def list( # Retrieve the checkpoint data directly checkpoint_data = item['checkpoint'] - checkpoint = self.serde.loads_typed((checkpoint_type, checkpoint_data)) # Retrieve metadata directly metadata = item['metadata'] - metadata = self.serde.loads(metadata) + metadata = self.serde.loads_typed(("compressed_json", metadata)) parent_checkpoint_id = item.get('parent_checkpoint_id') if parent_checkpoint_id: diff --git a/chat/src/agent/search_agent.py b/chat/src/agent/search_agent.py index e027a67..f250b43 100644 --- a/chat/src/agent/search_agent.py +++ b/chat/src/agent/search_agent.py @@ -5,10 +5,19 @@ from agent.dynamodb_saver import DynamoDBSaver from agent.tools import aggregate, discover_fields, search from langchain_core.messages.base import BaseMessage +from langchain_core.messages.system import SystemMessage from langgraph.graph import END, START, StateGraph, MessagesState from langgraph.prebuilt import ToolNode from setup import openai_chat_client +# Keep your answer concise and keep reading time under 45 seconds. + +system_message = """ +Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that +support your answer. Answer in raw markdown, but not within a code block. When citing source documents, construct Markdown +links using the document's canonical_link field. +""" + tools = [discover_fields, search, aggregate] tool_node = ToolNode(tools) @@ -28,13 +37,12 @@ def should_continue(state: MessagesState) -> Literal["tools", END]: # Define the function that calls the model def call_model(state: MessagesState): - messages = state["messages"] + messages = [SystemMessage(content=system_message)] + state["messages"] response: BaseMessage = model.invoke(messages, model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")) # We return a list, because this will get added to the existing list # if socket is not none and the response content is not an empty string return {"messages": [response]} - # Define a new graph workflow = StateGraph(MessagesState) diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index 2530876..3db7b6b 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -2,6 +2,7 @@ import boto3 import json import os +import traceback from datetime import datetime from event_config import EventConfig # from honeybadger import honeybadger @@ -78,11 +79,12 @@ def handler(event, context): try: search_agent.invoke( {"messages": [HumanMessage(content=config.question)]}, - config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks}, + config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks, "metadata": {"model_deployment": os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")}}, debug=False ) except Exception as e: print(f"Error: {e}") + print(traceback.format_exc()) error_response = {"type": "error", "message": "An unexpected error occurred. Please try again later."} if config.socket: config.socket.send(error_response) diff --git a/chat/src/helpers/hybrid_query.py b/chat/src/helpers/hybrid_query.py index 76c4c76..ecf3d38 100644 --- a/chat/src/helpers/hybrid_query.py +++ b/chat/src/helpers/hybrid_query.py @@ -15,8 +15,13 @@ def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k: result = { "size": kwargs.get("size", 20), "_source": { - "include": ["title", "description", "collection.title", "id", "collection.id"], - "exclude": ["embedding"] + "include": ["abstract", "accession_number", "alternate_title", "api_link", "ark", "canonical_link", + "collection", "contributor", "creator", "date_created_edtf", "description", "genre", + "id", "iiif_manifest", "language", "library_unit", "license", "location", "physical_description_material", + "physical_description_size", "provenance", "publisher", "rights_holder", "rights_statement", + "scope_and_contents", "series", "style_period", "subject", "table_of_contents", "technique", + "thumbnail", "title", "visibility", "work_type"], + "exclude": ["embedding", "embedding_model"] }, "query": { "hybrid": {