Skip to content

Commit

Permalink
Initial faiss integration, wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled committed Nov 15, 2024
1 parent 88cb61e commit 9dc0c11
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 12 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pynndescent = [
"numpy>=1.24.0"
]
annoy = ["annoy"]
faiss = ["faiss-cpu"]

[project.urls]
"Homepage" = "https://github.com/MinishLab"
Expand Down
41 changes: 35 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@

random_gen = np.random.default_rng(42)

# Define supported FAISS index types
FAISS_INDEX_TYPES = [
"flat",
# "ivf",
# "hnsw",
# "lsh",
# "scalar",
# "pq",
# "ivf_scalar",
# "ivfpq",
# "ivfpqr"
]


@pytest.fixture(scope="session")
def items() -> list[str]:
Expand All @@ -18,22 +31,38 @@ def items() -> list[str]:
@pytest.fixture(scope="session")
def vectors() -> np.ndarray:
"""Fixture providing an array of vectors corresponding to items."""
return random_gen.random((100, 5))
return random_gen.random((100, 8))


@pytest.fixture(scope="session")
def query_vector() -> np.ndarray:
"""Fixture providing a query vector."""
return random_gen.random(5)
return random_gen.random(8)


BACKEND_PARAMS = [(Backend.FAISS, index_type) for index_type in FAISS_INDEX_TYPES] + [
(Backend.BASIC, None),
(Backend.HNSW, None),
(Backend.ANNOY, None),
(Backend.PYNNDESCENT, None),
]

@pytest.fixture(params=list(Backend))

@pytest.fixture(params=BACKEND_PARAMS)
def backend_type(request: pytest.FixtureRequest) -> Backend:
"""Fixture parametrizing over all backend types defined in Backend."""
return request.param


@pytest.fixture
def vicinity_instance(backend_type: Backend, items: list[str], vectors: np.ndarray) -> Vicinity:
"""Fixture creating a Vicinity instance with the given backend, items, and vectors."""
@pytest.fixture(params=BACKEND_PARAMS)
def vicinity_instance(request: pytest.FixtureRequest, items: list[str], vectors: np.ndarray) -> Vicinity:
"""Fixture providing a Vicinity instance for each backend type."""
backend_type, index_type = request.param
# Handle FAISS backend with specific FAISS index types
if backend_type == Backend.FAISS:
return Vicinity.from_vectors_and_items(
vectors, items, backend_type=backend_type, index_type=index_type, nlist=4
)

# Handle non-FAISS backends without passing `index_type`
return Vicinity.from_vectors_and_items(vectors, items, backend_type=backend_type)
24 changes: 18 additions & 6 deletions tests/test_vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,38 @@
from vicinity.datatypes import Backend


def test_vicinity_init(backend_type: Backend, items: list[str], vectors: np.ndarray) -> None:
def test_vicinity_init(backend_type: tuple[Backend, str], items: list[str], vectors: np.ndarray) -> None:
"""
Test Vicinity.init.
:param backend_type: The backend type to use (BASIC, HNSW or Annoy).
:param items: A list of item names.
:param vectors: An array of vectors.
"""
vicinity = Vicinity.from_vectors_and_items(vectors, items, backend_type=backend_type)
backend = backend_type[0]
vicinity = Vicinity.from_vectors_and_items(vectors, items, backend_type=backend)
assert len(vicinity) == len(items)
assert vicinity.items == items
assert vicinity.dim == vectors.shape[1]

vectors = np.random.default_rng(42).random((len(items) - 1, 5))

with pytest.raises(ValueError):
vicinity = Vicinity.from_vectors_and_items(vectors, items, backend_type=backend_type)
vicinity = Vicinity.from_vectors_and_items(vectors, items, backend_type=backend)


def test_vicinity_from_vectors_and_items(backend_type: Backend, items: list[str], vectors: np.ndarray) -> None:
def test_vicinity_from_vectors_and_items(
backend_type: tuple[Backend, str], items: list[str], vectors: np.ndarray
) -> None:
"""
Test Vicinity.from_vectors_and_items.
:param backend_type: The backend type to use (BASIC, HNSW or Annoy).
:param items: A list of item names.
:param vectors: An array of vectors.
"""
vicinity = Vicinity.from_vectors_and_items(vectors, items, backend_type=backend_type)
backend: Backend = backend_type[0]
vicinity = Vicinity.from_vectors_and_items(vectors, items, backend_type=backend)

