From 93db5a907c1d8ce26d47c94f446ce0dce7d84517 Mon Sep 17 00:00:00 2001 From: Min Tian Date: Thu, 13 Feb 2025 11:40:24 +0800 Subject: [PATCH] Label filter for pinecone (#460) * feat: add label filtering to pinecone client. Signed-off-by: wxywb * fix: pinecone filters Signed-off-by: min.tian --------- Signed-off-by: wxywb Signed-off-by: min.tian Co-authored-by: wxywb --- .../backend/clients/pinecone/pinecone.py | 66 +++++++++++-------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py index 1a681b33..9c2b3888 100644 --- a/vectordb_bench/backend/clients/pinecone/pinecone.py +++ b/vectordb_bench/backend/clients/pinecone/pinecone.py @@ -5,8 +5,9 @@ import pinecone -from ..api import DBCaseConfig, DBConfig, EmptyDBCaseConfig, IndexType, VectorDB -from .config import PineconeConfig +from vectordb_bench.backend.filter import Filter, FilterOp + +from ..api import DBCaseConfig, VectorDB log = logging.getLogger(__name__) @@ -15,12 +16,19 @@ class Pinecone(VectorDB): + supported_filter_types: list[FilterOp] = [ + FilterOp.NonFilter, + FilterOp.NumGE, + FilterOp.StrEqual, + ] + def __init__( self, dim: int, db_config: dict, db_case_config: DBCaseConfig, drop_old: bool = False, + with_scalar_labels: bool = False, **kwargs, ): """Initialize wrapper around the milvus vector database.""" @@ -33,6 +41,7 @@ def __init__( pc = pinecone.Pinecone(api_key=self.api_key) index = pc.Index(self.index_name) + self.with_scalar_labels = with_scalar_labels if drop_old: index_stats = index.describe_index_stats() index_dim = index_stats["dimension"] @@ -43,15 +52,8 @@ def __init__( log.info(f"Pinecone index delete namespace: {namespace}") index.delete(delete_all=True, namespace=namespace) - self._metadata_key = "meta" - - @classmethod - def config_cls(cls) -> type[DBConfig]: - return PineconeConfig - - @classmethod - def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]: - return EmptyDBCaseConfig + self._scalar_id_field = "meta" + self._scalar_label_field = "label" @contextmanager def init(self): @@ -66,8 +68,9 @@ def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], + labels_data: list[str] | None = None, **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: assert len(embeddings) == len(metadata) insert_count = 0 try: @@ -75,33 +78,44 @@ def insert_embeddings( batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings)) insert_datas = [] for i in range(batch_start_offset, batch_end_offset): + metadata_dict = {self._scalar_id_field: metadata[i]} + if self.with_scalar_labels: + metadata_dict[self._scalar_label_field] = labels_data[i] insert_data = ( str(metadata[i]), embeddings[i], - {self._metadata_key: metadata[i]}, + metadata_dict, ) insert_datas.append(insert_data) self.index.upsert(insert_datas) insert_count += batch_end_offset - batch_start_offset except Exception as e: - return (insert_count, e) - return (len(embeddings), None) + return insert_count, e + return len(embeddings), None def search_embedding( self, query: list[float], k: int = 100, - filters: dict | None = None, timeout: int | None = None, ) -> list[int]: - pinecone_filters = {} if filters is None else {self._metadata_key: {"$gte": filters["id"]}} - try: - res = self.index.query( - top_k=k, - vector=query, - filter=pinecone_filters, - )["matches"] - except Exception as e: - log.warning(f"Error querying index: {e}") - raise e from e + pinecone_filters = self.expr + res = self.index.query( + top_k=k, + vector=query, + filter=pinecone_filters, + )["matches"] return [int(one_res["id"]) for one_res in res] + + def prepare_filter(self, filters: Filter): + if filters.type == FilterOp.NonFilter: + self.expr = None + elif filters.type == FilterOp.NumGE: + self.expr = {self._scalar_id_field: {"$gte": filters.int_value}} + elif filters.type == FilterOp.StrEqual: + # both "in" and "==" are supported + # for example, self.expr = {self._scalar_label_field: {"$in": [filters.label_value]}} + self.expr = {self._scalar_label_field: {"$eq": filters.label_value}} + else: + msg = f"Not support Filter for Pinecone - {filters}" + raise ValueError(msg)