From 2bbc195e7d07bf4d330c6656aff51f5c5d02bb15 Mon Sep 17 00:00:00 2001 From: wxywb Date: Fri, 20 Dec 2024 07:35:29 +0000 Subject: [PATCH 1/2] feat: add label filtering to pinecone client. Signed-off-by: wxywb --- .../backend/clients/pinecone/pinecone.py | 39 ++++++++++++++++--- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py index 1a681b33..c61d23dc 100644 --- a/vectordb_bench/backend/clients/pinecone/pinecone.py +++ b/vectordb_bench/backend/clients/pinecone/pinecone.py @@ -2,6 +2,11 @@ import logging from contextlib import contextmanager +from typing import Type +import pinecone +from vectordb_bench.backend.filter import Filter, FilterType +from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType +from .config import PineconeConfig import pinecone @@ -15,12 +20,18 @@ class Pinecone(VectorDB): + supported_filter_types: list[FilterType] = [ + FilterType.NonFilter, + FilterType.Int, + FilterType.Label, + ] def __init__( self, dim: int, db_config: dict, db_case_config: DBCaseConfig, drop_old: bool = False, + with_scalar_labels: bool = False, **kwargs, ): """Initialize wrapper around the milvus vector database.""" @@ -33,6 +44,7 @@ def __init__( pc = pinecone.Pinecone(api_key=self.api_key) index = pc.Index(self.index_name) + self.with_scalar_labels = with_scalar_labels if drop_old: index_stats = index.describe_index_stats() index_dim = index_stats["dimension"] @@ -44,6 +56,7 @@ def __init__( index.delete(delete_all=True, namespace=namespace) self._metadata_key = "meta" + self._scalar_label_field = "label" @classmethod def config_cls(cls) -> type[DBConfig]: @@ -66,6 +79,7 @@ def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], + labels_data: list[str] = None, **kwargs, ) -> (int, Exception): assert len(embeddings) == len(metadata) @@ -75,10 +89,13 @@ def insert_embeddings( batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings)) insert_datas = [] for i in range(batch_start_offset, batch_end_offset): + metadata_dict = {self._metadata_key: metadata[i]} + if self.with_scalar_labels: + metadata_dict[self._scalar_label_field] = labels_data[i] insert_data = ( str(metadata[i]), embeddings[i], - {self._metadata_key: metadata[i]}, + metadata_dict, ) insert_datas.append(insert_data) self.index.upsert(insert_datas) @@ -91,10 +108,9 @@ def search_embedding( self, query: list[float], k: int = 100, - filters: dict | None = None, timeout: int | None = None, ) -> list[int]: - pinecone_filters = {} if filters is None else {self._metadata_key: {"$gte": filters["id"]}} + pinecone_filters = self.expr try: res = self.index.query( top_k=k, @@ -102,6 +118,17 @@ def search_embedding( filter=pinecone_filters, )["matches"] except Exception as e: - log.warning(f"Error querying index: {e}") - raise e from e - return [int(one_res["id"]) for one_res in res] + print(f"Error querying index: {e}") + raise e + id_res = [int(one_res["id"]) for one_res in res] + return id_res + + def prepare_filter(self, filter: Filter): + if filter.type == FilterType.NonFilter: + self.expr = None + elif filter.type == FilterType.Int: + self.expr = {self._scalar_id_field: {"$gte": filter.int_value}} + elif filter.type == FilterType.Label: + self.expr = {self._scalar_label_field : {"$eq": filter.label_value}} + else: + raise ValueError(f"Not support Filter for Pinecone - {filter}") From 4cf912c2330ce93de658bdb4e9f9e736ae3dd79b Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Thu, 13 Feb 2025 11:34:51 +0800 Subject: [PATCH 2/2] fix: pinecone filters Signed-off-by: min.tian --- .../backend/clients/pinecone/pinecone.py | 75 ++++++++----------- 1 file changed, 31 insertions(+), 44 deletions(-) diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py index c61d23dc..9c2b3888 100644 --- a/vectordb_bench/backend/clients/pinecone/pinecone.py +++ b/vectordb_bench/backend/clients/pinecone/pinecone.py @@ -2,16 +2,12 @@ import logging from contextlib import contextmanager -from typing import Type -import pinecone -from vectordb_bench.backend.filter import Filter, FilterType -from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType -from .config import PineconeConfig import pinecone -from ..api import DBCaseConfig, DBConfig, EmptyDBCaseConfig, IndexType, VectorDB -from .config import PineconeConfig +from vectordb_bench.backend.filter import Filter, FilterOp + +from ..api import DBCaseConfig, VectorDB log = logging.getLogger(__name__) @@ -20,11 +16,12 @@ class Pinecone(VectorDB): - supported_filter_types: list[FilterType] = [ - FilterType.NonFilter, - FilterType.Int, - FilterType.Label, + supported_filter_types: list[FilterOp] = [ + FilterOp.NonFilter, + FilterOp.NumGE, + FilterOp.StrEqual, ] + def __init__( self, dim: int, @@ -55,17 +52,9 @@ def __init__( log.info(f"Pinecone index delete namespace: {namespace}") index.delete(delete_all=True, namespace=namespace) - self._metadata_key = "meta" + self._scalar_id_field = "meta" self._scalar_label_field = "label" - @classmethod - def config_cls(cls) -> type[DBConfig]: - return PineconeConfig - - @classmethod - def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]: - return EmptyDBCaseConfig - @contextmanager def init(self): pc = pinecone.Pinecone(api_key=self.api_key) @@ -79,9 +68,9 @@ def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], - labels_data: list[str] = None, + labels_data: list[str] | None = None, **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: assert len(embeddings) == len(metadata) insert_count = 0 try: @@ -89,7 +78,7 @@ def insert_embeddings( batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings)) insert_datas = [] for i in range(batch_start_offset, batch_end_offset): - metadata_dict = {self._metadata_key: metadata[i]} + metadata_dict = {self._scalar_id_field: metadata[i]} if self.with_scalar_labels: metadata_dict[self._scalar_label_field] = labels_data[i] insert_data = ( @@ -101,8 +90,8 @@ def insert_embeddings( self.index.upsert(insert_datas) insert_count += batch_end_offset - batch_start_offset except Exception as e: - return (insert_count, e) - return (len(embeddings), None) + return insert_count, e + return len(embeddings), None def search_embedding( self, @@ -111,24 +100,22 @@ def search_embedding( timeout: int | None = None, ) -> list[int]: pinecone_filters = self.expr - try: - res = self.index.query( - top_k=k, - vector=query, - filter=pinecone_filters, - )["matches"] - except Exception as e: - print(f"Error querying index: {e}") - raise e - id_res = [int(one_res["id"]) for one_res in res] - return id_res - - def prepare_filter(self, filter: Filter): - if filter.type == FilterType.NonFilter: + res = self.index.query( + top_k=k, + vector=query, + filter=pinecone_filters, + )["matches"] + return [int(one_res["id"]) for one_res in res] + + def prepare_filter(self, filters: Filter): + if filters.type == FilterOp.NonFilter: self.expr = None - elif filter.type == FilterType.Int: - self.expr = {self._scalar_id_field: {"$gte": filter.int_value}} - elif filter.type == FilterType.Label: - self.expr = {self._scalar_label_field : {"$eq": filter.label_value}} + elif filters.type == FilterOp.NumGE: + self.expr = {self._scalar_id_field: {"$gte": filters.int_value}} + elif filters.type == FilterOp.StrEqual: + # both "in" and "==" are supported + # for example, self.expr = {self._scalar_label_field: {"$in": [filters.label_value]}} + self.expr = {self._scalar_label_field: {"$eq": filters.label_value}} else: - raise ValueError(f"Not support Filter for Pinecone - {filter}") + msg = f"Not support Filter for Pinecone - {filters}" + raise ValueError(msg)