Skip to content

Commit

Permalink
Adds Ability to Sub-Sample Data for Data Constrained Scaling Law Expe…
Browse files Browse the repository at this point in the history
…riments
  • Loading branch information
Helw150 committed Jan 30, 2025
1 parent 04a81ca commit 4fb5cd7
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/levanter/data/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
randomize_blocks: bool = True,
key: PRNGKeyArray | int,
stop_strategy: str = StopStrategy.RESTART_STRATEGY,
simulated_data_ratio: float = 1,
):
super().__init__()
if isinstance(weights, dict):
Expand Down Expand Up @@ -99,6 +100,9 @@ def __init__(
raise NotImplementedError("Only restart strategy is supported for now.")

self.stop_strategy = stop_strategy
if simulated_data_ratio > 1:
raise ValueError(f"Simulated data ratio must be at most 1, got {simulated_data_ratio}")
self.simulated_data_ratio = simulated_data_ratio

# Initialize stage-related counts and IDs
(
Expand Down Expand Up @@ -275,7 +279,13 @@ async def _remap_indices(self, ds, indices_into_ds):
if self.stop_strategy == StopStrategy.RESTART_STRATEGY:
if ds.is_finite():
max_elem = max(indices_into_ds)
length_of_dataset = await ds.wait_until_len_at_least(max_elem + 1)
# Remap Indices Earlier when simulating epoching for a larger budget
if self.simulated_data_ratio < 1:
# Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok
true_length_of_dataset = await ds.async_len()
length_of_dataset = int(true_length_of_dataset * self.simulated_data_ratio)
else:
length_of_dataset = await ds.wait_until_len_at_least(max_elem + 1)
indices_into_ds = [idx % length_of_dataset for idx in indices_into_ds]

return indices_into_ds
Expand Down
11 changes: 11 additions & 0 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,11 @@ class LMMixtureDatasetConfig(LMTaskConfig):
""" Dataset mixing weights. Either a constant dict[name->weight] or list of (step, weights) tuples """

stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY)

# Configuration for Simulated Epoching
target_budget: Optional[int] = None
experiment_budget: Optional[int] = None

mixture_block_size: int = 2048
""" Block size for deterministic mixing """

Expand Down Expand Up @@ -1226,12 +1231,18 @@ def shuffle_ds(ds, key):
out_token_datasets[name] = shuffle_ds(ds, next(key_iter))
token_datasets = out_token_datasets

if self.experiment_budget is not None and self.target_budget is not None:
simulated_data_ratio = self.experiment_budget / self.target_budget
else:
simulated_data_ratio = 1

mixture = MixtureDataset(
datasets=token_datasets,
weights=self.train_weights,
stop_strategy=self.stop_strategy,
key=mix_key,
block_size=self.mixture_block_size,
simulated_data_ratio=simulated_data_ratio,
)

return mixture
Expand Down
18 changes: 18 additions & 0 deletions tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ async def test_mixture_dataset_stop_strategy_restart():
await mixture_ds.async_len()


@pytest.mark.asyncio
async def test_mixture_dataset_simulated_data_size():
weights = {"ds1": 1 / 3, "ds2": 1 / 3, "ds3": 1 / 3}
mixture_ds = MixtureDataset(
datasets(),
weights,
block_size=10,
key=key(),
randomize_blocks=False,
stop_strategy=StopStrategy.RESTART_STRATEGY,
simulated_data_ratio=0.2,
)
for _ in range(10):
batch = await mixture_ds.get_batch([0, 1, 2])
assert len(batch) == 3
assert all(item in [1, 10, 100] for item in batch)


@pytest.mark.asyncio
async def test_mixture_dataset_normalized_weights():
weights = {"ds1": 0, "ds2": 0.5, "ds3": 0.5}
Expand Down

0 comments on commit 4fb5cd7

Please sign in to comment.