Skip to content

Commit

Permalink
update examples
Browse files Browse the repository at this point in the history
  • Loading branch information
staoxiao committed Mar 30, 2022
1 parent 8336124 commit 72756bf
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 5 deletions.
4 changes: 2 additions & 2 deletions LibVQ/base_index/BaseIndex.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def hard_negative(self,
score, search_results = self.search(query_embeddings, topk=topk, batch_size=batch_size, nprobe=nprobe)
query2hardneg = {}
for qid, neighbors in enumerate(search_results):
neg = list(filter(lambda x: x not in ground_truths[qid], neighbors))
query2hardneg[qid] = [x for x in neg if x != -1]
neg = list(filter(lambda x: x not in ground_truths[qid] and x != -1, neighbors))
query2hardneg[qid] = neg
return query2hardneg

def generate_virtual_traindata(self,
Expand Down
4 changes: 2 additions & 2 deletions examples/MSMARCO/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ pleaser refer to [dataset.README.md](../../LibVQ/dataset/README.md)
python ./basic_index/faiss_index.py \
--preprocess_dir ./data/passage/preprocess \
--embeddings_dir ./data/passage/evaluate/co-condenser \
--index_method ivf_pq \
--index_method ivf_opq \
--ivf_centers_num 10000 \
--subvector_num 32 \
--subvector_num 64 \
--subvector_bits 8 \
--nprobe 100
```
Expand Down
2 changes: 1 addition & 1 deletion examples/MSMARCO/basic_index/faiss_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from LibVQ.base_index import FaissIndex
from LibVQ.dataset.dataset import load_rel

faiss.omp_set_num_threads(32)


if __name__ == '__main__':
parser = HfArgumentParser((IndexArguments, DataArguments, ModelArguments, TrainingArguments))
Expand Down
5 changes: 5 additions & 0 deletions examples/MSMARCO/learnable_index/train_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
sys.path.append('./')
import os
import pickle
import gc

import faiss
import numpy as np
Expand Down Expand Up @@ -62,6 +63,8 @@
batch_size=64,
nprobe=min(index_args.ivf_centers_num, 500))
pickle.dump(trainquery2hardneg, open(neg_file, 'wb'))
del trainquery2hardneg
gc.collect()

data_args.save_ckpt_dir = f'./saved_ckpts/{training_args.training_mode}_{index_args.index_method}/'

Expand Down Expand Up @@ -131,6 +134,8 @@
pickle.dump(query2neg,
open(os.path.join(data_args.embeddings_dir, f"train-queries-virtual_hardneg.pickle"), 'wb'))

del query2neg, query2pos
gc.collect()

# distill with no label data
if training_args.training_mode == 'distill_index_nolabel':
Expand Down
7 changes: 7 additions & 0 deletions examples/MSMARCO/learnable_index/train_index_and_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
sys.path.append('./')
import os
import pickle
import gc

import faiss
import numpy as np
Expand Down Expand Up @@ -73,6 +74,9 @@
nprobe=min(index_args.ivf_centers_num, 500))
pickle.dump(trainquery2hardneg, open(neg_file, 'wb'))

del trainquery2hardneg
gc.collect()

data_args.save_ckpt_dir = f'./saved_ckpts/{training_args.training_mode}_{index_args.index_method}/'

# contrastive learning
Expand Down Expand Up @@ -192,6 +196,9 @@
pickle.dump(query2neg,
open(os.path.join(data_args.embeddings_dir, f"train-queries-virtual_hardneg.pickle"), 'wb'))

del query2pos, query2neg
gc.collect()


# distill with no label data
if training_args.training_mode == 'distill_index-and-query-encoder_nolabel':
Expand Down
7 changes: 7 additions & 0 deletions examples/NQ/learnable_index/train_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
sys.path.append('./')
import os
import pickle
import gc

import faiss
import numpy as np
Expand Down Expand Up @@ -64,6 +65,9 @@
nprobe=index_args.ivf_centers_num)
pickle.dump(trainquery2hardneg, open(neg_file, 'wb'))

del trainquery2hardneg
gc.collect()

data_args.save_ckpt_dir = f'./saved_ckpts/{training_args.training_mode}_{index_args.index_method}/'

# contrastive learning
Expand Down Expand Up @@ -131,6 +135,9 @@
pickle.dump(query2neg,
open(os.path.join(data_args.embeddings_dir, f"train-queries-virtual_hardneg.pickle"), 'wb'))

del query2neg, query2pos
gc.collect()


# distill with no label data
if training_args.training_mode == 'distill_index_nolabel':
Expand Down
6 changes: 6 additions & 0 deletions examples/NQ/learnable_index/train_index_and_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
sys.path.append('./')
import os
import pickle
import gc

import faiss
import numpy as np
Expand Down Expand Up @@ -74,6 +75,8 @@
batch_size=64,
nprobe=index_args.ivf_centers_num)
pickle.dump(trainquery2hardneg, open(neg_file, 'wb'))
del trainquery2hardneg
gc.collect()

data_args.save_ckpt_dir = f'./saved_ckpts/{training_args.training_mode}_{index_args.index_method}/'

Expand Down Expand Up @@ -173,6 +176,9 @@
pickle.dump(query2neg,
open(os.path.join(data_args.embeddings_dir, f"train-queries-virtual_hardneg.pickle"), 'wb'))

del query2pos, query2neg
gc.collect()


# distill with no label data
if training_args.training_mode == 'distill_index-and-query-encoder_nolabel':
Expand Down

0 comments on commit 72756bf

Please sign in to comment.