Skip to content

Commit

Permalink
PR and Sanity
Browse files Browse the repository at this point in the history
  • Loading branch information
Helw150 committed Feb 1, 2025
1 parent 65fbbdb commit ae0305d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
20 changes: 11 additions & 9 deletions src/levanter/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ def map(self, fn: MapFunction[U], *extra_args, **extra_kwargs) -> "MappedAsyncDa
def map_batches(self, fn: MapFunction[Sequence[U]], *extra_args, **extra_kwargs) -> "BatchMappedAsyncDataset[U]":
return BatchMappedAsyncDataset(self, fn, *extra_args, **extra_kwargs)

def slice_dataset(
self, start_index: Optional[int] = None, end_index: Optional[int] = None
) -> "SlicedAsyncDataset[U]":
def slice_dataset(self, start_index: Optional[int] = None, end_index: Optional[int] = None):
return SlicedAsyncDataset(self, start_index, end_index)

def shuffle(self, key: PRNGKey):
Expand Down Expand Up @@ -383,27 +381,29 @@ def _call_fn(self, index, item):
class SlicedAsyncDataset(AsyncDataset[U]):
def __init__(
self,
dataset: AsyncDataset[T],
dataset: AsyncDataset[U],
start_index: Optional[int] = None,
end_index: Optional[int] = None,
):
super().__init__()
if start_index is None:
start_index = 0
if end_index is not None and start_index > end_index:
raise ValueError("End index must come after start index.")

self.start_index = start_index
self.end_index = end_index
self.dataset = dataset
self._min_known_len = dataset._min_known_len if start_index is not None else (end_index - start_index)
self._min_known_len = dataset._min_known_len if end_index is None else (end_index - start_index)

async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
async def get_batch(self, indices: Sequence[int]) -> Sequence[U]:
shifted_indices = [(index + self.start_index) for index in indices]
max_index = max(shifted_indices)

if self.end_index is not None and max_index > self.end_index:
raise ValueError("Requested indices beyond the end of the dataset")

return await self.dataset.get_batch(indices)
return await self.dataset.get_batch(shifted_indices)

async def async_len(self) -> int:
underlying_length = await self.dataset.async_len()
Expand All @@ -421,10 +421,12 @@ def is_finite(self) -> bool:

async def current_len(self) -> Optional[int]:
underlying_length = await self.dataset.current_len()
if self.end_index is None:
if self.end_index is not None:
return self.end_index - self.start_index
elif underlying_length is not None:
return underlying_length - self.start_index
else:
return self.end_index - self.start_index
return underlying_length


class BatchMappedAsyncDataset(AsyncDataset[U]):
Expand Down
4 changes: 3 additions & 1 deletion src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,11 +1240,13 @@ def shuffle_ds(ds, key):
)
if self.experiment_budget is not None and self.target_budget is not None:
simulated_data_ratio = self.experiment_budget / self.target_budget
sliced_token_datasets: Dict[str, TokenSeqDataset] = {}
for name, ds in token_datasets.items():
# Note(Will): This blocks on datasets being fully processed even for small simulated runs making simulating data size slightly latency inducing but I think that's ok
true_length_of_dataset = len(ds.as_sync_dataset())
simulated_length_of_dataset = int(true_length_of_dataset * simulated_data_ratio)
token_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset)
sliced_token_datasets[name] = ds.slice_dataset(end_index=simulated_length_of_dataset)
token_datasets = sliced_token_datasets

mixture = MixtureDataset(
datasets=token_datasets,
Expand Down

0 comments on commit ae0305d

Please sign in to comment.