Skip to content

Commit

Permalink
add dataset registration callback
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 15, 2024
1 parent 91d4403 commit a33ecaf
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 38 deletions.
2 changes: 1 addition & 1 deletion lightning_ir/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"QueryDataset",
"QuerySample",
"RankBatch",
"RunDataset",
"RankSample",
"RunDataset",
"SearchBatch",
"TrainBatch",
"TupleDataset",
Expand Down
31 changes: 22 additions & 9 deletions lightning_ir/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@
class IRDataset:
def __init__(self, dataset: str) -> None:
super().__init__()
if dataset in self.DASHED_DATASET_MAP:
dataset = self.DASHED_DATASET_MAP[dataset]
self.dataset = dataset
try:
self.ir_dataset = ir_datasets.load(dataset)
except KeyError:
self.ir_dataset = None
self._dataset = dataset
self._queries = None
self._docs = None
self._qrels = None

@property
def dataset(self) -> str:
return self.DASHED_DATASET_MAP.get(self._dataset, self._dataset)

@property
def ir_dataset(self) -> ir_datasets.Dataset | None:
try:
return ir_datasets.load(self.dataset)
except KeyError:
return None

@property
def DASHED_DATASET_MAP(self) -> Dict[str, str]:
return {dataset.replace("/", "-"): dataset for dataset in ir_datasets.registry._registered}
Expand Down Expand Up @@ -224,6 +229,7 @@ def __init__(
self.sampling_strategy = sampling_strategy
self.targets = targets
self.normalize_targets = normalize_targets
self.add_non_retrieved_docs = add_non_retrieved_docs

if self.sampling_strategy == "top" and self.sample_size > self.depth:
warnings.warn(
Expand All @@ -232,6 +238,11 @@ def __init__(
"in the run file, but that are present in the qrels."
)

self.run: pd.DataFrame | None = None

def _setup(self):
if self.run is not None:
return
self.run = self.load_run()
self.run = self.run.drop_duplicates(["query_id", "doc_id"])

Expand All @@ -243,7 +254,7 @@ def __init__(
self.run = self.run[self.run["query_id"].isin(query_ids)]
# outer join if docs are from ir_datasets else only keep docs in run
how = "left"
if self._docs is None and add_non_retrieved_docs:
if self._docs is None and self.add_non_retrieved_docs:
how = "outer"
self.run = self.run.merge(
self.qrels.loc[pd.IndexSlice[query_ids, :]].add_prefix("relevance_", axis=1),
Expand Down Expand Up @@ -372,7 +383,7 @@ def load_run(self) -> pd.DataFrame:
def qrels(self) -> pd.DataFrame | None:
if self._qrels is not None:
return self._qrels
if "relevance" in self.run:
if self.run is not None and "relevance" in self.run:
qrels = self.run[["query_id", "doc_id", "relevance"]].copy()
if "iteration" in self.run:
qrels["iteration"] = self.run["iteration"]
Expand All @@ -387,9 +398,11 @@ def qrels(self) -> pd.DataFrame | None:
return super().qrels

def __len__(self) -> int:
self._setup()
return len(self.query_ids)

def __getitem__(self, idx: int) -> RankSample:
self._setup()
query_id = str(self.query_ids[idx])
group = self.run_groups.get_group(query_id).copy()
query = self.queries[query_id]
Expand Down
38 changes: 19 additions & 19 deletions lightning_ir/data/ir_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
}


def load_constituent(
constituent: str | None,
def _load_constituent(
constituent: Path | str | None,
constituent_type: Literal["docs", "queries", "qrels", "scoreddocs", "docpairs"],
**kwargs,
) -> Any:
Expand All @@ -45,23 +45,23 @@ def load_constituent(
return ConstituentType(Cache(None, constituent_path), **kwargs)


def register_local(
def _register_local_dataset(
dataset_id: str,
docs: str | None = None,
queries: str | None = None,
qrels: str | None = None,
docpairs: str | None = None,
scoreddocs: str | None = None,
docs: Path | str | None = None,
queries: Path | str | None = None,
qrels: Path | str | None = None,
docpairs: Path | str | None = None,
scoreddocs: Path | str | None = None,
qrels_defs: Dict[int, str] | None = None,
):
if dataset_id in ir_datasets.registry._registered:
return

docs = load_constituent(docs, "docs")
queries = load_constituent(queries, "queries")
qrels = load_constituent(qrels, "qrels", qrels_defs=qrels_defs if qrels_defs is not None else {})
docpairs = load_constituent(docpairs, "docpairs")
scoreddocs = load_constituent(scoreddocs, "scoreddocs")
docs = _load_constituent(docs, "docs")
queries = _load_constituent(queries, "queries")
qrels = _load_constituent(qrels, "qrels", qrels_defs=qrels_defs if qrels_defs is not None else {})
docpairs = _load_constituent(docpairs, "docpairs")
scoreddocs = _load_constituent(scoreddocs, "scoreddocs")

ir_datasets.registry.register(dataset_id, Dataset(docs, queries, qrels, docpairs, scoreddocs))

Expand Down Expand Up @@ -107,7 +107,7 @@ def docpairs_cls(self):
return ScoredDocTuple


def register_kd_docpairs():
def _register_kd_docpairs():
base_id = "msmarco-passage"
split_id = "train"
file_id = "kd-docpairs"
Expand All @@ -124,7 +124,7 @@ def register_kd_docpairs():
register_msmarco(base_id, split_id, file_id, cache_path, dlc_contents, file_name, ScoredDocTuples)


def register_colbert_docpairs():
def _register_colbert_docpairs():
base_id = "msmarco-passage"
split_id = "train"
file_id = "colbert-docpairs"
Expand All @@ -140,7 +140,7 @@ def register_colbert_docpairs():
register_msmarco(base_id, split_id, file_id, cache_path, dlc_contents, file_name, ScoredDocTuples)


def register_rank_distillm():
def _register_rank_distillm():
base_id = "msmarco-passage"
split_id = "train"
file_id = "rank-distillm/rankzephyr"
Expand Down Expand Up @@ -215,6 +215,6 @@ def register_msmarco(
ir_datasets.registry.register(dataset_id, Dataset(dataset))


register_kd_docpairs()
register_colbert_docpairs()
register_rank_distillm()
_register_kd_docpairs()
_register_colbert_docpairs()
_register_rank_distillm()
57 changes: 55 additions & 2 deletions lightning_ir/lightning_utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ..data import RankBatch, SearchBatch
from ..data.dataset import RUN_HEADER, DocDataset, QueryDataset, RunDataset
from ..data.ir_datasets_utils import _register_local_dataset
from ..retrieve import IndexConfig, Indexer, SearchConfig, Searcher

if TYPE_CHECKING:
Expand Down Expand Up @@ -284,12 +285,12 @@ def write_on_batch_end(
run_df["system"] = pl_module.model.__class__.__name__
run_df = run_df[RUN_HEADER]

self.run_dfs.append(run_df)

if batch_idx == trainer.num_test_batches[dataloader_idx] - 1:
self.write_run_dfs(trainer, pl_module, dataloader_idx)
self.run_dfs = []

self.run_dfs.append(run_df)


class SearchCallback(RankCallback, _IndexDirMixin):
def __init__(
Expand Down Expand Up @@ -347,3 +348,55 @@ def rank(

class ReRankCallback(RankCallback):
pass


class RegisterLocalDatasetCallback(Callback):

def __init__(
self,
dataset_id: str,
docs: str | None = None,
queries: str | None = None,
qrels: str | None = None,
docpairs: str | None = None,
scoreddocs: str | None = None,
qrels_defs: Dict[int, str] | None = None,
):
"""Registers a local dataset with ``ir_datasets``. After registering the dataset, it can be loaded using
``ir_datasets.load(dataset_id)``. Currently, the following (optionally gzipped) file types are supported:
- ``.tsv``, ``.json``, or ``.jsonl`` for documents and queries
- ``.tsv`` or ``.qrels`` for qrels
- ``.tsv`` for training n-tuples
- ``.tsv`` or ``.run`` for scored documents / run files
:param dataset_id: Dataset id
:type dataset_id: str
:param docs: Path to documents file or valid ir_datasets id from which documents should be taken, defaults to None
:type docs: str | None, optional
:param queries: Path to queries file or valid ir_datastes id from which queries should be taken, defaults to None
:type queries: str | None, optional
:param qrels: Path to qrels file or valid ir_datasets id from which qrels will be taken, defaults to None
:type qrels: str | None, optional
:param docpairs: Path to training n-tuple file or valid ir_datasets id from which training tuples will be taken,
defaults to None
:type docpairs: str | None, optional
:param scoreddocs: Path to run file or valid ir_datasets id from which scored documents will be taken,
defaults to None
:type scoreddocs: str | None, optional
:param qrels_defs: Optional dictionary describing the relevance levels of the qrels, defaults to None
:type qrels_defs: Dict[int, str] | None, optional
"""
super().__init__()
self.dataset_id = dataset_id
self.docs = docs
self.queries = queries
self.qrels = qrels
self.docpairs = docpairs
self.scoreddocs = scoreddocs
self.qrels_defs = qrels_defs

def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Literal["fit", "validate", "test"]) -> None:
_register_local_dataset(
self.dataset_id, self.docs, self.queries, self.qrels, self.docpairs, self.scoreddocs, self.qrels_defs
)
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
RankNet,
RunDataset,
)
from lightning_ir.data.ir_datasets_utils import register_local
from lightning_ir.data.ir_datasets_utils import _register_local_dataset

DATA_DIR = Path(__file__).parent / "data"
CORPUS_DIR = DATA_DIR / "corpus"
RUNS_DIR = DATA_DIR / "runs"

CONFIGS = Union[BiEncoderConfig, CrossEncoderConfig]

register_local(
_register_local_dataset(
dataset_id="lightning-ir",
docs=str(CORPUS_DIR / "docs.tsv"),
queries=str(CORPUS_DIR / "queries.tsv"),
Expand Down
33 changes: 28 additions & 5 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from pathlib import Path
from typing import Sequence

import ir_datasets
import pandas as pd
import pytest
from _pytest.fixtures import SubRequest

from lightning_ir import BiEncoderModule, LightningIRDataModule, LightningIRModule, LightningIRTrainer, RunDataset
from lightning_ir.lightning_utils.callbacks import IndexCallback, ReRankCallback, SearchCallback
from lightning_ir.lightning_utils.callbacks import (
IndexCallback,
RegisterLocalDatasetCallback,
ReRankCallback,
SearchCallback,
)
from lightning_ir.retrieve import (
FaissFlatIndexConfig,
FaissSearchConfig,
Expand All @@ -16,7 +22,7 @@
)
from lightning_ir.retrieve.indexer import IndexConfig

from .conftest import DATA_DIR
from .conftest import CORPUS_DIR, DATA_DIR


@pytest.fixture(
Expand All @@ -42,7 +48,7 @@ def test_index_callback(
# devices: int,
):
index_dir = tmp_path / "index"
index_callback = IndexCallback(index_dir, index_config)
index_callback = IndexCallback(index_config=index_config, index_dir=index_dir)

trainer = LightningIRTrainer(
# devices=devices,
Expand Down Expand Up @@ -82,7 +88,7 @@ def get_index(
if index_dir.exists():
return index_dir / "lightning-ir"

index_callback = IndexCallback(index_dir, index_config)
index_callback = IndexCallback(index_config=index_config, index_dir=index_dir)

trainer = LightningIRTrainer(
logger=False,
Expand Down Expand Up @@ -111,7 +117,7 @@ def test_search_callback(
):
index_dir = get_index(bi_encoder_module, doc_datamodule, search_config)
save_dir = tmp_path / "runs"
search_callback = SearchCallback(index_dir, search_config, save_dir)
search_callback = SearchCallback(search_config=search_config, index_dir=index_dir, save_dir=save_dir)

trainer = LightningIRTrainer(
logger=False,
Expand Down Expand Up @@ -152,3 +158,20 @@ def test_rerank_callback(tmp_path: Path, module: LightningIRModule, inference_da
names=["query_id", "Q0", "doc_id", "rank", "score", "system"],
)
assert run_df["query_id"].nunique() == len(dataset)


def test_register_local_dataset_callback(module: LightningIRModule):
callback = RegisterLocalDatasetCallback(
dataset_id="test",
docs=str(CORPUS_DIR / "docs.tsv"),
queries=str(CORPUS_DIR / "queries.tsv"),
qrels=str(CORPUS_DIR / "qrels.tsv"),
docpairs=str(CORPUS_DIR / "docpairs.tsv"),
)
datamodule = LightningIRDataModule(train_dataset=RunDataset("test"))

trainer = LightningIRTrainer(logger=False, enable_checkpointing=False, callbacks=[callback])

trainer.test(module, datamodule)

assert ir_datasets.registry._registered.get("test") is not None

0 comments on commit a33ecaf

Please sign in to comment.