Skip to content

Commit

Permalink
Merge branch 'fastscan-reconstruct' of github.com:alisafaya/faiss int…
Browse files Browse the repository at this point in the history
…o fastscan-reconstruct
  • Loading branch information
alisafaya committed Jan 7, 2025
2 parents 03e06dc + 33734dc commit 74a5893
Show file tree
Hide file tree
Showing 24 changed files with 843 additions and 70 deletions.
4 changes: 2 additions & 2 deletions .github/actions/build_conda/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ runs:
- name: Install conda build tools
shell: ${{ steps.choose_shell.outputs.shell }}
run: |
conda install -y "conda!=24.11.0"
conda install -y "conda-build!=24.11.0"
conda install -y -q "conda!=24.11.0"
conda install -y -q "conda-build!=24.11.0"
- name: Fix CI failure
shell: ${{ steps.choose_shell.outputs.shell }}
if: runner.os != 'Windows'
Expand Down
41 changes: 11 additions & 30 deletions .github/workflows/build-pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ jobs:
uses: ./.github/actions/build_cmake
with:
opt_level: avx512
linux-x86_64-AVX512_SPR-cmake:
name: Linux x86_64 AVX512_SPR (cmake)
needs: linux-x86_64-cmake
runs-on: faiss-aws-m7i.large
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Build and Test (cmake)
uses: ./.github/actions/build_cmake
with:
opt_level: avx512_spr
linux-x86_64-GPU-cmake:
name: Linux x86_64 GPU (cmake)
needs: linux-x86_64-cmake
Expand Down Expand Up @@ -132,36 +143,6 @@ jobs:
fetch-tags: true
- name: Build and Package (conda)
uses: ./.github/actions/build_conda
linux-x86_64-GPU-CUVS-CUDA11-8-0-conda:
name: Linux x86_64 GPU w/ cuVS conda (CUDA 11.8.0)
runs-on: 4-core-ubuntu-gpu-t4
env:
CUDA_ARCHS: "70-real;72-real;75-real;80;86-real"
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
fetch-tags: true
- uses: ./.github/actions/build_conda
with:
cuvs: "ON"
cuda: "11.8.0"
linux-x86_64-GPU-CUVS-CUDA12-4-0-conda:
name: Linux x86_64 GPU w/ cuVS conda (CUDA 12.4.0)
runs-on: 4-core-ubuntu-gpu-t4
env:
CUDA_ARCHS: "70-real;72-real;75-real;80;86-real"
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
fetch-tags: true
- uses: ./.github/actions/build_conda
with:
cuvs: "ON"
cuda: "12.4.0"
windows-x86_64-conda:
name: Windows x86_64 (conda)
needs: linux-x86_64-cmake
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@
/tests/gtest/
faiss/python/swigfaiss_avx2.swig
faiss/python/swigfaiss_avx512.swig
faiss/python/swigfaiss_avx512_spr.swig
faiss/python/swigfaiss_sve.swig
12 changes: 7 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ set(CMAKE_CXX_STANDARD 17)

list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")

# Valid values are "generic", "avx2", "avx512", "sve".
# Valid values are "generic", "avx2", "avx512", "avx512_spr", "sve".
option(FAISS_OPT_LEVEL "" "generic")
option(FAISS_ENABLE_GPU "Enable support for GPU indexes." ON)
option(FAISS_ENABLE_CUVS "Enable cuVS for GPU indexes." OFF)
option(FAISS_ENABLE_ROCM "Enable ROCm for GPU indexes." OFF)
option(FAISS_ENABLE_PYTHON "Build Python extension." ON)
option(FAISS_ENABLE_C_API "Build C API." OFF)
option(FAISS_ENABLE_EXTRAS "Build extras like benchmarks and demos" ON)
option(FAISS_USE_LTO "Enable Link-Time optimization" OFF)

if(FAISS_ENABLE_GPU)
Expand Down Expand Up @@ -103,10 +104,11 @@ if(FAISS_ENABLE_C_API)
add_subdirectory(c_api)
endif()

add_subdirectory(demos)
add_subdirectory(benchs)
add_subdirectory(tutorial/cpp)

if(FAISS_ENABLE_EXTRAS)
add_subdirectory(demos)
add_subdirectory(benchs)
add_subdirectory(tutorial/cpp)
endif()

# CTest must be included in the top level to enable `make test` target.
include(CTest)
Expand Down
47 changes: 36 additions & 11 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pre-release nightly builds.

