Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Label filter for pinecone #460

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 40 additions & 26 deletions vectordb_bench/backend/clients/pinecone/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import pinecone

from ..api import DBCaseConfig, DBConfig, EmptyDBCaseConfig, IndexType, VectorDB
from .config import PineconeConfig
from vectordb_bench.backend.filter import Filter, FilterOp

from ..api import DBCaseConfig, VectorDB

log = logging.getLogger(__name__)

Expand All @@ -15,12 +16,19 @@


class Pinecone(VectorDB):
supported_filter_types: list[FilterOp] = [
FilterOp.NonFilter,
FilterOp.NumGE,
FilterOp.StrEqual,
]

def __init__(
self,
dim: int,
db_config: dict,
db_case_config: DBCaseConfig,
drop_old: bool = False,
with_scalar_labels: bool = False,
**kwargs,
):
"""Initialize wrapper around the milvus vector database."""
Expand All @@ -33,6 +41,7 @@ def __init__(
pc = pinecone.Pinecone(api_key=self.api_key)
index = pc.Index(self.index_name)

self.with_scalar_labels = with_scalar_labels
if drop_old:
index_stats = index.describe_index_stats()
index_dim = index_stats["dimension"]
Expand All @@ -43,15 +52,8 @@ def __init__(
log.info(f"Pinecone index delete namespace: {namespace}")
index.delete(delete_all=True, namespace=namespace)

self._metadata_key = "meta"

@classmethod
def config_cls(cls) -> type[DBConfig]:
return PineconeConfig

@classmethod
def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]:
return EmptyDBCaseConfig
self._scalar_id_field = "meta"
self._scalar_label_field = "label"

@contextmanager
def init(self):
Expand All @@ -66,42 +68,54 @@ def insert_embeddings(
self,
embeddings: list[list[float]],
metadata: list[int],
labels_data: list[str] | None = None,
**kwargs,
) -> (int, Exception):
) -> tuple[int, Exception]:
assert len(embeddings) == len(metadata)
insert_count = 0
try:
for batch_start_offset in range(0, len(embeddings), self.batch_size):
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
insert_datas = []
for i in range(batch_start_offset, batch_end_offset):
metadata_dict = {self._scalar_id_field: metadata[i]}
if self.with_scalar_labels:
metadata_dict[self._scalar_label_field] = labels_data[i]
insert_data = (
str(metadata[i]),
embeddings[i],
{self._metadata_key: metadata[i]},
metadata_dict,
)
insert_datas.append(insert_data)
self.index.upsert(insert_datas)
insert_count += batch_end_offset - batch_start_offset
except Exception as e:
return (insert_count, e)
return (len(embeddings), None)
return insert_count, e
return len(embeddings), None

def search_embedding(
self,
query: list[float],
k: int = 100,
filters: dict | None = None,
timeout: int | None = None,
) -> list[int]:
pinecone_filters = {} if filters is None else {self._metadata_key: {"$gte": filters["id"]}}
try:
res = self.index.query(
top_k=k,
vector=query,
filter=pinecone_filters,
)["matches"]
except Exception as e:
log.warning(f"Error querying index: {e}")
raise e from e
pinecone_filters = self.expr
res = self.index.query(
top_k=k,
vector=query,
filter=pinecone_filters,
)["matches"]
return [int(one_res["id"]) for one_res in res]

def prepare_filter(self, filters: Filter):
if filters.type == FilterOp.NonFilter:
self.expr = None
elif filters.type == FilterOp.NumGE:
self.expr = {self._scalar_id_field: {"$gte": filters.int_value}}
elif filters.type == FilterOp.StrEqual:
# both "in" and "==" are supported
# for example, self.expr = {self._scalar_label_field: {"$in": [filters.label_value]}}
self.expr = {self._scalar_label_field: {"$eq": filters.label_value}}
else:
msg = f"Not support Filter for Pinecone - {filters}"
raise ValueError(msg)