diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index c1f3e46d..ed9650cc 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -21,7 +21,7 @@ jobs: with: python-version: '3.9' cache-dependency-path: chat/src/requirements.txt - - run: pip install -r requirements.txt + - run: pip install -r requirements.txt && pip install -r requirements-dev.txt working-directory: ./chat/src - name: Check code style run: ruff check . diff --git a/Makefile b/Makefile index ee8622d4..e2d9e6a4 100644 --- a/Makefile +++ b/Makefile @@ -28,8 +28,10 @@ help: echo "make cover-python | run python tests with coverage" .aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json chat/dependencies/requirements.txt chat/src/requirements.txt sed -Ei.orig 's/^(\s+)#\*\s/\1/' template.yaml + sed -Ei.orig 's/^(\s+)#\*\s/\1/' chat/template.yaml sam build --cached --parallel mv template.yaml.orig template.yaml + mv chat/template.yaml.orig chat/template.yaml deps-node: cd node/src ;\ npm list >/dev/null 2>&1 ;\ @@ -48,7 +50,7 @@ style-node: deps-node test-node: deps-node cd node && npm run test deps-python: - cd chat/src && pip install -r requirements.txt + cd chat/src && pip install -r requirements.txt && pip install -r requirements-dev.txt cover-python: deps-python cd chat && export SKIP_WEAVIATE_SETUP=True && coverage run --source=src -m unittest -v && coverage report --skip-empty cover-html-python: deps-python diff --git a/chat/dependencies/requirements.txt b/chat/dependencies/requirements.txt index d784fe19..6c4a743b 100644 --- a/chat/dependencies/requirements.txt +++ b/chat/dependencies/requirements.txt @@ -1,12 +1,13 @@ -boto3~=1.34.13 +boto3~=1.34 honeybadger -langchain -langchain-community -openai~=0.27.8 +langchain~=0.2 +langchain-aws~=0.1 +langchain-openai~=0.1 +openai~=1.35 opensearch-py pyjwt~=2.6.0 python-dotenv~=1.0.0 requests requests-aws4auth -tiktoken~=0.4.0 -wheel~=0.40.0 \ No newline at end of file +tiktoken~=0.7 +wheel~=0.40 \ No newline at end of file diff --git a/chat/src/event_config.py b/chat/src/event_config.py index cb339455..e5d3ae08 100644 --- a/chat/src/event_config.py +++ b/chat/src/event_config.py @@ -2,8 +2,8 @@ import json from dataclasses import dataclass, field -from langchain.chains.qa_with_sources import load_qa_with_sources_chain -from langchain.prompts import PromptTemplate + +from langchain_core.prompts import ChatPromptTemplate from setup import ( opensearch_client, opensearch_vector_store, @@ -19,6 +19,7 @@ DOCUMENT_VARIABLE_NAME = "context" K_VALUE = 5 MAX_K = 100 +MAX_TOKENS = 1000 TEMPERATURE = 0.2 TEXT_KEY = "id" VERSION = "2024-02-01" @@ -42,19 +43,21 @@ class EventConfig: azure_resource_name: str = field(init=False) debug_mode: bool = field(init=False) deployment_name: str = field(init=False) - document_prompt: PromptTemplate = field(init=False) + document_prompt: ChatPromptTemplate = field(init=False) event: dict = field(default_factory=dict) is_logged_in: bool = field(init=False) k: int = field(init=False) + max_tokens: int = field(init=False) openai_api_version: str = field(init=False) payload: dict = field(default_factory=dict) prompt_text: str = field(init=False) - prompt: PromptTemplate = field(init=False) + prompt: ChatPromptTemplate = field(init=False) question: str = field(init=False) ref: str = field(init=False) request_context: dict = field(init=False) temperature: float = field(init=False) socket: Websocket = field(init=False, default=None) + stream_response: bool = field(init=False) text_key: str = field(init=False) def __post_init__(self): @@ -67,17 +70,17 @@ def __post_init__(self): self.deployment_name = self._get_deployment_name() self.is_logged_in = self.api_token.is_logged_in() self.k = self._get_k() + self.max_tokens = self._get_max_tokens() self.openai_api_version = self._get_openai_api_version() self.prompt_text = self._get_prompt_text() self.request_context = self.event.get("requestContext", {}) self.question = self.payload.get("question") self.ref = self.payload.get("ref") + self.stream_response = self.payload.get("stream_response", not self.debug_mode) self.temperature = self._get_temperature() self.text_key = self._get_text_key() self.document_prompt = self._get_document_prompt() - self.prompt = PromptTemplate( - template=self.prompt_text, input_variables=["question", "context"] - ) + self.prompt = ChatPromptTemplate.from_template(self.prompt_text) def _get_payload_value_with_superuser_check(self, key, default): if self.api_token.is_superuser(): @@ -124,6 +127,9 @@ def _get_openai_api_version(self): "openai_api_version", VERSION ) + def _get_max_tokens(self): + return self._get_payload_value_with_superuser_check("max_tokens", MAX_TOKENS) + def _get_prompt_text(self): return self._get_payload_value_with_superuser_check("prompt", prompt_template()) @@ -134,10 +140,7 @@ def _get_text_key(self): return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY) def _get_document_prompt(self): - return PromptTemplate( - template=document_template(self.attributes), - input_variables=["title", "id"] + self.attributes, - ) + return ChatPromptTemplate.from_template(document_template(self.attributes)) def debug_message(self): return { @@ -170,28 +173,18 @@ def setup_websocket(self, socket=None): def setup_llm_request(self): self._setup_vector_store() self._setup_chat_client() - self._setup_chain() def _setup_vector_store(self): self.opensearch = opensearch_vector_store() def _setup_chat_client(self): self.client = openai_chat_client( - deployment_name=self.deployment_name, - openai_api_base=self.azure_endpoint, + azure_deployment=self.deployment_name, + azure_endpoint=self.azure_endpoint, openai_api_version=self.openai_api_version, - callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)], + callbacks=[StreamingSocketCallbackHandler(self.socket, stream=self.stream_response)], streaming=True, - ) - - def _setup_chain(self): - self.chain = load_qa_with_sources_chain( - self.client, - chain_type=CHAIN_TYPE, - prompt=self.prompt, - document_prompt=self.document_prompt, - document_variable_name=DOCUMENT_VARIABLE_NAME, - verbose=self._to_bool(os.getenv("VERBOSE")), + max_tokens=self.max_tokens ) def _is_debug_mode_enabled(self): diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index a14ade00..04e92242 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -4,7 +4,7 @@ import os from datetime import datetime from event_config import EventConfig -from helpers.response import prepare_response +from helpers.response import Response from honeybadger import honeybadger honeybadger.configure() @@ -35,7 +35,8 @@ def handler(event, context): if not os.getenv("SKIP_WEAVIATE_SETUP"): config.setup_llm_request() - final_response = prepare_response(config) + response = Response(config) + final_response = response.prepare_response() config.socket.send(reshape_response(final_response, 'debug' if config.debug_mode else 'base')) log_group = os.getenv('METRICS_LOG_GROUP') diff --git a/chat/src/handlers/streaming_socket_callback_handler.py b/chat/src/handlers/streaming_socket_callback_handler.py index 5bc1d012..8fe32272 100644 --- a/chat/src/handlers/streaming_socket_callback_handler.py +++ b/chat/src/handlers/streaming_socket_callback_handler.py @@ -1,11 +1,22 @@ from langchain.callbacks.base import BaseCallbackHandler from websocket import Websocket +from typing import Any +from langchain_core.outputs.llm_result import LLMResult class StreamingSocketCallbackHandler(BaseCallbackHandler): - def __init__(self, socket: Websocket, debug_mode: bool): + def __init__(self, socket: Websocket, stream: bool = True): self.socket = socket - self.debug_mode = debug_mode + self.stream = stream def on_llm_new_token(self, token: str, **kwargs): - if self.socket and not self.debug_mode: + if len(token) > 0 and self.socket and self.stream: return self.socket.send({"token": token}) + + def on_llm_end(self, response: LLMResult, **kwargs: Any): + try: + finish_reason = response.generations[0][0].generation_info["finish_reason"] + if self.socket: + return self.socket.send({"end": {"reason": finish_reason}}) + except Exception as err: + finish_reason = f'Unknown ({str(err)})' + print(f"Stream ended: {finish_reason}") diff --git a/chat/src/helpers/metrics.py b/chat/src/helpers/metrics.py index 6cb13efb..336895b9 100644 --- a/chat/src/helpers/metrics.py +++ b/chat/src/helpers/metrics.py @@ -4,7 +4,7 @@ def token_usage(config, response, original_question): data = { "question": count_tokens(config.question), - "answer": count_tokens(response["output_text"]), + "answer": count_tokens(response), "prompt": count_tokens(config.prompt_text), "source_documents": count_tokens(original_question["source_documents"]), } diff --git a/chat/src/helpers/prompts.py b/chat/src/helpers/prompts.py index 1b08d1d9..9d660427 100644 --- a/chat/src/helpers/prompts.py +++ b/chat/src/helpers/prompts.py @@ -2,16 +2,15 @@ def prompt_template() -> str: - return """Please provide an answer to the question based on the documents provided. Include specific details from the documents that support your answer. Each document is identified by a 'title' and a unique 'source' UUID: + return """Please provide a brief answer to the question based on the documents provided. Include specific details from the documents that support your answer. Keep your answer concise. Each document is identified by a 'title' and a unique 'source' UUID: -Documents: -{context} -Answer in raw markdown. When referencing a document by title, link to it using its UUID like this: [title](https://dc.library.northwestern.edu/items/UUID). For example: [Judy Collins, Jackson Hole Folk Festival](https://dc.library.northwestern.edu/items/f1ca513b-7d13-4af6-ad7b-8c7ffd1d3a37). Suggest keyword searches using this format: [keyword](https://dc.library.northwestern.edu/search?q=keyword). Offer a variety of search terms that cover different aspects of the topic. Include as many direct links to Digital Collections searches as necessary for a thorough study. The `collection` field contains information about the collection the document belongs to. In the summary, mention the top 1 or 2 collections, explain why they are relevant and link to them using the collection title and id: [collection['title']](https://dc.library.northwestern.edu/collections/collection['id']), for example [World War II Poster Collection](https://dc.library.northwestern.edu/collections/faf4f60e-78e0-4fbf-96ce-4ca8b4df597a): - -Question: -{question} -""" + Documents: + {context} + Answer in raw markdown. When referencing a document by title, link to it using its UUID like this: [title](https://dc.library.northwestern.edu/items/UUID). For example: [Judy Collins, Jackson Hole Folk Festival](https://dc.library.northwestern.edu/items/f1ca513b-7d13-4af6-ad7b-8c7ffd1d3a37). Suggest keyword searches using this format: [keyword](https://dc.library.northwestern.edu/search?q=keyword). Offer a variety of search terms that cover different aspects of the topic. Include as many direct links to Digital Collections searches as necessary for a thorough study. The `collection` field contains information about the collection the document belongs to. In the summary, mention the top 1 or 2 collections, explain why they are relevant and link to them using the collection title and id: [collection['title']](https://dc.library.northwestern.edu/collections/collection['id']), for example [World War II Poster Collection](https://dc.library.northwestern.edu/collections/faf4f60e-78e0-4fbf-96ce-4ca8b4df597a): + Question: + {question} + """ def document_template(attributes: Optional[List[str]] = None) -> str: if attributes is None: diff --git a/chat/src/helpers/response.py b/chat/src/helpers/response.py index 374e4482..79715a79 100644 --- a/chat/src/helpers/response.py +++ b/chat/src/helpers/response.py @@ -1,37 +1,6 @@ from helpers.metrics import token_usage -from openai.error import InvalidRequestError - -def debug_response(config, response, original_question): - return { - "answer": response["output_text"], - "attributes": config.attributes, - "azure_endpoint": config.azure_endpoint, - "deployment_name": config.deployment_name, - "is_superuser": config.api_token.is_superuser(), - "k": config.k, - "openai_api_version": config.openai_api_version, - "prompt": config.prompt_text, - "question": config.question, - "ref": config.ref, - "temperature": config.temperature, - "text_key": config.text_key, - "token_counts": token_usage(config, response, original_question), - } - -def get_and_send_original_question(config, docs): - doc_response = [] - for doc in docs: - doc_dict = doc.__dict__ - metadata = doc_dict.get('metadata', {}) - new_doc = {key: extract_prompt_value(metadata.get(key)) for key in config.attributes if key in metadata} - doc_response.append(new_doc) - - original_question = { - "question": config.question, - "source_documents": doc_response, - } - config.socket.send(original_question) - return original_question +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import RunnableLambda, RunnablePassthrough def extract_prompt_value(v): if isinstance(v, list): @@ -40,29 +9,76 @@ def extract_prompt_value(v): return [v.get('label')] else: return v - -def prepare_response(config): - try: - subquery = { - "match": { - "all_titles": { - "query": config.question, - "operator": "AND", - "analyzer": "english" - } + +class Response: + def __init__(self, config): + self.config = config + self.store = {} + + def debug_response_passthrough(self): + def debug_response(config, response, original_question): + return { + "answer": response, + "attributes": config.attributes, + "azure_endpoint": config.azure_endpoint, + "deployment_name": config.deployment_name, + "is_superuser": config.api_token.is_superuser(), + "k": config.k, + "openai_api_version": config.openai_api_version, + "prompt": config.prompt_text, + "question": config.question, + "ref": config.ref, + "temperature": config.temperature, + "text_key": config.text_key, + "token_counts": token_usage(config, response, original_question), } - } - docs = config.opensearch.similarity_search( - query=config.question, k=config.k, subquery=subquery, _source={"excludes": ["embedding"]} - ) - original_question = get_and_send_original_question(config, docs) - response = config.chain({"question": config.question, "input_documents": docs}) - prepared_response = debug_response(config, response, original_question) - except InvalidRequestError as err: - prepared_response = { - "question": config.question, - "error": str(err), - "source_documents": [], - } - return prepared_response + return RunnableLambda(lambda x: debug_response(self.config, x, self.original_question)) + + def original_question_passthrough(self): + def get_and_send_original_question(docs): + source_documents = [] + for doc in docs["context"]: + doc.metadata = {key: extract_prompt_value(doc.metadata.get(key)) for key in self.config.attributes if key in doc.metadata} + source_document = doc.metadata.copy() + source_document["content"] = doc.page_content + source_documents.append(source_document) + + original_question = { + "question": self.config.question, + "source_documents": source_documents, + } + self.config.socket.send(original_question) + self.original_question = original_question + return docs + + return RunnablePassthrough(get_and_send_original_question) + + def prepare_response(self): + try: + subquery = { + "match": { + "all_titles": { + "query": self.config.question, + "operator": "AND", + "analyzer": "english" + } + } + } + retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "subquery": subquery, "_source": {"excludes": ["embedding"]}}) + chain = ( + {"context": retriever, "question": RunnablePassthrough()} + | self.original_question_passthrough() + | self.config.prompt + | self.config.client + | StrOutputParser() + | self.debug_response_passthrough() + ) + response = chain.invoke(self.config.question) + except Exception as err: + response = { + "question": self.config.question, + "error": str(err), + "source_documents": [], + } + return response diff --git a/chat/src/requirements-dev.txt b/chat/src/requirements-dev.txt new file mode 100644 index 00000000..13699722 --- /dev/null +++ b/chat/src/requirements-dev.txt @@ -0,0 +1,3 @@ +# Dev/Test Dependencies +ruff~=0.1.0 +coverage~=7.3.2 diff --git a/chat/src/requirements.txt b/chat/src/requirements.txt index 2864f79a..79a4f375 100644 --- a/chat/src/requirements.txt +++ b/chat/src/requirements.txt @@ -1,17 +1,14 @@ # Runtime Dependencies -boto3~=1.34.13 +boto3~=1.34 honeybadger -langchain -langchain-community -openai~=0.27.8 +langchain~=0.2 +langchain-aws~=0.1 +langchain-openai~=0.1 +openai~=1.35 opensearch-py pyjwt~=2.6.0 python-dotenv~=1.0.0 requests requests-aws4auth -tiktoken~=0.4.0 -wheel~=0.40.0 - -# Dev/Test Dependencies -ruff~=0.1.0 -coverage~=7.3.2 +tiktoken~=0.7 +wheel~=0.40 \ No newline at end of file diff --git a/chat/src/setup.py b/chat/src/setup.py index e57a9c20..a2a7a3a4 100644 --- a/chat/src/setup.py +++ b/chat/src/setup.py @@ -1,4 +1,4 @@ -from langchain_community.chat_models import AzureChatOpenAI +from langchain_openai import AzureChatOpenAI from handlers.opensearch_neural_search import OpenSearchNeuralSearch from opensearchpy import OpenSearch, RequestsHttpConnection from requests_aws4auth import AWS4Auth @@ -17,7 +17,6 @@ def openai_chat_client(**kwargs): ) def opensearch_client(region_name=os.getenv("AWS_REGION")): - print(region_name) session = boto3.Session(region_name=region_name) awsauth = AWS4Auth(region=region_name, service="es", refreshable_credentials=session.get_credentials()) endpoint = os.getenv("OPENSEARCH_ENDPOINT") diff --git a/chat/template.yaml b/chat/template.yaml index 265a40f4..04d59379 100644 --- a/chat/template.yaml +++ b/chat/template.yaml @@ -184,18 +184,18 @@ Resources: Action: lambda:InvokeFunction FunctionName: !Ref ChatFunction Principal: apigateway.amazonaws.com - ChatDependencies: - Type: AWS::Serverless::LayerVersion - Properties: - LayerName: - Fn::Sub: "${AWS::StackName}-dependencies" - Description: Dependencies for streaming chat function - ContentUri: ./dependencies - CompatibleRuntimes: - - python3.10 - LicenseInfo: "Apache-2.0" - Metadata: - BuildMethod: python3.10 + #* ChatDependencies: + #* Type: AWS::Serverless::LayerVersion + #* Properties: + #* LayerName: + #* Fn::Sub: "${AWS::StackName}-dependencies" + #* Description: Dependencies for streaming chat function + #* ContentUri: ./dependencies + #* CompatibleRuntimes: + #* - python3.10 + #* LicenseInfo: "Apache-2.0" + #* Metadata: + #* BuildMethod: python3.10 ChatFunction: Type: AWS::Serverless::Function Properties: @@ -203,8 +203,8 @@ Resources: Runtime: python3.10 Architectures: - x86_64 - Layers: - - !Ref ChatDependencies + #* Layers: + #* - !Ref ChatDependencies MemorySize: 1024 Handler: handlers/chat.handler Timeout: 300 @@ -240,8 +240,8 @@ Resources: - logs:CreateLogStream - logs:PutLogEvents Resource: !Sub "${ChatMetricsLog.Arn}:*" - Metadata: - BuildMethod: nodejs20.x + #* Metadata: + #* BuildMethod: nodejs20.x ChatMetricsLog: Type: AWS::Logs::LogGroup Properties: diff --git a/chat/test/handlers/test_streaming_socket_callback_handler.py b/chat/test/handlers/test_streaming_socket_callback_handler.py index 5293a6a2..27d6cb2e 100644 --- a/chat/test/handlers/test_streaming_socket_callback_handler.py +++ b/chat/test/handlers/test_streaming_socket_callback_handler.py @@ -7,7 +7,9 @@ StreamingSocketCallbackHandler, ) from websocket import Websocket - +from langchain_core.outputs.llm_result import LLMResult +from langchain_core.outputs import ChatGeneration +from langchain_core.messages.ai import AIMessage class MockClient: @@ -16,11 +18,32 @@ def post_to_connection(self, Data, ConnectionId): class TestMyStreamingSocketCallbackHandler(TestCase): def test_on_new_llm_token(self): - handler = StreamingSocketCallbackHandler(Websocket(client=MockClient()), False) + handler = StreamingSocketCallbackHandler(Websocket(client=MockClient())) result = handler.on_llm_new_token(token="test") self.assertEqual(result, {'token': 'test', 'ref': {}}) - self.assertFalse(handler.debug_mode) + self.assertTrue(handler.stream) + def test_on_llm_end(self): + handler = StreamingSocketCallbackHandler(Websocket(client=MockClient())) + payload = LLMResult( + generations=[[ + ChatGeneration( + text='LLM Response', + generation_info={'finish_reason': 'stop', 'model_name': 'llm-model', 'system_fingerprint': 'fp_012345678'}, + message=AIMessage( + content='LLM Response', + response_metadata={'finish_reason': 'stop', 'model_name': 'llm-model', 'system_fingerprint': 'fp_012345678'}, + id='run-5da4fbbc-b663-4851-809d-fd11c27c5b76-0' + ) + ) + ]], + llm_output=None, + run=None + ) + result = handler.on_llm_end(payload) + self.assertEqual(result, {'end': {'reason': 'stop'}, 'ref': {}}) + self.assertTrue(handler.stream) + def test_debug_mode(self): - handler = StreamingSocketCallbackHandler(Websocket(client=MockClient()), debug_mode=True) - self.assertTrue(handler.debug_mode) + handler = StreamingSocketCallbackHandler(Websocket(client=MockClient()), stream=False) + self.assertFalse(handler.stream) diff --git a/chat/test/helpers/test_metrics.py b/chat/test/helpers/test_metrics.py index 4c65d0e2..5c593b8a 100644 --- a/chat/test/helpers/test_metrics.py +++ b/chat/test/helpers/test_metrics.py @@ -47,11 +47,11 @@ def test_token_usage(self): result = token_usage(config, response, original_question) expected_result = { - "answer": 6, - "prompt": 302, + "answer": 12, + "prompt": 314, "question": 15, "source_documents": 1, - "total": 324 + "total": 342 } self.assertEqual(result, expected_result)