From 6df761198a6f4b4d6379fc64e53caefed2f0bd35 Mon Sep 17 00:00:00 2001 From: "Michael B. Klein" Date: Wed, 4 Dec 2024 17:53:41 +0000 Subject: [PATCH] Begin to add LangGraph agent code --- chat/src/agent/__init__.py | 0 chat/src/agent/agent.py | 59 +++++++++++++++++++ chat/src/agent/tools.py | 35 +++++++++++ chat/src/handlers/opensearch_neural_search.py | 19 ++++++ chat/src/requirements.txt | 9 +-- 5 files changed, 118 insertions(+), 4 deletions(-) create mode 100644 chat/src/agent/__init__.py create mode 100644 chat/src/agent/agent.py create mode 100644 chat/src/agent/tools.py diff --git a/chat/src/agent/__init__.py b/chat/src/agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chat/src/agent/agent.py b/chat/src/agent/agent.py new file mode 100644 index 00000000..0e99fcf6 --- /dev/null +++ b/chat/src/agent/agent.py @@ -0,0 +1,59 @@ +import os + +from typing import Literal + +from agent.tools import search, aggregate +from langgraph.checkpoint.memory import MemorySaver +from langgraph.graph import END, START, StateGraph, MessagesState +from langgraph.prebuilt import ToolNode +from setup import openai_chat_client + +tools = [search, aggregate] + +tool_node = ToolNode(tools) + +model = openai_chat_client().bind_tools(tools) + +# Define the function that determines whether to continue or not +def should_continue(state: MessagesState) -> Literal["tools", END]: + messages = state["messages"] + last_message = messages[-1] + # If the LLM makes a tool call, then we route to the "tools" node + if last_message.tool_calls: + return "tools" + # Otherwise, we stop (reply to the user) + return END + + +# Define the function that calls the model +def call_model(state: MessagesState): + messages = state["messages"] + response = model.invoke(messages, model=os.getenv("AZURE_DEPLOYMENT_NAME")) + # We return a list, because this will get added to the existing list + return {"messages": [response]} + + +# Define a new graph +workflow = StateGraph(MessagesState) + +# Define the two nodes we will cycle between +workflow.add_node("agent", call_model) +workflow.add_node("tools", tool_node) + +# Set the entrypoint as `agent` +workflow.add_edge(START, "agent") + +# Add a conditional edge +workflow.add_conditional_edges( + "agent", + should_continue, +) + +# Add a normal edge from `tools` to `agent` +workflow.add_edge("tools", "agent") + +# Initialize memory to persist state between graph runs +checkpointer = MemorySaver() + +# Compile the graph +app = workflow.compile(checkpointer=checkpointer, debug=True) diff --git a/chat/src/agent/tools.py b/chat/src/agent/tools.py new file mode 100644 index 00000000..bc081aaa --- /dev/null +++ b/chat/src/agent/tools.py @@ -0,0 +1,35 @@ +import json + +from langchain_core.tools import tool +from opensearch_client import opensearch_vector_store + +@tool(response_format="content_and_artifact") +def search(query: str): + """Perform a semantic search of Northwestern University Library digital collections. When answering a search query, ground your answer in the context of the results with references to the document's metadata.""" + query_results = opensearch_vector_store.similarity_search(query, size=20) + return json.dumps(query_results, default=str), query_results + +@tool(response_format="content_and_artifact") +def aggregate(aggregation_query: str): + """ + Perform a quantitative aggregation on the OpenSearch index. + + Available fields: + api_link, api_model, ark, collection.title.keyword, contributor.label.keyword, contributor.variants, + create_date, creator.variants, date_created, embedding_model, embedding_text_length, + folder_name, folder_number, genre.variants, id, identifier, indexed_at, language.variants, + legacy_identifier, library_unit, location.variants, modified_date, notes.note, notes.type, + physical_description_material, physical_description_size, preservation_level, provenance, published, publisher, + related_url.url, related_url.label, representative_file_set.aspect_ratio, representative_file_set.url, rights_holder, + series, status, style_period.label.keyword, style_period.variants, subject.label.keyword, subject.role, + subject.variants, table_of_contents, technique.label.keyword, technique.variants, title.keyword, visibility, work_type + + Examples: + - Number of collections: collection.title.keyword + - Number of works by work type: work_type + """ + try: + response = opensearch_vector_store.aggregations_search(aggregation_query) + return json.dumps(response, default=str), response + except Exception as e: + return json.dumps({"error": str(e)}), None diff --git a/chat/src/handlers/opensearch_neural_search.py b/chat/src/handlers/opensearch_neural_search.py index 0530eab0..a7616967 100644 --- a/chat/src/handlers/opensearch_neural_search.py +++ b/chat/src/handlers/opensearch_neural_search.py @@ -55,6 +55,25 @@ def similarity_search_with_score( return documents_with_scores + def aggregations_search(self, field: str, **kwargs: Any) -> dict: + """Perform a search with aggregations and return the aggregation results.""" + dsl = { + "size": 0, + "aggs": {"aggregation_result": {"terms": {"field": field}}}, + } + + response = self.client.search( + index=self.index, + body=dsl, + params=( + {"search_pipeline": self.search_pipeline} + if self.search_pipeline + else None + ), + ) + + return response.get("aggregations", {}) + def add_texts(self, texts: List[str], metadatas: List[dict], **kwargs: Any) -> None: pass diff --git a/chat/src/requirements.txt b/chat/src/requirements.txt index 79a4f375..54462823 100644 --- a/chat/src/requirements.txt +++ b/chat/src/requirements.txt @@ -1,14 +1,15 @@ # Runtime Dependencies boto3~=1.34 -honeybadger +honeybadger~=0.20 langchain~=0.2 langchain-aws~=0.1 langchain-openai~=0.1 +langgraph~=0.2 openai~=1.35 -opensearch-py +opensearch-py~=2.8 pyjwt~=2.6.0 python-dotenv~=1.0.0 -requests -requests-aws4auth +requests~=2.32 +requests-aws4auth~=1.3 tiktoken~=0.7 wheel~=0.40 \ No newline at end of file