diff --git a/lightning_ir/data/dataset.py b/lightning_ir/data/dataset.py index 296c73b..c06eadf 100644 --- a/lightning_ir/data/dataset.py +++ b/lightning_ir/data/dataset.py @@ -81,9 +81,6 @@ def dataset_id(self) -> str: def docs_dataset_id(self) -> str: return ir_datasets.docs_parent_id(self.dataset_id) - def setup(self, stage: Literal["fit", "validate", "test"] | None) -> "IRDataset": - return self - class DataParallelIterableDataset(IterableDataset): # https://github.com/Lightning-AI/pytorch-lightning/issues/15734 @@ -193,6 +190,8 @@ def sample( sample_size: int, sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"], ) -> pd.DataFrame: + if sample_size == -1: + return df if hasattr(Sampler, sampling_strategy): return getattr(Sampler, sampling_strategy)(df, sample_size) raise ValueError("Invalid sampling strategy.") @@ -202,9 +201,9 @@ class RunDataset(IRDataset, Dataset): def __init__( self, run_path_or_id: Path | str, - depth: int, - sample_size: int, - sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"], + depth: int = -1, + sample_size: int = -1, + sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top", targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None, normalize_targets: bool = False, ) -> None: @@ -215,14 +214,14 @@ def __init__( else: dataset = str(run_path_or_id) super().__init__(dataset) + if depth != -1 and sample_size == -1: + sample_size = depth self.depth = depth self.sample_size = sample_size self.sampling_strategy = sampling_strategy self.targets = targets self.normalize_targets = normalize_targets - self.run: pd.DataFrame - if self.sampling_strategy == "top" and self.sample_size > self.depth: warnings.warn( "Sample size is greater than depth and top sampling strategy is used. " @@ -230,16 +229,6 @@ def __init__( "in the run file, but that are present in the qrels." ) - def setup(self, stage: Literal["fit", "validate", "test"] | None = None) -> "RunDataset": - super().setup(stage) - if stage == "fit": - if self.targets is None: - raise ValueError("Targets are required for training.") - if stage == "test": - if self.targets is not None: - warnings.warn("Targets are ignored in predict stage.") - self.targets = None - self.run = self.load_run() self.run = self.run.drop_duplicates(["query_id", "doc_id"]) @@ -257,7 +246,7 @@ def setup(self, stage: Literal["fit", "validate", "test"] | None = None) -> "Run ), # outer join if docs are from ir_datasets else only keep docs in run ) - if stage == "fit": + if sample_size != -1: num_docs_per_query = self.run.groupby("query_id").transform("size") self.run = self.run[num_docs_per_query >= self.sample_size] @@ -265,11 +254,9 @@ def setup(self, stage: Literal["fit", "validate", "test"] | None = None) -> "Run self.run_groups = self.run.groupby("query_id") self.query_ids = list(self.run_groups.groups.keys()) - if self.run["rank"].max() < self.depth: + if self.depth != -1 and self.run["rank"].max() < self.depth: warnings.warn("Depth is greater than the maximum rank in the run file.") - return self - @staticmethod def load_csv(path: Path) -> pd.DataFrame: return pd.read_csv( @@ -311,17 +298,8 @@ def load_json(path: Path) -> pd.DataFrame: ) return run - def load_run(self) -> pd.DataFrame: + def _get_run_path(self) -> Path | None: run_path = self.run_path - - suffix_load_map = { - ".tsv": self.load_csv, - ".run": self.load_csv, - ".csv": self.load_csv, - ".parquet": self.load_parquet, - ".json": self.load_json, - ".jsonl": self.load_json, - } if run_path is None: if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs(): raise ValueError("Run file or dataset with scoreddocs required.") @@ -329,14 +307,9 @@ def load_run(self) -> pd.DataFrame: run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path() except NotImplementedError: pass - if run_path is not None and run_path.suffixes[0] in suffix_load_map: - run = suffix_load_map[run_path.suffixes[0]](run_path) - elif self.ir_dataset is not None and self.ir_dataset.has_scoreddocs(): - run = pd.DataFrame(self.ir_dataset.scoreddocs_iter()) - run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False) - run = run.sort_values(["query_id", "rank"]) - else: - raise ValueError("Invalid run file format.") + return run_path + + def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame: run = run.rename( {"qid": "query_id", "docid": "doc_id", "docno": "doc_id"}, axis=1, @@ -355,6 +328,40 @@ def load_run(self) -> pd.DataFrame: run = run.astype(dtypes) return run + def load_run(self) -> pd.DataFrame: + + suffix_load_map = { + ".tsv": self.load_csv, + ".run": self.load_csv, + ".csv": self.load_csv, + ".parquet": self.load_parquet, + ".json": self.load_json, + ".jsonl": self.load_json, + } + run = None + + # try loading run from file + run_path = self._get_run_path() + if run_path is not None: + load_func = suffix_load_map.get(run_path.suffixes[0], None) + if load_func is not None: + try: + run = load_func(run_path) + except Exception: + pass + + # try loading run from ir_datasets + if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs(): + run = pd.DataFrame(self.ir_dataset.scoreddocs_iter()) + run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False) + run = run.sort_values(["query_id", "rank"]) + + if run is None: + raise ValueError("Invalid run file format.") + + run = self._clean_run(run) + return run + @property def qrels(self) -> pd.DataFrame | None: if self._qrels is not None: