Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Nov 16, 2023
1 parent 0a17e2f commit f76d36d
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 117 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ help:
echo "make test-python | run python tests"
echo "make cover-node | run node tests with coverage"
echo "make cover-python | run python tests with coverage"
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json python/requirements.txt python/src/requirements.txt
.aws-sam/build.toml: ./template.yaml node/package-lock.json node/src/package-lock.json chat/src/requirements.txt
sam build --cached --parallel
deps-node:
cd node && npm ci
Expand Down
237 changes: 144 additions & 93 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import boto3
import json
import os
import setup

import boto3
import tiktoken
from helpers.apitoken import ApiToken
from helpers.prompts import document_template, prompt_template
from src.helpers.apitoken import ApiToken
from src.helpers.prompts import document_template, prompt_template
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.prompts import PromptTemplate
from openai.error import InvalidRequestError
from src.setup import (
weaviate_client,
weaviate_vector_store,
openai_chat_client,
)

DEFAULT_INDEX = "DCWork"

DEFAULT_INDEX = "Work"
DEFAULT_KEY = "title"
DEFAULT_K = 10
MAX_K = 100
Expand Down Expand Up @@ -40,118 +46,163 @@ def on_llm_new_token(self, token: str, **kwargs):
self.socket.send({"token": token})


def handler(event, context):
try:
payload = json.loads(event.get("body", "{}"))
class EventConfig:
def __init__(self, event):
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.attributes = [
item
for item in self._get_attributes()
if item not in [self.text_key, "source"]
]
self.debug_mode = self._get_debug_mode()
self.document_prompt = PromptTemplate(
template=document_template(self.attributes),
input_variables=["page_content", "source"] + self.attributes,
)
self.index_name = self.payload.get(
"index", self.payload.get("index", DEFAULT_INDEX)
)
self.is_logged_in = self.api_token.is_logged_in()
self.k = min(self.payload.get("k", DEFAULT_K), MAX_K)
self.payload = json.loads(event.get("body", "{}"))
self.prompt = PromptTemplate(
template=self.prompt_text, input_variables=["question", "context"]
)
self.prompt_text = (
self.payload.get("prompt", prompt_template())
if self.api_token.is_superuser()
else prompt_template()
)
self.question = self.payload.get("question")
self.text_key = self.payload.get("text_key", DEFAULT_KEY)

