-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Neurips23] ParlayANN Submission for OOD track (#186)
* initial commit * added default alpha * fixed bad dockerfile * cache bust * fixed timeout * added additional search configs to get past .9 * one more query config * added two pass arg * fixing arg in diskann dockerfile * committing to switch branches * committing to switch branches * committing to switch branches * added vamana.py * fixed issue in file detection * finalizing before PR * changes requested for PR * changes for PR * initial commit * added two pass arg * added default alpha * cache bust * added additional search configs to get past .9 * one more query config * committing to switch branches * committing to switch branches * committing to switch branches * added vamana.py * fixed issue in file detection * finalizing before PR * changes requested for PR --------- Co-authored-by: Ben Landrum <[email protected]> Co-authored-by: Magdalen Dobson <[email protected]>
- Loading branch information
1 parent
9b3a10a
commit 72b61f6
Showing
5 changed files
with
232 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
FROM neurips23 | ||
|
||
RUN apt update | ||
RUN apt install -y software-properties-common | ||
RUN add-apt-repository -y ppa:git-core/ppa | ||
RUN apt update | ||
RUN DEBIAN_FRONTEND=noninteractive apt install -y git make cmake g++ libaio-dev libgoogle-perftools-dev libunwind-dev clang-format libboost-dev libboost-program-options-dev libmkl-full-dev libcpprest-dev python3.10 | ||
|
||
|
||
ARG CACHEBUST=1 | ||
RUN git clone -b ood_v2 https://github.com/cmuparlay/ParlayANN.git && cd ParlayANN && git submodule update --init --recursive && cd python && pip install pybind11 && bash compile.sh | ||
# WORKDIR /home/app/ParlayANN | ||
# RUN git submodule update --init --recursive | ||
# WORKDIR /home/app/ParlayANN/python | ||
|
||
# RUN pip install pybind11 | ||
|
||
# RUN bash compile.sh | ||
|
||
ENV PYTHONPATH=$PYTHONPATH:/home/app/ParlayANN/python | ||
|
||
# ENV PARLAY_NUM_THREADS=8 | ||
|
||
WORKDIR /home/app |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
random-xs: | ||
vamana: | ||
docker-tag: neurips23-ood-vamana | ||
module: neurips23.ood.vamana.vamana | ||
constructor: vamana | ||
base-args: ["@metric"] | ||
run-groups: | ||
base: | ||
args: | | ||
[{"R":30, "L":50, "alpha":1.2}] | ||
query-args: | | ||
[{"Ls":50, "T":8}] | ||
text2image-10M: | ||
vamana: | ||
docker-tag: neurips23-ood-vamana | ||
module: neurips23.ood.vamana.vamana | ||
constructor: vamana | ||
base-args: ["@metric"] | ||
run-groups: | ||
base: | ||
args: | | ||
[{"R":55, "L":500, "alpha":1.0, "two_pass":1, "use_query_data":1, "compress":1}] | ||
query-args: | | ||
[ | ||
{"Ls":70, "T":8}, | ||
{"Ls":80, "T":8}, | ||
{"Ls":90, "T":8}, | ||
{"Ls":95, "T":8}, | ||
{"Ls":100, "T":8}, | ||
{"Ls":105, "T":8}, | ||
{"Ls":110, "T":8}, | ||
{"Ls":120, "T":8}, | ||
{"Ls":125, "T":8}, | ||
{"Ls":150, "T":8}] | ||
vamana-singlepass: | ||
docker-tag: neurips23-ood-vamana | ||
module: neurips23.ood.vamana.vamana | ||
constructor: vamana | ||
base-args: ["@metric"] | ||
run-groups: | ||
base: | ||
args: | | ||
[{"R":64, "L":500}] | ||
query-args: | | ||
[{"Ls":30, "T":8}, | ||
{"Ls":50, "T":8}, | ||
{"Ls":70, "T":8}, | ||
{"Ls":100, "T":8}, | ||
{"Ls":113, "T":8}, | ||
{"Ls":125, "T":8}, | ||
{"Ls":150, "T":8}, | ||
{"Ls":175, "T":8}, | ||
{"Ls":200, "T":8}] | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
from __future__ import absolute_import | ||
import psutil | ||
import os | ||
import time | ||
import numpy as np | ||
import wrapper as pann | ||
|
||
from neurips23.ood.base import BaseOODANN | ||
from benchmark.datasets import DATASETS, download_accelerated, BASEDIR | ||
from benchmark.dataset_io import download | ||
|
||
class vamana(BaseOODANN): | ||
def __init__(self, metric, index_params): | ||
self.name = "vamana" | ||
if (index_params.get("R")==None): | ||
print("Error: missing parameter R") | ||
return | ||
if (index_params.get("L")==None): | ||
print("Error: missing parameter L") | ||
return | ||
self._index_params = index_params | ||
self._metric = self.translate_dist_fn(metric) | ||
|
||
self.R = int(index_params.get("R")) | ||
self.L = int(index_params.get("L")) | ||
self.alpha = float(index_params.get("alpha", 1.0)) | ||
self.two_pass = bool(index_params.get("two_pass", False)) | ||
self.use_query_data = bool(index_params.get("use_query_data", False)) | ||
self.compress_vectors = bool(index_params.get("compress", False)) | ||
|
||
def index_name(self): | ||
return f"R{self.R}_L{self.L}_alpha{self.alpha}" | ||
|
||
def create_index_dir(self, dataset): | ||
index_dir = os.path.join(os.getcwd(), "data", "indices", "ood") | ||
os.makedirs(index_dir, mode=0o777, exist_ok=True) | ||
index_dir = os.path.join(index_dir, 'vamana') | ||
os.makedirs(index_dir, mode=0o777, exist_ok=True) | ||
index_dir = os.path.join(index_dir, dataset.short_name()) | ||
os.makedirs(index_dir, mode=0o777, exist_ok=True) | ||
index_dir = os.path.join(index_dir, self.index_name()) | ||
os.makedirs(index_dir, mode=0o777, exist_ok=True) | ||
return os.path.join(index_dir, self.index_name()) | ||
|
||
def translate_dist_fn(self, metric): | ||
if metric == 'euclidean': | ||
return 'Euclidian' | ||
elif metric == 'ip': | ||
return 'mips' | ||
else: | ||
raise Exception('Invalid metric') | ||
|
||
def translate_dtype(self, dtype:str): | ||
if dtype == 'float32': | ||
return 'float' | ||
else: | ||
return dtype | ||
|
||
def prepare_sample_info(self, index_dir): | ||
if(self.use_query_data): | ||
#download the additional sample points for the ood index | ||
self.sample_points_path = "data/text2image1B/query_sample_200000.fbin" | ||
sample_qs_large_url = "https://storage.yandexcloud.net/yr-secret-share/ann-datasets-5ac0659e27/T2I/query.private.1M.fbin" | ||
bytes_to_download = 8 + 200000*4*200 | ||
download(sample_qs_large_url, self.sample_points_path, bytes_to_download) | ||
header = np.memmap(self.sample_points_path, shape=2, dtype='uint32', mode="r+") | ||
header[0] = 200000 | ||
|
||
self.secondary_index_dir = index_dir + ".secondary" | ||
self.secondary_gt_dir = self.secondary_index_dir + ".gt" | ||
else: | ||
self.sample_points_path = "" | ||
self.secondary_index_dir = "" | ||
self.secondary_gt_dir = "" | ||
|
||
def prepare_compressed_info(self): | ||
if(self.compress_vectors): | ||
self.compressed_vectors_path = "data/text2image1B/compressed_10M.fbin" | ||
else: | ||
self.compressed_vectors_path = "" | ||
|
||
def fit(self, dataset): | ||
""" | ||
Build the index for the data points given in dataset name. | ||
""" | ||
ds = DATASETS[dataset]() | ||
d = ds.d | ||
|
||
index_dir = self.create_index_dir(ds) | ||
|
||
self.prepare_sample_info(index_dir) | ||
self.prepare_compressed_info() | ||
|
||
if hasattr(self, 'index'): | ||
print("Index already exists") | ||
return | ||
else: | ||
start = time.time() | ||
# ds.ds_fn is the name of the dataset file but probably needs a prefix | ||
pann.build_vamana_index(self._metric, self.translate_dtype(ds.dtype), ds.get_dataset_fn(), self.sample_points_path, | ||
self.compressed_vectors_path, index_dir, self.secondary_index_dir, self.secondary_gt_dir, self.R, self.L, self.alpha, | ||
self.two_pass) | ||
end = time.time() | ||
print("Indexing time: ", end - start) | ||
print(f"Wrote index to {index_dir}") | ||
|
||
self.index = pann.load_vamana_index(self._metric, self.translate_dtype(ds.dtype), ds.get_dataset_fn(), self.compressed_vectors_path, | ||
self.sample_points_path, index_dir, self.secondary_index_dir, self.secondary_gt_dir, ds.nb, d) | ||
print("Index loaded") | ||
|
||
def query(self, X, k): | ||
nq, d = X.shape | ||
self.res, self.query_dists = self.index.batch_search(X, nq, k, self.Ls) | ||
|
||
def set_query_arguments(self, query_args): | ||
self._query_args = query_args | ||
self.Ls = 0 if query_args.get("Ls") is None else query_args.get("Ls") | ||
self.search_threads = self._query_args.get("T", 16) | ||
os.environ["PARLAY_NUM_THREADS"] = str(self.search_threads) | ||
|
||
def load_index(self, dataset): | ||
ds = DATASETS[dataset]() | ||
d = ds.d | ||
|
||
index_dir = self.create_index_dir(ds) | ||
self.prepare_sample_info(index_dir) | ||
self.prepare_compressed_info() | ||
|
||
print("Trying to load...") | ||
|
||
try: | ||
file_size = os.path.getsize(index_dir) | ||
print(f"File Size in Bytes is {file_size}") | ||
except FileNotFoundError: | ||
file_size = 0 | ||
print("File not found.") | ||
|
||
if file_size != 0: | ||
try: | ||
self.index = pann.load_vamana_index(self._metric, self.translate_dtype(ds.dtype), ds.get_dataset_fn(), | ||
self.compressed_vectors_path, self.sample_points_path, index_dir, | ||
self.secondary_index_dir, self.secondary_gt_dir, ds.nb, d) | ||
print("Index loaded") | ||
return True | ||
except: | ||
print("Index not found") | ||
return False | ||
else: | ||
print("Index not found") | ||
return False |