Skip to content

Commit

Permalink
add log random sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Jul 20, 2024
1 parent 8afe7bd commit ca1e8d0
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions lightning_ir/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(
run_path_or_id: Path | str,
depth: int,
sample_size: int,
sampling_strategy: Literal["single_relevant", "top", "random"],
sampling_strategy: Literal["single_relevant", "top", "random", "log_random"],
targets: Literal["relevance", "subtopic_relevance", "rank", "score"] | None = None,
normalize_targets: bool = False,
) -> None:
Expand Down Expand Up @@ -345,8 +345,12 @@ def __getitem__(self, idx: int) -> RunSample:
group = pd.concat([relevant, non_relevant])
elif self.sampling_strategy == "top":
group = group.head(self.sample_size)
elif self.sampling_strategy == "random":
group = group.sample(self.sample_size)
elif "random" in self.sampling_strategy:
weights = None
if self.sampling_strategy == "log_random":
weights = 1 / np.log1p(group["rank"])
weights[weights.isna()] = weights.min()
group = group.sample(self.sample_size, weights=weights)
else:
raise ValueError("Invalid sampling strategy.")

Expand Down

0 comments on commit ca1e8d0

Please sign in to comment.