From 0345857b0d1a4bc8f839eb9d206e0a5e7c0148e7 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Sun, 26 Jan 2025 14:28:54 -0800 Subject: [PATCH 01/16] try streaming --- config/debug_pack_sft.yaml | 39 +++++++++++++ src/levanter/data/text.py | 4 +- src/levanter/main/sft.py | 113 +++++++++++++++++++++++++++++++++++-- 3 files changed, 150 insertions(+), 6 deletions(-) create mode 100644 config/debug_pack_sft.yaml diff --git a/config/debug_pack_sft.yaml b/config/debug_pack_sft.yaml new file mode 100644 index 000000000..d5cb11503 --- /dev/null +++ b/config/debug_pack_sft.yaml @@ -0,0 +1,39 @@ +dataset_type: chat_jsonl +chat_train_urls: + - "gs://marin-us-central2/documents/allenai--tulu-v2-sft-mixture-0ba27c/data/**/*.jsonl.gz" +supervised_data: + cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer_retrypack-bca8bd/" + +tokenizer: "meta-llama/Meta-Llama-3.1-8B" +model: # 7B class model + type: llama + seq_len: 2048 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: True + flash_attention_block_size: 512 + use_bias: false + use_layer_norm_weight: false +trainer: + tracker: + type: wandb + project: "marin" + tags: ["dolma", "olmo", "llama"] + + mp: p=f32,c=bfloat16 + train_batch_size: 256 + num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 + steps_per_eval: 1000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 4E-4 + weight_decay: 0.1 + min_lr_ratio: 0.1 + warmup: 5000 + +epoch: 0 diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 13c7ea44b..066912cee 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -1004,9 +1004,9 @@ def mk_chat_sft_dataset( # Ensure padding token is set (needed by _prepare_supervised_example) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - + return cached_dataset # Reuse the supervised prepare function directly - return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) + # return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) @dataclass diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index 3f8329a2b..15f42da2c 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -3,7 +3,7 @@ import os from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional, Union +from typing import List, Optional, Union, Iterator import jax.random as jrandom import transformers @@ -24,10 +24,18 @@ mk_supervised_dataset, ) from levanter.models.llama import LlamaConfig -from levanter.models.lm_model import LmConfig, LmHeadModel, compute_next_token_loss +from levanter.models.lm_model import LmConfig, LmHeadModel, LmExample, compute_next_token_loss from levanter.optim import AdamConfig, OptimizerConfig from levanter.trainer import Trainer, TrainerConfig +from levanter.data.loader import stack_tree +from levanter.data.packing import PromptCompletion, pack_prompt_completions +from levanter.utils.background_iterable import BackgroundIterator +from levanter.utils.hf_utils import HfTokenizer +from levanter.data import batched +from levanter.utils.jax_utils import broadcast_shard, use_cpu_device +from levanter.data.dataset import AsyncDataset +import asyncio logger = logging.getLogger(__name__) @@ -202,7 +210,29 @@ def train(config: SFTConfig): callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1 ) - loader = trainer.data_loader(train_dataset, trainer.TrainBatch) + # reshuffle the examples before packing! + + # to implement seeking + # check the step number in the trainer state if it's not zero + # then next the iterator until we get there, then continue training. + # batch size will be backed in from config + + # change iterate tokenized requests to take a dict rather than a list + # of where the first element is prompt ands econd is response + + # then pass into tierate tokenizer requests, go to pack requests + # and then you have the correct loader, just pass to trainer.train() + + # TODO figure out if there's a better heuristic for max segements to pack per example? + prompt_completion_iterator = create_prompt_completion_iterator(train_dataset, Pos) + + packed_iterator = _pack_requests(prompt_completion_iterator, tokenizer, Pos, max_pack_size=4) + packed_iterator = stack_batches(packed_iterator, trainer.TrainBatch) + # TODO what's a good number for max_capacity? + packed_loader = BackgroundIterator(packed_iterator, max_capacity=256) + + # to be moved + #loader = trainer.data_loader(train_dataset, trainer.TrainBatch) if config.hf_save_path is not None: # bit gross to reach this far into the config, but it's fine @@ -216,8 +246,83 @@ def train(config: SFTConfig): every=config.hf_save_steps, ) - trainer.train(state, loader) + trainer.train(state, packed_loader) + + +# async def get_dataset_length(cached_dataset: AsyncDataset) -> int: +# """Helper function to get dataset length asynchronously""" +# return await cached_dataset.async_len() + +def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axis) -> Iterator[PromptCompletion]: + """ + Creates an iterator that yields PromptCompletion objects from a cached dataset. + + Args: + cached_dataset: The AsyncDataset containing preprocessed examples + Pos: The position axis defining maximum sequence length + + Returns: + An iterator yielding PromptCompletion objects + """ + # AsyncDataset already has a current_len method that returns current length or None + # We can use wait_until_len_at_least which will wait until the dataset has at least + # the requested length or the final length is known + length = asyncio.run(cached_dataset.wait_until_len_at_least(0)) + + if length is None: + raise ValueError("Dataset length cannot be None") + + for i in range(length): + example = asyncio.run(cached_dataset.getitem_async(i)) + + if int(example["sources_len"]) > Pos.size - 1: + continue + + ids = example["input_ids"].tolist() + if len(ids) > Pos.size: + ids = ids[:Pos.size] + + if len(ids) <= example["sources_len"]: + continue + + try: + yield PromptCompletion( + ids=ids, + prompt_length=int(example["sources_len"]), + segment_id=i + ) + except ValueError: + continue + +def _pack_requests( + prompt_completion_iterator: Iterator[PromptCompletion], tokenizer: HfTokenizer, Pos: hax.Axis, max_pack_size: int +) -> Iterator[LmExample]: + # TODO: use a better packing algorithm? + yield from pack_prompt_completions( + Pos, + prompt_completion_iterator, + max_segments_per_example=max_pack_size, + pad_token=tokenizer.pad_token_id, + max_buffered_examples=16, + ) +def stack_batches(self, example_iterator, TrainBatch): + """ + Stack examples from an iterator into a batch. + + Args: + TrainBatch: The batch axis. + example_iterator: An iterator of examples. + + Returns: + A batch of examples. + """ + with use_cpu_device(): + for batch in batched(example_iterator, TrainBatch.size): + if len(batch) < TrainBatch.size: + dummy_instance = self._make_dummy_instance(batch) + batch.extend([dummy_instance] * (TrainBatch.size - len(batch))) + yield stack_tree(TrainBatch, batch) def add_special_tokens(tokenizer, use_unk_instead_of_adding=False): special_tokens_dict = dict() From 97d35d5f66d861cee7dd1184b09885696ca79e81 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Sun, 26 Jan 2025 14:34:14 -0800 Subject: [PATCH 02/16] add updated config --- config/debug_sft.yaml | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 config/debug_sft.yaml diff --git a/config/debug_sft.yaml b/config/debug_sft.yaml new file mode 100644 index 000000000..fdbcbaf82 --- /dev/null +++ b/config/debug_sft.yaml @@ -0,0 +1,41 @@ +dataset_type: chat_jsonl +chat_train_urls: + - "gs://marin-us-central2/documents/allenai--tulu-v2-sft-mixture-0ba27c/data/**/*.jsonl.gz" +supervised_data: +# cache_dir before trying sequence packing + cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer-7b19dc" + #cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer_retrypack-bca8bd/" + +tokenizer: "meta-llama/Meta-Llama-3.1-8B" +model: # 7B class model + type: llama + seq_len: 2048 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: True + flash_attention_block_size: 512 + use_bias: false + use_layer_norm_weight: false +trainer: + tracker: + type: wandb + project: "marin" + tags: ["dolma", "olmo", "llama"] + + mp: p=f32,c=bfloat16 + train_batch_size: 256 + num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 + steps_per_eval: 1000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 4E-4 + weight_decay: 0.1 + min_lr_ratio: 0.1 + warmup: 5000 + +epoch: 0 From d2e6b3bef7d5acbc98aadf1800c504b5efcf40f1 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Sun, 26 Jan 2025 14:37:10 -0800 Subject: [PATCH 03/16] add updated config --- src/levanter/main/sft.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index 15f42da2c..59e4d6a32 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -224,11 +224,15 @@ def train(config: SFTConfig): # and then you have the correct loader, just pass to trainer.train() # TODO figure out if there's a better heuristic for max segements to pack per example? + logger.info("Creating prompt completion iterator") prompt_completion_iterator = create_prompt_completion_iterator(train_dataset, Pos) + logger.info("Packing prompt completions") packed_iterator = _pack_requests(prompt_completion_iterator, tokenizer, Pos, max_pack_size=4) + logger.info("Stacking batches to train batch") packed_iterator = stack_batches(packed_iterator, trainer.TrainBatch) # TODO what's a good number for max_capacity? + logger.info("Creating data loader") packed_loader = BackgroundIterator(packed_iterator, max_capacity=256) # to be moved From 0d7acc7a2c74f8ba9ee8d3c1a4ead2c31f356d63 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Sun, 26 Jan 2025 14:54:26 -0800 Subject: [PATCH 04/16] bug fix --- src/levanter/main/sft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index 59e4d6a32..4bcae2ce0 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -230,7 +230,7 @@ def train(config: SFTConfig): logger.info("Packing prompt completions") packed_iterator = _pack_requests(prompt_completion_iterator, tokenizer, Pos, max_pack_size=4) logger.info("Stacking batches to train batch") - packed_iterator = stack_batches(packed_iterator, trainer.TrainBatch) + packed_iterator = stack_batches(example_iterator=packed_iterator, TrainBatch=trainer.TrainBatch) # TODO what's a good number for max_capacity? logger.info("Creating data loader") packed_loader = BackgroundIterator(packed_iterator, max_capacity=256) @@ -310,7 +310,7 @@ def _pack_requests( max_buffered_examples=16, ) -def stack_batches(self, example_iterator, TrainBatch): +def stack_batches(example_iterator, TrainBatch): """ Stack examples from an iterator into a batch. From efbcc60a92a13e20ee610102a3e57689b717dcb2 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Mon, 27 Jan 2025 14:35:41 -0800 Subject: [PATCH 05/16] WIP buggy, MFU stuck at 1 --- src/levanter/main/sft.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index 4bcae2ce0..f63b8a4e9 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -186,7 +186,6 @@ def train(config: SFTConfig): # some axes we need Pos = config.model.Pos - # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of # tokens: gpt-2 has 50257, for example. So we round up. @@ -210,12 +209,12 @@ def train(config: SFTConfig): callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1 ) - # reshuffle the examples before packing! + # TODO: reshuffle the examples before packing! + # Get current step from trainer state + current_step = int(state.step) - # to implement seeking - # check the step number in the trainer state if it's not zero - # then next the iterator until we get there, then continue training. - # batch size will be backed in from config + + # change iterate tokenized requests to take a dict rather than a list # of where the first element is prompt ands econd is response @@ -227,6 +226,22 @@ def train(config: SFTConfig): logger.info("Creating prompt completion iterator") prompt_completion_iterator = create_prompt_completion_iterator(train_dataset, Pos) + if current_step > 0: + logger.info(f"Resuming training from step {current_step}") + # Calculate how many examples to skip based on batch size + examples_to_skip = current_step * trainer.config.train_batch_size + + # Skip through the iterator until we reach the right position + for _ in range(examples_to_skip): + try: + next(prompt_completion_iterator) + except StopIteration: + logger.warning("Ran out of examples while seeking - restarting from beginning") + # Recreate iterator and continue skipping + prompt_completion_iterator = create_prompt_completion_iterator(train_dataset, Pos) + else: + logger.info("Starting SFT from scratch") + logger.info("Packing prompt completions") packed_iterator = _pack_requests(prompt_completion_iterator, tokenizer, Pos, max_pack_size=4) logger.info("Stacking batches to train batch") @@ -279,20 +294,21 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi for i in range(length): example = asyncio.run(cached_dataset.getitem_async(i)) - if int(example["sources_len"]) > Pos.size - 1: + sources_len = example["sources_len"].item() + if sources_len > Pos.size - 1: continue ids = example["input_ids"].tolist() if len(ids) > Pos.size: ids = ids[:Pos.size] - if len(ids) <= example["sources_len"]: + if len(ids) <= sources_len: continue try: yield PromptCompletion( ids=ids, - prompt_length=int(example["sources_len"]), + prompt_length=sources_len, segment_id=i ) except ValueError: From 01c4a2879145aeecea0ef7e0ffb4014caa5030ac Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Thu, 30 Jan 2025 12:16:11 -0800 Subject: [PATCH 06/16] sequence packing working, open thoughts --- config/llama3_openthoughts_sft.yaml | 46 +++++++++++++++++ src/levanter/main/sft.py | 76 +++++++++++++++++------------ 2 files changed, 92 insertions(+), 30 deletions(-) create mode 100644 config/llama3_openthoughts_sft.yaml diff --git a/config/llama3_openthoughts_sft.yaml b/config/llama3_openthoughts_sft.yaml new file mode 100644 index 000000000..9d3c74acd --- /dev/null +++ b/config/llama3_openthoughts_sft.yaml @@ -0,0 +1,46 @@ +dataset_type: chat_jsonl +chat_train_urls: + - "gs://marin-us-central2/documents/open-thoughts--OpenThoughts-114k-216e29/data/**/*.jsonl.gz" +supervised_data: +# cache_dir before trying sequence packing + cache_dir: "gs://marin-us-central2/tokenized/openthoughts_llama3_tokenizer-9edd80" + #cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer_retrypack-bca8bd/" + +max_seq_len: 4096 +tokenizer: "meta-llama/Meta-Llama-3.1-8B" +model: # 7B class model + type: llama + seq_len: 4096 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: True + flash_attention_block_size: 512 + use_bias: false + use_layer_norm_weight: false + initializer_range: 0.02 +trainer: + seed: 1 + tracker: + type: wandb + project: "marin" + tags: ["dolma", "olmo", "llama"] + + mp: p=f32,c=bfloat16 + # same as 606 sft in marin + train_batch_size: 128 + num_train_steps: 7335 # 3,000,000,000,000 / 4,000,000 = 750,000 + steps_per_eval: 1000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 5e-6 + weight_decay: 0.0 + min_lr_ratio: 0.0 + lr_schedule: "linear" + warmup: 0.03 + +epoch: 0 diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index f63b8a4e9..f040f0989 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -6,16 +6,18 @@ from typing import List, Optional, Union, Iterator import jax.random as jrandom +import jax.numpy as jnp import transformers import haliax as hax from haliax import Axis from haliax.partitioning import round_axis_for_partitioning +from optax.tree_utils import tree_zeros_like import levanter from levanter import callbacks from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback -from levanter.data import PermutationDataset +from levanter.data import PermutationDataset, batched from levanter.data.text import ( ChatUrlDataSourceConfig, EpochDataset, @@ -23,6 +25,7 @@ mk_chat_sft_dataset, mk_supervised_dataset, ) +from levanter.models.attention import AttentionMask from levanter.models.llama import LlamaConfig from levanter.models.lm_model import LmConfig, LmHeadModel, LmExample, compute_next_token_loss from levanter.optim import AdamConfig, OptimizerConfig @@ -31,7 +34,6 @@ from levanter.data.packing import PromptCompletion, pack_prompt_completions from levanter.utils.background_iterable import BackgroundIterator from levanter.utils.hf_utils import HfTokenizer -from levanter.data import batched from levanter.utils.jax_utils import broadcast_shard, use_cpu_device from levanter.data.dataset import AsyncDataset @@ -208,8 +210,6 @@ def train(config: SFTConfig): trainer.add_hook( callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1 ) - - # TODO: reshuffle the examples before packing! # Get current step from trainer state current_step = int(state.step) @@ -245,10 +245,10 @@ def train(config: SFTConfig): logger.info("Packing prompt completions") packed_iterator = _pack_requests(prompt_completion_iterator, tokenizer, Pos, max_pack_size=4) logger.info("Stacking batches to train batch") - packed_iterator = stack_batches(example_iterator=packed_iterator, TrainBatch=trainer.TrainBatch) + packed_iterator = stack_batches(example_iterator=packed_iterator, Pos=Pos, TrainBatch=trainer.TrainBatch) # TODO what's a good number for max_capacity? logger.info("Creating data loader") - packed_loader = BackgroundIterator(packed_iterator, max_capacity=256) + packed_loader = BackgroundIterator(packed_iterator, max_capacity=512) # to be moved #loader = trainer.data_loader(train_dataset, trainer.TrainBatch) @@ -291,28 +291,31 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi if length is None: raise ValueError("Dataset length cannot be None") - for i in range(length): - example = asyncio.run(cached_dataset.getitem_async(i)) - - sources_len = example["sources_len"].item() - if sources_len > Pos.size - 1: - continue - - ids = example["input_ids"].tolist() - if len(ids) > Pos.size: - ids = ids[:Pos.size] - - if len(ids) <= sources_len: - continue - - try: - yield PromptCompletion( - ids=ids, - prompt_length=sources_len, - segment_id=i - ) - except ValueError: - continue + # TODO play around with batch size + for batch_indicies in batched(range(length), 128): + examples = asyncio.run(cached_dataset.get_batch(batch_indicies)) + + for i in range(len(examples)): + example = examples[i] + sources_len = example["sources_len"].item() + if sources_len > Pos.size - 1: + continue + + ids = example["input_ids"].tolist() + if len(ids) > Pos.size: + ids = ids[:Pos.size] + + if len(ids) <= sources_len: + continue + + try: + yield PromptCompletion( + ids=ids, + prompt_length=sources_len, + segment_id=batch_indicies[i] + ) + except ValueError: + continue def _pack_requests( prompt_completion_iterator: Iterator[PromptCompletion], tokenizer: HfTokenizer, Pos: hax.Axis, max_pack_size: int @@ -326,12 +329,25 @@ def _pack_requests( max_buffered_examples=16, ) -def stack_batches(example_iterator, TrainBatch): +""" +Helper function to create a dummy instance with the same shape as the batch. +When we reach the end of the dataset but we want a full batch, +will give a batch of zeros with -1 segment mask so it doesn't affect loss +""" +def _make_dummy_instance(batch, Pos): + dummy_instance: LmExample = tree_zeros_like(batch[0]) + dummy_segment_mask = hax.full(Pos, -1, dtype=jnp.int32) + dummy_attn = AttentionMask.causal().with_segment_ids(dummy_segment_mask) + dummy_instance = dataclasses.replace(dummy_instance, attn_mask=dummy_attn) + return dummy_instance + +def stack_batches(example_iterator, Pos, TrainBatch): """ Stack examples from an iterator into a batch. Args: TrainBatch: The batch axis. + Pos: The position axis. example_iterator: An iterator of examples. Returns: @@ -340,7 +356,7 @@ def stack_batches(example_iterator, TrainBatch): with use_cpu_device(): for batch in batched(example_iterator, TrainBatch.size): if len(batch) < TrainBatch.size: - dummy_instance = self._make_dummy_instance(batch) + dummy_instance = _make_dummy_instance(batch, Pos) batch.extend([dummy_instance] * (TrainBatch.size - len(batch))) yield stack_tree(TrainBatch, batch) From 3ab98859d112233e68a5e39b8e509f1d357131af Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Fri, 31 Jan 2025 22:35:15 -0800 Subject: [PATCH 07/16] try to hunt down MFU --- config/debug_sft.yaml | 19 ++++++++++++------- src/levanter/data/packing.py | 18 ++++++++++++++++-- src/levanter/main/sft.py | 35 ++++++++++++++++++++++++----------- 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/config/debug_sft.yaml b/config/debug_sft.yaml index fdbcbaf82..a77aede1c 100644 --- a/config/debug_sft.yaml +++ b/config/debug_sft.yaml @@ -6,10 +6,11 @@ supervised_data: cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer-7b19dc" #cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer_retrypack-bca8bd/" +max_seq_len: 4096 tokenizer: "meta-llama/Meta-Llama-3.1-8B" model: # 7B class model type: llama - seq_len: 2048 + seq_len: 4096 hidden_dim: 4096 intermediate_dim: 11008 num_layers: 32 @@ -19,23 +20,27 @@ model: # 7B class model flash_attention_block_size: 512 use_bias: false use_layer_norm_weight: false + initializer_range: 0.02 trainer: + seed: 1 tracker: type: wandb project: "marin" tags: ["dolma", "olmo", "llama"] mp: p=f32,c=bfloat16 - train_batch_size: 256 - num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 + # same as 606 sft in marin + train_batch_size: 128 + num_train_steps: 7335 # 3,000,000,000,000 / 4,000,000 = 750,000 steps_per_eval: 1000 tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" batch_axis: "batch" optimizer: - learning_rate: 4E-4 - weight_decay: 0.1 - min_lr_ratio: 0.1 - warmup: 5000 + learning_rate: 5e-6 + weight_decay: 0.0 + min_lr_ratio: 0.0 + lr_schedule: "linear" + warmup: 0.03 epoch: 0 diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index a049de56c..8cfa64886 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -18,6 +18,7 @@ from levanter.models.attention import AttentionMask from levanter.models.lm_model import LmExample from levanter.utils.jax_utils import local_cpu_mesh +import time # cf https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/data_generators/generator_utils.py#L623 @@ -114,17 +115,30 @@ def pack_prompt_completions( packers = [SequencePacker(Pos, max_segments_per_example, pad_token)] + # put timer in here for sequence in sequences: + start_time = time.perf_counter + # put timer here, keep in mind the loop is an iterator so we want to + # separate the time to get examples, and how long it takes to pack loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 loss_mask[-1] = 0 assert np.any(loss_mask) + # time how long to pack, and subtract + start_for_pack_yield = time.perf_counter() for packer in packers: if packer.can_pack(sequence.ids): + add_example = time.perf_counter() packer.add_example(sequence.ids, loss_mask, sequence.segment_id) - + add_example_end = time.perf_counter() + print(f" time to add example to segment apacker{ add_example_end - add_example:.6f} seconds", flush=True) if packer.num_segments == max_segments_per_example: - yield packer.pack() + ot = packer.pack() + end_time = time.perf_counter() + time_to_yield = end_time - start_for_pack_yield + print(f"total time to for loop until we yielded a packed example {time_to_yield}", flush=True) + print(f"total time for iterator {start_time - time_to_yield}", flush=True) + yield ot packers.remove(packer) break else: diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index f040f0989..0c1eb2173 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -13,6 +13,7 @@ from haliax import Axis from haliax.partitioning import round_axis_for_partitioning from optax.tree_utils import tree_zeros_like +import time import levanter from levanter import callbacks @@ -248,7 +249,7 @@ def train(config: SFTConfig): packed_iterator = stack_batches(example_iterator=packed_iterator, Pos=Pos, TrainBatch=trainer.TrainBatch) # TODO what's a good number for max_capacity? logger.info("Creating data loader") - packed_loader = BackgroundIterator(packed_iterator, max_capacity=512) + packed_loader = BackgroundIterator(packed_iterator, max_capacity=1024) # to be moved #loader = trainer.data_loader(train_dataset, trainer.TrainBatch) @@ -267,11 +268,6 @@ def train(config: SFTConfig): trainer.train(state, packed_loader) - -# async def get_dataset_length(cached_dataset: AsyncDataset) -> int: -# """Helper function to get dataset length asynchronously""" -# return await cached_dataset.async_len() - def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axis) -> Iterator[PromptCompletion]: """ Creates an iterator that yields PromptCompletion objects from a cached dataset. @@ -284,16 +280,19 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi An iterator yielding PromptCompletion objects """ # AsyncDataset already has a current_len method that returns current length or None - # We can use wait_until_len_at_least which will wait until the dataset has at least - # the requested length or the final length is known - length = asyncio.run(cached_dataset.wait_until_len_at_least(0)) + length = asyncio.run(cached_dataset.async_len()) if length is None: raise ValueError("Dataset length cannot be None") # TODO play around with batch size - for batch_indicies in batched(range(length), 128): + for batch_indicies in batched(range(length), 4096): + # put timer here + start_time = time.perf_counter() examples = asyncio.run(cached_dataset.get_batch(batch_indicies)) + end_time = time.perf_counter() + elapsed_time = end_time - start_time + print(f"Elapsed time for get batches: {elapsed_time:.6f} seconds", flush=True) for i in range(len(examples)): example = examples[i] @@ -353,12 +352,26 @@ def stack_batches(example_iterator, Pos, TrainBatch): Returns: A batch of examples. """ + # add timer here as well and profile with use_cpu_device(): + batch_count = 0 for batch in batched(example_iterator, TrainBatch.size): + batch_count += 1 + start_time_loop = time.perf_counter() if len(batch) < TrainBatch.size: dummy_instance = _make_dummy_instance(batch, Pos) batch.extend([dummy_instance] * (TrainBatch.size - len(batch))) - yield stack_tree(TrainBatch, batch) + # Start timing before calling stack_tree + start_time = time.perf_counter() + result = stack_tree(TrainBatch, batch) # Capture the result + stack_time = time.perf_counter() - start_time # Calculate elapsed time + + print(f"Stack tree execution time: {stack_time:.6f} seconds", flush=True) + yield result # Yield the computed result + end_time_loop = time.perf_counter() + loop_time = end_time_loop - start_time_loop + print(f"Loop takes {loop_time}") + print(f"Iterator time is {loop_time - stack_time}") def add_special_tokens(tokenizer, use_unk_instead_of_adding=False): special_tokens_dict = dict() From 05fbc294577b89968ac208ec6145e837c0703f73 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Sat, 1 Feb 2025 13:47:28 -0800 Subject: [PATCH 08/16] somehow mfu is fine now?? --- src/levanter/data/packing.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index 8cfa64886..e0b3c6d43 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -113,11 +113,12 @@ def pack_prompt_completions( Packs a list of prompt completions into LmExamples using the SequencePacker """ + start_packer = time.perf_counter() packers = [SequencePacker(Pos, max_segments_per_example, pad_token)] - + print(f"time to create packer is {time.perf_counter() - start_packer:.6f} seconds", flush=True) # put timer in here for sequence in sequences: - start_time = time.perf_counter + start_time = time.perf_counter() # put timer here, keep in mind the loop is an iterator so we want to # separate the time to get examples, and how long it takes to pack loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 @@ -131,27 +132,39 @@ def pack_prompt_completions( add_example = time.perf_counter() packer.add_example(sequence.ids, loss_mask, sequence.segment_id) add_example_end = time.perf_counter() - print(f" time to add example to segment apacker{ add_example_end - add_example:.6f} seconds", flush=True) + time_to_add_example = add_example_end - add_example + print(f" time to add example to segment packer { time_to_add_example:.6f} seconds", flush=True) + end_iter_plus_example = time.perf_counter() + time_to_reach_example_total = end_iter_plus_example - start_time + print(f"time to get sequence, so time to get here minus time to add example {time_to_reach_example_total - time_to_add_example:.6f} seconds", flush=True) if packer.num_segments == max_segments_per_example: ot = packer.pack() end_time = time.perf_counter() time_to_yield = end_time - start_for_pack_yield - print(f"total time to for loop until we yielded a packed example {time_to_yield}", flush=True) - print(f"total time for iterator {start_time - time_to_yield}", flush=True) + print(f"MAX SEG total time to for loop until we yielded a packed example {time_to_yield}", flush=True) + print(f"MAX SEG total time for iterator {start_time - time_to_yield}", flush=True) yield ot packers.remove(packer) break else: # no packer could fit the example, create a new one + start_new_packer = time.perf_counter() packer = SequencePacker(Pos, max_segments_per_example, pad_token) packer.add_example(sequence.ids, loss_mask, sequence.segment_id) packers.append(packer) + print(f"time to create new packer is {time.perf_counter() - start_new_packer:.6f}", flush=True) while len(packers) >= max_buffered_examples: - yield packers.pop(0).pack() + start_return_packed_example_full_buffer = time.perf_counter() + max_example = packers.pop(0).pack() + print(f"time to create lm example when max buffered examples is {time.perf_counter() - start_return_packed_example_full_buffer:.6f}", flush=True) + yield max_example for packer in packers: - yield packer.pack() + start_return_packed_example = time.perf_counter() + example = packer.pack() + print(f"time to create lm example from packer is {time.perf_counter() - start_return_packed_example:.6f}", flush=True) + yield example def per_segment_loss( From 75bcaee92fba19ac6305323375d4ea313e54fce8 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Sat, 1 Feb 2025 13:56:58 -0800 Subject: [PATCH 09/16] remove prints for now --- src/levanter/data/packing.py | 42 ++++++++++++++++++------------------ src/levanter/main/sft.py | 26 +++++++++++----------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index e0b3c6d43..4393fc9c8 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -113,12 +113,12 @@ def pack_prompt_completions( Packs a list of prompt completions into LmExamples using the SequencePacker """ - start_packer = time.perf_counter() + # start_packer = time.perf_counter() packers = [SequencePacker(Pos, max_segments_per_example, pad_token)] - print(f"time to create packer is {time.perf_counter() - start_packer:.6f} seconds", flush=True) + # print(f"time to create packer is {time.perf_counter() - start_packer:.6f} seconds", flush=True) # put timer in here for sequence in sequences: - start_time = time.perf_counter() + # start_time = time.perf_counter() # put timer here, keep in mind the loop is an iterator so we want to # separate the time to get examples, and how long it takes to pack loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 @@ -126,44 +126,44 @@ def pack_prompt_completions( assert np.any(loss_mask) # time how long to pack, and subtract - start_for_pack_yield = time.perf_counter() + #start_for_pack_yield = time.perf_counter() for packer in packers: if packer.can_pack(sequence.ids): - add_example = time.perf_counter() + # add_example = time.perf_counter() packer.add_example(sequence.ids, loss_mask, sequence.segment_id) - add_example_end = time.perf_counter() - time_to_add_example = add_example_end - add_example - print(f" time to add example to segment packer { time_to_add_example:.6f} seconds", flush=True) - end_iter_plus_example = time.perf_counter() - time_to_reach_example_total = end_iter_plus_example - start_time - print(f"time to get sequence, so time to get here minus time to add example {time_to_reach_example_total - time_to_add_example:.6f} seconds", flush=True) + # add_example_end = time.perf_counter() + # time_to_add_example = add_example_end - add_example + # print(f" time to add example to segment packer { time_to_add_example:.6f} seconds", flush=True) + # end_iter_plus_example = time.perf_counter() + # time_to_reach_example_total = end_iter_plus_example - start_time + # print(f"time to get sequence, so time to get here minus time to add example {time_to_reach_example_total - time_to_add_example:.6f} seconds", flush=True) if packer.num_segments == max_segments_per_example: ot = packer.pack() - end_time = time.perf_counter() - time_to_yield = end_time - start_for_pack_yield - print(f"MAX SEG total time to for loop until we yielded a packed example {time_to_yield}", flush=True) - print(f"MAX SEG total time for iterator {start_time - time_to_yield}", flush=True) + # end_time = time.perf_counter() + # time_to_yield = end_time - start_for_pack_yield + # print(f"MAX SEG total time to for loop until we yielded a packed example {time_to_yield}", flush=True) + # print(f"MAX SEG total time for iterator {start_time - time_to_yield}", flush=True) yield ot packers.remove(packer) break else: # no packer could fit the example, create a new one - start_new_packer = time.perf_counter() + #start_new_packer = time.perf_counter() packer = SequencePacker(Pos, max_segments_per_example, pad_token) packer.add_example(sequence.ids, loss_mask, sequence.segment_id) packers.append(packer) - print(f"time to create new packer is {time.perf_counter() - start_new_packer:.6f}", flush=True) + #print(f"time to create new packer is {time.perf_counter() - start_new_packer:.6f}", flush=True) while len(packers) >= max_buffered_examples: - start_return_packed_example_full_buffer = time.perf_counter() + #start_return_packed_example_full_buffer = time.perf_counter() max_example = packers.pop(0).pack() - print(f"time to create lm example when max buffered examples is {time.perf_counter() - start_return_packed_example_full_buffer:.6f}", flush=True) + #print(f"time to create lm example when max buffered examples is {time.perf_counter() - start_return_packed_example_full_buffer:.6f}", flush=True) yield max_example for packer in packers: - start_return_packed_example = time.perf_counter() + #start_return_packed_example = time.perf_counter() example = packer.pack() - print(f"time to create lm example from packer is {time.perf_counter() - start_return_packed_example:.6f}", flush=True) + #print(f"time to create lm example from packer is {time.perf_counter() - start_return_packed_example:.6f}", flush=True) yield example diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index 0c1eb2173..96c3fef07 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -288,11 +288,11 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi # TODO play around with batch size for batch_indicies in batched(range(length), 4096): # put timer here - start_time = time.perf_counter() + # start_time = time.perf_counter() examples = asyncio.run(cached_dataset.get_batch(batch_indicies)) - end_time = time.perf_counter() - elapsed_time = end_time - start_time - print(f"Elapsed time for get batches: {elapsed_time:.6f} seconds", flush=True) + # end_time = time.perf_counter() + # elapsed_time = end_time - start_time + # print(f"Elapsed time for get batches: {elapsed_time:.6f} seconds", flush=True) for i in range(len(examples)): example = examples[i] @@ -357,21 +357,21 @@ def stack_batches(example_iterator, Pos, TrainBatch): batch_count = 0 for batch in batched(example_iterator, TrainBatch.size): batch_count += 1 - start_time_loop = time.perf_counter() + #start_time_loop = time.perf_counter() if len(batch) < TrainBatch.size: dummy_instance = _make_dummy_instance(batch, Pos) batch.extend([dummy_instance] * (TrainBatch.size - len(batch))) - # Start timing before calling stack_tree - start_time = time.perf_counter() + # # Start timing before calling stack_tree + # start_time = time.perf_counter() result = stack_tree(TrainBatch, batch) # Capture the result - stack_time = time.perf_counter() - start_time # Calculate elapsed time + # stack_time = time.perf_counter() - start_time # Calculate elapsed time - print(f"Stack tree execution time: {stack_time:.6f} seconds", flush=True) + # print(f"Stack tree execution time: {stack_time:.6f} seconds", flush=True) yield result # Yield the computed result - end_time_loop = time.perf_counter() - loop_time = end_time_loop - start_time_loop - print(f"Loop takes {loop_time}") - print(f"Iterator time is {loop_time - stack_time}") + # end_time_loop = time.perf_counter() + # loop_time = end_time_loop - start_time_loop + # print(f"Loop takes {loop_time}") + # print(f"Iterator time is {loop_time - stack_time}") def add_special_tokens(tokenizer, use_unk_instead_of_adding=False): special_tokens_dict = dict() From 0536c63ed35ce4679780bc5456b08ba96043e692 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Mon, 3 Feb 2025 13:43:20 -0800 Subject: [PATCH 10/16] add configs --- config/{debug_sft.yaml => llama3.1_tulu3_sft.yaml} | 10 +++++++--- config/llama3_openthoughts_sft.yaml | 11 ++++++++--- ...llama_sft_hf_ckpt.yaml => llama3_sft_hf_ckpt.yaml} | 7 ++++--- 3 files changed, 19 insertions(+), 9 deletions(-) rename config/{debug_sft.yaml => llama3.1_tulu3_sft.yaml} (81%) rename config/{llama_sft_hf_ckpt.yaml => llama3_sft_hf_ckpt.yaml} (70%) diff --git a/config/debug_sft.yaml b/config/llama3.1_tulu3_sft.yaml similarity index 81% rename from config/debug_sft.yaml rename to config/llama3.1_tulu3_sft.yaml index a77aede1c..16b98ee51 100644 --- a/config/debug_sft.yaml +++ b/config/llama3.1_tulu3_sft.yaml @@ -12,10 +12,10 @@ model: # 7B class model type: llama seq_len: 4096 hidden_dim: 4096 - intermediate_dim: 11008 + intermediate_dim: 14336 num_layers: 32 num_heads: 32 - num_kv_heads: 32 + num_kv_heads: 8 use_flash_attention: True flash_attention_block_size: 512 use_bias: false @@ -31,7 +31,8 @@ trainer: mp: p=f32,c=bfloat16 # same as 606 sft in marin train_batch_size: 128 - num_train_steps: 7335 # 3,000,000,000,000 / 4,000,000 = 750,000 + # number of steps until we hit stop iteration + num_train_steps: 1791 # 3,000,000,000,000 / 4,000,000 = 750,000 steps_per_eval: 1000 tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" @@ -43,4 +44,7 @@ optimizer: lr_schedule: "linear" warmup: 0.03 +hf_save_steps: 1790 +hf_save_path: "gs://levanter-checkpoints/marin/llama_3.1_tulusft/" + epoch: 0 diff --git a/config/llama3_openthoughts_sft.yaml b/config/llama3_openthoughts_sft.yaml index 9d3c74acd..945ac65ac 100644 --- a/config/llama3_openthoughts_sft.yaml +++ b/config/llama3_openthoughts_sft.yaml @@ -12,10 +12,10 @@ model: # 7B class model type: llama seq_len: 4096 hidden_dim: 4096 - intermediate_dim: 11008 + intermediate_dim: 14336 num_layers: 32 num_heads: 32 - num_kv_heads: 32 + num_kv_heads: 8 use_flash_attention: True flash_attention_block_size: 512 use_bias: false @@ -31,7 +31,8 @@ trainer: mp: p=f32,c=bfloat16 # same as 606 sft in marin train_batch_size: 128 - num_train_steps: 7335 # 3,000,000,000,000 / 4,000,000 = 750,000 + # number of steps until we hit stop iteration + num_train_steps: 802 steps_per_eval: 1000 tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" @@ -43,4 +44,8 @@ optimizer: lr_schedule: "linear" warmup: 0.03 + +hf_save_steps: 801 +hf_save_path: "gs://levanter-checkpoints/marin/tulusft_openthoughtsft/" + epoch: 0 diff --git a/config/llama_sft_hf_ckpt.yaml b/config/llama3_sft_hf_ckpt.yaml similarity index 70% rename from config/llama_sft_hf_ckpt.yaml rename to config/llama3_sft_hf_ckpt.yaml index a5742486c..e112bc2d1 100644 --- a/config/llama_sft_hf_ckpt.yaml +++ b/config/llama3_sft_hf_ckpt.yaml @@ -1,13 +1,14 @@ # Model configuration model: type: llama - seq_len: 2048 + seq_len: 4096 hidden_dim: 4096 - intermediate_dim: 11008 + intermediate_dim: 14336 num_layers: 32 num_heads: 32 - num_kv_heads: 32 + num_kv_heads: 8 use_flash_attention: true flash_attention_block_size: 512 use_bias: false use_layer_norm_weight: false + initializer_range: 0.02 From 2edc6581068ca7d82f46f191f5ce42713db0a8e2 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 4 Feb 2025 17:44:03 -0800 Subject: [PATCH 11/16] fix llama3 config --- config/llama3.1_tulu3_sft.yaml | 7 +++++-- config/llama3_openthoughts_sft.yaml | 6 ++++-- config/llama3_sft_hf_ckpt.yaml | 7 ++++++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/config/llama3.1_tulu3_sft.yaml b/config/llama3.1_tulu3_sft.yaml index 16b98ee51..499fe2091 100644 --- a/config/llama3.1_tulu3_sft.yaml +++ b/config/llama3.1_tulu3_sft.yaml @@ -8,7 +8,7 @@ supervised_data: max_seq_len: 4096 tokenizer: "meta-llama/Meta-Llama-3.1-8B" -model: # 7B class model +model: # 8B llama3 class model type: llama seq_len: 4096 hidden_dim: 4096 @@ -19,8 +19,11 @@ model: # 7B class model use_flash_attention: True flash_attention_block_size: 512 use_bias: false - use_layer_norm_weight: false + use_layer_norm_weight: true initializer_range: 0.02 + rope: + type: "llama3" + trainer: seed: 1 tracker: diff --git a/config/llama3_openthoughts_sft.yaml b/config/llama3_openthoughts_sft.yaml index 945ac65ac..bff3f1435 100644 --- a/config/llama3_openthoughts_sft.yaml +++ b/config/llama3_openthoughts_sft.yaml @@ -8,7 +8,7 @@ supervised_data: max_seq_len: 4096 tokenizer: "meta-llama/Meta-Llama-3.1-8B" -model: # 7B class model +model: # 8B llama3 class model type: llama seq_len: 4096 hidden_dim: 4096 @@ -19,8 +19,10 @@ model: # 7B class model use_flash_attention: True flash_attention_block_size: 512 use_bias: false - use_layer_norm_weight: false + use_layer_norm_weight: true initializer_range: 0.02 + rope: + type: "llama3" trainer: seed: 1 tracker: diff --git a/config/llama3_sft_hf_ckpt.yaml b/config/llama3_sft_hf_ckpt.yaml index e112bc2d1..950ca964b 100644 --- a/config/llama3_sft_hf_ckpt.yaml +++ b/config/llama3_sft_hf_ckpt.yaml @@ -10,5 +10,10 @@ model: use_flash_attention: true flash_attention_block_size: 512 use_bias: false - use_layer_norm_weight: false + use_layer_norm_weight: true initializer_range: 0.02 + rope: + type: "llama3" + +# need to set this! +tokenizer: "meta-llama/Meta-Llama-3.1-8B" \ No newline at end of file From f5cba5dbe589ca14e3d8be0401d561b600a807cd Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Tue, 4 Feb 2025 17:55:57 -0800 Subject: [PATCH 12/16] finalize packing --- config/data/dolma_llama.yaml | 202 ++++++++++++++++++ config/data/dolma_llama_euwest.yaml | 200 +++++++++++++++++ config/llama3.1_tulu3_sft.yaml | 4 +- config/llama3_openthoughts_sft.yaml | 4 +- config/llama3_sft_hf_ckpt.yaml | 4 +- config/llama_7b_with_olmo_config_euwest4.yaml | 39 ++++ ...=> llama_7b_with_olmo_config_uswest4.yaml} | 28 +-- src/levanter/data/packing.py | 36 +--- src/levanter/data/text.py | 35 ++- src/levanter/main/sft.py | 129 +++++------ 10 files changed, 548 insertions(+), 133 deletions(-) create mode 100644 config/data/dolma_llama.yaml create mode 100644 config/data/dolma_llama_euwest.yaml create mode 100644 config/llama_7b_with_olmo_config_euwest4.yaml rename config/{debug_pack_sft.yaml => llama_7b_with_olmo_config_uswest4.yaml} (56%) diff --git a/config/data/dolma_llama.yaml b/config/data/dolma_llama.yaml new file mode 100644 index 000000000..681c5f22b --- /dev/null +++ b/config/data/dolma_llama.yaml @@ -0,0 +1,202 @@ +cache_dir: null +cache_options: + batch_size: 128 + num_shard_groups: 128 + prefetch_per_group: 4 + shard_order_randomization_key: 0 + target_size_per_flush: 512MB +configs: + dolma/algebraic-stack: + cache_dir: gs://marin-us-west4/tokenized/dolma/algebraic-stack-cc00cf + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/algebraic-stack-train-{0000..0015}.json.gz + validation_urls: [] + dolma/arxiv: + cache_dir: gs://marin-us-west4/tokenized/dolma/arxiv-07a51f + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/arxiv-{0000..0099}.json.gz + validation_urls: [] + dolma/c4: + cache_dir: gs://marin-us-west4/tokenized/dolma/c4-e0e5ec + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/c4-{0000..0170}.json.gz + validation_urls: [] + dolma/cc: + cache_dir: gs://marin-us-west4/tokenized/dolma/cc-74b017 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_head-{0000..0274}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0000..0238}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0240..0379}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0000..0152}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0154..0444}.json.gz + validation_urls: [] + dolma/cc-news: + cache_dir: gs://marin-us-west4/tokenized/dolma/cc-news-625d3e + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/cc_news_head-{0000..0004}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_news_middle-{0000..0002}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_news_tail-0000.json.gz + validation_urls: [] + dolma/falcon: + cache_dir: gs://marin-us-west4/tokenized/dolma/falcon-da8fd0 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/falcon-{0000..0499}.json.gz + validation_urls: [] + dolma/flan: + cache_dir: gs://marin-us-west4/tokenized/dolma/flan-a99cb2 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/tulu_flan-{0000..0065}.json.gz + validation_urls: [] + dolma/gutenberg: + cache_dir: gs://marin-us-west4/tokenized/dolma/gutenberg-f9eb99 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/books-{0000..0002}.json.gz + validation_urls: [] + dolma/megawika: + cache_dir: gs://marin-us-west4/tokenized/dolma/megawika-34abf2 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/megawika-{0000..0261}.json.gz + validation_urls: [] + dolma/open-web-math: + cache_dir: gs://marin-us-west4/tokenized/dolma/open-web-math-79823d + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/open-web-math-train-{0000..0012}.json.gz + validation_urls: [] + dolma/pes2o: + cache_dir: gs://marin-us-west4/tokenized/dolma/pes2o-538363 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/pes2o-{0000..0025}.json.gz + validation_urls: [] + dolma/reddit: + cache_dir: gs://marin-us-west4/tokenized/dolma/reddit-62a64a + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/reddit-{0000..0077}.json.gz + validation_urls: [] + dolma/stackexchange: + cache_dir: gs://marin-us-west4/tokenized/dolma/stackexchange-adfc49 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/stackexchange-{0000..0025}.json.gz + validation_urls: [] + dolma/starcoder: + cache_dir: gs://marin-us-west4/tokenized/dolma/starcoder-8b6089 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/starcoder-{0000..0048}.json.gz + validation_urls: [] + dolma/wiki: + cache_dir: gs://marin-us-west4/tokenized/dolma/wiki-212315 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/wiki-{0000..0001}.json.gz + validation_urls: [] +enforce_eos: true +ignore_token_id: null +mixture_block_size: 2048 +shuffle: true +stop_strategy: restart +tokenizer: nvidia/Llama-3.1-Nemotron-70B-Instruct-HF +train_weights: + dolma/algebraic-stack: 12.6 + dolma/arxiv: 28.0 + dolma/c4: 124.95 + dolma/cc: 597.75 + dolma/cc-news: 14.3 + dolma/falcon: 456.4 + dolma/flan: 16.5 + dolma/gutenberg: 5.3 + dolma/megawika: 4.6 + dolma/open-web-math: 12.6 + dolma/pes2o: 57.2 + dolma/reddit: 79.9 + dolma/stackexchange: 19.6 + dolma/starcoder: 263.8 + dolma/wiki: 7.4 +vocab_size: null \ No newline at end of file diff --git a/config/data/dolma_llama_euwest.yaml b/config/data/dolma_llama_euwest.yaml new file mode 100644 index 000000000..32876031b --- /dev/null +++ b/config/data/dolma_llama_euwest.yaml @@ -0,0 +1,200 @@ +cache_dir: null +cache_options: + batch_size: 128 + num_shard_groups: 128 + target_size_per_flush: 512MB +configs: + dolma/algebraic-stack: + cache_dir: gs://marin-eu-west4/tokenized/dolma/algebraic-stack-cc00cf + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/algebraic-stack-train-{0000..0015}.json.gz + validation_urls: [] + dolma/arxiv: + cache_dir: gs://marin-eu-west4/tokenized/dolma/arxiv-07a51f + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/arxiv-{0000..0099}.json.gz + validation_urls: [] + dolma/c4: + cache_dir: gs://marin-eu-west4/tokenized/dolma/c4-e0e5ec + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/c4-{0000..0170}.json.gz + validation_urls: [] + dolma/cc: + cache_dir: gs://marin-eu-west4/tokenized/dolma/cc-74b017 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_head-{0000..0274}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0000..0238}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_middle-{0240..0379}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0000..0152}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_en_tail-{0154..0444}.json.gz + validation_urls: [] + dolma/cc-news: + cache_dir: gs://marin-eu-west4/tokenized/dolma/cc-news-625d3e + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/cc_news_head-{0000..0004}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_news_middle-{0000..0002}.json.gz + - gs://marin-us-central2/raw/dolma/v1.7/cc_news_tail-0000.json.gz + validation_urls: [] + dolma/falcon: + cache_dir: gs://marin-eu-west4/tokenized/dolma/falcon-da8fd0 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/falcon-{0000..0499}.json.gz + validation_urls: [] + dolma/flan: + cache_dir: gs://marin-eu-west4/tokenized/dolma/flan-a99cb2 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/tulu_flan-{0000..0065}.json.gz + validation_urls: [] + dolma/gutenberg: + cache_dir: gs://marin-eu-west4/tokenized/dolma/gutenberg-f9eb99 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/books-{0000..0002}.json.gz + validation_urls: [] + dolma/megawika: + cache_dir: gs://marin-eu-west4/tokenized/dolma/megawika-34abf2 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/megawika-{0000..0261}.json.gz + validation_urls: [] + dolma/open-web-math: + cache_dir: gs://marin-eu-west4/tokenized/dolma/open-web-math-79823d + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/open-web-math-train-{0000..0012}.json.gz + validation_urls: [] + dolma/pes2o: + cache_dir: gs://marin-eu-west4/tokenized/dolma/pes2o-538363 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/pes2o-{0000..0025}.json.gz + validation_urls: [] + dolma/reddit: + cache_dir: gs://marin-eu-west4/tokenized/dolma/reddit-62a64a + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/reddit-{0000..0077}.json.gz + validation_urls: [] + dolma/stackexchange: + cache_dir: gs://marin-eu-west4/tokenized/dolma/stackexchange-adfc49 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/stackexchange-{0000..0025}.json.gz + validation_urls: [] + dolma/starcoder: + cache_dir: gs://marin-eu-west4/tokenized/dolma/starcoder-8b6089 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/starcoder-{0000..0048}.json.gz + validation_urls: [] + dolma/wiki: + cache_dir: gs://marin-eu-west4/tokenized/dolma/wiki-212315 + id: null + name: null + plaintext: false + stream: true + tags: [] + text_key: text + train_urls: + - gs://marin-us-central2/raw/dolma/v1.7/wiki-{0000..0001}.json.gz + validation_urls: [] +enforce_eos: true +ignore_token_id: null +mixture_block_size: 2048 +shuffle: true +stop_strategy: restart +tokenizer: nvidia/Llama-3.1-Nemotron-70B-Instruct-HF +train_weights: + dolma/algebraic-stack: 12.6 + dolma/arxiv: 28.0 + dolma/c4: 124.95 + dolma/cc: 597.75 + dolma/cc-news: 14.3 + dolma/falcon: 456.4 + dolma/flan: 16.5 + dolma/gutenberg: 5.3 + dolma/megawika: 4.6 + dolma/open-web-math: 12.6 + dolma/pes2o: 57.2 + dolma/reddit: 79.9 + dolma/stackexchange: 19.6 + dolma/starcoder: 263.8 + dolma/wiki: 7.4 +vocab_size: null \ No newline at end of file diff --git a/config/llama3.1_tulu3_sft.yaml b/config/llama3.1_tulu3_sft.yaml index 499fe2091..19c7c5184 100644 --- a/config/llama3.1_tulu3_sft.yaml +++ b/config/llama3.1_tulu3_sft.yaml @@ -5,7 +5,7 @@ supervised_data: # cache_dir before trying sequence packing cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer-7b19dc" #cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer_retrypack-bca8bd/" - + max_seq_len: 4096 tokenizer: "meta-llama/Meta-Llama-3.1-8B" model: # 8B llama3 class model @@ -21,7 +21,7 @@ model: # 8B llama3 class model use_bias: false use_layer_norm_weight: true initializer_range: 0.02 - rope: + rope: type: "llama3" trainer: diff --git a/config/llama3_openthoughts_sft.yaml b/config/llama3_openthoughts_sft.yaml index bff3f1435..1bd94f879 100644 --- a/config/llama3_openthoughts_sft.yaml +++ b/config/llama3_openthoughts_sft.yaml @@ -5,7 +5,7 @@ supervised_data: # cache_dir before trying sequence packing cache_dir: "gs://marin-us-central2/tokenized/openthoughts_llama3_tokenizer-9edd80" #cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer_retrypack-bca8bd/" - + max_seq_len: 4096 tokenizer: "meta-llama/Meta-Llama-3.1-8B" model: # 8B llama3 class model @@ -21,7 +21,7 @@ model: # 8B llama3 class model use_bias: false use_layer_norm_weight: true initializer_range: 0.02 - rope: + rope: type: "llama3" trainer: seed: 1 diff --git a/config/llama3_sft_hf_ckpt.yaml b/config/llama3_sft_hf_ckpt.yaml index 950ca964b..fa30ba228 100644 --- a/config/llama3_sft_hf_ckpt.yaml +++ b/config/llama3_sft_hf_ckpt.yaml @@ -12,8 +12,8 @@ model: use_bias: false use_layer_norm_weight: true initializer_range: 0.02 - rope: + rope: type: "llama3" # need to set this! -tokenizer: "meta-llama/Meta-Llama-3.1-8B" \ No newline at end of file +tokenizer: "meta-llama/Meta-Llama-3.1-8B" diff --git a/config/llama_7b_with_olmo_config_euwest4.yaml b/config/llama_7b_with_olmo_config_euwest4.yaml new file mode 100644 index 000000000..1f83e293b --- /dev/null +++ b/config/llama_7b_with_olmo_config_euwest4.yaml @@ -0,0 +1,39 @@ +data: !include data/dolma_llama_euwest.yaml +model: # 7B class model + type: llama + seq_len: 2048 + hidden_dim: 4096 + intermediate_dim: 11008 + num_layers: 32 + num_heads: 32 + num_kv_heads: 32 + use_flash_attention: True + flash_attention_block_size: 1024 +trainer: + tracker: + type: wandb + project: "marin" + tags: ["dolma", "olmo", "llama"] + checkpointer: + keep: + - every: 1 + until: 2 + - every: 5 + until: 30 + - every: 50 + until: 1000 + - every: 1000 + until: 40000 +python -m levanter.main.export_lm_to_hf --checkpoint_path "gs://marin-ckpt-eu-w4/checkpoints/olmo7b_seed0_datafix0/fgtbtvho/step-25" --output_dir "gs://marin-ckpt-eu-w4/checkpoints/olmo7b_seed0_datafix0/hf_100M_tokens" + mp: p=f32,c=bfloat16 + train_batch_size: 2048 + num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 + steps_per_eval: 1000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" +optimizer: + learning_rate: 4E-4 + weight_decay: 0.1 + min_lr_ratio: 0.1 + warmup: 0.01 diff --git a/config/debug_pack_sft.yaml b/config/llama_7b_with_olmo_config_uswest4.yaml similarity index 56% rename from config/debug_pack_sft.yaml rename to config/llama_7b_with_olmo_config_uswest4.yaml index d5cb11503..76e0ff1df 100644 --- a/config/debug_pack_sft.yaml +++ b/config/llama_7b_with_olmo_config_uswest4.yaml @@ -1,10 +1,4 @@ -dataset_type: chat_jsonl -chat_train_urls: - - "gs://marin-us-central2/documents/allenai--tulu-v2-sft-mixture-0ba27c/data/**/*.jsonl.gz" -supervised_data: - cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer_retrypack-bca8bd/" - -tokenizer: "meta-llama/Meta-Llama-3.1-8B" +data: !include data/dolma_llama.yaml model: # 7B class model type: llama seq_len: 2048 @@ -14,17 +8,25 @@ model: # 7B class model num_heads: 32 num_kv_heads: 32 use_flash_attention: True - flash_attention_block_size: 512 - use_bias: false - use_layer_norm_weight: false + flash_attention_block_size: 1024 trainer: tracker: type: wandb project: "marin" tags: ["dolma", "olmo", "llama"] + checkpointer: + keep: + - every: 1 + until: 2 + - every: 5 + until: 30 + - every: 50 + until: 1000 + - every: 1000 + until: 40000 mp: p=f32,c=bfloat16 - train_batch_size: 256 + train_batch_size: 2048 num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 steps_per_eval: 1000 tensor_parallel_axes: ["mlp", "heads"] @@ -34,6 +36,4 @@ optimizer: learning_rate: 4E-4 weight_decay: 0.1 min_lr_ratio: 0.1 - warmup: 5000 - -epoch: 0 + warmup: 0.01 diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index 4393fc9c8..12e988d22 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -18,7 +18,6 @@ from levanter.models.attention import AttentionMask from levanter.models.lm_model import LmExample from levanter.utils.jax_utils import local_cpu_mesh -import time # cf https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/data_generators/generator_utils.py#L623 @@ -112,59 +111,30 @@ def pack_prompt_completions( """ Packs a list of prompt completions into LmExamples using the SequencePacker """ - - # start_packer = time.perf_counter() packers = [SequencePacker(Pos, max_segments_per_example, pad_token)] - # print(f"time to create packer is {time.perf_counter() - start_packer:.6f} seconds", flush=True) - # put timer in here for sequence in sequences: - # start_time = time.perf_counter() - # put timer here, keep in mind the loop is an iterator so we want to - # separate the time to get examples, and how long it takes to pack loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 loss_mask[-1] = 0 assert np.any(loss_mask) - # time how long to pack, and subtract - #start_for_pack_yield = time.perf_counter() for packer in packers: if packer.can_pack(sequence.ids): - # add_example = time.perf_counter() packer.add_example(sequence.ids, loss_mask, sequence.segment_id) - # add_example_end = time.perf_counter() - # time_to_add_example = add_example_end - add_example - # print(f" time to add example to segment packer { time_to_add_example:.6f} seconds", flush=True) - # end_iter_plus_example = time.perf_counter() - # time_to_reach_example_total = end_iter_plus_example - start_time - # print(f"time to get sequence, so time to get here minus time to add example {time_to_reach_example_total - time_to_add_example:.6f} seconds", flush=True) if packer.num_segments == max_segments_per_example: - ot = packer.pack() - # end_time = time.perf_counter() - # time_to_yield = end_time - start_for_pack_yield - # print(f"MAX SEG total time to for loop until we yielded a packed example {time_to_yield}", flush=True) - # print(f"MAX SEG total time for iterator {start_time - time_to_yield}", flush=True) - yield ot + yield packer.pack() packers.remove(packer) break else: # no packer could fit the example, create a new one - #start_new_packer = time.perf_counter() packer = SequencePacker(Pos, max_segments_per_example, pad_token) packer.add_example(sequence.ids, loss_mask, sequence.segment_id) packers.append(packer) - #print(f"time to create new packer is {time.perf_counter() - start_new_packer:.6f}", flush=True) while len(packers) >= max_buffered_examples: - #start_return_packed_example_full_buffer = time.perf_counter() - max_example = packers.pop(0).pack() - #print(f"time to create lm example when max buffered examples is {time.perf_counter() - start_return_packed_example_full_buffer:.6f}", flush=True) - yield max_example + yield packers.pop(0).pack() for packer in packers: - #start_return_packed_example = time.perf_counter() - example = packer.pack() - #print(f"time to create lm example from packer is {time.perf_counter() - start_return_packed_example:.6f}", flush=True) - yield example + yield packer.pack() def per_segment_loss( diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 066912cee..b48d558de 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -975,6 +975,38 @@ def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase, should_ap } +def mk_cached_sft_dataset( + config: ChatUrlDataSourceConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis +) -> AsyncDataset[dict]: + """Creates a dataset from JSONL files containing chat format data for SFT.""" + source = config.get_shard_source("train") + if source is None: + raise ValueError("No training data source found") + + # Set up example structure matching supervised case + output_exemplar = {"input_ids": np.zeros((0,), dtype=np.int32), "sources_len": np.zeros((0,), dtype=np.int32)} + + input_ids = tokenizer("hi there")["input_ids"] + should_append_eos = input_ids[-1] != tokenizer.eos_token_id + logger.info(f"Manual EOS Needed: {should_append_eos}") + + # Process the dataset + dataset = source.map_batches( + lambda ex: preprocess_chat_example(ex, tokenizer, should_append_eos), + batch_size=128, + num_cpus=num_cpus_used_by_tokenizer(tokenizer), + output_exemplar=output_exemplar, + ) + + # Cache the processed data + cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) + + # Ensure padding token is set (needed by _prepare_supervised_example) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + return cached_dataset + + def mk_chat_sft_dataset( config: ChatUrlDataSourceConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis ) -> AsyncDataset[LmExample]: @@ -1004,9 +1036,8 @@ def mk_chat_sft_dataset( # Ensure padding token is set (needed by _prepare_supervised_example) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - return cached_dataset # Reuse the supervised prepare function directly - # return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) + return cached_dataset.map_batches(lambda ex: _prepare_supervised_examples(ex, tokenizer, Pos)) @dataclass diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index 96c3fef07..1154b0c8d 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -1,44 +1,43 @@ +import asyncio import dataclasses import logging import os from dataclasses import dataclass, field from enum import Enum -from typing import List, Optional, Union, Iterator +from typing import Iterator, List, Optional, Union -import jax.random as jrandom import jax.numpy as jnp +import jax.random as jrandom import transformers +from optax.tree_utils import tree_zeros_like import haliax as hax from haliax import Axis from haliax.partitioning import round_axis_for_partitioning -from optax.tree_utils import tree_zeros_like -import time import levanter from levanter import callbacks from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback from levanter.data import PermutationDataset, batched +from levanter.data.dataset import AsyncDataset +from levanter.data.loader import stack_tree +from levanter.data.packing import PromptCompletion, pack_prompt_completions from levanter.data.text import ( ChatUrlDataSourceConfig, EpochDataset, SupervisedSourceConfig, - mk_chat_sft_dataset, + mk_cached_sft_dataset, mk_supervised_dataset, ) from levanter.models.attention import AttentionMask from levanter.models.llama import LlamaConfig -from levanter.models.lm_model import LmConfig, LmHeadModel, LmExample, compute_next_token_loss +from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel, compute_next_token_loss from levanter.optim import AdamConfig, OptimizerConfig from levanter.trainer import Trainer, TrainerConfig -from levanter.data.loader import stack_tree -from levanter.data.packing import PromptCompletion, pack_prompt_completions from levanter.utils.background_iterable import BackgroundIterator from levanter.utils.hf_utils import HfTokenizer -from levanter.utils.jax_utils import broadcast_shard, use_cpu_device -from levanter.data.dataset import AsyncDataset +from levanter.utils.jax_utils import use_cpu_device -import asyncio logger = logging.getLogger(__name__) @@ -155,7 +154,7 @@ def train(config: SFTConfig): input_role=config.input_role, output_role=config.output_role, ) - train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos) + train_dataset = mk_cached_sft_dataset(chat_config, tokenizer, model_config.Pos) else: assert config.supervised_data is not None if isinstance(config.supervised_data, dict): @@ -214,16 +213,6 @@ def train(config: SFTConfig): # Get current step from trainer state current_step = int(state.step) - - - - # change iterate tokenized requests to take a dict rather than a list - # of where the first element is prompt ands econd is response - - # then pass into tierate tokenizer requests, go to pack requests - # and then you have the correct loader, just pass to trainer.train() - - # TODO figure out if there's a better heuristic for max segements to pack per example? logger.info("Creating prompt completion iterator") prompt_completion_iterator = create_prompt_completion_iterator(train_dataset, Pos) @@ -231,7 +220,7 @@ def train(config: SFTConfig): logger.info(f"Resuming training from step {current_step}") # Calculate how many examples to skip based on batch size examples_to_skip = current_step * trainer.config.train_batch_size - + # Skip through the iterator until we reach the right position for _ in range(examples_to_skip): try: @@ -251,9 +240,6 @@ def train(config: SFTConfig): logger.info("Creating data loader") packed_loader = BackgroundIterator(packed_iterator, max_capacity=1024) - # to be moved - #loader = trainer.data_loader(train_dataset, trainer.TrainBatch) - if config.hf_save_path is not None: # bit gross to reach this far into the config, but it's fine if config.trainer.checkpointer.append_run_id_to_base_path: @@ -268,31 +254,26 @@ def train(config: SFTConfig): trainer.train(state, packed_loader) + def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axis) -> Iterator[PromptCompletion]: """ Creates an iterator that yields PromptCompletion objects from a cached dataset. - + Args: cached_dataset: The AsyncDataset containing preprocessed examples Pos: The position axis defining maximum sequence length - + Returns: An iterator yielding PromptCompletion objects """ # AsyncDataset already has a current_len method that returns current length or None length = asyncio.run(cached_dataset.async_len()) - + if length is None: raise ValueError("Dataset length cannot be None") - - # TODO play around with batch size + for batch_indicies in batched(range(length), 4096): - # put timer here - # start_time = time.perf_counter() examples = asyncio.run(cached_dataset.get_batch(batch_indicies)) - # end_time = time.perf_counter() - # elapsed_time = end_time - start_time - # print(f"Elapsed time for get batches: {elapsed_time:.6f} seconds", flush=True) for i in range(len(examples)): example = examples[i] @@ -302,20 +283,17 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi ids = example["input_ids"].tolist() if len(ids) > Pos.size: - ids = ids[:Pos.size] + ids = ids[: Pos.size] if len(ids) <= sources_len: continue try: - yield PromptCompletion( - ids=ids, - prompt_length=sources_len, - segment_id=batch_indicies[i] - ) + yield PromptCompletion(ids=ids, prompt_length=sources_len, segment_id=batch_indicies[i]) except ValueError: continue + def _pack_requests( prompt_completion_iterator: Iterator[PromptCompletion], tokenizer: HfTokenizer, Pos: hax.Axis, max_pack_size: int ) -> Iterator[LmExample]: @@ -328,50 +306,45 @@ def _pack_requests( max_buffered_examples=16, ) + """ Helper function to create a dummy instance with the same shape as the batch. When we reach the end of the dataset but we want a full batch, will give a batch of zeros with -1 segment mask so it doesn't affect loss """ + + def _make_dummy_instance(batch, Pos): - dummy_instance: LmExample = tree_zeros_like(batch[0]) - dummy_segment_mask = hax.full(Pos, -1, dtype=jnp.int32) - dummy_attn = AttentionMask.causal().with_segment_ids(dummy_segment_mask) - dummy_instance = dataclasses.replace(dummy_instance, attn_mask=dummy_attn) - return dummy_instance + dummy_instance: LmExample = tree_zeros_like(batch[0]) + dummy_segment_mask = hax.full(Pos, -1, dtype=jnp.int32) + dummy_attn = AttentionMask.causal().with_segment_ids(dummy_segment_mask) + dummy_instance = dataclasses.replace(dummy_instance, attn_mask=dummy_attn) + return dummy_instance + def stack_batches(example_iterator, Pos, TrainBatch): - """ - Stack examples from an iterator into a batch. - - Args: - TrainBatch: The batch axis. - Pos: The position axis. - example_iterator: An iterator of examples. - - Returns: - A batch of examples. - """ - # add timer here as well and profile - with use_cpu_device(): - batch_count = 0 - for batch in batched(example_iterator, TrainBatch.size): - batch_count += 1 - #start_time_loop = time.perf_counter() - if len(batch) < TrainBatch.size: - dummy_instance = _make_dummy_instance(batch, Pos) - batch.extend([dummy_instance] * (TrainBatch.size - len(batch))) - # # Start timing before calling stack_tree - # start_time = time.perf_counter() - result = stack_tree(TrainBatch, batch) # Capture the result - # stack_time = time.perf_counter() - start_time # Calculate elapsed time - - # print(f"Stack tree execution time: {stack_time:.6f} seconds", flush=True) - yield result # Yield the computed result - # end_time_loop = time.perf_counter() - # loop_time = end_time_loop - start_time_loop - # print(f"Loop takes {loop_time}") - # print(f"Iterator time is {loop_time - stack_time}") + """ + Stack examples from an iterator into a batch. + + Args: + TrainBatch: The batch axis. + Pos: The position axis. + example_iterator: An iterator of examples. + + Returns: + A batch of examples. + """ + # add timer here as well and profile + with use_cpu_device(): + batch_count = 0 + for batch in batched(example_iterator, TrainBatch.size): + batch_count += 1 + if len(batch) < TrainBatch.size: + dummy_instance = _make_dummy_instance(batch, Pos) + batch.extend([dummy_instance] * (TrainBatch.size - len(batch))) + result = stack_tree(TrainBatch, batch) # Capture the result + yield result # Yield the computed result + def add_special_tokens(tokenizer, use_unk_instead_of_adding=False): special_tokens_dict = dict() From 1fdeec537ae790f7a94360e58c0b943b439a6e69 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 5 Feb 2025 14:02:42 -0800 Subject: [PATCH 13/16] update configs --- config/data/dolma_llama.yaml | 2 +- config/data/dolma_llama_euwest.yaml | 2 +- config/llama3.1_tulu3_sft.yaml | 13 ++++++++++--- config/llama_7b_with_olmo_config_euwest4.yaml | 1 - 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/config/data/dolma_llama.yaml b/config/data/dolma_llama.yaml index 681c5f22b..3fe11e49f 100644 --- a/config/data/dolma_llama.yaml +++ b/config/data/dolma_llama.yaml @@ -199,4 +199,4 @@ train_weights: dolma/stackexchange: 19.6 dolma/starcoder: 263.8 dolma/wiki: 7.4 -vocab_size: null \ No newline at end of file +vocab_size: null diff --git a/config/data/dolma_llama_euwest.yaml b/config/data/dolma_llama_euwest.yaml index 32876031b..897b315c0 100644 --- a/config/data/dolma_llama_euwest.yaml +++ b/config/data/dolma_llama_euwest.yaml @@ -197,4 +197,4 @@ train_weights: dolma/stackexchange: 19.6 dolma/starcoder: 263.8 dolma/wiki: 7.4 -vocab_size: null \ No newline at end of file +vocab_size: null diff --git a/config/llama3.1_tulu3_sft.yaml b/config/llama3.1_tulu3_sft.yaml index 19c7c5184..588b48412 100644 --- a/config/llama3.1_tulu3_sft.yaml +++ b/config/llama3.1_tulu3_sft.yaml @@ -30,16 +30,21 @@ trainer: type: wandb project: "marin" tags: ["dolma", "olmo", "llama"] + wandb: + project: "marin" + name: "llama3.1_tulu_sft_packed" mp: p=f32,c=bfloat16 # same as 606 sft in marin train_batch_size: 128 # number of steps until we hit stop iteration - num_train_steps: 1791 # 3,000,000,000,000 / 4,000,000 = 750,000 + num_train_steps: 2574 # 3,000,000,000,000 / 4,000,000 = 750,000 steps_per_eval: 1000 tensor_parallel_axes: ["mlp", "heads"] fsdp_axis: "embed" batch_axis: "batch" + checkpointer: + base_path: "gs://levanter-checkpoints/marin/llama_3.1_tulusft/" optimizer: learning_rate: 5e-6 weight_decay: 0.0 @@ -47,7 +52,9 @@ optimizer: lr_schedule: "linear" warmup: 0.03 -hf_save_steps: 1790 -hf_save_path: "gs://levanter-checkpoints/marin/llama_3.1_tulusft/" +hf_save_steps: 500 +hf_save_path: "gs://levanter-checkpoints/marin/llama_3.1_tulusft/hf/" +initialize_from_hf: True +model_name_or_path: "meta-llama/Llama-3.1-8B" epoch: 0 diff --git a/config/llama_7b_with_olmo_config_euwest4.yaml b/config/llama_7b_with_olmo_config_euwest4.yaml index 1f83e293b..876618fd6 100644 --- a/config/llama_7b_with_olmo_config_euwest4.yaml +++ b/config/llama_7b_with_olmo_config_euwest4.yaml @@ -24,7 +24,6 @@ trainer: until: 1000 - every: 1000 until: 40000 -python -m levanter.main.export_lm_to_hf --checkpoint_path "gs://marin-ckpt-eu-w4/checkpoints/olmo7b_seed0_datafix0/fgtbtvho/step-25" --output_dir "gs://marin-ckpt-eu-w4/checkpoints/olmo7b_seed0_datafix0/hf_100M_tokens" mp: p=f32,c=bfloat16 train_batch_size: 2048 num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000 From 361b68004c2c5e14462a8ead9322451334c33a91 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 5 Feb 2025 14:15:37 -0800 Subject: [PATCH 14/16] david's suggestions --- src/levanter/data/packing.py | 3 +++ src/levanter/data/text.py | 4 ---- src/levanter/main/sft.py | 35 ++++++++++++++++------------------- 3 files changed, 19 insertions(+), 23 deletions(-) diff --git a/src/levanter/data/packing.py b/src/levanter/data/packing.py index 12e988d22..a049de56c 100644 --- a/src/levanter/data/packing.py +++ b/src/levanter/data/packing.py @@ -111,7 +111,9 @@ def pack_prompt_completions( """ Packs a list of prompt completions into LmExamples using the SequencePacker """ + packers = [SequencePacker(Pos, max_segments_per_example, pad_token)] + for sequence in sequences: loss_mask = np.arange(len(sequence.ids)) >= sequence.prompt_length - 1 loss_mask[-1] = 0 @@ -120,6 +122,7 @@ def pack_prompt_completions( for packer in packers: if packer.can_pack(sequence.ids): packer.add_example(sequence.ids, loss_mask, sequence.segment_id) + if packer.num_segments == max_segments_per_example: yield packer.pack() packers.remove(packer) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index c35bf2f52..90a4027a7 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -1000,10 +1000,6 @@ def mk_cached_sft_dataset( # Cache the processed data cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True) - - # Ensure padding token is set (needed by _prepare_supervised_example) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token return cached_dataset diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index 1154b0c8d..9503cb046 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -35,7 +35,6 @@ from levanter.optim import AdamConfig, OptimizerConfig from levanter.trainer import Trainer, TrainerConfig from levanter.utils.background_iterable import BackgroundIterator -from levanter.utils.hf_utils import HfTokenizer from levanter.utils.jax_utils import use_cpu_device @@ -233,7 +232,13 @@ def train(config: SFTConfig): logger.info("Starting SFT from scratch") logger.info("Packing prompt completions") - packed_iterator = _pack_requests(prompt_completion_iterator, tokenizer, Pos, max_pack_size=4) + packed_iterator = pack_prompt_completions( + Pos, + prompt_completion_iterator, + max_segments_per_example=4, + pad_token=tokenizer.pad_token_id, + max_buffered_examples=16, + ) logger.info("Stacking batches to train batch") packed_iterator = stack_batches(example_iterator=packed_iterator, Pos=Pos, TrainBatch=trainer.TrainBatch) # TODO what's a good number for max_capacity? @@ -272,8 +277,8 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi if length is None: raise ValueError("Dataset length cannot be None") - for batch_indicies in batched(range(length), 4096): - examples = asyncio.run(cached_dataset.get_batch(batch_indicies)) + for indicies in batched(range(length), 4096): + examples = asyncio.run(cached_dataset.get_batch(indicies)) for i in range(len(examples)): example = examples[i] @@ -289,24 +294,16 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi continue try: - yield PromptCompletion(ids=ids, prompt_length=sources_len, segment_id=batch_indicies[i]) - except ValueError: + yield PromptCompletion(ids=ids, prompt_length=sources_len, segment_id=indicies[i]) + except ValueError as e: + # Likely error: PromptCompletion may raise a ValueError if the token list is empty or if its length is not greater than the prompt_length. + logger.error( + f"Error creating PromptCompletion (ids length: {len(ids)}, sources_len: {sources_len}, segment id:" + f" {indicies[i]}): {e}" + ) continue -def _pack_requests( - prompt_completion_iterator: Iterator[PromptCompletion], tokenizer: HfTokenizer, Pos: hax.Axis, max_pack_size: int -) -> Iterator[LmExample]: - # TODO: use a better packing algorithm? - yield from pack_prompt_completions( - Pos, - prompt_completion_iterator, - max_segments_per_example=max_pack_size, - pad_token=tokenizer.pad_token_id, - max_buffered_examples=16, - ) - - """ Helper function to create a dummy instance with the same shape as the batch. When we reach the end of the dataset but we want a full batch, From 1f1efa4765668e9d39640e152fb837fd5f2df9ef Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 5 Feb 2025 14:46:35 -0800 Subject: [PATCH 15/16] david's suggested fix --- src/levanter/data/loader.py | 41 +++++++++++++++++++++++++++++- src/levanter/eval_harness.py | 31 ++--------------------- src/levanter/main/sft.py | 49 +++--------------------------------- 3 files changed, 45 insertions(+), 76 deletions(-) diff --git a/src/levanter/data/loader.py b/src/levanter/data/loader.py index 5db9b96b9..4930ce41c 100644 --- a/src/levanter/data/loader.py +++ b/src/levanter/data/loader.py @@ -1,3 +1,4 @@ +import dataclasses import logging import time from collections import defaultdict @@ -11,6 +12,7 @@ from jax.experimental import multihost_utils from jax.sharding import Mesh, PartitionSpec from jaxtyping import PyTree +from optax.tree_utils import tree_zeros_like import haliax as hax from haliax import is_named_array @@ -19,9 +21,11 @@ from levanter.data.dataset import AsyncDataset from levanter.data.utils import batched +from levanter.models.attention import AttentionMask +from levanter.models.lm_model import LmExample from levanter.shapes import NamedShapeSpec, ShapeSpec, to_raw_shape from levanter.utils.background_iterable import BackgroundIterator -from levanter.utils.jax_utils import local_cpu_mesh +from levanter.utils.jax_utils import local_cpu_mesh, use_cpu_device from levanter.utils.thread_utils import AsyncIteratorWrapper, blocking_wait @@ -247,6 +251,41 @@ def _pspec_for(self, shape_spec: ShapeSpec | NamedShapeSpec) -> PartitionSpec: return hax.partitioning.pspec_for_axis(shape_spec.shape, self.dl.axis_resources) # type: ignore +def _make_dummy_instance(batch, Pos): + """ + Creates a dummy instance matching the shape of the provided batch. + If the dataset is exhausted and a full batch is needed, this function returns a dummy instance + with all elements set to zero and a segment mask filled with -1. This design ensures that the dummy + instance does not contribute to the loss during training. + """ + dummy_instance: LmExample = tree_zeros_like(batch[0]) + dummy_segment_mask = hax.full(Pos, -1, dtype=jnp.int32) + dummy_attn = AttentionMask.causal().with_segment_ids(dummy_segment_mask) + dummy_instance = dataclasses.replace(dummy_instance, attn_mask=dummy_attn) + return dummy_instance + + +def stack_batches(example_iterator, Pos, Batch): + """ + Stack examples from an iterator into a batch. + + Args: + Batch: The batch axis. + Pos: The position axis. + example_iterator: An iterator of examples. + + Returns: + A batch of examples. + """ + # add timer here as well and profile + with use_cpu_device(): + for batch in batched(example_iterator, Batch.size): + if len(batch) < Batch.size: + dummy_instance = _make_dummy_instance(batch, Pos) + batch.extend([dummy_instance] * (Batch.size - len(batch))) + yield stack_tree(Batch, batch) + + def _batchified_shape(Batch, leaf: hax.NamedArray | Array) -> ShapeSpec | NamedShapeSpec: if is_named_array(leaf): return NamedShapeSpec((Batch,) + leaf.axes, leaf.dtype) diff --git a/src/levanter/eval_harness.py b/src/levanter/eval_harness.py index a4aa19f4d..5fd2da271 100644 --- a/src/levanter/eval_harness.py +++ b/src/levanter/eval_harness.py @@ -31,16 +31,13 @@ import jmp import numpy as np from jax.sharding import PartitionSpec -from optax.tree_utils import tree_zeros_like import haliax from haliax import NamedArray import levanter.tracker from levanter.compat.hf_checkpoints import HFCheckpointConverter, load_tokenizer -from levanter.data.loader import stack_tree from levanter.data.packing import PromptCompletion, pack_prompt_completions, per_segment_correct, per_segment_loss -from levanter.models.attention import AttentionMask from levanter.models.gpt2 import Gpt2Config from levanter.models.loss import next_token_loss from levanter.utils.background_iterable import BackgroundIterator @@ -65,6 +62,7 @@ import levanter.config from levanter.checkpoint import load_checkpoint from levanter.data import batched +from levanter.data.loader import stack_batches from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel from levanter.trainer import StepInfo, TrainerConfig from levanter.utils.jax_utils import broadcast_shard, use_cpu_device @@ -256,7 +254,7 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: self.tokenizer.pad_token_id = self.tokenizer.eos_token_id packed_iterator = _pack_requests(requests, self.tokenizer, self.EvalPos, self.leader.max_packed_segments) - packed_iterator = self.stack_batches(packed_iterator, self.EvalBatch) + packed_iterator = stack_batches(packed_iterator, self.EvalPos, self.EvalBatch) packed_iterator = BackgroundIterator(packed_iterator, max_capacity=1024) result_probs = np.zeros(len(requests)) @@ -309,31 +307,6 @@ def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: return result - def stack_batches(self, example_iterator, EvalBatch): - """ - Stack examples from an iterator into a batch. - - Args: - EvalBatch: The batch axis. - example_iterator: An iterator of examples. - - Returns: - A batch of examples. - """ - with use_cpu_device(): - for batch in batched(example_iterator, EvalBatch.size): - if len(batch) < EvalBatch.size: - dummy_instance = self._make_dummy_instance(batch) - batch.extend([dummy_instance] * (EvalBatch.size - len(batch))) - yield stack_tree(EvalBatch, batch) - - def _make_dummy_instance(self, batch): - dummy_instance: LmExample = tree_zeros_like(batch[0]) - dummy_segment_mask = hax.full(self.EvalPos, -1, dtype=jnp.int32) - dummy_attn = AttentionMask.causal().with_segment_ids(dummy_segment_mask) - dummy_instance = dataclasses.replace(dummy_instance, attn_mask=dummy_attn) - return dummy_instance - def loglikelihood_rolling(self, requests) -> List[Tuple[float]]: raise NotImplementedError() diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index 9503cb046..cb5a12926 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -6,10 +6,8 @@ from enum import Enum from typing import Iterator, List, Optional, Union -import jax.numpy as jnp import jax.random as jrandom import transformers -from optax.tree_utils import tree_zeros_like import haliax as hax from haliax import Axis @@ -20,7 +18,7 @@ from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback from levanter.data import PermutationDataset, batched from levanter.data.dataset import AsyncDataset -from levanter.data.loader import stack_tree +from levanter.data.loader import stack_batches from levanter.data.packing import PromptCompletion, pack_prompt_completions from levanter.data.text import ( ChatUrlDataSourceConfig, @@ -29,13 +27,11 @@ mk_cached_sft_dataset, mk_supervised_dataset, ) -from levanter.models.attention import AttentionMask from levanter.models.llama import LlamaConfig -from levanter.models.lm_model import LmConfig, LmExample, LmHeadModel, compute_next_token_loss +from levanter.models.lm_model import LmConfig, LmHeadModel, compute_next_token_loss from levanter.optim import AdamConfig, OptimizerConfig from levanter.trainer import Trainer, TrainerConfig from levanter.utils.background_iterable import BackgroundIterator -from levanter.utils.jax_utils import use_cpu_device logger = logging.getLogger(__name__) @@ -240,7 +236,7 @@ def train(config: SFTConfig): max_buffered_examples=16, ) logger.info("Stacking batches to train batch") - packed_iterator = stack_batches(example_iterator=packed_iterator, Pos=Pos, TrainBatch=trainer.TrainBatch) + packed_iterator = stack_batches(example_iterator=packed_iterator, Pos=Pos, Batch=trainer.TrainBatch) # TODO what's a good number for max_capacity? logger.info("Creating data loader") packed_loader = BackgroundIterator(packed_iterator, max_capacity=1024) @@ -304,45 +300,6 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi continue -""" -Helper function to create a dummy instance with the same shape as the batch. -When we reach the end of the dataset but we want a full batch, -will give a batch of zeros with -1 segment mask so it doesn't affect loss -""" - - -def _make_dummy_instance(batch, Pos): - dummy_instance: LmExample = tree_zeros_like(batch[0]) - dummy_segment_mask = hax.full(Pos, -1, dtype=jnp.int32) - dummy_attn = AttentionMask.causal().with_segment_ids(dummy_segment_mask) - dummy_instance = dataclasses.replace(dummy_instance, attn_mask=dummy_attn) - return dummy_instance - - -def stack_batches(example_iterator, Pos, TrainBatch): - """ - Stack examples from an iterator into a batch. - - Args: - TrainBatch: The batch axis. - Pos: The position axis. - example_iterator: An iterator of examples. - - Returns: - A batch of examples. - """ - # add timer here as well and profile - with use_cpu_device(): - batch_count = 0 - for batch in batched(example_iterator, TrainBatch.size): - batch_count += 1 - if len(batch) < TrainBatch.size: - dummy_instance = _make_dummy_instance(batch, Pos) - batch.extend([dummy_instance] * (TrainBatch.size - len(batch))) - result = stack_tree(TrainBatch, batch) # Capture the result - yield result # Yield the computed result - - def add_special_tokens(tokenizer, use_unk_instead_of_adding=False): special_tokens_dict = dict() if use_unk_instead_of_adding: From 683d39c9f14cb80dfa58a44e09fbc10b8c462081 Mon Sep 17 00:00:00 2001 From: Ahmed Ahmed Date: Wed, 5 Feb 2025 16:46:45 -0800 Subject: [PATCH 16/16] fix configs --- config/data/dolma_llama.yaml | 2 -- config/{llama3_sft_hf_ckpt.yaml => sft_hf_llama3_ckpt.yaml} | 0 config/{llama3.1_tulu3_sft.yaml => sft_llama3.1_tulu3.yaml} | 0 ...lama3_openthoughts_sft.yaml => sft_llama3_openthoughts.yaml} | 0 4 files changed, 2 deletions(-) rename config/{llama3_sft_hf_ckpt.yaml => sft_hf_llama3_ckpt.yaml} (100%) rename config/{llama3.1_tulu3_sft.yaml => sft_llama3.1_tulu3.yaml} (100%) rename config/{llama3_openthoughts_sft.yaml => sft_llama3_openthoughts.yaml} (100%) diff --git a/config/data/dolma_llama.yaml b/config/data/dolma_llama.yaml index 3fe11e49f..fc93bc830 100644 --- a/config/data/dolma_llama.yaml +++ b/config/data/dolma_llama.yaml @@ -2,8 +2,6 @@ cache_dir: null cache_options: batch_size: 128 num_shard_groups: 128 - prefetch_per_group: 4 - shard_order_randomization_key: 0 target_size_per_flush: 512MB configs: dolma/algebraic-stack: diff --git a/config/llama3_sft_hf_ckpt.yaml b/config/sft_hf_llama3_ckpt.yaml similarity index 100% rename from config/llama3_sft_hf_ckpt.yaml rename to config/sft_hf_llama3_ckpt.yaml diff --git a/config/llama3.1_tulu3_sft.yaml b/config/sft_llama3.1_tulu3.yaml similarity index 100% rename from config/llama3.1_tulu3_sft.yaml rename to config/sft_llama3.1_tulu3.yaml diff --git a/config/llama3_openthoughts_sft.yaml b/config/sft_llama3_openthoughts.yaml similarity index 100% rename from config/llama3_openthoughts_sft.yaml rename to config/sft_llama3_openthoughts.yaml