diff --git a/app/core/llm_framework/openai_vanilla.py b/app/core/llm_framework/openai_vanilla.py index 7ebf95c..05d500c 100644 --- a/app/core/llm_framework/openai_vanilla.py +++ b/app/core/llm_framework/openai_vanilla.py @@ -103,7 +103,7 @@ def generate_text( query_text = "\n".join( [x[0] + "/n" + x[1][:50] + "\n" for x in chat_history]) query_text += "\n" + query - source_documents = self.vectordb.get_relevant_documents(query_text) + source_documents = self.vectordb._get_relevant_documents(query_text) #pylint: disable=protected-access context = get_context(source_documents) pre_prompt = get_pre_prompt(context) prompt = append_query_to_prompt(pre_prompt, query, chat_history) diff --git a/app/core/vectordb/__init__.py b/app/core/vectordb/__init__.py index e4a7399..9521fe7 100644 --- a/app/core/vectordb/__init__.py +++ b/app/core/vectordb/__init__.py @@ -1,6 +1,6 @@ """Interface definition and common implemetations for vectordb classes""" import os -from typing import List +from typing import List, Any from abc import abstractmethod, ABC import schema @@ -13,10 +13,8 @@ class VectordbInterface(ABC): db_host: str = None # Host name to connect to a remote DB deployment db_port: str = None # Port to connect to a remote DB deployment - db_path: str = "chromadb_store" # Path for a local DB, if that is being used - collection_name: str = ( - "aDotBCollection" # Collection to connect to a remote/local DB - ) + db_path: str = None # Path for a local DB, if that is being used(in chroma) + collection_name: str = None # Collection to connect to a remote/local DB db_conn = None def __init__(self, host, port, path, collection_name) -> None: @@ -29,7 +27,7 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None: return @abstractmethod - def get_relevant_documents(self, query: str, **kwargs) -> List: + def _get_relevant_documents(self, query: str, run_manager: Any=None, **kwargs) -> List: """Similarity search on the vector store""" @abstractmethod diff --git a/app/core/vectordb/chroma.py b/app/core/vectordb/chroma.py index dec36c4..9264037 100644 --- a/app/core/vectordb/chroma.py +++ b/app/core/vectordb/chroma.py @@ -1,6 +1,6 @@ """Implemetations for vectordb interface for chroma""" import os -from typing import List +from typing import List, Any from core.vectordb import VectordbInterface from core.embedding.sentence_transformers import SentenceTransformerEmbedding @@ -8,9 +8,8 @@ from custom_exceptions import ChromaException import chromadb -from chromadb.config import Settings -# pylint: disable=too-few-public-methods, unused-argument, R0801, super-init-not-called +# pylint: disable=too-few-public-methods, unused-argument, R0801 QUERY_LIMIT = os.getenv("CHROMA_DB_QUERY_LIMIT", "10") @@ -20,17 +19,16 @@ class Chroma(VectordbInterface): db_host: str = None # Host name to connect to a remote DB deployment db_port: str = None # Port to connect to a remote DB deployment db_path: str = "chromadb_store" # Path for a local DB, if that is being used - collection_name: str = ( - "adotbcollection" # Collection to connect to a remote/local DB - ) - db_conn = None - db_client = None - embedding_function = None + collection_name: str = "adotbcollection" # Collection to connect to a remote/local DB + db_conn: Any = None + db_client: Any = None + embedding_function: Any = None def __init__( self, host=None, port=None, path="chromadb_store", collection_name=None - ) -> None: # pylint: disable=super-init-not-called + ) -> None: """Instanciate a chroma client""" + VectordbInterface.__init__(self, host, port, path, collection_name) if host: self.db_host = host if port: @@ -42,10 +40,7 @@ def __init__( # This method connects to the DB that get stored on the server itself # where the app is running try: - chroma_client = chromadb.Client( - Settings(chroma_db_impl="duckdb+parquet", - persist_directory=path) - ) + chroma_client = chromadb.PersistentClient(path=self.db_path) except Exception as exe: raise ChromaException( "While initializing client: " + str(exe)) from exe @@ -62,13 +57,7 @@ def __init__( # each user's API request to let each user connect to the DB # that he prefers and has rights for(fastapi's Depends()) try: - chroma_client = chromadb.Client( - Settings( - chroma_api_impl="rest", - chroma_server_host=host, - chroma_server_http_port=port, - ) - ) + chroma_client = chromadb.HttpClient(host=self.db_host, port=self.db_port) except Exception as exe: raise ChromaException( "While initializing client: " + str(exe)) from exe @@ -95,14 +84,17 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None: { "label": doc.label, "media": ",".join(doc.media), - "links": ",".join(doc.links), + "links": ",".join([str(link) for link in doc.links]), } ) metas.append(meta) if docs[0].embedding is None: embeddings = None else: - embeddings = [doc.embedding for doc in docs] + embeddings = [] + for doc in docs: + vector = [float(item) for item in doc.embedding] + embeddings.append(vector) try: self.db_conn.add( embeddings=embeddings, @@ -110,11 +102,10 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None: metadatas=metas, ids=[doc.docId for doc in docs], ) - self.db_client.persist() except Exception as exe: raise ChromaException("While adding data: " + str(exe)) from exe - def get_relevant_documents(self, query: str, **kwargs) -> List: + def _get_relevant_documents(self, query: str, run_manager=None, **kwargs) -> List: """Similarity search on the vector store""" results = self.db_conn.query( query_texts=[query], @@ -131,7 +122,3 @@ def get_available_labels(self) -> List[str]: labels = [meta["label"] for meta in rows["metadatas"]] labels = list(set(labels)) return labels - - def __del__(self): - """To persist DB upon app close""" - self.db_client.persist() diff --git a/app/core/vectordb/chroma4langchain.py b/app/core/vectordb/chroma4langchain.py index addf106..ae8dedd 100644 --- a/app/core/vectordb/chroma4langchain.py +++ b/app/core/vectordb/chroma4langchain.py @@ -1,17 +1,19 @@ """Implemetations for vectordb interface for chroma""" import os -from typing import List +from typing import List, Any from langchain.schema import Document as LangchainDocument from langchain.schema import BaseRetriever +from langchain.callbacks.manager import CallbackManagerForRetrieverRun +from langchain.callbacks.manager import AsyncCallbackManagerForRetrieverRun + from core.vectordb import VectordbInterface from core.embedding.sentence_transformers import SentenceTransformerEmbedding import schema from custom_exceptions import ChromaException import chromadb -from chromadb.config import Settings -# pylint: disable=too-few-public-methods, unused-argument, R0801, super-init-not-called +# pylint: disable=too-few-public-methods, unused-argument, R0801 QUERY_LIMIT = os.getenv("CHROMA_DB_QUERY_LIMIT", "10") @@ -21,17 +23,17 @@ class Chroma(VectordbInterface, BaseRetriever): db_host: str = None # Host name to connect to a remote DB deployment db_port: str = None # Port to connect to a remote DB deployment db_path: str = "chromadb_store" # Path for a local DB, if that is being used - collection_name: str = ( - "aDotBCollection" # Collection to connect to a remote/local DB - ) - db_conn = None - db_client = None - embedding_function = None + collection_name: str = "adotbcollection" # Collection to connect to a remote/local DB + db_conn: Any = None + db_client: Any = None + embedding_function: Any = None def __init__( self, host=None, port=None, path="chromadb_store", collection_name=None - ) -> None: # pylint: disable=super-init-not-called + ) -> None: """Instanciate a chroma client""" + VectordbInterface.__init__(self, host, port, path, collection_name) + BaseRetriever.__init__(self) if host: self.db_host = host if port: @@ -43,10 +45,7 @@ def __init__( # This method connects to the DB that get stored on the server itself # where the app is running try: - chroma_client = chromadb.Client( - Settings(chroma_db_impl="duckdb+parquet", - persist_directory=path) - ) + chroma_client = chromadb.PersistentClient(path=self.db_path) except Exception as exe: raise ChromaException( "While initializing client: " + str(exe)) from exe @@ -63,13 +62,7 @@ def __init__( # each user's API request to let each user connect to the DB # that he prefers and has rights for(fastapi's Depends()) try: - chroma_client = chromadb.Client( - Settings( - chroma_api_impl="rest", - chroma_server_host=host, - chroma_server_http_port=port, - ) - ) + chroma_client = chromadb.HttpClient(host=self.db_host, port=self.db_port) except Exception as exe: raise ChromaException( "While initializing client: " + str(exe)) from exe @@ -96,14 +89,17 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None: { "label": doc.label, "media": ",".join(doc.media), - "links": ",".join(doc.links), + "links": ",".join([str(link) for link in doc.links]), } ) metas.append(meta) if docs[0].embedding is None: embeddings = None else: - embeddings = [doc.embedding for doc in docs] + embeddings = [] + for doc in docs: + vector = [float(item) for item in doc.embedding] + embeddings.append(vector) try: self.db_conn.add( embeddings=embeddings, @@ -111,11 +107,12 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None: metadatas=metas, ids=[doc.docId for doc in docs], ) - self.db_client.persist() except Exception as exe: raise ChromaException("While adding data: " + str(exe)) from exe - def get_relevant_documents(self, query: str, **kwargs) -> List[LangchainDocument]: + def _get_relevant_documents( + self, query: str, run_manager: CallbackManagerForRetrieverRun | None = None, **kwargs + ) -> List[LangchainDocument]: """Similarity search on the vector store""" results = self.db_conn.query( query_texts=[query], @@ -128,8 +125,8 @@ def get_relevant_documents(self, query: str, **kwargs) -> List[LangchainDocument for doc, id_ in zip(results["documents"][0], results["ids"][0]) ] - async def aget_relevant_documents( - self, query: str, **kwargs + async def _aget_relevant_documents( + self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun| None = None, **kwargs ) -> List[LangchainDocument]: """Similarity search on the vector store""" results = self.db_conn.query( @@ -150,7 +147,3 @@ def get_available_labels(self) -> List[str]: labels = [meta["label"] for meta in rows["metadatas"]] labels = list(set(labels)) return labels - - def __del__(self): - """To persist DB upon app close""" - self.db_client.persist() diff --git a/app/core/vectordb/postgres4langchain.py b/app/core/vectordb/postgres4langchain.py index d05ca7f..206d371 100644 --- a/app/core/vectordb/postgres4langchain.py +++ b/app/core/vectordb/postgres4langchain.py @@ -1,9 +1,11 @@ """Implemetations for vectordb interface for postgres with vector store""" import math import os -from typing import List, Optional +from typing import List, Optional, Any from langchain.schema import Document as LangchainDocument from langchain.schema import BaseRetriever +from langchain.callbacks.manager import CallbackManagerForRetrieverRun +from langchain.callbacks.manager import AsyncCallbackManagerForRetrieverRun from core.vectordb import VectordbInterface from core.embedding import EmbeddingInterface import schema @@ -28,24 +30,27 @@ class Postgres( db_host: str = os.environ.get("POSTGRES_DB_HOST", "localhost") db_port: str = os.environ.get("POSTGRES_DB_PORT", "5432") db_path: Optional[str] = None # Path for a local DB, if that is being used - collection_name: str = os.environ.get( - "POSTGRES_DB_NAME", "adotbcollection") - db_user = os.environ.get("POSTGRES_DB_USER", "admin") - db_password = os.environ.get("POSTGRES_DB_PASSWORD", "secret") + collection_name: str = os.environ.get("POSTGRES_DB_NAME", "adotbcollection") + db_user: str = os.environ.get("POSTGRES_DB_USER", "admin") + db_password: str = os.environ.get("POSTGRES_DB_PASSWORD", "secret") embedding: EmbeddingInterface = None - db_client = None - + db_client: Any = None + db_conn:Any = None + labels:List[str] = [] + query_limit:int = None + max_cosine_distance: str = None def __init__( self, embedding: EmbeddingInterface = None, - host=None, - port=None, - path=None, - collection_name=None, - # pylint: disable=super-init-not-called - **kwargs, - ) -> None: # pylint: disable=super-init-not-called + host:str=None, + port:int=None, + path:str=None, + collection_name:str=None, + **kwargs:str, + ) -> None: """Instantiate a chroma client""" + VectordbInterface.__init__(self, host, port, path, collection_name) + BaseRetriever.__init__(self) # You MUST set embedding with PGVector, # since with this DB type the embedding # dimension size always hard-coded on init @@ -142,7 +147,7 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None: doc.text, doc.label, doc.media, - doc.links, + str(doc.links), doc.embedding, ] ) @@ -164,7 +169,7 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None: doc.text, doc.label, doc.media, - doc.links, + str(doc.links), doc.embedding, doc.docId, ), @@ -196,7 +201,9 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None: except Exception as exe: raise PostgresException("While adding data: " + str(exe)) from exe - def get_relevant_documents(self, query: list, **kwargs) -> List[LangchainDocument]: + def _get_relevant_documents( + self, query: list, run_manager: CallbackManagerForRetrieverRun| None = None, **kwargs + ) -> List[LangchainDocument]: """Similarity search on the vector store""" query_doc = schema.Document(docId="xxx", text=query) try: @@ -249,8 +256,8 @@ def get_relevant_documents(self, query: list, **kwargs) -> List[LangchainDocumen for doc in records ] - async def aget_relevant_documents( - self, query: list, **kwargs + async def _aget_relevant_documents( + self, query: list, run_manager: AsyncCallbackManagerForRetrieverRun| None = None, **kwargs ) -> List[LangchainDocument]: """Similarity search on the vector store""" query_doc = schema.Document(docId="xxx", text=query) diff --git a/app/schema.py b/app/schema.py index e6168be..d02918f 100644 --- a/app/schema.py +++ b/app/schema.py @@ -2,21 +2,22 @@ # pylint: disable=too-many-lines, fixme from typing import List from enum import Enum -from pydantic import BaseModel, Field, AnyUrl, constr, SecretStr # , validator +from pydantic import StringConstraints, BaseModel, Field, AnyUrl, SecretStr # , validator +from typing_extensions import Annotated class APIInfoResponse(BaseModel): """Response with only a message""" - message: str = Field(..., example="App is up and running") + message: str = Field(..., examples=["App is up and running"]) class APIErrorResponse(BaseModel): """Common error response format""" - error: str = Field(..., example="Database Error") + error: str = Field(..., examples=["Database Error"]) details: str = Field(..., - example="Violation of unique constraint blah blah blah") + examples=["Violation of unique constraint blah blah blah"]) class FileProcessorType(str, Enum): @@ -78,16 +79,16 @@ class CsvColDelimiter(str, Enum): TAB = "tab" -HostnPortPattern = constr(regex=r"^.*:\d+$") +HostnPortPattern = Annotated[str, StringConstraints(pattern=r"^.*:\d+$")] class DBSelector(BaseModel): """The credentials to connect to a remotely hosted chorma DB""" # dbTech: str = Field(DatabaseTech.CHROMA, desc="Technology choice like chroma, pinecone etc") - dbHostnPort: HostnPortPattern = Field( + dbHostnPort: HostnPortPattern | None = Field( None, - example="api.vachanengine.org:6000", + examples=["api.vachanengine.org:6000"], desc="Host and port name to connect to a remote DB deployment", ) dbPath: str = Field( @@ -98,9 +99,9 @@ class DBSelector(BaseModel): desc="Collection to connect to in a local/remote DB." + "One collection should use single embedding type for all docs", ) - dbUser: str = Field( + dbUser: str | None = Field( None, desc="Creds to connect to the server or remote db") - dbPassword: SecretStr = Field( + dbPassword: SecretStr | None = Field( None, desc="Creds to connect to the server or remote db" ) @@ -112,10 +113,10 @@ class EmbeddingSelector(BaseModel): EmbeddingType.HUGGINGFACE_DEFAULT, desc="EmbeddingType used for storing and searching documents in vectordb", ) - embeddingApiKey: str = Field( + embeddingApiKey: str | None = Field( None, desc="If using a cloud service, like OpenAI, the key obtained from them" ) - embeddingModelName: str = Field( + embeddingModelName: str | None = Field( None, desc="If there is a model we can choose to use from the available" ) @@ -123,10 +124,10 @@ class EmbeddingSelector(BaseModel): class LLMFrameworkSelector(BaseModel): """The credentials and configs to be used in the LLM and its framework""" - llmApiKey: str = Field( + llmApiKey: str | None = Field( None, desc="If using a cloud service, like OpenAI, the key from them" ) - llmModelName: str = Field( + llmModelName: str | None = Field( None, desc="The model to be used for chat completion") @@ -137,18 +138,18 @@ class ChatPipelineSelector(BaseModel): LLMFrameworkType.LANGCHAIN, desc="The framework through which LLM access is handled", ) - llmApiKey: str = Field( + llmApiKey: str | None = Field( None, desc="If using a cloud service, like OpenAI, the key from them" ) - llmModelName: str = Field( + llmModelName: str | None = Field( None, desc="The model to be used for chat completion") vectordbType: DatabaseType = Field( DatabaseType.POSTGRES, desc="The Database to be connected to. Same one used for dataupload", ) - dbHostnPort: HostnPortPattern = Field( + dbHostnPort: HostnPortPattern | None = Field( None, - example="api.vachanengine.org:6000", + examples=["api.vachanengine.org:6000"], desc="Host and port name to connect to a remote DB deployment", ) dbPath: str = Field( @@ -159,19 +160,19 @@ class ChatPipelineSelector(BaseModel): desc="Collection to connect to in a local/remote DB." + "One collection should use single embedding type for all docs", ) - dbUser: str = Field( + dbUser: str | None = Field( None, desc="Creds to connect to the server or remote db") - dbPassword: SecretStr = Field( + dbPassword: SecretStr | None = Field( None, desc="Creds to connect to the server or remote db" ) embeddingType: EmbeddingType = Field( EmbeddingType.HUGGINGFACE_DEFAULT, desc="EmbeddingType used for storing and searching documents in vectordb", ) - embeddingApiKey: str = Field( + embeddingApiKey: str | None = Field( None, desc="If using a cloud service, like OpenAI, the key obtained from them" ) - embeddingModelName: str = Field( + embeddingModelName: str | None = Field( None, desc="If there is a model we can choose to use from the available" ) transcriptionFrameworkType: AudioTranscriptionType = Field( @@ -203,20 +204,20 @@ class ChatResponseType(str, Enum): class BotResponse(BaseModel): """Chat response from server to UI or user app""" - message: str = Field(..., example="Good Morning to you too!") - sender: SenderType = Field(..., example="You or BOT") - sources: List[str] = Field( + message: str = Field(..., examples=["Good Morning to you too!"]) + sender: SenderType = Field(..., examples=["You or BOT"]) + sources: List[str] | None = Field( None, - example=[ + examples=[[ "https://www.biblegateway.com/passage/?search=Genesis+1%3A1&version=NIV", "https://git.door43.org/Door43-Catalog/en_tw/src/branch/master/" + "bible/other/creation.md", - ], + ]], ) - media: List[AnyUrl] = Field( - None, example=["https://www.youtube.com/watch?v=teu7BCZTgDs"] + media: List[AnyUrl] | None = Field( + None, examples=[["https://www.youtube.com/watch?v=teu7BCZTgDs"]] ) - type: ChatResponseType = Field(..., example="answer or error") + type: ChatResponseType = Field(..., examples=["answer or error"]) class JobStatus(str, Enum): @@ -231,14 +232,14 @@ class JobStatus(str, Enum): class Job(BaseModel): """Response object of Background Job status check""" - jobId: int = Field(..., example=100000) - status: JobStatus = Field(..., example="started") - output: dict = Field( + jobId: int = Field(..., examples=[100000]) + status: JobStatus = Field(..., examples=["started"]) + output: dict | None = Field( None, - example={ + examples=[{ "error": "Gateway Error", "details": "Connect to OpenAI terminated unexpectedly", - }, + }], ) @@ -247,18 +248,18 @@ class Document(BaseModel): docId: str = Field( ..., - example="NIV Bible Mat 1:1-20", + examples=["NIV Bible Mat 1:1-20"], desc="Unique for a sentence. Used by the LLM to specify which document " + "it answers from. Better to combine the source tag and a serial number.", ) text: str = Field( ..., desc="The sentence to be vectorised and used for question answering" ) - embedding: List[float] = Field( + embedding: List[float] | None = Field( None, desc="vector embedding for the text field") label: str = Field( "open-access", - example="paratext user manual or bible or door-43-users", + examples=["paratext user manual or bible or door-43-users"], desc="The common tag for all sentences under a set. " + "Used for specifying access rules and filtering during querying", ) @@ -276,7 +277,7 @@ class Document(BaseModel): {}, desc="Any additional data that needs to go along with the text," + " as per usecases. Could help in pre- and/or post-processing", - example={"displayimages": True}, + examples=[{"displayimages": True}], ) diff --git a/app/tests/conftest.py b/app/tests/conftest.py index 1c6feb8..e754431 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -1,10 +1,10 @@ """ Test fixtures and stuff""" import os -import shutil import pytest import psycopg2 from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT +import chromadb @pytest.fixture @@ -15,9 +15,9 @@ def fresh_db(): collection_name = "adotdcollection_test" # Chroma Specific clean up - chroma_db_path = "chromadb_store_test" - if os.path.exists(chroma_db_path): - shutil.rmtree(chroma_db_path) + chroma_db_path = 'chromadb_store_test' + chroma_client = chromadb.PersistentClient(path=chroma_db_path) + _ = chroma_client.get_or_create_collection(name=collection_name) # Postgres specific cleanup pg_db_host = os.environ.get("POSTGRES_DB_HOST", "localhost") @@ -43,8 +43,9 @@ def fresh_db(): try: yield {"dbPath": chroma_db_path, "collectionName": collection_name} finally: - if os.path.exists(chroma_db_path): - shutil.rmtree(chroma_db_path) + chroma_client = chromadb.PersistentClient(path='chromadb_store_test') + chroma_client.delete_collection(name=collection_name) + db_conn = psycopg2.connect( user=pg_db_user, password=pg_db_password, diff --git a/requirements.txt b/requirements.txt index 99ffbda..1dada06 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,20 @@ -fastapi==0.95.1 +fastapi==0.104.1 uvicorn==0.22.0 -pydantic==1.10.7 +pydantic==2.4.2 websockets openai==0.27.6 Jinja2==3.1.2 -chromadb==0.3.22 +chromadb==0.4.15 pytest==7.3.1 pytest-mock==3.11.1 httpx pylint==2.17.4 -langchain==0.0.165 +langchain==0.0.331 python-multipart==0.0.6 psycopg2==2.9.6 pgvector==0.1.8 supabase==1.0.3 duckdb==0.7.1 tiktoken==0.5.1 +sentence-transformers==2.2.2 + \ No newline at end of file