- The CPU-only faiss-cpu conda package is currently available on Linux (x86-64 and aarch64), OSX (arm64 only), and Windows (x86-64)
- faiss-gpu, containing both CPU and GPU indices, is available on Linux (x86-64 only) for CUDA 11.4 and 12.1
- faiss-gpu-raft containing both CPU and GPU indices provided by NVIDIA RAFT, is available on Linux (x86-64 only) for CUDA 11.8 and 12.1.
- faiss-gpu-raft [^1] package containing GPU indices provided by [NVIDIA RAFT](https://github.com/rapidsai/raft/) version 24.06, is available on Linux (x86-64 only) for CUDA 11.8 and 12.4.

To install the latest stable release:

Expand All @@ -23,10 +23,9 @@ $ conda install -c pytorch -c nvidia -c rapidsai -c conda-forge faiss-gpu-raft=1
# GPU(+CPU) version using AMD ROCm not yet available
```

For faiss-gpu, the nvidia channel is required for CUDA, which is not
published in the main anaconda channel.
For faiss-gpu, the nvidia channel is required for CUDA, which is not published in the main anaconda channel.

For faiss-gpu-raft, the nvidia, rapidsai and conda-forge channels are required.
For faiss-gpu-raft, the rapidsai, conda-forge and nvidia channels are required.

Nightly pre-release packages can be installed as follows:

Expand All @@ -37,8 +36,11 @@ $ conda install -c pytorch/label/nightly faiss-cpu
# GPU(+CPU) version
$ conda install -c pytorch/label/nightly -c nvidia faiss-gpu=1.9.0

# GPU(+CPU) version with NVIDIA RAFT
conda install -c pytorch -c nvidia -c rapidsai -c conda-forge faiss-gpu-raft=1.9.0 pytorch pytorch-cuda numpy
# GPU(+CPU) version with NVIDIA cuVS (package built with CUDA 12.4)
conda install -c pytorch -c rapidsai -c conda-forge -c nvidia pytorch/label/nightly::faiss-gpu-cuvs 'cuda-version>=12.0,<=12.5'

# GPU(+CPU) version with NVIDIA cuVS (package built with CUDA 11.8)
conda install -c pytorch -c rapidsai -c conda-forge -c nvidia pytorch/label/nightly::faiss-gpu-cuvs 'cuda-version>=11.4,<=11.8'

# GPU(+CPU) version using AMD ROCm not yet available
```
Expand Down Expand Up @@ -68,7 +70,7 @@ $ conda install -c conda-forge faiss-cpu
# GPU version
$ conda install -c conda-forge faiss-gpu

# AMD ROCm version not yet available
# NVIDIA cuVS and AMD ROCm version not yet available
```

You can tell which channel your conda packages come from by using `conda list`.
Expand All @@ -95,6 +97,8 @@ The optional requirements are:
- the CUDA toolkit,
- for AMD GPUs:
- AMD ROCm,
- for using NVIDIA cuVS implementations:
- libcuvs=24.12
- for the python bindings:
- python 3,
- numpy,
Expand All @@ -103,6 +107,19 @@ The optional requirements are:
Indications for specific configurations are available in the [troubleshooting
section of the wiki](https://github.com/facebookresearch/faiss/wiki/Troubleshooting).

### Building with NVIDIA cuVS

The libcuvs dependency should be installed via conda:
1. With CUDA 12.0 - 12.5:
```
conda install -c rapidsai -c conda-forge -c nvidia libcuvs=24.12 'cuda-version>=12.0,<=12.5'
```
2. With CUDA 11.4 - 11.8
```
conda install -c rapidsai -c conda-forge -c nvidia libcuvs=24.12 'cuda-version>=11.4,<=11.8'
```
For more ways to install cuVS 24.12, refer to the [RAPIDS Installation Guide](https://docs.rapids.ai/install).

## Step 1: invoking CMake

``` shell
Expand All @@ -118,9 +135,9 @@ Several options can be passed to CMake, among which:
values are `ON` and `OFF`),
- `-DFAISS_ENABLE_PYTHON=OFF` in order to disable building python bindings
(possible values are `ON` and `OFF`),
- `-DFAISS_ENABLE_CUVS=ON` in order to enable building the cuVS implementations
of the IVF-Flat and IVF-PQ GPU-accelerated indices (default is `OFF`, possible
values are `ON` and `OFF`)
- `-DFAISS_ENABLE_CUVS=ON` in order to use the NVIDIA cuVS implementations
of the IVF-Flat, IVF-PQ and [CAGRA](https://arxiv.org/pdf/2308.15136) GPU-accelerated indices (default is `OFF`, possible, values are `ON` and `OFF`).
Note: `-DFAISS_ENABLE_GPU` must be set to `ON` when enabling this option.
- `-DBUILD_TESTING=OFF` in order to disable building C++ tests,
- `-DBUILD_SHARED_LIBS=ON` in order to build a shared library (possible values
are `ON` and `OFF`),
Expand All @@ -131,7 +148,7 @@ Several options can be passed to CMake, among which:
optimization options (enables `-O3` on gcc for instance),
- `-DFAISS_OPT_LEVEL=avx2` in order to enable the required compiler flags to
generate code using optimized SIMD/Vector instructions. Possible values are below:
- On x86-64, `generic`, `avx2` and `avx512`, by increasing order of optimization,
- On x86-64, `generic`, `avx2`, 'avx512', and `avx512_spr` (for avx512 features available since Intel(R) Sapphire Rapids), by increasing order of optimization,
- On aarch64, `generic` and `sve`, by increasing order of optimization,
- `-DFAISS_USE_LTO=ON` in order to enable [Link-Time Optimization](https://en.wikipedia.org/wiki/Link-time_optimization) (default is `OFF`, possible values are `ON` and `OFF`).
- BLAS-related options:
Expand Down Expand Up @@ -180,6 +197,12 @@ For AVX512:
$ make -C build -j faiss_avx512
```

For AVX512 features available since Intel(R) Sapphire Rapids.

``` shell
$ make -C build -j faiss_avx512_spr
```

This will ensure the creation of neccesary files when building and installing the python package.

## Step 3: Building the python bindings (optional)
Expand Down Expand Up @@ -296,3 +319,5 @@ and you can run
$ python demos/demo_auto_tune.py
```
to test the GPU code.

[^1]: The vector search and clustering algorithms in NVIDIA RAFT have been formally migrated to [NVIDIA cuVS](https://github.com/rapidsai/cuvs). This package is being renamed to `faiss-gpu-cuvs` in the next stable release, which will use these GPU implementations from the pre-compiled `libcuvs=24.12` binary.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The GPU implementation can accept input from either CPU or GPU memory. On a serv

## Installing

Faiss comes with precompiled libraries for Anaconda in Python, see [faiss-cpu](https://anaconda.org/pytorch/faiss-cpu) and [faiss-gpu](https://anaconda.org/pytorch/faiss-gpu). The library is mostly implemented in C++, the only dependency is a [BLAS](https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms) implementation. Optional GPU support is provided via CUDA or AMD ROCm, and the Python interface is also optional. It compiles with cmake. See [INSTALL.md](INSTALL.md) for details.
Faiss comes with precompiled libraries for Anaconda in Python, see [faiss-cpu](https://anaconda.org/pytorch/faiss-cpu), [faiss-gpu](https://anaconda.org/pytorch/faiss-gpu) and [faiss-gpu-cuvs](https://anaconda.org/pytorch/faiss-gpu-cuvs). The library is mostly implemented in C++, the only dependency is a [BLAS](https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms) implementation. Optional GPU support is provided via CUDA or AMD ROCm, and the Python interface is also optional. The backend GPU implementations of NVIDIA [cuVS](https://github.com/rapidsai/cuvs) can also be enabled optionally. It compiles with cmake. See [INSTALL.md](INSTALL.md) for details.

## How Faiss works

Expand Down
3 changes: 3 additions & 0 deletions benchs/bench_fw/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def __hash__(self):
return hash(str(self))

def get_name(self):
if self.desc_name is not None:
return self.desc_name
name = self.index_desc.get_name()
name += IndexBaseDescriptor.param_dict_to_name(self.search_params)
name += self.query_dataset.get_filename(KnnDescriptor.FILENAME_PREFIX)
Expand All @@ -350,6 +352,7 @@ def get_name(self):
name += "rec."
else:
name += "knn."
self.desc_name = name
return name

def flat_name(self):
Expand Down
13 changes: 12 additions & 1 deletion cmake/link_to_faiss_lib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

function(link_to_faiss_lib target)
if(NOT FAISS_OPT_LEVEL STREQUAL "avx2" AND NOT FAISS_OPT_LEVEL STREQUAL "avx512" AND NOT FAISS_OPT_LEVEL STREQUAL "sve")
if(NOT FAISS_OPT_LEVEL STREQUAL "avx2" AND NOT FAISS_OPT_LEVEL STREQUAL "avx512" AND NOT FAISS_OPT_LEVEL STREQUAL "avx512_spr" AND NOT FAISS_OPT_LEVEL STREQUAL "sve")
target_link_libraries(${target} PRIVATE faiss)
endif()

Expand All @@ -27,6 +27,17 @@ function(link_to_faiss_lib target)
target_link_libraries(${target} PRIVATE faiss_avx512)
endif()

if(FAISS_OPT_LEVEL STREQUAL "avx512_spr")
if(NOT WIN32)
# Architecture mode to support AVX512 extensions available since Intel (R) Sapphire Rapids.
# Ref: https://networkbuilders.intel.com/solutionslibrary/intel-avx-512-fp16-instruction-set-for-intel-xeon-processor-based-products-technology-guide
target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-march=sapphirerapids -mtune=sapphirerapids>)
else()
target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
endif()
target_link_libraries(${target} PRIVATE faiss_avx512_spr)
endif()

if(FAISS_OPT_LEVEL STREQUAL "sve")
if(NOT WIN32)
if("${CMAKE_CXX_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG} " MATCHES "(^| )-march=native")
Expand Down
1 change: 1 addition & 0 deletions contrib/torch/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# the kmeans can produce both torch and numpy centroids
from faiss.contrib.clustering import kmeans


class DatasetAssign:
"""Wrapper for a tensor that offers a function to assign the vectors
to centroids. All other implementations offer the same interface"""
Expand Down
61 changes: 52 additions & 9 deletions contrib/torch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,47 @@
This contrib module contains Pytorch code for quantization.
"""

import numpy as np
import torch
import faiss

from faiss.contrib import torch_utils
import math
from faiss.contrib.torch import clustering
# the kmeans can produce both torch and numpy centroids


class Quantizer:

def __init__(self, d, code_size):
"""
d: dimension of vectors
code_size: nb of bytes of the code (per vector)
"""
self.d = d
self.code_size = code_size

def train(self, x):
"""
takes a n-by-d array and peforms training
"""
pass

def encode(self, x):
"""
takes a n-by-d float array, encodes to an n-by-code_size uint8 array
"""
pass

def decode(self, x):
def decode(self, codes):
"""
takes a n-by-code_size uint8 array, returns a n-by-d array
"""
pass


class VectorQuantizer(Quantizer):

def __init__(self, d, k):
code_size = int(torch.ceil(torch.log2(k) / 8))

code_size = int(math.ceil(torch.log2(k) / 8))
Quantizer.__init__(d, code_size)
self.k = k

Expand All @@ -42,12 +56,41 @@ def train(self, x):


class ProductQuantizer(Quantizer):

def __init__(self, d, M, nbits):
code_size = int(torch.ceil(M * nbits / 8))
Quantizer.__init__(d, code_size)
""" M: number of subvectors, d%M == 0
nbits: number of bits that each vector is encoded into
"""
assert d % M == 0
assert nbits == 8 # todo: implement other nbits values
code_size = int(math.ceil(M * nbits / 8))
Quantizer.__init__(self, d, code_size)
self.M = M
self.nbits = nbits
self.code_size = code_size

def train(self, x):
pass
nc = 2 ** self.nbits
sd = self.d // self.M
dev = x.device
dtype = x.dtype
self.codebook = torch.zeros((self.M, nc, sd), device=dev, dtype=dtype)
for m in range(self.M):
xsub = x[:, m * self.d // self.M: (m + 1) * self.d // self.M]
data = clustering.DatasetAssign(xsub.contiguous())
self.codebook[m] = clustering.kmeans(2 ** self.nbits, data)

def encode(self, x):
codes = torch.zeros((x.shape[0], self.code_size), dtype=torch.uint8)
for m in range(self.M):
xsub = x[:, m * self.d // self.M:(m + 1) * self.d // self.M]
_, I = faiss.knn(xsub.contiguous(), self.codebook[m], 1)
codes[:, m] = I.ravel()
return codes

def decode(self, codes):
idxs = [codes[:, m].long() for m in range(self.M)]
vectors = [self.codebook[m, idxs[m], :] for m in range(self.M)]
stacked_vectors = torch.stack(vectors, dim=1)
cbd = self.codebook.shape[-1]
x_rec = stacked_vectors.reshape(-1, cbd * self.M)
return x_rec
Loading

0 comments on commit 74a5893

Please sign in to comment.