Skip to content

Commit

Permalink
feat: add label filtering to pinecone client.
Browse files Browse the repository at this point in the history
Signed-off-by: wxywb <[email protected]>
  • Loading branch information
wxywb authored and alwayslove2013 committed Jan 23, 2025
1 parent c9d2775 commit f906b60
Showing 1 changed file with 33 additions and 6 deletions.
39 changes: 33 additions & 6 deletions vectordb_bench/backend/clients/pinecone/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

import logging
from contextlib import contextmanager
from typing import Type
import pinecone
from vectordb_bench.backend.filter import Filter, FilterType
from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
from .config import PineconeConfig

import pinecone

Expand All @@ -15,12 +20,18 @@


class Pinecone(VectorDB):
supported_filter_types: list[FilterType] = [
FilterType.NonFilter,
FilterType.Int,
FilterType.Label,
]
def __init__(
self,
dim: int,
db_config: dict,
db_case_config: DBCaseConfig,
drop_old: bool = False,
with_scalar_labels: bool = False,
**kwargs,
):
"""Initialize wrapper around the milvus vector database."""
Expand All @@ -33,6 +44,7 @@ def __init__(
pc = pinecone.Pinecone(api_key=self.api_key)
index = pc.Index(self.index_name)

self.with_scalar_labels = with_scalar_labels
if drop_old:
index_stats = index.describe_index_stats()
index_dim = index_stats["dimension"]
Expand All @@ -44,6 +56,7 @@ def __init__(
index.delete(delete_all=True, namespace=namespace)

self._metadata_key = "meta"
self._scalar_label_field = "label"

@classmethod
def config_cls(cls) -> type[DBConfig]:
Expand All @@ -66,6 +79,7 @@ def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
labels_data: list[str] = None,
**kwargs,
) -> (int, Exception):
assert len(embeddings) == len(metadata)
Expand All @@ -75,10 +89,13 @@ def insert_embeddings(
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
insert_datas = []
for i in range(batch_start_offset, batch_end_offset):
metadata_dict = {self._metadata_key: metadata[i]}
if self.with_scalar_labels:
metadata_dict[self._scalar_label_field] = labels_data[i]
insert_data = (
str(metadata[i]),
embeddings[i],
{self._metadata_key: metadata[i]},
metadata_dict,
)
insert_datas.append(insert_data)
self.index.upsert(insert_datas)
Expand All @@ -91,17 +108,27 @@ def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
) -> list[int]:
pinecone_filters = {} if filters is None else {self._metadata_key: {"$gte": filters["id"]}}
pinecone_filters = self.expr
try:
res = self.index.query(
top_k=k,
vector=query,
filter=pinecone_filters,
)["matches"]
except Exception as e:
log.warning(f"Error querying index: {e}")
raise e from e
return [int(one_res["id"]) for one_res in res]
print(f"Error querying index: {e}")
raise e
id_res = [int(one_res["id"]) for one_res in res]
return id_res

def prepare_filter(self, filter: Filter):
if filter.type == FilterType.NonFilter:
self.expr = None
elif filter.type == FilterType.Int:
self.expr = {self._scalar_id_field: {"$gte": filter.int_value}}
elif filter.type == FilterType.Label:
self.expr = {self._scalar_label_field : {"$eq": filter.label_value}}
else:
raise ValueError(f"Not support Filter for Pinecone - {filter}")

0 comments on commit f906b60

Please sign in to comment.