Skip to content

Commit

Permalink
Add translation in. Get LLM model from env var
Browse files Browse the repository at this point in the history
  • Loading branch information
woodwardmw committed Feb 22, 2024
1 parent a3b1ca9 commit 4f0c82e
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 58 deletions.
2 changes: 1 addition & 1 deletion app/core/llm_framework/openai_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self, # pylint: disable=super-init-not-called
# FIXME : Ideal to be able to mock the __init__ from tests
key: str = os.getenv("OPENAI_API_KEY", "dummy-for-test"),
model_name: str = "gpt-3.5-turbo",
model_name: str = os.getenv("OPENAI_LLM_NAME", "gpt-3.5-turbo"),
vectordb: VectordbInterface = Chroma(),
max_tokens_limit: int = int(
os.getenv("OPENAI_MAX_TOKEN_LIMIT", "3052")),
Expand Down
22 changes: 12 additions & 10 deletions app/core/llm_framework/openai_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_context(source_documents):
return context


def get_pre_prompt(context):
def get_pre_prompt(context, response_language="English"):
"""Constructs a pre-prompt for the conversation, including the context"""
chat_prefix = "The following is a conversation with an AI assistant for "
chat_prefix += "Bible translators. The assistant is"
Expand All @@ -39,6 +39,7 @@ def get_pre_prompt(context):
chat_prefix
+ "Read the paragraph below and answer the question, using only the information"
" in the context delimited by triple backticks. "
f" Your response should be in the {response_language} language."
"If the question cannot be answered based on the context alone, "
'write "Sorry, I had trouble answering this question based on the '
"information I found\n"
Expand Down Expand Up @@ -75,7 +76,7 @@ class OpenAIVanilla(LLMFrameworkInterface): # pylint: disable=too-few-public-me
def __init__(
self, # pylint: disable=super-init-not-called
key: str = os.getenv("OPENAI_API_KEY"),
model_name: str = "gpt-3.5-turbo",
model_name: str = os.getenv("OPENAI_LLM_NAME", "gpt-3.5-turbo"),
vectordb: VectordbInterface = None, # What should this be by default?
) -> None:
"""Sets the API key and initializes library objects if any"""
Expand All @@ -90,22 +91,24 @@ def __init__(
self.vectordb = vectordb

def generate_text(
self, query: str, chat_history: List[Tuple[str, str]], **kwargs
self,
query: str,
chat_history: List[Tuple[str, str]],
response_language: str = "English",
**kwargs,
) -> dict:
"""Prompt completion for QA or Chat reponse, based on specific documents,
if provided"""
if len(kwargs) > 0:
log.warning(
"Unused arguments in VanillaOpenAI.generate_text(): ", **kwargs)
log.warning("Unused arguments in VanillaOpenAI.generate_text(): ", **kwargs)

# Vectordb results are currently returned based on the whole chat history.
# We'll need to figure out if this is optimal or not.
query_text = "\n".join(
[x[0] + "/n" + x[1][:50] + "\n" for x in chat_history])
query_text = "\n".join([x[0] + "/n" + x[1][:50] + "\n" for x in chat_history])
query_text += "\n" + query
source_documents = self.vectordb.get_relevant_documents(query_text)
context = get_context(source_documents)
pre_prompt = get_pre_prompt(context)
pre_prompt = get_pre_prompt(context, response_language=response_language)
prompt = append_query_to_prompt(pre_prompt, query, chat_history)
print(f"{prompt=}")

Expand All @@ -122,5 +125,4 @@ def generate_text(
}

except Exception as exe:
raise OpenAIException(
"While generating answer: " + str(exe)) from exe
raise OpenAIException("While generating answer: " + str(exe)) from exe
6 changes: 3 additions & 3 deletions app/core/translation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import boto3

# with open('../iso639-1.json') as f:
# iso_639_1 = json.load(f)
with open('../iso639-1.json') as f:
iso_639_1 = json.load(f)


def translate_text(text:str):
Expand All @@ -20,7 +20,7 @@ def translate_text(text:str):
return {}

source_language_code = response.get('SourceLanguageCode', '').split('-')[0] # Extracting ISO 639-1 code part
# response['language'] = iso_639_1.get(source_language_code, "Unknown language")
response['language'] = iso_639_1.get(source_language_code, "Unknown language")

return response

Expand Down
72 changes: 30 additions & 42 deletions app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@
POSTGRES_DB_PORT = os.getenv("POSTGRES_DB_PORT", "5432")
POSTGRES_DB_NAME = os.getenv("POSTGRES_DB_NAME", "adotbcollection")
CHROMA_DB_PATH = os.environ.get("CHROMA_DB_PATH", "chromadb_store")
CHROMA_DB_COLLECTION = os.environ.get(
"CHROMA_DB_COLLECTION", "adotbcollection")
CHROMA_DB_COLLECTION = os.environ.get("CHROMA_DB_COLLECTION", "adotbcollection")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

UPLOAD_PATH = "./uploaded-files/"
Expand Down Expand Up @@ -205,8 +204,7 @@ def compose_vector_db_args(db_type, settings, embedding_config):
elif embedding_config.embeddingType == schema.EmbeddingType.OPENAI:
vectordb_args["embedding"] = OpenAIEmbedding()
else:
raise GenericException(
"This embedding type is not supported (yet)!")
raise GenericException("This embedding type is not supported (yet)!")

return vectordb_args

Expand Down Expand Up @@ -276,7 +274,7 @@ async def websocket_chat_endpoint(
question = chat_stack.transcription_framework.transcribe_audio(
received_bytes
)

start_human_q = schema.BotResponse(
sender=schema.SenderType.USER,
message=question,
Expand All @@ -288,28 +286,26 @@ async def websocket_chat_endpoint(

if len(question) > 0:
translation_response = translate_text(question)
english_query_text = translation_response.translations[0].translated_text
query_language_code = translation_response.translations[0].detected_language_code
query_language = translation_response['language']
print(f'{query_language=}')
print(f'{english_query_text=}')
english_query_text = translation_response["TranslatedText"]
query_language = translation_response["language"]
bot_response = chat_stack.llm_framework.generate_text(
query=question, chat_history=chat_stack.chat_history
query=english_query_text,
chat_history=chat_stack.chat_history,
response_language=query_language,
)
log.debug(
"Human: {0}\nBot:{1}\nSources:{2}\n\n".format(
question,
bot_response['answer'],
[item.metadata['source']
for item in bot_response['source_documents']]
bot_response["answer"],
[
item.metadata["source"]
for item in bot_response["source_documents"]
],
)
)

chat_stack.chat_history.append(
(
bot_response["question"],
bot_response["answer"]
)
(bot_response["question"], bot_response["answer"])
)

# Construct a response
Expand Down Expand Up @@ -365,8 +361,7 @@ async def upload_sentences(
),
vectordb_type: schema.DatabaseType = Query(schema.DatabaseType.CHROMA),
vectordb_config: schema.DBSelector = Depends(schema.DBSelector),
embedding_config: schema.EmbeddingSelector = Depends(
schema.EmbeddingSelector),
embedding_config: schema.EmbeddingSelector = Depends(schema.EmbeddingSelector),
token: SecretStr = Query(
None, desc="Optional access token to be used if user accounts not present"
),
Expand Down Expand Up @@ -417,8 +412,7 @@ async def upload_text_file( # pylint: disable=too-many-arguments
),
vectordb_type: schema.DatabaseType = Query(schema.DatabaseType.CHROMA),
vectordb_config: schema.DBSelector = Depends(schema.DBSelector),
embedding_config: schema.EmbeddingSelector = Depends(
schema.EmbeddingSelector),
embedding_config: schema.EmbeddingSelector = Depends(schema.EmbeddingSelector),
token: SecretStr = Query(
None, desc="Optional access token to be used if user accounts not present"
),
Expand Down Expand Up @@ -481,15 +475,9 @@ async def upload_pdf_file( # pylint: disable=too-many-arguments
file_processor_type: schema.FileProcessorType = Query(
schema.FileProcessorType.LANGCHAIN
),
vectordb_type: schema.DatabaseType = Query(
schema.DatabaseType.CHROMA
),
vectordb_config: schema.DBSelector = Depends(
schema.DBSelector
),
embedding_config: schema.EmbeddingSelector = Depends(
schema.EmbeddingSelector
),
vectordb_type: schema.DatabaseType = Query(schema.DatabaseType.CHROMA),
vectordb_config: schema.DBSelector = Depends(schema.DBSelector),
embedding_config: schema.EmbeddingSelector = Depends(schema.EmbeddingSelector),
token: SecretStr = Query(
None, desc="Optional access token to be used if user accounts not present"
),
Expand All @@ -502,7 +490,9 @@ async def upload_pdf_file( # pylint: disable=too-many-arguments
"""
log.info("Access token used: %s", token)

vectordb_args = compose_vector_db_args(vectordb_type, vectordb_config, embedding_config)
vectordb_args = compose_vector_db_args(
vectordb_type, vectordb_config, embedding_config
)
data_stack = DataUploadPipeline()
data_stack.set_vectordb(vectordb_type, **vectordb_args)

Expand Down Expand Up @@ -532,6 +522,7 @@ async def upload_pdf_file( # pylint: disable=too-many-arguments
data_stack.vectordb.add_to_collection(docs=docs)
return {"message": "Documents added to DB"}


@router.post(
"/upload/csv-file",
response_model=schema.APIInfoResponse,
Expand All @@ -551,8 +542,7 @@ async def upload_csv_file( # pylint: disable=too-many-arguments
),
vectordb_type: schema.DatabaseType = Query(schema.DatabaseType.CHROMA),
vectordb_config: schema.DBSelector = Depends(schema.DBSelector),
embedding_config: schema.EmbeddingSelector = Depends(
schema.EmbeddingSelector),
embedding_config: schema.EmbeddingSelector = Depends(schema.EmbeddingSelector),
token: SecretStr = Query(
None, desc="Optional access token to be used if user accounts not present"
),
Expand All @@ -579,10 +569,10 @@ async def upload_csv_file( # pylint: disable=too-many-arguments
elif col_delimiter == schema.CsvColDelimiter.TAB:
col_delimiter = "\t"
docs = data_stack.file_processor.process_file(
file_path = f"{UPLOAD_PATH}{file_obj.filename}",
file_type = schema.FileType.CSV,
col_delimiter = col_delimiter,
)
file_path=f"{UPLOAD_PATH}{file_obj.filename}",
file_type=schema.FileType.CSV,
col_delimiter=col_delimiter,
)
data_stack.set_embedding(
embedding_config.embeddingType,
embedding_config.embeddingApiKey,
Expand Down Expand Up @@ -634,8 +624,7 @@ async def check_job_status(
async def get_source_tags(
db_type: schema.DatabaseType = schema.DatabaseType.CHROMA,
settings: schema.DBSelector = Depends(schema.DBSelector),
embedding_config: schema.EmbeddingSelector = Depends(
schema.EmbeddingSelector),
embedding_config: schema.EmbeddingSelector = Depends(schema.EmbeddingSelector),
token: SecretStr = Query(
None, desc="Optional access token to be used if user accounts not present"
),
Expand Down Expand Up @@ -680,8 +669,7 @@ async def login(
+ "Please confirm your email and then try to log in again.",
) from exe

raise PermissionException(
"Unauthorized access. Invalid credentials.") from exe
raise PermissionException("Unauthorized access. Invalid credentials.") from exe

return {
"message": "User logged in successfully",
Expand Down
2 changes: 1 addition & 1 deletion recipes/basic_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
chat_stack.set_llm_framework(schema.LLMFrameworkType.LANGCHAIN,
api_key=os.getenv('OPENAI_API_KEY'),
model='gpt-3.5-turbo',
model=os.getenv('OPENAI_LLM_NAME', 'gpt-3.5-turbo')
vectordb=chat_stack.vectordb)

##### Checking DB has content ##########
Expand Down
2 changes: 1 addition & 1 deletion recipes/postgres_openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
user="test_user")
chat_stack.set_llm_framework(schema.LLMFrameworkType.LANGCHAIN,
api_key=os.getenv('OPENAI_API_KEY'),
model='gpt-3.5-turbo',
model=os.getenv('OPENAI_LLM_NAME', 'gpt-3.5-turbo'),
vectordb=chat_stack.vectordb)
chat_stack.label = RESOURCE_LABEL

Expand Down

0 comments on commit 4f0c82e

Please sign in to comment.