request_context = event.get("requestContext", {})
def setup_websocket(self):
request_context = self.event.get("requestContext", {})
connection_id = request_context.get("connectionId")
endpoint_url = f'https://{request_context.get("domainName")}/{request_context.get("stage")}'
ref = payload.get("ref")
socket = Websocket(
ref = self.payload.get("ref")
self.socket = Websocket(
connection_id=connection_id, endpoint_url=endpoint_url, ref=ref
)

api_token = ApiToken(signed_token=payload.get("auth"))
if not api_token.is_logged_in():
socket.send({"statusCode": 401, "body": "Unauthorized"})
return {"statusCode": 401, "body": "Unauthorized"}

debug_mode = payload.get("debug", False) and api_token.is_superuser()

question = payload.get("question")
index_name = payload.get("index", payload.get("index", DEFAULT_INDEX))
print(f"Searching index {index_name}")
text_key = payload.get("text_key", DEFAULT_KEY)
attributes = [
item
for item in get_attributes(
index_name, payload if api_token.is_superuser() else {}
)
if item not in [text_key, "source"]
]
def setup_llm_request(self):
self._setup_vector_store()
self._setup_chat_client()
self._setup_chain()

weaviate = setup.weaviate_vector_store(
index_name=index_name, text_key=text_key, attributes=attributes + ["source"]
def _setup_vector_store(self):
self.weaviate = weaviate_vector_store(
index_name=self.index_name,
text_key=self.text_key,
attributes=self.attributes + ["source"],
)

client = setup.openai_chat_client(
callbacks=[StreamingSocketCallbackHandler(socket, debug_mode)],
def _setup_chat_client(self):
self.client = openai_chat_client(
callbacks=[StreamingSocketCallbackHandler(self.socket, self.debug_mode)],
streaming=True,
)

prompt_text = (
payload.get("prompt", prompt_template())
if api_token.is_superuser()
else prompt_template()
)
prompt = PromptTemplate(
template=prompt_text, input_variables=["question", "context"]
)

document_prompt = PromptTemplate(
template=document_template(attributes),
input_variables=["page_content", "source"] + attributes,
)

k = min(payload.get("k", DEFAULT_K), MAX_K)
docs = weaviate.similarity_search(question, k=k, additional="certainty")
chain = load_qa_with_sources_chain(
client,
def _setup_chain(self):
self.chain = load_qa_with_sources_chain(
self.client,
chain_type="stuff",
prompt=prompt,
document_prompt=document_prompt,
prompt=self.prompt,
document_prompt=self.document_prompt,
document_variable_name="context",
verbose=to_bool(os.getenv("VERBOSE")),
)

try:
doc_response = [doc.__dict__ for doc in docs]
original_question = {"question": question, "source_documents": doc_response}
socket.send(original_question)
response = chain({"question": question, "input_documents": docs})
if debug_mode:
final_response = {
"answer": response["output_text"],
"attributes": attributes,
"isSuperuser": api_token.is_superuser(),
"prompt": prompt_text,
"ref": ref,
"k": k,
"original_question": original_question,
"token_counts": {
"question": count_tokens(question),
"answer": count_tokens(response["output_text"]),
"prompt": count_tokens(prompt_text),
"source_documents": count_tokens(doc_response),
},
}
else:
final_response = {"answer": response["output_text"], "ref": ref}
except InvalidRequestError as err:
final_response = {
"question": question,
"error": str(err),
"source_documents": [],
}

socket.send(final_response)
def _get_debug_mode(self):
debug = self.payload.get("debug", False)
return debug and self.api_token.is_superuser()

def _get_attributes(self):
request_attributes = self.payload.get("attributes", None)
if request_attributes is not None:
return request_attributes

client = weaviate_client()
schema = client.schema.get(self.index_name)
names = [prop["name"] for prop in schema.get("properties")]
print(f"Retrieved attributes: {names}")
return names


def handler(event, _context):
try:
config = EventConfig(event)
config.setup_websocket()

if not config.is_logged_in:
config.socket.send({"statusCode": 401, "body": "Unauthorized"})
return {"statusCode": 401, "body": "Unauthorized"}

config.setup_llm_request()
final_response = prepare_response(config)
config.socket.send(final_response)
return {"statusCode": 200}
except Exception as err:
print(event)
raise err


def get_attributes(index, payload):
request_attributes = payload.get("attributes", None)
if request_attributes is not None:
return request_attributes
def get_and_send_original_question(config, docs):
doc_response = [doc.__dict__ for doc in docs]
original_question = {
"question": config.question,
"source_documents": doc_response,
}
config.socket.send(original_question)
return original_question

client = setup.weaviate_client()
schema = client.schema.get(index)
names = [prop["name"] for prop in schema.get("properties")]
print(f"Retrieved attributes: {names}")
return names

def count_tokens(config, response, original_question):
return {
"question": config.question,
"answer": response["output_text"],
"prompt": config.prompt_text,
"source_documents": original_question["source_documents"],
}


def prepare_debug_response(config, response, original_question):
return {
"answer": response["output_text"],
"attributes": config.attributes,
"is_superuser": config.api_token.is_superuser(),
"prompt": config.prompt_text,
"ref": config.ref,
"k": config.k,
"original_question": original_question,
"token_counts": count_tokens(config, response, original_question),
}


def prepare_normal_response(config, response):
return {"answer": response["output_text"], "ref": config.ref}


def prepare_response(config):
try:
docs = config.weaviate.similarity_search(
config.question, k=config.k, additional="certainty"
)
original_question = get_and_send_original_question(config, docs)
response = config.chain({"question": config.question, "input_documents": docs})
if config.debug_mode:
prepared_response = prepare_debug_response(
config, response, original_question
)
else:
prepared_response = prepare_normal_response(config, response)
except InvalidRequestError as err:
prepared_response = {
"question": config.question,
"error": str(err),
"source_documents": [],
}
return prepared_response


def count_tokens(val):
Expand Down
50 changes: 27 additions & 23 deletions chat/src/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,37 @@
import os
import weaviate


def openai_chat_client(**kwargs):
deployment = os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
key = os.getenv("AZURE_OPENAI_API_KEY")
resource = os.getenv("AZURE_OPENAI_RESOURCE_NAME")
version = "2023-07-01-preview"

return AzureChatOpenAI(deployment_name=deployment,
openai_api_key=key,
openai_api_base=f"https://{resource}.openai.azure.com/",
openai_api_version=version,
**kwargs)

deployment = kwargs.get(
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
)
key = kwargs.get("openai_api_key", os.getenv("AZURE_OPENAI_API_KEY"))
resource = kwargs.get("openai_api_base", os.getenv("AZURE_OPENAI_RESOURCE_NAME"))
version = kwargs.get("openai_api_version", "2023-07-01-preview")

base_url = f"https://{resource}.openai.azure.com/" if resource else None

return AzureChatOpenAI(
deployment_name=deployment,
openai_api_key=key,
openai_api_base=base_url,
openai_api_version=version,
**kwargs,
)


def weaviate_client():
weaviate_url = os.environ['WEAVIATE_URL']
weaviate_api_key = os.environ['WEAVIATE_API_KEY']
auth_config = weaviate.AuthApiKey(api_key=weaviate_api_key)
weaviate_url = os.environ["WEAVIATE_URL"]
weaviate_api_key = os.environ["WEAVIATE_API_KEY"]
auth_config = weaviate.AuthApiKey(api_key=weaviate_api_key)

return weaviate.Client(url=weaviate_url, auth_client_secret=auth_config)

return weaviate.Client(
url=weaviate_url,
auth_client_secret=auth_config
)

def weaviate_vector_store(index_name: str, text_key: str, attributes: List[str] = []):
client = weaviate_client()
client = weaviate_client()

return Weaviate(client=client,
index_name=index_name,
text_key=text_key,
attributes=attributes)
return Weaviate(
client=client, index_name=index_name, text_key=text_key, attributes=attributes
)

0 comments on commit f76d36d

Please sign in to comment.