From c25572fd7f9a8cc728a5f1dcf549eb51ef7ada34 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Fri, 17 Jan 2025 12:46:34 -0800 Subject: [PATCH] fix --- src/levanter/data/text.py | 368 ++++++++------------------------------ 1 file changed, 74 insertions(+), 294 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 18a0ff62b..710e2d4ef 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -991,9 +991,11 @@ def preprocess_chat_example_for_packing( batch, tokenizer: PreTrainedTokenizerBase, should_append_eos: bool -) -> list[PromptCompletion]: - """Preprocesses chat examples into PromptCompletion objects for packing.""" - outputs = [] +) -> dict: + """Preprocesses chat examples into a cacheable format for packing.""" + input_ids_list = [] + prompt_lengths = [] + for example in batch: # Tokenize input (prompt) separately to get length input_ids = tokenizer(example["input"], truncation=True)["input_ids"] @@ -1006,293 +1008,66 @@ def preprocess_chat_example_for_packing( full_sequence = example["input"] + target full_ids = tokenizer(full_sequence, truncation=True)["input_ids"] - outputs.append(PromptCompletion( - ids=full_ids, - prompt_length=len(input_ids) - )) - return outputs - - -@dataclass -class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig): - """This class supports loading data both from HF Datasets and from a raw dataset of jsonl urls""" - - cache_dir: Optional[str] = "cache/" - - def train_set( - self, - seq_len: int, - monitors: Union[bool, List[MetricsMonitor]] = True, - *, - key: Optional[PRNGKeyArray] = None, - epochs: Optional[int] = None, - ) -> AsyncDataset[np.ndarray]: - - ds: AsyncDataset[np.ndarray] | None = self.token_seq_dataset("train", seq_len, monitors) - - # add epoch flag here. - if ds is None: - raise ValueError("No training set!") - - if epochs: - logger.info("Wrapping dataset in epoch dataset") - ds = EpochDataset(ds, max_epochs=epochs) - - if self.shuffle is True: - ds = ds.shuffle(key) - elif isinstance(self.shuffle, int) and self.shuffle > 0: - ds = ds.era_shuffle(self.shuffle, key=key) - - return ds # type: ignore - - def validation_set( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Optional[TokenSeqDataset]: - return self.token_seq_dataset("validation", seq_len, monitors) - - def validation_sets( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, AsyncDataset[np.ndarray]]: - validation_set = self.validation_set(seq_len, monitors) - if validation_set is not None: - return {"": validation_set} - else: - return {} - - @property - def sources(self) -> Mapping[str, LMDatasetSourceConfig]: - return {"": self} - - @cached_property - def _has_validation_set(self): - if len(self.validation_urls) > 0: - return True - - if self.id is not None: - dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream, split="validation") - try: - next(iter(dataset)) - return True - except StopIteration: - return False - - return False - - def token_seq_dataset( - self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Optional[TokenSeqDataset]: - cache = self.build_or_load_cache(split, monitors=monitors) - if cache is None: - return None - return TokenSeqDataset(cache, seq_len) - - def build_or_load_cache( - self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None - ) -> Optional[TreeCache[BatchEncoding]]: - if self.cache_dir is None: - raise ValueError("cache_dir cannot be None") - - split_cache_dir = os.path.join(self.cache_dir, split) - name = logger_name or os.path.basename(self.cache_dir) - - try: - # TODO: pass in options - return TreeCache.load(split_cache_dir, exemplar={"input_ids": np.zeros(0, dtype=np.int32)}) - except FileNotFoundError: - pass - - source = self.get_shard_source(split) - if source is None: - logger.info(f"No data for {split}") - return None - - logger.info(f"Building cache for {split}...") - - if monitors is True: - monitors = [ - LoggingMetricsMonitor(prefix=f"preprocessing/{name}/{split}", commit=False), - LoggerMetricsMonitor(f"preprocessing.{name}.{split}"), - ] - elif monitors is False: - monitors = [] - - bt = BatchTokenizer(self.the_tokenizer, enforce_bos=True, enforce_eos=self.enforce_eos) - - return build_or_load_cache( - split_cache_dir, - source, - bt, - monitors=monitors, - await_finished=False, - options=self.cache_options, - split=split, - ) - - -class PassthroughTokenizer(PreTrainedTokenizer): - def __init__(self, vocab_size, **kwargs): - self._vocab = {i: i for i in range(vocab_size)} - self._vocab_size = vocab_size - super().__init__(**kwargs) - - @property - def vocab_size(self) -> int: - return self._vocab_size - - def get_vocab(self): - return self._vocab - - def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, ...]: - return () - - def _tokenize(self, text, **kwargs): - tokens = np.fromstring(text, dtype=int, sep=" ") - return tokens - - def _convert_token_to_id(self, token: str) -> int: - return int(token) - - def _convert_id_to_token(self, index: int) -> str: - return str(index) - - -@dataclass -class LMMixtureDatasetConfig(LMTaskConfig): - """This class represents a mixture of datasets with their associated weights.""" - - cache_dir: Optional[str] = "cache/" - - # data source configs and weights - configs: Dict[str, LMDatasetSourceConfig] = field(default_factory=dict) - """ configuration of each dataset source (urls, hf dataset id, etc.) """ - train_weights: Dict[str, float] = field(default_factory=dict) - """ weights for each dataset source. They will be normalized to sum to 1. """ - stop_strategy: str = field(default=StopStrategy.RESTART_STRATEGY) - mixture_block_size: int = 2048 - """ block size for the mixture dataset.""" - - def __post_init__(self): - if len(self.configs) == 0: - raise ValueError("At least one dataset must be provided") - - if set(self.configs.keys()) != set(self.train_weights.keys()): - raise ValueError( - f"The keys in configs and weights must be the same;got {self.configs.keys()} and" - f" {self.train_weights.keys()}" - ) - - def train_set( - self, - seq_len: int, - monitors: Union[bool, List[MetricsMonitor]] = True, - *, - key: Optional[PRNGKeyArray], - epochs: Optional[int] = None, - ) -> AsyncDataset[np.ndarray]: - doc_caches = self.build_caches("train", monitors=monitors) - token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} - - if epochs: - raise ValueError("Epochs are not supported for mixture datasets") - - if key is None: - key = jax.random.PRNGKey(0) - - mix_key, shuffle_key = jax.random.split(key) - - # We shuffle the components and not the overall mixture because this lets us preserve - # the "stable batch" property of the mixture dataset. - def shuffle_ds(ds, key): - if self.shuffle is True: - ds = ds.shuffle(key) - elif isinstance(self.shuffle, int): - ds = ds.era_shuffle(self.shuffle, key=key) - - return ds - - if self.shuffle: - out_token_datasets = {} - key_iter = key_iterator(shuffle_key) - for name, ds in token_datasets.items(): - out_token_datasets[name] = shuffle_ds(ds, next(key_iter)) - token_datasets = out_token_datasets - - mixture = MixtureDataset( - datasets=token_datasets, - weights=self.train_weights, - stop_strategy=self.stop_strategy, - key=mix_key, - block_size=2048, - ) - - return mixture + input_ids_list.append(np.array(full_ids, dtype=np.int32)) + prompt_lengths.append(len(input_ids)) + + # Return a dictionary of numpy arrays that can be cached + return { + "input_ids": input_ids_list, + "prompt_length": np.array(prompt_lengths, dtype=np.int32) + } - def training_sets( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, TokenSeqDataset]: - doc_caches = self.build_caches("train", monitors=monitors) - token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} - return token_datasets - def validation_sets( - self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Mapping[str, AsyncDataset[np.ndarray]]: - doc_caches = self.build_caches("validation", monitors=monitors) - token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()} - return token_datasets - - def build_caches( - self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True - ) -> Dict[str, TreeCache[dict]]: - # this is a bit gross, but we want to forward all "Task" config fields to the LMDatasetConfig for building. - # We do this by just grabbing all the fields from the LMDatasetConfig and forwarding them to the - # LMDatasetConfig.build_or_load_cache method. We exclude the cache_dir field. - task_config_fields = set(x.name for x in dataclasses.fields(LMTaskConfig)) - task_config_dict = {k: v for k, v in self.__dict__.items() if k in task_config_fields and k != "cache_dir"} - - caches = {} - for name, source_config in self.configs.items(): - weight = self.train_weights.get(name, 0) - - if weight == 0 and split == "train": - continue +def mk_chat_sft_packed_dataset( + config: ChatUrlDataSourceConfig, + tokenizer: PreTrainedTokenizerBase, + Pos: hax.Axis, + *, + max_segments_per_example: int = 4, +) -> AsyncDataset[LmExample]: + """Creates a packed dataset from chat data for more efficient training.""" + source = config.get_shard_source("train") + if source is None: + raise ValueError("No training data source found") - source_config_dict = dict(**source_config.__dict__) - - if source_config.cache_dir is None: - # replace with the main cache dir/{name} - if self.cache_dir is None: - raise ValueError( - "If the 'main' cache_dir is None, then all component cache_dirs must be non-None, but" - f"{name}'s cache_dir is None." - ) - cache_dir = os.path.join(self.cache_dir, name) - source_config_dict["cache_dir"] = cache_dir - - dataset = LMDatasetConfig( - **source_config_dict, - **task_config_dict, - ) - cache = dataset.build_or_load_cache(split, monitors) - # drop the data source and corresponding weight if the cache is not built - if cache is None: - logger.warning(f"Skipping {name} for split {split} because no source was provided") - else: - caches[name] = cache + # Check if we need to manually append EOS + input_ids = tokenizer("hi there")["input_ids"] + should_append_eos = input_ids[-1] != tokenizer.eos_token_id - # in practice it works best if we block on validation caches - if split == "validation": - for cache in caches.values(): - cache.await_finished() + # First process into cacheable format + dataset = source.map_batches( + lambda ex: preprocess_chat_example_for_packing(ex, tokenizer, should_append_eos), + batch_size=128, + num_cpus=num_cpus_used_by_tokenizer(tokenizer), + output_exemplar={ + "input_ids": np.zeros(0, dtype=np.int32), + "prompt_length": np.zeros(0, dtype=np.int32) + } + ) - else: - logger.info(f"Not waiting for {split} caches to finish building") + # Cache the processed data + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache( + config.cache_dir, + await_finished=True + ) - return caches + # Convert cached dictionaries to PromptCompletions and pack them + def prepare_and_pack(examples: list[dict]) -> list[LmExample]: + completions = [ + PromptCompletion( + ids=ex["input_ids"].tolist(), + prompt_length=int(ex["prompt_length"]) + ) for ex in examples + ] + return list(pack_prompt_completions( + Pos=Pos, + sequences=completions, + pad_token=tokenizer.pad_token_id, + max_segments_per_example=max_segments_per_example, + )) - @property - def sources(self) -> Mapping[str, LMDatasetSourceConfig]: - return self.configs + # Pack the examples + return cached_dataset.map_batches(prepare_and_pack) def datasource_from_chat_jsonl( @@ -1352,7 +1127,7 @@ def mk_chat_sft_packed_dataset( tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis, *, - max_segments_per_example: int = 4, # How many sequences to pack into one example + max_segments_per_example: int = 4, ) -> AsyncDataset[LmExample]: """Creates a packed dataset from chat data for more efficient training.""" source = config.get_shard_source("train") @@ -1362,27 +1137,32 @@ def mk_chat_sft_packed_dataset( # Check if we need to manually append EOS input_ids = tokenizer("hi there")["input_ids"] should_append_eos = input_ids[-1] != tokenizer.eos_token_id - logger.info(f"Manual EOS Needed: {should_append_eos}") - # Ensure padding token is set - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - # First process into PromptCompletion objects + # First process into cacheable format dataset = source.map_batches( lambda ex: preprocess_chat_example_for_packing(ex, tokenizer, should_append_eos), batch_size=128, num_cpus=num_cpus_used_by_tokenizer(tokenizer), + output_exemplar={ + "input_ids": np.zeros(0, dtype=np.int32), + "prompt_length": np.zeros(0, dtype=np.int32) + } ) # Cache the processed data - cached_dataset: AsyncDataset[list[PromptCompletion]] = dataset.build_or_load_cache( + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache( config.cache_dir, await_finished=True ) - # Function to pack completions into LmExamples - def pack_batch(completions: list[PromptCompletion]) -> list[LmExample]: + # Convert cached dictionaries to PromptCompletions and pack them + def prepare_and_pack(examples: list[dict]) -> list[LmExample]: + completions = [ + PromptCompletion( + ids=ex["input_ids"].tolist(), + prompt_length=int(ex["prompt_length"]) + ) for ex in examples + ] return list(pack_prompt_completions( Pos=Pos, sequences=completions, @@ -1391,4 +1171,4 @@ def pack_batch(completions: list[PromptCompletion]) -> list[LmExample]: )) # Pack the examples - return cached_dataset.map_batches(pack_batch) + return cached_dataset.map_batches(prepare_and_pack)