diff --git a/.github/workflows/neurips23.yml b/.github/workflows/neurips23.yml index d8a72627..82c38bbd 100644 --- a/.github/workflows/neurips23.yml +++ b/.github/workflows/neurips23.yml @@ -33,6 +33,10 @@ jobs: - algorithm: vamana dataset: random-xs track: ood + # Test fassplus entry + - algorithm: faissplus + dataset: random-filter-s + track: filter - algorithm: shnsw dataset: sparse-small track: sparse diff --git a/neurips23/filter/faissplus/Dockerfile b/neurips23/filter/faissplus/Dockerfile new file mode 100644 index 00000000..5bfb528e --- /dev/null +++ b/neurips23/filter/faissplus/Dockerfile @@ -0,0 +1,23 @@ +FROM neurips23 + +RUN apt update && apt install -y wget swig +RUN wget https://repo.anaconda.com/archive/Anaconda3-2023.03-0-Linux-x86_64.sh +RUN bash Anaconda3-2023.03-0-Linux-x86_64.sh -b + +ENV PATH /root/anaconda3/bin:$PATH +ENV CONDA_PREFIX /root/anaconda3/ + +RUN conda install -c pytorch faiss-cpu +COPY install/requirements_conda.txt ./ +# conda doesn't like some of our packages, use pip +RUN python3 -m pip install -r requirements_conda.txt + +COPY neurips23/filter/faissplus/bow_id_selector.swig ./ + +RUN swig -c++ -python -I$CONDA_PREFIX/include -Ifaiss bow_id_selector.swig +RUN g++ -shared -O3 -g -fPIC bow_id_selector_wrap.cxx -o _bow_id_selector.so \ + -I $( python -c "import distutils.sysconfig ; print(distutils.sysconfig.get_python_inc())" ) \ + -I $CONDA_PREFIX/include $CONDA_PREFIX/lib/libfaiss_avx2.so -Ifaiss + +RUN python3 -c 'import faiss; print(faiss.IndexFlatL2); print(faiss.__version__)' + diff --git a/neurips23/filter/faissplus/bow_id_selector.swig b/neurips23/filter/faissplus/bow_id_selector.swig new file mode 100644 index 00000000..748b9db9 --- /dev/null +++ b/neurips23/filter/faissplus/bow_id_selector.swig @@ -0,0 +1,185 @@ + +%module bow_id_selector + +/* +To compile when Faiss is installed via conda: + +swig -c++ -python -I$CONDA_PREFIX/include bow_id_selector.swig && \ +g++ -shared -O3 -g -fPIC bow_id_selector_wrap.cxx -o _bow_id_selector.so \ + -I $( python -c "import distutils.sysconfig ; print(distutils.sysconfig.get_python_inc())" ) \ + -I $CONDA_PREFIX/include $CONDA_PREFIX/lib/libfaiss_avx2.so + +*/ + + +// Put C++ includes here +%{ + +#include +#include + +%} + +// to get uint32_t and friends +%include + +// This means: assume what's declared in these .h files is provided +// by the Faiss module. +%import(module="faiss") "faiss/MetricType.h" +%import(module="faiss") "faiss/impl/IDSelector.h" + +// functions to be parsed here + +// This is important to release GIL and do Faiss exception handing +%exception { + Py_BEGIN_ALLOW_THREADS + try { + $action + } catch(faiss::FaissException & e) { + PyEval_RestoreThread(_save); + + if (PyErr_Occurred()) { + // some previous code already set the error type. + } else { + PyErr_SetString(PyExc_RuntimeError, e.what()); + } + SWIG_fail; + } catch(std::bad_alloc & ba) { + PyEval_RestoreThread(_save); + PyErr_SetString(PyExc_MemoryError, "std::bad_alloc"); + SWIG_fail; + } + Py_END_ALLOW_THREADS +} + + +// any class or function declared below will be made available +// in the module. +%inline %{ + +struct IDSelectorBOW : faiss::IDSelector { + size_t nb; + using TL = int32_t; + const TL *lims; + const int32_t *indices; + int32_t w1 = -1, w2 = -1; + + IDSelectorBOW( + size_t nb, const TL *lims, const int32_t *indices): + nb(nb), lims(lims), indices(indices) {} + + void set_query_words(int32_t w1, int32_t w2) { + this->w1 = w1; + this->w2 = w2; + } + + // binary search in the indices array + bool find_sorted(TL l0, TL l1, int32_t w) const { + while (l1 > l0 + 1) { + TL lmed = (l0 + l1) / 2; + if (indices[lmed] > w) { + l1 = lmed; + } else { + l0 = lmed; + } + } + return indices[l0] == w; + } + + bool is_member(faiss::idx_t id) const { + TL l0 = lims[id], l1 = lims[id + 1]; + if (l1 <= l0) { + return false; + } + if(!find_sorted(l0, l1, w1)) { + return false; + } + if(w2 >= 0 && !find_sorted(l0, l1, w2)) { + return false; + } + return true; + } + + ~IDSelectorBOW() override {} +}; + + +struct IDSelectorBOWBin : IDSelectorBOW { + /** with additional binary filtering */ + faiss::idx_t id_mask; + + IDSelectorBOWBin( + size_t nb, const TL *lims, const int32_t *indices, faiss::idx_t id_mask): + IDSelectorBOW(nb, lims, indices), id_mask(id_mask) {} + + faiss::idx_t q_mask = 0; + + void set_query_words_mask(int32_t w1, int32_t w2, faiss::idx_t q_mask) { + set_query_words(w1, w2); + this->q_mask = q_mask; + } + + bool is_member(faiss::idx_t id) const { + if (q_mask & ~id) { + return false; + } + return IDSelectorBOW::is_member(id & id_mask); + } + + ~IDSelectorBOWBin() override {} +}; + + +size_t intersect_sorted_c( + size_t n1, const int32_t *a1, + size_t n2, const int32_t *a2, + int32_t *res) +{ + if (n1 == 0 || n2 == 0) { + return 0; + } + size_t i1 = 0, i2 = 0, i = 0; + for(;;) { + if (a1[i1] < a2[i2]) { + i1++; + if (i1 >= n1) { + return i; + } + } else if (a1[i1] > a2[i2]) { + i2++; + if (i2 >= n2) { + return i; + } + } else { // equal + res[i++] = a1[i1++]; + i2++; + if (i1 >= n1 || i2 >= n2) { + return i; + } + } + } +} + +%} + + +%pythoncode %{ + +import numpy as np + +# example additional function that converts the passed-in numpy arrays to +# C++ pointers +def intersect_sorted(a1, a2): + n1, = a1.shape + n2, = a2.shape + res = np.empty(n1 + n2, dtype=a1.dtype) + nres = intersect_sorted_c( + n1, faiss.swig_ptr(a1), + n2, faiss.swig_ptr(a2), + faiss.swig_ptr(res) + ) + return res[:nres] + +%} + + diff --git a/neurips23/filter/faissplus/config.yaml b/neurips23/filter/faissplus/config.yaml new file mode 100644 index 00000000..bb8ce6fd --- /dev/null +++ b/neurips23/filter/faissplus/config.yaml @@ -0,0 +1,66 @@ +random-filter-s: + faissplus: + docker-tag: neurips23-filter-faissplus + module: neurips23.filter.faissplus.faiss + constructor: FAISS + base-args: ["@metric"] + run-groups: + base: + args: | + [{"indexkey": "IVF1024,SQ8"}] + query-args: | + [{"nprobe": 1}, + {"nprobe":2}, + {"nprobe":4}] +random-s: + faissplus: + docker-tag: neurips23-filter-faissplus + module: neurips23.filter.faissplus.faiss + constructor: FAISS + base-args: ["@metric"] + run-groups: + base: + args: | + [{"indexkey": "IVF1024,SQ8"}] + query-args: | + [{"nprobe": 1}, + {"nprobe":2}, + {"nprobe":4}] +yfcc-10M-unfiltered: + faissplus: + docker-tag: neurips23-filter-faissplus + module: neurips23.filter.faissplus.faiss + constructor: FAISS + base-args: ["@metric"] + run-groups: + base: + args: | + [{"indexkey": "IVF16384,SQ8", "binarysig": true, "threads": 16}] + query-args: | + [{"nprobe": 1}, {"nprobe": 4}, {"nprobe": 16}, {"nprobe": 64}] +yfcc-10M: + faissplus: + docker-tag: neurips23-filter-faissplus + module: neurips23.filter.faissplus.faiss + constructor: FAISS + base-args: ["@metric"] + run-groups: + base: + args: | + [{"indexkey": "IVF11264,SQ8", + "binarysig": true, + "threads": 16 + }] + query-args: | + [ + {"nprobe": 34, "mt_threshold": 0.00031}, + {"nprobe": 32, "mt_threshold": 0.0003}, + {"nprobe": 32, "mt_threshold": 0.00031}, + {"nprobe": 34, "mt_threshold": 0.0003}, + {"nprobe": 34, "mt_threshold": 0.00035}, + {"nprobe": 32, "mt_threshold": 0.00033}, + {"nprobe": 30, "mt_threshold": 0.00033}, + {"nprobe": 32, "mt_threshold": 0.00035}, + {"nprobe": 34, "mt_threshold": 0.00033}, + {"nprobe": 40, "mt_threshold": 0.0003} + ] diff --git a/neurips23/filter/faissplus/faiss.py b/neurips23/filter/faissplus/faiss.py new file mode 100644 index 00000000..f45d00a7 --- /dev/null +++ b/neurips23/filter/faissplus/faiss.py @@ -0,0 +1,287 @@ +import pdb +import pickle +import numpy as np +import os + +from multiprocessing.pool import ThreadPool + +import faiss + +from neurips23.filter.base import BaseFilterANN +from benchmark.datasets import DATASETS +from benchmark.dataset_io import download_accelerated + +import bow_id_selector + +def csr_get_row_indices(m, i): + """ get the non-0 column indices for row i in matrix m """ + return m.indices[m.indptr[i] : m.indptr[i + 1]] + +def make_bow_id_selector(mat, id_mask=0): + sp = faiss.swig_ptr + if id_mask == 0: + return bow_id_selector.IDSelectorBOW(mat.shape[0], sp(mat.indptr), sp(mat.indices)) + else: + return bow_id_selector.IDSelectorBOWBin( + mat.shape[0], sp(mat.indptr), sp(mat.indices), id_mask + ) + +def set_invlist_ids(invlists, l, ids): + n, = ids.shape + ids = np.ascontiguousarray(ids, dtype='int64') + assert invlists.list_size(l) == n + faiss.memcpy( + invlists.get_ids(l), + faiss.swig_ptr(ids), n * 8 + ) + + + +def csr_to_bitcodes(matrix, bitsig): + """ Compute binary codes for the rows of the matrix: each binary code is + the OR of bitsig for non-0 entries of the row. + """ + indptr = matrix.indptr + indices = matrix.indices + n = matrix.shape[0] + bit_codes = np.zeros(n, dtype='int64') + for i in range(n): + # print(bitsig[indices[indptr[i]:indptr[i + 1]]]) + bit_codes[i] = np.bitwise_or.reduce(bitsig[indices[indptr[i]:indptr[i + 1]]]) + return bit_codes + + +class BinarySignatures: + """ binary signatures that encode vectors """ + + def __init__(self, meta_b, proba_1): + nvec, nword = meta_b.shape + # number of bits reserved for the vector ids + self.id_bits = int(np.ceil(np.log2(nvec))) + # number of bits for the binary signature + self.sig_bits = nbits = 63 - self.id_bits + + # select binary signatures for the vocabulary + rs = np.random.RandomState(123) # we rely on this to be reproducible! + bitsig = np.packbits(rs.rand(nword, nbits) < proba_1, axis=1) + bitsig = np.pad(bitsig, ((0, 0), (0, 8 - bitsig.shape[1]))).view("int64").ravel() + self.bitsig = bitsig + + # signatures for all the metadata matrix + self.db_sig = csr_to_bitcodes(meta_b, bitsig) << self.id_bits + + # mask to keep only the ids + self.id_mask = (1 << self.id_bits) - 1 + + def query_signature(self, w1, w2): + """ compute the query signature for 1 or 2 words """ + sig = self.bitsig[w1] + if w2 != -1: + sig |= self.bitsig[w2] + return int(sig << self.id_bits) + +class FAISS(BaseFilterANN): + + def __init__(self, metric, index_params): + self._index_params = index_params + self._metric = metric + print(index_params) + self.indexkey = index_params.get("indexkey", "IVF32768,SQ8") + self.binarysig = index_params.get("binarysig", True) + self.binarysig_proba1 = index_params.get("binarysig_proba1", 0.1) + self.metadata_threshold = 1e-3 + self.nt = index_params.get("threads", 1) + + + def fit(self, dataset): + ds = DATASETS[dataset]() + if ds.search_type() == "knn_filtered" and self.binarysig: + print("preparing binary signatures") + meta_b = ds.get_dataset_metadata() + self.binsig = BinarySignatures(meta_b, self.binarysig_proba1) + print("writing to", self.binarysig_name(dataset)) + pickle.dump(self.binsig, open(self.binarysig_name(dataset), "wb"), -1) + else: + self.binsig = None + + if ds.search_type() == "knn_filtered": + self.meta_b = ds.get_dataset_metadata() + self.meta_b.sort_indices() + + index = faiss.index_factory(ds.d, self.indexkey) + xb = ds.get_dataset() + print("train") + index.train(xb) + print("populate") + if self.binsig is None: + index.add(xb) + else: + ids = np.arange(ds.nb) | self.binsig.db_sig + index.add_with_ids(xb, ids) + + self.index = index + self.nb = ds.nb + self.xb = xb + self.ps = faiss.ParameterSpace() + self.ps.initialize(self.index) + print("store", self.index_name(dataset)) + faiss.write_index(index, self.index_name(dataset)) + + + def index_name(self, name): + return f"data/{name}.{self.indexkey}.faissindex" + + def binarysig_name(self, name): + return f"data/{name}.{self.indexkey}.binarysig" + + + def load_index(self, dataset): + """ + Load the index for dataset. Returns False if index + is not available, True otherwise. + + Checking the index usually involves the dataset name + and the index build paramters passed during construction. + """ + if not os.path.exists(self.index_name(dataset)): + if 'url' not in self._index_params: + return False + + print('Downloading index in background. This can take a while.') + download_accelerated(self._index_params['url'], self.index_name(dataset), quiet=True) + + print("Loading index") + + self.index = faiss.read_index(self.index_name(dataset)) + + self.ps = faiss.ParameterSpace() + self.ps.initialize(self.index) + + ds = DATASETS[dataset]() + + if ds.search_type() == "knn_filtered" and self.binarysig: + if not os.path.exists(self.binarysig_name(dataset)): + print("preparing binary signatures") + meta_b = ds.get_dataset_metadata() + self.binsig = BinarySignatures(meta_b, self.binarysig_proba1) + else: + print("loading binary signatures") + self.binsig = pickle.load(open(self.binarysig_name(dataset), "rb")) + else: + self.binsig = None + + if ds.search_type() == "knn_filtered": + self.meta_b = ds.get_dataset_metadata() + self.meta_b.sort_indices() + + self.nb = ds.nb + self.xb = ds.get_dataset() + + return True + + def index_files_to_store(self, dataset): + """ + Specify a triplet with the local directory path of index files, + the common prefix name of index component(s) and a list of + index components that need to be uploaded to (after build) + or downloaded from (for search) cloud storage. + + For local directory path under docker environment, please use + a directory under + data/indices/track(T1 or T2)/algo.__str__()/DATASETS[dataset]().short_name() + """ + raise NotImplementedError() + + def query(self, X, k): + nq = X.shape[0] + self.I = -np.ones((nq, k), dtype='int32') + bs = 1024 + for i0 in range(0, nq, bs): + _, self.I[i0:i0+bs] = self.index.search(X[i0:i0+bs], k) + + + def filtered_query(self, X, filter, k): + print('running filtered query') + nq = X.shape[0] + self.I = -np.ones((nq, k), dtype='int32') + meta_b = self.meta_b + meta_q = filter + docs_per_word = meta_b.T.tocsr() + ndoc_per_word = docs_per_word.indptr[1:] - docs_per_word.indptr[:-1] + freq_per_word = ndoc_per_word / self.nb + + def process_one_row(q): + faiss.omp_set_num_threads(1) + qwords = csr_get_row_indices(meta_q, q) + assert qwords.size in (1, 2) + w1 = qwords[0] + freq = freq_per_word[w1] + if qwords.size == 2: + w2 = qwords[1] + freq *= freq_per_word[w2] + else: + w2 = -1 + if freq < self.metadata_threshold: + # metadata first + docs = csr_get_row_indices(docs_per_word, w1) + if w2 != -1: + docs = bow_id_selector.intersect_sorted( + docs, csr_get_row_indices(docs_per_word, w2)) + + assert len(docs) >= k, pdb.set_trace() + xb_subset = self.xb[docs] + _, Ii = faiss.knn(X[q : q + 1], xb_subset, k=k) + + self.I[q, :] = docs[Ii.ravel()] + else: + # IVF first, filtered search + sel = make_bow_id_selector(meta_b, self.binsig.id_mask if self.binsig else 0) + if self.binsig is None: + sel.set_query_words(int(w1), int(w2)) + else: + sel.set_query_words_mask( + int(w1), int(w2), self.binsig.query_signature(w1, w2)) + + params = faiss.SearchParametersIVF(sel=sel, nprobe=self.nprobe) + + _, Ii = self.index.search( + X[q:q+1], k, params=params + ) + Ii = Ii.ravel() + if self.binsig is None: + self.I[q] = Ii + else: + # we'll just assume there are enough results + # valid = Ii != -1 + # I[q, valid] = Ii[valid] & binsig.id_mask + self.I[q] = Ii & self.binsig.id_mask + + + if self.nt <= 1: + for q in range(nq): + process_one_row(q) + else: + faiss.omp_set_num_threads(self.nt) + pool = ThreadPool(self.nt) + list(pool.map(process_one_row, range(nq))) + + def get_results(self): + return self.I + + def set_query_arguments(self, query_args): + faiss.cvar.indexIVF_stats.reset() + if "nprobe" in query_args: + self.nprobe = query_args['nprobe'] + self.ps.set_index_parameters(self.index, f"nprobe={query_args['nprobe']}") + self.qas = query_args + else: + self.nprobe = 1 + if "mt_threshold" in query_args: + self.metadata_threshold = query_args['mt_threshold'] + else: + self.metadata_threshold = 1e-3 + + def __str__(self): + return f'Faiss({self.indexkey, self.qas})' + +