From a33ecaf1dc93a2f694b5e9b8d8fceab7a33f0b10 Mon Sep 17 00:00:00 2001 From: Ferdinand Schlatt Date: Fri, 15 Nov 2024 15:50:17 +0100 Subject: [PATCH] add dataset registration callback --- lightning_ir/data/__init__.py | 2 +- lightning_ir/data/dataset.py | 31 ++++++++---- lightning_ir/data/ir_datasets_utils.py | 38 +++++++-------- lightning_ir/lightning_utils/callbacks.py | 57 ++++++++++++++++++++++- tests/conftest.py | 4 +- tests/test_callbacks.py | 33 +++++++++++-- 6 files changed, 127 insertions(+), 38 deletions(-) diff --git a/lightning_ir/data/__init__.py b/lightning_ir/data/__init__.py index fc7c3e3..c5d56e1 100644 --- a/lightning_ir/data/__init__.py +++ b/lightning_ir/data/__init__.py @@ -10,8 +10,8 @@ "QueryDataset", "QuerySample", "RankBatch", - "RunDataset", "RankSample", + "RunDataset", "SearchBatch", "TrainBatch", "TupleDataset", diff --git a/lightning_ir/data/dataset.py b/lightning_ir/data/dataset.py index ac37f6b..2bf92cf 100644 --- a/lightning_ir/data/dataset.py +++ b/lightning_ir/data/dataset.py @@ -20,17 +20,22 @@ class IRDataset: def __init__(self, dataset: str) -> None: super().__init__() - if dataset in self.DASHED_DATASET_MAP: - dataset = self.DASHED_DATASET_MAP[dataset] - self.dataset = dataset - try: - self.ir_dataset = ir_datasets.load(dataset) - except KeyError: - self.ir_dataset = None + self._dataset = dataset self._queries = None self._docs = None self._qrels = None + @property + def dataset(self) -> str: + return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset) + + @property + def ir_dataset(self) -> ir_datasets.Dataset | None: + try: + return ir_datasets.load(self.dataset) + except KeyError: + return None + @property def DASHED_DATASET_MAP(self) -> Dict[str, str]: return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered} @@ -224,6 +229,7 @@ def __init__( self.sampling_strategy = sampling_strategy self.targets = targets self.normalize_targets = normalize_targets + self.add_non_retrieved_docs = add_non_retrieved_docs if self.sampling_strategy == "top" and self.sample_size > self.depth: warnings.warn( @@ -232,6 +238,11 @@ def __init__( "in the run file, but that are present in the qrels." ) + self.run: pd.DataFrame | None = None + + def _setup(self): + if self.run is not None: + return self.run = self.load_run() self.run = self.run.drop_duplicates(["query_id", "doc_id"]) @@ -243,7 +254,7 @@ def __init__( self.run = self.run[self.run["query_id"].isin(query_ids)] # outer join if docs are from ir_datasets else only keep docs in run how = "left" - if self._docs is None and add_non_retrieved_docs: + if self._docs is None and self.add_non_retrieved_docs: how = "outer" self.run = self.run.merge( self.qrels.loc[pd.IndexSlice[query_ids, :]].add_prefix("relevance_", axis=1), @@ -372,7 +383,7 @@ def load_run(self) -> pd.DataFrame: def qrels(self) -> pd.DataFrame | None: if self._qrels is not None: return self._qrels - if "relevance" in self.run: + if self.run is not None and "relevance" in self.run: qrels = self.run[["query_id", "doc_id", "relevance"]].copy() if "iteration" in self.run: qrels["iteration"] = self.run["iteration"] @@ -387,9 +398,11 @@ def qrels(self) -> pd.DataFrame | None: return super().qrels def __len__(self) -> int: + self._setup() return len(self.query_ids) def __getitem__(self, idx: int) -> RankSample: + self._setup() query_id = str(self.query_ids[idx]) group = self.run_groups.get_group(query_id).copy() query = self.queries[query_id] diff --git a/lightning_ir/data/ir_datasets_utils.py b/lightning_ir/data/ir_datasets_utils.py index 17f3ade..1f35341 100644 --- a/lightning_ir/data/ir_datasets_utils.py +++ b/lightning_ir/data/ir_datasets_utils.py @@ -25,8 +25,8 @@ } -def load_constituent( - constituent: str | None, +def _load_constituent( + constituent: Path | str | None, constituent_type: Literal["docs", "queries", "qrels", "scoreddocs", "docpairs"], **kwargs, ) -> Any: @@ -45,23 +45,23 @@ def load_constituent( return ConstituentType(Cache(None, constituent_path), **kwargs) -def register_local( +def _register_local_dataset( dataset_id: str, - docs: str | None = None, - queries: str | None = None, - qrels: str | None = None, - docpairs: str | None = None, - scoreddocs: str | None = None, + docs: Path | str | None = None, + queries: Path | str | None = None, + qrels: Path | str | None = None, + docpairs: Path | str | None = None, + scoreddocs: Path | str | None = None, qrels_defs: Dict[int, str] | None = None, ): if dataset_id in ir_datasets.registry._registered: return - docs = load_constituent(docs, "docs") - queries = load_constituent(queries, "queries") - qrels = load_constituent(qrels, "qrels", qrels_defs=qrels_defs if qrels_defs is not None else {}) - docpairs = load_constituent(docpairs, "docpairs") - scoreddocs = load_constituent(scoreddocs, "scoreddocs") + docs = _load_constituent(docs, "docs") + queries = _load_constituent(queries, "queries") + qrels = _load_constituent(qrels, "qrels", qrels_defs=qrels_defs if qrels_defs is not None else {}) + docpairs = _load_constituent(docpairs, "docpairs") + scoreddocs = _load_constituent(scoreddocs, "scoreddocs") ir_datasets.registry.register(dataset_id, Dataset(docs, queries, qrels, docpairs, scoreddocs)) @@ -107,7 +107,7 @@ def docpairs_cls(self): return ScoredDocTuple -def register_kd_docpairs(): +def _register_kd_docpairs(): base_id = "msmarco-passage" split_id = "train" file_id = "kd-docpairs" @@ -124,7 +124,7 @@ def register_kd_docpairs(): register_msmarco(base_id, split_id, file_id, cache_path, dlc_contents, file_name, ScoredDocTuples) -def register_colbert_docpairs(): +def _register_colbert_docpairs(): base_id = "msmarco-passage" split_id = "train" file_id = "colbert-docpairs" @@ -140,7 +140,7 @@ def register_colbert_docpairs(): register_msmarco(base_id, split_id, file_id, cache_path, dlc_contents, file_name, ScoredDocTuples) -def register_rank_distillm(): +def _register_rank_distillm(): base_id = "msmarco-passage" split_id = "train" file_id = "rank-distillm/rankzephyr" @@ -215,6 +215,6 @@ def register_msmarco( ir_datasets.registry.register(dataset_id, Dataset(dataset)) -register_kd_docpairs() -register_colbert_docpairs() -register_rank_distillm() +_register_kd_docpairs() +_register_colbert_docpairs() +_register_rank_distillm() diff --git a/lightning_ir/lightning_utils/callbacks.py b/lightning_ir/lightning_utils/callbacks.py index 0069433..238f75d 100644 --- a/lightning_ir/lightning_utils/callbacks.py +++ b/lightning_ir/lightning_utils/callbacks.py @@ -12,6 +12,7 @@ from ..data import RankBatch, SearchBatch from ..data.dataset import RUN_HEADER, DocDataset, QueryDataset, RunDataset +from ..data.ir_datasets_utils import _register_local_dataset from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher if TYPE_CHECKING: @@ -284,12 +285,12 @@ def write_on_batch_end( run_df["system"] = pl_module.model.__class__.__name__ run_df = run_df[RUN_HEADER] + self.run_dfs.append(run_df) + if batch_idx == trainer.num_test_batches[dataloader_idx] - 1: self.write_run_dfs(trainer, pl_module, dataloader_idx) self.run_dfs = [] - self.run_dfs.append(run_df) - class SearchCallback(RankCallback, _IndexDirMixin): def __init__( @@ -347,3 +348,55 @@ def rank( class ReRankCallback(RankCallback): pass + + +class RegisterLocalDatasetCallback(Callback): + + def __init__( + self, + dataset_id: str, + docs: str | None = None, + queries: str | None = None, + qrels: str | None = None, + docpairs: str | None = None, + scoreddocs: str | None = None, + qrels_defs: Dict[int, str] | None = None, + ): + """Registers a local dataset with ``ir_datasets``. After registering the dataset, it can be loaded using + ``ir_datasets.load(dataset_id)``. Currently, the following (optionally gzipped) file types are supported: + + - ``.tsv``, ``.json``, or ``.jsonl`` for documents and queries + - ``.tsv`` or ``.qrels`` for qrels + - ``.tsv`` for training n-tuples + - ``.tsv`` or ``.run`` for scored documents / run files + + :param dataset_id: Dataset id + :type dataset_id: str + :param docs: Path to documents file or valid ir_datasets id from which documents should be taken, defaults to None + :type docs: str | None, optional + :param queries: Path to queries file or valid ir_datastes id from which queries should be taken, defaults to None + :type queries: str | None, optional + :param qrels: Path to qrels file or valid ir_datasets id from which qrels will be taken, defaults to None + :type qrels: str | None, optional + :param docpairs: Path to training n-tuple file or valid ir_datasets id from which training tuples will be taken, + defaults to None + :type docpairs: str | None, optional + :param scoreddocs: Path to run file or valid ir_datasets id from which scored documents will be taken, + defaults to None + :type scoreddocs: str | None, optional + :param qrels_defs: Optional dictionary describing the relevance levels of the qrels, defaults to None + :type qrels_defs: Dict[int, str] | None, optional + """ + super().__init__() + self.dataset_id = dataset_id + self.docs = docs + self.queries = queries + self.qrels = qrels + self.docpairs = docpairs + self.scoreddocs = scoreddocs + self.qrels_defs = qrels_defs + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Literal["fit", "validate", "test"]) -> None: + _register_local_dataset( + self.dataset_id, self.docs, self.queries, self.qrels, self.docpairs, self.scoreddocs, self.qrels_defs + ) diff --git a/tests/conftest.py b/tests/conftest.py index dc8171d..5cf1f2b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ RankNet, RunDataset, ) -from lightning_ir.data.ir_datasets_utils import register_local +from lightning_ir.data.ir_datasets_utils import _register_local_dataset DATA_DIR = Path(__file__).parent / "data" CORPUS_DIR = DATA_DIR / "corpus" @@ -26,7 +26,7 @@ CONFIGS = Union[BiEncoderConfig, CrossEncoderConfig] -register_local( +_register_local_dataset( dataset_id="lightning-ir", docs=str(CORPUS_DIR / "docs.tsv"), queries=str(CORPUS_DIR / "queries.tsv"), diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 53b5340..96717a9 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,12 +1,18 @@ from pathlib import Path from typing import Sequence +import ir_datasets import pandas as pd import pytest from _pytest.fixtures import SubRequest from lightning_ir import BiEncoderModule, LightningIRDataModule, LightningIRModule, LightningIRTrainer, RunDataset -from lightning_ir.lightning_utils.callbacks import IndexCallback, ReRankCallback, SearchCallback +from lightning_ir.lightning_utils.callbacks import ( + IndexCallback, + RegisterLocalDatasetCallback, + ReRankCallback, + SearchCallback, +) from lightning_ir.retrieve import ( FaissFlatIndexConfig, FaissSearchConfig, @@ -16,7 +22,7 @@ ) from lightning_ir.retrieve.indexer import IndexConfig -from .conftest import DATA_DIR +from .conftest import CORPUS_DIR, DATA_DIR @pytest.fixture( @@ -42,7 +48,7 @@ def test_index_callback( # devices: int, ): index_dir = tmp_path / "index" - index_callback = IndexCallback(index_dir, index_config) + index_callback = IndexCallback(index_config=index_config, index_dir=index_dir) trainer = LightningIRTrainer( # devices=devices, @@ -82,7 +88,7 @@ def get_index( if index_dir.exists(): return index_dir / "lightning-ir" - index_callback = IndexCallback(index_dir, index_config) + index_callback = IndexCallback(index_config=index_config, index_dir=index_dir) trainer = LightningIRTrainer( logger=False, @@ -111,7 +117,7 @@ def test_search_callback( ): index_dir = get_index(bi_encoder_module, doc_datamodule, search_config) save_dir = tmp_path / "runs" - search_callback = SearchCallback(index_dir, search_config, save_dir) + search_callback = SearchCallback(search_config=search_config, index_dir=index_dir, save_dir=save_dir) trainer = LightningIRTrainer( logger=False, @@ -152,3 +158,20 @@ def test_rerank_callback(tmp_path: Path, module: LightningIRModule, inference_da names=["query_id", "Q0", "doc_id", "rank", "score", "system"], ) assert run_df["query_id"].nunique() == len(dataset) + + +def test_register_local_dataset_callback(module: LightningIRModule): + callback = RegisterLocalDatasetCallback( + dataset_id="test", + docs=str(CORPUS_DIR / "docs.tsv"), + queries=str(CORPUS_DIR / "queries.tsv"), + qrels=str(CORPUS_DIR / "qrels.tsv"), + docpairs=str(CORPUS_DIR / "docpairs.tsv"), + ) + datamodule = LightningIRDataModule(train_dataset=RunDataset("test")) + + trainer = LightningIRTrainer(logger=False, enable_checkpointing=False, callbacks=[callback]) + + trainer.test(module, datamodule) + + assert ir_datasets.registry._registered.get("test") is not None