Skip to content

Commit

Permalink
Version upgrades (#116)
Browse files Browse the repository at this point in the history
* Pydantic upgrade: regex to pattern in constr

* pydantic upgrade: Fix issues in vectordb module due to inheriting from BaseRetriver

* Chroma upgrade: Remove persist() calls

* Chroma upgrade: change deprecated client creation methods

* Pydantic Upgrade: example to list of examples using bump-pydantic

* Langchain upgrade: get_relevant_documents to _get_relevant_documents

* Langchain upgrade: keep runmanager optional for vanilla framework to be able to use it

* pydantic upgrade: Specify noe as type in optional fields

* pydantic upgrade: explicitly convert url to str before inserting to postgres

* upgrade pydantic/chroma: explicity convert numpy.floats to floats in embeddings before adding to chroma

* Upgrade Chroma: Use delete_collection instead of deleting datapath in tests

* Upgrade pydantic: Explicitly convert Url to str before loading to chroma

* Fix linting issues

* Edit requirements.tx with new versions

* Include sentence-tranfromer in requirements.txt

* Fix type conversion of list of urls
  • Loading branch information
kavitharaju authored Nov 18, 2023
1 parent 6a0479d commit 75931ec
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 134 deletions.
2 changes: 1 addition & 1 deletion app/core/llm_framework/openai_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def generate_text(
query_text = "\n".join(
[x[0] + "/n" + x[1][:50] + "\n" for x in chat_history])
query_text += "\n" + query
source_documents = self.vectordb.get_relevant_documents(query_text)
source_documents = self.vectordb._get_relevant_documents(query_text) #pylint: disable=protected-access
context = get_context(source_documents)
pre_prompt = get_pre_prompt(context)
prompt = append_query_to_prompt(pre_prompt, query, chat_history)
Expand Down
10 changes: 4 additions & 6 deletions app/core/vectordb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Interface definition and common implemetations for vectordb classes"""
import os
from typing import List
from typing import List, Any
from abc import abstractmethod, ABC

import schema
Expand All @@ -13,10 +13,8 @@ class VectordbInterface(ABC):

db_host: str = None # Host name to connect to a remote DB deployment
db_port: str = None # Port to connect to a remote DB deployment
db_path: str = "chromadb_store" # Path for a local DB, if that is being used
collection_name: str = (
"aDotBCollection" # Collection to connect to a remote/local DB
)
db_path: str = None # Path for a local DB, if that is being used(in chroma)
collection_name: str = None # Collection to connect to a remote/local DB
db_conn = None

def __init__(self, host, port, path, collection_name) -> None:
Expand All @@ -29,7 +27,7 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None:
return

@abstractmethod
def get_relevant_documents(self, query: str, **kwargs) -> List:
def _get_relevant_documents(self, query: str, run_manager: Any=None, **kwargs) -> List:
"""Similarity search on the vector store"""

@abstractmethod
Expand Down
45 changes: 16 additions & 29 deletions app/core/vectordb/chroma.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""Implemetations for vectordb interface for chroma"""
import os
from typing import List
from typing import List, Any

from core.vectordb import VectordbInterface
from core.embedding.sentence_transformers import SentenceTransformerEmbedding
import schema
from custom_exceptions import ChromaException

import chromadb
from chromadb.config import Settings

# pylint: disable=too-few-public-methods, unused-argument, R0801, super-init-not-called
# pylint: disable=too-few-public-methods, unused-argument, R0801
QUERY_LIMIT = os.getenv("CHROMA_DB_QUERY_LIMIT", "10")


Expand All @@ -20,17 +19,16 @@ class Chroma(VectordbInterface):
db_host: str = None # Host name to connect to a remote DB deployment
db_port: str = None # Port to connect to a remote DB deployment
db_path: str = "chromadb_store" # Path for a local DB, if that is being used
collection_name: str = (
"adotbcollection" # Collection to connect to a remote/local DB
)
db_conn = None
db_client = None
embedding_function = None
collection_name: str = "adotbcollection" # Collection to connect to a remote/local DB
db_conn: Any = None
db_client: Any = None
embedding_function: Any = None

def __init__(
self, host=None, port=None, path="chromadb_store", collection_name=None
) -> None: # pylint: disable=super-init-not-called
) -> None:
"""Instanciate a chroma client"""
VectordbInterface.__init__(self, host, port, path, collection_name)
if host:
self.db_host = host
if port:
Expand All @@ -42,10 +40,7 @@ def __init__(
# This method connects to the DB that get stored on the server itself
# where the app is running
try:
chroma_client = chromadb.Client(
Settings(chroma_db_impl="duckdb+parquet",
persist_directory=path)
)
chroma_client = chromadb.PersistentClient(path=self.db_path)
except Exception as exe:
raise ChromaException(
"While initializing client: " + str(exe)) from exe
Expand All @@ -62,13 +57,7 @@ def __init__(
# each user's API request to let each user connect to the DB
# that he prefers and has rights for(fastapi's Depends())
try:
chroma_client = chromadb.Client(
Settings(
chroma_api_impl="rest",
chroma_server_host=host,
chroma_server_http_port=port,
)
)
chroma_client = chromadb.HttpClient(host=self.db_host, port=self.db_port)
except Exception as exe:
raise ChromaException(
"While initializing client: " + str(exe)) from exe
Expand All @@ -95,26 +84,28 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None:
{
"label": doc.label,
"media": ",".join(doc.media),
"links": ",".join(doc.links),
"links": ",".join([str(link) for link in doc.links]),
}
)
metas.append(meta)
if docs[0].embedding is None:
embeddings = None
else:
embeddings = [doc.embedding for doc in docs]
embeddings = []
for doc in docs:
vector = [float(item) for item in doc.embedding]
embeddings.append(vector)
try:
self.db_conn.add(
embeddings=embeddings,
documents=[doc.text for doc in docs],
metadatas=metas,
ids=[doc.docId for doc in docs],
)
self.db_client.persist()
except Exception as exe:
raise ChromaException("While adding data: " + str(exe)) from exe

def get_relevant_documents(self, query: str, **kwargs) -> List:
def _get_relevant_documents(self, query: str, run_manager=None, **kwargs) -> List:
"""Similarity search on the vector store"""
results = self.db_conn.query(
query_texts=[query],
Expand All @@ -131,7 +122,3 @@ def get_available_labels(self) -> List[str]:
labels = [meta["label"] for meta in rows["metadatas"]]
labels = list(set(labels))
return labels

def __del__(self):
"""To persist DB upon app close"""
self.db_client.persist()
55 changes: 24 additions & 31 deletions app/core/vectordb/chroma4langchain.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""Implemetations for vectordb interface for chroma"""
import os
from typing import List
from typing import List, Any
from langchain.schema import Document as LangchainDocument
from langchain.schema import BaseRetriever
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.callbacks.manager import AsyncCallbackManagerForRetrieverRun

from core.vectordb import VectordbInterface
from core.embedding.sentence_transformers import SentenceTransformerEmbedding
import schema
from custom_exceptions import ChromaException

import chromadb
from chromadb.config import Settings

# pylint: disable=too-few-public-methods, unused-argument, R0801, super-init-not-called
# pylint: disable=too-few-public-methods, unused-argument, R0801
QUERY_LIMIT = os.getenv("CHROMA_DB_QUERY_LIMIT", "10")


Expand All @@ -21,17 +23,17 @@ class Chroma(VectordbInterface, BaseRetriever):
db_host: str = None # Host name to connect to a remote DB deployment
db_port: str = None # Port to connect to a remote DB deployment
db_path: str = "chromadb_store" # Path for a local DB, if that is being used
collection_name: str = (
"aDotBCollection" # Collection to connect to a remote/local DB
)
db_conn = None
db_client = None
embedding_function = None
collection_name: str = "adotbcollection" # Collection to connect to a remote/local DB
db_conn: Any = None
db_client: Any = None
embedding_function: Any = None

def __init__(
self, host=None, port=None, path="chromadb_store", collection_name=None
) -> None: # pylint: disable=super-init-not-called
) -> None:
"""Instanciate a chroma client"""
VectordbInterface.__init__(self, host, port, path, collection_name)
BaseRetriever.__init__(self)
if host:
self.db_host = host
if port:
Expand All @@ -43,10 +45,7 @@ def __init__(
# This method connects to the DB that get stored on the server itself
# where the app is running
try:
chroma_client = chromadb.Client(
Settings(chroma_db_impl="duckdb+parquet",
persist_directory=path)
)
chroma_client = chromadb.PersistentClient(path=self.db_path)
except Exception as exe:
raise ChromaException(
"While initializing client: " + str(exe)) from exe
Expand All @@ -63,13 +62,7 @@ def __init__(
# each user's API request to let each user connect to the DB
# that he prefers and has rights for(fastapi's Depends())
try:
chroma_client = chromadb.Client(
Settings(
chroma_api_impl="rest",
chroma_server_host=host,
chroma_server_http_port=port,
)
)
chroma_client = chromadb.HttpClient(host=self.db_host, port=self.db_port)
except Exception as exe:
raise ChromaException(
"While initializing client: " + str(exe)) from exe
Expand All @@ -96,26 +89,30 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None:
{
"label": doc.label,
"media": ",".join(doc.media),
"links": ",".join(doc.links),
"links": ",".join([str(link) for link in doc.links]),
}
)
metas.append(meta)
if docs[0].embedding is None:
embeddings = None
else:
embeddings = [doc.embedding for doc in docs]
embeddings = []
for doc in docs:
vector = [float(item) for item in doc.embedding]
embeddings.append(vector)
try:
self.db_conn.add(
embeddings=embeddings,
documents=[doc.text for doc in docs],
metadatas=metas,
ids=[doc.docId for doc in docs],
)
self.db_client.persist()
except Exception as exe:
raise ChromaException("While adding data: " + str(exe)) from exe

def get_relevant_documents(self, query: str, **kwargs) -> List[LangchainDocument]:
def _get_relevant_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun | None = None, **kwargs
) -> List[LangchainDocument]:
"""Similarity search on the vector store"""
results = self.db_conn.query(
query_texts=[query],
Expand All @@ -128,8 +125,8 @@ def get_relevant_documents(self, query: str, **kwargs) -> List[LangchainDocument
for doc, id_ in zip(results["documents"][0], results["ids"][0])
]

async def aget_relevant_documents(
self, query: str, **kwargs
async def _aget_relevant_documents(
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun| None = None, **kwargs
) -> List[LangchainDocument]:
"""Similarity search on the vector store"""
results = self.db_conn.query(
Expand All @@ -150,7 +147,3 @@ def get_available_labels(self) -> List[str]:
labels = [meta["label"] for meta in rows["metadatas"]]
labels = list(set(labels))
return labels

def __del__(self):
"""To persist DB upon app close"""
self.db_client.persist()
45 changes: 26 additions & 19 deletions app/core/vectordb/postgres4langchain.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Implemetations for vectordb interface for postgres with vector store"""
import math
import os
from typing import List, Optional
from typing import List, Optional, Any
from langchain.schema import Document as LangchainDocument
from langchain.schema import BaseRetriever
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.callbacks.manager import AsyncCallbackManagerForRetrieverRun
from core.vectordb import VectordbInterface
from core.embedding import EmbeddingInterface
import schema
Expand All @@ -28,24 +30,27 @@ class Postgres(
db_host: str = os.environ.get("POSTGRES_DB_HOST", "localhost")
db_port: str = os.environ.get("POSTGRES_DB_PORT", "5432")
db_path: Optional[str] = None # Path for a local DB, if that is being used
collection_name: str = os.environ.get(
"POSTGRES_DB_NAME", "adotbcollection")
db_user = os.environ.get("POSTGRES_DB_USER", "admin")
db_password = os.environ.get("POSTGRES_DB_PASSWORD", "secret")
collection_name: str = os.environ.get("POSTGRES_DB_NAME", "adotbcollection")
db_user: str = os.environ.get("POSTGRES_DB_USER", "admin")
db_password: str = os.environ.get("POSTGRES_DB_PASSWORD", "secret")
embedding: EmbeddingInterface = None
db_client = None

db_client: Any = None
db_conn:Any = None
labels:List[str] = []
query_limit:int = None
max_cosine_distance: str = None
def __init__(
self,
embedding: EmbeddingInterface = None,
host=None,
port=None,
path=None,
collection_name=None,
# pylint: disable=super-init-not-called
**kwargs,
) -> None: # pylint: disable=super-init-not-called
host:str=None,
port:int=None,
path:str=None,
collection_name:str=None,
**kwargs:str,
) -> None:
"""Instantiate a chroma client"""
VectordbInterface.__init__(self, host, port, path, collection_name)
BaseRetriever.__init__(self)
# You MUST set embedding with PGVector,
# since with this DB type the embedding
# dimension size always hard-coded on init
Expand Down Expand Up @@ -142,7 +147,7 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None:
doc.text,
doc.label,
doc.media,
doc.links,
str(doc.links),
doc.embedding,
]
)
Expand All @@ -164,7 +169,7 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None:
doc.text,
doc.label,
doc.media,
doc.links,
str(doc.links),
doc.embedding,
doc.docId,
),
Expand Down Expand Up @@ -196,7 +201,9 @@ def add_to_collection(self, docs: List[schema.Document], **kwargs) -> None:
except Exception as exe:
raise PostgresException("While adding data: " + str(exe)) from exe

def get_relevant_documents(self, query: list, **kwargs) -> List[LangchainDocument]:
def _get_relevant_documents(
self, query: list, run_manager: CallbackManagerForRetrieverRun| None = None, **kwargs
) -> List[LangchainDocument]:
"""Similarity search on the vector store"""
query_doc = schema.Document(docId="xxx", text=query)
try:
Expand Down Expand Up @@ -249,8 +256,8 @@ def get_relevant_documents(self, query: list, **kwargs) -> List[LangchainDocumen
for doc in records
]

async def aget_relevant_documents(
self, query: list, **kwargs
async def _aget_relevant_documents(
self, query: list, run_manager: AsyncCallbackManagerForRetrieverRun| None = None, **kwargs
) -> List[LangchainDocument]:
"""Similarity search on the vector store"""
query_doc = schema.Document(docId="xxx", text=query)
Expand Down
Loading

0 comments on commit 75931ec

Please sign in to comment.