diff --git a/lightning_ir/callbacks/callbacks.py b/lightning_ir/callbacks/callbacks.py index 80d2e8e..6cfdbdd 100644 --- a/lightning_ir/callbacks/callbacks.py +++ b/lightning_ir/callbacks/callbacks.py @@ -413,6 +413,8 @@ def _get_searcher(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_id dataset = dataloaders[dataset_idx].dataset index_dir = self._get_index_dir(pl_module, dataset) + if self.searcher is not None and self.searcher.index_dir == index_dir: + return self.searcher searcher = self.search_config.search_class(index_dir, self.search_config, pl_module, self.use_gpu) return searcher