From 4f0c82e9a370514bb774a00873f590283295a0b9 Mon Sep 17 00:00:00 2001 From: Mark Date: Wed, 21 Feb 2024 21:52:18 -0800 Subject: [PATCH] Add translation in. Get LLM model from env var --- app/core/llm_framework/openai_langchain.py | 2 +- app/core/llm_framework/openai_vanilla.py | 22 ++++--- app/core/translation/__init__.py | 6 +- app/routers.py | 72 +++++++++------------- recipes/basic_chat.py | 2 +- recipes/postgres_openai_chat.py | 2 +- 6 files changed, 48 insertions(+), 58 deletions(-) diff --git a/app/core/llm_framework/openai_langchain.py b/app/core/llm_framework/openai_langchain.py index 1bc0a46..9ad303b 100644 --- a/app/core/llm_framework/openai_langchain.py +++ b/app/core/llm_framework/openai_langchain.py @@ -34,7 +34,7 @@ def __init__( self, # pylint: disable=super-init-not-called # FIXME : Ideal to be able to mock the __init__ from tests key: str = os.getenv("OPENAI_API_KEY", "dummy-for-test"), - model_name: str = "gpt-3.5-turbo", + model_name: str = os.getenv("OPENAI_LLM_NAME", "gpt-3.5-turbo"), vectordb: VectordbInterface = Chroma(), max_tokens_limit: int = int( os.getenv("OPENAI_MAX_TOKEN_LIMIT", "3052")), diff --git a/app/core/llm_framework/openai_vanilla.py b/app/core/llm_framework/openai_vanilla.py index 7ebf95c..07a99bb 100644 --- a/app/core/llm_framework/openai_vanilla.py +++ b/app/core/llm_framework/openai_vanilla.py @@ -30,7 +30,7 @@ def get_context(source_documents): return context -def get_pre_prompt(context): +def get_pre_prompt(context, response_language="English"): """Constructs a pre-prompt for the conversation, including the context""" chat_prefix = "The following is a conversation with an AI assistant for " chat_prefix += "Bible translators. The assistant is" @@ -39,6 +39,7 @@ def get_pre_prompt(context): chat_prefix + "Read the paragraph below and answer the question, using only the information" " in the context delimited by triple backticks. " + f" Your response should be in the {response_language} language." "If the question cannot be answered based on the context alone, " 'write "Sorry, I had trouble answering this question based on the ' "information I found\n" @@ -75,7 +76,7 @@ class OpenAIVanilla(LLMFrameworkInterface): # pylint: disable=too-few-public-me def __init__( self, # pylint: disable=super-init-not-called key: str = os.getenv("OPENAI_API_KEY"), - model_name: str = "gpt-3.5-turbo", + model_name: str = os.getenv("OPENAI_LLM_NAME", "gpt-3.5-turbo"), vectordb: VectordbInterface = None, # What should this be by default? ) -> None: """Sets the API key and initializes library objects if any""" @@ -90,22 +91,24 @@ def __init__( self.vectordb = vectordb def generate_text( - self, query: str, chat_history: List[Tuple[str, str]], **kwargs + self, + query: str, + chat_history: List[Tuple[str, str]], + response_language: str = "English", + **kwargs, ) -> dict: """Prompt completion for QA or Chat reponse, based on specific documents, if provided""" if len(kwargs) > 0: - log.warning( - "Unused arguments in VanillaOpenAI.generate_text(): ", **kwargs) + log.warning("Unused arguments in VanillaOpenAI.generate_text(): ", **kwargs) # Vectordb results are currently returned based on the whole chat history. # We'll need to figure out if this is optimal or not. - query_text = "\n".join( - [x[0] + "/n" + x[1][:50] + "\n" for x in chat_history]) + query_text = "\n".join([x[0] + "/n" + x[1][:50] + "\n" for x in chat_history]) query_text += "\n" + query source_documents = self.vectordb.get_relevant_documents(query_text) context = get_context(source_documents) - pre_prompt = get_pre_prompt(context) + pre_prompt = get_pre_prompt(context, response_language=response_language) prompt = append_query_to_prompt(pre_prompt, query, chat_history) print(f"{prompt=}") @@ -122,5 +125,4 @@ def generate_text( } except Exception as exe: - raise OpenAIException( - "While generating answer: " + str(exe)) from exe + raise OpenAIException("While generating answer: " + str(exe)) from exe diff --git a/app/core/translation/__init__.py b/app/core/translation/__init__.py index 445cb2b..69a3534 100644 --- a/app/core/translation/__init__.py +++ b/app/core/translation/__init__.py @@ -2,8 +2,8 @@ import boto3 -# with open('../iso639-1.json') as f: -# iso_639_1 = json.load(f) +with open('../iso639-1.json') as f: + iso_639_1 = json.load(f) def translate_text(text:str): @@ -20,7 +20,7 @@ def translate_text(text:str): return {} source_language_code = response.get('SourceLanguageCode', '').split('-')[0] # Extracting ISO 639-1 code part - # response['language'] = iso_639_1.get(source_language_code, "Unknown language") + response['language'] = iso_639_1.get(source_language_code, "Unknown language") return response diff --git a/app/routers.py b/app/routers.py index eee95a5..8ed91ef 100644 --- a/app/routers.py +++ b/app/routers.py @@ -46,8 +46,7 @@ POSTGRES_DB_PORT = os.getenv("POSTGRES_DB_PORT", "5432") POSTGRES_DB_NAME = os.getenv("POSTGRES_DB_NAME", "adotbcollection") CHROMA_DB_PATH = os.environ.get("CHROMA_DB_PATH", "chromadb_store") -CHROMA_DB_COLLECTION = os.environ.get( - "CHROMA_DB_COLLECTION", "adotbcollection") +CHROMA_DB_COLLECTION = os.environ.get("CHROMA_DB_COLLECTION", "adotbcollection") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") UPLOAD_PATH = "./uploaded-files/" @@ -205,8 +204,7 @@ def compose_vector_db_args(db_type, settings, embedding_config): elif embedding_config.embeddingType == schema.EmbeddingType.OPENAI: vectordb_args["embedding"] = OpenAIEmbedding() else: - raise GenericException( - "This embedding type is not supported (yet)!") + raise GenericException("This embedding type is not supported (yet)!") return vectordb_args @@ -276,7 +274,7 @@ async def websocket_chat_endpoint( question = chat_stack.transcription_framework.transcribe_audio( received_bytes ) - + start_human_q = schema.BotResponse( sender=schema.SenderType.USER, message=question, @@ -288,28 +286,26 @@ async def websocket_chat_endpoint( if len(question) > 0: translation_response = translate_text(question) - english_query_text = translation_response.translations[0].translated_text - query_language_code = translation_response.translations[0].detected_language_code - query_language = translation_response['language'] - print(f'{query_language=}') - print(f'{english_query_text=}') + english_query_text = translation_response["TranslatedText"] + query_language = translation_response["language"] bot_response = chat_stack.llm_framework.generate_text( - query=question, chat_history=chat_stack.chat_history + query=english_query_text, + chat_history=chat_stack.chat_history, + response_language=query_language, ) log.debug( "Human: {0}\nBot:{1}\nSources:{2}\n\n".format( question, - bot_response['answer'], - [item.metadata['source'] - for item in bot_response['source_documents']] + bot_response["answer"], + [ + item.metadata["source"] + for item in bot_response["source_documents"] + ], ) ) chat_stack.chat_history.append( - ( - bot_response["question"], - bot_response["answer"] - ) + (bot_response["question"], bot_response["answer"]) ) # Construct a response @@ -365,8 +361,7 @@ async def upload_sentences( ), vectordb_type: schema.DatabaseType = Query(schema.DatabaseType.CHROMA), vectordb_config: schema.DBSelector = Depends(schema.DBSelector), - embedding_config: schema.EmbeddingSelector = Depends( - schema.EmbeddingSelector), + embedding_config: schema.EmbeddingSelector = Depends(schema.EmbeddingSelector), token: SecretStr = Query( None, desc="Optional access token to be used if user accounts not present" ), @@ -417,8 +412,7 @@ async def upload_text_file( # pylint: disable=too-many-arguments ), vectordb_type: schema.DatabaseType = Query(schema.DatabaseType.CHROMA), vectordb_config: schema.DBSelector = Depends(schema.DBSelector), - embedding_config: schema.EmbeddingSelector = Depends( - schema.EmbeddingSelector), + embedding_config: schema.EmbeddingSelector = Depends(schema.EmbeddingSelector), token: SecretStr = Query( None, desc="Optional access token to be used if user accounts not present" ), @@ -481,15 +475,9 @@ async def upload_pdf_file( # pylint: disable=too-many-arguments file_processor_type: schema.FileProcessorType = Query( schema.FileProcessorType.LANGCHAIN ), - vectordb_type: schema.DatabaseType = Query( - schema.DatabaseType.CHROMA - ), - vectordb_config: schema.DBSelector = Depends( - schema.DBSelector - ), - embedding_config: schema.EmbeddingSelector = Depends( - schema.EmbeddingSelector - ), + vectordb_type: schema.DatabaseType = Query(schema.DatabaseType.CHROMA), + vectordb_config: schema.DBSelector = Depends(schema.DBSelector), + embedding_config: schema.EmbeddingSelector = Depends(schema.EmbeddingSelector), token: SecretStr = Query( None, desc="Optional access token to be used if user accounts not present" ), @@ -502,7 +490,9 @@ async def upload_pdf_file( # pylint: disable=too-many-arguments """ log.info("Access token used: %s", token) - vectordb_args = compose_vector_db_args(vectordb_type, vectordb_config, embedding_config) + vectordb_args = compose_vector_db_args( + vectordb_type, vectordb_config, embedding_config + ) data_stack = DataUploadPipeline() data_stack.set_vectordb(vectordb_type, **vectordb_args) @@ -532,6 +522,7 @@ async def upload_pdf_file( # pylint: disable=too-many-arguments data_stack.vectordb.add_to_collection(docs=docs) return {"message": "Documents added to DB"} + @router.post( "/upload/csv-file", response_model=schema.APIInfoResponse, @@ -551,8 +542,7 @@ async def upload_csv_file( # pylint: disable=too-many-arguments ), vectordb_type: schema.DatabaseType = Query(schema.DatabaseType.CHROMA), vectordb_config: schema.DBSelector = Depends(schema.DBSelector), - embedding_config: schema.EmbeddingSelector = Depends( - schema.EmbeddingSelector), + embedding_config: schema.EmbeddingSelector = Depends(schema.EmbeddingSelector), token: SecretStr = Query( None, desc="Optional access token to be used if user accounts not present" ), @@ -579,10 +569,10 @@ async def upload_csv_file( # pylint: disable=too-many-arguments elif col_delimiter == schema.CsvColDelimiter.TAB: col_delimiter = "\t" docs = data_stack.file_processor.process_file( - file_path = f"{UPLOAD_PATH}{file_obj.filename}", - file_type = schema.FileType.CSV, - col_delimiter = col_delimiter, - ) + file_path=f"{UPLOAD_PATH}{file_obj.filename}", + file_type=schema.FileType.CSV, + col_delimiter=col_delimiter, + ) data_stack.set_embedding( embedding_config.embeddingType, embedding_config.embeddingApiKey, @@ -634,8 +624,7 @@ async def check_job_status( async def get_source_tags( db_type: schema.DatabaseType = schema.DatabaseType.CHROMA, settings: schema.DBSelector = Depends(schema.DBSelector), - embedding_config: schema.EmbeddingSelector = Depends( - schema.EmbeddingSelector), + embedding_config: schema.EmbeddingSelector = Depends(schema.EmbeddingSelector), token: SecretStr = Query( None, desc="Optional access token to be used if user accounts not present" ), @@ -680,8 +669,7 @@ async def login( + "Please confirm your email and then try to log in again.", ) from exe - raise PermissionException( - "Unauthorized access. Invalid credentials.") from exe + raise PermissionException("Unauthorized access. Invalid credentials.") from exe return { "message": "User logged in successfully", diff --git a/recipes/basic_chat.py b/recipes/basic_chat.py index 8d36717..911ec5b 100644 --- a/recipes/basic_chat.py +++ b/recipes/basic_chat.py @@ -23,7 +23,7 @@ ) chat_stack.set_llm_framework(schema.LLMFrameworkType.LANGCHAIN, api_key=os.getenv('OPENAI_API_KEY'), - model='gpt-3.5-turbo', + model=os.getenv('OPENAI_LLM_NAME', 'gpt-3.5-turbo') vectordb=chat_stack.vectordb) ##### Checking DB has content ########## diff --git a/recipes/postgres_openai_chat.py b/recipes/postgres_openai_chat.py index 1b24c19..6fb49f4 100644 --- a/recipes/postgres_openai_chat.py +++ b/recipes/postgres_openai_chat.py @@ -31,7 +31,7 @@ user="test_user") chat_stack.set_llm_framework(schema.LLMFrameworkType.LANGCHAIN, api_key=os.getenv('OPENAI_API_KEY'), - model='gpt-3.5-turbo', + model=os.getenv('OPENAI_LLM_NAME', 'gpt-3.5-turbo'), vectordb=chat_stack.vectordb) chat_stack.label = RESOURCE_LABEL