Skip to content

Latest commit

 

History

History
65 lines (53 loc) · 2.69 KB

LearnableIndex.md

File metadata and controls

65 lines (53 loc) · 2.69 KB
from LibVQ.learnable_index import LearnableIndex

This class can finetune the index and get better performance than transitional index (e.g., scann, faiss). For efficiency, it is implemented base on faiss, so it supports all operations in faiss (e.g., GPU acceleration). Specifically, you can get a origin faiss index by index = LearnableIndex.index, and apply faiss operations to it.

1. Prepare embedding
LibVQ needs two sets of vectors: one is docs' embeddings, containing all embeddings that you are going to search in; and the other is the queries' embeddings.

2. Construct index

Currently it supports five index methods: ivf, ivf_pq, ivf_opq, pq, and opq.

learnable_index = LearnableIndex(doc_embeddings = doc_embeddings,
                                 ivf_centers_num = 10000, # the number of centers in ivf
                                 subvector_num = 32,      # the number of codebooks in pq
                                 subvector_bits = 8,      # the number of codewords (2^subvector_bits) in each codebook
                                 index_method = 'ivf_opq',# ivf, ivf_opq, ivf_pq, pq, opq
                                 dist_mode = 'ip'         # distance metric
                                 )

3. Train index
fit function will train the parameters of pq and ivf (if has) based on the relevance relationship between query and doc.

learnable_index.fit(query_embeddings = query_embeddings,
                    doc_embeddings = doc_embeddings,
                    rel_data = None,   # relevance relationship between query and doc; if set None, it will automatically generate the data for training
                    loss_method = 'distill',
                    epochs=5)

For distributed training on multi GPUs, you cans use fit_with_multi_gpus.

4. Search

scores, ann_items = learnable_index.search(query_embeddings = query_embeddings,
                                       topk = 100,
                                       nprobe = 10)

5. Save and Load

Save the index:

learnable_index.save_index(saved_index_file)

Load the index:

learnable_index = LearnableIndex(index_method='ivf_opq',
                                 init_index_file=saved_index_file)

Besides, you also can load it as a faiss index:

import faiss
index = faiss.read_index(saved_index_file)

Please refer to example: MSMARCO/train_index or NQ/train_index for more information