From 6ac7d73369c1069e9887bd74cd4abc60aaabf06c Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 13 Dec 2024 18:45:24 +0000 Subject: [PATCH] stress test --- examples/config_tiny_llama_bench.yaml | 94 +++++++ run_multinode.sh | 9 +- scaling_benchmarks.py | 336 ++++++++++++++++++++------ src/nanotron/distributed.py | 2 +- src/nanotron/helpers.py | 22 +- src/nanotron/serialize/__init__.py | 1 + src/nanotron/trainer.py | 29 ++- test_2nodes.sh | 63 +++++ 8 files changed, 470 insertions(+), 86 deletions(-) create mode 100644 examples/config_tiny_llama_bench.yaml create mode 100755 test_2nodes.sh diff --git a/examples/config_tiny_llama_bench.yaml b/examples/config_tiny_llama_bench.yaml new file mode 100644 index 00000000..af4a835a --- /dev/null +++ b/examples/config_tiny_llama_bench.yaml @@ -0,0 +1,94 @@ +# /fsx/nouamane/miniconda/envs/2-1-cu121/bin/torchrun --nproc_per_node=8 run_train.py --config-file examples/config_tiny_llama.yaml +# NANOTRON_BENCHMARK=1 CUDA_DEVICE_MAX_CONNECTIONS=1 /fsx/nouamane/miniconda/envs/2-1-cu121/bin/torchrun --nproc_per_node=8 run_train.py --config-file examples/config_tiny_llama.yaml +checkpoints: + checkpoint_interval: 10000 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: null # Custom dataloader will be used + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +general: + benchmark_csv_path: bench.csv + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: dp2_tp8_seq64k + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.02 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 0 + eos_token_id: 0 + hidden_act: silu + hidden_size: 3072 + initializer_range: 0.02 + intermediate_size: 8192 + is_llama_config: true + max_position_embeddings: 2048 + num_attention_heads: 24 + num_hidden_layers: 28 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 8 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null + # profiler_export_path: ./tb_logs +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 1 + sequence_length: 2048 + train_steps: 100 + val_check_interval: 100 diff --git a/run_multinode.sh b/run_multinode.sh index cc64a10a..86f98a15 100644 --- a/run_multinode.sh +++ b/run_multinode.sh @@ -1,9 +1,10 @@ #!/bin/bash #SBATCH --job-name=smolm2-bench # Job name -#SBATCH --time=00:15:00 +#SBATCH --time=00:02:00 #SBATCH --partition=hopper-prod #SBATCH --qos=high +#SBATCH --reservation=huggingface_37 #SBATCH -o /fsx/nouamane/projects/nanotron/logs/%j-%x.out @@ -41,6 +42,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 # Nanotron specific export NANOTRON_BENCHMARK=1 + # Print some debugging information echo "Master node: $MASTER_NODE" echo "All nodes: $NODELIST" @@ -53,5 +55,6 @@ srun torchrun \ --rdzv_id=$SLURM_JOB_ID \ --rdzv_backend=c10d \ --rdzv_endpoint=$MASTER_NODE:$MASTER_PORT \ - run_train.py \ - --config-file examples/config_tiny_llama.yaml + stress_test.py \ + # run_train.py \ + # --config-file examples/config_tiny_llama.yaml diff --git a/scaling_benchmarks.py b/scaling_benchmarks.py index 1e37a1ec..8ac9a9b7 100644 --- a/scaling_benchmarks.py +++ b/scaling_benchmarks.py @@ -4,22 +4,22 @@ import math import os +import pandas as pd import yaml from nanotron.logging import human_format +from tqdm import tqdm -VOCAB_SIZE = 32768 -NUM_KEY_VALUE_HEADS = None -TIE_WORD_EMBEDDINGS = True -ZERO_STAGE = 0 -# TP_MODE = "REDUCE_SCATTER" # "REDUCE_SCATTER" "ALL_REDUCE" -TP_MODE = "ALL_REDUCE" # "REDUCE_SCATTER" "ALL_REDUCE" -PROFILE = True +ACCUMULATE_GRAD_IN_FP32 = True +NUM_KEY_VALUE_HEADS = 8 -def estimate_num_params(layers, hidden_size, heads, intermediate_size, tie_word_embeddings): - # params = 2*V*h + l(3*h*H + 4*h*h) = (2)Vh + 16lh^2 - vocab = VOCAB_SIZE * hidden_size if tie_word_embeddings else 2 * VOCAB_SIZE * hidden_size - return vocab + layers * (3 * hidden_size * intermediate_size + 4 * hidden_size * hidden_size) +def estimate_num_params(layers, hidden_size, heads, intermediate_size, tie_word_embeddings, vocab, kv_heads=None): + # params = 2*V*h + l(3*h*H + (2 + 2*q/kv_ratio)*h*h) + # For GQA with 8 KV heads and 32 attention heads (4x ratio), it's: 2*V*h + l(3*h*H + (2 + 2/4)*h*h) + vocab = vocab * hidden_size if tie_word_embeddings else 2 * vocab * hidden_size + kv_ratio = kv_heads / heads if kv_heads is not None else 1 + qkv_params = (2 + 2 * kv_ratio) * hidden_size * hidden_size # Account for GQA + return vocab + layers * (3 * hidden_size * intermediate_size + qkv_params) def create_config( @@ -30,12 +30,15 @@ def create_config( seq_len: int, micro_batch_size: int = 1, base_config_path: str = "examples/config_tiny_llama_bench.yaml", - zero_stage: int = ZERO_STAGE, + zero_stage: int = 0, num_layers: int = 24, hidden_size: int = 2048, num_attention_heads: int = 16, intermediate_size=None, - tp_mode: str = TP_MODE, + tp_mode: str = "REDUCE_SCATTER", + vocab_size: int = 32768, + profile: bool = False, + benchmark_csv_path: str = "benchmark/results/bench_final.csv", ) -> dict: """Create a config with the specified parallelism settings.""" # Load base config @@ -66,40 +69,43 @@ def create_config( config["model"]["model_config"]["intermediate_size"] = ( intermediate_size if intermediate_size is not None else 4 * hidden_size ) - config["model"]["model_config"]["tie_word_embeddings"] = TIE_WORD_EMBEDDINGS + config["model"]["model_config"]["tie_word_embeddings"] = ( + True if intermediate_size < 10_000 else False + ) # model < 4B - # Set vocab_size to 32k to reduce memory usage - config["model"]["model_config"]["vocab_size"] = VOCAB_SIZE + # Set vocab_size + config["model"]["model_config"]["vocab_size"] = vocab_size - # modify zero stage + # Set zero stage config["optimizer"]["zero_stage"] = zero_stage - # modify tp mode + # Set tp mode config["parallelism"]["tp_mode"] = tp_mode config["parallelism"]["tp_linear_async_communication"] = True if tp_mode == "REDUCE_SCATTER" else False - N = human_format( - estimate_num_params( - num_layers, - hidden_size, - num_attention_heads, - config["model"]["model_config"]["intermediate_size"], - config["model"]["model_config"]["tie_word_embeddings"], - ) + num_params = estimate_num_params( + num_layers, + hidden_size, + num_attention_heads, + config["model"]["model_config"]["intermediate_size"], + config["model"]["model_config"]["tie_word_embeddings"], + vocab_size, ) + N = human_format(num_params) # Update run name to reflect configuration config["general"][ "run" - ] = f"{N}_dp{dp}_tp{tp}_pp{pp}_acc{batch_accum}_mbs{micro_batch_size}_seq{seq_len}_zero{zero_stage}_tpmode{tp_mode[:3]}_l{num_layers}_h{hidden_size}_heads{num_attention_heads}" + ] = f"{N}_dp{dp}_tp{tp}_pp{pp}_acc{batch_accum}_mbs{micro_batch_size}_seq{seq_len}_zero{zero_stage}_tpmode{tp_mode[:3]}_vocab{vocab_size//1000}k" # Update benchmark CSV path - config["general"]["benchmark_csv_path"] = "bench_tp.csv" + config["general"]["benchmark_csv_path"] = benchmark_csv_path - if PROFILE: + if profile: config["profiler"] = {} config["profiler"]["profiler_export_path"] = "./tb_logs" config["tokens"]["train_steps"] = 10 + config["general"]["run"] += "_prof" return config @@ -109,7 +115,7 @@ def generate_slurm_script( dp: int, tp: int, pp: int, - time: str = "00:15:00", + time: str = "00:02:00", partition: str = "hopper-prod", base_script_path: str = "run_multinode.sh", ) -> str: @@ -130,7 +136,7 @@ def generate_slurm_script( # Replace SLURM parameters replacements = { "--nodes=2": f"--nodes={num_nodes}", - "--time=00:15:00": f"--time={time}", + "--time=00:02:00": f"--time={time}", "--partition=hopper-prod": f"--partition={partition}", "--job-name=smolm2-bench": f"--job-name=bench_{config['general']['run']}", "examples/config_tiny_llama.yaml": f"benchmark/configs/config_{config['general']['run']}.yaml", @@ -144,6 +150,69 @@ def generate_slurm_script( return script +def check_params(model_configs): + for model_name, (num_layers, hidden_size, num_heads, intermediate_size) in model_configs.items(): + print(f"{model_name} model parameters:") + tie = True if intermediate_size < 10_000 else False + print( + f" Embedding params: {human_format(estimate_num_params(num_layers, hidden_size, num_heads, intermediate_size, tie, 131072, 8))}" + ) + print() + + exit() + + +def save_experiment_configs(configs, output_path): + """Save core experiment configurations for tracking""" + records = [] + + for config in configs: + # Calculate total params + tie_word_embeddings = True if config["model"]["model_config"]["intermediate_size"] < 10_000 else False + estimate_num_params( + config["model"]["model_config"]["num_hidden_layers"], + config["model"]["model_config"]["hidden_size"], + config["model"]["model_config"]["num_attention_heads"], + config["model"]["model_config"]["intermediate_size"], + tie_word_embeddings, + config["model"]["model_config"]["vocab_size"], + NUM_KEY_VALUE_HEADS, + ) + record = { + "name": config["general"]["run"], + "nodes": config["parallelism"]["dp"] * config["parallelism"]["tp"] * config["parallelism"]["pp"] / 8, + "seq_len": config["tokens"]["sequence_length"], + "mbs": config["tokens"]["micro_batch_size"], + "batch_accum": config["tokens"]["batch_accumulation_per_replica"], + "gbs": config["tokens"]["sequence_length"] + * config["tokens"]["micro_batch_size"] + * config["tokens"]["batch_accumulation_per_replica"] + * config["parallelism"]["dp"], + "dp": config["parallelism"]["dp"], + "pp": config["parallelism"]["pp"], + "tp": config["parallelism"]["tp"], + "tp_mode": f"TensorParallelLinearMode.{config['parallelism']['tp_mode']}", + "hidden_size": config["model"]["model_config"]["hidden_size"], + "num_layers": config["model"]["model_config"]["num_hidden_layers"], + "num_heads": config["model"]["model_config"]["num_attention_heads"], + "vocab_size": config["model"]["model_config"]["vocab_size"], + "zero_stage": config["optimizer"]["zero_stage"], + } + records.append(record) + + # Save to CSV + if os.path.exists(output_path): + # Read existing data and append new records + existing_df = pd.read_csv(output_path) + df = pd.DataFrame(records) + df = pd.concat([existing_df, df], ignore_index=True) + else: + df = pd.DataFrame(records) + + df.to_csv(output_path, index=False) + print(f"Saved {len(records)} experiment configurations to {output_path}") + + def main(): parser = argparse.ArgumentParser(description="Run scaling benchmarks with different parallelism configurations") parser.add_argument( @@ -153,13 +222,29 @@ def main(): "--scripts-dir", type=str, default="benchmark/scripts", help="Directory to store generated SLURM scripts" ) parser.add_argument("--partition", type=str, default="hopper-prod", help="SLURM partition to use") - parser.add_argument("--time", type=str, default="00:15:00", help="Time limit for each job") + parser.add_argument("--time", type=str, default="00:10:00", help="Time limit for each job") parser.add_argument( - "--base-config", type=str, default="examples/config_tiny_llama.yaml", help="Base configuration file to use" + "--base-config", + type=str, + default="examples/config_tiny_llama_bench.yaml", + help="Base configuration file to use", ) parser.add_argument("--base-script", type=str, default="run_multinode.sh", help="Base SLURM script to use") + parser.add_argument( + "--pending-csv", + type=str, + default="benchmark/results/pending_experiments_stress.csv", + help="CSV file to store pending experiments", + ) + parser.add_argument( + "--benchmark-csv", + type=str, + default="benchmark/results/bench_final_stress.csv", + help="CSV file to store benchmark results", + ) parser.add_argument("--run", action="store_true", help="Automatically submit all generated SLURM scripts") parser.add_argument("--debug", action="store_true", help="Debug mode") + parser.add_argument("--profile", action="store_true", help="Enable profiling") args = parser.parse_args() # Validate input files exist @@ -174,17 +259,12 @@ def main(): # Define model configurations model_configs = { - # params = 2*V*h + l(3*h*H + 4*h*h) = (2)Vh + 16lh^2 # (layers, hidden_size, heads, intermediate_size) - # "138M": (12, 768, 12, 3072), - # "200M": (12, 1024, 16, 4096), - # "500M": (12, 1536, 16, 6144), - # "1000M": (15, 2048, 16, 8192), - # "1700M": (24, 2048, 16, 8192), # (layers, hidden_size, heads, intermediate_size) - # "4300M": (28, 3072, 20, 12288), - # "8700M": (32, 4096, 32, 16384), - # "11B": (42, 4096, 32, 16384), - "3500M": (28, 3072, 24, 8192) + # "1B": (16, 2048, 32, 8192), # 1.2G + # "3B": (28, 3072, 24, 8192), # 3.2G + "8B": (32, 4096, 32, 14336), # 8.0G + # "70B": (80, 8192, 64, 28672), # 70G + # "405B": (126, 16384, 128, 53248), # 406G } # Define configurations to test @@ -192,41 +272,125 @@ def main(): # For each model size, test different GPU configurations for model_name, (num_layers, hidden_size, num_heads, intermediate_size) in model_configs.items(): - # Test each model with different GPU counts while maintaining 4M tokens/step - model_configs = [ + vocab_size = 32768 + zero_stage = 0 + tp_mode = "REDUCE_SCATTER" + configs = [ # 64 nodes max + # 2k, 8k, 32k + # GBS: 1M, 4M # Format: (dp, tp, pp, batch_accum, seq_len, mbs, ...) - # (1, 1, 1, 8, 2048, 1, num_layers, hidden_size, num_heads, intermediate_size), - # (1, 2, 1, 1, 4096, 2, num_layers, hidden_size, num_heads, intermediate_size), - # (1, 4, 1, 1, 4096, 2, num_layers, hidden_size, num_heads, intermediate_size), - # (1, 8, 1, 1, 4096, 2, num_layers, hidden_size, num_heads, intermediate_size), - # find best tput on 16 nodes with 4GBS - (1, 8, 1, 1, 4096, 8, num_layers, hidden_size, num_heads, intermediate_size), # test max MBS - # (8, 1, 1, 1, 4096, 1, num_layers, hidden_size, num_heads, intermediate_size), # test max MBS - # (1, 8, 1, 1, 4096, 64, num_layers, hidden_size, num_heads, intermediate_size), # test max MBS - # (16, 8, 1, 1, 4096, 16, num_layers, hidden_size, num_heads, intermediate_size), # ideal run i guess - # (32, 4, 1, 1, 4096, 8, num_layers, hidden_size, num_heads, intermediate_size), # TP=4 - # (64, 2, 1, 1, 4096, 4, num_layers, hidden_size, num_heads, intermediate_size), # TP=2 - # (128, 1, 1, 1, 4096, 2, num_layers, hidden_size, num_heads, intermediate_size), # TP=1 - # find best tput on 8 nodes with 1GBS - # (8, 8, 1, 1, 4096, 32, num_layers, hidden_size, num_heads, intermediate_size), - # (8, 8, 1, 2, 4096, 16, num_layers, hidden_size, num_heads, intermediate_size), - # (16, 4, 1, 2, 4096, 8, num_layers, hidden_size, num_heads, intermediate_size), - # (32, 2, 1, 2, 4096, 4, num_layers, hidden_size, num_heads, intermediate_size), - # (64, 1, 1, 2, 4096, 2, num_layers, hidden_size, num_heads, intermediate_size), - # same for 4 nodes - # (4, 8, 1, 1, 4096, 16, num_layers, hidden_size, num_heads, intermediate_size), - # (8, 4, 1, 1, 4096, 8, num_layers, hidden_size, num_heads, intermediate_size), - # (16, 2, 1, 1, 4096, 4, num_layers, hidden_size, num_heads, intermediate_size), - # (32, 1, 1, 1, 4096, 2, num_layers, hidden_size, num_heads, intermediate_size), + # Using SP what's the biggest seqlen we can fit? + # (1, 8, 1, 1, 2048, 1, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # (1, 8, 1, 1, 2048, 2, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # (1, 8, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # (1, 8, 1, 1, 2048, 32, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # best run + ( + 2, + 8, + 1, + 1, + 2048, + 512, + num_layers, + hidden_size, + num_heads, + intermediate_size, + vocab_size, + zero_stage, + tp_mode, + ), + # test zero + # (3, 8, 1, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 0, tp_mode), + # (3, 8, 1, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 1, tp_mode), + # (24, 1, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 0, tp_mode), + # (24, 1, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 1, tp_mode), + # test tp mode + # (1, 8, 1, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, "ALL_REDUCE"), + # test pp + # (1, 1, 8, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # (1, 8, 2, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # (1, 1, 8, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # (1, 2, 8, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # (1, 2, 64, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # (1, 2, 16, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), ] - configurations.extend(model_configs) + configurations.extend(configs) + + # Duplicate configurations 100 times + # configurations = configurations * 5000 + + # Method 2: Parameter combinations + PARALLEL_CONFIGS = [ + (dp, tp, pp) + for dp in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + for tp in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + for pp in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + ] # Max 64 nodes + # Sort PARALLEL_CONFIGS by total GPU count (dp*tp*pp) ascending + PARALLEL_CONFIGS = sorted(PARALLEL_CONFIGS, key=lambda x: x[0] * x[1] * x[2]) + + # for pp, tp, dp in PARALLEL_CONFIGS: + # for model_name, (num_layers, hidden_size, num_heads, intermediate_size) in model_configs.items(): + # for seq_len in SEQUENCE_LENGTHS: + # for mbs in MBS: + # for batch_accum in GRAD_ACCUM_STEPS: + # for vocab_size in VOCAB_SIZES: + # for zero_stage in ZERO_STAGES: + # for tp_mode in TP_MODES: + # # Optional: Add conditions to filter out unwanted combinations + # total_gpus = dp * tp * pp + # if total_gpus < 8 or total_gpus/8 > 64: # max 64 nodes + # continue + + # tokens_per_step = dp * mbs * batch_accum * seq_len + # if not tokens_per_step in [512*2048, 2048*2048]: + # continue + + # # if dp=1 skip zero stage 1 + # if dp == 1 and zero_stage == 1: + # continue + + # # if tp=1 skip tp_mode=ALL_REDUCE + # if tp == 1 and tp_mode == "ALL_REDUCE": + # continue + + # configurations.append(( + # dp, tp, pp, + # batch_accum, seq_len, mbs, + # num_layers, hidden_size, num_heads, intermediate_size, + # vocab_size, zero_stage, tp_mode + # )) + # time += total_gpus * 1.5 / 8 / 64 # 1.5 minutes per config + + # print(len(configurations)) + # each config takes 1.5 minutes to run, print how many days + # print(f"{time / 60 / 24:.2f} days ({time/60:.2f} hours)") + # exit() if args.debug: print("Debug mode: only running 1 configuration") configurations = configurations[:1] + # run first 100 configurations + # configurations = configurations[:120+5000] + # Validate configurations - for dp, tp, pp, batch_accum, seq_len, mbs, num_layers, hidden_size, num_heads, intermediate_size in configurations: + for ( + dp, + tp, + pp, + batch_accum, + seq_len, + mbs, + num_layers, + hidden_size, + num_heads, + intermediate_size, + vocab_size, + zero_stage, + tp_mode, + ) in configurations: total_gpus = dp * tp * pp if total_gpus > 512: print( @@ -234,12 +398,27 @@ def main(): ) # Calculate tokens per step to verify batch size - tokens_per_step = dp * tp * pp * mbs * batch_accum * seq_len - print(f"Model {hidden_size}H_{num_layers}L: {total_gpus} GPUs, " f"{tokens_per_step:,} GBS") + # tokens_per_step = human_format(dp * mbs * batch_accum * seq_len) + # print(f"Model {hidden_size}H_{num_layers}L: {total_gpus} GPUs, " f"{tokens_per_step} GBS") # Generate configs and scripts generated_scripts = [] - for dp, tp, pp, batch_accum, seq_len, mbs, num_layers, hidden_size, num_heads, intermediate_size in configurations: + configs = [] + for ( + dp, + tp, + pp, + batch_accum, + seq_len, + mbs, + num_layers, + hidden_size, + num_heads, + intermediate_size, + vocab_size, + zero_stage, + tp_mode, + ) in tqdm(configurations, desc="Generating configs and scripts"): try: # Create config config = create_config( @@ -254,6 +433,11 @@ def main(): hidden_size=hidden_size, num_attention_heads=num_heads, intermediate_size=intermediate_size, + vocab_size=vocab_size, + zero_stage=zero_stage, + tp_mode=tp_mode, + profile=args.profile, + benchmark_csv_path=args.benchmark_csv, ) # Save config @@ -273,18 +457,20 @@ def main(): # Make script executable os.chmod(script_path, 0o755) - print(f"Successfully generated config and script for {config_path}") generated_scripts.append(script_path) + configs.append(config) except Exception as e: print(f"Error processing configuration (dp={dp}, tp={tp}, pp={pp}): {str(e)}") + save_experiment_configs(configs, args.pending_csv) + # Submit jobs if requested if args.run: import subprocess print("\nSubmitting jobs...") - for script_path in generated_scripts: + for script_path in tqdm(generated_scripts, desc="Submitting jobs"): try: result = subprocess.run(["sbatch", script_path], check=True, capture_output=True, text=True) print(f"Submitted {script_path}: {result.stdout.strip()}") diff --git a/src/nanotron/distributed.py b/src/nanotron/distributed.py index 0156b1bb..20540324 100644 --- a/src/nanotron/distributed.py +++ b/src/nanotron/distributed.py @@ -13,7 +13,7 @@ torch_version_above_1_13 = version.parse(torch.__version__) >= version.parse("1.13.0") Work = dist.Work if torch_version_above_1_13 else dist._Work -default_pg_timeout = datetime.timedelta(minutes=10) +default_pg_timeout = datetime.timedelta(minutes=30) def new_group( # pylint: disable=function-redefined diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index f61333fd..c3f32a8a 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -595,8 +595,6 @@ def create_table_log( num_params, slurm_job_id, ): - print("num_params") - print(num_params) return [ LogItem("job_id", slurm_job_id, "s"), LogItem("name", config.general.run, "s"), @@ -680,12 +678,22 @@ def write_to_csv(csv_filename, table_log, model_tflops, slurm_job_id): # Use fcntl for file locking max_attempts = 10 attempt = 0 + log_rank( + f"Attempting to write benchmark results to CSV file: {csv_filename}", + logger=logger, + level=logging.INFO, + ) while attempt < max_attempts: try: # Open file in append mode (will create if doesn't exist) with open(csv_filename, mode="a+", newline="") as f: # Get exclusive lock fcntl.flock(f.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + log_rank( + f"Acquired lock for CSV file: {csv_filename}", + logger=logger, + level=logging.INFO, + ) try: # Check if file is empty/new f.seek(0) @@ -704,8 +712,18 @@ def write_to_csv(csv_filename, table_log, model_tflops, slurm_job_id): finally: # Release lock fcntl.flock(f.fileno(), fcntl.LOCK_UN) + log_rank( + f"Successfully wrote to CSV file: {csv_filename}. Releasing lock...", + logger=logger, + level=logging.INFO, + ) except BlockingIOError: # Another process has the lock, wait and retry + log_rank( + f"Another process has the lock for CSV file: {csv_filename}, waiting and retrying attempt {attempt + 1} of {max_attempts}...", + logger=logger, + level=logging.INFO, + ) attempt += 1 time.sleep(0.1) # Wait 100ms before retrying except IOError as e: diff --git a/src/nanotron/serialize/__init__.py b/src/nanotron/serialize/__init__.py index 7fc7b0a9..d6756804 100644 --- a/src/nanotron/serialize/__init__.py +++ b/src/nanotron/serialize/__init__.py @@ -3,3 +3,4 @@ from nanotron.serialize.optimizer import * from nanotron.serialize.random import * from nanotron.serialize.weights import * +from nanotron.serialize.metadata import * diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index a9d8876a..aae4d3fc 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -2,6 +2,7 @@ import gc import json import os +import random import shutil import time from dataclasses import asdict @@ -156,10 +157,6 @@ def __init__( # Set log levels set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging) - # Log benchmark info - # if os.environ.get("NANOTRON_BENCHMARK", "0") == "1": - # log_throughput(self.config, self.parallel_context) - ######################################## ## Setting up our model, optimizers, schedulers, etc. ######################################## @@ -259,9 +256,17 @@ def __init__( def pre_init(self): self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context) - # Calculate cluster bandwidth + # TODO: fix in case of dp=3 tp=8 self.BANDWIDTHS = measure_bandwidth(self.parallel_context) + # self.BANDWIDTHS = { + # "all_reduce": 0.0, + # "reduce_scatter": 0.0, + # "all_gather": 0.0, + # "all_reduce_intranode": 0.0, + # "reduce_scatter_intranode": 0.0, + # "all_gather_intranode": 0.0, + # } def post_init(self): # S3 Mover and save initial state @@ -441,6 +446,12 @@ def train( self.last_iter_step = self.config.tokens.train_steps prof = get_profiler(config=self.config) + + # Random wait between 0-5 seconds + wait_time = random.Random().randint(13, 17) # Use a new generator instance each time + log_rank(f"Waiting for {wait_time} seconds", logger=logger, level=logging.INFO, rank=0) + time.sleep(wait_time) + # free memory gc.collect() torch.cuda.empty_cache() @@ -1076,6 +1087,14 @@ def measure_bandwidth(parallel_context: ParallelContext): import torch import torch.distributed as dist + # Log that we're measuring bandwidth + log_rank( + "Measuring inter-GPU and intra-node bandwidth...", + logger=logger, + level=logging.INFO, + rank=0, + ) + # Size of data to transfer (256MB in elements) size = 256 * 1024 * 1024 # Number of elements tensor = torch.ones(size).cuda() diff --git a/test_2nodes.sh b/test_2nodes.sh new file mode 100755 index 00000000..d5b5108a --- /dev/null +++ b/test_2nodes.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +#SBATCH --job-name=bench_stress_test # Job name +#SBATCH --time=00:01:01 +#SBATCH --partition=hopper-prod +#SBATCH --qos=high + +#SBATCH -o /fsx/nouamane/projects/nanotron/logs/%j-%x.out + +#SBATCH --nodes=1 # Number of nodes (modify as needed) +#SBATCH --ntasks-per-node=1 # Number of tasks per node +#SBATCH --cpus-per-task=60 # CPU cores per task +#SBATCH --gres=gpu:8 # Number of GPUs per node +#SBATCH --exclusive # Exclusive use of nodes + +set -x -e + +# Load any necessary modules for your system +source /etc/profile.d/modules.sh # for some reason module isn't loaded +module load cuda/12.1 + +# Activate your conda environment if needed +source /fsx/nouamane/miniconda/bin/activate +conda activate 2-1-cu121 +export PATH=/fsx/nouamane/miniconda/envs/2-1-cu121/bin:$PATH + +# Get the node names from SLURM +export NODELIST=`scontrol show hostnames $SLURM_JOB_NODELIST` +export MASTER_NODE=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n1` +export MASTER_PORT=12356 + +# Calculate total number of processes +export NNODES=$SLURM_NNODES +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES * $GPUS_PER_NODE)) + +# Set some environment variables for better distributed training +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_DEBUG=INFO + +# Nanotron specific +export NANOTRON_BENCHMARK=1 + +# Print GPU topology information +echo "=== GPU Topology ===" +nvidia-smi topo -m +echo "==================" + + +# Print some debugging information +echo "Master node: $MASTER_NODE" +echo "All nodes: $NODELIST" +echo "World size: $WORLD_SIZE" + +# Launch the training script using srun +srun torchrun \ + --nnodes=$NNODES \ + --nproc_per_node=$GPUS_PER_NODE \ + --rdzv_id=$SLURM_JOB_ID \ + --rdzv_backend=c10d \ + --rdzv_endpoint=$MASTER_NODE:$MASTER_PORT \ + /fsx/nouamane/projects/nanotron/run_train.py \ + --config-file benchmark/configs/config_1.14G_dp4_tp2_pp1_acc256_mbs2_seq2048_zero1_tpmodeRED_vocab32k.yaml