Skip to content

Commit

Permalink
add overwrite/skip option in indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 6, 2024
1 parent 7a7fff8 commit 63733f5
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions lightning_ir/lightning_utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ def __init__(
self,
index_config: IndexConfig,
index_dir: Path | str | None = None,
overwrite: bool = False,
verbose: bool = False,
) -> None:
super().__init__()
self.index_config = index_config
self.index_dir = index_dir
self.overwrite = overwrite
self.verbose = verbose
self.indexer: Indexer

Expand All @@ -70,6 +72,14 @@ def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
datasets = [dataloader.dataset for dataloader in dataloaders]
if not all(isinstance(dataset, DocDataset) for dataset in datasets):
raise ValueError("Expected DocDatasets for indexing")
if not self.overwrite:
for dataset in datasets:
index_dir = self.get_index_dir(pl_module, dataset)
if index_dir.exists():
trainer.datamodule.inference_datasets.remove(dataset)
trainer.print(
f"Index dir {index_dir} already exists. Skipping this dataset. Set overwrite=True to overwrite"
)

def get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
index_dir = self.index_dir
Expand Down

0 comments on commit 63733f5

Please sign in to comment.