Skip to content

Commit

Permalink
Adds Ability to Sub-Sample Data for Data Constrained Scaling Law Expe…
Browse files Browse the repository at this point in the history
…riments (#872)

![image](https://github.com/user-attachments/assets/cc381e88-79b4-4810-bb66-351ddf7c3b04)


Allows mixture datasets to specify a target budget and a experiment
budget. This then computes what percentage of the data to sample overall
in order to enable data constrained experiments like the above figure.
  • Loading branch information
Helw150 authored Feb 1, 2025
1 parent 04a81ca commit 1d216d1
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 0 deletions.
54 changes: 54 additions & 0 deletions src/levanter/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def map(self, fn: MapFunction[U], *extra_args, **extra_kwargs) -> "MappedAsyncDa
def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> "BatchMappedAsyncDataset[U]":
return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs)

def slice_dataset(self, start_index: Optional[int] = None, end_index: Optional[int] = None):
return SlicedAsyncDataset(self, start_index, end_index)

def shuffle(self, key: PRNGKey):
import levanter.data.permutation as permutation

Expand Down Expand Up @@ -375,6 +378,57 @@ def _call_fn(self, index, item):
return self.fn(item, *self._extra_args, **kwargs)


class SlicedAsyncDataset(AsyncDataset[U]):
def __init__(
self,
dataset: AsyncDataset[U],
start_index: Optional[int] = None,
end_index: Optional[int] = None,
):
super().__init__()
if start_index is None:
start_index = 0
if end_index is not None and start_index > end_index:
raise ValueError("End index must come after start index.")

self.start_index = start_index
self.end_index = end_index
self.dataset = dataset
self._min_known_len = dataset._min_known_len if end_index is None else (end_index - start_index)

async def get_batch(self, indices: Sequence[int]) -> Sequence[U]:
shifted_indices = [(index + self.start_index) for index in indices]
max_index = max(shifted_indices)

if self.end_index is not None and max_index > self.end_index:
raise ValueError("Requested indices beyond the end of the dataset")

return await self.dataset.get_batch(shifted_indices)

async def async_len(self) -> int:
underlying_length = await self.dataset.async_len()
if self.end_index is None:
return underlying_length - self.start_index
else:
return self.end_index - self.start_index

async def final_length_is_known(self) -> bool:
underlying_is_known = await self.dataset.final_length_is_known()
return underlying_is_known and self.end_index is not None

def is_finite(self) -> bool:
return self.dataset.is_finite() and self.end_index is not None

async def current_len(self) -> Optional[int]:
underlying_length = await self.dataset.current_len()
if self.end_index is not None:
return self.end_index - self.start_index
elif underlying_length is not None:
return underlying_length - self.start_index
else:
return underlying_length


class BatchMappedAsyncDataset(AsyncDataset[U]):
"""
A dataset that applies a function to each batch of items in the dataset.
Expand Down
22 changes: 22 additions & 0 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,11 @@ class LMMixtureDatasetConfig(LMTaskConfig):
""" Dataset mixing weights. Either a constant dict[name->weight] or list of (step, weights) tuples """

stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY)

# Configuration for Simulated Epoching
target_budget: Optional[int] = None
experiment_budget: Optional[int] = None

mixture_block_size: int = 2048
""" Block size for deterministic mixing """

Expand Down Expand Up @@ -1226,6 +1231,23 @@ def shuffle_ds(ds, key):
out_token_datasets[name] = shuffle_ds(ds, next(key_iter))
token_datasets = out_token_datasets

if (
self.experiment_budget is not None and self.target_budget is not None
) and self.experiment_budget > self.target_budget:
raise ValueError(
f"Experiment budget should be smaller than target budget, got {self.experiment_budget} >"
f" {self.target_budget}"
)
if self.experiment_budget is not None and self.target_budget is not None:
simulated_data_ratio = self.experiment_budget / self.target_budget
sliced_token_datasets: Dict[str, TokenSeqDataset] = {}
for name, ds in token_datasets.items():
# Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok
true_length_of_dataset = len(ds.as_sync_dataset())
simulated_length_of_dataset = int(true_length_of_dataset * simulated_data_ratio)
sliced_token_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset)
token_datasets = sliced_token_datasets

mixture = MixtureDataset(
datasets=token_datasets,
weights=self.train_weights,
Expand Down
30 changes: 30 additions & 0 deletions tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,36 @@ async def test_mixture_dataset_stop_strategy_restart():
await mixture_ds.async_len()


@pytest.mark.asyncio
async def test_mixture_dataset_simulated_data_size():
weights = {"ds1": 1 / 3, "ds2": 1 / 3, "ds3": 1 / 3}
mixture_ds = MixtureDataset(
{name: dataset.slice_dataset(end_index=1) for name, dataset in datasets().items()},
weights,
block_size=10,
key=key(),
randomize_blocks=False,
stop_strategy=StopStrategy.RESTART_STRATEGY,
)
for _ in range(10):
batch = await mixture_ds.get_batch([0, 1, 2])
assert len(batch) == 3
assert all(item in [1, 10, 100] for item in batch)

mixture_ds = MixtureDataset(
{name: dataset.slice_dataset(end_index=2) for name, dataset in datasets().items()},
weights,
block_size=10,
key=key(),
randomize_blocks=False,
stop_strategy=StopStrategy.RESTART_STRATEGY,
)
for _ in range(10):
batch = await mixture_ds.get_batch([0, 1, 2])
assert len(batch) == 3
assert all(item in [1, 2, 10, 20, 100, 200] for item in batch)


@pytest.mark.asyncio
async def test_mixture_dataset_normalized_weights():
weights = {"ds1": 0, "ds2": 0.5, "ds3": 0.5}
Expand Down

0 comments on commit 1d216d1

Please sign in to comment.