Skip to content

Commit

Permalink
Merge pull request #123 from kavitharaju/test-chat-endpoint
Browse files Browse the repository at this point in the history
Test chat endpoint
  • Loading branch information
kavitharaju authored Dec 16, 2023
2 parents 46158a8 + e94af8e commit aeb9d61
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check_on_push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ jobs:

- name: Run Pytest
working-directory: ./app
run: python3 -m pytest -s tests/test_basics.py tests/test_dataupload.py
run: python3 -m pytest -s -m "not with_llm"
2 changes: 1 addition & 1 deletion app/core/llm_framework/openai_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class LangchainOpenAI(LLMFrameworkInterface):
def __init__(
self, # pylint: disable=super-init-not-called
# FIXME : Ideal to be able to mock the __init__ from tests
key: str = os.getenv("OPENAI_API_KEY", "dummy-for-test"),
key: str = os.getenv("OPENAI_API_KEY"),
model_name: str = "gpt-3.5-turbo",
vectordb: VectordbInterface = Chroma(),
max_tokens_limit: int = int(
Expand Down
3 changes: 2 additions & 1 deletion app/core/llm_framework/openai_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Tuple

import openai
from openai import ChatCompletion

from core.llm_framework import LLMFrameworkInterface
from core.vectordb import VectordbInterface
Expand Down Expand Up @@ -111,7 +112,7 @@ def generate_text(
print(f"{prompt=}")

try:
response = openai.ChatCompletion.create(
response = ChatCompletion.create(
model=self.model_name,
temperature=0,
messages=[{"role": "user", "content": prompt}],
Expand Down
13 changes: 8 additions & 5 deletions app/core/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ def __init__(
file_processor: FileProcessorInterface = LangchainLoader,
embedding: EmbeddingInterface = SentenceTransformerEmbedding(),
vectordb: VectordbInterface = Chroma(),
llm_framework: LLMFrameworkInterface = LangchainOpenAI(),
llm_framework: LLMFrameworkInterface = LangchainOpenAI,
llm_api_key: str | None = None,
transcription_framework: AudioTranscriptionInterface = WhisperAudioTranscription,
transcription_api_key: str | None = None,
) -> None:
"""Instantiate with default tech stack"""
super().__init__(file_processor, embedding, vectordb)
Expand All @@ -138,7 +140,8 @@ def __init__(
self.embedding = embedding
self.vectordb = vectordb
self.llm_framework = llm_framework
self.transcription_framework = transcription_framework()
self.transcription_framework = transcription_framework(key=transcription_api_key)
self.llm_framework = llm_framework(key=llm_api_key)

def set_llm_framework(
self,
Expand All @@ -159,7 +162,7 @@ def set_llm_framework(
path=vectordb.db_path,
collection_name=vectordb.collection_name,
)
self.llm_framework = LangchainOpenAI(vectordb=vectordb)
self.llm_framework = LangchainOpenAI(vectordb=vectordb, key=api_key)
elif choice == schema.LLMFrameworkType.VANILLA:
if isinstance(vectordb, Chroma):
vectordb = ChromaLC(
Expand All @@ -168,7 +171,7 @@ def set_llm_framework(
path=vectordb.db_path,
collection_name=vectordb.collection_name,
)
self.llm_framework = OpenAIVanilla(vectordb=vectordb)
self.llm_framework = OpenAIVanilla(vectordb=vectordb, key=api_key)

def set_transcription_framework(
self,
Expand All @@ -181,4 +184,4 @@ def set_transcription_framework(
self.transcription_framework.api_key = api_key
self.transcription_framework.model_name = model_name
if choice == schema.AudioTranscriptionType.WHISPER:
self.transcription_framework = WhisperAudioTranscription()
self.transcription_framework = WhisperAudioTranscription(key=api_key)
8 changes: 6 additions & 2 deletions app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ async def websocket_chat_endpoint(
if token:
log.info("User, connecting with token, %s", token)
await websocket.accept()
chat_stack = ConversationPipeline(user="XXX", labels=labels)
chat_stack = ConversationPipeline(user="XXX",
labels=labels,
transcription_api_key=settings.transcriptionApiKey,
llm_api_key=settings.llmApiKey)

vectordb_args = compose_vector_db_args(
settings.vectordbType,
Expand All @@ -246,7 +249,8 @@ async def websocket_chat_endpoint(
chat_stack.set_llm_framework(
settings.llmFrameworkType, vectordb=chat_stack.vectordb, **llm_args
)
chat_stack.set_transcription_framework(settings.transcriptionFrameworkType)
chat_stack.set_transcription_framework(settings.transcriptionFrameworkType,
api_key=settings.transcriptionApiKey)

# Not implemented using custom embeddings

Expand Down
3 changes: 3 additions & 0 deletions app/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ class ChatPipelineSelector(BaseModel):
AudioTranscriptionType.WHISPER,
desc="The framework through which audio transcription is handled",
)
transcriptionApiKey: str | None = Field(
None, desc="If using a cloud service, like OpenAI, the key obtained from them"
)


# class UserPrompt(BaseModel): # not using this as we recieve string from websocket
Expand Down
56 changes: 38 additions & 18 deletions app/tests/test_chat_on_websocket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""Test the chat websocket"""
import os
import json
import pytest
from app import schema
from . import client

from .test_dataupload import test_data_upload_markdown
from .test_dataupload import test_data_upload_processed_sentences

Expand All @@ -9,12 +13,13 @@
# pylint: disable=too-many-function-args
COMMON_CONNECTION_ARGS = {
"user": "xxx",
"llmFrameworkType": "openai-langchain",
"vectordbType": "chroma-db",
"dbPath": "chromadb_store_test",
"collectionName": "aDotBCollection_test",
"labels": ["NIV bible", "ESV-Bible", "translationwords"],
"llmFrameworkType": schema.LLMFrameworkType.VANILLA.value,
"vectordbType": schema.DatabaseType.POSTGRES.value,
"collectionName": "adotdcollection_test",
"labels": ["NIV bible", "ESV-Bible", "translationwords", "open-access"],
"token": admin_token,
"transcriptionApiKey":"dummy-key-for-openai",
"llmApiKey":"dummy-key-for-openai",
}


Expand All @@ -30,41 +35,56 @@ def assert_positive_bot_response(resp_json):
assert resp_json["sender"] in ["Bot", "You"]


def test_chat_websocket_connection(fresh_db):
def test_chat_websocket_connection(mocker, fresh_db):
"""Check if websocket is connecting to and is bot responding"""
mocker.patch("app.routers.Supabase.check_token",
return_value={"user_id": "1111"})
mocker.patch("app.routers.Supabase.check_role", return_value=True)
mocker.patch("app.routers.Supabase.get_accessible_labels", return_value=["mock-label"])
mocker.patch("app.core.llm_framework.openai_vanilla.ChatCompletion.create",
return_value={"choices":[{"message":{"content":"Mock response"}}]})
args = COMMON_CONNECTION_ARGS.copy()
args["dbPath"] = fresh_db["dbPath"]
args["collectionName"] = fresh_db["collectionName"]
with client.websocket_connect("/chat", params=args) as websocket:
websocket.send_text("Hello")
websocket.send_bytes(json.dumps({"message":"Hello"}).encode('utf-8'))
data = websocket.receive_json()
assert_positive_bot_response(data)
assert "Mock" in data['message']


def test_chat_based_on_translationwords(fresh_db):
@pytest.mark.with_llm
def test_chat_based_on_data_in_db(mocker, fresh_db):
"""Add some docs and ask questions on it, with default configs"""
mocker.patch("app.routers.Supabase.check_token",
return_value={"user_id": "1111"})
mocker.patch("app.routers.Supabase.check_role", return_value=True)
mocker.patch("app.routers.Supabase.get_accessible_labels",
return_value=COMMON_CONNECTION_ARGS['labels'])
args = COMMON_CONNECTION_ARGS.copy()
args["dbPath"] = fresh_db["dbPath"]
args["collectionName"] = fresh_db["collectionName"]

test_data_upload_markdown(fresh_db, None, None, None)
test_data_upload_processed_sentences(fresh_db, None, None, None)
test_data_upload_markdown(mocker, schema.DatabaseType.POSTGRES,
schema.FileProcessorType.LANGCHAIN, fresh_db)
test_data_upload_processed_sentences(mocker, schema.DatabaseType.POSTGRES, fresh_db)

with client.websocket_connect("/chat", params=args) as websocket:
websocket.send_text("Hello")
data = websocket.receive_json()
assert_positive_bot_response(data)

websocket.send_text("Can you tell me about angels?")
websocket.send_bytes(
json.dumps({"message":"Can you tell me about angels?"}).encode('utf-8'))
data = websocket.receive_json()
assert_positive_bot_response(data)
assert "angel" in data["message"].lower()
assert "spirit" in data["message"].lower()
assert "god" in data["message"].lower()

websocket.send_text("What happended in the beginning?")
data = websocket.receive_json()
assert_positive_bot_response(data)
assert "heaven" in data["message"].lower()
assert "god" in data["message"].lower()
assert "earth" in data["message"].lower()
# websocket.send_bytes(
# json.dumps({"message":"How was earth in the beginning?"}).encode('utf-8'))
# data = websocket.receive_json()
# assert_positive_bot_response(data)
# assert "heaven" in data["message"].lower()
# assert "god" in data["message"].lower()
# assert "earth" in data["message"].lower()
3 changes: 3 additions & 0 deletions app/tests/test_dataupload.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def test_data_upload_processed_sentences(mocker, vectordb, fresh_db):
"dbPath": fresh_db["dbPath"],
"collectionName": fresh_db["collectionName"],
"embeddingType": schema.EmbeddingType.HUGGINGFACE_DEFAULT.value,
"llmApiKey":"dummy-value",
},
json=SENT_DATA,
)
Expand Down Expand Up @@ -101,6 +102,7 @@ def test_data_upload_markdown(mocker, vectordb, chunker, fresh_db):
"dbPath": fresh_db["dbPath"],
"collectionName": fresh_db["collectionName"],
"token": ADMIN_TOKEN,
"llmApiKey":"dummy-value",
}
# json={"vectordb_config": fresh_db}
)
Expand Down Expand Up @@ -132,6 +134,7 @@ def test_data_upload_csv(mocker, vectordb, fresh_db):
"dbPath": fresh_db["dbPath"],
"collectionName": fresh_db["collectionName"],
"token": ADMIN_TOKEN,
"llmApiKey":"dummy-value",
},
json={"vectordb_config": fresh_db},
)
Expand Down

0 comments on commit aeb9d61

Please sign in to comment.