Skip to content

Commit

Permalink
remove setup + clean loading
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Aug 29, 2024
1 parent 6c4841f commit ba3337c
Showing 1 changed file with 47 additions and 40 deletions.
87 changes: 47 additions & 40 deletions lightning_ir/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ def dataset_id(self) -> str:
def docs_dataset_id(self) -> str:
return ir_datasets.docs_parent_id(self.dataset_id)

def setup(self, stage: Literal["fit", "validate", "test"] | None) -> "IRDataset":
return self


class DataParallelIterableDataset(IterableDataset):
# https://github.com/Lightning-AI/pytorch-lightning/issues/15734
Expand Down Expand Up @@ -193,6 +190,8 @@ def sample(
sample_size: int,
sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"],
) -> pd.DataFrame:
if sample_size == -1:
return df
if hasattr(Sampler, sampling_strategy):
return getattr(Sampler, sampling_strategy)(df, sample_size)
raise ValueError("Invalid sampling strategy.")
Expand All @@ -202,9 +201,9 @@ class RunDataset(IRDataset, Dataset):
def __init__(
self,
run_path_or_id: Path | str,
depth: int,
sample_size: int,
sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"],
depth: int = -1,
sample_size: int = -1,
sampling_strategy: Literal["single_relevant", "top", "random", "log_random", "top_and_random"] = "top",
targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None,
normalize_targets: bool = False,
) -> None:
Expand All @@ -215,31 +214,21 @@ def __init__(
else:
dataset = str(run_path_or_id)
super().__init__(dataset)
if depth != -1 and sample_size == -1:
sample_size = depth
self.depth = depth
self.sample_size = sample_size
self.sampling_strategy = sampling_strategy
self.targets = targets
self.normalize_targets = normalize_targets

self.run: pd.DataFrame

if self.sampling_strategy == "top" and self.sample_size > self.depth:
warnings.warn(
"Sample size is greater than depth and top sampling strategy is used. "
"This can cause documents to be sampled that are not contained "
"in the run file, but that are present in the qrels."
)

def setup(self, stage: Literal["fit", "validate", "test"] | None = None) -> "RunDataset":
super().setup(stage)
if stage == "fit":
if self.targets is None:
raise ValueError("Targets are required for training.")
if stage == "test":
if self.targets is not None:
warnings.warn("Targets are ignored in predict stage.")
self.targets = None

self.run = self.load_run()
self.run = self.run.drop_duplicates(["query_id", "doc_id"])

Expand All @@ -257,19 +246,17 @@ def setup(self, stage: Literal["fit", "validate", "test"] | None = None) -> "Run
), # outer join if docs are from ir_datasets else only keep docs in run
)

if stage == "fit":
if sample_size != -1:
num_docs_per_query = self.run.groupby("query_id").transform("size")
self.run = self.run[num_docs_per_query >= self.sample_size]

self.run = self.run.sort_values(["query_id", "rank"])
self.run_groups = self.run.groupby("query_id")
self.query_ids = list(self.run_groups.groups.keys())

if self.run["rank"].max() < self.depth:
if self.depth != -1 and self.run["rank"].max() < self.depth:
warnings.warn("Depth is greater than the maximum rank in the run file.")

return self

@staticmethod
def load_csv(path: Path) -> pd.DataFrame:
return pd.read_csv(
Expand Down Expand Up @@ -311,32 +298,18 @@ def load_json(path: Path) -> pd.DataFrame:
)
return run

def load_run(self) -> pd.DataFrame:
def _get_run_path(self) -> Path | None:
run_path = self.run_path

suffix_load_map = {
".tsv": self.load_csv,
".run": self.load_csv,
".csv": self.load_csv,
".parquet": self.load_parquet,
".json": self.load_json,
".jsonl": self.load_json,
}
if run_path is None:
if self.ir_dataset is None or not self.ir_dataset.has_scoreddocs():
raise ValueError("Run file or dataset with scoreddocs required.")
try:
run_path = self.ir_dataset.scoreddocs_handler().scoreddocs_path()
except NotImplementedError:
pass
if run_path is not None and run_path.suffixes[0] in suffix_load_map:
run = suffix_load_map[run_path.suffixes[0]](run_path)
elif self.ir_dataset is not None and self.ir_dataset.has_scoreddocs():
run = pd.DataFrame(self.ir_dataset.scoreddocs_iter())
run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False)
run = run.sort_values(["query_id", "rank"])
else:
raise ValueError("Invalid run file format.")
return run_path

def _clean_run(self, run: pd.DataFrame) -> pd.DataFrame:
run = run.rename(
{"qid": "query_id", "docid": "doc_id", "docno": "doc_id"},
axis=1,
Expand All @@ -355,6 +328,40 @@ def load_run(self) -> pd.DataFrame:
run = run.astype(dtypes)
return run

def load_run(self) -> pd.DataFrame:

suffix_load_map = {
".tsv": self.load_csv,
".run": self.load_csv,
".csv": self.load_csv,
".parquet": self.load_parquet,
".json": self.load_json,
".jsonl": self.load_json,
}
run = None

# try loading run from file
run_path = self._get_run_path()
if run_path is not None:
load_func = suffix_load_map.get(run_path.suffixes[0], None)
if load_func is not None:
try:
run = load_func(run_path)
except Exception:
pass

# try loading run from ir_datasets
if run is None and self.ir_dataset is not None and self.ir_dataset.has_scoreddocs():
run = pd.DataFrame(self.ir_dataset.scoreddocs_iter())
run["rank"] = run.groupby("query_id")["score"].rank("first", ascending=False)
run = run.sort_values(["query_id", "rank"])

if run is None:
raise ValueError("Invalid run file format.")

run = self._clean_run(run)
return run

@property
def qrels(self) -> pd.DataFrame | None:
if self._qrels is not None:
Expand Down

0 comments on commit ba3337c

Please sign in to comment.