Skip to content

Commit

Permalink
fix overwriting
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 7, 2024
1 parent d786809 commit 0796b89
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions lightning_ir/lightning_utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,8 @@ def __init__(
def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> None:
if stage != "test":
raise ValueError("IndexCallback can only be used in test stage")

def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
dataloaders = trainer.test_dataloaders
if dataloaders is None:
raise ValueError("No test_dataloaders found")
datasets = [dataloader.dataset for dataloader in dataloaders]
if not all(isinstance(dataset, DocDataset) for dataset in datasets):
raise ValueError("Expected DocDatasets for indexing")
if not self.overwrite:
datasets = list(trainer.datamodule.inference_datasets)
for dataset in datasets:
index_dir = self.get_index_dir(pl_module, dataset)
if index_dir.exists():
Expand All @@ -81,6 +74,14 @@ def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
f"Index dir {index_dir} already exists. Skipping this dataset. Set overwrite=True to overwrite"
)

def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
dataloaders = trainer.test_dataloaders
if dataloaders is None:
raise ValueError("No test_dataloaders found")
datasets = [dataloader.dataset for dataloader in dataloaders]
if not all(isinstance(dataset, DocDataset) for dataset in datasets):
raise ValueError("Expected DocDatasets for indexing")

def get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
index_dir = self.index_dir
if index_dir is None:
Expand Down Expand Up @@ -112,6 +113,13 @@ def log_to_pg(self, info: Dict[str, Any], trainer: Trainer):
if pg is not None:
pg.set_postfix(info)

def on_test_batch_start(
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
if batch_idx == 0:
self.indexer = self.get_indexer(trainer, pl_module, dataloader_idx)
super().on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)

def on_test_batch_end(
self,
trainer: Trainer,
Expand All @@ -121,11 +129,6 @@ def on_test_batch_end(
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
if batch_idx == 0:
if hasattr(self, "indexer"):
self.indexer.save()
self.indexer = self.get_indexer(trainer, pl_module, dataloader_idx)

batch = self.gather(pl_module, batch)
outputs = self.gather(pl_module, outputs)

Expand All @@ -140,6 +143,9 @@ def on_test_batch_end(
},
trainer,
)
if batch_idx == trainer.num_test_batches[dataloader_idx] - 1:
assert hasattr(self, "indexer")
self.indexer.save()
return super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)

def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
Expand Down

0 comments on commit 0796b89

Please sign in to comment.