Skip to content

Commit

Permalink
Label filter for pinecone (zilliztech#460)
Browse files Browse the repository at this point in the history
* feat: add label filtering to pinecone client.

Signed-off-by: wxywb <[email protected]>

* fix: pinecone filters

Signed-off-by: min.tian <[email protected]>

---------

Signed-off-by: wxywb <[email protected]>
Signed-off-by: min.tian <[email protected]>
Co-authored-by: wxywb <[email protected]>
  • Loading branch information
alwayslove2013 and wxywb committed Feb 17, 2025
1 parent b6825de commit 84175cc
Showing 1 changed file with 40 additions and 26 deletions.
66 changes: 40 additions & 26 deletions vectordb_bench/backend/clients/pinecone/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import pinecone

from ..api import DBCaseConfig, DBConfig, EmptyDBCaseConfig, IndexType, VectorDB
from .config import PineconeConfig
from vectordb_bench.backend.filter import Filter, FilterOp

from ..api import DBCaseConfig, VectorDB

log = logging.getLogger(__name__)

Expand All @@ -15,12 +16,19 @@


class Pinecone(VectorDB):
supported_filter_types: list[FilterOp] = [
FilterOp.NonFilter,
FilterOp.NumGE,
FilterOp.StrEqual,
]

def __init__(
self,
dim: int,
db_config: dict,
db_case_config: DBCaseConfig,
drop_old: bool = False,
with_scalar_labels: bool = False,
**kwargs,
):
"""Initialize wrapper around the milvus vector database."""
Expand All @@ -33,6 +41,7 @@ def __init__(
pc = pinecone.Pinecone(api_key=self.api_key)
index = pc.Index(self.index_name)

self.with_scalar_labels = with_scalar_labels
if drop_old:
index_stats = index.describe_index_stats()
index_dim = index_stats["dimension"]
Expand All @@ -43,15 +52,8 @@ def __init__(
log.info(f"Pinecone index delete namespace: {namespace}")
index.delete(delete_all=True, namespace=namespace)

self._metadata_key = "meta"

@classmethod
def config_cls(cls) -> type[DBConfig]:
return PineconeConfig

@classmethod
def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]:
return EmptyDBCaseConfig
self._scalar_id_field = "meta"
self._scalar_label_field = "label"

@contextmanager
def init(self):
Expand All @@ -66,42 +68,54 @@ def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
labels_data: list[str] | None = None,
**kwargs,
) -> (int, Exception):
) -> tuple[int, Exception]:
assert len(embeddings) == len(metadata)
insert_count = 0
try:
for batch_start_offset in range(0, len(embeddings), self.batch_size):
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
insert_datas = []
for i in range(batch_start_offset, batch_end_offset):
metadata_dict = {self._scalar_id_field: metadata[i]}
if self.with_scalar_labels:
metadata_dict[self._scalar_label_field] = labels_data[i]
insert_data = (
str(metadata[i]),
embeddings[i],
{self._metadata_key: metadata[i]},
metadata_dict,
)
insert_datas.append(insert_data)
self.index.upsert(insert_datas)
insert_count += batch_end_offset - batch_start_offset
except Exception as e:
return (insert_count, e)
return (len(embeddings), None)
return insert_count, e
return len(embeddings), None

def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
) -> list[int]:
pinecone_filters = {} if filters is None else {self._metadata_key: {"$gte": filters["id"]}}
try:
res = self.index.query(
top_k=k,
vector=query,
filter=pinecone_filters,
)["matches"]
except Exception as e:
log.warning(f"Error querying index: {e}")
raise e from e
pinecone_filters = self.expr
res = self.index.query(
top_k=k,
vector=query,
filter=pinecone_filters,
)["matches"]
return [int(one_res["id"]) for one_res in res]

def prepare_filter(self, filters: Filter):
if filters.type == FilterOp.NonFilter:
self.expr = None
elif filters.type == FilterOp.NumGE:
self.expr = {self._scalar_id_field: {"$gte": filters.int_value}}
elif filters.type == FilterOp.StrEqual:
# both "in" and "==" are supported
# for example, self.expr = {self._scalar_label_field: {"$in": [filters.label_value]}}
self.expr = {self._scalar_label_field: {"$eq": filters.label_value}}
else:
msg = f"Not support Filter for Pinecone - {filters}"
raise ValueError(msg)

0 comments on commit 84175cc

Please sign in to comment.