From b44040814beafcce3833dd6df3ca97efe9e660c3 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 18 Dec 2024 17:32:56 +0000 Subject: [PATCH] fix nan in fwd pass --- ...ooth_but_uclipping_and_adam_eps1.0e-8.yaml | 1756 +++++++++++++++++ src/nanotron/config/config.py | 4 + src/nanotron/constants.py | 1 + src/nanotron/fp8/functional.py | 8 +- src/nanotron/fp8/linear.py | 18 +- src/nanotron/fp8/utils.py | 22 +- src/nanotron/helpers.py | 16 + src/nanotron/models/llama.py | 39 + src/nanotron/optim/gradient_accumulator.py | 3 +- src/nanotron/trainer.py | 2 +- 10 files changed, 1853 insertions(+), 16 deletions(-) create mode 100644 examples/exp660a0_1b_fp8_like_exp661ad0_with_fp8opt_and_smolds_and_tp1_and_dp80_and_mbs16_and_gbs_1.3m_and_100b_tokens_and_80k_steps_and_gclip1.0_and_2ndbfl16_and_and_siluact_fa2_and_e4m3_wacts_and_e5m2_grads_and_no_smooth_but_uclipping_and_adam_eps1.0e-8.yaml diff --git a/examples/exp660a0_1b_fp8_like_exp661ad0_with_fp8opt_and_smolds_and_tp1_and_dp80_and_mbs16_and_gbs_1.3m_and_100b_tokens_and_80k_steps_and_gclip1.0_and_2ndbfl16_and_and_siluact_fa2_and_e4m3_wacts_and_e5m2_grads_and_no_smooth_but_uclipping_and_adam_eps1.0e-8.yaml b/examples/exp660a0_1b_fp8_like_exp661ad0_with_fp8opt_and_smolds_and_tp1_and_dp80_and_mbs16_and_gbs_1.3m_and_100b_tokens_and_80k_steps_and_gclip1.0_and_2ndbfl16_and_and_siluact_fa2_and_e4m3_wacts_and_e5m2_grads_and_no_smooth_but_uclipping_and_adam_eps1.0e-8.yaml new file mode 100644 index 00000000..8f3f9c8f --- /dev/null +++ b/examples/exp660a0_1b_fp8_like_exp661ad0_with_fp8opt_and_smolds_and_tp1_and_dp80_and_mbs16_and_gbs_1.3m_and_100b_tokens_and_80k_steps_and_gclip1.0_and_2ndbfl16_and_and_siluact_fa2_and_e4m3_wacts_and_e5m2_grads_and_no_smooth_but_uclipping_and_adam_eps1.0e-8.yaml @@ -0,0 +1,1756 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: /fsx/phuc/new_workspace/experiments/fp8_for_nanotron/exp660a0_1b_fp8_like_exp661ad0_with_fp8opt_and_smolds_and_tp1_and_dp80_and_mbs16_and_gbs_1.3m_and_100b_tokens_and_80k_steps_and_gclip1.0_and_2ndbfl16_and_and_siluact_fa2_and_e4m3_wacts_and_e5m2_grads_and_no_smooth_but_uclipping_and_adam_eps1.0e-8/checkpoints/ + checkpoints_path_is_shared_file_system: false + # resume_checkpoint_path: /fsx/phuc/new_workspace/experiments/fp8_for_nanotron/exp608ba1_100m_fp8_like_exp602ah01_with_fp8_optim_and_adam_epsilon_1.0e-5_and_smol_ds_but_tp8_and_dp10_and_mbs_2560_and_gbs_1.6m_and_100b_tokens_and_70k_steps.6m_and_100b_tokens_and_70k_steps/checkpoints/ + save_initial_state: false + +# NOTE: the old one +# data_stages: +# - data: +# dataset: +# dataset_overwrite_cache: false +# dataset_processing_num_proc_per_process: 1 +# hf_dataset_config_name: null +# hf_dataset_or_datasets: roneneldan/TinyStories +# hf_dataset_splits: train +# text_column_name: text +# num_loading_workers: 1 +# seed: 42 +# name: Stable Training Stage +# start_training_step: 1 + +# data: +# dataset: +# dataloader_type: single +# dataset_max_tokens: null +# dataset_weights: +# - 0.5 +# - 0.4 +# - 0.1 +# datasets: +# - filename_pattern: .*.ds +# folder: /fsx/loubna/tokenized_for_exps/fw_edu/fineweb-edu-full-cosmo2_merged +# skip_tokens: 0 +# - filename_pattern: .*.ds +# folder: /fsx/loubna/tokenized_for_exps/fw_edu/dclm-3T-cosmo2_merged +# skip_tokens: 0 +# - filename_pattern: .*.ds +# folder: /fsx/loubna/tokenized_for_exps/fw_edu/starcoderdata-full-cosmo_merged +# skip_tokens: 0 +# pad_samples_to_global_batch_size: false +# skip_in_stream: true +# num_loading_workers: 0 +# seed: 42 + +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: roneneldan/TinyStories + hf_dataset_splits: train + text_column_name: text + # hf_dataset_or_datasets: stas/openwebtext-10k + # hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small + # hf_dataset_splits: train + # text_column_name: prompt + num_loading_workers: 0 + seed: 42 + name: Stable Training Stage + start_training_step: 1 + +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: fp8_for_nanotron + run: exp660a0_1b_fp8_like_exp661ad0_with_fp8opt_and_smolds_and_tp1_and_dp80_and_mbs16_and_gbs_1.3m_and_100b_tokens_and_80k_steps_and_gclip1.0_and_2ndbfl16_and_and_siluact_fa2_and_e4m3_wacts_and_e5m2_grads_and_no_smooth_but_uclipping_and_adam_eps1.0e-8 + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info + # monitor_model_states: false + # monitor_model_states_using_hooks: false +model: + ddp_bucket_cap_mb: 25 + dtype: float8 + init_method: + # std: 0.25 # sqrt(1/16) + # std: 0.125 # sqrt(1/64) + # std: 0.04419417382415922 # sqrt(1/512) + std: 0.02209708691207961 # sqrt(1/2048) + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + initializer_range: 0.02 + + hidden_size: 2048 + intermediate_size: 8192 + num_hidden_layers: 14 + + is_llama_config: true + max_position_embeddings: 1024 + num_attention_heads: 16 + num_key_value_heads: 16 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: false + use_cache: true + vocab_size: 49152 + +optimizer: + accumulate_grad_in_fp32: true + learning_rate_scheduler: + learning_rate: 0.0006 + lr_decay_starting_step: null + lr_decay_steps: null + lr_decay_style: cosine + lr_warmup_steps: 16_000 # 10% warm up of total training steps + lr_warmup_style: linear + min_decay_lr: 0.00006 + + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.1 + zero_stage: 0 + clip_grad: 1.0 + # update_clipping: true + +parallelism: + # large batch training + dp: 2 + tp: 1 + # dp: 2 + # tp: 2 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp_linear_async_communication: false + tp_mode: ALL_REDUCE + +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: lvwerra/the-tokenizer-v1 + tokenizer_revision: null +tokens: + # NOTE: micro_batch_size * sequence_length * batch_accumulation_per_replica + # = 128 * 256 * 1 = 16384 + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + # large batch training + micro_batch_size: 2 # 16 * 1024 * 8 = 130k tokens per batch + # micro_batch_size: 16 + sequence_length: 1024 + train_steps: 10 + val_check_interval: -1 + +fp8: + resid_dtype: float32 + accum_dtype: bfloat16 + model: + - module_name: model.decoder.1.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.1.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.1.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.1.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + + # NOTE: layer 2 + - module_name: model.decoder.2.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.2.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.2.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.2.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + # NOTE: layer 3 + - module_name: model.decoder.3.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.3.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.3.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.3.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + + # NOTE: layer 4 + - module_name: model.decoder.4.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.4.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.4.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.4.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + + # NOTE: layer 5 + - module_name: model.decoder.5.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.5.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.5.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.5.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + + # NOTE: layer 6 + - module_name: model.decoder.6.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.6.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.6.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.6.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + # NOTE: layer 7 + - module_name: model.decoder.7.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.7.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.7.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.7.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + # NOTE: layer 8 + - module_name: model.decoder.8.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.8.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.8.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.8.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + + # NOTE: layer 9 + - module_name: model.decoder.9.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.9.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.9.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.9.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + + # NOTE: layer 10 + - module_name: model.decoder.10.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.10.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.10.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.10.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + + # NOTE: layer 11 + - module_name: model.decoder.11.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.11.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.11.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.11.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + + # NOTE: layer 12 + - module_name: model.decoder.12.attn.qkv_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.12.attn.o_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.12.mlp.gate_up_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + - module_name: model.decoder.12.mlp.down_proj + accum_dtype: bfloat16 + input: + dtype: fp8e4m3 + margin: 0 + interval: 16 + weight: + dtype: fp8e4m3 + margin: 0 + interval: 1 + bias: bfloat16 + input_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + weight_grad: + dtype: fp8e5m2 + margin: 0 + interval: 1 + output_grad: + dtype: fp8e5m2 + margin: 0 + interval: 16 + split_accumulator: + output: true + input_grad: true + weight_grad: true + accumulate: + output: true + input_grad: true + weight_grad: true + smooth_quant: false + + optim: + master_weight_dtype: kfloat16 + accum_dtype: float32 + exp_avg_dtype: fp8e4m3 + exp_avg_sq_dtype: bfloat16 + + clipped_softmax: false + clipped_softmax_zeta: 1.3 + clipped_softmax_gamma: -0.03 + + layer_scale: false + layer_scale_init: zeros + + qk_norm_before_pos: false + smooth_quant: false + update_clipping: true + + skip_param_update_if_nan: true + is_directly_keep_accum_grad_of_fp8: false # original config kept true + +# s3_upload: +# remove_after_upload: true +# s5cmd_concurrency: 5 +# s5cmd_numworkers: 16 +# s5cmd_path: /fsx/nouamane/miniconda/envs/2-1-cu121/bin/s5cmd +# upload_s3_path: s3://phuc-experiments/fp8_for_nanotron/exp660a0_1b_fp8_like_exp661ad0_with_fp8opt_and_smolds_and_tp1_and_dp80_and_mbs16_and_gbs_1.3m_and_100b_tokens_and_80k_steps_and_gclip1.0_and_2ndbfl16_and_and_siluact_fa2_and_e4m3_wacts_and_e5m2_grads_and_no_smooth_but_uclipping_and_adam_eps1.0e-8 + +# experiment_logger: +# # id: exp614ba4_100m_fp8_like_exp602ah01_with_fp8_optim_and_adam_epsilon_1.0e-7_and_smol_ds_but_tp8_and_dp10_and_mbs_2560_and_gbs_1.6m_and_100b_tokens_and_70k_steps +# tensorboard_logger: +# flush_secs: 30 +# tensorboard_dir: /fsx/phuc/new_workspace/experiments/fp8_for_nanotron/exp660a0_1b_fp8_like_exp661ad0_with_fp8opt_and_smolds_and_tp1_and_dp80_and_mbs16_and_gbs_1.3m_and_100b_tokens_and_80k_steps_and_gclip1.0_and_2ndbfl16_and_and_siluact_fa2_and_e4m3_wacts_and_e5m2_grads_and_no_smooth_but_uclipping_and_adam_eps1.0e-8/logs/tb_logs +# wandb_logger: +# wandb_entity: neuralink +# wandb_project: fp8_for_nanotron diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 784082fd..c5bd2d99 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -444,6 +444,9 @@ def get_config_from_dict( for k, v in config_dict.items() if v is not None } + + from nanotron.fp8.dtypes import DTypes + return from_dict( data_class=config_class, data=config_dict, @@ -455,6 +458,7 @@ def get_config_from_dict( TensorParallelLinearMode: lambda x: TensorParallelLinearMode[x.upper()], RecomputeGranularity: lambda x: RecomputeGranularity[x.upper()], SamplerType: lambda x: SamplerType[x.upper()], + DTypes: lambda x: DTypes[x.upper()], # Add this line, }, # strict_unions_match=True, strict=True, diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 306baecf..e0603164 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -22,3 +22,4 @@ # TODO(xrsrke): refactor CPU_WEIGHTS = {} +ACCUM_GRADS = {} diff --git a/src/nanotron/fp8/functional.py b/src/nanotron/fp8/functional.py index 22f61654..e695020a 100644 --- a/src/nanotron/fp8/functional.py +++ b/src/nanotron/fp8/functional.py @@ -5,6 +5,7 @@ from nanotron.fp8.linear import FP8LinearMeta from nanotron.fp8.recipe import FP8LinearRecipe from nanotron.fp8.tensor import FP8Tensor +from nanotron.fp8.utils import is_overflow_underflow_nan from nanotron.parallel.parameters import NanotronParameter @@ -74,8 +75,13 @@ def linear( # because weight and bias's requires_grad are set to False # so that we can compute the gradients using the fp8 kernels by ourselves phony = torch.empty(0, device=input.device, requires_grad=True) - output = torch.empty(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype) + # NOTE: interesting that if i initialize the output buffer as torch.empty + # it leads to nan matmul, so i do torch.zeros instead + # output = torch.empty(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype) + output = torch.zeros(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype) output, _ = _FP8Matmul.apply(input, weight, output, phony, metadatas, recipe, name) + if is_overflow_underflow_nan(output) is True: + assert 1 == 1 # TODO(xrsrke): add support for adding bias in fp8 # TODO(xrsrke): support return an fp8 tensor as output diff --git a/src/nanotron/fp8/linear.py b/src/nanotron/fp8/linear.py index 82dadf76..b44e5883 100644 --- a/src/nanotron/fp8/linear.py +++ b/src/nanotron/fp8/linear.py @@ -173,6 +173,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ from nanotron import constants from nanotron.config.fp8_config import FP8Args + from nanotron.fp8.utils import is_overflow_underflow_nan # pydevd.settrace(suspend=False, trace_only_current_thread=True) if ( @@ -200,6 +201,8 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ fp8_input = cast(FP8Tensor, fp8_input) fp8_weight = cast(FP8Tensor, fp8_weight) + assert is_overflow_underflow_nan(grad_output) is False, f"name: {ctx.name}" + ctx.metadatas = cast(FP8LinearMeta, ctx.metadatas) if ctx.metadatas.input_grad is None: fp8_grad_output = FP8Tensor( @@ -214,7 +217,8 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ if ctx.is_input_require_grad: transposed_fp8_weight = fp8_weight.transpose_fp8() - grad_input_temp = torch.empty( + # NOTE: same reason as output buffer in .forward + grad_input_temp = torch.zeros( fp8_grad_output.shape[0], transposed_fp8_weight.shape[0], device="cuda", @@ -232,11 +236,14 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ else: grad_input = None + assert is_overflow_underflow_nan(grad_input) is False + # TODO(xrsrke): fuse cast and transpose transposed_fp8_grad_output = fp8_grad_output.transpose_fp8() transposed_fp8_input = fp8_input.transpose_fp8() - grad_weight_temp = torch.empty( + # NOTE: same reason as output buffer in .forward + grad_weight_temp = torch.zeros( transposed_fp8_input.shape[0], transposed_fp8_grad_output.shape[0], device="cuda", @@ -250,6 +257,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ accumulate=recipe.accumulate.weight_grad, accum_qtype=recipe.accum_dtype, ) + assert is_overflow_underflow_nan(grad_weight) is False if ctx.is_input_require_grad: assert grad_input.dtype == recipe.accum_dtype @@ -272,8 +280,8 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ # File "/fsx/phuc/temp/temp3_env_for_fp8/env/lib/python3.10/site-packages/torch/_tensor.py", line 1386, in __torch_function__ # ret = func(*args, **kwargs) # RuntimeError: attempting to assign a gradient with dtype 'c10::BFloat16' to a tensor with dtype 'unsigned char'. Please ensure that the gradient and the tensor have the same dtype - fp8_weight.__accum_grad = grad_weight - assert fp8_weight.__accum_grad.dtype in [torch.float16, torch.bfloat16, torch.float32] + # fp8_weight.__accum_grad = grad_weight + # assert fp8_weight.__accum_grad.dtype in [torch.float16, torch.bfloat16, torch.float32] # constants.ACCUM_GRADS[ctx.name] = grad_weight set_accum_grad(ctx.name, grad_weight) else: @@ -295,4 +303,4 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ # NOTE: sanity check assert isinstance(fp8_weight_param.grad, FP8Tensor) - return grad_input, fp8_weight_grad, None, None, None, None, None + return grad_input, None, None, None, None, None, None diff --git a/src/nanotron/fp8/utils.py b/src/nanotron/fp8/utils.py index 47653a89..25154660 100644 --- a/src/nanotron/fp8/utils.py +++ b/src/nanotron/fp8/utils.py @@ -239,18 +239,23 @@ def find_fp8_config_by_module_name(target_module_name: str, config: FP8Args) -> if config.model is not None: for layer_args in config.model: - if layer_args.module_name == target_module_name: + if layer_args.module_name == target_module_name.replace("pp_block.", "").replace("module.", ""): return layer_args # elif config.is_quant_all_except_first_and_last: else: def match_layer_pattern(name, layer_idxs): + # patterns = [ + # "model.decoder.{}.pp_block.attn.qkv_proj", + # "model.decoder.{}.pp_block.attn.o_proj", + # "model.decoder.{}.pp_block.mlp.down_proj", + # "model.decoder.{}.pp_block.mlp.gate_up_proj", + # ] patterns = [ - "model.decoder.{}.pp_block.attn.qkv_proj", - "model.decoder.{}.pp_block.attn.o_proj", - "model.decoder.{}.pp_block.mlp.down_proj", - # "model.decoder.{}.mlp.up_proj", - "model.decoder.{}.pp_block.mlp.gate_up_proj", + "model.decoder.{}.attn.qkv_proj", + "model.decoder.{}.attn.o_proj", + "model.decoder.{}.mlp.down_proj", + "model.decoder.{}.mlp.gate_up_proj", ] for idx in layer_idxs: @@ -267,12 +272,13 @@ def match_layer_pattern(name, layer_idxs): # assert config.fp8_linear_config_temp is not None quant_layer_idxs = list(range(1, num_layers - 1)) - if match_layer_pattern(target_module_name, quant_layer_idxs) is True: + # NOTE: remove ".pp_block" from module name + if match_layer_pattern(target_module_name.replace(".pp_block", ""), quant_layer_idxs) is True: from copy import deepcopy # config_temp = deepcopy(config.fp8_linear_config_temp) config_temp = deepcopy(FP8LM_LINEAR_RECIPE) - # config_temp.module_name = target_module_name + config_temp.module_name = target_module_name return config_temp # else: # from nanotron.fp8.constant_recipe import MODULE_NAMES_THAT_NOT_FP8 diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 87a70c53..f168f934 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -737,3 +737,19 @@ def get_consumed_train_samples_of_a_data_stage_from_ckp( (s.consumed_train_samples for s in metadata.data_stages if s.start_training_step == start_training_step), None, ) + + +def get_accum_grad(param_name): + from nanotron import constants + + assert "bias" not in param_name + # return constants.ACCUM_GRADS[param_name.replace("weight", "")] + return constants.ACCUM_GRADS[param_name.replace(".weight", "").replace(".pp_block", "")] + + +def set_accum_grad(param_name, value): + from nanotron import constants + + assert "bias" not in param_name + # constants.ACCUM_GRADS[param_name.replace("weight", "")] = value + constants.ACCUM_GRADS[param_name.replace(".weight", "").replace(".pp_block", "")] = value diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 35d174a4..dc254432 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -24,6 +24,7 @@ from nanotron import logging from nanotron.config import Config, LlamaConfig, ParallelismArgs from nanotron.config.models_config import RandomInit, SpectralMupInit +from nanotron.fp8.utils import is_overflow_underflow_nan from nanotron.generation.generate_store import AttachableStore from nanotron.logging import log_rank from nanotron.models import NanotronModel @@ -431,6 +432,7 @@ def __init__( parallel_config=parallel_config, layer_idx=layer_idx, ) + self.layer_idx = layer_idx self.prefill_kv_len = ( config.max_position_embeddings @@ -452,6 +454,8 @@ def forward( ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape + assert is_overflow_underflow_nan(qkv_states) is False, f"layer_idx: {self.layer_idx}" + if self.is_gqa: query_states, key_states, value_states = torch.split( qkv_states, @@ -661,6 +665,10 @@ def forward( key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0) # [batch_size, seq_length, 2, num_heads, d_qk] key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous() + + assert is_overflow_underflow_nan(query_states) is False, f"layer_idx: {self.layer_idx}" + assert is_overflow_underflow_nan(key_value_states) is False, f"layer_idx: {self.layer_idx}" + query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states) # [batch_size, seq_length, num_heads, d_qk] key_states, value_states = torch.split(key_value_states, 1, dim=2) @@ -685,10 +693,19 @@ def forward( # NOTE: even though in some cases, we accumulate fp8 gemm in bfloat16, # but since the layer norm are in float32, the resulting output will be in float32 # and flash attention don't support float32 qkv, so we have to cast it back to bfloat16 + + assert is_overflow_underflow_nan(query_states) is False, f"layer_idx: {self.layer_idx}" + assert is_overflow_underflow_nan(key_states) is False, f"layer_idx: {self.layer_idx}" + assert is_overflow_underflow_nan(value_states) is False, f"layer_idx: {self.layer_idx}" + query_states = query_states.to(torch.bfloat16) key_states = key_states.to(torch.bfloat16) value_states = value_states.to(torch.bfloat16) + assert is_overflow_underflow_nan(query_states) is False, f"layer_idx: {self.layer_idx}" + assert is_overflow_underflow_nan(key_states) is False, f"layer_idx: {self.layer_idx}" + assert is_overflow_underflow_nan(value_states) is False, f"layer_idx: {self.layer_idx}" + attention_output = self.attention( query_states=query_states, key_states=key_states, @@ -700,6 +717,14 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) + from nanotron import constants + + if attention_output.dtype != constants.CONFIG.fp8.resid_dtype: + assert is_overflow_underflow_nan(attention_output) is False, f"layer_idx: {self.layer_idx}" + attention_output = attention_output.to(constants.CONFIG.fp8.resid_dtype) + assert is_overflow_underflow_nan(attention_output) is False, f"layer_idx: {self.layer_idx}" + + assert is_overflow_underflow_nan(attention_output) is False, f"layer_idx: {self.layer_idx}" output = self.o_proj(attention_output) return {"hidden_states": output, "sequence_mask": sequence_mask} @@ -730,6 +755,7 @@ def __init__( self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx) self.recompute_layer = parallel_config.recompute_layer + self.layer_idx = layer_idx def _core_forward( self, @@ -738,15 +764,22 @@ def _core_forward( ) -> List[Union[torch.Tensor, TensorPointer]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) hidden_states = output["hidden_states"] + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" hidden_states = hidden_states + residual + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" + hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" hidden_states = hidden_states + residual + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" return hidden_states, output["sequence_mask"] @@ -920,14 +953,20 @@ def forward_with_hidden_states( "hidden_states": output["input_embeds"], "sequence_mask": input_mask, } + assert is_overflow_underflow_nan(hidden_encoder_states["hidden_states"]) is False + for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) + assert is_overflow_underflow_nan(hidden_encoder_states["hidden_states"]) is False hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + assert is_overflow_underflow_nan(hidden_states) is False sharded_logits = self.lm_head(x=hidden_states)["logits"] + assert is_overflow_underflow_nan(sharded_logits) is False fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + assert is_overflow_underflow_nan(fp32_sharded_logits) is False return fp32_sharded_logits, hidden_states diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 5147fb70..30d1506e 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -294,7 +294,7 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: from nanotron.fp8.utils import is_overflow_underflow_nan - assert is_overflow_underflow_nan(grad) is False + assert is_overflow_underflow_nan(grad) is False, f"name: {name}" fp32_grad = self.get_grad_buffer(name=name) @@ -324,6 +324,7 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: else: grad = fp32_grad fp32_param.grad = grad + assert is_overflow_underflow_nan(fp32_param.grad) is False @contextmanager def no_sync(self): diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index d99b2692..b89c9ba4 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -210,7 +210,7 @@ def __init__( # from nanotron import constants for n, p in self.model.named_parameters(): if hasattr(p, "_is_future_fp8") and p._is_future_fp8 is True: - constants.CPU_WEIGHTS[n] = p.data.cpu().clone() + constants.CPU_WEIGHTS[n.replace("module.", "")] = p.data.cpu().clone() # NOTE: sanity check all hash are different param_hash = []