Skip to content

Commit

Permalink
added support for lancedb as vectorstore
Browse files Browse the repository at this point in the history
Signed-off-by: akashAD <[email protected]>
  • Loading branch information
akashAD98 committed Sep 2, 2024
1 parent bf8182a commit 57e5c30
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/configure_it.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ For the similar cache of text, only cache store and vector store are needed. If
- docarray
- usearch
- redis
- lancedb

### object store

Expand Down
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ Support vector database
- Zilliz Cloud
- FAISS
- ChromaDB
- LanceDB

> [Example code](https://github.com/zilliztech/GPTCache/blob/main/examples/data_manager/vector_store.py)
Expand Down
1 change: 1 addition & 0 deletions examples/data_manager/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def run():
'docarray',
'redis',
'weaviate',
'lancedb',
]
for vector_store in vector_stores:
cache_base = CacheBase('sqlite')
Expand Down
92 changes: 92 additions & 0 deletions gptcache/manager/vector_data/lancedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import List, Optional

import numpy as np
import pyarrow as pa

import lancedb
from gptcache.manager.vector_data.base import VectorBase, VectorData
from gptcache.utils import import_lancedb, import_torch

import_torch()
import_lancedb()


class LanceDB(VectorBase):
"""Vector store: LanceDB
:param persist_directory: The directory to persist, defaults to '/tmp/lancedb'.
:type persist_directory: str
:param table_name: The name of the table in LanceDB, defaults to 'gptcache'.
:type table_name: str
:param top_k: The number of the vectors results to return, defaults to 1.
:type top_k: int
"""

def __init__(
self,
persist_directory: Optional[str] = "/tmp/lancedb",
table_name: str = "gptcache",
top_k: int = 1,
):
self._persist_directory = persist_directory
self._table_name = table_name
self._top_k = top_k

# Initialize LanceDB database
self._db = lancedb.connect(self._persist_directory)

# Initialize or open table
if self._table_name not in self._db.table_names():
self._table = None # Table will be created with the first insertion
else:
self._table = self._db.open_table(self._table_name)

def mul_add(self, datas: List[VectorData]):
"""Add multiple vectors to the LanceDB table"""
vectors, ids = map(list, zip(*((data.data.tolist(), str(data.id)) for data in datas)))

# Infer the dimension of the vectors
vector_dim = len(vectors[0]) if vectors else 0

# Create table with the inferred schema if it doesn't exist
if self._table is None:
schema = pa.schema([
pa.field("id", pa.string()),
pa.field("vector", pa.list_(pa.float32(), list_size=vector_dim))
])
self._table = self._db.create_table(self._table_name, schema=schema)

# Prepare data for insertion
data = [{"id": id, "vector": vector} for id, vector in zip(ids, vectors)]
self._table.add(data)

def search(self, data: np.ndarray, top_k: int = -1):
"""Search for the most similar vectors in the LanceDB table"""
if len(self._table) == 0:
return []

if top_k == -1:
top_k = self._top_k

results = self._table.search(data.tolist()).limit(top_k).to_list()
return [(result["_distance"], int(result["id"])) for result in results]

def delete(self, ids: List[int]):
"""Delete vectors from the LanceDB table based on IDs"""
for id in ids:
self._table.delete(f"id = '{id}'")

def rebuild(self, ids: Optional[List[int]] = None):
"""Rebuild the index, if applicable"""
return True

def flush(self):
"""Flush changes to disk (if necessary)"""
pass

def close(self):
"""Close the connection to LanceDB"""
pass

def count(self):
"""Return the total number of vectors in the table"""
return len(self._table)
25 changes: 24 additions & 1 deletion gptcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class VectorBase:
`Chromadb` (with `top_k`, `client_settings`, `persist_directory`, `collection_name` params),
`Hnswlib` (with `index_file_path`, `dimension`, `top_k`, `max_elements` params).
`pgvector` (with `url`, `collection_name`, `index_params`, `top_k`, `dimension` params).
`lancedb` (with `url`, `collection_name`, `index_params`, `top_k`,).
:param name: the name of the vectorbase, it is support 'milvus', 'faiss', 'chromadb', 'hnswlib' now.
:type name: str
Expand Down Expand Up @@ -89,6 +90,14 @@ class VectorBase:
:param persist_directory: the directory to persist, defaults to '.chromadb/' in the current directory.
:type persist_directory: str
:param client_settings: the setting for LanceDB.
:param persist_directory: The directory to persist, defaults to '/tmp/lancedb'.
:type persist_directory: str
:param table_name: The name of the table in LanceDB, defaults to 'gptcache'.
:type table_name: str
:param top_k: The number of the vectors results to return, defaults to 1.
:type top_k: int
:param index_path: the path to hnswlib index, defaults to 'hnswlib_index.bin'.
:type index_path: str
:param max_elements: max_elements of hnswlib, defaults 100000.
Expand Down Expand Up @@ -264,7 +273,7 @@ def get(name, **kwargs):
from gptcache.manager.vector_data.weaviate import Weaviate

url = kwargs.get("url", None)
auth_client_secret = kwargs.get("auth_client_secret", None)
auth_client_secret = kwargs.get("auth_client_secrets", None)
timeout_config = kwargs.get("timeout_config", WEAVIATE_TIMEOUT_CONFIG)
proxies = kwargs.get("proxies", None)
trust_env = kwargs.get("trust_env", False)
Expand All @@ -289,6 +298,20 @@ def get(name, **kwargs):
class_schema=class_schema,
top_k=top_k,
)

elif name == "lancedb":
from gptcache.manager.vector_data.lancedb import LanceDB

persist_directory = kwargs.get("persist_directory", None)
table_name = kwargs.get("table_name", COLLECTION_NAME)
top_k: int = kwargs.get("top_k", TOP_K)

vector_base = LanceDB(
persist_directory=persist_directory,
table_name=table_name,
top_k=top_k,
)

else:
raise NotFoundError("vector store", name)
return vector_base
3 changes: 3 additions & 0 deletions gptcache/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"import_redis",
"import_qdrant",
"import_weaviate",
"import_lancedb",
]

import importlib.util
Expand Down Expand Up @@ -147,6 +148,8 @@ def import_duckdb():
_check_library("duckdb", package="duckdb")
_check_library("duckdb-engine", package="duckdb-engine")

def import_lancedb():
_check_library("lancedb", package="lancedb")

def import_sql_client(db_name):
if db_name == "postgresql":
Expand Down

0 comments on commit 57e5c30

Please sign in to comment.