Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add evaluator #31

Merged
merged 15 commits into from
Dec 1, 2024
20 changes: 20 additions & 0 deletions tests/test_vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,23 @@ def test_vicinity_delete_and_query(vicinity_instance: Vicinity, items: list[str]

# Check that the queried item is in the results
assert "item3" in returned_items


def test_vicinity_evaluate(vicinity_instance: Vicinity, vectors: np.ndarray) -> None:
"""
Test the evaluate method of the Vicinity instance.

:param vicinity_instance: A Vicinity instance.
:param vectors: The full dataset vectors used to build the index.
"""
query_vectors = vectors[:10]
qps, recall = vicinity_instance.evaluate(vectors, query_vectors)

# Ensure the QPS and recall values are within valid ranges
assert qps > 0
assert 0 <= recall <= 1

# Test with an unsupported metric
vicinity_instance.backend.arguments.metric = "manhattan"
with pytest.raises(ValueError):
vicinity_instance.evaluate(vectors, query_vectors)
76 changes: 75 additions & 1 deletion vicinity/vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import time
from io import open
from pathlib import Path
from time import perf_counter
from typing import Any, Sequence, Union

import numpy as np
import orjson
from numpy import typing as npt

from vicinity.backends import AbstractBackend, get_backend_class
from vicinity import Metric
from vicinity.backends import AbstractBackend, BasicBackend, get_backend_class
from vicinity.datatypes import Backend, PathLike

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -83,6 +85,11 @@ def dim(self) -> int:
"""The dimensionality of the vectors."""
return self.backend.dim

@property
def metric(self) -> str:
"""The metric used by the backend."""
return self.backend.arguments.metric

def query(
self,
vectors: npt.NDArray,
Expand Down Expand Up @@ -229,3 +236,70 @@ def delete(self, tokens: Sequence[str]) -> None:
# Delete items starting from the highest index
for index in sorted(curr_indices, reverse=True):
self.items.pop(index)

def evaluate(
self,
full_vectors: npt.NDArray,
query_vectors: npt.NDArray,
k: int = 10,
epsilon: float = 1e-3,
) -> tuple[float, float]:
"""
Evaluate the Vicinity instance on the given query vectors.

Computes recall and measures QPS (Queries Per Second).
For recall calculation, the same methodology is used as in the ann-benchmarks repository.

NOTE: this is only supported for Cosine and Euclidean metric backends.

:param full_vectors: The full dataset vectors used to build the index.
:param query_vectors: The query vectors to evaluate.
:param k: The number of nearest neighbors to retrieve.
:param epsilon: The epsilon threshold for recall calculation.
:return: A tuple of (QPS, recall).
:raises ValueError: If the metric is not supported by the BasicBackend.
"""
try:
# Validate and map the metric using Metric.from_string
metric_enum = Metric.from_string(self.metric)
if metric_enum not in BasicBackend.supported_metrics:
raise ValueError(f"Unsupported metric '{metric_enum.value}' for BasicBackend.")
basic_metric = metric_enum.value
except ValueError as e:
raise ValueError(
f"Unsupported metric '{self.metric}' for evaluation with BasicBackend. "
f"Supported metrics are: {[m.value for m in BasicBackend.supported_metrics]}"
) from e

# Create ground truth Vicinity instance
gt_vicinity = Vicinity.from_vectors_and_items(
vectors=full_vectors,
items=self.items,
backend_type=Backend.BASIC,
metric=basic_metric,
)

# Compute ground truth results
gt_distances = [[dist for _, dist in neighbors] for neighbors in gt_vicinity.query(query_vectors, k=k)]

# Start timer for approximate query
start_time = perf_counter()
run_results = self.query(query_vectors, k=k)
elapsed_time = perf_counter() - start_time

# Compute QPS
num_queries = len(query_vectors)
qps = num_queries / elapsed_time if elapsed_time > 0 else float("inf")

# Extract approximate distances
approx_distances = [[dist for _, dist in neighbors] for neighbors in run_results]

# Compute recall using the ground truth and approximate distances
recalls = []
for _gt_distances, _approx_distances in zip(gt_distances, approx_distances):
t = _gt_distances[k - 1] + epsilon
recall = sum(1 for dist in _approx_distances if dist <= t) / k
recalls.append(recall)

mean_recall = float(np.mean(recalls))
return qps, mean_recall