From 478984afc4dfb456fe4fb4d01f74be9e9191aafb Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 20 Nov 2024 12:23:49 +0000 Subject: [PATCH] remove fp8 tp from llama's modeling code, fix no grad in param, remove rms norm due to illegal memory --- examples/config_tiny_fp8_llama.yaml | 109 ++++++++++++++++++ examples/config_tiny_llama.yaml | 52 ++++----- src/nanotron/models/llama.py | 46 ++++---- src/nanotron/nn/layer_norm.py | 55 ++++++++- src/nanotron/optim/gradient_accumulator.py | 9 +- .../parallel/tensor_parallel/functional.py | 13 +-- src/nanotron/scaling/parametrization.py | 1 + 7 files changed, 226 insertions(+), 59 deletions(-) create mode 100644 examples/config_tiny_fp8_llama.yaml diff --git a/examples/config_tiny_fp8_llama.yaml b/examples/config_tiny_fp8_llama.yaml new file mode 100644 index 00000000..58645e2d --- /dev/null +++ b/examples/config_tiny_fp8_llama.yaml @@ -0,0 +1,109 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Annealing Phase + start_training_step: 10 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 16 + initializer_range: 0.02 + intermediate_size: 64 + is_llama_config: true + max_position_embeddings: 256 + num_attention_heads: 4 + num_hidden_layers: 2 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 2 + expert_parallel_size: 1 + pp: 2 + pp_engine: 1f1b + tp: 2 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 256 + train_steps: 15 + val_check_interval: -1 diff --git a/examples/config_tiny_llama.yaml b/examples/config_tiny_llama.yaml index 0fd639d5..e1fac82e 100644 --- a/examples/config_tiny_llama.yaml +++ b/examples/config_tiny_llama.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 10 + checkpoint_interval: 10000 checkpoints_path: checkpoints checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null @@ -10,25 +10,25 @@ data_stages: dataset_overwrite_cache: false dataset_processing_num_proc_per_process: 1 hf_dataset_config_name: null - hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_or_datasets: roneneldan/TinyStories hf_dataset_splits: train text_column_name: text num_loading_workers: 1 seed: 42 name: Stable Training Stage start_training_step: 1 -- data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 1 - hf_dataset_config_name: null - hf_dataset_or_datasets: stas/openwebtext-10k - hf_dataset_splits: train - text_column_name: text - num_loading_workers: 1 - seed: 42 - name: Annealing Phase - start_training_step: 10 +# - data: +# dataset: +# dataset_overwrite_cache: false +# dataset_processing_num_proc_per_process: 1 +# hf_dataset_config_name: null +# hf_dataset_or_datasets: stas/openwebtext-10k +# hf_dataset_splits: train +# text_column_name: text +# num_loading_workers: 1 +# seed: 42 +# name: Annealing Phase +# start_training_step: 10 general: benchmark_csv_path: null consumed_train_samples: null @@ -44,7 +44,7 @@ logging: log_level_replica: info model: ddp_bucket_cap_mb: 25 - dtype: float8 + dtype: bfloat16 init_method: std: 0.025 make_vocab_size_divisible_by: 1 @@ -52,13 +52,13 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 16 + hidden_size: 1024 initializer_range: 0.02 - intermediate_size: 64 + intermediate_size: 4096 is_llama_config: true - max_position_embeddings: 256 + max_position_embeddings: 1024 num_attention_heads: 4 - num_hidden_layers: 2 + num_hidden_layers: 6 num_key_value_heads: 4 pad_token_id: null pretraining_tp: 1 @@ -66,7 +66,7 @@ model: rope_scaling: null tie_word_embeddings: true use_cache: true - vocab_size: 256 + vocab_size: 1024 optimizer: accumulate_grad_in_fp32: true clip_grad: 1.0 @@ -87,13 +87,13 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 2 + dp: 1 expert_parallel_size: 1 - pp: 2 + pp: 1 pp_engine: 1f1b tp: 2 - tp_linear_async_communication: true - tp_mode: REDUCE_SCATTER + tp_linear_async_communication: false + tp_mode: ALL_REDUCE profiler: null tokenizer: tokenizer_max_length: null @@ -104,6 +104,6 @@ tokens: limit_test_batches: 0 limit_val_batches: 0 micro_batch_size: 2 - sequence_length: 256 - train_steps: 15 + sequence_length: 1024 + train_steps: 1500 val_check_interval: -1 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index e6e74ecb..6ebbeb31 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -28,17 +28,16 @@ from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.nn.activations import ACT2FN -from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer from nanotron.parallel.pipeline_parallel.p2p import P2P from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( - FP8TensorParallelColumnLinear, - FP8TensorParallelRowLinear, + TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelLinearMode, + TensorParallelRowLinear, ) from nanotron.random import RandomStates from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator @@ -222,7 +221,7 @@ def __init__( config.intermediate_size, # shape of up_linear ) # self.gate_up_proj = TensorParallelColumnLinear( - self.gate_up_proj = FP8TensorParallelColumnLinear( + self.gate_up_proj = TensorParallelColumnLinear( config.hidden_size, 2 * config.intermediate_size, pg=tp_pg, @@ -230,18 +229,16 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, - name=f"model.decoder.{layer_idx}.mlp.gate_up_proj", + # name=f"model.decoder.{layer_idx}.mlp.gate_up_proj", # tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) - # self.down_proj = TensorParallelRowLinear( - self.down_proj = FP8TensorParallelRowLinear( + self.down_proj = TensorParallelRowLinear( config.intermediate_size, config.hidden_size, pg=tp_pg, mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, - name=f"model.decoder.{layer_idx}.mlp.down_proj", ) self.split_silu_mul = GLUActivation(config.hidden_act) @@ -386,8 +383,8 @@ def __init__( config.num_key_value_heads * self.d_qk, # shape of k config.num_key_value_heads * self.d_qk, # shape of v ) - # self.qkv_proj = TensorParallelColumnLinear( - self.qkv_proj = FP8TensorParallelColumnLinear( + # self.qkv_proj = FP8TensorParallelColumnLinear( + self.qkv_proj = TensorParallelColumnLinear( self.d_model, config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk, pg=tp_pg, @@ -395,7 +392,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, - name=f"model.decoder.{layer_idx}.attention.qkv_proj", + # name=f"model.decoder.{layer_idx}.attention.qkv_proj", # tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. @@ -418,15 +415,14 @@ def __init__( dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved ) - # self.o_proj = TensorParallelRowLinear( - self.o_proj = FP8TensorParallelRowLinear( + self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, self.d_model, pg=tp_pg, mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - name=f"model.decoder.{layer_idx}.attention.o_proj", + # name=f"model.decoder.{layer_idx}.attention.o_proj", ) self.attention = CoreAttention( @@ -710,7 +706,10 @@ def __init__( layer_idx: int, ): super().__init__() - self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # NOTE: i got an illegal memory access was encountered when using TritonRMSNorm + # even downgrad flash_attn to 2.4.2 + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = CausalSelfAttention( config=config, parallel_config=parallel_config, @@ -718,7 +717,8 @@ def __init__( layer_idx=layer_idx, ) - self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx) self.recompute_layer = parallel_config.recompute_layer @@ -856,8 +856,10 @@ def __init__( self.final_layer_norm = PipelineBlock( p2p=self.p2p, - module_builder=TritonRMSNorm, - module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + # module_builder=TritonRMSNorm, + # module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_builder=nn.LayerNorm, + module_kwargs={"normalized_shape": config.hidden_size, "eps": config.rms_norm_eps}, module_input_keys={"input"}, module_output_keys={"hidden_states"}, ) # TODO @@ -865,8 +867,8 @@ def __init__( self.lm_head = PipelineBlock( p2p=self.p2p, # Understand that this means that we return sharded logits that are going to need to be gathered - # module_builder=TensorParallelColumnLinear, - module_builder=FP8TensorParallelColumnLinear, + # module_builder=FP8TensorParallelColumnLinear, + module_builder=TensorParallelColumnLinear, module_kwargs={ "in_features": config.hidden_size, "out_features": config.vocab_size, @@ -930,8 +932,8 @@ def get_block_compute_costs(self): LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + 3 * d_ff * model_config.hidden_size, # This is the last lm_head - # TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, - FP8TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + # FP8TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } return block_compute_costs diff --git a/src/nanotron/nn/layer_norm.py b/src/nanotron/nn/layer_norm.py index ef3b4c50..7a8fcaad 100644 --- a/src/nanotron/nn/layer_norm.py +++ b/src/nanotron/nn/layer_norm.py @@ -39,9 +39,58 @@ def reset_parameters(self): def forward( self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False ): - from flash_attn.ops.triton.layer_norm import layer_norm_fn + # NOTE: fa=2.6.3 + # got the following errors: + # Traceback (most recent call last): + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl + # return self._call_impl(*args, **kwargs) + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl + # return forward_call(*args, **kwargs) + # File "/fsx/phuc/temp/fp8_for_nanotron/nanotron/src/nanotron/nn/layer_norm.py", line 44, in forward + # return layer_norm_fn( + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py", line 875, in layer_norm_fn + # return LayerNormFn.apply( + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply + # return super().apply(*args, **kwargs) # type: ignore[misc] + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py", line 748, in forward + # y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py", line 335, in _layer_norm_fwd + # _layer_norm_fwd_1pass_kernel[(M,)]( + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in + # return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 156, in run + # timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 156, in + # timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 133, in _bench + # return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/triton/testing.py", line 104, in do_bench + # torch.cuda.synchronize() + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/torch/cuda/__init__.py", line 783, in synchronize + # return torch._C._cuda_synchronize() + # RuntimeError: CUDA error: an illegal memory access was encountered + # CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. + # For debugging consider passing CUDA_LAUNCH_BLOCKING=1. + # Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. - return layer_norm_fn( + # from flash_attn.ops.triton.layer_norm import layer_norm_fn + # return layer_norm_fn( + # input, + # self.weight, + # None, + # residual=residual, + # eps=self.eps, + # dropout_p=dropout_p, + # prenorm=prenorm, + # residual_in_fp32=residual_in_fp32, + # is_rms_norm=True, + # return_dropout_mask=return_dropout_mask, + # ) + + # NOTE: fa=2.4.2 + from flash_attn.ops.triton.layernorm import rms_norm_fn + + return rms_norm_fn( input, self.weight, None, @@ -50,6 +99,6 @@ def forward( dropout_p=dropout_p, prenorm=prenorm, residual_in_fp32=residual_in_fp32, - is_rms_norm=True, + # is_rms_norm=True, # NOTE: fa=2.4.2 don't use this? wtf dao return_dropout_mask=return_dropout_mask, ) diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 2e940744..74165ca2 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -211,7 +211,14 @@ def backward(self, loss: torch.Tensor): def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: """Accumulate grad in fp32 and set the fp32 grad to the fp32 grad buffer, so that optimizer can update fp32 weights afterwards""" - assert half_param.grad is not None, f"Expected param {name} to have gradient." + if name == "model.decoder.4.pp_block.attn.qkv_proj.weight": + assert 1 == 1 + + try: + assert half_param.grad is not None, f"Expected param {name} to have gradient." + except AssertionError: + assert 1 == 1 + fp32_grad = self.get_grad_buffer(name=name) if self._is_accumulation_sync_step is False: diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 22e8b72e..6d56ed7f 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -21,7 +21,6 @@ import nanotron.distributed as dist from nanotron.fp8.linear import FP8LinearMeta from nanotron.fp8.recipe import FP8LinearRecipe -from nanotron.parallel.parameters import get_data_from_param from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, @@ -443,10 +442,10 @@ def column_linear( name: Optional[str] = None, recipe: Optional[FP8LinearRecipe] = None, ): - weight = get_data_from_param(weight) + # weight = get_data_from_param(weight) - if bias is not None: - bias = get_data_from_param(bias) + # if bias is not None: + # bias = get_data_from_param(bias) if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) @@ -632,9 +631,9 @@ def row_linear( recipe: Optional[FP8LinearRecipe] = None, name: Optional[str] = None, ): - weight = get_data_from_param(weight) - if bias is not None: - bias = get_data_from_param(bias) + # weight = get_data_from_param(weight) + # if bias is not None: + # bias = get_data_from_param(bias) if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..a8f5f93d 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -37,6 +37,7 @@ def __init__(self, config: ModelArgs): TensorParallelColumnLinear: self._parametrize_column_linear, TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, + nn.LayerNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, }