Skip to content

Commit

Permalink
Add parameter in initialising Milvus store
Browse files Browse the repository at this point in the history
Add parameter (use_partition_key) in initialising Milvus store
  • Loading branch information
ziyi-curiousthing authored and SimFG committed Jan 30, 2024
1 parent 640117c commit 5f110cd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
4 changes: 4 additions & 0 deletions gptcache/manager/vector_data/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class VectorBase:
:type local_mode: bool
:param local_data: required when local_mode is True.
:type local_data: str
:param use_partition_key: if true, use partition key feature in milvus.
:type use_partition_key: bool
:param url: the connection url for PostgreSQL database, defaults to 'postgresql://postgres@localhost:5432/postgres'
:type url: str
Expand Down Expand Up @@ -125,6 +127,7 @@ def get(name, **kwargs):
search_params = kwargs.get("search_params", None)
local_mode = kwargs.get("local_mode", False)
local_data = kwargs.get("local_data", "./milvus_data")
use_partition_key = kwargs.get("use_partition_key", False)
vector_base = Milvus(
host=host,
port=port,
Expand All @@ -138,6 +141,7 @@ def get(name, **kwargs):
search_params=search_params,
local_mode=local_mode,
local_data=local_data,
use_partition_key=use_partition_key
)
elif name == "faiss":
from gptcache.manager.vector_data.faiss import Faiss
Expand Down
28 changes: 18 additions & 10 deletions gptcache/manager/vector_data/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def __init__(
index_params: dict = None,
search_params: dict = None,
local_mode: bool = False,
local_data: str = "./milvus_data"
local_data: str = "./milvus_data",
use_partition_key: bool = False
):
if dimension <= 0:
raise ValueError(
Expand All @@ -85,6 +86,7 @@ def __init__(
self.dimension = dimension
self.top_k = top_k
self.index_params = index_params
self.use_partition_key = use_partition_key
if self._local_mode:
self._create_local(port, local_data)
self._connect(host, port, user, password, secure)
Expand Down Expand Up @@ -131,16 +133,19 @@ def _create_collection(self, collection_name):
is_primary=True,
auto_id=False,
),
FieldSchema(
name="partition_key",
dtype=DataType.VARCHAR,
max_length=256,
is_partition_key=True,
),
FieldSchema(
name="embedding", dtype=DataType.FLOAT_VECTOR, dim=self.dimension
),
]
if self.use_partition_key:
schema.append(
FieldSchema(
name="partition_key",
dtype=DataType.VARCHAR,
max_length=256,
is_partition_key=True,
)
)
schema = CollectionSchema(schema)
self.col = Collection(
collection_name,
Expand Down Expand Up @@ -170,8 +175,11 @@ def _create_collection(self, collection_name):
self.col.load()

def mul_add(self, datas: List[VectorData], **kwargs):
partition_key = kwargs.get("partition_key") or ""
self.col.insert([{"id": data.id, "embedding": np.array(data.data).astype("float32"), "partition_key": partition_key} for data in datas])
if self.use_partition_key:
partition_key = kwargs.get("partition_key") or "default"
self.col.insert([{"id": data.id, "embedding": np.array(data.data).astype("float32"), "partition_key": partition_key} for data in datas])
else:
self.col.insert([{"id": data.id, "embedding": np.array(data.data).astype("float32")} for data in datas])

def search(self, data: np.ndarray, top_k: int = -1, **kwargs):
if top_k == -1:
Expand All @@ -182,7 +190,7 @@ def search(self, data: np.ndarray, top_k: int = -1, **kwargs):
anns_field="embedding",
param=self.search_params,
limit=top_k,
expr=f'partition_key=="{partition_key}"' if partition_key else None,
expr=f'partition_key=="{partition_key}"' if (self.use_partition_key and partition_key) else None,
)
return list(zip(search_result[0].distances, search_result[0].ids))

Expand Down

0 comments on commit 5f110cd

Please sign in to comment.