diff --git a/src/levanter/data/mixture.py b/src/levanter/data/mixture.py index 188e5e426..57fd792bb 100644 --- a/src/levanter/data/mixture.py +++ b/src/levanter/data/mixture.py @@ -52,6 +52,7 @@ def __init__( randomize_blocks: bool = True, key: PRNGKeyArray | int, stop_strategy: str = StopStrategy.RESTART_STRATEGY, + simulated_data_ratio: float = 1, ): super().__init__() if isinstance(weights, dict): @@ -99,6 +100,9 @@ def __init__( raise NotImplementedError("Only restart strategy is supported for now.") self.stop_strategy = stop_strategy + if simulated_data_ratio > 1: + raise ValueError(f"Simulated data ratio must be at most 1, got {simulated_data_ratio}") + self.simulated_data_ratio = simulated_data_ratio # Initialize stage-related counts and IDs ( @@ -275,7 +279,13 @@ async def _remap_indices(self, ds, indices_into_ds): if self.stop_strategy == StopStrategy.RESTART_STRATEGY: if ds.is_finite(): max_elem = max(indices_into_ds) - length_of_dataset = await ds.wait_until_len_at_least(max_elem + 1) + # Remap Indices Earlier when simulating epoching for a larger budget + if self.simulated_data_ratio < 1: + # Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok + true_length_of_dataset = await ds.async_len() + length_of_dataset = int(true_length_of_dataset * self.simulated_data_ratio) + else: + length_of_dataset = await ds.wait_until_len_at_least(max_elem + 1) indices_into_ds = [idx % length_of_dataset for idx in indices_into_ds] return indices_into_ds diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 8af05698c..355de9631 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -1169,6 +1169,11 @@ class LMMixtureDatasetConfig(LMTaskConfig): """ Dataset mixing weights. Either a constant dict[name->weight] or list of (step, weights) tuples """ stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY) + + # Configuration for Simulated Epoching + target_budget: Optional[int] = None + experiment_budget: Optional[int] = None + mixture_block_size: int = 2048 """ Block size for deterministic mixing """ @@ -1226,12 +1231,18 @@ def shuffle_ds(ds, key): out_token_datasets[name] = shuffle_ds(ds, next(key_iter)) token_datasets = out_token_datasets + if self.experiment_budget is not None and self.target_budget is not None: + simulated_data_ratio = self.experiment_budget / self.target_budget + else: + simulated_data_ratio = 1 + mixture = MixtureDataset( datasets=token_datasets, weights=self.train_weights, stop_strategy=self.stop_strategy, key=mix_key, block_size=self.mixture_block_size, + simulated_data_ratio=simulated_data_ratio, ) return mixture diff --git a/tests/test_mixture.py b/tests/test_mixture.py index 8ae6dbb1b..450082153 100644 --- a/tests/test_mixture.py +++ b/tests/test_mixture.py @@ -73,6 +73,24 @@ async def test_mixture_dataset_stop_strategy_restart(): await mixture_ds.async_len() +@pytest.mark.asyncio +async def test_mixture_dataset_simulated_data_size(): + weights = {"ds1": 1 / 3, "ds2": 1 / 3, "ds3": 1 / 3} + mixture_ds = MixtureDataset( + datasets(), + weights, + block_size=10, + key=key(), + randomize_blocks=False, + stop_strategy=StopStrategy.RESTART_STRATEGY, + simulated_data_ratio=0.2, + ) + for _ in range(10): + batch = await mixture_ds.get_batch([0, 1, 2]) + assert len(batch) == 3 + assert all(item in [1, 10, 100] for item in batch) + + @pytest.mark.asyncio async def test_mixture_dataset_normalized_weights(): weights = {"ds1": 0, "ds2": 0.5, "ds3": 0.5}