Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter/cufe #214

Merged
merged 13 commits into from
Nov 6, 2023
4 changes: 2 additions & 2 deletions .github/workflows/neurips23.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ jobs:
dataset: random-xs
track: ood
- algorithm: cufe
dataset: sparse-small
track: sparse
dataset: random-s
track: filter
maumueller marked this conversation as resolved.
Show resolved Hide resolved
- algorithm: vamana
dataset: random-xs
track: ood
Expand Down
25 changes: 25 additions & 0 deletions neurips23/filter/cufe/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
FROM neurips23

RUN apt update && apt install -y wget swig
RUN wget https://repo.anaconda.com/archive/Anaconda3-2023.03-0-Linux-x86_64.sh
RUN bash Anaconda3-2023.03-0-Linux-x86_64.sh -b

ENV PATH /root/anaconda3/bin:$PATH
ENV CONDA_PREFIX /root/anaconda3/

RUN conda install -c pytorch faiss-cpu
COPY install/requirements_conda.txt ./
# conda doesn't like some of our packages, use pip
RUN python3 -m pip install -r requirements_conda.txt

COPY neurips23/filter/cufe/bow_id_selector.swig ./

RUN swig -c++ -python -I$CONDA_PREFIX/include -Ifaiss bow_id_selector.swig
RUN g++ -shared -O3 -g -fPIC bow_id_selector_wrap.cxx -o _bow_id_selector.so \
-I $( python -c "import distutils.sysconfig ; print(distutils.sysconfig.get_python_inc())" ) \
-I $CONDA_PREFIX/include $CONDA_PREFIX/lib/libfaiss_avx2.so -Ifaiss

RUN python3 -c 'import faiss; print(faiss.IndexFlatL2); print(faiss.__version__)'



102 changes: 102 additions & 0 deletions neurips23/filter/cufe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@

# Faiss baseline for the Filtered search track

The database of size $N=10^7$ can be seen as the combination of:

- a matrix $M$ of size $N \times d$ of embedding vectors (called `xb` in the code). $d=192$.
- a sparse matrix $M_\mathrm{meta}$ of size $N \times v$, entry $i,j$ is set to 1 iff word $j$ is applicable to vector $i$. $v=200386$, called `meta_b` in the code (a [CSR matrix](https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html))

The Faiss basleline for the filtered search track is based on two distinct data structures, a word-based inverted file and a Faiss `IndexIVFFlat`.
Both data structured allow to peform filtered searches in two different ways.

The search is based on a query vector $q\in \mathbb{R}^d$ and associated query words $w_1, w_2$ (there are one or two query words).
The search results are the database vectors that include /all/ query words and that are nearest to $q$ in $L_2$ distance.

## Word-based inverted file

