Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Dec 16, 2024
1 parent 0e68bdb commit 3e9816c
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions lightning_ir/callbacks/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Module containing callbacks for indexing, searching, ranking, and registering custom datasets."""

from __future__ import annotations

import itertools
from dataclasses import is_dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Sequence, Tuple, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, TypeVar

import pandas as pd
import torch
Expand Down Expand Up @@ -41,10 +43,10 @@ def _format_large_number(number: float) -> str:
class _GatherMixin:
"""Mixin to gather dataclasses across all processes"""

def gather(self, pl_module: LightningIRModule, dataclass: T) -> T:
def _gather(self, pl_module: LightningIRModule, dataclass: T) -> T:
if is_dataclass(dataclass):
return dataclass.__class__(
**{k: self.gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__}
**{k: self._gather(pl_module, getattr(dataclass, k)) for k in dataclass.__dataclass_fields__}
)
return pl_module.all_gather(dataclass)

Expand All @@ -54,7 +56,7 @@ class _IndexDirMixin:

index_dir: Path | str | None

def get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
def _get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path:
index_dir = self.index_dir
if index_dir is None:
default_index_dir = Path(pl_module.config.name_or_path)
Expand All @@ -70,11 +72,11 @@ def get_index_dir(self, pl_module: BiEncoderModule, dataset: DocDataset) -> Path
class _OverwriteMixin:
"""Mixin to skip datasets (for indexing or searching) if they already exist"""

overwrite: bool
_get_save_path: Callable[[Trainer, LightningModule, int], Path]

def _remove_overwrite_datasets(self, trainer: Trainer, pl_module: LightningIRModule, stage: str) -> None:
if not self.overwrite:
overwrite = getattr(self, "overwrite", False)
if not overwrite:
datasets = list(trainer.datamodule.inference_datasets)
remove_datasets = []
for dataset_idx in range(len(datasets)):
Expand All @@ -96,7 +98,7 @@ def __init__(
overwrite: bool = False,
verbose: bool = False,
) -> None:
"""Callback to index documents using an :py:class:`Indexer`.
"""Callback to index documents using an :py:class:`~lightning_ir.retrieve.base.indexer.Indexer`.
:param index_config: Configuration for the indexer
:type index_config: IndexConfig
Expand Down Expand Up @@ -131,7 +133,7 @@ def setup(self, trainer: Trainer, pl_module: BiEncoderModule, stage: str) -> Non
self._remove_overwrite_datasets(trainer, pl_module, stage)

def _get_save_path(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Path:
return self.get_index_dir(pl_module, trainer.datamodule.inference_datasets[dataset_idx])
return self._get_index_dir(pl_module, trainer.datamodule.inference_datasets[dataset_idx])

def _get_indexer(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_idx: int) -> Indexer:
dataloaders = trainer.test_dataloaders
Expand Down Expand Up @@ -160,7 +162,7 @@ def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
:param pl_module: LightningIR BiEncoderModule
:type pl_module: BiEncoderModule
:raises ValueError: If no test_dataloaders are found
:raises ValueError: If not all test datasets are :py:class:`DocDataset`
:raises ValueError: If not all test datasets are :py:class:`~lightning_ir.data.dataset.DocDataset`
"""
dataloaders = trainer.test_dataloaders
if dataloaders is None:
Expand Down Expand Up @@ -213,8 +215,8 @@ def on_test_batch_end(
:param dataloader_idx: Index of the dataloader, defaults to 0
:type dataloader_idx: int, optional
"""
batch = self.gather(pl_module, batch)
outputs = self.gather(pl_module, outputs)
batch = self._gather(pl_module, batch)
outputs = self._gather(pl_module, outputs)

if not trainer.is_global_zero:
return
Expand Down Expand Up @@ -342,8 +344,8 @@ def on_test_batch_end(
:type dataloader_idx: int, optional
"""
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
batch = self.gather(pl_module, batch)
outputs = self.gather(pl_module, outputs)
batch = self._gather(pl_module, batch)
outputs = self._gather(pl_module, outputs)
if not trainer.is_global_zero:
return

Expand Down Expand Up @@ -382,7 +384,7 @@ def __init__(
) -> None:
"""Callback to which uses index to retrieve documents efficiently.
:param search_config: Configuration of the :py:class:`~Searcher`
:param search_config: Configuration of the :py:class:`~lightning_ir.retrieve.base.searcher.Searcher`
:type search_config: SearchConfig
:param index_dir: Directory where indexes are stored, defaults to None
:type index_dir: Path | str | None, optional
Expand Down Expand Up @@ -410,7 +412,7 @@ def _get_searcher(self, trainer: Trainer, pl_module: BiEncoderModule, dataset_id
raise ValueError("No test_dataloaders found")
dataset = dataloaders[dataset_idx].dataset

index_dir = self.get_index_dir(pl_module, dataset)
index_dir = self._get_index_dir(pl_module, dataset)

searcher = self.search_config.search_class(index_dir, self.search_config, pl_module, self.use_gpu)
return searcher
Expand All @@ -432,7 +434,7 @@ def on_test_start(self, trainer: Trainer, pl_module: BiEncoderModule) -> None:
:param pl_module: LightningIR BiEncoderModule
:type pl_module: BiEncoderModule
:raises ValueError: If no test_dataloaders are found
:raises ValueError: If not all datasets are :py:class:`~QueryDataset`
:raises ValueError: If not all datasets are :py:class:`~lightning_ir.data.dataset.QueryDataset`
"""
dataloaders = trainer.test_dataloaders
if dataloaders is None:
Expand Down

0 comments on commit 3e9816c

Please sign in to comment.