From 28770a5749d33bcd284422a8241f9d6001662ed4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Mon, 26 Aug 2024 00:42:22 +0000 Subject: [PATCH] fancy++ launcher --- launcher.py | 142 +++++++++++++++++++++++----------- src/nanotron/config/config.py | 65 ++++++++++++---- 2 files changed, 148 insertions(+), 59 deletions(-) diff --git a/launcher.py b/launcher.py index 7b192bb4..3d78d698 100644 --- a/launcher.py +++ b/launcher.py @@ -197,7 +197,7 @@ def launch_slurm_job(launch_file_contents, *args): ) parallelism = ParallelismArgs( - dp=16, + dp=8, pp=1, tp=1, pp_engine="1f1b", @@ -267,16 +267,16 @@ def launch_slurm_job(launch_file_contents, *args): ) tokenizer = TokenizerArgs( - tokenizer_name_or_path="lvwerra/the-tokenizer-v1", + tokenizer_name_or_path="HuggingFaceTB/cosmo2-tokenizer", ) - s3_upload = S3UploadArgs( - upload_s3_path=f"s3://elie-exp/debug_nanotron/test/", - remove_after_upload=True, - s5cmd_numworkers=16, - s5cmd_concurrency=5, - s5cmd_path=os.path.join(slurm.conda_env_path, "bin/s5cmd"), - ) + # s3_upload = S3UploadArgs( + # upload_s3_path=f"s3://elie-exp/debug_nanotron/test/", + # remove_after_upload=True, + # s5cmd_numworkers=16, + # s5cmd_concurrency=5, + # s5cmd_path=os.path.join(slurm.conda_env_path, "bin/s5cmd"), + # ) data_stages=[ DatasetStageArgs( @@ -302,62 +302,76 @@ def launch_slurm_job(launch_file_contents, *args): tokens=tokens, optimizer=optimizer, data_stages=data_stages, - s3_upload=s3_upload, + # s3_upload=s3_upload, lighteval=lighteval, - slurm=slurm, + # slurm=slurm, ) print(f""" πŸ‹οΈ Model Parameters: -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ Total Parameters β”‚ {num_params:>25} β”‚ -β”‚ Layers β”‚ {model_config.num_hidden_layers:>25d} β”‚ -β”‚ Attention Heads β”‚ {model_config.num_attention_heads:>25d} β”‚ -β”‚ Hidden Size β”‚ {model_config.hidden_size:>25d} β”‚ -β”‚ Intermediate Size β”‚ {model_config.intermediate_size:>25d} β”‚ -β”‚ Context Length β”‚ {model_config.max_position_embeddings:>25d} β”‚ -β”‚ Tokenizer β”‚ {tokenizer.tokenizer_name_or_path[:25]:>25} β”‚ -β”‚ Vocab Size β”‚ {model_config.vocab_size:>25d} β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Total Parameters β”‚ {num_params:>22} β”‚ +β”‚ Layers β”‚ {model_config.num_hidden_layers:>22d} β”‚ +β”‚ Attention Heads β”‚ {model_config.num_attention_heads:>22d} β”‚ +β”‚ Hidden Size β”‚ {model_config.hidden_size:>22d} β”‚ +β”‚ Intermediate Size β”‚ {model_config.intermediate_size:>22d} β”‚ +β”‚ Context Length β”‚ {model_config.max_position_embeddings:>22d} β”‚ +β”‚ Tokenizer β”‚ {tokenizer.tokenizer_name_or_path[:22]:>22} β”‚ +β”‚ Vocab Size β”‚ {model_config.vocab_size:>22d} β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ """) + num_nodes = slurm.nodes if args.slurm else torch.cuda.device_count() print(f""" πŸ€– Parallelism Configuration: -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ Nodes β”‚ {slurm.nodes:>17d} β”‚ -β”‚ Total GPUs β”‚ {parallelism.dp*parallelism.pp*parallelism.tp:>17d} β”‚ -β”‚ Data Parallel (DP) β”‚ {parallelism.dp:>17d} β”‚ -β”‚ Pipeline Parallel (PP)β”‚ {parallelism.pp:>17d} β”‚ -β”‚ Tensor Parallel (TP) β”‚ {parallelism.tp:>17d} β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Nodes β”‚ {num_nodes:>22d} β”‚ +β”‚ Total GPUs β”‚ {parallelism.dp*parallelism.pp*parallelism.tp:>22d} β”‚ +β”‚ Data Parallel (DP) β”‚ {parallelism.dp:>22d} β”‚ +β”‚ Pipeline Parallel (PP)β”‚ {parallelism.pp:>22d} β”‚ +β”‚ Tensor Parallel (TP) β”‚ {parallelism.tp:>22d} β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ """) print(f""" πŸ“™ Training Configuration: -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ Total Tokens β”‚ {total_tokens_billions:>16.2f}B β”‚ -β”‚ Global Batch Size β”‚ {GBS:>17,d} β”‚ -β”‚ Batch Size (per GPU) β”‚ {BS:>17,d} β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Total Tokens β”‚ {total_tokens_billions:>21.2f}B β”‚ +β”‚ Global Batch Size β”‚ {GBS:>22,d} β”‚ +β”‚ Batch Size (per GPU) β”‚ {BS:>22,d} β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ """) print(f""" πŸ“Š Learning Rate Schedule: -β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” -β”‚ Initial LR β”‚ {lr_initial:>17.2e} β”‚ -β”‚ Warmup Style β”‚ {learning_rate_scheduler.lr_warmup_style[:17]:>17} β”‚ -β”‚ Warmup Steps β”‚ {lr_warmup_steps:>17d} β”‚ -β”‚ Decay Style β”‚ {lr_decay_style[:17]:>17} β”‚ -β”‚ Decay Start Step β”‚ {lr_decay_start:>17d} β”‚ -β”‚ Decay Steps β”‚ {lr_decay_steps:>17d} β”‚ -β”‚ Final LR β”‚ {lr_min:>17.2e} β”‚ -β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Initial LR β”‚ {lr_initial:>22.2e} β”‚ +β”‚ Warmup Style β”‚ {learning_rate_scheduler.lr_warmup_style[:22]:>22} β”‚ +β”‚ Warmup Steps β”‚ {lr_warmup_steps:>22d} β”‚ +β”‚ Decay Style β”‚ {lr_decay_style[:22]:>22} β”‚ +β”‚ Decay Start Step β”‚ {lr_decay_start:>22d} β”‚ +β”‚ Decay Steps β”‚ {lr_decay_steps:>22d} β”‚ +β”‚ Final LR β”‚ {lr_min:>22.2e} β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +""") + print(f""" +πŸ”§ Optimization Configuration: +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ Optimizer β”‚ {optimizer.optimizer_factory.__class__.__name__:>22} β”‚ +β”‚ Weight Decay β”‚ {optimizer.weight_decay:>22.2e} β”‚ +β”‚ Gradient Clipping β”‚ {optimizer.clip_grad:>22.2f} β”‚ +β”‚ Adam Epsilon β”‚ {optimizer.optimizer_factory.adam_eps:>22.2e} β”‚ +β”‚ Adam Beta1 β”‚ {optimizer.optimizer_factory.adam_beta1:>22.2f} β”‚ +β”‚ Adam Beta2 β”‚ {optimizer.optimizer_factory.adam_beta2:>22.2f} β”‚ +β”‚ ZeRO Stage β”‚ {optimizer.zero_stage:>22d} β”‚ +β”‚ FP32 Grad Accumulationβ”‚ {str(optimizer.accumulate_grad_in_fp32):>22} β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ """) - if slurm is not None: + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + if args.slurm: dir = os.path.dirname(__file__) - timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") os.makedirs(config.slurm.config_logs_path, exist_ok=True) config_path_yaml = f"{config.slurm.config_logs_path}/{timestamp}.yaml" config.save_as_yaml(config_path_yaml) @@ -457,4 +471,42 @@ def format_sbatch_option(option, value): echo "END TIME: $(date)" """ - print(f"Slurm job launched with id={launch_slurm_job(sbatch_script)}") \ No newline at end of file + print(f"Slurm job launched with id={launch_slurm_job(sbatch_script)}") + else: + # Check if running on an interactive node + try: + gpu_count = torch.cuda.device_count() + is_interactive = gpu_count > 0 + except: + is_interactive = False + + if is_interactive: + print("Running on an interactive node with GPUs.") + + # Check if the parallelism configuration matches the available GPUs + total_gpus = gpu_count + config_gpus = parallelism.dp * parallelism.tp * parallelism.pp + + if total_gpus != config_gpus: + raise ValueError(f"The parallelism configuration (dp={parallelism.dp}, tp={parallelism.tp}, pp={parallelism.pp}) " + f"doesn't match the number of available GPUs ({total_gpus}). " + f"Please adjust your configuration to match the available resources.") + + # Save config + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + os.makedirs("/fsx/elie_bakouch/nanotron/config_logs", exist_ok=True) + config_path_yaml = f"/fsx/elie_bakouch/nanotron/config_logs/{timestamp}.yaml" + config.save_as_yaml(config_path_yaml) + + # Prepare command + trainer_python_file = "/fsx/elie_bakouch/nanotron/run_train.py" + cmd = f"{trainer_python_file} --config-file {config_path_yaml}" + + # Launch job + launch_cmd = f"torchrun --nproc_per_node {gpu_count} {cmd}" + print(f"Launching interactive job with command: {launch_cmd}") + + # Execute the command + subprocess.run(launch_cmd, shell=True, check=True) + else: + print("Not running on a Slurm cluster or an interactive node with GPUs. Please submit a Slurm job or use an interactive node with GPUs.") \ No newline at end of file diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 26bd1546..126a4d23 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, fields from pathlib import Path from datasets.download.streaming_download_manager import xPath -from typing import List, Optional, Type, Union +from typing import List, Optional, Type, Union, Dict import dacite import torch @@ -93,25 +93,62 @@ def __post_init__(self): @dataclass class SlurmArgs: + """ + Arguments for configuring SLURM job submission. + + Attributes: + gpu_partition (str): SLURM partition (queue) for GPU jobs. + job_name (str): Name of the SLURM job. + nodes (int): Number of nodes to allocate for the job. + logs_path (str): Base directory for storing log files. + conda_path (str): Path to the Conda installation script. + conda_env_path (str): Path to the Conda environment to be used. + n_tasks_per_node (int): Number of tasks to run per node. Default is 1. + cpus_per_task (int): Number of CPUs to allocate per task. Default is 32. + gpu_per_node (int): Number of GPUs to allocate per node. Default is 8. + array (Optional[str]): Job array specification, allowing multiple similar jobs to be submitted as a group. + qos (Optional[str]): Quality of Service, used to define job priority or resource limits. + mail_type (Optional[str]): Specifies when to send email notifications about the job (e.g., BEGIN, END, FAIL). Default is FAIL. + mail_user (Optional[str]): Email address to receive job notifications. + exclude_nodes (Optional[List[str]]): List of nodes to exclude from job allocation. + time (Optional[str]): Maximum time limit for the job. + mem (Optional[str]): Memory requirement for the job. + constraint (Optional[str]): Specifies node features required for the job. + account (Optional[str]): Account to charge for the job's resource usage. + reservation (Optional[str]): Name of a reservation to use for the job. + begin (Optional[str]): Earliest time the job can start. + torchrun_args (Optional[Dict[str, str]]): Additional arguments for torchrun command. + slurm_logs_path (Optional[str]): Specific path for SLURM output logs. + config_logs_path (Optional[str]): Path for storing configuration logs. + """ + + gpu_partition: str job_name: str nodes: int - logs_path: Path - # TODO: @elibak: Add a way to handle different virtual environments (conda, venv, uv, etc) For now, we assume conda and user can modify the slurm template if they use something else. + logs_path: str conda_path: str - conda_env_path : str - gpu_partition: Optional[str] = None - n_tasks_per_node: Optional[int] = 1 - cpus_per_task: Optional[int] = 32 - gpu_per_node: Optional[int] = 8 - mail: Optional[str] = None - qos: Optional[str] = "high" - array: Optional[str] = "1-1%1" + conda_env_path: str + n_tasks_per_node: int = 1 + cpus_per_task: int = 32 + gpu_per_node: int = 8 + array: Optional[str] = None + qos: Optional[str] = None + mail_user: Optional[str] = None + mail_type: Optional[str] = None + exclude_nodes: Optional[List[str]] = None + time: Optional[str] = None + mem: Optional[str] = None + constraint: Optional[str] = None + account: Optional[str] = None + reservation: Optional[str] = None + begin: Optional[str] = None + torchrun_args: Optional[Dict[str, str]] = None slurm_logs_path: Optional[str] = None - evals_logs_path: Optional[str] = None config_logs_path: Optional[str] = None - - + def __post_init__(self): + if self.mail_type is None and self.mail_user is not None: + self.mail_type = "FAIL" @dataclass class S3UploadArgs: