From f76d36d99c86380b607bd8b4d2e99a3ee0422c69 Mon Sep 17 00:00:00 2001 From: Brendan Quinn Date: Thu, 16 Nov 2023 02:20:32 +0000 Subject: [PATCH] WIP --- Makefile | 2 +- chat/src/handlers/chat.py | 237 +++++++++++++++++++++++--------------- chat/src/setup.py | 50 ++++---- 3 files changed, 172 insertions(+), 117 deletions(-) diff --git a/Makefile b/Makefile index 41f41972..325f973b 100644 --- a/Makefile +++ b/Makefile @@ -22,7 +22,7 @@ help: echo "make test-python | run python tests" echo "make cover-node | run node tests with coverage" echo "make cover-python | run python tests with coverage" -.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json python/requirements.txt python/src/requirements.txt +.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json chat/src/requirements.txt sam build --cached --parallel deps-node: cd node && npm ci diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index 3ac4c293..6c93646d 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -1,16 +1,22 @@ -import boto3 import json import os -import setup + +import boto3 import tiktoken -from helpers.apitoken import ApiToken -from helpers.prompts import document_template, prompt_template +from src.helpers.apitoken import ApiToken +from src.helpers.prompts import document_template, prompt_template from langchain.callbacks.base import BaseCallbackHandler from langchain.chains.qa_with_sources import load_qa_with_sources_chain from langchain.prompts import PromptTemplate from openai.error import InvalidRequestError +from src.setup import ( + weaviate_client, + weaviate_vector_store, + openai_chat_client, +) -DEFAULT_INDEX = "DCWork" + +DEFAULT_INDEX = "Work" DEFAULT_KEY = "title" DEFAULT_K = 10 MAX_K = 100 @@ -40,118 +46,163 @@ def on_llm_new_token(self, token: str, **kwargs): self.socket.send({"token": token}) -def handler(event, context): - try: - payload = json.loads(event.get("body", "{}")) +class EventConfig: + def __init__(self, event): + self.api_token = ApiToken(signed_token=self.payload.get("auth")) + self.attributes = [ + item + for item in self._get_attributes() + if item not in [self.text_key, "source"] + ] + self.debug_mode = self._get_debug_mode() + self.document_prompt = PromptTemplate( + template=document_template(self.attributes), + input_variables=["page_content", "source"] + self.attributes, + ) + self.index_name = self.payload.get( + "index", self.payload.get("index", DEFAULT_INDEX) + ) + self.is_logged_in = self.api_token.is_logged_in() + self.k = min(self.payload.get("k", DEFAULT_K), MAX_K) + self.payload = json.loads(event.get("body", "{}")) + self.prompt = PromptTemplate( + template=self.prompt_text, input_variables=["question", "context"] + ) + self.prompt_text = ( + self.payload.get("prompt", prompt_template()) + if self.api_token.is_superuser() + else prompt_template() + ) + self.question = self.payload.get("question") + self.text_key = self.payload.get("text_key", DEFAULT_KEY) - request_context = event.get("requestContext", {}) + def setup_websocket(self): + request_context = self.event.get("requestContext", {}) connection_id = request_context.get("connectionId") endpoint_url = f'https://{request_context.get("domainName")}/{request_context.get("stage")}' - ref = payload.get("ref") - socket = Websocket( + ref = self.payload.get("ref") + self.socket = Websocket( connection_id=connection_id, endpoint_url=endpoint_url, ref=ref ) - api_token = ApiToken(signed_token=payload.get("auth")) - if not api_token.is_logged_in(): - socket.send({"statusCode": 401, "body": "Unauthorized"}) - return {"statusCode": 401, "body": "Unauthorized"} - - debug_mode = payload.get("debug", False) and api_token.is_superuser() - - question = payload.get("question") - index_name = payload.get("index", payload.get("index", DEFAULT_INDEX)) - print(f"Searching index {index_name}") - text_key = payload.get("text_key", DEFAULT_KEY) - attributes = [ - item - for item in get_attributes( - index_name, payload if api_token.is_superuser() else {} - ) - if item not in [text_key, "source"] - ] + def setup_llm_request(self): + self._setup_vector_store() + self._setup_chat_client() + self._setup_chain() - weaviate = setup.weaviate_vector_store( - index_name=index_name, text_key=text_key, attributes=attributes + ["source"] + def _setup_vector_store(self): + self.weaviate = weaviate_vector_store( + index_name=self.index_name, + text_key=self.text_key, + attributes=self.attributes + ["source"], ) - client = setup.openai_chat_client( - callbacks=[StreamingSocketCallbackHandler(socket, debug_mode)], + def _setup_chat_client(self): + self.client = openai_chat_client( + callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)], streaming=True, ) - prompt_text = ( - payload.get("prompt", prompt_template()) - if api_token.is_superuser() - else prompt_template() - ) - prompt = PromptTemplate( - template=prompt_text, input_variables=["question", "context"] - ) - - document_prompt = PromptTemplate( - template=document_template(attributes), - input_variables=["page_content", "source"] + attributes, - ) - - k = min(payload.get("k", DEFAULT_K), MAX_K) - docs = weaviate.similarity_search(question, k=k, additional="certainty") - chain = load_qa_with_sources_chain( - client, + def _setup_chain(self): + self.chain = load_qa_with_sources_chain( + self.client, chain_type="stuff", - prompt=prompt, - document_prompt=document_prompt, + prompt=self.prompt, + document_prompt=self.document_prompt, document_variable_name="context", verbose=to_bool(os.getenv("VERBOSE")), ) - try: - doc_response = [doc.__dict__ for doc in docs] - original_question = {"question": question, "source_documents": doc_response} - socket.send(original_question) - response = chain({"question": question, "input_documents": docs}) - if debug_mode: - final_response = { - "answer": response["output_text"], - "attributes": attributes, - "isSuperuser": api_token.is_superuser(), - "prompt": prompt_text, - "ref": ref, - "k": k, - "original_question": original_question, - "token_counts": { - "question": count_tokens(question), - "answer": count_tokens(response["output_text"]), - "prompt": count_tokens(prompt_text), - "source_documents": count_tokens(doc_response), - }, - } - else: - final_response = {"answer": response["output_text"], "ref": ref} - except InvalidRequestError as err: - final_response = { - "question": question, - "error": str(err), - "source_documents": [], - } - - socket.send(final_response) + def _get_debug_mode(self): + debug = self.payload.get("debug", False) + return debug and self.api_token.is_superuser() + + def _get_attributes(self): + request_attributes = self.payload.get("attributes", None) + if request_attributes is not None: + return request_attributes + + client = weaviate_client() + schema = client.schema.get(self.index_name) + names = [prop["name"] for prop in schema.get("properties")] + print(f"Retrieved attributes: {names}") + return names + + +def handler(event, _context): + try: + config = EventConfig(event) + config.setup_websocket() + + if not config.is_logged_in: + config.socket.send({"statusCode": 401, "body": "Unauthorized"}) + return {"statusCode": 401, "body": "Unauthorized"} + + config.setup_llm_request() + final_response = prepare_response(config) + config.socket.send(final_response) return {"statusCode": 200} except Exception as err: print(event) raise err -def get_attributes(index, payload): - request_attributes = payload.get("attributes", None) - if request_attributes is not None: - return request_attributes +def get_and_send_original_question(config, docs): + doc_response = [doc.__dict__ for doc in docs] + original_question = { + "question": config.question, + "source_documents": doc_response, + } + config.socket.send(original_question) + return original_question - client = setup.weaviate_client() - schema = client.schema.get(index) - names = [prop["name"] for prop in schema.get("properties")] - print(f"Retrieved attributes: {names}") - return names + +def count_tokens(config, response, original_question): + return { + "question": config.question, + "answer": response["output_text"], + "prompt": config.prompt_text, + "source_documents": original_question["source_documents"], + } + + +def prepare_debug_response(config, response, original_question): + return { + "answer": response["output_text"], + "attributes": config.attributes, + "is_superuser": config.api_token.is_superuser(), + "prompt": config.prompt_text, + "ref": config.ref, + "k": config.k, + "original_question": original_question, + "token_counts": count_tokens(config, response, original_question), + } + + +def prepare_normal_response(config, response): + return {"answer": response["output_text"], "ref": config.ref} + + +def prepare_response(config): + try: + docs = config.weaviate.similarity_search( + config.question, k=config.k, additional="certainty" + ) + original_question = get_and_send_original_question(config, docs) + response = config.chain({"question": config.question, "input_documents": docs}) + if config.debug_mode: + prepared_response = prepare_debug_response( + config, response, original_question + ) + else: + prepared_response = prepare_normal_response(config, response) + except InvalidRequestError as err: + prepared_response = { + "question": config.question, + "error": str(err), + "source_documents": [], + } + return prepared_response def count_tokens(val): diff --git a/chat/src/setup.py b/chat/src/setup.py index da9dbbf1..c0279fdf 100644 --- a/chat/src/setup.py +++ b/chat/src/setup.py @@ -4,33 +4,37 @@ import os import weaviate + def openai_chat_client(**kwargs): - deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") - key = os.getenv("AZURE_OPENAI_API_KEY") - resource = os.getenv("AZURE_OPENAI_RESOURCE_NAME") - version = "2023-07-01-preview" - - return AzureChatOpenAI(deployment_name=deployment, - openai_api_key=key, - openai_api_base=f"https://{resource}.openai.azure.com/", - openai_api_version=version, - **kwargs) - + deployment = kwargs.get( + "deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") + ) + key = kwargs.get("openai_api_key", os.getenv("AZURE_OPENAI_API_KEY")) + resource = kwargs.get("openai_api_base", os.getenv("AZURE_OPENAI_RESOURCE_NAME")) + version = kwargs.get("openai_api_version", "2023-07-01-preview") + + base_url = f"https://{resource}.openai.azure.com/" if resource else None + + return AzureChatOpenAI( + deployment_name=deployment, + openai_api_key=key, + openai_api_base=base_url, + openai_api_version=version, + **kwargs, + ) + def weaviate_client(): - weaviate_url = os.environ['WEAVIATE_URL'] - weaviate_api_key = os.environ['WEAVIATE_API_KEY'] - auth_config = weaviate.AuthApiKey(api_key=weaviate_api_key) + weaviate_url = os.environ["WEAVIATE_URL"] + weaviate_api_key = os.environ["WEAVIATE_API_KEY"] + auth_config = weaviate.AuthApiKey(api_key=weaviate_api_key) + + return weaviate.Client(url=weaviate_url, auth_client_secret=auth_config) - return weaviate.Client( - url=weaviate_url, - auth_client_secret=auth_config - ) def weaviate_vector_store(index_name: str, text_key: str, attributes: List[str] = []): - client = weaviate_client() + client = weaviate_client() - return Weaviate(client=client, - index_name=index_name, - text_key=text_key, - attributes=attributes) + return Weaviate( + client=client, index_name=index_name, text_key=text_key, attributes=attributes + )