This is term-based inverted file that maps each word to the vectors (docs) that contain that term.
In the code it is a CSR matrix called `docs_per_word` (it's just the transposed version of `meta_b`).

At search time, the subset (`subset`) of vectors eligible for results depends on the number of query words:

- if there is a single word $w_1$ then it's just the set of non-0 entries in row $w_1$ of the `docs_per_word` matrix.
This can be extracted at no cost

- if there are two words $w_1$ and $w_2$ then the sets of non-0 entries of rows $w_1$ and $w_2$ are intersected.
This is done with `np.intersect1d` or the C++ function `intersect_sorted`, that is faster (linear in nb of non-0 entries of the two rows).

When this subset is selected, the result is found by searching the top-k vectors in this subset of rows of $M$.
The result is exact and the search is most efficient when the subset is small (ie. the words are discriminative enough to filter the results well).

## IndexIVFFlat structure

This is a Faiss [`IndexIVFFlat`](https://github.com/facebookresearch/faiss/wiki/The-index-factory#encodings) called `index`.

By default the index performs unfiltered search, ie. the nearest vectors to $q$ can be retrieved.
The accuracy of this search depends on the number of visited centroids of the `IndexIVFFlat` (parameter `nprobe`, the larger the more accurate and the slower).

One solution would be to over-fetch vectors and perform filtering post-hoc using the words in the result list.
However, it is unclear /how much/ we should overfetch.

Therefore, another solution is to use the Faiss [filtering functionality](https://github.com/facebookresearch/faiss/wiki/Setting-search-parameters-for-one-query#searching-in-a-subset-of-elements), ie. provide a callback function that is called for each vector id to decide if it should be considered as a result or not.

The callback function is implemented in C++ in the class `IDSelectorBOW`.
For vector id $i$ it looks up the row $i$ of $M_\mathrm{meta}$ and peforms a binary search on $w_1$ to check of that word belongs to the words associated to vector $i$.
If $w_2$ is also provided, it does the same for $w_2$.
The callback returns true only if all terms are present.

### Binary filtering

The issue is that this callback is relatively slow because (1) it requires to access the $M_\mathrm{meta}$ matrix which causes cache misses and (2) it performs an iterative binary search.
Since the callback is called in the tightest inner loop of the search function, and since the IVF search tends to perform many vector comparisons, this has non negligible performance impact.

To speed up this test, we can use a nifty piece of bit manipulation.
The idea is that the vector ids are 63 bits long (64 bits integers but negative values are reserved, so we cannot use the sign bit).
However, since $N=10^7$ we use only $\lceil \log_2 N \rceil = 24$ bits of these, leaving 63-24 = 39 bits that are always 0.

Now, we associate to each word $j$ a 39-bit signature $S[j]$, and the to each set of words the binary `or` of these signatures.
The query is represented by $s_\mathrm{q} = S[w_1] \vee S[w_2]$.
Database entry $i$ with words $W_i$ is represented by $s_i = \vee_{w\in W_i} S[w]$.

Then we have the following implication: if $\\{w_1, w_2\\} \subset W_i$ then all 1 bits of $s_\mathrm{q}$ are also set to 1 in $s_i$.

$$\\{w_1, w_2\\} \subset W_i \Rightarrow \neg s_i \wedge s_\mathrm{q} = 0$$

Which is equivalent to:

$$\neg s_i \wedge s_\mathrm{q} \neq 0 \Rightarrow \\{w_1, w_2\\} \not\subset W_i $$

Of course, this is an implication, not an equivalence.
Therefore, it can only rule out database vectors.
However, the binary test is very cheap to perform (uses a few machine instructions on data that is already in machine registers), so it can be used as a pre-filter to apply the full membership test on candidates.
This is implemented in the `IDSelectorBOWBin` object.

The remaining degree of freedom is how to choose the binary signatures, because this rule is always valid, but its filtering ability depends on the choice of the signatures $S$.
After a few tests (see [this notebook](https://gist.github.com/mdouze/75103e4cef436510ac9b834f9a77496f#file-eval_binary_signatures-ipynb) ) it seems that a random signature with 0.1 probability for 1s filters our 80% of negative tests.
Asjuting this to the frequency of the words did not seem to yield better results.

## Choosing between the two implementations

The two implementations are complementary: the word-first implementation gives exact results, and has a strong filtering ability for rare words.
The `IndexIVFFlat` implementation gives approximate results and is more relevant for words that are more common, where a significant subset of vectors are indeed relevant.

Therefore, there should be a rule to choose between the two, and the relevant metric is the size of the subset of vectors to consider.
We can use statistics on the words, ie. $\mathrm{nocc}[j]$ is the number of times word $j$ appears in the dataset (this is just the column-wise sum of the $M_\mathrm{meta}$).

For a single query word $w_1$, the fraction of relevant indices is just $f = \mathrm{nocc}[w_1] / N$.
For two query words, it is more complicated to compute but an estimate is given by $f = \mathrm{nocc}[w_1] \times \mathrm{nocc}[w_2] / N^2$ (this estimate assumes words are independent, which is incorrect).

Therefore, the rule that we use is based on a threshold $\tau$ (called `metadata_threshold` in the code) :

- if $f < \tau$ then use the word-first search

- otherwise use the IVFFlat based index

Note that the optimal threshold also depends on the target accuracy (since the IVFFlat is not exact, when a higher accuracy is desired), see https://github.com/harsha-simhadri/big-ann-benchmarks/pull/105#issuecomment-1539842223 .


## Code layout

The code is in faiss.py, with performance critical parts implemented in C++ and wrapped with SWIG in `bow_id_selector.swig`.
SWIG directly exposes the C++ classes and functions in Python.

186 changes: 186 additions & 0 deletions neurips23/filter/cufe/bow_id_selector.swig
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@

%module bow_id_selector

/*
To compile when Faiss is installed via conda:

swig -c++ -python -I$CONDA_PREFIX/include bow_id_selector.swig && \
g++ -shared -O3 -g -fPIC bow_id_selector_wrap.cxx -o _bow_id_selector.so \
-I $( python -c "import distutils.sysconfig ; print(distutils.sysconfig.get_python_inc())" ) \
-I $CONDA_PREFIX/include $CONDA_PREFIX/lib/libfaiss_avx2.so

*/


// Put C++ includes here
%{

#include <faiss/impl/FaissException.h>
#include <faiss/impl/IDSelector.h>

%}

// to get uint32_t and friends
%include <stdint.i>

// This means: assume what's declared in these .h files is provided
// by the Faiss module.
%import(module="faiss") "faiss/MetricType.h"
%import(module="faiss") "faiss/impl/IDSelector.h"

// functions to be parsed here

// This is important to release GIL and do Faiss exception handing
%exception {
Py_BEGIN_ALLOW_THREADS
try {
$action
} catch(faiss::FaissException & e) {
PyEval_RestoreThread(_save);

if (PyErr_Occurred()) {
// some previous code already set the error type.
} else {
PyErr_SetString(PyExc_RuntimeError, e.what());
}
SWIG_fail;
} catch(std::bad_alloc & ba) {
PyEval_RestoreThread(_save);
PyErr_SetString(PyExc_MemoryError, "std::bad_alloc");
SWIG_fail;
}
Py_END_ALLOW_THREADS
}


// any class or function declared below will be made available
// in the module.
%inline %{

struct IDSelectorBOW : faiss::IDSelector {
size_t nb;
using TL = int32_t;
const TL *lims;
const int32_t *indices;
int32_t w1 = -1, w2 = -1;

IDSelectorBOW(
size_t nb, const TL *lims, const int32_t *indices):
nb(nb), lims(lims), indices(indices) {}

void set_query_words(int32_t w1, int32_t w2) {
this->w1 = w1;
this->w2 = w2;
}

// binary search in the indices array
bool find_sorted(TL l0, TL l1, int32_t w) const {
while (l1 > l0 + 1) {
TL lmed = (l0 + l1) / 2;
if (indices[lmed] > w) {
l1 = lmed;
} else {
l0 = lmed;
}
}
return indices[l0] == w;
}

bool is_member(faiss::idx_t id) const {
TL l0 = lims[id], l1 = lims[id + 1];
if (l1 <= l0) {
return false;
}
if(!find_sorted(l0, l1, w1)) {
return false;
}
if(w2 >= 0 && !find_sorted(l0, l1, w2)) {
return false;
}
return true;
}

~IDSelectorBOW() override {}
};


struct IDSelectorBOWBin : IDSelectorBOW {
/** with additional binary filtering */
faiss::idx_t id_mask;

IDSelectorBOWBin(
size_t nb, const TL *lims, const int32_t *indices, faiss::idx_t id_mask):
IDSelectorBOW(nb, lims, indices), id_mask(id_mask) {}

faiss::idx_t q_mask = 0;

void set_query_words_mask(int32_t w1, int32_t w2, faiss::idx_t q_mask) {
set_query_words(w1, w2);
this->q_mask = q_mask;
}

bool is_member(faiss::idx_t id) const {
if (q_mask & ~id) {
return false;
}
return IDSelectorBOW::is_member(id & id_mask);
}

~IDSelectorBOWBin() override {}
};


size_t intersect_sorted_c(
size_t n1, const int32_t *a1,
size_t n2, const int32_t *a2,
int32_t *res)
{
if (n1 == 0 || n2 == 0) {
return 0;
}
size_t i1 = 0, i2 = 0, i = 0;
for(;;) {
if (a1[i1] < a2[i2]) {
i1++;
if (i1 >= n1) {
return i;
}
} else if (a1[i1] > a2[i2]) {
i2++;
if (i2 >= n2) {
return i;
}
} else { // equal
res[i++] = a1[i1++];
i2++;
if (i1 >= n1 || i2 >= n2) {
return i;
}
}
}
}

%}


%pythoncode %{

import numpy as np

# example additional function that converts the passed-in numpy arrays to
# C++ pointers
def intersect_sorted(a1, a2):
n1, = a1.shape
n2, = a2.shape
if n1 < n2:
res = np.empty(n1, dtype=a1.dtype)
else:
res = np.empty(n2, dtype=a1.dtype)
nres = intersect_sorted_c(
n1, faiss.swig_ptr(a1),
n2, faiss.swig_ptr(a2),
faiss.swig_ptr(res)
)
return res[:nres]

%}
66 changes: 66 additions & 0 deletions neurips23/filter/cufe/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
random-filter-s:
cufe:
docker-tag: neurips23-filter-cufe
module: neurips23.filter.cufe.faissCUFE
constructor: faissCUFE
base-args: ["@metric"]
run-groups:
base:
args: |
[{"indexkey": "IVF1024,SQ8"}]
query-args: |
[{"nprobe": 1},
{"nprobe":2},
{"nprobe":4}]
random-s:
cufe:
docker-tag: neurips23-filter-cufe
module: neurips23.filter.cufe.faissCUFE
constructor: faissCUFE
base-args: ["@metric"]
run-groups:
base:
args: |
[{"indexkey": "IVF1024,SQ8"}]
query-args: |
[{"nprobe": 1},
{"nprobe":2},
{"nprobe":4}]
yfcc-10M-unfiltered:
cufe:
docker-tag: neurips23-filter-cufe
module: neurips23.filter.cufe.faissCUFE
constructor: faissCUFE
base-args: ["@metric"]
run-groups:
base:
args: |
[{"indexkey": "IVF16384,SQ8", "binarysig": true, "threads": 16}]
query-args: |
[{"nprobe": 1}, {"nprobe": 4}, {"nprobe": 16}, {"nprobe": 64}]
yfcc-10M:
cufe:
docker-tag: neurips23-filter-cufe
module: neurips23.filter.cufe.faissCUFE
constructor: faissCUFE
base-args: ["@metric"]
run-groups:
base:
args: |
[{"indexkey": "IVF4096,SQ8",
"binarysig": true,
"threads": 16
}]
query-args: |
[{"nprobe": 4, "mt_threshold":0.0003},
{"nprobe": 16, "mt_threshold":0.0003},
{"nprobe": 4, "mt_threshold":0.0001},
{"nprobe": 16, "mt_threshold":0.0001},
{"nprobe": 10, "mt_threshold":0.0001},
{"nprobe": 8, "mt_threshold": 0.0003},
{"nprobe": 32, "mt_threshold": 0.00033},
{"nprobe": 30, "mt_threshold": 0.00033},
{"nprobe": 12, "mt_threshold": 0.0002},
{"nprobe": 16, "mt_threshold": 0.00033}
]

Loading
Loading