assert len(vicinity) == len(items)
assert vicinity.items == items
Expand Down Expand Up @@ -92,7 +96,6 @@ def test_vicinity_delete(vicinity_instance: Vicinity, items: list[str], vectors:
"""
Test Vicinity.delete method by verifying that the vector for a deleted item is not returned in subsequent queries.
:param backend_type: The backend type to use.
:param vicinity_instance: A Vicinity instance.
:param items: List of item names.
:param vectors: Array of vectors corresponding to items.
Expand All @@ -101,6 +104,15 @@ def test_vicinity_delete(vicinity_instance: Vicinity, items: list[str], vectors:
# Don't test delete for Annoy and Pynndescent backend
return

if vicinity_instance.backend.backend_type == Backend.FAISS and vicinity_instance.backend.arguments.index_type in {
"pq",
"scalar",
"ivfpq",
"ivfpqr",
}:
# Skip delete test for FAISS index types that do not support deletion
return

# Get the vector corresponding to "item2"
item2_index = items.index("item2")
item2_vector = vectors[item2_index]
Expand Down
31 changes: 31 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions vicinity/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,10 @@ def get_backend_class(backend: Backend | str) -> type[AbstractBackend]:

return PyNNDescentBackend

elif backend == Backend.FAISS:
from vicinity.backends.faiss import FaissBackend

return FaissBackend


__all__ = ["get_backend_class", "AbstractBackend"]
157 changes: 157 additions & 0 deletions vicinity/backends/faiss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal

import faiss
import numpy as np
from numpy import typing as npt

from vicinity.backends.base import AbstractBackend, BaseArgs
from vicinity.datatypes import Backend, QueryResult


@dataclass
class FaissArgs(BaseArgs):
dim: int = 0
index_type: Literal["flat", "ivf", "hnsw", "lsh", "scalar", "pq", "ivf_scalar", "ivfpq", "ivfpqr"] = "flat"
metric: Literal["l2", "inner_product"] = "l2"
nlist: int = 100 # Used for IVF indexes
m: int = 8 # Used for PQ and HNSW
nbits: int = 8 # Used for LSH and PQ
refine_nbits: int = 8 # Used for IVFPQR
direct_map: bool = False # Enable DirectMap for IVF indexes


class FaissBackend(AbstractBackend[FaissArgs]):
argument_class = FaissArgs

def __init__(
self,
index: faiss.Index,
arguments: FaissArgs,
) -> None:
"""Initialize the backend using a FAISS index."""
super().__init__(arguments)
self.index = index
# Enable DirectMap if specified and supported by index type
if isinstance(index, faiss.IndexIVF) and arguments.direct_map:
index.set_direct_map_type(faiss.DirectMap.Hashtable)

@classmethod
def from_vectors( # noqa: C901
cls: type[FaissBackend],
vectors: npt.NDArray,
index_type: Literal["flat", "ivf", "hnsw", "lsh", "scalar", "pq", "ivf_scalar", "ivfpq", "ivfpqr"] = "flat",
metric: Literal["l2", "inner_product"] = "l2",
nlist: int = 100,
m: int = 8,
nbits: int = 8,
refine_nbits: int = 8,
direct_map: bool = False,
**kwargs: Any,
) -> FaissBackend:
"""Create a new instance from vectors."""
dim = vectors.shape[1]

faiss_metric = faiss.METRIC_L2 if metric == "l2" else faiss.METRIC_INNER_PRODUCT

if index_type == "flat":
index = faiss.IndexFlatL2(dim) if faiss_metric == faiss.METRIC_L2 else faiss.IndexFlatIP(dim)
elif index_type == "ivf":
quantizer = faiss.IndexFlatL2(dim) if faiss_metric == faiss.METRIC_L2 else faiss.IndexFlatIP(dim)
index = faiss.IndexIVFFlat(quantizer, dim, nlist, faiss_metric)
index.train(vectors)
elif index_type == "hnsw":
index = faiss.IndexHNSWFlat(dim, m)
elif index_type == "lsh":
index = faiss.IndexLSH(dim, nbits)
elif index_type == "scalar":
index = faiss.IndexScalarQuantizer(dim, faiss.ScalarQuantizer.QT_8bit)
elif index_type == "pq":
index = faiss.IndexPQ(dim, m, nbits)
elif index_type == "ivf_scalar":
quantizer = faiss.IndexFlatL2(dim) if faiss_metric == faiss.METRIC_L2 else faiss.IndexFlatIP(dim)
index = faiss.IndexIVFScalarQuantizer(quantizer, dim, nlist, faiss.ScalarQuantizer.QT_8bit)
index.train(vectors)
elif index_type == "ivfpq":
quantizer = faiss.IndexFlatL2(dim) if faiss_metric == faiss.METRIC_L2 else faiss.IndexFlatIP(dim)
index = faiss.IndexIVFPQ(quantizer, dim, nlist, m, nbits)
index.train(vectors)
elif index_type == "ivfpqr":
quantizer = faiss.IndexFlatL2(dim) if faiss_metric == faiss.METRIC_L2 else faiss.IndexFlatIP(dim)
index = faiss.IndexIVFPQR(quantizer, dim, nlist, m, nbits, m, refine_nbits)
index.train(vectors)
else:
raise ValueError(f"Unsupported FAISS index type: {index_type}")

index.add(vectors)

# Enable DirectMap for IVF indexes if requested
if isinstance(index, faiss.IndexIVF) and direct_map:
index.set_direct_map_type(faiss.DirectMap.Hashtable)

arguments = FaissArgs(
dim=dim,
index_type=index_type,
metric=metric,
nlist=nlist,
m=m,
nbits=nbits,
refine_nbits=refine_nbits,
direct_map=direct_map,
)
return cls(index=index, arguments=arguments)

def __len__(self) -> int:
"""Return the number of vectors in the index."""
return self.index.ntotal

@property
def backend_type(self) -> Backend:
"""The type of the backend."""
return Backend.FAISS

@property
def dim(self) -> int:
"""Get the dimension of the space."""
return self.index.d

def query(self, vectors: npt.NDArray, k: int) -> QueryResult:
"""Perform a k-NN search in the FAISS index."""
distances, indices = self.index.search(vectors, k)
return list(zip(indices, distances))

def insert(self, vectors: npt.NDArray) -> None:
"""Insert vectors into the backend."""
self.index.add(vectors)

def delete(self, indices: list[int]) -> None:
"""Delete vectors from the backend, if supported."""
if hasattr(self.index, "remove_ids"):
id_selector = faiss.IDSelectorBatch(np.array(indices, dtype=np.int64))
self.index.remove_ids(id_selector)
else:
raise NotImplementedError("This FAISS index type does not support deletion.")

def threshold(self, vectors: npt.NDArray, threshold: float) -> list[npt.NDArray]:
"""Query vectors within a distance threshold."""
out: list[npt.NDArray] = []
distances, indices = self.index.search(vectors, 100)
for dist, idx in zip(distances, indices):
within_threshold = idx[dist < threshold]
out.append(within_threshold)
return out

def save(self, base_path: Path) -> None:
"""Save the FAISS index and arguments."""
faiss.write_index(self.index, str(base_path / "index.faiss"))
self.arguments.dump(base_path / "arguments.json")

@classmethod
def load(cls: type[FaissBackend], base_path: Path) -> FaissBackend:
"""Load a FAISS index and arguments."""
arguments = FaissArgs.load(base_path / "arguments.json")
index = faiss.read_index(str(base_path / "index.faiss"))
return cls(index=index, arguments=arguments)
1 change: 1 addition & 0 deletions vicinity/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ class Backend(str, Enum):
BASIC = "basic"
ANNOY = "annoy"
PYNNDESCENT = "pynndescent"
FAISS = "faiss"

0 comments on commit 9dc0c11

Please sign in to comment.