diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py index c61d23dc..d87942e5 100644 --- a/vectordb_bench/backend/clients/pinecone/pinecone.py +++ b/vectordb_bench/backend/clients/pinecone/pinecone.py @@ -2,16 +2,12 @@ import logging from contextlib import contextmanager -from typing import Type -import pinecone -from vectordb_bench.backend.filter import Filter, FilterType -from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType -from .config import PineconeConfig import pinecone -from ..api import DBCaseConfig, DBConfig, EmptyDBCaseConfig, IndexType, VectorDB -from .config import PineconeConfig +from vectordb_bench.backend.filter import Filter, FilterType + +from ..api import DBCaseConfig, VectorDB log = logging.getLogger(__name__) @@ -25,6 +21,7 @@ class Pinecone(VectorDB): FilterType.Int, FilterType.Label, ] + def __init__( self, dim: int, @@ -55,17 +52,9 @@ def __init__( log.info(f"Pinecone index delete namespace: {namespace}") index.delete(delete_all=True, namespace=namespace) - self._metadata_key = "meta" + self._scalar_id_field = "meta" self._scalar_label_field = "label" - @classmethod - def config_cls(cls) -> type[DBConfig]: - return PineconeConfig - - @classmethod - def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]: - return EmptyDBCaseConfig - @contextmanager def init(self): pc = pinecone.Pinecone(api_key=self.api_key) @@ -79,9 +68,9 @@ def insert_embeddings( self, embeddings: list[list[float]], metadata: list[int], - labels_data: list[str] = None, + labels_data: list[str] | None = None, **kwargs, - ) -> (int, Exception): + ) -> tuple[int, Exception]: assert len(embeddings) == len(metadata) insert_count = 0 try: @@ -89,7 +78,7 @@ def insert_embeddings( batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings)) insert_datas = [] for i in range(batch_start_offset, batch_end_offset): - metadata_dict = {self._metadata_key: metadata[i]} + metadata_dict = {self._scalar_id_field: metadata[i]} if self.with_scalar_labels: metadata_dict[self._scalar_label_field] = labels_data[i] insert_data = ( @@ -101,8 +90,8 @@ def insert_embeddings( self.index.upsert(insert_datas) insert_count += batch_end_offset - batch_start_offset except Exception as e: - return (insert_count, e) - return (len(embeddings), None) + return insert_count, e + return len(embeddings), None def search_embedding( self, @@ -111,24 +100,22 @@ def search_embedding( timeout: int | None = None, ) -> list[int]: pinecone_filters = self.expr - try: - res = self.index.query( - top_k=k, - vector=query, - filter=pinecone_filters, - )["matches"] - except Exception as e: - print(f"Error querying index: {e}") - raise e - id_res = [int(one_res["id"]) for one_res in res] - return id_res - - def prepare_filter(self, filter: Filter): - if filter.type == FilterType.NonFilter: + res = self.index.query( + top_k=k, + vector=query, + filter=pinecone_filters, + )["matches"] + return [int(one_res["id"]) for one_res in res] + + def prepare_filter(self, filters: Filter): + if filters.type == FilterType.NonFilter: self.expr = None - elif filter.type == FilterType.Int: - self.expr = {self._scalar_id_field: {"$gte": filter.int_value}} - elif filter.type == FilterType.Label: - self.expr = {self._scalar_label_field : {"$eq": filter.label_value}} + elif filters.type == FilterType.Int: + self.expr = {self._scalar_id_field: {"$gte": filters.int_value}} + elif filters.type == FilterType.Label: + # both "in" and "==" are supported + self.expr = {self._scalar_label_field: {"$eq": filters.label_value}} + # self.expr = {self._scalar_label_field: {"$in": [filters.label_value]}} else: - raise ValueError(f"Not support Filter for Pinecone - {filter}") + msg = f"Not support Filter for Pinecone - {filters}" + raise ValueError(msg)