From 927b22eeb7ea6930c9dd9092080f32c9d6cca7f4 Mon Sep 17 00:00:00 2001 From: "Kavitha.Raju" Date: Mon, 11 Dec 2023 12:23:41 +0530 Subject: [PATCH 1/4] Fix test for chat endpoint to work with default tech stack --- .github/workflows/check_on_push.yml | 2 +- app/core/llm_framework/openai_vanilla.py | 3 +- app/tests/test_chat_on_websocket.py | 59 ++++++++++++++++-------- 3 files changed, 44 insertions(+), 20 deletions(-) diff --git a/.github/workflows/check_on_push.yml b/.github/workflows/check_on_push.yml index 5b46b90..b1413fe 100644 --- a/.github/workflows/check_on_push.yml +++ b/.github/workflows/check_on_push.yml @@ -43,4 +43,4 @@ jobs: - name: Run Pytest working-directory: ./app - run: python3 -m pytest -s tests/test_basics.py tests/test_dataupload.py + run: python3 -m pytest -s -m "not with_llm" diff --git a/app/core/llm_framework/openai_vanilla.py b/app/core/llm_framework/openai_vanilla.py index 05d500c..96609cd 100644 --- a/app/core/llm_framework/openai_vanilla.py +++ b/app/core/llm_framework/openai_vanilla.py @@ -3,6 +3,7 @@ from typing import List, Tuple import openai +from openai import ChatCompletion from core.llm_framework import LLMFrameworkInterface from core.vectordb import VectordbInterface @@ -110,7 +111,7 @@ def generate_text( print(f"{prompt=}") try: - response = openai.ChatCompletion.create( + response = ChatCompletion.create( model=self.model_name, temperature=0, messages=[{"role": "user", "content": prompt}], diff --git a/app/tests/test_chat_on_websocket.py b/app/tests/test_chat_on_websocket.py index 6a0cdd9..ad99a23 100644 --- a/app/tests/test_chat_on_websocket.py +++ b/app/tests/test_chat_on_websocket.py @@ -1,6 +1,10 @@ """Test the chat websocket""" import os +import json +import pytest from . import client +from app import schema + from .test_dataupload import test_data_upload_markdown from .test_dataupload import test_data_upload_processed_sentences @@ -9,11 +13,10 @@ # pylint: disable=too-many-function-args COMMON_CONNECTION_ARGS = { "user": "xxx", - "llmFrameworkType": "openai-langchain", - "vectordbType": "chroma-db", - "dbPath": "chromadb_store_test", - "collectionName": "aDotBCollection_test", - "labels": ["NIV bible", "ESV-Bible", "translationwords"], + "llmFrameworkType": schema.LLMFrameworkType.VANILLA.value, + "vectordbType": schema.DatabaseType.POSTGRES.value, + "collectionName": "adotdcollection_test", + "labels": ["NIV bible", "ESV-Bible", "translationwords", "open-access"], "token": admin_token, } @@ -30,41 +33,61 @@ def assert_positive_bot_response(resp_json): assert resp_json["sender"] in ["Bot", "You"] -def test_chat_websocket_connection(fresh_db): +def test_chat_websocket_connection(mocker, fresh_db): """Check if websocket is connecting to and is bot responding""" + mocker.patch("app.routers.Supabase.check_token", + return_value={"user_id": "1111"}) + mocker.patch("app.routers.Supabase.check_role", return_value=True) + mocker.patch("app.routers.Supabase.get_accessible_labels", return_value=["mock-label"]) + mocker.patch("app.core.llm_framework.openai_vanilla.ChatCompletion.create", + return_value={"choices":[{"message":{"content":"Mock response"}}]}) + # mocker.patch("app.routers.ConversationPipeline.llm_framework.generate_text", + # return_value={"question": "Mock Question", + # "answer": "Mock answer", + # "source_documents": [], + # }) args = COMMON_CONNECTION_ARGS.copy() args["dbPath"] = fresh_db["dbPath"] args["collectionName"] = fresh_db["collectionName"] with client.websocket_connect("/chat", params=args) as websocket: - websocket.send_text("Hello") + websocket.send_bytes(json.dumps({"message":"Hello"}).encode('utf-8')) data = websocket.receive_json() assert_positive_bot_response(data) + assert "Mock" in data['message'] - -def test_chat_based_on_translationwords(fresh_db): +@pytest.mark.with_llm +def test_chat_based_on_translationwords_with_llm(mocker, fresh_db): """Add some docs and ask questions on it, with default configs""" + mocker.patch("app.routers.Supabase.check_token", + return_value={"user_id": "1111"}) + mocker.patch("app.routers.Supabase.check_role", return_value=True) + mocker.patch("app.routers.Supabase.get_accessible_labels", + return_value=COMMON_CONNECTION_ARGS['labels']) + # mocker.patch("app.core.llm_framework.openai_vanilla.ChatCompletion.create", + # return_value={"choices":[{"message":{"content":"Mock response"}}]}) args = COMMON_CONNECTION_ARGS.copy() args["dbPath"] = fresh_db["dbPath"] args["collectionName"] = fresh_db["collectionName"] - test_data_upload_markdown(fresh_db, None, None, None) - test_data_upload_processed_sentences(fresh_db, None, None, None) + test_data_upload_markdown(mocker, schema.DatabaseType.POSTGRES, + schema.FileProcessorType.LANGCHAIN, fresh_db) + test_data_upload_processed_sentences(mocker, schema.DatabaseType.POSTGRES, fresh_db) with client.websocket_connect("/chat", params=args) as websocket: websocket.send_text("Hello") data = websocket.receive_json() assert_positive_bot_response(data) - websocket.send_text("Can you tell me about angels?") + websocket.send_bytes(json.dumps({"message":"Can you tell me about angels?"}).encode('utf-8')) data = websocket.receive_json() assert_positive_bot_response(data) assert "angel" in data["message"].lower() assert "spirit" in data["message"].lower() assert "god" in data["message"].lower() - websocket.send_text("What happended in the beginning?") - data = websocket.receive_json() - assert_positive_bot_response(data) - assert "heaven" in data["message"].lower() - assert "god" in data["message"].lower() - assert "earth" in data["message"].lower() + # websocket.send_bytes(json.dumps({"message":"How was earth in the beginning?"}).encode('utf-8')) + # data = websocket.receive_json() + # assert_positive_bot_response(data) + # assert "heaven" in data["message"].lower() + # assert "god" in data["message"].lower() + # assert "earth" in data["message"].lower() From 9d6b5f13e5cb959c2c444b4ad2da78a6f7539fcf Mon Sep 17 00:00:00 2001 From: "Kavitha.Raju" Date: Mon, 11 Dec 2023 12:29:51 +0530 Subject: [PATCH 2/4] Fix linting issues --- app/tests/test_chat_on_websocket.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/app/tests/test_chat_on_websocket.py b/app/tests/test_chat_on_websocket.py index ad99a23..c0adf14 100644 --- a/app/tests/test_chat_on_websocket.py +++ b/app/tests/test_chat_on_websocket.py @@ -2,8 +2,8 @@ import os import json import pytest -from . import client from app import schema +from . import client from .test_dataupload import test_data_upload_markdown from .test_dataupload import test_data_upload_processed_sentences @@ -41,11 +41,6 @@ def test_chat_websocket_connection(mocker, fresh_db): mocker.patch("app.routers.Supabase.get_accessible_labels", return_value=["mock-label"]) mocker.patch("app.core.llm_framework.openai_vanilla.ChatCompletion.create", return_value={"choices":[{"message":{"content":"Mock response"}}]}) - # mocker.patch("app.routers.ConversationPipeline.llm_framework.generate_text", - # return_value={"question": "Mock Question", - # "answer": "Mock answer", - # "source_documents": [], - # }) args = COMMON_CONNECTION_ARGS.copy() args["dbPath"] = fresh_db["dbPath"] args["collectionName"] = fresh_db["collectionName"] @@ -56,15 +51,13 @@ def test_chat_websocket_connection(mocker, fresh_db): assert "Mock" in data['message'] @pytest.mark.with_llm -def test_chat_based_on_translationwords_with_llm(mocker, fresh_db): +def test_chat_based_on_data_in_db(mocker, fresh_db): """Add some docs and ask questions on it, with default configs""" mocker.patch("app.routers.Supabase.check_token", return_value={"user_id": "1111"}) mocker.patch("app.routers.Supabase.check_role", return_value=True) mocker.patch("app.routers.Supabase.get_accessible_labels", return_value=COMMON_CONNECTION_ARGS['labels']) - # mocker.patch("app.core.llm_framework.openai_vanilla.ChatCompletion.create", - # return_value={"choices":[{"message":{"content":"Mock response"}}]}) args = COMMON_CONNECTION_ARGS.copy() args["dbPath"] = fresh_db["dbPath"] args["collectionName"] = fresh_db["collectionName"] @@ -78,14 +71,16 @@ def test_chat_based_on_translationwords_with_llm(mocker, fresh_db): data = websocket.receive_json() assert_positive_bot_response(data) - websocket.send_bytes(json.dumps({"message":"Can you tell me about angels?"}).encode('utf-8')) + websocket.send_bytes( + json.dumps({"message":"Can you tell me about angels?"}).encode('utf-8')) data = websocket.receive_json() assert_positive_bot_response(data) assert "angel" in data["message"].lower() assert "spirit" in data["message"].lower() assert "god" in data["message"].lower() - # websocket.send_bytes(json.dumps({"message":"How was earth in the beginning?"}).encode('utf-8')) + # websocket.send_bytes( + # json.dumps({"message":"How was earth in the beginning?"}).encode('utf-8')) # data = websocket.receive_json() # assert_positive_bot_response(data) # assert "heaven" in data["message"].lower() From 43cfd053da2404bf518b05eff52a4623a0041fd3 Mon Sep 17 00:00:00 2001 From: "Kavitha.Raju" Date: Mon, 11 Dec 2023 13:54:04 +0530 Subject: [PATCH 3/4] handle no OPENAI key case in tests --- app/core/llm_framework/openai_langchain.py | 2 +- app/core/pipeline/__init__.py | 13 ++++++++----- app/routers.py | 8 ++++++-- app/schema.py | 3 +++ app/tests/test_chat_on_websocket.py | 4 +++- app/tests/test_dataupload.py | 3 +++ 6 files changed, 24 insertions(+), 9 deletions(-) diff --git a/app/core/llm_framework/openai_langchain.py b/app/core/llm_framework/openai_langchain.py index 1bc0a46..872d1ac 100644 --- a/app/core/llm_framework/openai_langchain.py +++ b/app/core/llm_framework/openai_langchain.py @@ -33,7 +33,7 @@ class LangchainOpenAI(LLMFrameworkInterface): def __init__( self, # pylint: disable=super-init-not-called # FIXME : Ideal to be able to mock the __init__ from tests - key: str = os.getenv("OPENAI_API_KEY", "dummy-for-test"), + key: str = os.getenv("OPENAI_API_KEY"), model_name: str = "gpt-3.5-turbo", vectordb: VectordbInterface = Chroma(), max_tokens_limit: int = int( diff --git a/app/core/pipeline/__init__.py b/app/core/pipeline/__init__.py index 86b00f9..0dbac3a 100644 --- a/app/core/pipeline/__init__.py +++ b/app/core/pipeline/__init__.py @@ -126,8 +126,10 @@ def __init__( file_processor: FileProcessorInterface = LangchainLoader, embedding: EmbeddingInterface = SentenceTransformerEmbedding(), vectordb: VectordbInterface = Chroma(), - llm_framework: LLMFrameworkInterface = LangchainOpenAI(), + llm_framework: LLMFrameworkInterface = LangchainOpenAI, + llm_api_key: str | None = None, transcription_framework: AudioTranscriptionInterface = WhisperAudioTranscription, + transcription_api_key: str | None = None, ) -> None: """Instantiate with default tech stack""" super().__init__(file_processor, embedding, vectordb) @@ -138,7 +140,8 @@ def __init__( self.embedding = embedding self.vectordb = vectordb self.llm_framework = llm_framework - self.transcription_framework = transcription_framework() + self.transcription_framework = transcription_framework(key=transcription_api_key) + self.llm_framework = llm_framework(key=llm_api_key) def set_llm_framework( self, @@ -159,7 +162,7 @@ def set_llm_framework( path=vectordb.db_path, collection_name=vectordb.collection_name, ) - self.llm_framework = LangchainOpenAI(vectordb=vectordb) + self.llm_framework = LangchainOpenAI(vectordb=vectordb, api_key=api_key) elif choice == schema.LLMFrameworkType.VANILLA: if isinstance(vectordb, Chroma): vectordb = ChromaLC( @@ -168,7 +171,7 @@ def set_llm_framework( path=vectordb.db_path, collection_name=vectordb.collection_name, ) - self.llm_framework = OpenAIVanilla(vectordb=vectordb) + self.llm_framework = OpenAIVanilla(vectordb=vectordb, key=api_key) def set_transcription_framework( self, @@ -181,4 +184,4 @@ def set_transcription_framework( self.transcription_framework.api_key = api_key self.transcription_framework.model_name = model_name if choice == schema.AudioTranscriptionType.WHISPER: - self.transcription_framework = WhisperAudioTranscription() + self.transcription_framework = WhisperAudioTranscription(key=api_key) diff --git a/app/routers.py b/app/routers.py index 63fe576..e245519 100644 --- a/app/routers.py +++ b/app/routers.py @@ -228,7 +228,10 @@ async def websocket_chat_endpoint( if token: log.info("User, connecting with token, %s", token) await websocket.accept() - chat_stack = ConversationPipeline(user="XXX", labels=labels) + chat_stack = ConversationPipeline(user="XXX", + labels=labels, + transcription_api_key=settings.transcriptionApiKey, + llm_api_key=settings.llmApiKey) vectordb_args = compose_vector_db_args( settings.vectordbType, @@ -246,7 +249,8 @@ async def websocket_chat_endpoint( chat_stack.set_llm_framework( settings.llmFrameworkType, vectordb=chat_stack.vectordb, **llm_args ) - chat_stack.set_transcription_framework(settings.transcriptionFrameworkType) + chat_stack.set_transcription_framework(settings.transcriptionFrameworkType, + api_key=settings.transcriptionApiKey) # Not implemented using custom embeddings diff --git a/app/schema.py b/app/schema.py index d02918f..c56a228 100644 --- a/app/schema.py +++ b/app/schema.py @@ -179,6 +179,9 @@ class ChatPipelineSelector(BaseModel): AudioTranscriptionType.WHISPER, desc="The framework through which audio transcription is handled", ) + transcriptionApiKey: str | None = Field( + None, desc="If using a cloud service, like OpenAI, the key obtained from them" + ) # class UserPrompt(BaseModel): # not using this as we recieve string from websocket diff --git a/app/tests/test_chat_on_websocket.py b/app/tests/test_chat_on_websocket.py index c0adf14..c8bf574 100644 --- a/app/tests/test_chat_on_websocket.py +++ b/app/tests/test_chat_on_websocket.py @@ -18,6 +18,8 @@ "collectionName": "adotdcollection_test", "labels": ["NIV bible", "ESV-Bible", "translationwords", "open-access"], "token": admin_token, + "transcriptionApiKey":"dummy-key-for-openai", + "llmApiKey":"dummy-key-for-openai", } @@ -33,7 +35,7 @@ def assert_positive_bot_response(resp_json): assert resp_json["sender"] in ["Bot", "You"] -def test_chat_websocket_connection(mocker, fresh_db): +def test_chat_websocket_connection(mocker, fresh_db, monkeypatch): """Check if websocket is connecting to and is bot responding""" mocker.patch("app.routers.Supabase.check_token", return_value={"user_id": "1111"}) diff --git a/app/tests/test_dataupload.py b/app/tests/test_dataupload.py index 44d184e..342f6e7 100644 --- a/app/tests/test_dataupload.py +++ b/app/tests/test_dataupload.py @@ -61,6 +61,7 @@ def test_data_upload_processed_sentences(mocker, vectordb, fresh_db): "dbPath": fresh_db["dbPath"], "collectionName": fresh_db["collectionName"], "embeddingType": schema.EmbeddingType.HUGGINGFACE_DEFAULT.value, + "llmApiKey":"dummy-value", }, json=SENT_DATA, ) @@ -101,6 +102,7 @@ def test_data_upload_markdown(mocker, vectordb, chunker, fresh_db): "dbPath": fresh_db["dbPath"], "collectionName": fresh_db["collectionName"], "token": ADMIN_TOKEN, + "llmApiKey":"dummy-value", } # json={"vectordb_config": fresh_db} ) @@ -132,6 +134,7 @@ def test_data_upload_csv(mocker, vectordb, fresh_db): "dbPath": fresh_db["dbPath"], "collectionName": fresh_db["collectionName"], "token": ADMIN_TOKEN, + "llmApiKey":"dummy-value", }, json={"vectordb_config": fresh_db}, ) From fc0b6e237d140a75b20539900dd57d72565f4791 Mon Sep 17 00:00:00 2001 From: "Kavitha.Raju" Date: Mon, 11 Dec 2023 14:04:32 +0530 Subject: [PATCH 4/4] Fix linting issues --- app/core/pipeline/__init__.py | 2 +- app/tests/test_chat_on_websocket.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/app/core/pipeline/__init__.py b/app/core/pipeline/__init__.py index 0dbac3a..5d53015 100644 --- a/app/core/pipeline/__init__.py +++ b/app/core/pipeline/__init__.py @@ -162,7 +162,7 @@ def set_llm_framework( path=vectordb.db_path, collection_name=vectordb.collection_name, ) - self.llm_framework = LangchainOpenAI(vectordb=vectordb, api_key=api_key) + self.llm_framework = LangchainOpenAI(vectordb=vectordb, key=api_key) elif choice == schema.LLMFrameworkType.VANILLA: if isinstance(vectordb, Chroma): vectordb = ChromaLC( diff --git a/app/tests/test_chat_on_websocket.py b/app/tests/test_chat_on_websocket.py index c8bf574..1efe735 100644 --- a/app/tests/test_chat_on_websocket.py +++ b/app/tests/test_chat_on_websocket.py @@ -35,7 +35,7 @@ def assert_positive_bot_response(resp_json): assert resp_json["sender"] in ["Bot", "You"] -def test_chat_websocket_connection(mocker, fresh_db, monkeypatch): +def test_chat_websocket_connection(mocker, fresh_db): """Check if websocket is connecting to and is bot responding""" mocker.patch("app.routers.Supabase.check_token", return_value={"user_id": "1111"})