From 5eb25497eef1216b5ed380735fcf504fb3dc535c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Fri, 23 Aug 2024 03:32:57 +0000 Subject: [PATCH 01/43] adding slurm config as an argument to better generate slurm for launch and eval --- src/nanotron/config/config.py | 38 +++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 0744dd69..44ec8d2b 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -91,6 +91,23 @@ def __post_init__(self): if self.hf_dataset_splits is None: self.hf_dataset_splits = "train" +@dataclass +class SlurmArgs: + gpu_partition: str + job_name: str + nodes: int + logs_path: Path + n_tasks_per_node: Optional[int] = 1 + cpus_per_task: Optional[int] = 32 + n_gpu: Optional[int] = 8 + email: Optional[str] = None + qos: Optional[str] + array: Optional[str] + slurm_logs_path: Optional[str] = None + evals_logs_path: Optional[str] = None + config_logs_path: Optional[str] = None + + @dataclass class S3UploadArgs: @@ -356,6 +373,7 @@ class Config: profiler: Optional[ProfilerArgs] = None lighteval: Optional[LightEvalConfig] = None s3_upload : Optional[S3UploadArgs] = None + slurm: Optional[SlurmArgs] = None @classmethod def create_empty(cls): @@ -399,6 +417,26 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" + if self.slurm is not None: + job_folder = os.path.join(self.logs_path, self.job_name) + os.makedirs(job_folder, exist_ok=True) + + subfolders = ['configs', 'evals', 'slurm'] + logs_paths = {} + + for subfolder in subfolders: + specific_path = getattr(self.slurm, f"{subfolder}_logs_path", None) + if specific_path is None: + folder_path = os.path.join(job_folder, subfolder) + else: + folder_path = specific_path + os.makedirs(folder_path, exist_ok=True) + logs_paths[subfolder] = folder_path + + self.slurm.config_logs_path = logs_paths['configs'] + self.slurm.evals_logs_path = logs_paths['evals'] + self.slurm.slurm_logs_path = logs_paths['slurm'] + # # if lighteval, we need tokenizer to be defined # if self.checkpoints.lighteval is not None: # assert self.tokenizer.tokenizer_name_or_path is not None From 3875c602473acf4260f9e720a3ed013edbb3454a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Sat, 24 Aug 2024 19:25:26 +0000 Subject: [PATCH 02/43] working version of lighteval after s3 on 1 node --- run_train.py | 4 +- src/nanotron/config/config.py | 31 +- src/nanotron/helpers.py | 1 - src/nanotron/lighteval/__init__.py | 3 + src/nanotron/lighteval/evaluation_tasks.py | 654 ++++++++++++++++++++ src/nanotron/lighteval/one_job_runner.py | 148 +++++ src/nanotron/lighteval/run_eval.slurm.jinja | 73 +++ src/nanotron/lighteval/run_evals.py | 35 ++ src/nanotron/trainer.py | 10 +- 9 files changed, 947 insertions(+), 12 deletions(-) create mode 100644 src/nanotron/lighteval/__init__.py create mode 100644 src/nanotron/lighteval/evaluation_tasks.py create mode 100644 src/nanotron/lighteval/one_job_runner.py create mode 100644 src/nanotron/lighteval/run_eval.slurm.jinja create mode 100644 src/nanotron/lighteval/run_evals.py diff --git a/run_train.py b/run_train.py index 021d955d..09c27d35 100644 --- a/run_train.py +++ b/run_train.py @@ -37,7 +37,9 @@ except ImportError: hf_hub_version = None tf_version = None - + +import torch._dynamo +torch._dynamo.config.suppress_errors = True logger = logging.get_logger(__name__) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 44ec8d2b..26bd1546 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -93,19 +93,23 @@ def __post_init__(self): @dataclass class SlurmArgs: - gpu_partition: str job_name: str nodes: int logs_path: Path + # TODO: @elibak: Add a way to handle different virtual environments (conda, venv, uv, etc) For now, we assume conda and user can modify the slurm template if they use something else. + conda_path: str + conda_env_path : str + gpu_partition: Optional[str] = None n_tasks_per_node: Optional[int] = 1 cpus_per_task: Optional[int] = 32 - n_gpu: Optional[int] = 8 - email: Optional[str] = None - qos: Optional[str] - array: Optional[str] + gpu_per_node: Optional[int] = 8 + mail: Optional[str] = None + qos: Optional[str] = "high" + array: Optional[str] = "1-1%1" slurm_logs_path: Optional[str] = None evals_logs_path: Optional[str] = None config_logs_path: Optional[str] = None + @@ -203,6 +207,8 @@ class GeneralArgs: """ project: str + repo_id: Optional[str] = None + temp_dir: Optional[str] = None run: Optional[str] = None seed: Optional[int] = None step: Optional[int] = None @@ -418,10 +424,13 @@ def __post_init__(self): ), "The stages are not sorted by start_training_step in increasing order" if self.slurm is not None: - job_folder = os.path.join(self.logs_path, self.job_name) + job_folder = os.path.join(self.slurm.logs_path, self.slurm.job_name) os.makedirs(job_folder, exist_ok=True) - subfolders = ['configs', 'evals', 'slurm'] + subfolders = ['configs', 'slurm'] + if self.lighteval is not None and self.s3_upload is not None: + subfolders.append('evals') + logs_paths = {} for subfolder in subfolders: @@ -433,9 +442,15 @@ def __post_init__(self): os.makedirs(folder_path, exist_ok=True) logs_paths[subfolder] = folder_path + if subfolder == 'evals': + for evals_subfolder in ['launch-config', 'logs']: + evals_subfolder_path = os.path.join(folder_path, evals_subfolder) + os.makedirs(evals_subfolder_path, exist_ok=True) + self.slurm.config_logs_path = logs_paths['configs'] - self.slurm.evals_logs_path = logs_paths['evals'] self.slurm.slurm_logs_path = logs_paths['slurm'] + if self.lighteval is not None: + self.slurm.evals_logs_path = logs_paths['evals'] # # if lighteval, we need tokenizer to be defined # if self.checkpoints.lighteval is not None: diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 892ac03c..761fffc2 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -52,7 +52,6 @@ def _vocab_size_with_padding(orig_vocab_size: int, pg_size: int, make_vocab_size multiple = make_vocab_size_divisible_by * pg_size after = int(ceil(orig_vocab_size / multiple) * multiple) - print("hello") if after != orig_vocab_size: print("i'm in") log_rank( diff --git a/src/nanotron/lighteval/__init__.py b/src/nanotron/lighteval/__init__.py new file mode 100644 index 00000000..d7ea002c --- /dev/null +++ b/src/nanotron/lighteval/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa: F401 + +from .one_job_runner import LightEvalRunner diff --git a/src/nanotron/lighteval/evaluation_tasks.py b/src/nanotron/lighteval/evaluation_tasks.py new file mode 100644 index 00000000..2c0d9449 --- /dev/null +++ b/src/nanotron/lighteval/evaluation_tasks.py @@ -0,0 +1,654 @@ +# ruff: noqa: F405, F403, F401 +""" +Custom evaluation tasks for lighteval + +This file generally create just a TASKS_TABLE and TASKS_GROUPS which are then imported by LightEval. +Edit this file to add your own task if needed +""" +import re +from dataclasses import asdict +from typing import Dict, List, Tuple + +from lighteval.metrics import Metrics +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc +from lighteval.tasks.tasks_prompt_formatting import LETTER_INDICES + +_TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = [] +_TASKS: List[LightevalTaskConfig] = [] + +## COMMON_SENSE_REASONING_TASKS ## +COMMON_SENSE_REASONING_TASKS = [ + LightevalTaskConfig( + name="hellaswag", + prompt_function="hellaswag_prompt", + hf_repo="hellaswag", + hf_subset="default", + metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + ), + LightevalTaskConfig( + name="winogrande", + prompt_function="winogrande", + hf_repo="winogrande", + hf_subset="winogrande_xl", + metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + ), + LightevalTaskConfig( + name="piqa", + prompt_function="piqa_harness", + hf_repo="piqa", + hf_subset="plain_text", + metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + ), + LightevalTaskConfig( + name="siqa", + prompt_function="siqa_prompt", + hf_repo="lighteval/siqa", + hf_subset="default", + hf_avail_splits=["train", "validation"], + metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + ), + LightevalTaskConfig( + name="openbookqa", + prompt_function="openbookqa", + hf_repo="openbookqa", + hf_subset="main", + metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + ), + LightevalTaskConfig( + name="arc:easy", + prompt_function="arc", + hf_repo="ai2_arc", + hf_subset="ARC-Easy", + evaluation_splits=["test"], + generation_size=1, + metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + ), + LightevalTaskConfig( + name="arc:challenge", + prompt_function="arc", + hf_repo="ai2_arc", + hf_subset="ARC-Challenge", + evaluation_splits=["test"], + generation_size=1, + metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + ), + LightevalTaskConfig( + name="commonsense_qa", + prompt_function="commonsense_qa_prompt", + hf_repo="commonsense_qa", + hf_subset="default", + metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + ), +] + + +def commonsense_qa_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["question"], + choices=[f" {c}" for c in line["choices"]["text"]], + gold_index=LETTER_INDICES.index(line["answerKey"].strip()), + instruction="", + ) + + +def siqa_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["context"] + " " + line["question"], + choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]], + gold_index=int(line["label"]) - 1, + instruction="", + ) + + +def hellaswag_prompt(line, task_name: str = None): + def preprocess(text): + """Comes from AiHarness""" + # text = text.strip() + # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} " + return Doc( + task_name=task_name, + query=preprocess(line["activity_label"] + ": " + ctx), + choices=[" " + preprocess(ending) for ending in line["endings"]], + gold_index=int(line["label"]) if line["label"] != "" else -1, # -1 for test + # "metric": "choices_loglikelihood", + ) + + +# 0 short for common sense +COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS] +_TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING) +_TASKS += COMMON_SENSE_REASONING_TASKS + +## WORLD_KNOWLEDGE_TASKS ## + +WORLD_KNOWLEDGE_TASKS = [ + LightevalTaskConfig( + name="trivia_qa", + prompt_function="triviaqa", + hf_repo="trivia_qa", + hf_subset="rc.nocontext", + metric=[Metrics.quasi_exact_match], + generation_size=20, + stop_sequence=["\n", ".", ","], + ), + LightevalTaskConfig( + name="natural_questions", + prompt_function="natural_questions_prompt", + hf_repo="lighteval/natural_questions_clean", + hf_subset="default", + metric=[Metrics.quasi_exact_match], + generation_size=20, + stop_sequence=["\n", ".", ","], + ), +] + + +def natural_questions_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["question"] + "?\nAnswer: ", + choices=[line["short_answers"]], + gold_index=0, + instruction="", + ) + + +WORLD_KNOWLEDGE_STRING = [(t, f"custom|{t.name}|5|1") for t in WORLD_KNOWLEDGE_TASKS] +# WORLD_KNOWLEDGE_STRING = {t: f'custom|{t.name}|0|1' for t in WORLD_KNOWLEDGE_TASKS} +_TASKS_STRINGS.extend(WORLD_KNOWLEDGE_STRING) +_TASKS += WORLD_KNOWLEDGE_TASKS + +## Reading comprehension ## + +READING_COMP_TASKS = [ + LightevalTaskConfig( + name="super_glue:boolq", + prompt_function="boolq_prompt", + hf_repo="super_glue", + hf_subset="boolq", + metric=["target_perplexity"], + ), + LightevalTaskConfig( + name="quac", + prompt_function="quac", + hf_repo="lighteval/quac_helm", + hf_subset="deault", + metric=[Metrics.quasi_exact_match], + generation_size=20, + stop_sequence=["\n", ".", ","], + ), +] + + +def boolq_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:", + choices=[" No", " Yes"], # Only gold + gold_index=int(line["label"]), + ) + + +READING_COMP_STRING = [(t, f"custom|{t.name}|0|1") for t in READING_COMP_TASKS] +_TASKS_STRINGS.extend(READING_COMP_STRING) +_TASKS += READING_COMP_TASKS + + +## MATH ## +class CustomMathEvaluationTask(LightevalTaskConfig): + """Custom class for math tasks with all the defaults set""" + + def __init__( + self, + name, + prompt_function="math", + hf_repo="lighteval/MATH", + hf_subset=None, + metric=[Metrics.quasi_exact_match_math], + hf_avail_splits=None, + evaluation_splits=["test"], + few_shots_split=None, + few_shots_select=None, + suite=["custom"], + generation_size=40, + stop_sequence=None, + output_regex=None, + frozen=False, + ): + super().__init__( + name=name, + prompt_function=prompt_function, + hf_repo=hf_repo, + hf_subset=hf_subset, + metric=metric, + hf_avail_splits=hf_avail_splits, + evaluation_splits=evaluation_splits, + few_shots_split=few_shots_split, + few_shots_select=few_shots_select, + suite=suite, + generation_size=generation_size, + stop_sequence=stop_sequence, + output_regex=output_regex, + frozen=frozen, + ) + + +MATH_TASKS = [ + CustomMathEvaluationTask(name="math:algebra", hf_subset="algebra"), + CustomMathEvaluationTask(name="math:counting_and_probability", hf_subset="counting_and_probability"), + CustomMathEvaluationTask(name="math:geometry", hf_subset="geometry"), + CustomMathEvaluationTask(name="math:intermediate_algebra", hf_subset="intermediate_algebra"), + CustomMathEvaluationTask(name="math:number_theory", hf_subset="number_theory"), + CustomMathEvaluationTask(name="math:prealgebra", hf_subset="prealgebra"), + CustomMathEvaluationTask(name="math:precalculus", hf_subset="precalculus"), +] +GSM8K = LightevalTaskConfig( + name="gsm8k", + prompt_function="gsm8k", + hf_repo="gsm8k", + hf_subset="main", + hf_avail_splits=["train", "test"], + evaluation_splits=["test"], + metric=[Metrics.perfect_exact_match], + generation_size=10, + stop_sequence=["\n"], +) + + +MATH_STRING = [(t, f"custom|{t.name}|4|1") for t in MATH_TASKS] +GSM8K_STRING = [(GSM8K, f"custom|{GSM8K.name}|8|1")] +_TASKS_STRINGS.extend(MATH_STRING) +_TASKS_STRINGS.extend(GSM8K_STRING) +_TASKS += MATH_TASKS + [GSM8K] + + +## MMLU ## +class CustomMMLUEvaluationTask(LightevalTaskConfig): + def __init__( + self, + name, + prompt_function="mmlu_prompt", + hf_repo="lighteval/mmlu", + hf_subset=None, + # metric=[Metrics.loglikelihood_acc_single_token], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], + hf_avail_splits=None, + evaluation_splits=["test"], + few_shots_split="dev", + few_shots_select=None, + suite=None, + generation_size=-1, + stop_sequence=None, + output_regex=None, + frozen=False, + ): + super().__init__( + name=name, + prompt_function=prompt_function, + hf_repo=hf_repo, + hf_subset=hf_subset, + metric=metric, + hf_avail_splits=hf_avail_splits, + evaluation_splits=evaluation_splits, + few_shots_split=few_shots_split, + few_shots_select=few_shots_select, + suite=suite, + generation_size=generation_size, + stop_sequence=stop_sequence, + output_regex=output_regex, + frozen=frozen, + ) + + +MMLU_TASKS = [ + CustomMMLUEvaluationTask(name="mmlu:abstract_algebra", hf_subset="abstract_algebra"), + CustomMMLUEvaluationTask(name="mmlu:anatomy", hf_subset="anatomy"), + CustomMMLUEvaluationTask(name="mmlu:astronomy", hf_subset="astronomy"), + CustomMMLUEvaluationTask(name="mmlu:business_ethics", hf_subset="business_ethics"), + CustomMMLUEvaluationTask(name="mmlu:clinical_knowledge", hf_subset="clinical_knowledge"), + CustomMMLUEvaluationTask(name="mmlu:college_biology", hf_subset="college_biology"), + CustomMMLUEvaluationTask(name="mmlu:college_chemistry", hf_subset="college_chemistry"), + CustomMMLUEvaluationTask(name="mmlu:college_computer_science", hf_subset="college_computer_science"), + CustomMMLUEvaluationTask(name="mmlu:college_mathematics", hf_subset="college_mathematics"), + CustomMMLUEvaluationTask(name="mmlu:college_medicine", hf_subset="college_medicine"), + CustomMMLUEvaluationTask(name="mmlu:college_physics", hf_subset="college_physics"), + CustomMMLUEvaluationTask(name="mmlu:computer_security", hf_subset="computer_security"), + CustomMMLUEvaluationTask(name="mmlu:conceptual_physics", hf_subset="conceptual_physics"), + CustomMMLUEvaluationTask(name="mmlu:econometrics", hf_subset="econometrics"), + CustomMMLUEvaluationTask(name="mmlu:electrical_engineering", hf_subset="electrical_engineering"), + CustomMMLUEvaluationTask(name="mmlu:elementary_mathematics", hf_subset="elementary_mathematics"), + CustomMMLUEvaluationTask(name="mmlu:formal_logic", hf_subset="formal_logic"), + CustomMMLUEvaluationTask(name="mmlu:global_facts", hf_subset="global_facts"), + CustomMMLUEvaluationTask(name="mmlu:high_school_biology", hf_subset="high_school_biology"), + CustomMMLUEvaluationTask(name="mmlu:high_school_chemistry", hf_subset="high_school_chemistry"), + CustomMMLUEvaluationTask(name="mmlu:high_school_computer_science", hf_subset="high_school_computer_science"), + CustomMMLUEvaluationTask(name="mmlu:high_school_european_history", hf_subset="high_school_european_history"), + CustomMMLUEvaluationTask(name="mmlu:high_school_geography", hf_subset="high_school_geography"), + CustomMMLUEvaluationTask( + name="mmlu:high_school_government_and_politics", hf_subset="high_school_government_and_politics" + ), + CustomMMLUEvaluationTask(name="mmlu:high_school_macroeconomics", hf_subset="high_school_macroeconomics"), + CustomMMLUEvaluationTask(name="mmlu:high_school_mathematics", hf_subset="high_school_mathematics"), + CustomMMLUEvaluationTask(name="mmlu:high_school_microeconomics", hf_subset="high_school_microeconomics"), + CustomMMLUEvaluationTask(name="mmlu:high_school_physics", hf_subset="high_school_physics"), + CustomMMLUEvaluationTask(name="mmlu:high_school_psychology", hf_subset="high_school_psychology"), + CustomMMLUEvaluationTask(name="mmlu:high_school_statistics", hf_subset="high_school_statistics"), + CustomMMLUEvaluationTask(name="mmlu:high_school_us_history", hf_subset="high_school_us_history"), + CustomMMLUEvaluationTask(name="mmlu:high_school_world_history", hf_subset="high_school_world_history"), + CustomMMLUEvaluationTask(name="mmlu:human_aging", hf_subset="human_aging"), + CustomMMLUEvaluationTask(name="mmlu:human_sexuality", hf_subset="human_sexuality"), + CustomMMLUEvaluationTask(name="mmlu:international_law", hf_subset="international_law"), + CustomMMLUEvaluationTask(name="mmlu:jurisprudence", hf_subset="jurisprudence"), + CustomMMLUEvaluationTask(name="mmlu:logical_fallacies", hf_subset="logical_fallacies"), + CustomMMLUEvaluationTask(name="mmlu:machine_learning", hf_subset="machine_learning"), + CustomMMLUEvaluationTask(name="mmlu:management", hf_subset="management"), + CustomMMLUEvaluationTask(name="mmlu:marketing", hf_subset="marketing"), + CustomMMLUEvaluationTask(name="mmlu:medical_genetics", hf_subset="medical_genetics"), + CustomMMLUEvaluationTask(name="mmlu:miscellaneous", hf_subset="miscellaneous"), + CustomMMLUEvaluationTask(name="mmlu:moral_disputes", hf_subset="moral_disputes"), + CustomMMLUEvaluationTask(name="mmlu:moral_scenarios", hf_subset="moral_scenarios"), + CustomMMLUEvaluationTask(name="mmlu:nutrition", hf_subset="nutrition"), + CustomMMLUEvaluationTask(name="mmlu:philosophy", hf_subset="philosophy"), + CustomMMLUEvaluationTask(name="mmlu:prehistory", hf_subset="prehistory"), + CustomMMLUEvaluationTask(name="mmlu:professional_accounting", hf_subset="professional_accounting"), + CustomMMLUEvaluationTask(name="mmlu:professional_law", hf_subset="professional_law"), + CustomMMLUEvaluationTask(name="mmlu:professional_medicine", hf_subset="professional_medicine"), + CustomMMLUEvaluationTask(name="mmlu:professional_psychology", hf_subset="professional_psychology"), + CustomMMLUEvaluationTask(name="mmlu:public_relations", hf_subset="public_relations"), + CustomMMLUEvaluationTask(name="mmlu:security_studies", hf_subset="security_studies"), + CustomMMLUEvaluationTask(name="mmlu:sociology", hf_subset="sociology"), + CustomMMLUEvaluationTask(name="mmlu:us_foreign_policy", hf_subset="us_foreign_policy"), + CustomMMLUEvaluationTask(name="mmlu:virology", hf_subset="virology"), + CustomMMLUEvaluationTask(name="mmlu:world_religions", hf_subset="world_religions"), +] + + +def mmlu_harness(line, task_name: str = None): + topic = line["subject"] + prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" + prompt += line["question"] + "\n" + prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) + prompt += "Answer:" + + gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] + "__few_shots" in line and line["__few_shots"] is True # We are adding few shots + + return Doc( + task_name=task_name, + query=prompt, + choices=[" A", " B", " C", " D"], + target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix], + gold_index=gold_ix, + instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", + ) + + +def mmlu_prompt(line, task_name: str = None): + """MMLU prompt without letters""" + topic = line["subject"] + prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: " + prompt += line["question"] + "\nAnswer:" + + return Doc( + task_name=task_name, + query=prompt, + choices=[f" {c}" for c in line["choices"]], + gold_index=line["answer"], + instruction=f"The following are questions about {topic.replace('_', ' ')}.\n", + ) + + +# MMLU_STRING = {t: f'custom|{t.name}|5|1' for t in MMLU_TASKS} +MMLU_STRING = [(t, f"custom|{t.name}|0|1") for t in MMLU_TASKS] +_TASKS_STRINGS.extend(MMLU_STRING) +_TASKS += MMLU_TASKS + +## BBH ## + + +class CustomBBHEvaluationTask(LightevalTaskConfig): + def __init__( + self, + name, + prompt_function="bbh_prompt", + hf_repo="lighteval/big_bench_hard", + hf_subset=None, + metric=[Metrics.exact_match], + hf_avail_splits=["train"], + evaluation_splits=["train"], + few_shots_split="train", + few_shots_select=None, + suite=None, + generation_size=4, + stop_sequence=None, + output_regex=None, + frozen=False, + ): + super().__init__( + name=name, + prompt_function=prompt_function, + hf_repo=hf_repo, + hf_subset=hf_subset, + metric=metric, + hf_avail_splits=hf_avail_splits, + evaluation_splits=evaluation_splits, + few_shots_split=few_shots_split, + few_shots_select=few_shots_select, + suite=suite, + generation_size=generation_size, + stop_sequence=stop_sequence, + output_regex=output_regex, + frozen=frozen, + ) + + +BBH_TASKS = [ + CustomBBHEvaluationTask(name="bbh:boolean_expressions", hf_subset="boolean_expressions"), + CustomBBHEvaluationTask(name="bbh:causal_judgement", hf_subset="causal_judgement"), + CustomBBHEvaluationTask(name="bbh:date_understanding", hf_subset="date_understanding"), + CustomBBHEvaluationTask(name="bbh:disambiguation_qa", hf_subset="disambiguation_qa"), + CustomBBHEvaluationTask(name="bbh:dyck_languages", hf_subset="dyck_languages"), + CustomBBHEvaluationTask(name="bbh:formal_fallacies", hf_subset="formal_fallacies"), + CustomBBHEvaluationTask(name="bbh:geometric_shapes", hf_subset="geometric_shapes"), + CustomBBHEvaluationTask(name="bbh:hyperbaton", hf_subset="hyperbaton"), + CustomBBHEvaluationTask(name="bbh:logical_deduction_five_objects", hf_subset="logical_deduction_five_objects"), + CustomBBHEvaluationTask(name="bbh:logical_deduction_seven_objects", hf_subset="logical_deduction_seven_objects"), + CustomBBHEvaluationTask(name="bbh:logical_deduction_three_objects", hf_subset="logical_deduction_three_objects"), + CustomBBHEvaluationTask(name="bbh:movie_recommendation", hf_subset="movie_recommendation"), + CustomBBHEvaluationTask(name="bbh:multistep_arithmetic_two", hf_subset="multistep_arithmetic_two"), + CustomBBHEvaluationTask(name="bbh:navigate", hf_subset="navigate"), + CustomBBHEvaluationTask(name="bbh:object_counting", hf_subset="object_counting"), + CustomBBHEvaluationTask(name="bbh:penguins_in_a_table", hf_subset="penguins_in_a_table"), + CustomBBHEvaluationTask(name="bbh:reasoning_about_colored_objects", hf_subset="reasoning_about_colored_objects"), + CustomBBHEvaluationTask(name="bbh:ruin_names", hf_subset="ruin_names"), + CustomBBHEvaluationTask( + name="bbh:salient_translation_error_detection", hf_subset="salient_translation_error_detection" + ), + CustomBBHEvaluationTask(name="bbh:snarks", hf_subset="snarks"), + CustomBBHEvaluationTask(name="bbh:sports_understanding", hf_subset="sports_understanding"), + CustomBBHEvaluationTask(name="bbh:temporal_sequences", hf_subset="temporal_sequences"), + CustomBBHEvaluationTask( + name="bbh:tracking_shuffled_objects_five_objects", hf_subset="tracking_shuffled_objects_five_objects" + ), + CustomBBHEvaluationTask( + name="bbh:tracking_shuffled_objects_seven_objects", hf_subset="tracking_shuffled_objects_seven_objects" + ), + CustomBBHEvaluationTask( + name="bbh:tracking_shuffled_objects_three_objects", hf_subset="tracking_shuffled_objects_three_objects" + ), + CustomBBHEvaluationTask(name="bbh:web_of_lies", hf_subset="web_of_lies"), + CustomBBHEvaluationTask(name="bbh:word_sorting", hf_subset="word_sorting"), +] + + +def bbh_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["input"] + "\nAnswer: ", + choices=[line["target"]], + gold_index=0, + ) + + +# BBH_STRING = {t: f'custom|{t.name}|3|1' for t in BBH_TASKS} +BBH_STRING = [(t, f"custom|{t.name}|0|1") for t in BBH_TASKS] +_TASKS_STRINGS.extend(BBH_STRING) +_TASKS += BBH_TASKS + + +## AGI eval ## +class CustomAGIEvalEvaluationTask(LightevalTaskConfig): + def __init__( + self, + name, + prompt_function="agi_eval_prompt_no_letters", + hf_repo="lighteval/agi_eval_en", + hf_subset=None, + # metric=[Metrics.loglikelihood_acc_single_token], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], + hf_avail_splits=["train", "validation"], + evaluation_splits=["train"], + few_shots_split="validation", + few_shots_select=None, + suite=None, + generation_size=-1, + stop_sequence=None, + output_regex=None, + frozen=False, + ): + super().__init__( + name=name, + prompt_function=prompt_function, + hf_repo=hf_repo, + hf_subset=hf_subset, + metric=metric, + hf_avail_splits=hf_avail_splits, + evaluation_splits=evaluation_splits, + few_shots_split=few_shots_split, + few_shots_select=few_shots_select, + suite=suite, + generation_size=generation_size, + stop_sequence=stop_sequence, + output_regex=output_regex, + frozen=frozen, + ) + + +AGIEVAL_TASKS = [ + CustomAGIEvalEvaluationTask(name="agi_eval:aqua_rat", hf_subset="aqua_rat"), + CustomAGIEvalEvaluationTask(name="agi_eval:logiqa-en", hf_subset="logiqa-en"), + CustomAGIEvalEvaluationTask(name="agi_eval:lsat-ar", hf_subset="lsat-ar"), + CustomAGIEvalEvaluationTask(name="agi_eval:lsat-lr", hf_subset="lsat-lr"), + CustomAGIEvalEvaluationTask(name="agi_eval:lsat-rc", hf_subset="lsat-rc"), + CustomAGIEvalEvaluationTask( + name="agi_eval:math", + hf_subset="math", + prompt_function="agi_eval_math_prompt", + metric=[Metrics.exact_match, Metrics.quasi_exact_match], + generation_size=40, + ), + CustomAGIEvalEvaluationTask(name="agi_eval:sat-en", hf_subset="sat-en"), + CustomAGIEvalEvaluationTask(name="agi_eval:sat-math", hf_subset="sat-math"), +] + + +def agi_eval_math_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["question"], + choices=[line["answer"]], + gold_index=0, + instruction="", + ) + + +def agi_eval_prompt(line, task_name: str = None): + cleaned_options = [o.replace("(", "").replace(")", " ") for o in line["options"]] + prompt = "The following are multiple choice questions (with answers).\n\n" + prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n" + prompt += "Answer: " + + choices = LETTER_INDICES[: len(line["options"])] + + output = Doc( + query=prompt, + instruction="The following are multiple choice questions (with answers).\n\n", + choices=None, # updated below + gold_index=None, # updated below + ) + + if line["label"]: + output.choices = choices + output.gold_index = LETTER_INDICES.index(line["label"].strip()) + else: + output.choices = [line["answer"]] + output.gold_index = 0 + + return output + + +def agi_eval_prompt_no_letters(line, task_name: str = None): + cleaned_options = [ + " " + o.replace("(A)", "").replace("(B)", "").replace("(C)", "").replace("(D)", "").replace("(E)", "") + for o in line["options"] + ] + + output = Doc( + query=line["question"], + choices=cleaned_options, + gold_index=LETTER_INDICES.index(line["label"].strip()), + instruction="", + ) + + return output + + +# AGIEVAL_STRING = {t: f'custom|{t.name}|5|1' for t in AGIEVAL_TASKS} +AGIEVAL_STRING = [(t, f"custom|{t.name}|0|1") for t in AGIEVAL_TASKS] +_TASKS_STRINGS.extend(AGIEVAL_STRING) +_TASKS += AGIEVAL_TASKS + + +OPEN_LLM_LEADERBOARD_STRING = [ + "custom|arc:challenge|25|1", + "custom|hellaswag|10|1", + "lighteval|truthfulqa:mc|0|1", + "custom|winogrande|5|1", + "lighteval|gsm8k|5|1", +] + [f"custom|{t.name}|5|1" for t in MMLU_TASKS] + + +## HUMAN EVAL ## +#TODO @eliebak add human eval again +# human_eval = LightevalTaskConfig( +# name="human_eval", +# prompt_function="human_eval", +# hf_repo="lighteval/human_eval", +# metric=["human_eval_pass_at_1"], +# ), + + +EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING]) + +# Convert to dict for lighteval +TASKS_TABLE = [task.as_dict() for task in _TASKS] +# You can have a few pre-organised groups of tasks +# TODO @eliebak add math and code here +TASKS_GROUPS = { + "all": ",".join(t[1] for t in _TASKS_STRINGS), + "early-signal": EARLY_SIGNAL_TASKS, + "open-llm-leaderboard": ",".join(OPEN_LLM_LEADERBOARD_STRING), +} + +if __name__ == "__main__": + print(t["name"] for t in TASKS_TABLE) + print(len(TASKS_TABLE)) diff --git a/src/nanotron/lighteval/one_job_runner.py b/src/nanotron/lighteval/one_job_runner.py new file mode 100644 index 00000000..30e7eaf8 --- /dev/null +++ b/src/nanotron/lighteval/one_job_runner.py @@ -0,0 +1,148 @@ +""" Mostly complete a SLURM template with a link to a single checkpoint on s3 and launch it +""" +import datetime +import os +import re +import subprocess +from typing import List, Optional, Tuple, Union + +import jinja2 +from nanotron import logging +from nanotron.logging import log_rank +from nanotron.parallel import ParallelContext + +from nanotron.config import Config, LightEvalConfig + +logger = logging.get_logger(__name__) + + +class LightEvalRunner: + def __init__(self, config: Config, parallel_context: Optional[ParallelContext] = None): + self.config = config + self.lighteval_config = config.lighteval + self.parallel_context = parallel_context + + def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: + """Run light evaluation on uploaded files.""" + logger.warning(f"Lighteval Runner got {len(uploaded_files)} files. Checking configs.") + config_files = [ + f for f in uploaded_files if "config.py" in f["destination"] or "config.yaml" in f["destination"] + ] + # Sanity check on the config files len (we want only one) + if len(config_files) == 0: + log_rank( + "No config files founds in uploaded checkpoints. Not running evaluation.", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return + if len(config_files) > 1: + log_rank( + "Found multiple config files in uploaded checkpoints.", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return + checkpoint_path = config_files[0]["destination"].replace("config.yaml", "") + + slurm_job_id, slurm_log = run_slurm_one_job( + config = self.config, + slurm_template=self.lighteval_config.slurm_template, + model_checkpoint_path=checkpoint_path, + ) + + return slurm_job_id, slurm_log + + +def run_slurm_one_job( + config: Config, + model_checkpoint_path: str, + slurm_template: str, + slurm_name: Optional[str] = "eval", + slurm_kwargs: Optional[dict] = None, #add slurm_kwargs and modify the jinja template in case you need to adapt it to your slurm cluster. +): + """Launch a single job on Slurm with the given mapping + Args: + slurm_config: Slurm configuration + mapping: Mapping to use for the job script (see SLURM_ONE_JOB_MAPPING) + """ + + eval_launch_script_path=os.path.join(config.slurm.evals_logs_path, "launch-config") + eval_logs_path= os.path.join(config.slurm.evals_logs_path, "logs") + + environment = jinja2.Environment( + comment_start_string="{=", + comment_end_string="=}", + ) + + with open(slurm_template, "r") as f: + SLURM_JOBS_ARRAY_TEMPLATE = environment.from_string(f.read()) + + launch_string = SLURM_JOBS_ARRAY_TEMPLATE.render( + model_checkpoint_path=model_checkpoint_path, + job_name=f"{slurm_name}-eval", + n_tasks_per_node=config.slurm.n_tasks_per_node, + partition=config.slurm.gpu_partition, + gpu_per_node=config.slurm.gpu_per_node, + cpus_per_task=config.slurm.cpus_per_task, + eval_path=eval_logs_path, + mail=config.slurm.mail, + conda_path=config.slurm.conda_path, + conda_env_path=config.slurm.conda_env_path, + local_path=config.checkpoints.checkpoints_path, + **(slurm_kwargs if slurm_kwargs else {}), + ) + + match = re.match(r"#SBATCH --output=(.*)", launch_string) + slurm_output_path = match.group(1) if match else "" + + current_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + launch_script_path = os.path.join(eval_launch_script_path, f"launch_script-{current_time}.slurm") + + # make sure the folder exists before write + # Extract the folder path from launch_script_path + folder_path = os.path.dirname(launch_script_path) + + # Check if the folder exists. If not, create it. + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + with open(launch_script_path, "w") as f: + f.write(launch_string) + + logger.warning(f'Launching Slurm job {slurm_name} with launch script "{launch_script_path}"') + + # Preserve important environment variables + env = { + 'PATH': os.environ['PATH'], + 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), + 'HOME': os.path.expanduser("~"), + } + + try: + # Use subprocess.run instead of check_output for better error handling + result = subprocess.run( + ["sbatch", launch_script_path], + env=env, + check=True, + capture_output=True, + text=True + ) + output = result.stdout + job_ids = output.split()[-1] + output_log = ( + slurm_output_path.replace("%x", slurm_name).replace("%j", job_ids).replace("%n", "0").replace("%t", "0") + ) + logger.warning(f'Slurm job launched successfully with id={job_ids}, logging outputs at "{output_log}"') + except subprocess.CalledProcessError as e: + logger.error(f"Error while launching Slurm job: {e}") + logger.error(f"Command output: {e.output}") + logger.error(f"Command stderr: {e.stderr}") + job_ids = None + output_log = None + + return job_ids, output_log diff --git a/src/nanotron/lighteval/run_eval.slurm.jinja b/src/nanotron/lighteval/run_eval.slurm.jinja new file mode 100644 index 00000000..6b2e1245 --- /dev/null +++ b/src/nanotron/lighteval/run_eval.slurm.jinja @@ -0,0 +1,73 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }}-eval +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --partition={{ partition }} +#SBATCH --gres=gpu:{{ gpu_per_node }} +#SBATCH --cpus-per-task={{ cpus_per_task}} +#SBATCH --output={{ eval_path }}/eval-%x-%n-%j +#SBATCH --error={{ eval_path }}/eval-%x-%n-%j +#SBATCH --qos=high +#SBATCH --dependency=singleton +#SBATCH --mail-type=FAIL +#SBATCH --mail-user={{ mail }} + +########################################### +source ~/.bashrc +source {{ conda_path }} +conda activate {{ conda_env_path}} #Modify this line if you use something different than conda + +LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={{ local_path }} + +# [END] ADAPT TO YOUR ENVIRONMENT +########################################### + + +set -x -e +echo "START TIME: $(date)" +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +# Hugging Face token +if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then + # Attempt to read the token from the cache + if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then + export HUGGING_FACE_HUB_TOKEN=$TOKEN + else + echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." + exit 1 + fi +fi + +export TMPDIR=/scratch +export CUBLAS_WORKSPACE_CONFIG=":4096:8" +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +echo go $COUNT_NODE +echo $HOSTNAMES + +# Copying checkpoint from s3 to the node on node +mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER +s5cmd cp --exclude "optimizer/*" {{ model_checkpoint_path }}* $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER + +torch_dist_args="--nproc_per_node 8 \ + --nnodes $COUNT_NODE \ + --max_restarts 0 \ + --tee 3 \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: " + +launch_args="$torch_dist_args \ + /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path ${LOCAL_DOWNLOAD_CHECKPOINT_FOLDER}/config.yaml \ + " + +srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file diff --git a/src/nanotron/lighteval/run_evals.py b/src/nanotron/lighteval/run_evals.py new file mode 100644 index 00000000..f0f112ec --- /dev/null +++ b/src/nanotron/lighteval/run_evals.py @@ -0,0 +1,35 @@ +# flake8: noqa: C901 +import argparse + +from lighteval.main import main + +from nanotron.config import Config + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint-config-path", + type=str, + required=True, + help="Path to the Nanotron checkpoint YAML or python config file, potentially on S3", + ) + parser.add_argument( + "--lighteval-override", + type=str, + help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config", + ) + parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Cache directory", + ) + + return parser + + +if __name__ == "__main__": + parser = get_parser() + args, unknowns = parser.parse_known_args() + main(args.checkpoint_config_path, args.lighteval_override, args.cache_dir, config_cls=Config) \ No newline at end of file diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 0a7f57f1..471a2018 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -20,13 +20,13 @@ ) from nanotron.s3_checkpoints import S3Mover, check_path_is_local -from nanotron.utils import check_path_is_s3 import torch from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader from nanotron import distributed as dist from nanotron import logging +from nanotron.lighteval import LightEvalRunner from nanotron.config import ( Config, DatasetStageArgs, @@ -278,6 +278,12 @@ def post_init(self): ) else: self.s3_mover = None + if self.config.lighteval is not None and dist.get_rank(self.parallel_context.world_pg) == 0: + # We only start evaluation runs once on the first node + if self.s3_mover is None: + raise ValueError("lighteval requires s3 upload of checkpoints to be enabled") + self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) + self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint def pre_training(self, *args, **kwargs): self._print_training_plan() @@ -860,7 +866,7 @@ def pre_save_checkpoint(self) -> Path: self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) if self.s3_mover.post_upload_callback_outputs is not None: slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs - self.log_object({"job_id": slurm_job_id, "log": slurm_log}, "slurm_eval") + log_rank(f"launching eval job: job_id={slurm_job_id} log at {slurm_log} slurm_eval", logger=logger, level=logging.INFO, rank=0) def post_save_checkpoint(self): # Upload to S3 From e609d1c665062992c936e8c5a7cd7f559b0647b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Sat, 24 Aug 2024 19:25:51 +0000 Subject: [PATCH 03/43] add first version of launcher (still ugly) --- launcher.py | 377 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 377 insertions(+) create mode 100644 launcher.py diff --git a/launcher.py b/launcher.py new file mode 100644 index 00000000..03a29754 --- /dev/null +++ b/launcher.py @@ -0,0 +1,377 @@ +import os +import subprocess +import tempfile +from datetime import datetime +import math +import torch + +import argparse + +from nanotron.logging import human_format +from nanotron.models.llama import LlamaConfig + +from nanotron.config import ( + Config, + DataArgs, + NanosetDatasetsArgs, + S3UploadArgs, + SlurmArgs, + CheckpointsArgs, + GeneralArgs, + LightEvalConfig, + LightEvalLoggingArgs, + LightEvalTasksArgs, + LoggingArgs, + LRSchedulerArgs, + ModelArgs, + OptimizerArgs, + AdamWOptimizerArgs, + ParallelismArgs, + RandomInit, + TokenizerArgs, + TokensArgs, + LightEvalWandbLoggerConfig, + DatasetStageArgs, +) + +def launch_slurm_job(launch_file_contents, *args): + """ + Small helper function to save a sbatch script and call it. + Args: + launch_file_contents: Contents of the sbatch script + *args: any other arguments to pass to the sbatch command + + Returns: the id of the launched slurm job + + """ + with tempfile.NamedTemporaryFile("w") as f: + f.write(launch_file_contents) + f.flush() + return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[-1] + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("project", help="project name", type=str) + parser.add_argument("--slurm", help="use slurm", action="store_true") + parser.add_argument("--name", help="run name", type=str, default=None) + parser.add_argument("--seed", help="seed", type=int, default=8) + parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="normal") + args = parser.parse_args() + + PROJECT = args.project + if args.name is not None: + RUN = f"{PROJECT}-{args.name}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + else: + RUN = f"{PROJECT}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + + ## FOR SANITY CHECK LATER + # from dataclasses import fields, is_dataclass + + # def print_differences(target, updates): + # if not is_dataclass(target) or not is_dataclass(updates): + # raise ValueError("Both target and updates should be dataclass instances") + + # for field in fields(target): + # update_value = getattr(updates, field.name) + + # if update_value is not None: + # if is_dataclass(update_value): + # print_differences(getattr(target, field.name), update_value) + # else: + # target_value = getattr(target, field.name) + # if update_value != target_value: + # if update_value.__class__.__module__ != "builtins": + # continue + # print(f"{field.name}: {target_value} -> {update_value}") + + + general = GeneralArgs( + project=PROJECT, + run=RUN, + repo_id="HuggingFaceSmol/test-nanotron", + seed=args.seed, + temp_dir="/scratch", + ) + if args.slurm: + slurm = SlurmArgs( + gpu_partition="hopper-prod", + job_name=f"{PROJECT}-{args.name}", + nodes=2, + logs_path=f"/fsx/elie_bakouch/nanotron/debug", + conda_path="/fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh", + conda_env_path="/fsx/elie_bakouch/miniconda3/envs/smollm", + ) + + model_config = LlamaConfig( + bos_token_id=0, + eos_token_id=0, + hidden_act="silu", + hidden_size=576, + initializer_range=0.02, + intermediate_size=1536, + max_position_embeddings=2048, + num_attention_heads=9, + num_hidden_layers=30, + num_key_value_heads=3, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + tie_word_embeddings=True, + use_cache=True, + vocab_size=49152, + ) + if model_config.tie_word_embeddings ==True: + tie_word_embeddings_multiplier = 1 + else: + tie_word_embeddings_multiplier = 2 + + num_params = human_format( + model_config.vocab_size * model_config.hidden_size * tie_word_embeddings_multiplier + + model_config.num_hidden_layers + * ( + 3 * model_config.hidden_size * model_config.intermediate_size + + 4 * model_config.hidden_size * model_config.hidden_size + ) + ).replace(".", "p") + + print(f"🏋️ Model has {num_params} parameters") + + # Do we have a SLURM task ID? + # You can SLURM_ARRAY_TASK_ID to run multiple runs with predefined HP + task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", -1)) + job_id = os.environ.get("SLURM_JOB_ID", "") + + + + lighteval = LightEvalConfig( + tasks=LightEvalTasksArgs( + tasks="early-signal", # "generatives", "all" + custom_tasks="nanotron.lighteval.evaluation_tasks", + max_samples=1000, # Cap very large evals or for debugging + dataset_loading_processes=8, + ), + parallelism=ParallelismArgs( + dp=8, + pp=1, + tp=1, + pp_engine="1f1b", + tp_mode="ALL_REDUCE", + # recompute_granularity="selective", + tp_linear_async_communication=False, + ), + batch_size=16, + wandb=LightEvalWandbLoggerConfig( + wandb_project=PROJECT, + wandb_entity="eliebak", + wandb_run_name=f"{RUN}_evals", + ), + logging=LightEvalLoggingArgs( + local_output_path=f"{general.temp_dir}/lighteval/{RUN}", + push_details_to_hub=False, + push_results_to_hub=True, + push_results_to_tensorboard=True, + #hub_repo_details=REPO_ID, + hub_repo_results=general.repo_id, + hub_repo_tensorboard=general.repo_id, + tensorboard_metric_prefix="e", + ), + slurm_template="/fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_eval.slurm.jinja", + ) + + + checkpoints = CheckpointsArgs( + checkpoints_path=f"checkpoints/{RUN}", + checkpoints_path_is_shared_file_system=False, + resume_checkpoint_path=None, + checkpoint_interval=20, + save_initial_state=False, + ) + + parallelism = ParallelismArgs( + dp=16, + pp=1, + tp=1, + pp_engine="1f1b", + tp_mode="REDUCE_SCATTER", + tp_linear_async_communication=True, + ) + #Add sanity check for the number of GPUs and the number of nodes ? + print(f"🤖 {slurm.nodes} Nodes | {parallelism.dp*parallelism.pp*parallelism.tp} GPUs | 3D Config : DP {parallelism.dp} / PP {parallelism.pp} / TP {parallelism.tp}") + + tokens = TokensArgs( + batch_accumulation_per_replica=8, + micro_batch_size=16, + sequence_length=2048, + train_steps=100, + val_check_interval=-1, + ) + + model = ModelArgs( + model_config=model_config, + make_vocab_size_divisible_by=1, + init_method=RandomInit( + std=math.sqrt(model_config.hidden_size), + ), + dtype=torch.bfloat16, + ) + + logging = LoggingArgs( + # 'debug', 'info', 'warning', 'error', 'critical' and 'passive' + log_level="info", + log_level_replica="info", + iteration_step_info_interval=1, + ) + + learning_rate_scheduler = LRSchedulerArgs( + learning_rate=1e-4, #llama one + lr_warmup_steps=10, + lr_warmup_style="linear", + lr_decay_style="linear", + lr_decay_steps = 20, + lr_decay_starting_step= 80, + min_decay_lr=0, + ) + + optimizer = OptimizerArgs( + zero_stage=0, + weight_decay=0.01, + clip_grad=1.0, + accumulate_grad_in_fp32=True, + learning_rate_scheduler=learning_rate_scheduler, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + ), + ) + + tokenizer = TokenizerArgs( + tokenizer_name_or_path="lvwerra/the-tokenizer-v1", + ) + + s3_upload = S3UploadArgs( + upload_s3_path=f"s3://elie-exp/debug_nanotron/test/", + remove_after_upload=True, + s5cmd_numworkers=16, + s5cmd_concurrency=5, + s5cmd_path=os.path.join(slurm.conda_env_path, "bin/s5cmd"), + ) + + data_stages=[ + DatasetStageArgs( + data=DataArgs( + dataset=NanosetDatasetsArgs( + dataset_folder={ + "/fsx/elie_bakouch/nanotron/datasets/cosmopedia-v2":0.7, + "/fsx/elie_bakouch/nanotron/datasets/fineweb-edu-dedup":0.3, + }, + ), + seed=general.seed, + ), + name="training stage", + start_training_step=1, + ), + ] + + config = Config( + general=general, + checkpoints=checkpoints, + parallelism=parallelism, + model=model, + tokenizer=tokenizer, + logging=logging, + tokens=tokens, + optimizer=optimizer, + data_stages=data_stages, + s3_upload=s3_upload, + lighteval=lighteval, + slurm=slurm, + ) + if slurm is not None: + dir = os.path.dirname(__file__) + + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + os.makedirs(config.slurm.config_logs_path, exist_ok=True) + config_path_yaml = f"{config.slurm.config_logs_path}/{timestamp}.yaml" + config.save_as_yaml(config_path_yaml) + + os.makedirs(f"{config.slurm.slurm_logs_path}/", exist_ok=True) + + sbatch_script = f"""#!/bin/bash +#SBATCH --job-name={slurm.job_name} +#SBATCH --nodes={slurm.nodes} +#SBATCH --ntasks-per-node={slurm.n_tasks_per_node} # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task={slurm.cpus_per_task} +#SBATCH --gres=gpu:{slurm.gpu_per_node} +#SBATCH --partition={slurm.gpu_partition} +#SBATCH --output={slurm.slurm_logs_path}/train-{timestamp}-%x-%j.out +#SBATCH --array={slurm.array} +#SBATCH --qos={slurm.qos} +#SBATCH --begin=now+0minutes +#SBATCH --mail-type=ALL +#SBATCH --mail-user={slurm.mail} +#SBATCH --requeue + + +TRAINER_PYTHON_FILE=/fsx/elie_bakouch/nanotron/run_train.py +nvidia-smi +set -x -e +source ~/.bashrc +source /fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh +conda activate {config.slurm.conda_env_path} #Modify this line if you use something different than conda + +module load cuda/12.1 + +echo "START TIME: $(date)" + +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" +start=$(date +%s) +echo "$(date -d @${{start}} "+%Y-%m-%d %H:%M:%S"): ${{SLURM_JOB_NAME}} start id=${{SLURM_JOB_ID}}\n" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +export TMPDIR=/scratch +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +echo go $COUNT_NODE +echo $HOSTNAMES + +##### MOVE TO YAML ###### + +CMD=" \ + $TRAINER_PYTHON_FILE \ + --config-file {config_path_yaml} + " + +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node {config.slurm.gpu_per_node} \ + --nnodes $COUNT_NODE \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: \ + --max_restarts 0 \ + --tee 3 \ + " + +# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub +random_milliseconds=$(( RANDOM % 1001 )) +sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") +echo "Sleeping for $sleep_time seconds..." +sleep $sleep_time + +launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + + +echo "END TIME: $(date)" +""" + print(f"Slurm job launched with id={launch_slurm_job(sbatch_script)}") \ No newline at end of file From a5c0cc23dfa26b73f65e6ac3b218ac5fdeabc9d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Sun, 25 Aug 2024 01:11:25 +0000 Subject: [PATCH 04/43] not yet functional, lighteval stuff to figure out --- run_train.py | 4 +--- src/nanotron/lighteval/evaluation_tasks.py | 6 +++--- src/nanotron/lighteval/run_evals.py | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/run_train.py b/run_train.py index 09c27d35..021d955d 100644 --- a/run_train.py +++ b/run_train.py @@ -37,9 +37,7 @@ except ImportError: hf_hub_version = None tf_version = None - -import torch._dynamo -torch._dynamo.config.suppress_errors = True + logger = logging.get_logger(__name__) diff --git a/src/nanotron/lighteval/evaluation_tasks.py b/src/nanotron/lighteval/evaluation_tasks.py index 2c0d9449..88bdd0b1 100644 --- a/src/nanotron/lighteval/evaluation_tasks.py +++ b/src/nanotron/lighteval/evaluation_tasks.py @@ -9,10 +9,10 @@ from dataclasses import asdict from typing import Dict, List, Tuple -from lighteval.metrics import Metrics +from lighteval.metrics.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc -from lighteval.tasks.tasks_prompt_formatting import LETTER_INDICES +from lighteval.tasks.default_prompts import LETTER_INDICES _TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = [] _TASKS: List[LightevalTaskConfig] = [] @@ -640,7 +640,7 @@ def agi_eval_prompt_no_letters(line, task_name: str = None): EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING]) # Convert to dict for lighteval -TASKS_TABLE = [task.as_dict() for task in _TASKS] +TASKS_TABLE = [asdict(task) for task in _TASKS] # You can have a few pre-organised groups of tasks # TODO @eliebak add math and code here TASKS_GROUPS = { diff --git a/src/nanotron/lighteval/run_evals.py b/src/nanotron/lighteval/run_evals.py index f0f112ec..e2d84d1b 100644 --- a/src/nanotron/lighteval/run_evals.py +++ b/src/nanotron/lighteval/run_evals.py @@ -1,7 +1,7 @@ # flake8: noqa: C901 import argparse -from lighteval.main import main +from lighteval.main_nanotron import main from nanotron.config import Config @@ -32,4 +32,4 @@ def get_parser(): if __name__ == "__main__": parser = get_parser() args, unknowns = parser.parse_known_args() - main(args.checkpoint_config_path, args.lighteval_override, args.cache_dir, config_cls=Config) \ No newline at end of file + main(args.checkpoint_config_path, args.lighteval_override, args.cache_dir) \ No newline at end of file From 6b58c25a26425dc9de75d2f289a08a63dd28e103 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Sun, 25 Aug 2024 22:42:53 +0000 Subject: [PATCH 05/43] remove torch.compile() bc it's not working (might be a me pb) --- src/nanotron/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 49ea86e6..ca6c2441 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -165,7 +165,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) + self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) From 34b50a6cb3403c172d25493dabece7997d918df3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Sun, 25 Aug 2024 23:20:48 +0000 Subject: [PATCH 06/43] update launcher.py --- launcher.py | 73 ++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 17 deletions(-) diff --git a/launcher.py b/launcher.py index 03a29754..e0e9c4cb 100644 --- a/launcher.py +++ b/launcher.py @@ -135,7 +135,7 @@ def launch_slurm_job(launch_file_contents, *args): ) ).replace(".", "p") - print(f"🏋️ Model has {num_params} parameters") + print(f"🏋️ Model has {num_params} parameters") # Do we have a SLURM task ID? # You can SLURM_ARRAY_TASK_ID to run multiple runs with predefined HP @@ -206,6 +206,14 @@ def launch_slurm_job(launch_file_contents, *args): train_steps=100, val_check_interval=-1, ) + BS = tokens.micro_batch_size*tokens.batch_accumulation_per_replica*tokens.sequence_length + GBS = BS * parallelism.dp + + total_tokens = tokens.train_steps * GBS + total_tokens_billions = total_tokens / 1e9 + print(f"📙 Number of tokens: {total_tokens_billions:.2f} billion") + + model = ModelArgs( model_config=model_config, @@ -232,7 +240,23 @@ def launch_slurm_job(launch_file_contents, *args): lr_decay_starting_step= 80, min_decay_lr=0, ) - + # Calculate and print learning rate and global batch size information + lr_initial = learning_rate_scheduler.learning_rate + lr_min = learning_rate_scheduler.min_decay_lr + lr_warmup_steps = learning_rate_scheduler.lr_warmup_steps + lr_decay_steps = learning_rate_scheduler.lr_decay_steps + lr_decay_start = learning_rate_scheduler.lr_decay_starting_step + lr_decay_style = learning_rate_scheduler.lr_decay_style + + print(f"📊 Learning Rate Schedule:") + print(f" Initial LR: {lr_initial:.2e}") + print(f" Warmup: {learning_rate_scheduler.lr_warmup_style} increase over {lr_warmup_steps} steps") + if lr_decay_start != lr_warmup_steps: + print(f" Constant LR until step {lr_decay_start}") + print(f" {lr_decay_style.capitalize()} decay from step {lr_decay_start} to {lr_decay_start + lr_decay_steps}") + print(f" Final LR: {lr_min:.2e}") + + print(f"🚚 Global Batch Size: {GBS:,} tokens") optimizer = OptimizerArgs( zero_stage=0, weight_decay=0.01, @@ -262,13 +286,11 @@ def launch_slurm_job(launch_file_contents, *args): data_stages=[ DatasetStageArgs( data=DataArgs( - dataset=NanosetDatasetsArgs( - dataset_folder={ - "/fsx/elie_bakouch/nanotron/datasets/cosmopedia-v2":0.7, - "/fsx/elie_bakouch/nanotron/datasets/fineweb-edu-dedup":0.3, - }, - ), - seed=general.seed, + dataset=NanosetDatasetsArgs( + dataset_folder="/fsx/elie_bakouch/nanotron/datasets/cosmopedia-v2", + ), + num_loading_workers=0, + seed=general.seed, ), name="training stage", start_training_step=1, @@ -299,7 +321,6 @@ def launch_slurm_job(launch_file_contents, *args): os.makedirs(f"{config.slurm.slurm_logs_path}/", exist_ok=True) - sbatch_script = f"""#!/bin/bash #SBATCH --job-name={slurm.job_name} #SBATCH --nodes={slurm.nodes} #SBATCH --ntasks-per-node={slurm.n_tasks_per_node} # crucial - only 1 task per dist per node! @@ -313,26 +334,39 @@ def launch_slurm_job(launch_file_contents, *args): #SBATCH --mail-type=ALL #SBATCH --mail-user={slurm.mail} #SBATCH --requeue - + sbatch_script = f"""#!/bin/bash +#SBATCH --job-name=test +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=32 +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH --output=/fsx/elie_bakouch/nanotron/debug/main/train-{timestamp}-%x-%j.out +#SBATCH --qos=high +#SBATCH --begin=now+0minutes +#SBATCH --mail-type=ALL +set -x -e TRAINER_PYTHON_FILE=/fsx/elie_bakouch/nanotron/run_train.py nvidia-smi -set -x -e source ~/.bashrc source /fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh -conda activate {config.slurm.conda_env_path} #Modify this line if you use something different than conda +conda activate /fsx/elie_bakouch/miniconda3/envs/smollm #Modify this line if you use something different than conda -module load cuda/12.1 - -echo "START TIME: $(date)" #Show some environment variables echo python3 version = `python3 --version` echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +echo "START TIME: $(date)" +secs_to_human(){{ + echo "$(( ${{1}} / 3600 )):$(( (${{1}} / 60) % 60 )):$(( ${{1}} % 60 ))" +}} start=$(date +%s) echo "$(date -d @${{start}} "+%Y-%m-%d %H:%M:%S"): ${{SLURM_JOB_NAME}} start id=${{SLURM_JOB_ID}}\n" + # SLURM stuff export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) @@ -342,6 +376,8 @@ def launch_slurm_job(launch_file_contents, *args): export TMPDIR=/scratch export CUDA_DEVICE_MAX_CONNECTIONS="1" +module load cuda/12.1 + echo go $COUNT_NODE echo $HOSTNAMES @@ -353,8 +389,11 @@ def launch_slurm_job(launch_file_contents, *args): " export LAUNCHER="python -u -m torch.distributed.run \ - --nproc_per_node {config.slurm.gpu_per_node} \ + --nproc_per_node 8 \ --nnodes $COUNT_NODE \ + --rdzv-backend etcd-v2 \ + --rdzv-endpoint etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379 \ + --rdzv-id $SLURM_JOB_ID \ --node_rank $SLURM_PROCID \ --role $SLURMD_NODENAME: \ --max_restarts 0 \ From 7a105be7e2b780c0e2b09337091347a2e92c7cc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Mon, 26 Aug 2024 00:20:42 +0000 Subject: [PATCH 07/43] fancy launcher --- launcher.py | 136 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 90 insertions(+), 46 deletions(-) diff --git a/launcher.py b/launcher.py index e0e9c4cb..7b192bb4 100644 --- a/launcher.py +++ b/launcher.py @@ -101,6 +101,16 @@ def launch_slurm_job(launch_file_contents, *args): logs_path=f"/fsx/elie_bakouch/nanotron/debug", conda_path="/fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh", conda_env_path="/fsx/elie_bakouch/miniconda3/envs/smollm", + exclude_nodes=["ip-26-0-161-138", "ip-26-0-161-178"], + torchrun_args={ + "rdzv_backend": "etcd-v2", + "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", + "rdzv_id": "$SLURM_JOB_ID" + }, + qos="normal", + mail_type="FAIL", + mail_user="bakouch.elie@gmail.com", + begin="now+0minutes" ) model_config = LlamaConfig( @@ -135,8 +145,6 @@ def launch_slurm_job(launch_file_contents, *args): ) ).replace(".", "p") - print(f"🏋️ Model has {num_params} parameters") - # Do we have a SLURM task ID? # You can SLURM_ARRAY_TASK_ID to run multiple runs with predefined HP task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", -1)) @@ -197,7 +205,6 @@ def launch_slurm_job(launch_file_contents, *args): tp_linear_async_communication=True, ) #Add sanity check for the number of GPUs and the number of nodes ? - print(f"🤖 {slurm.nodes} Nodes | {parallelism.dp*parallelism.pp*parallelism.tp} GPUs | 3D Config : DP {parallelism.dp} / PP {parallelism.pp} / TP {parallelism.tp}") tokens = TokensArgs( batch_accumulation_per_replica=8, @@ -211,9 +218,6 @@ def launch_slurm_job(launch_file_contents, *args): total_tokens = tokens.train_steps * GBS total_tokens_billions = total_tokens / 1e9 - print(f"📙 Number of tokens: {total_tokens_billions:.2f} billion") - - model = ModelArgs( model_config=model_config, @@ -248,15 +252,6 @@ def launch_slurm_job(launch_file_contents, *args): lr_decay_start = learning_rate_scheduler.lr_decay_starting_step lr_decay_style = learning_rate_scheduler.lr_decay_style - print(f"📊 Learning Rate Schedule:") - print(f" Initial LR: {lr_initial:.2e}") - print(f" Warmup: {learning_rate_scheduler.lr_warmup_style} increase over {lr_warmup_steps} steps") - if lr_decay_start != lr_warmup_steps: - print(f" Constant LR until step {lr_decay_start}") - print(f" {lr_decay_style.capitalize()} decay from step {lr_decay_start} to {lr_decay_start + lr_decay_steps}") - print(f" Final LR: {lr_min:.2e}") - - print(f"🚚 Global Batch Size: {GBS:,} tokens") optimizer = OptimizerArgs( zero_stage=0, weight_decay=0.01, @@ -311,6 +306,54 @@ def launch_slurm_job(launch_file_contents, *args): lighteval=lighteval, slurm=slurm, ) + + print(f""" +🏋️ Model Parameters: +┌───────────────────────┬───────────────────────────┐ +│ Total Parameters │ {num_params:>25} │ +│ Layers │ {model_config.num_hidden_layers:>25d} │ +│ Attention Heads │ {model_config.num_attention_heads:>25d} │ +│ Hidden Size │ {model_config.hidden_size:>25d} │ +│ Intermediate Size │ {model_config.intermediate_size:>25d} │ +│ Context Length │ {model_config.max_position_embeddings:>25d} │ +│ Tokenizer │ {tokenizer.tokenizer_name_or_path[:25]:>25} │ +│ Vocab Size │ {model_config.vocab_size:>25d} │ +└───────────────────────┴───────────────────────────┘ +""") + + print(f""" +🤖 Parallelism Configuration: +┌───────────────────────┬───────────────────┐ +│ Nodes │ {slurm.nodes:>17d} │ +│ Total GPUs │ {parallelism.dp*parallelism.pp*parallelism.tp:>17d} │ +│ Data Parallel (DP) │ {parallelism.dp:>17d} │ +│ Pipeline Parallel (PP)│ {parallelism.pp:>17d} │ +│ Tensor Parallel (TP) │ {parallelism.tp:>17d} │ +└───────────────────────┴───────────────────┘ +""") + + print(f""" +📙 Training Configuration: +┌───────────────────────┬───────────────────┐ +│ Total Tokens │ {total_tokens_billions:>16.2f}B │ +│ Global Batch Size │ {GBS:>17,d} │ +│ Batch Size (per GPU) │ {BS:>17,d} │ +└───────────────────────┴───────────────────┘ +""") + + print(f""" +📊 Learning Rate Schedule: +┌───────────────────────┬───────────────────┐ +│ Initial LR │ {lr_initial:>17.2e} │ +│ Warmup Style │ {learning_rate_scheduler.lr_warmup_style[:17]:>17} │ +│ Warmup Steps │ {lr_warmup_steps:>17d} │ +│ Decay Style │ {lr_decay_style[:17]:>17} │ +│ Decay Start Step │ {lr_decay_start:>17d} │ +│ Decay Steps │ {lr_decay_steps:>17d} │ +│ Final LR │ {lr_min:>17.2e} │ +└───────────────────────┴───────────────────┘ +""") + if slurm is not None: dir = os.path.dirname(__file__) @@ -321,37 +364,40 @@ def launch_slurm_job(launch_file_contents, *args): os.makedirs(f"{config.slurm.slurm_logs_path}/", exist_ok=True) -#SBATCH --job-name={slurm.job_name} -#SBATCH --nodes={slurm.nodes} -#SBATCH --ntasks-per-node={slurm.n_tasks_per_node} # crucial - only 1 task per dist per node! -#SBATCH --cpus-per-task={slurm.cpus_per_task} -#SBATCH --gres=gpu:{slurm.gpu_per_node} -#SBATCH --partition={slurm.gpu_partition} -#SBATCH --output={slurm.slurm_logs_path}/train-{timestamp}-%x-%j.out -#SBATCH --array={slurm.array} -#SBATCH --qos={slurm.qos} -#SBATCH --begin=now+0minutes -#SBATCH --mail-type=ALL -#SBATCH --mail-user={slurm.mail} -#SBATCH --requeue + def format_sbatch_option(option, value): + return f"#SBATCH --{option}={value}" if value is not None else "" + + torchrun_args = "" + if hasattr(slurm, 'torchrun_args') and slurm.torchrun_args: + torchrun_args = " ".join([f"--{k} {v}" for k, v in slurm.torchrun_args.items()]) + sbatch_script = f"""#!/bin/bash -#SBATCH --job-name=test -#SBATCH --nodes=2 -#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! -#SBATCH --cpus-per-task=32 -#SBATCH --gres=gpu:8 -#SBATCH --partition=hopper-prod -#SBATCH --output=/fsx/elie_bakouch/nanotron/debug/main/train-{timestamp}-%x-%j.out -#SBATCH --qos=high -#SBATCH --begin=now+0minutes -#SBATCH --mail-type=ALL +{format_sbatch_option("job-name", slurm.job_name)} +{format_sbatch_option("nodes", slurm.nodes)} +{format_sbatch_option("ntasks-per-node", slurm.n_tasks_per_node)} +{format_sbatch_option("cpus-per-task", slurm.cpus_per_task)} +{format_sbatch_option("gres", f"gpu:{slurm.gpu_per_node}")} +{format_sbatch_option("partition", slurm.gpu_partition)} +{format_sbatch_option("output", f"{slurm.slurm_logs_path}/train-{timestamp}-%x-%j.out")} +{format_sbatch_option("array", slurm.array)} +{format_sbatch_option("qos", slurm.qos)} +{format_sbatch_option("mail-type", slurm.mail_type)} +{format_sbatch_option("mail-user", slurm.mail_user)} +{format_sbatch_option("exclude", ",".join(slurm.exclude_nodes) if slurm.exclude_nodes else None)} +{format_sbatch_option("time", slurm.time)} +{format_sbatch_option("mem", slurm.mem)} +{format_sbatch_option("constraint", slurm.constraint)} +{format_sbatch_option("account", slurm.account)} +{format_sbatch_option("reservation", slurm.reservation)} +{format_sbatch_option("begin", slurm.begin)} + set -x -e TRAINER_PYTHON_FILE=/fsx/elie_bakouch/nanotron/run_train.py nvidia-smi source ~/.bashrc source /fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh -conda activate /fsx/elie_bakouch/miniconda3/envs/smollm #Modify this line if you use something different than conda +conda activate {slurm.conda_env_path} #Modify this line if you use something different than conda #Show some environment variables @@ -387,13 +433,10 @@ def launch_slurm_job(launch_file_contents, *args): $TRAINER_PYTHON_FILE \ --config-file {config_path_yaml} " - -export LAUNCHER="python -u -m torch.distributed.run \ - --nproc_per_node 8 \ +export LAUNCHER="torchrun \ + --nproc_per_node {slurm.gpu_per_node} \ --nnodes $COUNT_NODE \ - --rdzv-backend etcd-v2 \ - --rdzv-endpoint etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379 \ - --rdzv-id $SLURM_JOB_ID \ + {torchrun_args} \ --node_rank $SLURM_PROCID \ --role $SLURMD_NODENAME: \ --max_restarts 0 \ @@ -412,5 +455,6 @@ def launch_slurm_job(launch_file_contents, *args): echo "END TIME: $(date)" -""" + """ + print(f"Slurm job launched with id={launch_slurm_job(sbatch_script)}") \ No newline at end of file From 28770a5749d33bcd284422a8241f9d6001662ed4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Mon, 26 Aug 2024 00:42:22 +0000 Subject: [PATCH 08/43] fancy++ launcher --- launcher.py | 142 +++++++++++++++++++++++----------- src/nanotron/config/config.py | 65 ++++++++++++---- 2 files changed, 148 insertions(+), 59 deletions(-) diff --git a/launcher.py b/launcher.py index 7b192bb4..3d78d698 100644 --- a/launcher.py +++ b/launcher.py @@ -197,7 +197,7 @@ def launch_slurm_job(launch_file_contents, *args): ) parallelism = ParallelismArgs( - dp=16, + dp=8, pp=1, tp=1, pp_engine="1f1b", @@ -267,16 +267,16 @@ def launch_slurm_job(launch_file_contents, *args): ) tokenizer = TokenizerArgs( - tokenizer_name_or_path="lvwerra/the-tokenizer-v1", + tokenizer_name_or_path="HuggingFaceTB/cosmo2-tokenizer", ) - s3_upload = S3UploadArgs( - upload_s3_path=f"s3://elie-exp/debug_nanotron/test/", - remove_after_upload=True, - s5cmd_numworkers=16, - s5cmd_concurrency=5, - s5cmd_path=os.path.join(slurm.conda_env_path, "bin/s5cmd"), - ) + # s3_upload = S3UploadArgs( + # upload_s3_path=f"s3://elie-exp/debug_nanotron/test/", + # remove_after_upload=True, + # s5cmd_numworkers=16, + # s5cmd_concurrency=5, + # s5cmd_path=os.path.join(slurm.conda_env_path, "bin/s5cmd"), + # ) data_stages=[ DatasetStageArgs( @@ -302,62 +302,76 @@ def launch_slurm_job(launch_file_contents, *args): tokens=tokens, optimizer=optimizer, data_stages=data_stages, - s3_upload=s3_upload, + # s3_upload=s3_upload, lighteval=lighteval, - slurm=slurm, + # slurm=slurm, ) print(f""" 🏋️ Model Parameters: -┌───────────────────────┬───────────────────────────┐ -│ Total Parameters │ {num_params:>25} │ -│ Layers │ {model_config.num_hidden_layers:>25d} │ -│ Attention Heads │ {model_config.num_attention_heads:>25d} │ -│ Hidden Size │ {model_config.hidden_size:>25d} │ -│ Intermediate Size │ {model_config.intermediate_size:>25d} │ -│ Context Length │ {model_config.max_position_embeddings:>25d} │ -│ Tokenizer │ {tokenizer.tokenizer_name_or_path[:25]:>25} │ -│ Vocab Size │ {model_config.vocab_size:>25d} │ -└───────────────────────┴───────────────────────────┘ +┌───────────────────────┬────────────────────────┐ +│ Total Parameters │ {num_params:>22} │ +│ Layers │ {model_config.num_hidden_layers:>22d} │ +│ Attention Heads │ {model_config.num_attention_heads:>22d} │ +│ Hidden Size │ {model_config.hidden_size:>22d} │ +│ Intermediate Size │ {model_config.intermediate_size:>22d} │ +│ Context Length │ {model_config.max_position_embeddings:>22d} │ +│ Tokenizer │ {tokenizer.tokenizer_name_or_path[:22]:>22} │ +│ Vocab Size │ {model_config.vocab_size:>22d} │ +└───────────────────────┴────────────────────────┘ """) + num_nodes = slurm.nodes if args.slurm else torch.cuda.device_count() print(f""" 🤖 Parallelism Configuration: -┌───────────────────────┬───────────────────┐ -│ Nodes │ {slurm.nodes:>17d} │ -│ Total GPUs │ {parallelism.dp*parallelism.pp*parallelism.tp:>17d} │ -│ Data Parallel (DP) │ {parallelism.dp:>17d} │ -│ Pipeline Parallel (PP)│ {parallelism.pp:>17d} │ -│ Tensor Parallel (TP) │ {parallelism.tp:>17d} │ -└───────────────────────┴───────────────────┘ +┌───────────────────────┬────────────────────────┐ +│ Nodes │ {num_nodes:>22d} │ +│ Total GPUs │ {parallelism.dp*parallelism.pp*parallelism.tp:>22d} │ +│ Data Parallel (DP) │ {parallelism.dp:>22d} │ +│ Pipeline Parallel (PP)│ {parallelism.pp:>22d} │ +│ Tensor Parallel (TP) │ {parallelism.tp:>22d} │ +└───────────────────────┴────────────────────────┘ """) print(f""" 📙 Training Configuration: -┌───────────────────────┬───────────────────┐ -│ Total Tokens │ {total_tokens_billions:>16.2f}B │ -│ Global Batch Size │ {GBS:>17,d} │ -│ Batch Size (per GPU) │ {BS:>17,d} │ -└───────────────────────┴───────────────────┘ +┌───────────────────────┬────────────────────────┐ +│ Total Tokens │ {total_tokens_billions:>21.2f}B │ +│ Global Batch Size │ {GBS:>22,d} │ +│ Batch Size (per GPU) │ {BS:>22,d} │ +└───────────────────────┴────────────────────────┘ """) print(f""" 📊 Learning Rate Schedule: -┌───────────────────────┬───────────────────┐ -│ Initial LR │ {lr_initial:>17.2e} │ -│ Warmup Style │ {learning_rate_scheduler.lr_warmup_style[:17]:>17} │ -│ Warmup Steps │ {lr_warmup_steps:>17d} │ -│ Decay Style │ {lr_decay_style[:17]:>17} │ -│ Decay Start Step │ {lr_decay_start:>17d} │ -│ Decay Steps │ {lr_decay_steps:>17d} │ -│ Final LR │ {lr_min:>17.2e} │ -└───────────────────────┴───────────────────┘ +┌───────────────────────┬────────────────────────┐ +│ Initial LR │ {lr_initial:>22.2e} │ +│ Warmup Style │ {learning_rate_scheduler.lr_warmup_style[:22]:>22} │ +│ Warmup Steps │ {lr_warmup_steps:>22d} │ +│ Decay Style │ {lr_decay_style[:22]:>22} │ +│ Decay Start Step │ {lr_decay_start:>22d} │ +│ Decay Steps │ {lr_decay_steps:>22d} │ +│ Final LR │ {lr_min:>22.2e} │ +└───────────────────────┴────────────────────────┘ +""") + print(f""" +🔧 Optimization Configuration: +┌───────────────────────┬────────────────────────┐ +│ Optimizer │ {optimizer.optimizer_factory.__class__.__name__:>22} │ +│ Weight Decay │ {optimizer.weight_decay:>22.2e} │ +│ Gradient Clipping │ {optimizer.clip_grad:>22.2f} │ +│ Adam Epsilon │ {optimizer.optimizer_factory.adam_eps:>22.2e} │ +│ Adam Beta1 │ {optimizer.optimizer_factory.adam_beta1:>22.2f} │ +│ Adam Beta2 │ {optimizer.optimizer_factory.adam_beta2:>22.2f} │ +│ ZeRO Stage │ {optimizer.zero_stage:>22d} │ +│ FP32 Grad Accumulation│ {str(optimizer.accumulate_grad_in_fp32):>22} │ +└───────────────────────┴────────────────────────┘ """) - if slurm is not None: + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + if args.slurm: dir = os.path.dirname(__file__) - timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") os.makedirs(config.slurm.config_logs_path, exist_ok=True) config_path_yaml = f"{config.slurm.config_logs_path}/{timestamp}.yaml" config.save_as_yaml(config_path_yaml) @@ -457,4 +471,42 @@ def format_sbatch_option(option, value): echo "END TIME: $(date)" """ - print(f"Slurm job launched with id={launch_slurm_job(sbatch_script)}") \ No newline at end of file + print(f"Slurm job launched with id={launch_slurm_job(sbatch_script)}") + else: + # Check if running on an interactive node + try: + gpu_count = torch.cuda.device_count() + is_interactive = gpu_count > 0 + except: + is_interactive = False + + if is_interactive: + print("Running on an interactive node with GPUs.") + + # Check if the parallelism configuration matches the available GPUs + total_gpus = gpu_count + config_gpus = parallelism.dp * parallelism.tp * parallelism.pp + + if total_gpus != config_gpus: + raise ValueError(f"The parallelism configuration (dp={parallelism.dp}, tp={parallelism.tp}, pp={parallelism.pp}) " + f"doesn't match the number of available GPUs ({total_gpus}). " + f"Please adjust your configuration to match the available resources.") + + # Save config + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + os.makedirs("/fsx/elie_bakouch/nanotron/config_logs", exist_ok=True) + config_path_yaml = f"/fsx/elie_bakouch/nanotron/config_logs/{timestamp}.yaml" + config.save_as_yaml(config_path_yaml) + + # Prepare command + trainer_python_file = "/fsx/elie_bakouch/nanotron/run_train.py" + cmd = f"{trainer_python_file} --config-file {config_path_yaml}" + + # Launch job + launch_cmd = f"torchrun --nproc_per_node {gpu_count} {cmd}" + print(f"Launching interactive job with command: {launch_cmd}") + + # Execute the command + subprocess.run(launch_cmd, shell=True, check=True) + else: + print("Not running on a Slurm cluster or an interactive node with GPUs. Please submit a Slurm job or use an interactive node with GPUs.") \ No newline at end of file diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 26bd1546..126a4d23 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, fields from pathlib import Path from datasets.download.streaming_download_manager import xPath -from typing import List, Optional, Type, Union +from typing import List, Optional, Type, Union, Dict import dacite import torch @@ -93,25 +93,62 @@ def __post_init__(self): @dataclass class SlurmArgs: + """ + Arguments for configuring SLURM job submission. + + Attributes: + gpu_partition (str): SLURM partition (queue) for GPU jobs. + job_name (str): Name of the SLURM job. + nodes (int): Number of nodes to allocate for the job. + logs_path (str): Base directory for storing log files. + conda_path (str): Path to the Conda installation script. + conda_env_path (str): Path to the Conda environment to be used. + n_tasks_per_node (int): Number of tasks to run per node. Default is 1. + cpus_per_task (int): Number of CPUs to allocate per task. Default is 32. + gpu_per_node (int): Number of GPUs to allocate per node. Default is 8. + array (Optional[str]): Job array specification, allowing multiple similar jobs to be submitted as a group. + qos (Optional[str]): Quality of Service, used to define job priority or resource limits. + mail_type (Optional[str]): Specifies when to send email notifications about the job (e.g., BEGIN, END, FAIL). Default is FAIL. + mail_user (Optional[str]): Email address to receive job notifications. + exclude_nodes (Optional[List[str]]): List of nodes to exclude from job allocation. + time (Optional[str]): Maximum time limit for the job. + mem (Optional[str]): Memory requirement for the job. + constraint (Optional[str]): Specifies node features required for the job. + account (Optional[str]): Account to charge for the job's resource usage. + reservation (Optional[str]): Name of a reservation to use for the job. + begin (Optional[str]): Earliest time the job can start. + torchrun_args (Optional[Dict[str, str]]): Additional arguments for torchrun command. + slurm_logs_path (Optional[str]): Specific path for SLURM output logs. + config_logs_path (Optional[str]): Path for storing configuration logs. + """ + + gpu_partition: str job_name: str nodes: int - logs_path: Path - # TODO: @elibak: Add a way to handle different virtual environments (conda, venv, uv, etc) For now, we assume conda and user can modify the slurm template if they use something else. + logs_path: str conda_path: str - conda_env_path : str - gpu_partition: Optional[str] = None - n_tasks_per_node: Optional[int] = 1 - cpus_per_task: Optional[int] = 32 - gpu_per_node: Optional[int] = 8 - mail: Optional[str] = None - qos: Optional[str] = "high" - array: Optional[str] = "1-1%1" + conda_env_path: str + n_tasks_per_node: int = 1 + cpus_per_task: int = 32 + gpu_per_node: int = 8 + array: Optional[str] = None + qos: Optional[str] = None + mail_user: Optional[str] = None + mail_type: Optional[str] = None + exclude_nodes: Optional[List[str]] = None + time: Optional[str] = None + mem: Optional[str] = None + constraint: Optional[str] = None + account: Optional[str] = None + reservation: Optional[str] = None + begin: Optional[str] = None + torchrun_args: Optional[Dict[str, str]] = None slurm_logs_path: Optional[str] = None - evals_logs_path: Optional[str] = None config_logs_path: Optional[str] = None - - + def __post_init__(self): + if self.mail_type is None and self.mail_user is not None: + self.mail_type = "FAIL" @dataclass class S3UploadArgs: From 76520898548003cb91d007ea36844af8e89bcc17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Mon, 26 Aug 2024 01:04:13 +0000 Subject: [PATCH 09/43] add the possibility to override config yeahhh --- launcher.py | 82 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 47 insertions(+), 35 deletions(-) diff --git a/launcher.py b/launcher.py index 3d78d698..9ad544b3 100644 --- a/launcher.py +++ b/launcher.py @@ -6,6 +6,7 @@ import torch import argparse +from typing import Any, Dict from nanotron.logging import human_format from nanotron.models.llama import LlamaConfig @@ -49,6 +50,14 @@ def launch_slurm_job(launch_file_contents, *args): f.flush() return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[-1] +def set_nested_attribute(obj, path, value): + parts = path.split('.') + for part in parts[:-1]: + if not hasattr(obj, part): + setattr(obj, part, type('', (), {})()) # Create empty object if attribute doesn't exist + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -57,6 +66,8 @@ def launch_slurm_job(launch_file_contents, *args): parser.add_argument("--name", help="run name", type=str, default=None) parser.add_argument("--seed", help="seed", type=int, default=8) parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="normal") + parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", + help="Override config values. Use dot notation for nested keys.") args = parser.parse_args() PROJECT = args.project @@ -65,27 +76,6 @@ def launch_slurm_job(launch_file_contents, *args): else: RUN = f"{PROJECT}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" - ## FOR SANITY CHECK LATER - # from dataclasses import fields, is_dataclass - - # def print_differences(target, updates): - # if not is_dataclass(target) or not is_dataclass(updates): - # raise ValueError("Both target and updates should be dataclass instances") - - # for field in fields(target): - # update_value = getattr(updates, field.name) - - # if update_value is not None: - # if is_dataclass(update_value): - # print_differences(getattr(target, field.name), update_value) - # else: - # target_value = getattr(target, field.name) - # if update_value != target_value: - # if update_value.__class__.__module__ != "builtins": - # continue - # print(f"{field.name}: {target_value} -> {update_value}") - - general = GeneralArgs( project=PROJECT, run=RUN, @@ -213,11 +203,6 @@ def launch_slurm_job(launch_file_contents, *args): train_steps=100, val_check_interval=-1, ) - BS = tokens.micro_batch_size*tokens.batch_accumulation_per_replica*tokens.sequence_length - GBS = BS * parallelism.dp - - total_tokens = tokens.train_steps * GBS - total_tokens_billions = total_tokens / 1e9 model = ModelArgs( model_config=model_config, @@ -244,13 +229,7 @@ def launch_slurm_job(launch_file_contents, *args): lr_decay_starting_step= 80, min_decay_lr=0, ) - # Calculate and print learning rate and global batch size information - lr_initial = learning_rate_scheduler.learning_rate - lr_min = learning_rate_scheduler.min_decay_lr - lr_warmup_steps = learning_rate_scheduler.lr_warmup_steps - lr_decay_steps = learning_rate_scheduler.lr_decay_steps - lr_decay_start = learning_rate_scheduler.lr_decay_starting_step - lr_decay_style = learning_rate_scheduler.lr_decay_style + optimizer = OptimizerArgs( zero_stage=0, @@ -307,6 +286,39 @@ def launch_slurm_job(launch_file_contents, *args): # slurm=slurm, ) + # Parse and apply overrides + if args.override: + for item in args.override: + if '=' not in item: + raise ValueError(f"Invalid override format: {item}. Use KEY=VALUE.") + key, value = item.split('=', 1) + try: + # Try to evaluate the value as a Python literal + value = eval(value) + except: + # If eval fails, treat it as a string + pass + + set_nested_attribute(config, key, value) + + print("Applied overrides:") + for item in args.override: + print(f" {item}") + + # Calculate and print learning rate and global batch size information + lr_initial = learning_rate_scheduler.learning_rate + lr_min = learning_rate_scheduler.min_decay_lr + lr_warmup_steps = learning_rate_scheduler.lr_warmup_steps + lr_decay_steps = learning_rate_scheduler.lr_decay_steps + lr_decay_start = learning_rate_scheduler.lr_decay_starting_step + lr_decay_style = learning_rate_scheduler.lr_decay_style + + BS = tokens.micro_batch_size*tokens.batch_accumulation_per_replica*tokens.sequence_length + GBS = BS * parallelism.dp + + total_tokens = tokens.train_steps * GBS + total_tokens_billions = total_tokens / 1e9 + print(f""" 🏋️ Model Parameters: ┌───────────────────────┬────────────────────────┐ @@ -329,7 +341,7 @@ def launch_slurm_job(launch_file_contents, *args): │ Total GPUs │ {parallelism.dp*parallelism.pp*parallelism.tp:>22d} │ │ Data Parallel (DP) │ {parallelism.dp:>22d} │ │ Pipeline Parallel (PP)│ {parallelism.pp:>22d} │ -│ Tensor Parallel (TP) │ {parallelism.tp:>22d} │ +│ Tensor Parallel (TP) │ {parallelism.tp:>22d} └───────────────────────┴────────────────────────┘ """) @@ -352,7 +364,7 @@ def launch_slurm_job(launch_file_contents, *args): │ Decay Start Step │ {lr_decay_start:>22d} │ │ Decay Steps │ {lr_decay_steps:>22d} │ │ Final LR │ {lr_min:>22.2e} │ -└───────────────────────┴────────────────────────┘ +└──────────────────────┴────────────────────────┘ """) print(f""" 🔧 Optimization Configuration: From 4e2d7d9a65dd0d98ce24fdf954f32e90a1d19026 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Mon, 26 Aug 2024 01:18:03 +0000 Subject: [PATCH 10/43] don't run lighteval runner if no s3 uploader AND slurm (might want change this condition just to slurm in the future and support local lighteval) --- src/nanotron/trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 471a2018..2e97f561 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -280,10 +280,11 @@ def post_init(self): self.s3_mover = None if self.config.lighteval is not None and dist.get_rank(self.parallel_context.world_pg) == 0: # We only start evaluation runs once on the first node - if self.s3_mover is None: - raise ValueError("lighteval requires s3 upload of checkpoints to be enabled") - self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) - self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint + if self.s3_mover is not None and self.slurm is not None: + self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) + self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint + else: + log_rank("LightEval is enabled but s3 upload is not enabled, skipping evaluation", logger=logger, level=logging.INFO, rank=0) def pre_training(self, *args, **kwargs): self._print_training_plan() From e7f0437bdf58c8d6dab97c744040365173906c44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Mon, 26 Aug 2024 01:19:07 +0000 Subject: [PATCH 11/43] add CUDA__DEVICE_MAX_CONNECTIONS=1 in interactive mode --- launcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launcher.py b/launcher.py index 9ad544b3..7832ac51 100644 --- a/launcher.py +++ b/launcher.py @@ -515,7 +515,7 @@ def format_sbatch_option(option, value): cmd = f"{trainer_python_file} --config-file {config_path_yaml}" # Launch job - launch_cmd = f"torchrun --nproc_per_node {gpu_count} {cmd}" + launch_cmd = f"CUDA_DEVICE_MAX_CONNECTIONS='1' torchrun --nproc_per_node {gpu_count} {cmd}" print(f"Launching interactive job with command: {launch_cmd}") # Execute the command From bb45352dbcf43cf7ddcbe65ad5452213cac1c301 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Mon, 26 Aug 2024 21:20:14 +0000 Subject: [PATCH 12/43] add create_config, moove log_path to general --- create_config.py | 274 ++++++++++++++++++++++++++ launcher.py | 359 +++++++--------------------------- src/nanotron/config/config.py | 54 +++-- src/nanotron/trainer.py | 3 +- 4 files changed, 370 insertions(+), 320 deletions(-) create mode 100644 create_config.py diff --git a/create_config.py b/create_config.py new file mode 100644 index 00000000..55dcf85d --- /dev/null +++ b/create_config.py @@ -0,0 +1,274 @@ +import os +import subprocess +import tempfile +from datetime import datetime +import math +import torch + +import argparse +from typing import Any, Dict + +from nanotron.logging import human_format +from nanotron.models.llama import LlamaConfig + +from nanotron.config import ( + Config, + DataArgs, + NanosetDatasetsArgs, + S3UploadArgs, + SlurmArgs, + CheckpointsArgs, + GeneralArgs, + LightEvalConfig, + LightEvalLoggingArgs, + LightEvalTasksArgs, + LoggingArgs, + LRSchedulerArgs, + ModelArgs, + OptimizerArgs, + AdamWOptimizerArgs, + ParallelismArgs, + RandomInit, + TokenizerArgs, + TokensArgs, + LightEvalWandbLoggerConfig, + DatasetStageArgs, +) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("project", help="project name", type=str) + parser.add_argument("--slurm", help="use slurm", action="store_true") + parser.add_argument("--name", help="run name", type=str, default=None) + parser.add_argument("--seed", help="seed", type=int, default=8) + parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="normal") + parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", + help="Override config values. Use dot notation for nested keys.") + parser.add_argument("--launch", action="store_true", help="Launch the configuration immediately") + args = parser.parse_args() + + if args.name is not None: + run = f"{args.project}-{args.name}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + else: + run = f"{args.project}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + + general = GeneralArgs( + project=args.project, + run=run, + repo_id="HuggingFaceSmol/test-nanotron", + logs_path="/fsx/elie_bakouch/nanotron/debug", + seed=args.seed, + temp_dir="/scratch", + ) + + if args.slurm: + job_name=f"{args.project}-{args.name}" if args.name is not None else f"{args.project}" + slurm = SlurmArgs( + gpu_partition="hopper-prod", + job_name=f"{args.project}-{args.name}", + nodes=2, + conda_path="/fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh", + conda_env_path="/fsx/elie_bakouch/miniconda3/envs/smollm", + exclude_nodes=["ip-26-0-161-138", "ip-26-0-161-178"], + torchrun_args={ + "rdzv_backend": "etcd-v2", + "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", + "rdzv_id": "$SLURM_JOB_ID" + }, + qos="normal", + begin="now+0minutes" + ) + else: + slurm = None + + model_config = LlamaConfig( + bos_token_id=0, + eos_token_id=0, + hidden_act="silu", + hidden_size=576, + initializer_range=0.02, + intermediate_size=1536, + max_position_embeddings=2048, + num_attention_heads=9, + num_hidden_layers=30, + num_key_value_heads=3, + pretraining_tp=1, + rms_norm_eps=1e-05, + rope_scaling=None, + tie_word_embeddings=True, + use_cache=True, + vocab_size=49152, + ) + + lighteval = LightEvalConfig( + tasks=LightEvalTasksArgs( + tasks="early-signal", # "generatives", "all" + custom_tasks="nanotron.lighteval.evaluation_tasks", + max_samples=1000, # Cap very large evals or for debugging + dataset_loading_processes=8, + ), + parallelism=ParallelismArgs( + dp=8, + pp=1, + tp=1, + pp_engine="1f1b", + tp_mode="ALL_REDUCE", + # recompute_granularity="selective", + tp_linear_async_communication=False, + ), + batch_size=16, + wandb=LightEvalWandbLoggerConfig( + wandb_project=args.project, + wandb_entity="eliebak", + wandb_run_name=f"{run}_evals", + ), + logging=LightEvalLoggingArgs( + local_output_path=f"{general.temp_dir}/lighteval/{run}", + push_details_to_hub=False, + push_results_to_hub=True, + push_results_to_tensorboard=True, + #hub_repo_details=REPO_ID, + hub_repo_results=general.repo_id, + hub_repo_tensorboard=general.repo_id, + tensorboard_metric_prefix="e", + ), + slurm_template="/fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_eval.slurm.jinja", + ) + + + checkpoints = CheckpointsArgs( + checkpoints_path=f"checkpoints/{run}", + checkpoints_path_is_shared_file_system=False, + resume_checkpoint_path=None, + checkpoint_interval=20, + save_initial_state=False, + ) + + parallelism = ParallelismArgs( + dp=8, + pp=1, + tp=1, + pp_engine="1f1b", + tp_mode="REDUCE_SCATTER", + tp_linear_async_communication=True, + ) + #Add sanity check for the number of GPUs and the number of nodes ? + + tokens = TokensArgs( + batch_accumulation_per_replica=8, + micro_batch_size=16, + sequence_length=2048, + train_steps=100, + val_check_interval=-1, + ) + + model = ModelArgs( + model_config=model_config, + make_vocab_size_divisible_by=1, + init_method=RandomInit( + std=math.sqrt(model_config.hidden_size), + ), + dtype=torch.bfloat16, + ) + + logging = LoggingArgs( + # 'debug', 'info', 'warning', 'error', 'critical' and 'passive' + log_level="info", + log_level_replica="info", + iteration_step_info_interval=1, + ) + + learning_rate_scheduler = LRSchedulerArgs( + learning_rate=1e-4, #llama one + lr_warmup_steps=10, + lr_warmup_style="linear", + lr_decay_style="linear", + lr_decay_steps = 20, + lr_decay_starting_step= 80, + min_decay_lr=0, + ) + + + optimizer = OptimizerArgs( + zero_stage=0, + weight_decay=0.01, + clip_grad=1.0, + accumulate_grad_in_fp32=True, + learning_rate_scheduler=learning_rate_scheduler, + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, + adam_beta1=0.9, + adam_beta2=0.95, + torch_adam_is_fused=True, + ), + ) + + tokenizer = TokenizerArgs( + tokenizer_name_or_path="HuggingFaceTB/cosmo2-tokenizer", + ) + + s3_upload = S3UploadArgs( + upload_s3_path=f"s3://elie-exp/debug_nanotron/test/", + remove_after_upload=True, + s5cmd_numworkers=16, + s5cmd_concurrency=5, + s5cmd_path="/fsx/elie_bakouch/miniconda3/envs/smollm/bin/s5cmd", + ) + + data_stages=[ + DatasetStageArgs( + data=DataArgs( + dataset=NanosetDatasetsArgs( + dataset_folder="/fsx/elie_bakouch/nanotron/datasets/cosmopedia-v2", + ), + num_loading_workers=0, + seed=general.seed, + ), + name="training stage", + start_training_step=1, + ), + ] + + config = Config( + general=general, + checkpoints=checkpoints, + parallelism=parallelism, + model=model, + tokenizer=tokenizer, + logging=logging, + tokens=tokens, + optimizer=optimizer, + data_stages=data_stages, + s3_upload=s3_upload, + lighteval=lighteval, + slurm=slurm, + ) + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + dir = os.path.dirname(__file__) + os.makedirs(config.general.config_folder_path, exist_ok=True) + config_path_yaml = f"{config.general.config_folder_path}/{timestamp}.yaml" + config.save_as_yaml(config_path_yaml) + + os.makedirs(f"{config.general.slurm_logs_path}/", exist_ok=True) + + print(f"Configuration saved to: {config_path_yaml}") + + if args.launch: + launcher_path = os.path.join(dir, "launcher.py") + launch_command = [ + "python", launcher_path, + config_path_yaml, + ] + + if args.override: + launch_command.extend(["--override"] + args.override) + + print(f"Launching configuration with command: {' '.join(launch_command)}") + subprocess.run(launch_command, check=True) + else: + print("To launch this configuration, run:") + print(f"python {os.path.join(dir, 'launcher.py')} {config_path_yaml} " + f"--override general.config_path={config_path_yaml}") + + if args.override: + print(f" {' '.join(args.override)}") \ No newline at end of file diff --git a/launcher.py b/launcher.py index 7832ac51..48b5f7ed 100644 --- a/launcher.py +++ b/launcher.py @@ -13,26 +13,7 @@ from nanotron.config import ( Config, - DataArgs, - NanosetDatasetsArgs, - S3UploadArgs, - SlurmArgs, - CheckpointsArgs, - GeneralArgs, - LightEvalConfig, - LightEvalLoggingArgs, - LightEvalTasksArgs, - LoggingArgs, - LRSchedulerArgs, - ModelArgs, - OptimizerArgs, - AdamWOptimizerArgs, - ParallelismArgs, - RandomInit, - TokenizerArgs, - TokensArgs, - LightEvalWandbLoggerConfig, - DatasetStageArgs, + get_config_from_file, ) def launch_slurm_job(launch_file_contents, *args): @@ -61,232 +42,28 @@ def set_nested_attribute(obj, path, value): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("project", help="project name", type=str) - parser.add_argument("--slurm", help="use slurm", action="store_true") - parser.add_argument("--name", help="run name", type=str, default=None) - parser.add_argument("--seed", help="seed", type=int, default=8) - parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="normal") + parser.add_argument("config_path", help="path to the configuration file", type=str) parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", help="Override config values. Use dot notation for nested keys.") args = parser.parse_args() - PROJECT = args.project - if args.name is not None: - RUN = f"{PROJECT}-{args.name}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" - else: - RUN = f"{PROJECT}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" - - general = GeneralArgs( - project=PROJECT, - run=RUN, - repo_id="HuggingFaceSmol/test-nanotron", - seed=args.seed, - temp_dir="/scratch", - ) - if args.slurm: - slurm = SlurmArgs( - gpu_partition="hopper-prod", - job_name=f"{PROJECT}-{args.name}", - nodes=2, - logs_path=f"/fsx/elie_bakouch/nanotron/debug", - conda_path="/fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh", - conda_env_path="/fsx/elie_bakouch/miniconda3/envs/smollm", - exclude_nodes=["ip-26-0-161-138", "ip-26-0-161-178"], - torchrun_args={ - "rdzv_backend": "etcd-v2", - "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", - "rdzv_id": "$SLURM_JOB_ID" - }, - qos="normal", - mail_type="FAIL", - mail_user="bakouch.elie@gmail.com", - begin="now+0minutes" - ) + # Load the configuration using get_config_from_file + config = get_config_from_file(args.config_path, config_class=Config) - model_config = LlamaConfig( - bos_token_id=0, - eos_token_id=0, - hidden_act="silu", - hidden_size=576, - initializer_range=0.02, - intermediate_size=1536, - max_position_embeddings=2048, - num_attention_heads=9, - num_hidden_layers=30, - num_key_value_heads=3, - pretraining_tp=1, - rms_norm_eps=1e-05, - rope_scaling=None, - tie_word_embeddings=True, - use_cache=True, - vocab_size=49152, - ) - if model_config.tie_word_embeddings ==True: + if config.model.model_config.tie_word_embeddings ==True: tie_word_embeddings_multiplier = 1 else: tie_word_embeddings_multiplier = 2 num_params = human_format( - model_config.vocab_size * model_config.hidden_size * tie_word_embeddings_multiplier - + model_config.num_hidden_layers + config.model.model_config.vocab_size * config.model.model_config.hidden_size * tie_word_embeddings_multiplier + + config.model.model_config.num_hidden_layers * ( - 3 * model_config.hidden_size * model_config.intermediate_size - + 4 * model_config.hidden_size * model_config.hidden_size + 3 * config.model.model_config.hidden_size * config.model.model_config.intermediate_size + + 4 * config.model.model_config.hidden_size * config.model.model_config.hidden_size ) ).replace(".", "p") - - # Do we have a SLURM task ID? - # You can SLURM_ARRAY_TASK_ID to run multiple runs with predefined HP - task_id = int(os.environ.get("SLURM_ARRAY_TASK_ID", -1)) - job_id = os.environ.get("SLURM_JOB_ID", "") - - - - lighteval = LightEvalConfig( - tasks=LightEvalTasksArgs( - tasks="early-signal", # "generatives", "all" - custom_tasks="nanotron.lighteval.evaluation_tasks", - max_samples=1000, # Cap very large evals or for debugging - dataset_loading_processes=8, - ), - parallelism=ParallelismArgs( - dp=8, - pp=1, - tp=1, - pp_engine="1f1b", - tp_mode="ALL_REDUCE", - # recompute_granularity="selective", - tp_linear_async_communication=False, - ), - batch_size=16, - wandb=LightEvalWandbLoggerConfig( - wandb_project=PROJECT, - wandb_entity="eliebak", - wandb_run_name=f"{RUN}_evals", - ), - logging=LightEvalLoggingArgs( - local_output_path=f"{general.temp_dir}/lighteval/{RUN}", - push_details_to_hub=False, - push_results_to_hub=True, - push_results_to_tensorboard=True, - #hub_repo_details=REPO_ID, - hub_repo_results=general.repo_id, - hub_repo_tensorboard=general.repo_id, - tensorboard_metric_prefix="e", - ), - slurm_template="/fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_eval.slurm.jinja", - ) - - - checkpoints = CheckpointsArgs( - checkpoints_path=f"checkpoints/{RUN}", - checkpoints_path_is_shared_file_system=False, - resume_checkpoint_path=None, - checkpoint_interval=20, - save_initial_state=False, - ) - - parallelism = ParallelismArgs( - dp=8, - pp=1, - tp=1, - pp_engine="1f1b", - tp_mode="REDUCE_SCATTER", - tp_linear_async_communication=True, - ) - #Add sanity check for the number of GPUs and the number of nodes ? - - tokens = TokensArgs( - batch_accumulation_per_replica=8, - micro_batch_size=16, - sequence_length=2048, - train_steps=100, - val_check_interval=-1, - ) - - model = ModelArgs( - model_config=model_config, - make_vocab_size_divisible_by=1, - init_method=RandomInit( - std=math.sqrt(model_config.hidden_size), - ), - dtype=torch.bfloat16, - ) - - logging = LoggingArgs( - # 'debug', 'info', 'warning', 'error', 'critical' and 'passive' - log_level="info", - log_level_replica="info", - iteration_step_info_interval=1, - ) - - learning_rate_scheduler = LRSchedulerArgs( - learning_rate=1e-4, #llama one - lr_warmup_steps=10, - lr_warmup_style="linear", - lr_decay_style="linear", - lr_decay_steps = 20, - lr_decay_starting_step= 80, - min_decay_lr=0, - ) - - - optimizer = OptimizerArgs( - zero_stage=0, - weight_decay=0.01, - clip_grad=1.0, - accumulate_grad_in_fp32=True, - learning_rate_scheduler=learning_rate_scheduler, - optimizer_factory=AdamWOptimizerArgs( - adam_eps=1e-08, - adam_beta1=0.9, - adam_beta2=0.95, - torch_adam_is_fused=True, - ), - ) - - tokenizer = TokenizerArgs( - tokenizer_name_or_path="HuggingFaceTB/cosmo2-tokenizer", - ) - - # s3_upload = S3UploadArgs( - # upload_s3_path=f"s3://elie-exp/debug_nanotron/test/", - # remove_after_upload=True, - # s5cmd_numworkers=16, - # s5cmd_concurrency=5, - # s5cmd_path=os.path.join(slurm.conda_env_path, "bin/s5cmd"), - # ) - - data_stages=[ - DatasetStageArgs( - data=DataArgs( - dataset=NanosetDatasetsArgs( - dataset_folder="/fsx/elie_bakouch/nanotron/datasets/cosmopedia-v2", - ), - num_loading_workers=0, - seed=general.seed, - ), - name="training stage", - start_training_step=1, - ), - ] - - config = Config( - general=general, - checkpoints=checkpoints, - parallelism=parallelism, - model=model, - tokenizer=tokenizer, - logging=logging, - tokens=tokens, - optimizer=optimizer, - data_stages=data_stages, - # s3_upload=s3_upload, - lighteval=lighteval, - # slurm=slurm, - ) - - # Parse and apply overrides + # Apply overrides if args.override: for item in args.override: if '=' not in item: @@ -306,42 +83,42 @@ def set_nested_attribute(obj, path, value): print(f" {item}") # Calculate and print learning rate and global batch size information - lr_initial = learning_rate_scheduler.learning_rate - lr_min = learning_rate_scheduler.min_decay_lr - lr_warmup_steps = learning_rate_scheduler.lr_warmup_steps - lr_decay_steps = learning_rate_scheduler.lr_decay_steps - lr_decay_start = learning_rate_scheduler.lr_decay_starting_step - lr_decay_style = learning_rate_scheduler.lr_decay_style - - BS = tokens.micro_batch_size*tokens.batch_accumulation_per_replica*tokens.sequence_length - GBS = BS * parallelism.dp + lr_initial = config.optimizer.learning_rate_scheduler.learning_rate + lr_min = config.optimizer.learning_rate_scheduler.min_decay_lr + lr_warmup_steps = config.optimizer.learning_rate_scheduler.lr_warmup_steps + lr_decay_steps = config.optimizer.learning_rate_scheduler.lr_decay_steps + lr_decay_start = config.optimizer.learning_rate_scheduler.lr_decay_starting_step + lr_decay_style = config.optimizer.learning_rate_scheduler.lr_decay_style + + BS = config.tokens.micro_batch_size*config.tokens.batch_accumulation_per_replica*config.tokens.sequence_length + GBS = BS * config.parallelism.dp - total_tokens = tokens.train_steps * GBS + total_tokens = config.tokens.train_steps * GBS total_tokens_billions = total_tokens / 1e9 print(f""" 🏋️ Model Parameters: ┌───────────────────────┬────────────────────────┐ │ Total Parameters │ {num_params:>22} │ -│ Layers │ {model_config.num_hidden_layers:>22d} │ -│ Attention Heads │ {model_config.num_attention_heads:>22d} │ -│ Hidden Size │ {model_config.hidden_size:>22d} │ -│ Intermediate Size │ {model_config.intermediate_size:>22d} │ -│ Context Length │ {model_config.max_position_embeddings:>22d} │ -│ Tokenizer │ {tokenizer.tokenizer_name_or_path[:22]:>22} │ -│ Vocab Size │ {model_config.vocab_size:>22d} │ +│ Layers │ {config.model.model_config.num_hidden_layers:>22d} │ +│ Attention Heads │ {config.model.model_config.num_attention_heads:>22d} │ +│ Hidden Size │ {config.model.model_config.hidden_size:>22d} │ +│ Intermediate Size │ {config.model.model_config.intermediate_size:>22d} │ +│ Context Length │ {config.model.model_config.max_position_embeddings:>22d} │ +│ Tokenizer │ {config.tokenizer.tokenizer_name_or_path[:22]:>22} │ +│ Vocab Size │ {config.model.model_config.vocab_size:>22d} │ └───────────────────────┴────────────────────────┘ """) - num_nodes = slurm.nodes if args.slurm else torch.cuda.device_count() + num_nodes = config.slurm.nodes if config.slurm else torch.cuda.device_count() print(f""" 🤖 Parallelism Configuration: ┌───────────────────────┬────────────────────────┐ │ Nodes │ {num_nodes:>22d} │ -│ Total GPUs │ {parallelism.dp*parallelism.pp*parallelism.tp:>22d} │ -│ Data Parallel (DP) │ {parallelism.dp:>22d} │ -│ Pipeline Parallel (PP)│ {parallelism.pp:>22d} │ -│ Tensor Parallel (TP) │ {parallelism.tp:>22d} +│ Total GPUs │ {config.parallelism.dp*config.parallelism.pp*config.parallelism.tp:>22d} │ +│ Data Parallel (DP) │ {config.parallelism.dp:>22d} │ +│ Pipeline Parallel (PP)│ {config.parallelism.pp:>22d} │ +│ Tensor Parallel (TP) │ {config.parallelism.tp:>22d} │ └───────────────────────┴────────────────────────┘ """) @@ -358,30 +135,30 @@ def set_nested_attribute(obj, path, value): 📊 Learning Rate Schedule: ┌───────────────────────┬────────────────────────┐ │ Initial LR │ {lr_initial:>22.2e} │ -│ Warmup Style │ {learning_rate_scheduler.lr_warmup_style[:22]:>22} │ +│ Warmup Style │ {config.optimizer.learning_rate_scheduler.lr_warmup_style[:22]:>22} │ │ Warmup Steps │ {lr_warmup_steps:>22d} │ │ Decay Style │ {lr_decay_style[:22]:>22} │ │ Decay Start Step │ {lr_decay_start:>22d} │ │ Decay Steps │ {lr_decay_steps:>22d} │ │ Final LR │ {lr_min:>22.2e} │ -└──────────────────────┴────────────────────────┘ +└───────────────────────┴────────────────────────┘ """) print(f""" 🔧 Optimization Configuration: ┌───────────────────────┬────────────────────────┐ -│ Optimizer │ {optimizer.optimizer_factory.__class__.__name__:>22} │ -│ Weight Decay │ {optimizer.weight_decay:>22.2e} │ -│ Gradient Clipping │ {optimizer.clip_grad:>22.2f} │ -│ Adam Epsilon │ {optimizer.optimizer_factory.adam_eps:>22.2e} │ -│ Adam Beta1 │ {optimizer.optimizer_factory.adam_beta1:>22.2f} │ -│ Adam Beta2 │ {optimizer.optimizer_factory.adam_beta2:>22.2f} │ -│ ZeRO Stage │ {optimizer.zero_stage:>22d} │ -│ FP32 Grad Accumulation│ {str(optimizer.accumulate_grad_in_fp32):>22} │ +│ Optimizer │ {config.optimizer.optimizer_factory.__class__.__name__:>22} │ +│ Weight Decay │ {config.optimizer.weight_decay:>22.2e} │ +│ Gradient Clipping │ {config.optimizer.clip_grad:>22.2f} │ +│ Adam Epsilon │ {config.optimizer.optimizer_factory.adam_eps:>22.2e} │ +│ Adam Beta1 │ {config.optimizer.optimizer_factory.adam_beta1:>22.2f} │ +│ Adam Beta2 │ {config.optimizer.optimizer_factory.adam_beta2:>22.2f} │ +│ ZeRO Stage │ {config.optimizer.zero_stage:>22d} │ +│ FP32 Grad Accumulation│ {str(config.optimizer.accumulate_grad_in_fp32):>22} │ └───────────────────────┴────────────────────────┘ """) timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - if args.slurm: + if config.slurm: dir = os.path.dirname(__file__) os.makedirs(config.slurm.config_logs_path, exist_ok=True) @@ -394,28 +171,28 @@ def format_sbatch_option(option, value): return f"#SBATCH --{option}={value}" if value is not None else "" torchrun_args = "" - if hasattr(slurm, 'torchrun_args') and slurm.torchrun_args: - torchrun_args = " ".join([f"--{k} {v}" for k, v in slurm.torchrun_args.items()]) + if hasattr(config.slurm, 'torchrun_args') and config.slurm.torchrun_args: + torchrun_args = " ".join([f"--{k} {v}" for k, v in config.slurm.torchrun_args.items()]) sbatch_script = f"""#!/bin/bash -{format_sbatch_option("job-name", slurm.job_name)} -{format_sbatch_option("nodes", slurm.nodes)} -{format_sbatch_option("ntasks-per-node", slurm.n_tasks_per_node)} -{format_sbatch_option("cpus-per-task", slurm.cpus_per_task)} -{format_sbatch_option("gres", f"gpu:{slurm.gpu_per_node}")} -{format_sbatch_option("partition", slurm.gpu_partition)} -{format_sbatch_option("output", f"{slurm.slurm_logs_path}/train-{timestamp}-%x-%j.out")} -{format_sbatch_option("array", slurm.array)} -{format_sbatch_option("qos", slurm.qos)} -{format_sbatch_option("mail-type", slurm.mail_type)} -{format_sbatch_option("mail-user", slurm.mail_user)} -{format_sbatch_option("exclude", ",".join(slurm.exclude_nodes) if slurm.exclude_nodes else None)} -{format_sbatch_option("time", slurm.time)} -{format_sbatch_option("mem", slurm.mem)} -{format_sbatch_option("constraint", slurm.constraint)} -{format_sbatch_option("account", slurm.account)} -{format_sbatch_option("reservation", slurm.reservation)} -{format_sbatch_option("begin", slurm.begin)} +{format_sbatch_option("job-name", config.slurm.job_name)} +{format_sbatch_option("nodes", config.slurm.nodes)} +{format_sbatch_option("ntasks-per-node", config.slurm.n_tasks_per_node)} +{format_sbatch_option("cpus-per-task", config.slurm.cpus_per_task)} +{format_sbatch_option("gres", f"gpu:{config.slurm.gpu_per_node}")} +{format_sbatch_option("partition", config.slurm.gpu_partition)} +{format_sbatch_option("output", f"{config.slurm.slurm_logs_path}/train-{timestamp}-%x-%j.out")} +{format_sbatch_option("array", config.slurm.array)} +{format_sbatch_option("qos", config.slurm.qos)} +{format_sbatch_option("mail-type", config.slurm.mail_type)} +{format_sbatch_option("mail-user", config.slurm.mail_user)} +{format_sbatch_option("exclude", ",".join(config.slurm.exclude_nodes) if config.slurm.exclude_nodes else None)} +{format_sbatch_option("time", config.slurm.time)} +{format_sbatch_option("mem", config.slurm.mem)} +{format_sbatch_option("constraint", config.slurm.constraint)} +{format_sbatch_option("account", config.slurm.account)} +{format_sbatch_option("reservation", config.slurm.reservation)} +{format_sbatch_option("begin", config.slurm.begin)} set -x -e @@ -423,7 +200,7 @@ def format_sbatch_option(option, value): nvidia-smi source ~/.bashrc source /fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh -conda activate {slurm.conda_env_path} #Modify this line if you use something different than conda +conda activate {config.slurm.conda_env_path} #Modify this line if you use something different than conda #Show some environment variables @@ -460,7 +237,7 @@ def format_sbatch_option(option, value): --config-file {config_path_yaml} " export LAUNCHER="torchrun \ - --nproc_per_node {slurm.gpu_per_node} \ + --nproc_per_node {config.slurm.gpu_per_node} \ --nnodes $COUNT_NODE \ {torchrun_args} \ --node_rank $SLURM_PROCID \ @@ -497,10 +274,10 @@ def format_sbatch_option(option, value): # Check if the parallelism configuration matches the available GPUs total_gpus = gpu_count - config_gpus = parallelism.dp * parallelism.tp * parallelism.pp + config_gpus = config.parallelism.dp * config.parallelism.tp * config.parallelism.pp if total_gpus != config_gpus: - raise ValueError(f"The parallelism configuration (dp={parallelism.dp}, tp={parallelism.tp}, pp={parallelism.pp}) " + raise ValueError(f"The parallelism configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " f"doesn't match the number of available GPUs ({total_gpus}). " f"Please adjust your configuration to match the available resources.") @@ -512,7 +289,7 @@ def format_sbatch_option(option, value): # Prepare command trainer_python_file = "/fsx/elie_bakouch/nanotron/run_train.py" - cmd = f"{trainer_python_file} --config-file {config_path_yaml}" + cmd = f"{trainer_python_file} --config-file {args.config_path}" # Launch job launch_cmd = f"CUDA_DEVICE_MAX_CONNECTIONS='1' torchrun --nproc_per_node {gpu_count} {cmd}" diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 126a4d23..32f89d43 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -100,7 +100,6 @@ class SlurmArgs: gpu_partition (str): SLURM partition (queue) for GPU jobs. job_name (str): Name of the SLURM job. nodes (int): Number of nodes to allocate for the job. - logs_path (str): Base directory for storing log files. conda_path (str): Path to the Conda installation script. conda_env_path (str): Path to the Conda environment to be used. n_tasks_per_node (int): Number of tasks to run per node. Default is 1. @@ -118,14 +117,11 @@ class SlurmArgs: reservation (Optional[str]): Name of a reservation to use for the job. begin (Optional[str]): Earliest time the job can start. torchrun_args (Optional[Dict[str, str]]): Additional arguments for torchrun command. - slurm_logs_path (Optional[str]): Specific path for SLURM output logs. - config_logs_path (Optional[str]): Path for storing configuration logs. """ gpu_partition: str job_name: str nodes: int - logs_path: str conda_path: str conda_env_path: str n_tasks_per_node: int = 1 @@ -143,8 +139,6 @@ class SlurmArgs: reservation: Optional[str] = None begin: Optional[str] = None torchrun_args: Optional[Dict[str, str]] = None - slurm_logs_path: Optional[str] = None - config_logs_path: Optional[str] = None def __post_init__(self): if self.mail_type is None and self.mail_user is not None: @@ -244,6 +238,9 @@ class GeneralArgs: """ project: str + logs_path: Optional[str] = "./logs" + slurm_logs_path: Optional[str] = None + config_logs_path: Optional[str] = None repo_id: Optional[str] = None temp_dir: Optional[str] = None run: Optional[str] = None @@ -252,6 +249,7 @@ class GeneralArgs: consumed_train_samples: Optional[int] = None benchmark_csv_path: Optional[Path] = None ignore_sanity_checks: bool = True + name: Optional[str] = None def __post_init__(self): if self.seed is None: @@ -265,6 +263,8 @@ def __post_init__(self): self.run = "%date_%jobid" self.run.replace("%date", datetime.datetime.now().strftime("%Y%m%d_%H%M%S")) self.run.replace("%jobid", os.environ.get("SLURM_JOB_ID", "local")) + if self.name is None: + self.name = f"{self.project}-{self.run}" @dataclass @@ -415,7 +415,7 @@ class Config: data_stages: Optional[List[DatasetStageArgs]] = None profiler: Optional[ProfilerArgs] = None lighteval: Optional[LightEvalConfig] = None - s3_upload : Optional[S3UploadArgs] = None + s3_upload: Optional[S3UploadArgs] = None slurm: Optional[SlurmArgs] = None @classmethod @@ -424,7 +424,6 @@ def create_empty(cls): return cls(**{f.name: None for f in cls_fields}) def __post_init__(self): - if self.s3_upload is not None: self.s3_upload.__post_init__() @@ -460,38 +459,37 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" - if self.slurm is not None: - job_folder = os.path.join(self.slurm.logs_path, self.slurm.job_name) - os.makedirs(job_folder, exist_ok=True) + log_folder = os.path.join(self.general.logs_path, self.general.name) + os.makedirs(log_folder, exist_ok=True) - subfolders = ['configs', 'slurm'] + # Create config folder for all jobs + config_folder = os.path.join(log_folder, 'configs') + os.makedirs(config_folder, exist_ok=True) + self.general.config_folder_path = config_folder + + if self.slurm is not None: + subfolders = ['slurm'] if self.lighteval is not None and self.s3_upload is not None: subfolders.append('evals') - logs_paths = {} - for subfolder in subfolders: - specific_path = getattr(self.slurm, f"{subfolder}_logs_path", None) - if specific_path is None: - folder_path = os.path.join(job_folder, subfolder) - else: - folder_path = specific_path + folder_path = os.path.join(log_folder, subfolder) os.makedirs(folder_path, exist_ok=True) - logs_paths[subfolder] = folder_path + setattr(self.general, f"{subfolder}_logs_path", folder_path) if subfolder == 'evals': for evals_subfolder in ['launch-config', 'logs']: evals_subfolder_path = os.path.join(folder_path, evals_subfolder) os.makedirs(evals_subfolder_path, exist_ok=True) - self.slurm.config_logs_path = logs_paths['configs'] - self.slurm.slurm_logs_path = logs_paths['slurm'] - if self.lighteval is not None: - self.slurm.evals_logs_path = logs_paths['evals'] - - # # if lighteval, we need tokenizer to be defined - # if self.checkpoints.lighteval is not None: - # assert self.tokenizer.tokenizer_name_or_path is not None + # Create launch-script folder + launch_script_folder = os.path.join(log_folder, 'launch-script') + os.makedirs(launch_script_folder, exist_ok=True) + self.general.launch_script_path = launch_script_folder + + # if lighteval, we need tokenizer to be defined + if self.lighteval is not None: + assert self.tokenizer.tokenizer_name_or_path is not None @property def global_batch_size(self): diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 2e97f561..d41bcf46 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -280,7 +280,8 @@ def post_init(self): self.s3_mover = None if self.config.lighteval is not None and dist.get_rank(self.parallel_context.world_pg) == 0: # We only start evaluation runs once on the first node - if self.s3_mover is not None and self.slurm is not None: + #TODO @eliebak add the support of lighteval locally + if self.s3_mover is not None and self.config.slurm is not None: self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint else: From e9d4a2e4bf4dd7babfe30cef7aa02342f202742e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 27 Aug 2024 19:08:19 +0000 Subject: [PATCH 13/43] fix launcher and create_config file, still need some improvement for the vf --- create_config.py | 8 ++++---- launcher.py | 19 +++++++++++++++---- src/nanotron/config/config.py | 10 ++++++++-- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/create_config.py b/create_config.py index 55dcf85d..c4c13b40 100644 --- a/create_config.py +++ b/create_config.py @@ -145,7 +145,7 @@ ) parallelism = ParallelismArgs( - dp=8, + dp=16, pp=1, tp=1, pp_engine="1f1b", @@ -208,7 +208,7 @@ ) s3_upload = S3UploadArgs( - upload_s3_path=f"s3://elie-exp/debug_nanotron/test/", + upload_s3_path=f"s3://elie-exp/debug_nanotron/test_eval/", remove_after_upload=True, s5cmd_numworkers=16, s5cmd_concurrency=5, @@ -245,8 +245,8 @@ ) timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") dir = os.path.dirname(__file__) - os.makedirs(config.general.config_folder_path, exist_ok=True) - config_path_yaml = f"{config.general.config_folder_path}/{timestamp}.yaml" + os.makedirs(config.general.config_logs_path, exist_ok=True) + config_path_yaml = f"{config.general.config_logs_path}/{timestamp}.yaml" config.save_as_yaml(config_path_yaml) os.makedirs(f"{config.general.slurm_logs_path}/", exist_ok=True) diff --git a/launcher.py b/launcher.py index 48b5f7ed..823553f8 100644 --- a/launcher.py +++ b/launcher.py @@ -161,11 +161,11 @@ def set_nested_attribute(obj, path, value): if config.slurm: dir = os.path.dirname(__file__) - os.makedirs(config.slurm.config_logs_path, exist_ok=True) - config_path_yaml = f"{config.slurm.config_logs_path}/{timestamp}.yaml" + os.makedirs(config.general.config_logs_path, exist_ok=True) + config_path_yaml = f"{config.general.config_logs_path}/{timestamp}.yaml" config.save_as_yaml(config_path_yaml) - os.makedirs(f"{config.slurm.slurm_logs_path}/", exist_ok=True) + os.makedirs(f"{config.general.slurm_logs_path}/", exist_ok=True) def format_sbatch_option(option, value): return f"#SBATCH --{option}={value}" if value is not None else "" @@ -181,7 +181,7 @@ def format_sbatch_option(option, value): {format_sbatch_option("cpus-per-task", config.slurm.cpus_per_task)} {format_sbatch_option("gres", f"gpu:{config.slurm.gpu_per_node}")} {format_sbatch_option("partition", config.slurm.gpu_partition)} -{format_sbatch_option("output", f"{config.slurm.slurm_logs_path}/train-{timestamp}-%x-%j.out")} +{format_sbatch_option("output", f"{config.general.slurm_logs_path}/train-{timestamp}-%x-%j.out")} {format_sbatch_option("array", config.slurm.array)} {format_sbatch_option("qos", config.slurm.qos)} {format_sbatch_option("mail-type", config.slurm.mail_type)} @@ -259,6 +259,17 @@ def format_sbatch_option(option, value): echo "END TIME: $(date)" """ + # Save the Slurm script + if config.general.launch_script_path: + os.makedirs(config.general.launch_script_path, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + script_filename = f"slurm_script_{timestamp}.slurm" + script_path = os.path.join(config.general.launch_script_path, script_filename) + + with open(script_path, 'w') as f: + f.write(sbatch_script) + + print(f"Slurm script saved to: {script_path}") print(f"Slurm job launched with id={launch_slurm_job(sbatch_script)}") else: diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 32f89d43..6cc7a509 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -239,8 +239,10 @@ class GeneralArgs: project: str logs_path: Optional[str] = "./logs" + launch_script_path: Optional[str] = None slurm_logs_path: Optional[str] = None config_logs_path: Optional[str] = None + evals_logs_path: Optional[str] = None repo_id: Optional[str] = None temp_dir: Optional[str] = None run: Optional[str] = None @@ -424,6 +426,9 @@ def create_empty(cls): return cls(**{f.name: None for f in cls_fields}) def __post_init__(self): + if hasattr(self, '_post_init_done'): + return + self._post_init_done = True if self.s3_upload is not None: self.s3_upload.__post_init__() @@ -459,19 +464,19 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" + log_folder = os.path.join(self.general.logs_path, self.general.name) os.makedirs(log_folder, exist_ok=True) # Create config folder for all jobs config_folder = os.path.join(log_folder, 'configs') os.makedirs(config_folder, exist_ok=True) - self.general.config_folder_path = config_folder + self.general.config_logs_path = config_folder if self.slurm is not None: subfolders = ['slurm'] if self.lighteval is not None and self.s3_upload is not None: subfolders.append('evals') - for subfolder in subfolders: folder_path = os.path.join(log_folder, subfolder) os.makedirs(folder_path, exist_ok=True) @@ -486,6 +491,7 @@ def __post_init__(self): launch_script_folder = os.path.join(log_folder, 'launch-script') os.makedirs(launch_script_folder, exist_ok=True) self.general.launch_script_path = launch_script_folder + # if lighteval, we need tokenizer to be defined if self.lighteval is not None: From 79ae2cbbf6e552af5bc553246863203aa86eef06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Fri, 30 Aug 2024 03:28:06 +0000 Subject: [PATCH 14/43] lot of changes, working on 1 node with s3, will test the rest soon it may work --- create_config.py | 64 +++++++-------- launcher.py | 28 +++---- src/nanotron/config/config.py | 17 ++-- src/nanotron/config/lighteval_config.py | 18 ++++- src/nanotron/lighteval/evaluation_tasks.py | 9 +-- src/nanotron/lighteval/one_job_runner.py | 79 ++++++++++++++----- src/nanotron/lighteval/run_eval.slurm.jinja | 17 ++-- .../lighteval/run_eval_no_s3.slurm.jinja | 62 +++++++++++++++ src/nanotron/lighteval/run_evals.py | 6 +- src/nanotron/trainer.py | 30 +++++-- 10 files changed, 218 insertions(+), 112 deletions(-) create mode 100644 src/nanotron/lighteval/run_eval_no_s3.slurm.jinja diff --git a/create_config.py b/create_config.py index c4c13b40..82379463 100644 --- a/create_config.py +++ b/create_config.py @@ -31,52 +31,44 @@ RandomInit, TokenizerArgs, TokensArgs, - LightEvalWandbLoggerConfig, DatasetStageArgs, ) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("project", help="project name", type=str) + parser.add_argument("--project", help="project name", type=str, required=True) + parser.add_argument("--run", help="run name", type=str, required=True) parser.add_argument("--slurm", help="use slurm", action="store_true") - parser.add_argument("--name", help="run name", type=str, default=None) parser.add_argument("--seed", help="seed", type=int, default=8) - parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="normal") + parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="high") parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", help="Override config values. Use dot notation for nested keys.") parser.add_argument("--launch", action="store_true", help="Launch the configuration immediately") args = parser.parse_args() - if args.name is not None: - run = f"{args.project}-{args.name}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" - else: - run = f"{args.project}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" general = GeneralArgs( project=args.project, - run=run, - repo_id="HuggingFaceSmol/test-nanotron", + run=args.run, logs_path="/fsx/elie_bakouch/nanotron/debug", seed=args.seed, temp_dir="/scratch", ) if args.slurm: - job_name=f"{args.project}-{args.name}" if args.name is not None else f"{args.project}" + job_name=f"{args.project}-{args.run}" slurm = SlurmArgs( gpu_partition="hopper-prod", - job_name=f"{args.project}-{args.name}", - nodes=2, - conda_path="/fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh", - conda_env_path="/fsx/elie_bakouch/miniconda3/envs/smollm", - exclude_nodes=["ip-26-0-161-138", "ip-26-0-161-178"], + job_name=job_name, + nodes=1, torchrun_args={ "rdzv_backend": "etcd-v2", "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", "rdzv_id": "$SLURM_JOB_ID" }, - qos="normal", - begin="now+0minutes" + qos="high", + begin="now+0minutes", + time="01:00:00", ) else: slurm = None @@ -117,27 +109,27 @@ tp_linear_async_communication=False, ), batch_size=16, - wandb=LightEvalWandbLoggerConfig( - wandb_project=args.project, - wandb_entity="eliebak", - wandb_run_name=f"{run}_evals", - ), logging=LightEvalLoggingArgs( - local_output_path=f"{general.temp_dir}/lighteval/{run}", - push_details_to_hub=False, + local_output_path=f"/fsx/elie_bakouch/lighteval-logs/{general.project}-{general.run}", + #local_output_path=PATH_TO_LOCAL_LOG, + private=True, + push_details_to_hub=True, push_results_to_hub=True, push_results_to_tensorboard=True, - #hub_repo_details=REPO_ID, - hub_repo_results=general.repo_id, - hub_repo_tensorboard=general.repo_id, - tensorboard_metric_prefix="e", + hf_user_or_org="eliebak", + #hf_user_or_org="USER_OR_ORG", + hub_repo_results="lighteval-results", + hub_repo_details="lighteval-details", + hub_repo_tensorboard="smollm-evals-visualization", + tensorboard_metric_prefix="eval", ), slurm_template="/fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_eval.slurm.jinja", ) checkpoints = CheckpointsArgs( - checkpoints_path=f"checkpoints/{run}", + checkpoints_path=f"/scratch/elie_bakouch/checkpoints/{general.project}-{general.run}", + #checkpoints_path="CHECKPOINTS_PATH", checkpoints_path_is_shared_file_system=False, resume_checkpoint_path=None, checkpoint_interval=20, @@ -145,14 +137,13 @@ ) parallelism = ParallelismArgs( - dp=16, + dp=8, pp=1, tp=1, pp_engine="1f1b", tp_mode="REDUCE_SCATTER", tp_linear_async_communication=True, ) - #Add sanity check for the number of GPUs and the number of nodes ? tokens = TokensArgs( batch_accumulation_per_replica=8, @@ -164,7 +155,6 @@ model = ModelArgs( model_config=model_config, - make_vocab_size_divisible_by=1, init_method=RandomInit( std=math.sqrt(model_config.hidden_size), ), @@ -179,7 +169,7 @@ ) learning_rate_scheduler = LRSchedulerArgs( - learning_rate=1e-4, #llama one + learning_rate=1e-4, lr_warmup_steps=10, lr_warmup_style="linear", lr_decay_style="linear", @@ -208,7 +198,7 @@ ) s3_upload = S3UploadArgs( - upload_s3_path=f"s3://elie-exp/debug_nanotron/test_eval/", + upload_s3_path=f"s3://elie-exp/debug_nanotron/eval-vf-hope/", remove_after_upload=True, s5cmd_numworkers=16, s5cmd_concurrency=5, @@ -246,12 +236,12 @@ timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") dir = os.path.dirname(__file__) os.makedirs(config.general.config_logs_path, exist_ok=True) - config_path_yaml = f"{config.general.config_logs_path}/{timestamp}.yaml" + config_path_yaml = f"{config.general.config_logs_path}/{timestamp}_create.yaml" config.save_as_yaml(config_path_yaml) os.makedirs(f"{config.general.slurm_logs_path}/", exist_ok=True) - print(f"Configuration saved to: {config_path_yaml}") + print(f"💾 Configuration saved to: {config_path_yaml}") if args.launch: launcher_path = os.path.join(dir, "launcher.py") diff --git a/launcher.py b/launcher.py index 823553f8..496deb8c 100644 --- a/launcher.py +++ b/launcher.py @@ -2,14 +2,11 @@ import subprocess import tempfile from datetime import datetime -import math import torch import argparse -from typing import Any, Dict from nanotron.logging import human_format -from nanotron.models.llama import LlamaConfig from nanotron.config import ( Config, @@ -162,7 +159,7 @@ def set_nested_attribute(obj, path, value): dir = os.path.dirname(__file__) os.makedirs(config.general.config_logs_path, exist_ok=True) - config_path_yaml = f"{config.general.config_logs_path}/{timestamp}.yaml" + config_path_yaml = f"{config.general.config_logs_path}/{timestamp}_launch.yaml" config.save_as_yaml(config_path_yaml) os.makedirs(f"{config.general.slurm_logs_path}/", exist_ok=True) @@ -181,8 +178,8 @@ def format_sbatch_option(option, value): {format_sbatch_option("cpus-per-task", config.slurm.cpus_per_task)} {format_sbatch_option("gres", f"gpu:{config.slurm.gpu_per_node}")} {format_sbatch_option("partition", config.slurm.gpu_partition)} -{format_sbatch_option("output", f"{config.general.slurm_logs_path}/train-{timestamp}-%x-%j.out")} -{format_sbatch_option("array", config.slurm.array)} +{format_sbatch_option("output", f"{config.general.slurm_logs_path}/train-{timestamp}-%j.out")} +{format_sbatch_option("error", f"{config.general.slurm_logs_path}/train-{timestamp}-%j.err")} {format_sbatch_option("qos", config.slurm.qos)} {format_sbatch_option("mail-type", config.slurm.mail_type)} {format_sbatch_option("mail-user", config.slurm.mail_user)} @@ -198,13 +195,11 @@ def format_sbatch_option(option, value): TRAINER_PYTHON_FILE=/fsx/elie_bakouch/nanotron/run_train.py nvidia-smi -source ~/.bashrc -source /fsx/elie_bakouch/miniconda3/etc/profile.d/conda.sh -conda activate {config.slurm.conda_env_path} #Modify this line if you use something different than conda #Show some environment variables echo python3 version = `python3 --version` +echo "Python path: $(which python3)" echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" @@ -232,9 +227,8 @@ def format_sbatch_option(option, value): ##### MOVE TO YAML ###### -CMD=" \ - $TRAINER_PYTHON_FILE \ - --config-file {config_path_yaml} +CMD=" $TRAINER_PYTHON_FILE \ + --config-file {config_path_yaml} \ " export LAUNCHER="torchrun \ --nproc_per_node {config.slurm.gpu_per_node} \ @@ -260,6 +254,7 @@ def format_sbatch_option(option, value): echo "END TIME: $(date)" """ # Save the Slurm script + print(f"🚀 Slurm job launched with id={launch_slurm_job(sbatch_script)}") if config.general.launch_script_path: os.makedirs(config.general.launch_script_path, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -269,9 +264,8 @@ def format_sbatch_option(option, value): with open(script_path, 'w') as f: f.write(sbatch_script) - print(f"Slurm script saved to: {script_path}") + print(f" 💾 Logs are saved to : {config.general.logs_path}") - print(f"Slurm job launched with id={launch_slurm_job(sbatch_script)}") else: # Check if running on an interactive node try: @@ -281,7 +275,7 @@ def format_sbatch_option(option, value): is_interactive = False if is_interactive: - print("Running on an interactive node with GPUs.") + print("💻 Running on an interactive node with GPUs.") # Check if the parallelism configuration matches the available GPUs total_gpus = gpu_count @@ -304,9 +298,9 @@ def format_sbatch_option(option, value): # Launch job launch_cmd = f"CUDA_DEVICE_MAX_CONNECTIONS='1' torchrun --nproc_per_node {gpu_count} {cmd}" - print(f"Launching interactive job with command: {launch_cmd}") + print(f"🚀 Launching interactive job with command: {launch_cmd}") # Execute the command subprocess.run(launch_cmd, shell=True, check=True) else: - print("Not running on a Slurm cluster or an interactive node with GPUs. Please submit a Slurm job or use an interactive node with GPUs.") \ No newline at end of file + print("❌ Not running on a Slurm cluster or an interactive node with GPUs. Please submit a Slurm job or use an interactive node with GPUs.") \ No newline at end of file diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 6cc7a509..760f2bd3 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -100,8 +100,6 @@ class SlurmArgs: gpu_partition (str): SLURM partition (queue) for GPU jobs. job_name (str): Name of the SLURM job. nodes (int): Number of nodes to allocate for the job. - conda_path (str): Path to the Conda installation script. - conda_env_path (str): Path to the Conda environment to be used. n_tasks_per_node (int): Number of tasks to run per node. Default is 1. cpus_per_task (int): Number of CPUs to allocate per task. Default is 32. gpu_per_node (int): Number of GPUs to allocate per node. Default is 8. @@ -122,8 +120,6 @@ class SlurmArgs: gpu_partition: str job_name: str nodes: int - conda_path: str - conda_env_path: str n_tasks_per_node: int = 1 cpus_per_task: int = 32 gpu_per_node: int = 8 @@ -238,20 +234,18 @@ class GeneralArgs: """ project: str + run: str logs_path: Optional[str] = "./logs" launch_script_path: Optional[str] = None slurm_logs_path: Optional[str] = None config_logs_path: Optional[str] = None evals_logs_path: Optional[str] = None - repo_id: Optional[str] = None temp_dir: Optional[str] = None - run: Optional[str] = None seed: Optional[int] = None step: Optional[int] = None consumed_train_samples: Optional[int] = None benchmark_csv_path: Optional[Path] = None ignore_sanity_checks: bool = True - name: Optional[str] = None def __post_init__(self): if self.seed is None: @@ -265,8 +259,6 @@ def __post_init__(self): self.run = "%date_%jobid" self.run.replace("%date", datetime.datetime.now().strftime("%Y%m%d_%H%M%S")) self.run.replace("%jobid", os.environ.get("SLURM_JOB_ID", "local")) - if self.name is None: - self.name = f"{self.project}-{self.run}" @dataclass @@ -429,6 +421,7 @@ def __post_init__(self): if hasattr(self, '_post_init_done'): return self._post_init_done = True + self.general.__post_init__() if self.s3_upload is not None: self.s3_upload.__post_init__() @@ -465,7 +458,10 @@ def __post_init__(self): ), "The stages are not sorted by start_training_step in increasing order" - log_folder = os.path.join(self.general.logs_path, self.general.name) + project_log_folder = Path(self.general.logs_path) + os.makedirs(project_log_folder, exist_ok=True) + + log_folder = os.path.join(project_log_folder, f"{self.general.run}-{self.general.project}") os.makedirs(log_folder, exist_ok=True) # Create config folder for all jobs @@ -497,6 +493,7 @@ def __post_init__(self): if self.lighteval is not None: assert self.tokenizer.tokenizer_name_or_path is not None + @property def global_batch_size(self): return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index b5f12059..ea3ba120 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -32,19 +32,29 @@ def __post_init__(self): @dataclass class LightEvalLoggingArgs: """Arguments related to logging for LightEval""" - local_output_path: Optional[Path] = None + private: Optional[bool] = True push_results_to_hub: Optional[bool] = None push_details_to_hub: Optional[bool] = None push_results_to_tensorboard: Optional[bool] = None - hub_repo_results: Optional[str] = None - hub_repo_details: Optional[str] = None + hf_user_or_org: Optional[str] = None + hub_repo_results: Optional[str] = None #path is hf_user_or_org/hub_repo_results + hub_repo_details: Optional[str] = None #path is hf_user_or_org/hub_repo_details hub_repo_tensorboard: Optional[str] = None tensorboard_metric_prefix: Optional[str] = None def __post_init__(self): if isinstance(self.local_output_path, str): self.local_output_path = Path(self.local_output_path) + if self.push_results_to_hub is not None and self.hf_user_or_org is None: + raise ValueError("hf_user_or_org must be specified if push_results_to_hub is set") + if self.push_details_to_hub is not None and self.hf_user_or_org is None: + raise ValueError("hf_user_or_org must be specified if push_details_to_hub is set") + if self.hf_user_or_org is not None: + if self.push_results_to_hub is not None and self.hub_repo_results is None: + self.hub_repo_results = "evals-results" + if self.push_details_to_hub is not None and self.hub_repo_details is None: + self.hub_repo_details = "evals-details" @dataclass @@ -83,7 +93,7 @@ class LightEvalConfig: slurm_template: Optional[str] = None slurm_script_dir: Optional[str] = None - + temp_dir: Optional[str] = None checkpoints_path: Optional[str] = None parallelism: Optional[ParallelismArgs] = None batch_size: Optional[int] = None diff --git a/src/nanotron/lighteval/evaluation_tasks.py b/src/nanotron/lighteval/evaluation_tasks.py index 88bdd0b1..ff4342bb 100644 --- a/src/nanotron/lighteval/evaluation_tasks.py +++ b/src/nanotron/lighteval/evaluation_tasks.py @@ -3,16 +3,15 @@ Custom evaluation tasks for lighteval This file generally create just a TASKS_TABLE and TASKS_GROUPS which are then imported by LightEval. -Edit this file to add your own task if needed """ import re from dataclasses import asdict from typing import Dict, List, Tuple -from lighteval.metrics.metrics import Metrics +from lighteval.metrics import Metrics from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc -from lighteval.tasks.default_prompts import LETTER_INDICES +from lighteval.tasks.tasks_prompt_formatting import LETTER_INDICES _TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = [] _TASKS: List[LightevalTaskConfig] = [] @@ -628,7 +627,6 @@ def agi_eval_prompt_no_letters(line, task_name: str = None): ## HUMAN EVAL ## -#TODO @eliebak add human eval again # human_eval = LightevalTaskConfig( # name="human_eval", # prompt_function="human_eval", @@ -640,9 +638,8 @@ def agi_eval_prompt_no_letters(line, task_name: str = None): EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING]) # Convert to dict for lighteval -TASKS_TABLE = [asdict(task) for task in _TASKS] +TASKS_TABLE = [task.as_dict() for task in _TASKS] # You can have a few pre-organised groups of tasks -# TODO @eliebak add math and code here TASKS_GROUPS = { "all": ",".join(t[1] for t in _TASKS_STRINGS), "early-signal": EARLY_SIGNAL_TASKS, diff --git a/src/nanotron/lighteval/one_job_runner.py b/src/nanotron/lighteval/one_job_runner.py index 30e7eaf8..8b00b1f7 100644 --- a/src/nanotron/lighteval/one_job_runner.py +++ b/src/nanotron/lighteval/one_job_runner.py @@ -11,7 +11,7 @@ from nanotron.logging import log_rank from nanotron.parallel import ParallelContext -from nanotron.config import Config, LightEvalConfig +from nanotron.config import Config logger = logging.get_logger(__name__) @@ -22,6 +22,30 @@ def __init__(self, config: Config, parallel_context: Optional[ParallelContext] = self.lighteval_config = config.lighteval self.parallel_context = parallel_context + def eval_single_checkpoint_no_s3(self, checkpoints_folder, current_step) -> Tuple[str, str]: + current_checkpoint_folder = os.path.join(checkpoints_folder, str(current_step)) + checkpoint_path = os.path.join(current_checkpoint_folder, "config.yaml") + if not os.path.exists(checkpoint_path): + log_rank( + f"Checkpoint path does not exist: {checkpoint_path}. Unable to evaluate.", + logger=logger, + level=logging.ERROR, + group=self.parallel_context.dp_pg if self.parallel_context is not None else None, + rank=0, + ) + return None, None + + slurm_job_id, slurm_log = run_slurm_one_job( + config = self.config, + slurm_template=self.lighteval_config.slurm_template, + model_checkpoint_path=checkpoint_path, + current_step=self.config.general.step, + checkpoint_local_path=current_checkpoint_folder, + s3=False, + ) + + return slurm_job_id, slurm_log + def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: """Run light evaluation on uploaded files.""" logger.warning(f"Lighteval Runner got {len(uploaded_files)} files. Checking configs.") @@ -53,6 +77,8 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: config = self.config, slurm_template=self.lighteval_config.slurm_template, model_checkpoint_path=checkpoint_path, + current_step=self.config.general.step, + s3=True, ) return slurm_job_id, slurm_log @@ -62,6 +88,9 @@ def run_slurm_one_job( config: Config, model_checkpoint_path: str, slurm_template: str, + current_step: int, + s3: bool = True, + checkpoint_local_path: str = None, slurm_name: Optional[str] = "eval", slurm_kwargs: Optional[dict] = None, #add slurm_kwargs and modify the jinja template in case you need to adapt it to your slurm cluster. ): @@ -71,8 +100,11 @@ def run_slurm_one_job( mapping: Mapping to use for the job script (see SLURM_ONE_JOB_MAPPING) """ - eval_launch_script_path=os.path.join(config.slurm.evals_logs_path, "launch-config") - eval_logs_path= os.path.join(config.slurm.evals_logs_path, "logs") + eval_launch_script_path = os.path.join(config.general.evals_logs_path, "launch-config", str(current_step)) + eval_logs_path = os.path.join(config.general.evals_logs_path, "logs", str(current_step)) + + os.makedirs(eval_launch_script_path, exist_ok=True) + os.makedirs(eval_logs_path, exist_ok=True) environment = jinja2.Environment( comment_start_string="{=", @@ -81,21 +113,32 @@ def run_slurm_one_job( with open(slurm_template, "r") as f: SLURM_JOBS_ARRAY_TEMPLATE = environment.from_string(f.read()) - - launch_string = SLURM_JOBS_ARRAY_TEMPLATE.render( - model_checkpoint_path=model_checkpoint_path, - job_name=f"{slurm_name}-eval", - n_tasks_per_node=config.slurm.n_tasks_per_node, - partition=config.slurm.gpu_partition, - gpu_per_node=config.slurm.gpu_per_node, - cpus_per_task=config.slurm.cpus_per_task, - eval_path=eval_logs_path, - mail=config.slurm.mail, - conda_path=config.slurm.conda_path, - conda_env_path=config.slurm.conda_env_path, - local_path=config.checkpoints.checkpoints_path, - **(slurm_kwargs if slurm_kwargs else {}), - ) + if s3: + launch_string = SLURM_JOBS_ARRAY_TEMPLATE.render( + model_checkpoint_path=model_checkpoint_path, + job_name=f"{slurm_name}", + n_tasks_per_node=config.slurm.n_tasks_per_node, + partition=config.slurm.gpu_partition, + gpu_per_node=config.slurm.gpu_per_node, + cpus_per_task=config.slurm.cpus_per_task, + eval_path=eval_logs_path, + mail=config.slurm.mail_user, + local_path=config.lighteval.temp_dir, + **(slurm_kwargs if slurm_kwargs else {}), + ) + else: + launch_string = SLURM_JOBS_ARRAY_TEMPLATE.render( + model_checkpoint_path=model_checkpoint_path, + job_name=f"{slurm_name}", + n_tasks_per_node=config.slurm.n_tasks_per_node, + partition=config.slurm.gpu_partition, + gpu_per_node=config.slurm.gpu_per_node, + cpus_per_task=config.slurm.cpus_per_task, + eval_path=eval_logs_path, + mail=config.slurm.mail_user, + ckpt_local_path=checkpoint_local_path, + **(slurm_kwargs if slurm_kwargs else {}), + ) match = re.match(r"#SBATCH --output=(.*)", launch_string) slurm_output_path = match.group(1) if match else "" diff --git a/src/nanotron/lighteval/run_eval.slurm.jinja b/src/nanotron/lighteval/run_eval.slurm.jinja index 6b2e1245..dc6a6ab6 100644 --- a/src/nanotron/lighteval/run_eval.slurm.jinja +++ b/src/nanotron/lighteval/run_eval.slurm.jinja @@ -1,28 +1,21 @@ #!/bin/bash -#SBATCH --job-name={{ job_name }}-eval +#SBATCH --job-name={{ job_name }} #SBATCH --nodes=1 #SBATCH --ntasks-per-node={{ n_tasks_per_node }} #SBATCH --partition={{ partition }} #SBATCH --gres=gpu:{{ gpu_per_node }} #SBATCH --cpus-per-task={{ cpus_per_task}} -#SBATCH --output={{ eval_path }}/eval-%x-%n-%j -#SBATCH --error={{ eval_path }}/eval-%x-%n-%j +#SBATCH --output={{ eval_path }}/%x-%n-%j.out +#SBATCH --error={{ eval_path }}/%x-%n-%j.err #SBATCH --qos=high #SBATCH --dependency=singleton #SBATCH --mail-type=FAIL #SBATCH --mail-user={{ mail }} +#SBATCH --time=01:00:00 -########################################### -source ~/.bashrc -source {{ conda_path }} -conda activate {{ conda_env_path}} #Modify this line if you use something different than conda LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={{ local_path }} -# [END] ADAPT TO YOUR ENVIRONMENT -########################################### - - set -x -e echo "START TIME: $(date)" #Show some environment variables @@ -47,7 +40,6 @@ if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then fi fi -export TMPDIR=/scratch export CUBLAS_WORKSPACE_CONFIG=":4096:8" export CUDA_DEVICE_MAX_CONNECTIONS="1" @@ -68,6 +60,7 @@ torch_dist_args="--nproc_per_node 8 \ launch_args="$torch_dist_args \ /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ --checkpoint-config-path ${LOCAL_DOWNLOAD_CHECKPOINT_FOLDER}/config.yaml \ + --hf-user-or-org {{ hf_user_or_org }} \ " srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file diff --git a/src/nanotron/lighteval/run_eval_no_s3.slurm.jinja b/src/nanotron/lighteval/run_eval_no_s3.slurm.jinja new file mode 100644 index 00000000..a3e9bb28 --- /dev/null +++ b/src/nanotron/lighteval/run_eval_no_s3.slurm.jinja @@ -0,0 +1,62 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --partition={{ partition }} +#SBATCH --gres=gpu:{{ gpu_per_node }} +#SBATCH --cpus-per-task={{ cpus_per_task}} +#SBATCH --output={{ eval_path }}/%x-%n-%j.out +#SBATCH --error={{ eval_path }}/%x-%n-%j.err +#SBATCH --qos=high +#SBATCH --dependency=singleton +#SBATCH --mail-type=FAIL +#SBATCH --mail-user={{ mail }} +#SBATCH --time=01:00:00 + + +CHECKPOINT_FOLDER={{ ckpt_local_path }} + +set -x -e +echo "START TIME: $(date)" +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +# Hugging Face token +if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then + # Attempt to read the token from the cache + if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then + export HUGGING_FACE_HUB_TOKEN=$TOKEN + else + echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." + exit 1 + fi +fi + +export CUBLAS_WORKSPACE_CONFIG=":4096:8" +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +echo go $COUNT_NODE +echo $HOSTNAMES + +torch_dist_args="--nproc_per_node 8 \ + --nnodes $COUNT_NODE \ + --max_restarts 0 \ + --tee 3 \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: " + +launch_args="$torch_dist_args \ + /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path ${CHECKPOINT_FOLDER}/config.yaml \ + --hf-user-or-org {{ hf_user_or_org }} \ + " + +srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file diff --git a/src/nanotron/lighteval/run_evals.py b/src/nanotron/lighteval/run_evals.py index e2d84d1b..5ee36f53 100644 --- a/src/nanotron/lighteval/run_evals.py +++ b/src/nanotron/lighteval/run_evals.py @@ -32,4 +32,8 @@ def get_parser(): if __name__ == "__main__": parser = get_parser() args, unknowns = parser.parse_known_args() - main(args.checkpoint_config_path, args.lighteval_override, args.cache_dir) \ No newline at end of file + main( + checkpoint_config_path=args.checkpoint_config_path, + lighteval_config_path=args.lighteval_override, + cache_dir=args.cache_dir, + ) \ No newline at end of file diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index d41bcf46..dde0d83c 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -269,7 +269,6 @@ def post_init(self): self.s3_mover = S3Mover( local_path=self.config.checkpoints.checkpoints_path, s3_path=self.config.s3_upload.upload_s3_path, - # duplicate_checkpoint_path=self.config.checkpoints.resume_checkpoint_path, remove_after_upload=self.config.s3_upload.remove_after_upload, s5cmd_numworkers=self.config.s3_upload.s5cmd_numworkers, s5cmd_concurrency=self.config.s3_upload.s5cmd_concurrency, @@ -278,14 +277,28 @@ def post_init(self): ) else: self.s3_mover = None + if self.config.lighteval is not None and dist.get_rank(self.parallel_context.world_pg) == 0: # We only start evaluation runs once on the first node - #TODO @eliebak add the support of lighteval locally - if self.s3_mover is not None and self.config.slurm is not None: - self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) + self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) + if self.s3_mover is not None and self.config.slurm is not None: self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint + elif self.config.slurm is not None and self.s3_mover is None: + # Use the no_s3 version of the evaluation function + self.post_checkpoint_callback = self.lighteval_runner.eval_single_checkpoint_no_s3 else: - log_rank("LightEval is enabled but s3 upload is not enabled, skipping evaluation", logger=logger, level=logging.INFO, rank=0) + log_rank("LightEval is enabled but Slurm is not configured, skipping evaluation", logger=logger, level=logging.INFO, rank=0) + else: + self.post_checkpoint_callback = None + + def post_save_checkpoint(self): + # Upload to S3 + if self.s3_mover is not None: + self.s3_mover.start_uploading() + elif self.post_checkpoint_callback is not None: + # If we're not using S3, but we have a post-checkpoint callback, execute it + checkpoint_path = self.config.checkpoints.checkpoints_path / f"{self.iteration_step}" + self.post_checkpoint_callback(checkpoint_path) def pre_training(self, *args, **kwargs): self._print_training_plan() @@ -733,7 +746,7 @@ def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: reloaded_from_checkpoint=True if not reloaded_from_checkpoint: # TODO @eliebak add s3 support also here - log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO) + log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0) if isinstance(self.config.model.init_method, ExistingCheckpointInit): # Initialize model from an pretrained model checkpoint (without optimizer, lr_scheduler...) self.param_shard_metadata = load_weights( @@ -874,7 +887,10 @@ def post_save_checkpoint(self): # Upload to S3 if self.s3_mover is not None: self.s3_mover.start_uploading() - + elif self.post_checkpoint_callback is not None: + # If we're not using S3, but we have a post-checkpoint callback, execute it + checkpoint_path = self.config.checkpoints.checkpoints_path / f"{self.iteration_step}" + self.post_checkpoint_callback(checkpoint_path) def save_checkpoint(self) -> Path: self.pre_save_checkpoint() From cfcbd7013e81c06cf8cd01afa190a2c3eefa8670 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Sat, 31 Aug 2024 03:16:23 +0000 Subject: [PATCH 15/43] delete the SlurmArgs and add config to be more cluster agnostic + other big changes --- create_config.py | 80 +++----- launcher.py | 193 +++++++++--------- src/nanotron/config/config.py | 83 +------- src/nanotron/lighteval/one_job_runner.py | 53 +++-- src/nanotron/slurm/eval_slurm_config.json | 28 +++ src/nanotron/slurm/launch_slurm_config.json | 27 +++ .../slurm/launch_training.slurm.jinja | 92 +++++++++ src/nanotron/slurm/run_eval.slurm.jinja | 86 ++++++++ src/nanotron/slurm/run_eval_s3.slurm.jinja | 62 ++++++ src/nanotron/trainer.py | 4 +- 10 files changed, 447 insertions(+), 261 deletions(-) create mode 100644 src/nanotron/slurm/eval_slurm_config.json create mode 100644 src/nanotron/slurm/launch_slurm_config.json create mode 100644 src/nanotron/slurm/launch_training.slurm.jinja create mode 100644 src/nanotron/slurm/run_eval.slurm.jinja create mode 100644 src/nanotron/slurm/run_eval_s3.slurm.jinja diff --git a/create_config.py b/create_config.py index 82379463..2f7b75bf 100644 --- a/create_config.py +++ b/create_config.py @@ -1,14 +1,12 @@ import os +from pathlib import Path import subprocess -import tempfile from datetime import datetime import math import torch import argparse -from typing import Any, Dict -from nanotron.logging import human_format from nanotron.models.llama import LlamaConfig from nanotron.config import ( @@ -16,7 +14,6 @@ DataArgs, NanosetDatasetsArgs, S3UploadArgs, - SlurmArgs, CheckpointsArgs, GeneralArgs, LightEvalConfig, @@ -38,40 +35,23 @@ parser = argparse.ArgumentParser() parser.add_argument("--project", help="project name", type=str, required=True) parser.add_argument("--run", help="run name", type=str, required=True) - parser.add_argument("--slurm", help="use slurm", action="store_true") parser.add_argument("--seed", help="seed", type=int, default=8) parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="high") parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", help="Override config values. Use dot notation for nested keys.") parser.add_argument("--launch", action="store_true", help="Launch the configuration immediately") + parser.add_argument("--slurm", help="use slurm", action="store_true") + parser.add_argument("--nodes", help="specify the number of nodes", type=int) args = parser.parse_args() general = GeneralArgs( project=args.project, run=args.run, - logs_path="/fsx/elie_bakouch/nanotron/debug", + logs_path="/fsx/elie_bakouch/nanotron/refactor-logs", seed=args.seed, temp_dir="/scratch", ) - - if args.slurm: - job_name=f"{args.project}-{args.run}" - slurm = SlurmArgs( - gpu_partition="hopper-prod", - job_name=job_name, - nodes=1, - torchrun_args={ - "rdzv_backend": "etcd-v2", - "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", - "rdzv_id": "$SLURM_JOB_ID" - }, - qos="high", - begin="now+0minutes", - time="01:00:00", - ) - else: - slurm = None model_config = LlamaConfig( bos_token_id=0, @@ -123,7 +103,7 @@ hub_repo_tensorboard="smollm-evals-visualization", tensorboard_metric_prefix="eval", ), - slurm_template="/fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_eval.slurm.jinja", + slurm_template="/fsx/elie_bakouch/nanotron/src/nanotron/slurm/run_eval.slurm.jinja", ) @@ -132,12 +112,12 @@ #checkpoints_path="CHECKPOINTS_PATH", checkpoints_path_is_shared_file_system=False, resume_checkpoint_path=None, - checkpoint_interval=20, + checkpoint_interval=500, save_initial_state=False, ) parallelism = ParallelismArgs( - dp=8, + dp=32, pp=1, tp=1, pp_engine="1f1b", @@ -146,17 +126,17 @@ ) tokens = TokensArgs( - batch_accumulation_per_replica=8, + batch_accumulation_per_replica=2, micro_batch_size=16, sequence_length=2048, - train_steps=100, + train_steps=1500, val_check_interval=-1, ) model = ModelArgs( model_config=model_config, init_method=RandomInit( - std=math.sqrt(model_config.hidden_size), + std=1/math.sqrt(model_config.hidden_size), ), dtype=torch.bfloat16, ) @@ -170,11 +150,11 @@ learning_rate_scheduler = LRSchedulerArgs( learning_rate=1e-4, - lr_warmup_steps=10, + lr_warmup_steps=100, lr_warmup_style="linear", lr_decay_style="linear", - lr_decay_steps = 20, - lr_decay_starting_step= 80, + lr_decay_steps = 200, + lr_decay_starting_step= 1300, min_decay_lr=0, ) @@ -198,7 +178,7 @@ ) s3_upload = S3UploadArgs( - upload_s3_path=f"s3://elie-exp/debug_nanotron/eval-vf-hope/", + upload_s3_path=f"s3://elie-exp/debug_nanotron/better_init", remove_after_upload=True, s5cmd_numworkers=16, s5cmd_concurrency=5, @@ -231,17 +211,22 @@ data_stages=data_stages, s3_upload=s3_upload, lighteval=lighteval, - slurm=slurm, ) timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - dir = os.path.dirname(__file__) - os.makedirs(config.general.config_logs_path, exist_ok=True) + dir = os.path.dirname(__file__) + + # Create the necessary directories + project_log_folder = Path(general.logs_path) + log_folder = project_log_folder / f"{general.run}-{general.project}" + config_folder = log_folder / 'configs' + config_folder.mkdir(parents=True, exist_ok=True) + + config.general.config_logs_path = str(config_folder) + config_path_yaml = f"{config.general.config_logs_path}/{timestamp}_create.yaml" config.save_as_yaml(config_path_yaml) - os.makedirs(f"{config.general.slurm_logs_path}/", exist_ok=True) - - print(f"💾 Configuration saved to: {config_path_yaml}") + print(f"💾 Configuration saved in: {config.general.config_logs_path}") if args.launch: launcher_path = os.path.join(dir, "launcher.py") @@ -250,15 +235,14 @@ config_path_yaml, ] - if args.override: - launch_command.extend(["--override"] + args.override) + if args.slurm: + launch_command.append("--slurm") + + if args.nodes: + launch_command.extend(["--nodes", str(args.nodes)]) - print(f"Launching configuration with command: {' '.join(launch_command)}") + print(f"🧪 Launching configuration with command: {' '.join(launch_command)}") subprocess.run(launch_command, check=True) else: print("To launch this configuration, run:") - print(f"python {os.path.join(dir, 'launcher.py')} {config_path_yaml} " - f"--override general.config_path={config_path_yaml}") - - if args.override: - print(f" {' '.join(args.override)}") \ No newline at end of file + print(f"python {os.path.join(dir, 'launcher.py')} {config_path_yaml}") \ No newline at end of file diff --git a/launcher.py b/launcher.py index 496deb8c..10c5f645 100644 --- a/launcher.py +++ b/launcher.py @@ -1,10 +1,12 @@ import os +from pathlib import Path import subprocess import tempfile from datetime import datetime import torch - import argparse +import json +from jinja2 import Template from nanotron.logging import human_format @@ -32,7 +34,7 @@ def set_nested_attribute(obj, path, value): parts = path.split('.') for part in parts[:-1]: if not hasattr(obj, part): - setattr(obj, part, type('', (), {})()) # Create empty object if attribute doesn't exist + setattr(obj, part, type('', (), {})()) obj = getattr(obj, part) setattr(obj, parts[-1], value) @@ -42,8 +44,14 @@ def set_nested_attribute(obj, path, value): parser.add_argument("config_path", help="path to the configuration file", type=str) parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", help="Override config values. Use dot notation for nested keys.") + parser.add_argument("--slurm", action="store_true", help="Launch the job on Slurm") + parser.add_argument("--nodes", type=int, help="Number of nodes to use for the job") args = parser.parse_args() + if args.slurm: + if args.nodes is None: + raise ValueError("When using Slurm (--slurm), you must specify the number of nodes (--nodes)") + # Load the configuration using get_config_from_file config = get_config_from_file(args.config_path, config_class=Config) @@ -107,9 +115,9 @@ def set_nested_attribute(obj, path, value): └───────────────────────┴────────────────────────┘ """) - num_nodes = config.slurm.nodes if config.slurm else torch.cuda.device_count() + num_nodes = args.nodes if args.slurm else 1 print(f""" -🤖 Parallelism Configuration: +🎛️ Parallelism Configuration: ┌───────────────────────┬────────────────────────┐ │ Nodes │ {num_nodes:>22d} │ │ Total GPUs │ {config.parallelism.dp*config.parallelism.pp*config.parallelism.tp:>22d} │ @@ -155,116 +163,97 @@ def set_nested_attribute(obj, path, value): """) timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - if config.slurm: - dir = os.path.dirname(__file__) + if args.slurm: - os.makedirs(config.general.config_logs_path, exist_ok=True) - config_path_yaml = f"{config.general.config_logs_path}/{timestamp}_launch.yaml" - config.save_as_yaml(config_path_yaml) - - os.makedirs(f"{config.general.slurm_logs_path}/", exist_ok=True) + nodes = args.nodes + + launch_slurm_config_path = os.path.join(os.path.dirname(__file__), "src/nanotron/slurm/launch_slurm_config.json") + eval_slurm_config_path = os.path.join(os.path.dirname(__file__), "src/nanotron/slurm/eval_slurm_config.json") + + with open(launch_slurm_config_path, 'r') as f: + launch_slurm_config = json.load(f) + + with open(eval_slurm_config_path, 'r') as f: + eval_slurm_config = json.load(f) + + + total_gpus = config.parallelism.dp * config.parallelism.pp * config.parallelism.tp + gpus_per_node = launch_slurm_config.get('gpus_per_node') + required_nodes = (total_gpus + gpus_per_node - 1) // gpus_per_node # Ceiling division + + if args.nodes != required_nodes: + raise ValueError(f"Number of nodes in config ({args.nodes}) does not match the required number of nodes ({required_nodes}) based on the parallelism configuration.") - def format_sbatch_option(option, value): - return f"#SBATCH --{option}={value}" if value is not None else "" + + # Create necessary folders + project_log_folder = Path(config.general.logs_path) + log_folder = project_log_folder / f"{config.general.run}-{config.general.project}" + subfolders = ['launch-script', 'slurm'] + if hasattr(config, 'lighteval') and config.lighteval is not None: + subfolders.append('evals') + + for subfolder in subfolders: + folder_path = os.path.join(log_folder, subfolder) + os.makedirs(folder_path, exist_ok=True) + if subfolder == 'launch-script': + config.general.launch_script_path = folder_path + elif subfolder == 'slurm': + config.general.slurm_logs_path = folder_path + elif subfolder == 'evals': + config.general.evals_logs_path = folder_path + for evals_subfolder in ['launch-config', 'logs']: + evals_subfolder_path = os.path.join(config.general.evals_logs_path, evals_subfolder) + os.makedirs(evals_subfolder_path, exist_ok=True) + torchrun_args = "" - if hasattr(config.slurm, 'torchrun_args') and config.slurm.torchrun_args: - torchrun_args = " ".join([f"--{k} {v}" for k, v in config.slurm.torchrun_args.items()]) - - sbatch_script = f"""#!/bin/bash -{format_sbatch_option("job-name", config.slurm.job_name)} -{format_sbatch_option("nodes", config.slurm.nodes)} -{format_sbatch_option("ntasks-per-node", config.slurm.n_tasks_per_node)} -{format_sbatch_option("cpus-per-task", config.slurm.cpus_per_task)} -{format_sbatch_option("gres", f"gpu:{config.slurm.gpu_per_node}")} -{format_sbatch_option("partition", config.slurm.gpu_partition)} -{format_sbatch_option("output", f"{config.general.slurm_logs_path}/train-{timestamp}-%j.out")} -{format_sbatch_option("error", f"{config.general.slurm_logs_path}/train-{timestamp}-%j.err")} -{format_sbatch_option("qos", config.slurm.qos)} -{format_sbatch_option("mail-type", config.slurm.mail_type)} -{format_sbatch_option("mail-user", config.slurm.mail_user)} -{format_sbatch_option("exclude", ",".join(config.slurm.exclude_nodes) if config.slurm.exclude_nodes else None)} -{format_sbatch_option("time", config.slurm.time)} -{format_sbatch_option("mem", config.slurm.mem)} -{format_sbatch_option("constraint", config.slurm.constraint)} -{format_sbatch_option("account", config.slurm.account)} -{format_sbatch_option("reservation", config.slurm.reservation)} -{format_sbatch_option("begin", config.slurm.begin)} - -set -x -e - -TRAINER_PYTHON_FILE=/fsx/elie_bakouch/nanotron/run_train.py -nvidia-smi - - -#Show some environment variables -echo python3 version = `python3 --version` -echo "Python path: $(which python3)" -echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" -echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" - -echo "START TIME: $(date)" -secs_to_human(){{ - echo "$(( ${{1}} / 3600 )):$(( (${{1}} / 60) % 60 )):$(( ${{1}} % 60 ))" -}} -start=$(date +%s) -echo "$(date -d @${{start}} "+%Y-%m-%d %H:%M:%S"): ${{SLURM_JOB_NAME}} start id=${{SLURM_JOB_ID}}\n" - - -# SLURM stuff -export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` -export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) -export MASTER_PORT=6000 -export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` - -export TMPDIR=/scratch -export CUDA_DEVICE_MAX_CONNECTIONS="1" - -module load cuda/12.1 - -echo go $COUNT_NODE -echo $HOSTNAMES - -##### MOVE TO YAML ###### - -CMD=" $TRAINER_PYTHON_FILE \ - --config-file {config_path_yaml} \ - " -export LAUNCHER="torchrun \ - --nproc_per_node {config.slurm.gpu_per_node} \ - --nnodes $COUNT_NODE \ - {torchrun_args} \ - --node_rank $SLURM_PROCID \ - --role $SLURMD_NODENAME: \ - --max_restarts 0 \ - --tee 3 \ - " - -# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub -random_milliseconds=$(( RANDOM % 1001 )) -sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") -echo "Sleeping for $sleep_time seconds..." -sleep $sleep_time - -launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" - -srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" - - -echo "END TIME: $(date)" - """ - # Save the Slurm script - print(f"🚀 Slurm job launched with id={launch_slurm_job(sbatch_script)}") + if 'torchrun_args' in launch_slurm_config and launch_slurm_config['torchrun_args']: + torchrun_args = " ".join([f"--{k} {v}" for k, v in launch_slurm_config['torchrun_args'].items()]) + + launch_slurm_config.update({ + "job_name": f"{config.general.project}-{config.general.run}", + "nodes": args.nodes, + "slurm_logs_path": config.general.slurm_logs_path, + "timestamp": timestamp, + "path_to_trainer_python_file": os.path.join(os.path.dirname(__file__), "run_train.py"), + "config_path_yaml": f"{config.general.config_logs_path}/{timestamp}_launch.yaml", + "torchrun_args": torchrun_args, + }) + + # Load Jinja2 template + template_path = os.path.join(os.path.dirname(__file__), "src/nanotron/slurm/launch_training.slurm.jinja") + with open(template_path, 'r') as f: + template = Template(f.read()) + + # Render the template + sbatch_script = template.render(**launch_slurm_config) + + config.general.launch_slurm_config = launch_slurm_config + config.general.eval_slurm_config = eval_slurm_config + + config.save_as_yaml(launch_slurm_config["config_path_yaml"]) + + # Launch the Slurm job + job_id = launch_slurm_job(sbatch_script) + print(f"🚀 Slurm job launched with id={job_id}") + + # Save the Slurm script if a path is provided if config.general.launch_script_path: os.makedirs(config.general.launch_script_path, exist_ok=True) - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") script_filename = f"slurm_script_{timestamp}.slurm" script_path = os.path.join(config.general.launch_script_path, script_filename) with open(script_path, 'w') as f: f.write(sbatch_script) - print(f" 💾 Logs are saved to : {config.general.logs_path}") + print(f" 💾 Logs are saved to : {config.general.logs_path}/{config.general.run}-{config.general.project}") + print(f" 🤖 Slurm Configuration Details:") + + slurm_config_keys = ['qos', 'gpus_per_node', 'cpus_per_task', 'constraint', 'account', 'reservation'] + for key in slurm_config_keys: + if key in launch_slurm_config: + if launch_slurm_config[key] is not None: + print(f" {key}: {launch_slurm_config[key]}") else: # Check if running on an interactive node diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 760f2bd3..ad82553f 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -90,56 +90,6 @@ def __post_init__(self): self.text_column_name = "text" if self.hf_dataset_splits is None: self.hf_dataset_splits = "train" - -@dataclass -class SlurmArgs: - """ - Arguments for configuring SLURM job submission. - - Attributes: - gpu_partition (str): SLURM partition (queue) for GPU jobs. - job_name (str): Name of the SLURM job. - nodes (int): Number of nodes to allocate for the job. - n_tasks_per_node (int): Number of tasks to run per node. Default is 1. - cpus_per_task (int): Number of CPUs to allocate per task. Default is 32. - gpu_per_node (int): Number of GPUs to allocate per node. Default is 8. - array (Optional[str]): Job array specification, allowing multiple similar jobs to be submitted as a group. - qos (Optional[str]): Quality of Service, used to define job priority or resource limits. - mail_type (Optional[str]): Specifies when to send email notifications about the job (e.g., BEGIN, END, FAIL). Default is FAIL. - mail_user (Optional[str]): Email address to receive job notifications. - exclude_nodes (Optional[List[str]]): List of nodes to exclude from job allocation. - time (Optional[str]): Maximum time limit for the job. - mem (Optional[str]): Memory requirement for the job. - constraint (Optional[str]): Specifies node features required for the job. - account (Optional[str]): Account to charge for the job's resource usage. - reservation (Optional[str]): Name of a reservation to use for the job. - begin (Optional[str]): Earliest time the job can start. - torchrun_args (Optional[Dict[str, str]]): Additional arguments for torchrun command. - """ - - gpu_partition: str - job_name: str - nodes: int - n_tasks_per_node: int = 1 - cpus_per_task: int = 32 - gpu_per_node: int = 8 - array: Optional[str] = None - qos: Optional[str] = None - mail_user: Optional[str] = None - mail_type: Optional[str] = None - exclude_nodes: Optional[List[str]] = None - time: Optional[str] = None - mem: Optional[str] = None - constraint: Optional[str] = None - account: Optional[str] = None - reservation: Optional[str] = None - begin: Optional[str] = None - torchrun_args: Optional[Dict[str, str]] = None - - def __post_init__(self): - if self.mail_type is None and self.mail_user is not None: - self.mail_type = "FAIL" - @dataclass class S3UploadArgs: """Arguments related to uploading checkpoints on s3""" @@ -236,6 +186,8 @@ class GeneralArgs: project: str run: str logs_path: Optional[str] = "./logs" + launch_slurm_config: Optional[dict] = None + eval_slurm_config: Optional[dict] = None launch_script_path: Optional[str] = None slurm_logs_path: Optional[str] = None config_logs_path: Optional[str] = None @@ -410,7 +362,6 @@ class Config: profiler: Optional[ProfilerArgs] = None lighteval: Optional[LightEvalConfig] = None s3_upload: Optional[S3UploadArgs] = None - slurm: Optional[SlurmArgs] = None @classmethod def create_empty(cls): @@ -457,36 +408,6 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" - - project_log_folder = Path(self.general.logs_path) - os.makedirs(project_log_folder, exist_ok=True) - - log_folder = os.path.join(project_log_folder, f"{self.general.run}-{self.general.project}") - os.makedirs(log_folder, exist_ok=True) - - # Create config folder for all jobs - config_folder = os.path.join(log_folder, 'configs') - os.makedirs(config_folder, exist_ok=True) - self.general.config_logs_path = config_folder - - if self.slurm is not None: - subfolders = ['slurm'] - if self.lighteval is not None and self.s3_upload is not None: - subfolders.append('evals') - for subfolder in subfolders: - folder_path = os.path.join(log_folder, subfolder) - os.makedirs(folder_path, exist_ok=True) - setattr(self.general, f"{subfolder}_logs_path", folder_path) - - if subfolder == 'evals': - for evals_subfolder in ['launch-config', 'logs']: - evals_subfolder_path = os.path.join(folder_path, evals_subfolder) - os.makedirs(evals_subfolder_path, exist_ok=True) - - # Create launch-script folder - launch_script_folder = os.path.join(log_folder, 'launch-script') - os.makedirs(launch_script_folder, exist_ok=True) - self.general.launch_script_path = launch_script_folder # if lighteval, we need tokenizer to be defined diff --git a/src/nanotron/lighteval/one_job_runner.py b/src/nanotron/lighteval/one_job_runner.py index 8b00b1f7..408dd739 100644 --- a/src/nanotron/lighteval/one_job_runner.py +++ b/src/nanotron/lighteval/one_job_runner.py @@ -5,13 +5,14 @@ import re import subprocess from typing import List, Optional, Tuple, Union +import copy import jinja2 from nanotron import logging from nanotron.logging import log_rank from nanotron.parallel import ParallelContext -from nanotron.config import Config +from nanotron.config import Config, LightEvalConfig logger = logging.get_logger(__name__) @@ -75,6 +76,7 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: slurm_job_id, slurm_log = run_slurm_one_job( config = self.config, + lighteval_config = self.lighteval_config, slurm_template=self.lighteval_config.slurm_template, model_checkpoint_path=checkpoint_path, current_step=self.config.general.step, @@ -86,13 +88,13 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: def run_slurm_one_job( config: Config, + lighteval_config: LightEvalConfig, model_checkpoint_path: str, slurm_template: str, current_step: int, s3: bool = True, checkpoint_local_path: str = None, slurm_name: Optional[str] = "eval", - slurm_kwargs: Optional[dict] = None, #add slurm_kwargs and modify the jinja template in case you need to adapt it to your slurm cluster. ): """Launch a single job on Slurm with the given mapping Args: @@ -113,32 +115,27 @@ def run_slurm_one_job( with open(slurm_template, "r") as f: SLURM_JOBS_ARRAY_TEMPLATE = environment.from_string(f.read()) - if s3: - launch_string = SLURM_JOBS_ARRAY_TEMPLATE.render( - model_checkpoint_path=model_checkpoint_path, - job_name=f"{slurm_name}", - n_tasks_per_node=config.slurm.n_tasks_per_node, - partition=config.slurm.gpu_partition, - gpu_per_node=config.slurm.gpu_per_node, - cpus_per_task=config.slurm.cpus_per_task, - eval_path=eval_logs_path, - mail=config.slurm.mail_user, - local_path=config.lighteval.temp_dir, - **(slurm_kwargs if slurm_kwargs else {}), - ) - else: - launch_string = SLURM_JOBS_ARRAY_TEMPLATE.render( - model_checkpoint_path=model_checkpoint_path, - job_name=f"{slurm_name}", - n_tasks_per_node=config.slurm.n_tasks_per_node, - partition=config.slurm.gpu_partition, - gpu_per_node=config.slurm.gpu_per_node, - cpus_per_task=config.slurm.cpus_per_task, - eval_path=eval_logs_path, - mail=config.slurm.mail_user, - ckpt_local_path=checkpoint_local_path, - **(slurm_kwargs if slurm_kwargs else {}), - ) + + #Not sure if this is the best way to do it. Maybe we need to add a copy or deepcopy somewhere. + eval_slurm_config = config.general.eval_slurm_config + + # Update the config with additional required parameters + # Calculate the number of nodes based on parallelism config and gpus_per_node + total_gpus_needed = lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp + gpus_per_node = eval_slurm_config.get('gpus_per_node') + nodes = (total_gpus_needed + gpus_per_node - 1) // gpus_per_node # Ceiling division + + eval_slurm_config.update({ + 'nodes': nodes, # Assuming we want to run on a single node + 'job_name': f"eval-{current_step}", + 'eval_path': eval_logs_path, + 'local_path': config.lighteval.temp_dir if s3 else checkpoint_local_path, + 'hf_user_or_org': config.logging.hf_user_or_org if hasattr(config.logging, 'hf_user_or_org') else None, + "model_checkpoint_path": model_checkpoint_path, + }) + + + launch_string = SLURM_JOBS_ARRAY_TEMPLATE.render(**eval_slurm_config) match = re.match(r"#SBATCH --output=(.*)", launch_string) slurm_output_path = match.group(1) if match else "" diff --git a/src/nanotron/slurm/eval_slurm_config.json b/src/nanotron/slurm/eval_slurm_config.json new file mode 100644 index 00000000..58259a79 --- /dev/null +++ b/src/nanotron/slurm/eval_slurm_config.json @@ -0,0 +1,28 @@ +{ + "job_name": "", + "n_tasks_per_node": 1, + "cpus_per_task": 32, + "gpus_per_node": 8, + "partition": "hopper-prod", + "qos": "high", + "mail_type": null, + "mail_user": null, + "exclude_nodes": null, + "time": "1:00:00", + "constraint": null, + "account": null, + "reservation": null, + "torchrun_args": { + "rdzv_backend": "etcd-v2", + "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", + "rdzv_id": "$SLURM_JOB_ID", + "node_rank": "$SLURM_PROCID", + "role": "$SLURMD_NODENAME", + "max_restarts": 0, + "tee": 3 + }, + "array": null, + "mem": null, + "begin": null + } + \ No newline at end of file diff --git a/src/nanotron/slurm/launch_slurm_config.json b/src/nanotron/slurm/launch_slurm_config.json new file mode 100644 index 00000000..ed216852 --- /dev/null +++ b/src/nanotron/slurm/launch_slurm_config.json @@ -0,0 +1,27 @@ +{ + "job_name": "", + "n_tasks_per_node": 1, + "cpus_per_task": 32, + "gpus_per_node": 8, + "partition": "hopper-prod", + "qos": "high", + "mail_type": null, + "mail_user": null, + "exclude_nodes": null, + "time": "1:00:00", + "constraint": null, + "account": null, + "reservation": null, + "torchrun_args": { + "rdzv_backend": "etcd-v2", + "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", + "rdzv_id": "$SLURM_JOB_ID", + "node_rank": "$SLURM_PROCID", + "role": "$SLURMD_NODENAME", + "max_restarts": 0, + "tee": 3 + }, + "array": null, + "mem": null, + "begin": null +} diff --git a/src/nanotron/slurm/launch_training.slurm.jinja b/src/nanotron/slurm/launch_training.slurm.jinja new file mode 100644 index 00000000..9ba308ea --- /dev/null +++ b/src/nanotron/slurm/launch_training.slurm.jinja @@ -0,0 +1,92 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes={{ nodes }} +#SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --gres=gpu:{{ gpus_per_node }} +{% if cpus_per_task %} +#SBATCH --cpus-per-task={{ cpus_per_task }} +{% endif %} +#SBATCH --partition={{ partition }} +#SBATCH --output={{ slurm_logs_path }}/train-{{ timestamp }}-%j.out +#SBATCH --error={{ slurm_logs_path }}/train-{{ timestamp }}-%j.err +{% if qos %} +#SBATCH --qos={{ qos }} +{% endif %} +{% if mail_type %} +#SBATCH --mail-type={{ mail_type }} +{% endif %} +{% if mail_user %} +#SBATCH --mail-user={{ mail_user }} +{% endif %} +{% if exclude_nodes %} +#SBATCH --exclude={{ exclude_nodes|join(',') }} +{% endif %} +{% if time %} +#SBATCH --time={{ time }} +{% endif %} +{% if constraint %} +#SBATCH --constraint={{ constraint }} +{% endif %} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +{% if reservation %} +#SBATCH --reservation={{ reservation }} +{% endif %} + +set -e + +TRAINER_PYTHON_FILE={{ path_to_trainer_python_file }} +nvidia-smi + +# Show some environment variables +echo python3 version = `python3 --version` +echo "Python path: $(which python3)" +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +echo "START TIME: $(date)" +secs_to_human() { + echo "$(( ${1} / 3600 )):$(( (${1} / 60) % 60 )):$(( ${1} % 60 ))" +} +start=$(date +%s) +echo "$(date -d @${start} "+%Y-%m-%d %H:%M:%S"): ${SLURM_JOB_NAME} start id=${SLURM_JOB_ID}\n" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +export TMPDIR=/scratch +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +module load cuda/12.1 + +echo go $COUNT_NODE +echo $HOSTNAMES + +CMD=" $TRAINER_PYTHON_FILE \ + --config-file {{ config_path_yaml }} \ + " +export LAUNCHER="torchrun \ + --nproc_per_node {{ gpus_per_node }} \ + --nnodes $COUNT_NODE \ + {{ torchrun_args }} \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: \ + --max_restarts 0 \ + --tee 3 \ + " + +# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub +random_milliseconds=$(( RANDOM % 1001 )) +sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") +echo "Sleeping for $sleep_time seconds..." +sleep $sleep_time + +launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +echo "END TIME: $(date)" \ No newline at end of file diff --git a/src/nanotron/slurm/run_eval.slurm.jinja b/src/nanotron/slurm/run_eval.slurm.jinja new file mode 100644 index 00000000..10b61cb4 --- /dev/null +++ b/src/nanotron/slurm/run_eval.slurm.jinja @@ -0,0 +1,86 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes={{ nodes }} +#SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --gres=gpu:{{ gpus_per_node }} +{% if cpus_per_task %} +#SBATCH --cpus-per-task={{ cpus_per_task }} +{% endif %} +#SBATCH --partition={{ partition }} +#SBATCH --output={{ eval_path }}/%x-%n-%j.out +#SBATCH --error={{ eval_path }}/%x-%n-%j.err +{% if qos %} +#SBATCH --qos={{ qos }} +{% endif %} +{% if mail_type %} +#SBATCH --mail-type={{ mail_type }} +{% endif %} +{% if mail_user %} +#SBATCH --mail-user={{ mail_user }} +{% endif %} +{% if exclude_nodes %} +#SBATCH --exclude={{ exclude_nodes|join(',') }} +{% endif %} +{% if time %} +#SBATCH --time={{ time }} +{% endif %} +{% if constraint %} +#SBATCH --constraint={{ constraint }} +{% endif %} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +{% if reservation %} +#SBATCH --reservation={{ reservation }} +{% endif %} + +LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={{ local_path }} + +set -e +echo "START TIME: $(date)" +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +# Hugging Face token +if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then + # Attempt to read the token from the cache + if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then + export HUGGING_FACE_HUB_TOKEN=$TOKEN + else + echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." + exit 1 + fi +fi + +export CUBLAS_WORKSPACE_CONFIG=":4096:8" +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +echo go $COUNT_NODE +echo $HOSTNAMES + +# Copying checkpoint from s3 to the node on node +mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER +s5cmd cp --exclude "optimizer/*" {{ model_checkpoint_path }}* $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER + +torch_dist_args="--nproc_per_node 8 \ + --nnodes $COUNT_NODE \ + --max_restarts 0 \ + --tee 3 \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: " + +launch_args="$torch_dist_args \ + /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path ${LOCAL_DOWNLOAD_CHECKPOINT_FOLDER}/config.yaml \ + --hf-user-or-org {{ hf_user_or_org }} \ + " + +srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file diff --git a/src/nanotron/slurm/run_eval_s3.slurm.jinja b/src/nanotron/slurm/run_eval_s3.slurm.jinja new file mode 100644 index 00000000..a3e9bb28 --- /dev/null +++ b/src/nanotron/slurm/run_eval_s3.slurm.jinja @@ -0,0 +1,62 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --partition={{ partition }} +#SBATCH --gres=gpu:{{ gpu_per_node }} +#SBATCH --cpus-per-task={{ cpus_per_task}} +#SBATCH --output={{ eval_path }}/%x-%n-%j.out +#SBATCH --error={{ eval_path }}/%x-%n-%j.err +#SBATCH --qos=high +#SBATCH --dependency=singleton +#SBATCH --mail-type=FAIL +#SBATCH --mail-user={{ mail }} +#SBATCH --time=01:00:00 + + +CHECKPOINT_FOLDER={{ ckpt_local_path }} + +set -x -e +echo "START TIME: $(date)" +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +# Hugging Face token +if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then + # Attempt to read the token from the cache + if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then + export HUGGING_FACE_HUB_TOKEN=$TOKEN + else + echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." + exit 1 + fi +fi + +export CUBLAS_WORKSPACE_CONFIG=":4096:8" +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +echo go $COUNT_NODE +echo $HOSTNAMES + +torch_dist_args="--nproc_per_node 8 \ + --nnodes $COUNT_NODE \ + --max_restarts 0 \ + --tee 3 \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: " + +launch_args="$torch_dist_args \ + /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path ${CHECKPOINT_FOLDER}/config.yaml \ + --hf-user-or-org {{ hf_user_or_org }} \ + " + +srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index dde0d83c..7941a27f 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -281,9 +281,9 @@ def post_init(self): if self.config.lighteval is not None and dist.get_rank(self.parallel_context.world_pg) == 0: # We only start evaluation runs once on the first node self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) - if self.s3_mover is not None and self.config.slurm is not None: + if self.s3_mover is not None and self.config.general.eval_slurm_config is not None: self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint - elif self.config.slurm is not None and self.s3_mover is None: + elif self.config.general.eval_slurm_config is not None and self.s3_mover is None: # Use the no_s3 version of the evaluation function self.post_checkpoint_callback = self.lighteval_runner.eval_single_checkpoint_no_s3 else: From 0d43a9564ddf1bba1dbbacb314ad65833618631c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Sun, 1 Sep 2024 01:43:03 +0000 Subject: [PATCH 16/43] update wandb restart logic + logging the id and project to pass it tolighteval and log the eval to wandb --- src/nanotron/config/config.py | 2 ++ src/nanotron/trainer.py | 20 ++++++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index ad82553f..fdc5bc48 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -198,6 +198,8 @@ class GeneralArgs: consumed_train_samples: Optional[int] = None benchmark_csv_path: Optional[Path] = None ignore_sanity_checks: bool = True + wandb_id: Optional[str] = None + wandb_project: Optional[str] = None def __post_init__(self): if self.seed is None: diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 7941a27f..d42671d1 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -313,13 +313,19 @@ def pre_training(self, *args, **kwargs): ) current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S") - if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: + if dist.get_rank(self.parallel_context.world_pg) == 0 and wandb is not None: wandb.init( project=self.config.general.project, name=f"{current_time}_{self.config.general.run}", config={"nanotron_config": self.config.as_dict()}, ) - + # In case we resume from another run + initial_step = getattr(self.config.general, 'step', None) + if initial_step is not None: + wandb.run.step = initial_step + else: + wandb.run.step = 0 # Start from 0, will become 1 on first log + def post_train_step(self): # Update our background upload/removal of checkpoints @@ -877,6 +883,14 @@ def setup_log_writers( return loggerwriter def pre_save_checkpoint(self) -> Path: + if wandb is not None and dist.get_rank(self.parallel_context.dp_pg) == 0: + if self.config.general.wandb_id is None: + self.config.general.wandb_id = wandb.run.id + self.config.general.wandb_project = wandb.run.project + elif self.config.general.wandb_id is not None and self.config.general.wandb_id!= wandb.run.id: + log_rank("Update the wandb run due too resume from checkpoint", logger=logger, level=logging.WARNING, rank=0) + self.config.general.wandb_id = wandb.run.id + self.config.general.wandb_project = wandb.run.project if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) if self.s3_mover.post_upload_callback_outputs is not None: @@ -892,6 +906,8 @@ def post_save_checkpoint(self): checkpoint_path = self.config.checkpoints.checkpoints_path / f"{self.iteration_step}" self.post_checkpoint_callback(checkpoint_path) + + def save_checkpoint(self) -> Path: self.pre_save_checkpoint() From 207797ef6e5baef1b94bba5b5c695fc93542e4a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Sun, 1 Sep 2024 22:50:31 +0000 Subject: [PATCH 17/43] better wandb loggin, s3upload only for dl ckpt, correct Path and xPath, solve serializaton pb of xPath --- create_config.py | 37 +++++---- src/nanotron/config/config.py | 26 ++++--- src/nanotron/config/utils_config.py | 3 + src/nanotron/lighteval/one_job_runner.py | 32 ++++---- src/nanotron/lighteval/run_eval.slurm.jinja | 66 ---------------- .../lighteval/run_eval_no_s3.slurm.jinja | 62 --------------- src/nanotron/slurm/eval_slurm_config.json | 1 + src/nanotron/slurm/launch_slurm_config.json | 7 +- .../slurm/launch_training.slurm.jinja | 5 ++ src/nanotron/slurm/run_eval.slurm.jinja | 12 +-- src/nanotron/slurm/run_eval_s3.slurm.jinja | 53 ++++++++++--- src/nanotron/trainer.py | 75 ++++++++++--------- src/nanotron/utils.py | 5 +- 13 files changed, 158 insertions(+), 226 deletions(-) delete mode 100644 src/nanotron/lighteval/run_eval.slurm.jinja delete mode 100644 src/nanotron/lighteval/run_eval_no_s3.slurm.jinja diff --git a/create_config.py b/create_config.py index 2f7b75bf..fad942f0 100644 --- a/create_config.py +++ b/create_config.py @@ -44,7 +44,7 @@ parser.add_argument("--nodes", help="specify the number of nodes", type=int) args = parser.parse_args() - + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") general = GeneralArgs( project=args.project, run=args.run, @@ -90,7 +90,7 @@ ), batch_size=16, logging=LightEvalLoggingArgs( - local_output_path=f"/fsx/elie_bakouch/lighteval-logs/{general.project}-{general.run}", + local_output_path=f"/fsx/elie_bakouch/refactor-lighteval-logs/{general.project}-{general.run}", #local_output_path=PATH_TO_LOCAL_LOG, private=True, push_details_to_hub=True, @@ -108,16 +108,16 @@ checkpoints = CheckpointsArgs( - checkpoints_path=f"/scratch/elie_bakouch/checkpoints/{general.project}-{general.run}", + checkpoints_path=f"/fsx/elie_bakouch/refactor-checkpoints/{general.project}-{general.run}", #checkpoints_path="CHECKPOINTS_PATH", checkpoints_path_is_shared_file_system=False, resume_checkpoint_path=None, - checkpoint_interval=500, + checkpoint_interval=20, save_initial_state=False, ) parallelism = ParallelismArgs( - dp=32, + dp=8, pp=1, tp=1, pp_engine="1f1b", @@ -126,10 +126,10 @@ ) tokens = TokensArgs( - batch_accumulation_per_replica=2, + batch_accumulation_per_replica=8, micro_batch_size=16, sequence_length=2048, - train_steps=1500, + train_steps=100, val_check_interval=-1, ) @@ -150,11 +150,11 @@ learning_rate_scheduler = LRSchedulerArgs( learning_rate=1e-4, - lr_warmup_steps=100, + lr_warmup_steps=10, lr_warmup_style="linear", lr_decay_style="linear", - lr_decay_steps = 200, - lr_decay_starting_step= 1300, + lr_decay_steps = 20, + lr_decay_starting_step= 80, min_decay_lr=0, ) @@ -177,14 +177,14 @@ tokenizer_name_or_path="HuggingFaceTB/cosmo2-tokenizer", ) - s3_upload = S3UploadArgs( - upload_s3_path=f"s3://elie-exp/debug_nanotron/better_init", - remove_after_upload=True, - s5cmd_numworkers=16, - s5cmd_concurrency=5, - s5cmd_path="/fsx/elie_bakouch/miniconda3/envs/smollm/bin/s5cmd", - ) - + # s3_upload = S3UploadArgs( + # upload_s3_path=f"s3://elie-exp/debug_nanotron/{general.project}-{general.run}-{timestamp}", + # remove_after_upload=True, + # s5cmd_numworkers=16, + # s5cmd_concurrency=5, + # s5cmd_path="/fsx/elie_bakouch/miniconda3/envs/smollm/bin/s5cmd", + # ) + s3_upload = None data_stages=[ DatasetStageArgs( data=DataArgs( @@ -212,7 +212,6 @@ s3_upload=s3_upload, lighteval=lighteval, ) - timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") dir = os.path.dirname(__file__) # Create the necessary directories diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index fdc5bc48..36698ab7 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -1,6 +1,7 @@ import datetime import os from dataclasses import dataclass, fields +import pathlib from pathlib import Path from datasets.download.streaming_download_manager import xPath from typing import List, Optional, Type, Union, Dict @@ -20,6 +21,7 @@ cast_str_to_torch_dtype, serialize, ) +from nanotron.s3_checkpoints import check_path_is_local from nanotron.generation.sampler import SamplerType from nanotron.logging import get_logger from nanotron.parallel.pipeline_parallel.engine import PipelineEngine @@ -94,17 +96,17 @@ def __post_init__(self): class S3UploadArgs: """Arguments related to uploading checkpoints on s3""" - upload_s3_path: xPath + upload_s3_path: Optional[str] = None #set to None if we want to use S3UploadArgs to download checkpoints from s3 but not upload checkpoints on s3 remove_after_upload: bool s5cmd_numworkers: Optional[int] s5cmd_concurrency: Optional[int] - s5cmd_path: Optional[xPath] + s5cmd_path: Optional[str] def __post_init__(self): - if isinstance(self.upload_s3_path, str): + if isinstance(self.upload_s3_path, str) and self.upload_s3_path is not None: self.upload_s3_path = xPath(self.upload_s3_path) if isinstance(self.s5cmd_path, str): - self.s5cmd_path = xPath(self.s5cmd_path) + self.s5cmd_path = Path(self.s5cmd_path) @dataclass class NanosetDatasetsArgs: @@ -157,18 +159,21 @@ class CheckpointsArgs: """ - checkpoints_path: Path + checkpoints_path: str checkpoint_interval: int save_initial_state: Optional[bool] = False save_final_state: Optional[bool] = False - resume_checkpoint_path: Optional[xPath] = None + resume_checkpoint_path: Optional[str] = None checkpoints_path_is_shared_file_system: Optional[bool] = False def __post_init__(self): if isinstance(self.checkpoints_path, str): - self.checkpoints_path = xPath(self.checkpoints_path) + self.checkpoints_path = Path(self.checkpoints_path) if isinstance(self.resume_checkpoint_path, str): - self.resume_checkpoint_path = xPath(self.resume_checkpoint_path) + if check_path_is_local(self.resume_checkpoint_path): + self.resume_checkpoint_path = Path(self.resume_checkpoint_path) + else: + self.resume_checkpoint_path = xPath(self.resume_checkpoint_path) @dataclass @@ -421,7 +426,9 @@ def __post_init__(self): def global_batch_size(self): return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp + def save_as_yaml(self, file_path: str): + config_dict = serialize(self) file_path = str(file_path) with open(file_path, "w") as f: @@ -492,10 +499,11 @@ def get_config_from_file( skip_unused_config_keys: whether to skip unused first-nesting-level keys in the config file (for config with additional sections) skip_null_keys: whether to skip keys with value None at first and second nesting level """ - # Open the file and load the file + with open(config_path) as f: config_dict = yaml.load(f, Loader=SafeLoader) + config = get_config_from_dict( config_dict, config_class=config_class, diff --git a/src/nanotron/config/utils_config.py b/src/nanotron/config/utils_config.py index f4c07146..124516cd 100644 --- a/src/nanotron/config/utils_config.py +++ b/src/nanotron/config/utils_config.py @@ -1,6 +1,7 @@ from dataclasses import fields from enum import Enum, auto from pathlib import Path +from datasets.download.streaming_download_manager import xPath import torch @@ -31,6 +32,8 @@ def serialize(data) -> dict: value = getattr(data, field.name) if hasattr(value, "__dataclass_fields__"): result[field.name] = serialize(value) + elif isinstance(value, xPath): + result[field.name] = str(value) elif isinstance(value, Path): result[field.name] = str(value) elif isinstance(value, PipelineEngine): diff --git a/src/nanotron/lighteval/one_job_runner.py b/src/nanotron/lighteval/one_job_runner.py index 408dd739..e6ceafb9 100644 --- a/src/nanotron/lighteval/one_job_runner.py +++ b/src/nanotron/lighteval/one_job_runner.py @@ -23,9 +23,7 @@ def __init__(self, config: Config, parallel_context: Optional[ParallelContext] = self.lighteval_config = config.lighteval self.parallel_context = parallel_context - def eval_single_checkpoint_no_s3(self, checkpoints_folder, current_step) -> Tuple[str, str]: - current_checkpoint_folder = os.path.join(checkpoints_folder, str(current_step)) - checkpoint_path = os.path.join(current_checkpoint_folder, "config.yaml") + def eval_single_checkpoint_no_s3(self, checkpoint_path: str) -> Tuple[str, str]: if not os.path.exists(checkpoint_path): log_rank( f"Checkpoint path does not exist: {checkpoint_path}. Unable to evaluate.", @@ -38,10 +36,10 @@ def eval_single_checkpoint_no_s3(self, checkpoints_folder, current_step) -> Tupl slurm_job_id, slurm_log = run_slurm_one_job( config = self.config, + lighteval_config = self.lighteval_config, slurm_template=self.lighteval_config.slurm_template, model_checkpoint_path=checkpoint_path, current_step=self.config.general.step, - checkpoint_local_path=current_checkpoint_folder, s3=False, ) @@ -93,7 +91,6 @@ def run_slurm_one_job( slurm_template: str, current_step: int, s3: bool = True, - checkpoint_local_path: str = None, slurm_name: Optional[str] = "eval", ): """Launch a single job on Slurm with the given mapping @@ -125,14 +122,23 @@ def run_slurm_one_job( gpus_per_node = eval_slurm_config.get('gpus_per_node') nodes = (total_gpus_needed + gpus_per_node - 1) // gpus_per_node # Ceiling division - eval_slurm_config.update({ - 'nodes': nodes, # Assuming we want to run on a single node - 'job_name': f"eval-{current_step}", - 'eval_path': eval_logs_path, - 'local_path': config.lighteval.temp_dir if s3 else checkpoint_local_path, - 'hf_user_or_org': config.logging.hf_user_or_org if hasattr(config.logging, 'hf_user_or_org') else None, - "model_checkpoint_path": model_checkpoint_path, - }) + if s3: + eval_slurm_config.update({ + 'nodes': nodes, # Assuming we want to run on a single node + 'job_name': f"eval-{current_step}", + 'eval_path': eval_logs_path, + 'local_path': config.lighteval.temp_dir, + 'hf_user_or_org': config.logging.hf_user_or_org if hasattr(config.logging, 'hf_user_or_org') else None, + "model_checkpoint_path": model_checkpoint_path, + }) + else: + eval_slurm_config.update({ + 'nodes': nodes, # Assuming we want to run on a single node + 'job_name': f"eval-{current_step}", + 'eval_path': eval_logs_path, + 'hf_user_or_org': config.logging.hf_user_or_org if hasattr(config.logging, 'hf_user_or_org') else None, + "model_checkpoint_path": model_checkpoint_path, + }) launch_string = SLURM_JOBS_ARRAY_TEMPLATE.render(**eval_slurm_config) diff --git a/src/nanotron/lighteval/run_eval.slurm.jinja b/src/nanotron/lighteval/run_eval.slurm.jinja deleted file mode 100644 index dc6a6ab6..00000000 --- a/src/nanotron/lighteval/run_eval.slurm.jinja +++ /dev/null @@ -1,66 +0,0 @@ -#!/bin/bash -#SBATCH --job-name={{ job_name }} -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node={{ n_tasks_per_node }} -#SBATCH --partition={{ partition }} -#SBATCH --gres=gpu:{{ gpu_per_node }} -#SBATCH --cpus-per-task={{ cpus_per_task}} -#SBATCH --output={{ eval_path }}/%x-%n-%j.out -#SBATCH --error={{ eval_path }}/%x-%n-%j.err -#SBATCH --qos=high -#SBATCH --dependency=singleton -#SBATCH --mail-type=FAIL -#SBATCH --mail-user={{ mail }} -#SBATCH --time=01:00:00 - - -LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={{ local_path }} - -set -x -e -echo "START TIME: $(date)" -#Show some environment variables -echo python3 version = `python3 --version` -echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" -echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" - -# SLURM stuff -export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` -export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) -export MASTER_PORT=6000 -export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` - -# Hugging Face token -if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then - # Attempt to read the token from the cache - if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then - export HUGGING_FACE_HUB_TOKEN=$TOKEN - else - echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." - exit 1 - fi -fi - -export CUBLAS_WORKSPACE_CONFIG=":4096:8" -export CUDA_DEVICE_MAX_CONNECTIONS="1" - -echo go $COUNT_NODE -echo $HOSTNAMES - -# Copying checkpoint from s3 to the node on node -mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER -s5cmd cp --exclude "optimizer/*" {{ model_checkpoint_path }}* $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER - -torch_dist_args="--nproc_per_node 8 \ - --nnodes $COUNT_NODE \ - --max_restarts 0 \ - --tee 3 \ - --node_rank $SLURM_PROCID \ - --role $SLURMD_NODENAME: " - -launch_args="$torch_dist_args \ - /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ - --checkpoint-config-path ${LOCAL_DOWNLOAD_CHECKPOINT_FOLDER}/config.yaml \ - --hf-user-or-org {{ hf_user_or_org }} \ - " - -srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file diff --git a/src/nanotron/lighteval/run_eval_no_s3.slurm.jinja b/src/nanotron/lighteval/run_eval_no_s3.slurm.jinja deleted file mode 100644 index a3e9bb28..00000000 --- a/src/nanotron/lighteval/run_eval_no_s3.slurm.jinja +++ /dev/null @@ -1,62 +0,0 @@ -#!/bin/bash -#SBATCH --job-name={{ job_name }} -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node={{ n_tasks_per_node }} -#SBATCH --partition={{ partition }} -#SBATCH --gres=gpu:{{ gpu_per_node }} -#SBATCH --cpus-per-task={{ cpus_per_task}} -#SBATCH --output={{ eval_path }}/%x-%n-%j.out -#SBATCH --error={{ eval_path }}/%x-%n-%j.err -#SBATCH --qos=high -#SBATCH --dependency=singleton -#SBATCH --mail-type=FAIL -#SBATCH --mail-user={{ mail }} -#SBATCH --time=01:00:00 - - -CHECKPOINT_FOLDER={{ ckpt_local_path }} - -set -x -e -echo "START TIME: $(date)" -#Show some environment variables -echo python3 version = `python3 --version` -echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" -echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" - -# SLURM stuff -export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` -export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) -export MASTER_PORT=6000 -export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` - -# Hugging Face token -if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then - # Attempt to read the token from the cache - if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then - export HUGGING_FACE_HUB_TOKEN=$TOKEN - else - echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." - exit 1 - fi -fi - -export CUBLAS_WORKSPACE_CONFIG=":4096:8" -export CUDA_DEVICE_MAX_CONNECTIONS="1" - -echo go $COUNT_NODE -echo $HOSTNAMES - -torch_dist_args="--nproc_per_node 8 \ - --nnodes $COUNT_NODE \ - --max_restarts 0 \ - --tee 3 \ - --node_rank $SLURM_PROCID \ - --role $SLURMD_NODENAME: " - -launch_args="$torch_dist_args \ - /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ - --checkpoint-config-path ${CHECKPOINT_FOLDER}/config.yaml \ - --hf-user-or-org {{ hf_user_or_org }} \ - " - -srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file diff --git a/src/nanotron/slurm/eval_slurm_config.json b/src/nanotron/slurm/eval_slurm_config.json index 58259a79..4cf78cf1 100644 --- a/src/nanotron/slurm/eval_slurm_config.json +++ b/src/nanotron/slurm/eval_slurm_config.json @@ -21,6 +21,7 @@ "max_restarts": 0, "tee": 3 }, + "hf_cache": "/fsx/elie_bakouch/.cache", "array": null, "mem": null, "begin": null diff --git a/src/nanotron/slurm/launch_slurm_config.json b/src/nanotron/slurm/launch_slurm_config.json index ed216852..e06cca36 100644 --- a/src/nanotron/slurm/launch_slurm_config.json +++ b/src/nanotron/slurm/launch_slurm_config.json @@ -1,14 +1,14 @@ { "job_name": "", "n_tasks_per_node": 1, - "cpus_per_task": 32, + "cpus_per_task": 60, "gpus_per_node": 8, "partition": "hopper-prod", "qos": "high", "mail_type": null, "mail_user": null, - "exclude_nodes": null, - "time": "1:00:00", + "exclude_nodes": ["ip-26-0-161-138"], + "time": "01:30:00", "constraint": null, "account": null, "reservation": null, @@ -21,6 +21,7 @@ "max_restarts": 0, "tee": 3 }, + "hf_cache": "/fsx/elie_bakouch/.cache", "array": null, "mem": null, "begin": null diff --git a/src/nanotron/slurm/launch_training.slurm.jinja b/src/nanotron/slurm/launch_training.slurm.jinja index 9ba308ea..7ede5993 100644 --- a/src/nanotron/slurm/launch_training.slurm.jinja +++ b/src/nanotron/slurm/launch_training.slurm.jinja @@ -61,6 +61,11 @@ export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` export TMPDIR=/scratch export CUDA_DEVICE_MAX_CONNECTIONS="1" +export HUGGINGFACE_HUB_CACHE={{ hf_cache }} +export HF_DATASETS_CACHE={{ hf_cache }} +export HF_MODULES_CACHE={{ hf_cache }} +export HF_HOME={{ hf_cache }} + module load cuda/12.1 echo go $COUNT_NODE diff --git a/src/nanotron/slurm/run_eval.slurm.jinja b/src/nanotron/slurm/run_eval.slurm.jinja index 10b61cb4..8cd3ee5a 100644 --- a/src/nanotron/slurm/run_eval.slurm.jinja +++ b/src/nanotron/slurm/run_eval.slurm.jinja @@ -34,8 +34,6 @@ #SBATCH --reservation={{ reservation }} {% endif %} -LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={{ local_path }} - set -e echo "START TIME: $(date)" #Show some environment variables @@ -63,12 +61,14 @@ fi export CUBLAS_WORKSPACE_CONFIG=":4096:8" export CUDA_DEVICE_MAX_CONNECTIONS="1" +export HUGGINGFACE_HUB_CACHE={{ hf_cache }} +export HF_DATASETS_CACHE={{ hf_cache }} +export HF_MODULES_CACHE={{ hf_cache }} +export HF_HOME={{ hf_cache }} + echo go $COUNT_NODE echo $HOSTNAMES -# Copying checkpoint from s3 to the node on node -mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER -s5cmd cp --exclude "optimizer/*" {{ model_checkpoint_path }}* $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER torch_dist_args="--nproc_per_node 8 \ --nnodes $COUNT_NODE \ @@ -79,7 +79,7 @@ torch_dist_args="--nproc_per_node 8 \ launch_args="$torch_dist_args \ /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ - --checkpoint-config-path ${LOCAL_DOWNLOAD_CHECKPOINT_FOLDER}/config.yaml \ + --checkpoint-config-path {{ model_checkpoint_path }}/config.yaml \ --hf-user-or-org {{ hf_user_or_org }} \ " diff --git a/src/nanotron/slurm/run_eval_s3.slurm.jinja b/src/nanotron/slurm/run_eval_s3.slurm.jinja index a3e9bb28..ee467274 100644 --- a/src/nanotron/slurm/run_eval_s3.slurm.jinja +++ b/src/nanotron/slurm/run_eval_s3.slurm.jinja @@ -1,22 +1,42 @@ #!/bin/bash #SBATCH --job-name={{ job_name }} -#SBATCH --nodes=1 +#SBATCH --nodes={{ nodes }} #SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --gres=gpu:{{ gpus_per_node }} +{% if cpus_per_task %} +#SBATCH --cpus-per-task={{ cpus_per_task }} +{% endif %} #SBATCH --partition={{ partition }} -#SBATCH --gres=gpu:{{ gpu_per_node }} -#SBATCH --cpus-per-task={{ cpus_per_task}} #SBATCH --output={{ eval_path }}/%x-%n-%j.out #SBATCH --error={{ eval_path }}/%x-%n-%j.err -#SBATCH --qos=high -#SBATCH --dependency=singleton -#SBATCH --mail-type=FAIL -#SBATCH --mail-user={{ mail }} -#SBATCH --time=01:00:00 +{% if qos %} +#SBATCH --qos={{ qos }} +{% endif %} +{% if mail_type %} +#SBATCH --mail-type={{ mail_type }} +{% endif %} +{% if mail_user %} +#SBATCH --mail-user={{ mail_user }} +{% endif %} +{% if exclude_nodes %} +#SBATCH --exclude={{ exclude_nodes|join(',') }} +{% endif %} +{% if time %} +#SBATCH --time={{ time }} +{% endif %} +{% if constraint %} +#SBATCH --constraint={{ constraint }} +{% endif %} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +{% if reservation %} +#SBATCH --reservation={{ reservation }} +{% endif %} +LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={{ local_path }} -CHECKPOINT_FOLDER={{ ckpt_local_path }} - -set -x -e +set -e echo "START TIME: $(date)" #Show some environment variables echo python3 version = `python3 --version` @@ -43,9 +63,18 @@ fi export CUBLAS_WORKSPACE_CONFIG=":4096:8" export CUDA_DEVICE_MAX_CONNECTIONS="1" +export HUGGINGFACE_HUB_CACHE={{ hf_cache }} +export HF_DATASETS_CACHE={{ hf_cache }} +export HF_MODULES_CACHE={{ hf_cache }} +export HF_HOME={{ hf_cache }} + echo go $COUNT_NODE echo $HOSTNAMES +# Copying checkpoint from s3 to the node on node +mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER +s5cmd cp --exclude "optimizer/*" {{ model_checkpoint_path }}* $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER + torch_dist_args="--nproc_per_node 8 \ --nnodes $COUNT_NODE \ --max_restarts 0 \ @@ -55,7 +84,7 @@ torch_dist_args="--nproc_per_node 8 \ launch_args="$torch_dist_args \ /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ - --checkpoint-config-path ${CHECKPOINT_FOLDER}/config.yaml \ + --checkpoint-config-path ${LOCAL_DOWNLOAD_CHECKPOINT_FOLDER}/config.yaml \ --hf-user-or-org {{ hf_user_or_org }} \ " diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index d42671d1..ebaff33e 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -262,8 +262,8 @@ def pre_init(self): self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context) def post_init(self): - # S3 Mover and save initial state - if self.config.s3_upload is not None: + # S3 Mover and save initial state (only if we need to upload checkpoints on s3) + if self.config.s3_upload is not None and self.config.s3_upload.upload_s3_path is not None: # Only local rank 0 should upload dummy = bool(int(os.environ.get("LOCAL_RANK", None)) != 0) self.s3_mover = S3Mover( @@ -277,29 +277,22 @@ def post_init(self): ) else: self.s3_mover = None - - if self.config.lighteval is not None and dist.get_rank(self.parallel_context.world_pg) == 0: - # We only start evaluation runs once on the first node - self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) - if self.s3_mover is not None and self.config.general.eval_slurm_config is not None: - self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint - elif self.config.general.eval_slurm_config is not None and self.s3_mover is None: - # Use the no_s3 version of the evaluation function - self.post_checkpoint_callback = self.lighteval_runner.eval_single_checkpoint_no_s3 + if dist.get_rank(self.parallel_context.world_pg) == 0: + # check if slurm is configured + # TODO @eliebak rewrite the logic by self.slurm + there can be s3upload AND checkpoint path local which is not support for now + if self.config.lighteval is not None and self.config.general.eval_slurm_config is not None: + self.lighteval_runner = LightEvalRunner(config=self.config, parallel_context=self.parallel_context) + if self.s3_mover is not None: + self.s3_mover.post_upload_callback = self.lighteval_runner.eval_single_checkpoint + self.post_checkpoint_callback = None + else: + # Use the no_s3 version of the evaluation function + # TODO: make it one function + make it automatic to switch to the right jinja template + self.post_checkpoint_callback = self.lighteval_runner.eval_single_checkpoint_no_s3 else: - log_rank("LightEval is enabled but Slurm is not configured, skipping evaluation", logger=logger, level=logging.INFO, rank=0) + self.post_checkpoint_callback = None else: self.post_checkpoint_callback = None - - def post_save_checkpoint(self): - # Upload to S3 - if self.s3_mover is not None: - self.s3_mover.start_uploading() - elif self.post_checkpoint_callback is not None: - # If we're not using S3, but we have a post-checkpoint callback, execute it - checkpoint_path = self.config.checkpoints.checkpoints_path / f"{self.iteration_step}" - self.post_checkpoint_callback(checkpoint_path) - def pre_training(self, *args, **kwargs): self._print_training_plan() @@ -319,12 +312,22 @@ def pre_training(self, *args, **kwargs): name=f"{current_time}_{self.config.general.run}", config={"nanotron_config": self.config.as_dict()}, ) - # In case we resume from another run - initial_step = getattr(self.config.general, 'step', None) - if initial_step is not None: - wandb.run.step = initial_step - else: - wandb.run.step = 0 # Start from 0, will become 1 on first log + # Define tokens metric as x-axis for all metrics + wandb.define_metric("Tokens") + wandb.define_metric("*", step_metric="Tokens") + + # Handle resuming from a previous run + initial_step = getattr(self.config.general, 'step', 0) + if initial_step is None: + initial_step = 0 + + initial_tokens = initial_step * self.global_batch_size + + # Log initial tokens to set the starting point + wandb.log({"Tokens": initial_tokens}) + + print(f"Initial Tokens: {initial_tokens}") + def post_train_step(self): @@ -625,11 +628,11 @@ def train_step_logs( log_entries = [ # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), - LogItem( - "consumed_tokens", - self.metadata.consumed_train_samples * self.config.tokens.sequence_length, - "human_format", - ), # , "12d"), + # LogItem( + # "consumed_tokens", + # self.metadata.consumed_train_samples * self.config.tokens.sequence_length, + # "human_format", + # ), # , "12d"), LogItem("elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"), LogItem("tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), LogItem( @@ -668,6 +671,7 @@ def train_step_logs( { **{log_item.tag: log_item.scalar_value for log_item in log_entries}, "iteration_step": self.iteration_step, + "Tokens": self.metadata.consumed_train_samples * self.config.tokens.sequence_length, } ) @@ -901,9 +905,10 @@ def post_save_checkpoint(self): # Upload to S3 if self.s3_mover is not None: self.s3_mover.start_uploading() + elif self.post_checkpoint_callback is not None: - # If we're not using S3, but we have a post-checkpoint callback, execute it - checkpoint_path = self.config.checkpoints.checkpoints_path / f"{self.iteration_step}" + # If we're not using S3, but we have a post-checkpoint callback for evals + checkpoint_path = self.config.checkpoints.checkpoints_path / f"{self.config.general.step}" self.post_checkpoint_callback(checkpoint_path) diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index a99289da..9e23f381 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -6,6 +6,9 @@ import re from contextlib import ExitStack, contextmanager from typing import ContextManager, List, Optional +import json +import wandb +import os import torch from packaging import version @@ -165,4 +168,4 @@ def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int: def check_path_is_s3(path:str) -> bool: #TODO maybe replace by a better method ? s3_pattern = r'^s3://|^https?://[\w\-\.]+\.s3\.amazonaws\.com/|^https?://s3\.amazonaws\.com/[\w\-\.]+' - return bool(re.match(s3_pattern, path)) + return bool(re.match(s3_pattern, path)) \ No newline at end of file From 8ce8b181f329a96ca07bcd1d37be9bc76dbfbd66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Mon, 2 Sep 2024 01:33:19 +0000 Subject: [PATCH 18/43] fix some bug with the slurm related stuff --- create_config.py | 48 ++++----- launcher.py | 68 ++++++++----- src/nanotron/config/config.py | 14 +-- src/nanotron/lighteval/one_job_runner.py | 8 +- src/nanotron/serialize/main.py | 6 +- src/nanotron/slurm/eval_slurm_config.json | 29 ------ src/nanotron/slurm/launch_slurm_config.json | 28 ------ .../slurm/launch_training.slurm.jinja | 97 ------------------- src/nanotron/slurm/run_eval.slurm.jinja | 86 ---------------- src/nanotron/slurm/run_eval_s3.slurm.jinja | 91 ----------------- 10 files changed, 83 insertions(+), 392 deletions(-) delete mode 100644 src/nanotron/slurm/eval_slurm_config.json delete mode 100644 src/nanotron/slurm/launch_slurm_config.json delete mode 100644 src/nanotron/slurm/launch_training.slurm.jinja delete mode 100644 src/nanotron/slurm/run_eval.slurm.jinja delete mode 100644 src/nanotron/slurm/run_eval_s3.slurm.jinja diff --git a/create_config.py b/create_config.py index fad942f0..97206612 100644 --- a/create_config.py +++ b/create_config.py @@ -33,22 +33,21 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--project", help="project name", type=str, required=True) - parser.add_argument("--run", help="run name", type=str, required=True) + parser.add_argument("--save-path", help="path to save the configuration file", type=str, required=True) parser.add_argument("--seed", help="seed", type=int, default=8) parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="high") parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", help="Override config values. Use dot notation for nested keys.") parser.add_argument("--launch", action="store_true", help="Launch the configuration immediately") + parser.add_argument("--run", help="name of the run", type=str) + parser.add_argument("--logs-path", help="path to the logs folder", type=str) parser.add_argument("--slurm", help="use slurm", action="store_true") parser.add_argument("--nodes", help="specify the number of nodes", type=int) args = parser.parse_args() timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") general = GeneralArgs( - project=args.project, - run=args.run, - logs_path="/fsx/elie_bakouch/nanotron/refactor-logs", + project="smollm", seed=args.seed, temp_dir="/scratch", ) @@ -103,7 +102,7 @@ hub_repo_tensorboard="smollm-evals-visualization", tensorboard_metric_prefix="eval", ), - slurm_template="/fsx/elie_bakouch/nanotron/src/nanotron/slurm/run_eval.slurm.jinja", + slurm_template="/fsx/elie_bakouch/nanotron/slurm/run_eval.slurm.jinja", ) @@ -184,7 +183,6 @@ # s5cmd_concurrency=5, # s5cmd_path="/fsx/elie_bakouch/miniconda3/envs/smollm/bin/s5cmd", # ) - s3_upload = None data_stages=[ DatasetStageArgs( data=DataArgs( @@ -209,37 +207,43 @@ tokens=tokens, optimizer=optimizer, data_stages=data_stages, - s3_upload=s3_upload, lighteval=lighteval, ) - dir = os.path.dirname(__file__) - - # Create the necessary directories - project_log_folder = Path(general.logs_path) - log_folder = project_log_folder / f"{general.run}-{general.project}" - config_folder = log_folder / 'configs' - config_folder.mkdir(parents=True, exist_ok=True) - config.general.config_logs_path = str(config_folder) + save_path= Path(args.save_path) + save_path.mkdir(parents=True, exist_ok=True) - config_path_yaml = f"{config.general.config_logs_path}/{timestamp}_create.yaml" + config_path_yaml = save_path / f"{args.run}-{timestamp}.yaml" config.save_as_yaml(config_path_yaml) - print(f"💾 Configuration saved in: {config.general.config_logs_path}") + print(f"💾 Configuration saved in: {str(save_path)}") if args.launch: - launcher_path = os.path.join(dir, "launcher.py") + # Change the launcher_path + # Sanity check for logs_path and run + if not args.logs_path: + raise ValueError("--logs_path must be defined. Please provide a path for the logs.") + if not args.run: + raise ValueError("--run must be defined. Please provide a name for the run.") + + launcher_path = Path("launcher.py") + if not launcher_path.exists(): + raise FileNotFoundError(f"Launcher not found at {launcher_path}. Please ensure the file exists or change the launcher path in the create_config.py file.") launch_command = [ - "python", launcher_path, - config_path_yaml, + "python", str(launcher_path), + "--config-path", str(config_path_yaml), ] - + launch_command.extend([ + "--logs-path", args.logs_path, + "--run", args.run + ]) if args.slurm: launch_command.append("--slurm") if args.nodes: launch_command.extend(["--nodes", str(args.nodes)]) + print(f"🧪 Launching configuration with command: {' '.join(launch_command)}") subprocess.run(launch_command, check=True) else: diff --git a/launcher.py b/launcher.py index 10c5f645..fcf8a8a6 100644 --- a/launcher.py +++ b/launcher.py @@ -15,6 +15,9 @@ get_config_from_file, ) +def count_subdirectories(path): + return sum(os.path.isdir(os.path.join(path, item)) for item in os.listdir(path)) + def launch_slurm_job(launch_file_contents, *args): """ Small helper function to save a sbatch script and call it. @@ -41,7 +44,9 @@ def set_nested_attribute(obj, path, value): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("config_path", help="path to the configuration file", type=str) + parser.add_argument("--config-path", help="path to the configuration file", type=str,required=True) + parser.add_argument("--run", help="name of the run", type=str, required=True) + parser.add_argument("--logs-path", help="path to the logs folder", type=str, default=None) parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", help="Override config values. Use dot notation for nested keys.") parser.add_argument("--slurm", action="store_true", help="Launch the job on Slurm") @@ -55,6 +60,9 @@ def set_nested_attribute(obj, path, value): # Load the configuration using get_config_from_file config = get_config_from_file(args.config_path, config_class=Config) + if config.general.logs_path is None and args.logs_path is None: + raise ValueError("Please provide a logs path") + if config.model.model_config.tie_word_embeddings ==True: tie_word_embeddings_multiplier = 1 else: @@ -162,20 +170,32 @@ def set_nested_attribute(obj, path, value): └───────────────────────┴────────────────────────┘ """) + config.general.logs_path = args.logs_path + config.general.run = args.run + + path = Path(args.logs_path) / f"{args.run}" + path.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + run_number = count_subdirectories(f"{args.logs_path}/{args.run}") + 1 + timestamp_with_run = f"run{run_number:03d}_{timestamp}" + + config.general.config_logs_path = f"{config.general.logs_path}/{args.run}/{timestamp_with_run}/config" + Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) + + + #making sure the logs path folder exists + if args.slurm: nodes = args.nodes - launch_slurm_config_path = os.path.join(os.path.dirname(__file__), "src/nanotron/slurm/launch_slurm_config.json") - eval_slurm_config_path = os.path.join(os.path.dirname(__file__), "src/nanotron/slurm/eval_slurm_config.json") + launch_slurm_config_path = Path("./slurm/launch_slurm_config.json") + eval_slurm_config_path = Path("./slurm/eval_slurm_config.json") with open(launch_slurm_config_path, 'r') as f: launch_slurm_config = json.load(f) - with open(eval_slurm_config_path, 'r') as f: - eval_slurm_config = json.load(f) - total_gpus = config.parallelism.dp * config.parallelism.pp * config.parallelism.tp gpus_per_node = launch_slurm_config.get('gpus_per_node') @@ -187,8 +207,8 @@ def set_nested_attribute(obj, path, value): # Create necessary folders project_log_folder = Path(config.general.logs_path) - log_folder = project_log_folder / f"{config.general.run}-{config.general.project}" - subfolders = ['launch-script', 'slurm'] + log_folder = project_log_folder / f"{args.run}"/ f"{timestamp_with_run}" + subfolders = ['launch-script', 'slurm-logs'] if hasattr(config, 'lighteval') and config.lighteval is not None: subfolders.append('evals') @@ -197,7 +217,7 @@ def set_nested_attribute(obj, path, value): os.makedirs(folder_path, exist_ok=True) if subfolder == 'launch-script': config.general.launch_script_path = folder_path - elif subfolder == 'slurm': + elif subfolder == 'slurm-logs': config.general.slurm_logs_path = folder_path elif subfolder == 'evals': config.general.evals_logs_path = folder_path @@ -214,22 +234,26 @@ def set_nested_attribute(obj, path, value): "job_name": f"{config.general.project}-{config.general.run}", "nodes": args.nodes, "slurm_logs_path": config.general.slurm_logs_path, - "timestamp": timestamp, "path_to_trainer_python_file": os.path.join(os.path.dirname(__file__), "run_train.py"), - "config_path_yaml": f"{config.general.config_logs_path}/{timestamp}_launch.yaml", + "config_path_yaml": f"{config.general.config_logs_path}/launch.yaml", "torchrun_args": torchrun_args, }) # Load Jinja2 template - template_path = os.path.join(os.path.dirname(__file__), "src/nanotron/slurm/launch_training.slurm.jinja") + template_path = Path("slurm/launch_training.slurm.jinja") with open(template_path, 'r') as f: template = Template(f.read()) # Render the template sbatch_script = template.render(**launch_slurm_config) - - config.general.launch_slurm_config = launch_slurm_config - config.general.eval_slurm_config = eval_slurm_config + if launch_slurm_config_path.exists(): + config.general.launch_slurm_config = str(launch_slurm_config_path.resolve()) + else: + config.general.launch_slurm_config = None + if eval_slurm_config_path.exists(): + config.general.eval_slurm_config = str(eval_slurm_config_path.resolve()) + else: + config.general.eval_slurm_config = None config.save_as_yaml(launch_slurm_config["config_path_yaml"]) @@ -240,7 +264,7 @@ def set_nested_attribute(obj, path, value): # Save the Slurm script if a path is provided if config.general.launch_script_path: os.makedirs(config.general.launch_script_path, exist_ok=True) - script_filename = f"slurm_script_{timestamp}.slurm" + script_filename = f"slurm_launch_script.slurm" script_path = os.path.join(config.general.launch_script_path, script_filename) with open(script_path, 'w') as f: @@ -266,7 +290,6 @@ def set_nested_attribute(obj, path, value): if is_interactive: print("💻 Running on an interactive node with GPUs.") - # Check if the parallelism configuration matches the available GPUs total_gpus = gpu_count config_gpus = config.parallelism.dp * config.parallelism.tp * config.parallelism.pp @@ -275,21 +298,16 @@ def set_nested_attribute(obj, path, value): f"doesn't match the number of available GPUs ({total_gpus}). " f"Please adjust your configuration to match the available resources.") - # Save config - timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - os.makedirs("/fsx/elie_bakouch/nanotron/config_logs", exist_ok=True) - config_path_yaml = f"/fsx/elie_bakouch/nanotron/config_logs/{timestamp}.yaml" + config_path_yaml = f"{config.general.config_logs_path}/launch.yaml" + os.makedirs("config.general.config_logs_path", exist_ok=True) config.save_as_yaml(config_path_yaml) - # Prepare command - trainer_python_file = "/fsx/elie_bakouch/nanotron/run_train.py" + trainer_python_file = "run_train.py" cmd = f"{trainer_python_file} --config-file {args.config_path}" - # Launch job launch_cmd = f"CUDA_DEVICE_MAX_CONNECTIONS='1' torchrun --nproc_per_node {gpu_count} {cmd}" print(f"🚀 Launching interactive job with command: {launch_cmd}") - # Execute the command subprocess.run(launch_cmd, shell=True, check=True) else: print("❌ Not running on a Slurm cluster or an interactive node with GPUs. Please submit a Slurm job or use an interactive node with GPUs.") \ No newline at end of file diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 36698ab7..6e831cee 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -96,11 +96,11 @@ def __post_init__(self): class S3UploadArgs: """Arguments related to uploading checkpoints on s3""" - upload_s3_path: Optional[str] = None #set to None if we want to use S3UploadArgs to download checkpoints from s3 but not upload checkpoints on s3 remove_after_upload: bool - s5cmd_numworkers: Optional[int] - s5cmd_concurrency: Optional[int] - s5cmd_path: Optional[str] + upload_s3_path: Optional[str] = None # set to None if we want to use S3UploadArgs to download checkpoints from s3 but not upload checkpoints on s3 + s5cmd_numworkers: Optional[int] = None + s5cmd_concurrency: Optional[int] = None + s5cmd_path: Optional[str] = None def __post_init__(self): if isinstance(self.upload_s3_path, str) and self.upload_s3_path is not None: @@ -189,10 +189,10 @@ class GeneralArgs: """ project: str - run: str + run: Optional[str] = None logs_path: Optional[str] = "./logs" - launch_slurm_config: Optional[dict] = None - eval_slurm_config: Optional[dict] = None + launch_slurm_config: Optional[str] = None + eval_slurm_config: Optional[str] = None launch_script_path: Optional[str] = None slurm_logs_path: Optional[str] = None config_logs_path: Optional[str] = None diff --git a/src/nanotron/lighteval/one_job_runner.py b/src/nanotron/lighteval/one_job_runner.py index e6ceafb9..b56aafda 100644 --- a/src/nanotron/lighteval/one_job_runner.py +++ b/src/nanotron/lighteval/one_job_runner.py @@ -6,7 +6,7 @@ import subprocess from typing import List, Optional, Tuple, Union import copy - +import json import jinja2 from nanotron import logging from nanotron.logging import log_rank @@ -102,6 +102,9 @@ def run_slurm_one_job( eval_launch_script_path = os.path.join(config.general.evals_logs_path, "launch-config", str(current_step)) eval_logs_path = os.path.join(config.general.evals_logs_path, "logs", str(current_step)) + with open(config.general.eval_slurm_config, 'r') as f: + eval_slurm_config = json.load(f) + os.makedirs(eval_launch_script_path, exist_ok=True) os.makedirs(eval_logs_path, exist_ok=True) @@ -113,9 +116,6 @@ def run_slurm_one_job( with open(slurm_template, "r") as f: SLURM_JOBS_ARRAY_TEMPLATE = environment.from_string(f.read()) - #Not sure if this is the best way to do it. Maybe we need to add a copy or deepcopy somewhere. - eval_slurm_config = config.general.eval_slurm_config - # Update the config with additional required parameters # Calculate the number of nodes based on parallelism config and gpus_per_node total_gpus_needed = lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 3c93c69a..a5163ae2 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -256,7 +256,7 @@ def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Option load_from_candidate = config.checkpoints.resume_checkpoint_path if load_from_candidate is not None: if check_path_is_local(load_from_candidate): - latest_meta_path: xPath = config.checkpoints.resume_checkpoint_path / "latest.txt" + latest_meta_path: Path = config.checkpoints.resume_checkpoint_path / "latest.txt" if latest_meta_path.exists(): with fs_open(config.checkpoints.resume_checkpoint_path / "latest.txt", mode="r") as fi: # TODO @thomasw21: make a better structure system so that we get typing correct @@ -283,10 +283,10 @@ def parse_ckpt_path(config: Config, parallel_context: ParallelContext) -> Option rank=0, ) else: - # elif check_path_is_s3(str(load_from_candidate)): + #We assume that the checkpoint path is from s3 (maybe add more cases later ?) latest_meta_path = config.checkpoints.resume_checkpoint_path / "latest.txt" if latest_meta_path.exists(): - # if latest.txt exists, we assume that the checkpoint path is a path to a folder containing the checkpoint + with fs_open(latest_meta_path, mode="r") as fi: latest_iteration = int(fi.read()) s3_path = config.checkpoints.resume_checkpoint_path / str(latest_iteration) # load_path diff --git a/src/nanotron/slurm/eval_slurm_config.json b/src/nanotron/slurm/eval_slurm_config.json deleted file mode 100644 index 4cf78cf1..00000000 --- a/src/nanotron/slurm/eval_slurm_config.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "job_name": "", - "n_tasks_per_node": 1, - "cpus_per_task": 32, - "gpus_per_node": 8, - "partition": "hopper-prod", - "qos": "high", - "mail_type": null, - "mail_user": null, - "exclude_nodes": null, - "time": "1:00:00", - "constraint": null, - "account": null, - "reservation": null, - "torchrun_args": { - "rdzv_backend": "etcd-v2", - "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", - "rdzv_id": "$SLURM_JOB_ID", - "node_rank": "$SLURM_PROCID", - "role": "$SLURMD_NODENAME", - "max_restarts": 0, - "tee": 3 - }, - "hf_cache": "/fsx/elie_bakouch/.cache", - "array": null, - "mem": null, - "begin": null - } - \ No newline at end of file diff --git a/src/nanotron/slurm/launch_slurm_config.json b/src/nanotron/slurm/launch_slurm_config.json deleted file mode 100644 index e06cca36..00000000 --- a/src/nanotron/slurm/launch_slurm_config.json +++ /dev/null @@ -1,28 +0,0 @@ -{ - "job_name": "", - "n_tasks_per_node": 1, - "cpus_per_task": 60, - "gpus_per_node": 8, - "partition": "hopper-prod", - "qos": "high", - "mail_type": null, - "mail_user": null, - "exclude_nodes": ["ip-26-0-161-138"], - "time": "01:30:00", - "constraint": null, - "account": null, - "reservation": null, - "torchrun_args": { - "rdzv_backend": "etcd-v2", - "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", - "rdzv_id": "$SLURM_JOB_ID", - "node_rank": "$SLURM_PROCID", - "role": "$SLURMD_NODENAME", - "max_restarts": 0, - "tee": 3 - }, - "hf_cache": "/fsx/elie_bakouch/.cache", - "array": null, - "mem": null, - "begin": null -} diff --git a/src/nanotron/slurm/launch_training.slurm.jinja b/src/nanotron/slurm/launch_training.slurm.jinja deleted file mode 100644 index 7ede5993..00000000 --- a/src/nanotron/slurm/launch_training.slurm.jinja +++ /dev/null @@ -1,97 +0,0 @@ -#!/bin/bash -#SBATCH --job-name={{ job_name }} -#SBATCH --nodes={{ nodes }} -#SBATCH --ntasks-per-node={{ n_tasks_per_node }} -#SBATCH --gres=gpu:{{ gpus_per_node }} -{% if cpus_per_task %} -#SBATCH --cpus-per-task={{ cpus_per_task }} -{% endif %} -#SBATCH --partition={{ partition }} -#SBATCH --output={{ slurm_logs_path }}/train-{{ timestamp }}-%j.out -#SBATCH --error={{ slurm_logs_path }}/train-{{ timestamp }}-%j.err -{% if qos %} -#SBATCH --qos={{ qos }} -{% endif %} -{% if mail_type %} -#SBATCH --mail-type={{ mail_type }} -{% endif %} -{% if mail_user %} -#SBATCH --mail-user={{ mail_user }} -{% endif %} -{% if exclude_nodes %} -#SBATCH --exclude={{ exclude_nodes|join(',') }} -{% endif %} -{% if time %} -#SBATCH --time={{ time }} -{% endif %} -{% if constraint %} -#SBATCH --constraint={{ constraint }} -{% endif %} -{% if account %} -#SBATCH --account={{ account }} -{% endif %} -{% if reservation %} -#SBATCH --reservation={{ reservation }} -{% endif %} - -set -e - -TRAINER_PYTHON_FILE={{ path_to_trainer_python_file }} -nvidia-smi - -# Show some environment variables -echo python3 version = `python3 --version` -echo "Python path: $(which python3)" -echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" -echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" - -echo "START TIME: $(date)" -secs_to_human() { - echo "$(( ${1} / 3600 )):$(( (${1} / 60) % 60 )):$(( ${1} % 60 ))" -} -start=$(date +%s) -echo "$(date -d @${start} "+%Y-%m-%d %H:%M:%S"): ${SLURM_JOB_NAME} start id=${SLURM_JOB_ID}\n" - -# SLURM stuff -export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` -export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) -export MASTER_PORT=6000 -export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` - -export TMPDIR=/scratch -export CUDA_DEVICE_MAX_CONNECTIONS="1" - -export HUGGINGFACE_HUB_CACHE={{ hf_cache }} -export HF_DATASETS_CACHE={{ hf_cache }} -export HF_MODULES_CACHE={{ hf_cache }} -export HF_HOME={{ hf_cache }} - -module load cuda/12.1 - -echo go $COUNT_NODE -echo $HOSTNAMES - -CMD=" $TRAINER_PYTHON_FILE \ - --config-file {{ config_path_yaml }} \ - " -export LAUNCHER="torchrun \ - --nproc_per_node {{ gpus_per_node }} \ - --nnodes $COUNT_NODE \ - {{ torchrun_args }} \ - --node_rank $SLURM_PROCID \ - --role $SLURMD_NODENAME: \ - --max_restarts 0 \ - --tee 3 \ - " - -# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub -random_milliseconds=$(( RANDOM % 1001 )) -sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") -echo "Sleeping for $sleep_time seconds..." -sleep $sleep_time - -launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" - -srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" - -echo "END TIME: $(date)" \ No newline at end of file diff --git a/src/nanotron/slurm/run_eval.slurm.jinja b/src/nanotron/slurm/run_eval.slurm.jinja deleted file mode 100644 index 8cd3ee5a..00000000 --- a/src/nanotron/slurm/run_eval.slurm.jinja +++ /dev/null @@ -1,86 +0,0 @@ -#!/bin/bash -#SBATCH --job-name={{ job_name }} -#SBATCH --nodes={{ nodes }} -#SBATCH --ntasks-per-node={{ n_tasks_per_node }} -#SBATCH --gres=gpu:{{ gpus_per_node }} -{% if cpus_per_task %} -#SBATCH --cpus-per-task={{ cpus_per_task }} -{% endif %} -#SBATCH --partition={{ partition }} -#SBATCH --output={{ eval_path }}/%x-%n-%j.out -#SBATCH --error={{ eval_path }}/%x-%n-%j.err -{% if qos %} -#SBATCH --qos={{ qos }} -{% endif %} -{% if mail_type %} -#SBATCH --mail-type={{ mail_type }} -{% endif %} -{% if mail_user %} -#SBATCH --mail-user={{ mail_user }} -{% endif %} -{% if exclude_nodes %} -#SBATCH --exclude={{ exclude_nodes|join(',') }} -{% endif %} -{% if time %} -#SBATCH --time={{ time }} -{% endif %} -{% if constraint %} -#SBATCH --constraint={{ constraint }} -{% endif %} -{% if account %} -#SBATCH --account={{ account }} -{% endif %} -{% if reservation %} -#SBATCH --reservation={{ reservation }} -{% endif %} - -set -e -echo "START TIME: $(date)" -#Show some environment variables -echo python3 version = `python3 --version` -echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" -echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" - -# SLURM stuff -export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` -export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) -export MASTER_PORT=6000 -export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` - -# Hugging Face token -if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then - # Attempt to read the token from the cache - if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then - export HUGGING_FACE_HUB_TOKEN=$TOKEN - else - echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." - exit 1 - fi -fi - -export CUBLAS_WORKSPACE_CONFIG=":4096:8" -export CUDA_DEVICE_MAX_CONNECTIONS="1" - -export HUGGINGFACE_HUB_CACHE={{ hf_cache }} -export HF_DATASETS_CACHE={{ hf_cache }} -export HF_MODULES_CACHE={{ hf_cache }} -export HF_HOME={{ hf_cache }} - -echo go $COUNT_NODE -echo $HOSTNAMES - - -torch_dist_args="--nproc_per_node 8 \ - --nnodes $COUNT_NODE \ - --max_restarts 0 \ - --tee 3 \ - --node_rank $SLURM_PROCID \ - --role $SLURMD_NODENAME: " - -launch_args="$torch_dist_args \ - /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ - --checkpoint-config-path {{ model_checkpoint_path }}/config.yaml \ - --hf-user-or-org {{ hf_user_or_org }} \ - " - -srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file diff --git a/src/nanotron/slurm/run_eval_s3.slurm.jinja b/src/nanotron/slurm/run_eval_s3.slurm.jinja deleted file mode 100644 index ee467274..00000000 --- a/src/nanotron/slurm/run_eval_s3.slurm.jinja +++ /dev/null @@ -1,91 +0,0 @@ -#!/bin/bash -#SBATCH --job-name={{ job_name }} -#SBATCH --nodes={{ nodes }} -#SBATCH --ntasks-per-node={{ n_tasks_per_node }} -#SBATCH --gres=gpu:{{ gpus_per_node }} -{% if cpus_per_task %} -#SBATCH --cpus-per-task={{ cpus_per_task }} -{% endif %} -#SBATCH --partition={{ partition }} -#SBATCH --output={{ eval_path }}/%x-%n-%j.out -#SBATCH --error={{ eval_path }}/%x-%n-%j.err -{% if qos %} -#SBATCH --qos={{ qos }} -{% endif %} -{% if mail_type %} -#SBATCH --mail-type={{ mail_type }} -{% endif %} -{% if mail_user %} -#SBATCH --mail-user={{ mail_user }} -{% endif %} -{% if exclude_nodes %} -#SBATCH --exclude={{ exclude_nodes|join(',') }} -{% endif %} -{% if time %} -#SBATCH --time={{ time }} -{% endif %} -{% if constraint %} -#SBATCH --constraint={{ constraint }} -{% endif %} -{% if account %} -#SBATCH --account={{ account }} -{% endif %} -{% if reservation %} -#SBATCH --reservation={{ reservation }} -{% endif %} - -LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={{ local_path }} - -set -e -echo "START TIME: $(date)" -#Show some environment variables -echo python3 version = `python3 --version` -echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" -echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" - -# SLURM stuff -export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` -export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) -export MASTER_PORT=6000 -export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` - -# Hugging Face token -if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then - # Attempt to read the token from the cache - if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then - export HUGGING_FACE_HUB_TOKEN=$TOKEN - else - echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." - exit 1 - fi -fi - -export CUBLAS_WORKSPACE_CONFIG=":4096:8" -export CUDA_DEVICE_MAX_CONNECTIONS="1" - -export HUGGINGFACE_HUB_CACHE={{ hf_cache }} -export HF_DATASETS_CACHE={{ hf_cache }} -export HF_MODULES_CACHE={{ hf_cache }} -export HF_HOME={{ hf_cache }} - -echo go $COUNT_NODE -echo $HOSTNAMES - -# Copying checkpoint from s3 to the node on node -mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER -s5cmd cp --exclude "optimizer/*" {{ model_checkpoint_path }}* $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER - -torch_dist_args="--nproc_per_node 8 \ - --nnodes $COUNT_NODE \ - --max_restarts 0 \ - --tee 3 \ - --node_rank $SLURM_PROCID \ - --role $SLURMD_NODENAME: " - -launch_args="$torch_dist_args \ - /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ - --checkpoint-config-path ${LOCAL_DOWNLOAD_CHECKPOINT_FOLDER}/config.yaml \ - --hf-user-or-org {{ hf_user_or_org }} \ - " - -srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file From 6dd81b280796f6ff1ffc391b2d4d3b870506914a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Mon, 2 Sep 2024 01:33:37 +0000 Subject: [PATCH 19/43] add back slurm --- slurm/eval_slurm_config.json | 29 +++++++++ slurm/launch_slurm_config.json | 28 +++++++++ slurm/launch_training.slurm.jinja | 97 +++++++++++++++++++++++++++++++ slurm/run_eval.slurm.jinja | 86 +++++++++++++++++++++++++++ slurm/run_eval_s3.slurm.jinja | 91 +++++++++++++++++++++++++++++ 5 files changed, 331 insertions(+) create mode 100644 slurm/eval_slurm_config.json create mode 100644 slurm/launch_slurm_config.json create mode 100644 slurm/launch_training.slurm.jinja create mode 100644 slurm/run_eval.slurm.jinja create mode 100644 slurm/run_eval_s3.slurm.jinja diff --git a/slurm/eval_slurm_config.json b/slurm/eval_slurm_config.json new file mode 100644 index 00000000..4cf78cf1 --- /dev/null +++ b/slurm/eval_slurm_config.json @@ -0,0 +1,29 @@ +{ + "job_name": "", + "n_tasks_per_node": 1, + "cpus_per_task": 32, + "gpus_per_node": 8, + "partition": "hopper-prod", + "qos": "high", + "mail_type": null, + "mail_user": null, + "exclude_nodes": null, + "time": "1:00:00", + "constraint": null, + "account": null, + "reservation": null, + "torchrun_args": { + "rdzv_backend": "etcd-v2", + "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", + "rdzv_id": "$SLURM_JOB_ID", + "node_rank": "$SLURM_PROCID", + "role": "$SLURMD_NODENAME", + "max_restarts": 0, + "tee": 3 + }, + "hf_cache": "/fsx/elie_bakouch/.cache", + "array": null, + "mem": null, + "begin": null + } + \ No newline at end of file diff --git a/slurm/launch_slurm_config.json b/slurm/launch_slurm_config.json new file mode 100644 index 00000000..f727f7f7 --- /dev/null +++ b/slurm/launch_slurm_config.json @@ -0,0 +1,28 @@ +{ + "job_name": "", + "n_tasks_per_node": 1, + "cpus_per_task": 60, + "gpus_per_node": 8, + "partition": "hopper-prod", + "qos": "high", + "mail_type": null, + "mail_user": null, + "exclude_nodes": ["ip-26-0-161-138"], + "time": null, + "constraint": null, + "account": null, + "reservation": null, + "torchrun_args": { + "rdzv_backend": "etcd-v2", + "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", + "rdzv_id": "$SLURM_JOB_ID", + "node_rank": "$SLURM_PROCID", + "role": "$SLURMD_NODENAME", + "max_restarts": 0, + "tee": 3 + }, + "hf_cache": "/fsx/elie_bakouch/.cache", + "array": null, + "mem": null, + "begin": null +} diff --git a/slurm/launch_training.slurm.jinja b/slurm/launch_training.slurm.jinja new file mode 100644 index 00000000..9d4f21bd --- /dev/null +++ b/slurm/launch_training.slurm.jinja @@ -0,0 +1,97 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes={{ nodes }} +#SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --gres=gpu:{{ gpus_per_node }} +{% if cpus_per_task %} +#SBATCH --cpus-per-task={{ cpus_per_task }} +{% endif %} +#SBATCH --partition={{ partition }} +#SBATCH --output={{ slurm_logs_path }}/train-%j.out +#SBATCH --error={{ slurm_logs_path }}/train-%j.err +{% if qos %} +#SBATCH --qos={{ qos }} +{% endif %} +{% if mail_type %} +#SBATCH --mail-type={{ mail_type }} +{% endif %} +{% if mail_user %} +#SBATCH --mail-user={{ mail_user }} +{% endif %} +{% if exclude_nodes %} +#SBATCH --exclude={{ exclude_nodes|join(',') }} +{% endif %} +{% if time %} +#SBATCH --time={{ time }} +{% endif %} +{% if constraint %} +#SBATCH --constraint={{ constraint }} +{% endif %} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +{% if reservation %} +#SBATCH --reservation={{ reservation }} +{% endif %} + +set -e + +TRAINER_PYTHON_FILE={{ path_to_trainer_python_file }} +nvidia-smi + +# Show some environment variables +echo python3 version = `python3 --version` +echo "Python path: $(which python3)" +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +echo "START TIME: $(date)" +secs_to_human() { + echo "$(( ${1} / 3600 )):$(( (${1} / 60) % 60 )):$(( ${1} % 60 ))" +} +start=$(date +%s) +echo "$(date -d @${start} "+%Y-%m-%d %H:%M:%S"): ${SLURM_JOB_NAME} start id=${SLURM_JOB_ID}\n" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +export TMPDIR=/scratch +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +export HUGGINGFACE_HUB_CACHE={{ hf_cache }} +export HF_DATASETS_CACHE={{ hf_cache }} +export HF_MODULES_CACHE={{ hf_cache }} +export HF_HOME={{ hf_cache }} + +module load cuda/12.1 + +echo go $COUNT_NODE +echo $HOSTNAMES + +CMD=" $TRAINER_PYTHON_FILE \ + --config-file {{ config_path_yaml }} \ + " +export LAUNCHER="torchrun \ + --nproc_per_node {{ gpus_per_node }} \ + --nnodes $COUNT_NODE \ + {{ torchrun_args }} \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: \ + --max_restarts 0 \ + --tee 3 \ + " + +# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub +random_milliseconds=$(( RANDOM % 1001 )) +sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") +echo "Sleeping for $sleep_time seconds..." +sleep $sleep_time + +launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +echo "END TIME: $(date)" \ No newline at end of file diff --git a/slurm/run_eval.slurm.jinja b/slurm/run_eval.slurm.jinja new file mode 100644 index 00000000..8cd3ee5a --- /dev/null +++ b/slurm/run_eval.slurm.jinja @@ -0,0 +1,86 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes={{ nodes }} +#SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --gres=gpu:{{ gpus_per_node }} +{% if cpus_per_task %} +#SBATCH --cpus-per-task={{ cpus_per_task }} +{% endif %} +#SBATCH --partition={{ partition }} +#SBATCH --output={{ eval_path }}/%x-%n-%j.out +#SBATCH --error={{ eval_path }}/%x-%n-%j.err +{% if qos %} +#SBATCH --qos={{ qos }} +{% endif %} +{% if mail_type %} +#SBATCH --mail-type={{ mail_type }} +{% endif %} +{% if mail_user %} +#SBATCH --mail-user={{ mail_user }} +{% endif %} +{% if exclude_nodes %} +#SBATCH --exclude={{ exclude_nodes|join(',') }} +{% endif %} +{% if time %} +#SBATCH --time={{ time }} +{% endif %} +{% if constraint %} +#SBATCH --constraint={{ constraint }} +{% endif %} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +{% if reservation %} +#SBATCH --reservation={{ reservation }} +{% endif %} + +set -e +echo "START TIME: $(date)" +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +# Hugging Face token +if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then + # Attempt to read the token from the cache + if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then + export HUGGING_FACE_HUB_TOKEN=$TOKEN + else + echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." + exit 1 + fi +fi + +export CUBLAS_WORKSPACE_CONFIG=":4096:8" +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +export HUGGINGFACE_HUB_CACHE={{ hf_cache }} +export HF_DATASETS_CACHE={{ hf_cache }} +export HF_MODULES_CACHE={{ hf_cache }} +export HF_HOME={{ hf_cache }} + +echo go $COUNT_NODE +echo $HOSTNAMES + + +torch_dist_args="--nproc_per_node 8 \ + --nnodes $COUNT_NODE \ + --max_restarts 0 \ + --tee 3 \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: " + +launch_args="$torch_dist_args \ + /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path {{ model_checkpoint_path }}/config.yaml \ + --hf-user-or-org {{ hf_user_or_org }} \ + " + +srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file diff --git a/slurm/run_eval_s3.slurm.jinja b/slurm/run_eval_s3.slurm.jinja new file mode 100644 index 00000000..ee467274 --- /dev/null +++ b/slurm/run_eval_s3.slurm.jinja @@ -0,0 +1,91 @@ +#!/bin/bash +#SBATCH --job-name={{ job_name }} +#SBATCH --nodes={{ nodes }} +#SBATCH --ntasks-per-node={{ n_tasks_per_node }} +#SBATCH --gres=gpu:{{ gpus_per_node }} +{% if cpus_per_task %} +#SBATCH --cpus-per-task={{ cpus_per_task }} +{% endif %} +#SBATCH --partition={{ partition }} +#SBATCH --output={{ eval_path }}/%x-%n-%j.out +#SBATCH --error={{ eval_path }}/%x-%n-%j.err +{% if qos %} +#SBATCH --qos={{ qos }} +{% endif %} +{% if mail_type %} +#SBATCH --mail-type={{ mail_type }} +{% endif %} +{% if mail_user %} +#SBATCH --mail-user={{ mail_user }} +{% endif %} +{% if exclude_nodes %} +#SBATCH --exclude={{ exclude_nodes|join(',') }} +{% endif %} +{% if time %} +#SBATCH --time={{ time }} +{% endif %} +{% if constraint %} +#SBATCH --constraint={{ constraint }} +{% endif %} +{% if account %} +#SBATCH --account={{ account }} +{% endif %} +{% if reservation %} +#SBATCH --reservation={{ reservation }} +{% endif %} + +LOCAL_DOWNLOAD_CHECKPOINT_FOLDER={{ local_path }} + +set -e +echo "START TIME: $(date)" +#Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +# SLURM stuff +export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` +export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_PORT=6000 +export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` + +# Hugging Face token +if [ -z "$HUGGING_FACE_HUB_TOKEN" ]; then + # Attempt to read the token from the cache + if TOKEN=$(cat ~/.cache/huggingface/token 2>/dev/null); then + export HUGGING_FACE_HUB_TOKEN=$TOKEN + else + echo "Error: The environment variable HUGGING_FACE_HUB_TOKEN is not set and the token cache could not be read." + exit 1 + fi +fi + +export CUBLAS_WORKSPACE_CONFIG=":4096:8" +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +export HUGGINGFACE_HUB_CACHE={{ hf_cache }} +export HF_DATASETS_CACHE={{ hf_cache }} +export HF_MODULES_CACHE={{ hf_cache }} +export HF_HOME={{ hf_cache }} + +echo go $COUNT_NODE +echo $HOSTNAMES + +# Copying checkpoint from s3 to the node on node +mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER +s5cmd cp --exclude "optimizer/*" {{ model_checkpoint_path }}* $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER + +torch_dist_args="--nproc_per_node 8 \ + --nnodes $COUNT_NODE \ + --max_restarts 0 \ + --tee 3 \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: " + +launch_args="$torch_dist_args \ + /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path ${LOCAL_DOWNLOAD_CHECKPOINT_FOLDER}/config.yaml \ + --hf-user-or-org {{ hf_user_or_org }} \ + " + +srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file From 4750736176db94c102854d7cadf5a42213bbd219 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Mon, 2 Sep 2024 03:03:52 +0000 Subject: [PATCH 20/43] fix some stuff + introduce --base-config --- create_config.py | 2 +- launcher.py | 18 +++++++++++++++- slurm/launch_slurm_config.json | 2 +- src/nanotron/config/config.py | 1 + src/nanotron/lighteval/evaluation_tasks.py | 25 +++++++++++++++++++++- src/nanotron/trainer.py | 9 +++++++- 6 files changed, 52 insertions(+), 5 deletions(-) diff --git a/create_config.py b/create_config.py index 97206612..02c8e7d7 100644 --- a/create_config.py +++ b/create_config.py @@ -110,7 +110,7 @@ checkpoints_path=f"/fsx/elie_bakouch/refactor-checkpoints/{general.project}-{general.run}", #checkpoints_path="CHECKPOINTS_PATH", checkpoints_path_is_shared_file_system=False, - resume_checkpoint_path=None, + resume_checkpoint_path="/fsx/elie_bakouch/refactor-checkpoints/smollm-%date_%jobid/60", checkpoint_interval=20, save_initial_state=False, ) diff --git a/launcher.py b/launcher.py index fcf8a8a6..0c8e1020 100644 --- a/launcher.py +++ b/launcher.py @@ -44,7 +44,8 @@ def set_nested_attribute(obj, path, value): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--config-path", help="path to the configuration file", type=str,required=True) + parser.add_argument("--config-path", help="path to the configuration file", type=str, default=None) + parser.add_argument("--base-config", help="base config to use", type=str, default=None) parser.add_argument("--run", help="name of the run", type=str, required=True) parser.add_argument("--logs-path", help="path to the logs folder", type=str, default=None) parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", @@ -53,6 +54,20 @@ def set_nested_attribute(obj, path, value): parser.add_argument("--nodes", type=int, help="Number of nodes to use for the job") args = parser.parse_args() + supported_base_configs = { + 'llama-1B': "path_to_the_config", + } + + if args.base_config is None and args.config_path is None: + raise ValueError("Please provide a base config or a config path") + + if args.base_config not in supported_base_configs.keys(): + raise ValueError(f"Base config {args.base_config} is not supported. Please choose one of the following: {supported_base_configs}") + + if args.config_path is not None and args.base_config is not None: + print("Both config_path and base_config are provided. Using config_path and ignoring base_config.") + args.base_config = None + if args.slurm: if args.nodes is None: raise ValueError("When using Slurm (--slurm), you must specify the number of nodes (--nodes)") @@ -179,6 +194,7 @@ def set_nested_attribute(obj, path, value): timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") run_number = count_subdirectories(f"{args.logs_path}/{args.run}") + 1 timestamp_with_run = f"run{run_number:03d}_{timestamp}" + config.general.timestamp_with_run = timestamp_with_run config.general.config_logs_path = f"{config.general.logs_path}/{args.run}/{timestamp_with_run}/config" Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) diff --git a/slurm/launch_slurm_config.json b/slurm/launch_slurm_config.json index f727f7f7..86b0a1f8 100644 --- a/slurm/launch_slurm_config.json +++ b/slurm/launch_slurm_config.json @@ -1,7 +1,7 @@ { "job_name": "", "n_tasks_per_node": 1, - "cpus_per_task": 60, + "cpus_per_task": 88, "gpus_per_node": 8, "partition": "hopper-prod", "qos": "high", diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 6e831cee..cfed3f1e 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -193,6 +193,7 @@ class GeneralArgs: logs_path: Optional[str] = "./logs" launch_slurm_config: Optional[str] = None eval_slurm_config: Optional[str] = None + timestamp_with_run: Optional[str] = None launch_script_path: Optional[str] = None slurm_logs_path: Optional[str] = None config_logs_path: Optional[str] = None diff --git a/src/nanotron/lighteval/evaluation_tasks.py b/src/nanotron/lighteval/evaluation_tasks.py index ff4342bb..a78fe486 100644 --- a/src/nanotron/lighteval/evaluation_tasks.py +++ b/src/nanotron/lighteval/evaluation_tasks.py @@ -16,6 +16,8 @@ _TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = [] _TASKS: List[LightevalTaskConfig] = [] +trust_remote_code = True + ## COMMON_SENSE_REASONING_TASKS ## COMMON_SENSE_REASONING_TASKS = [ LightevalTaskConfig( @@ -24,6 +26,7 @@ hf_repo="hellaswag", hf_subset="default", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + trust_dataset=True, ), LightevalTaskConfig( name="winogrande", @@ -31,6 +34,7 @@ hf_repo="winogrande", hf_subset="winogrande_xl", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="piqa", @@ -38,6 +42,7 @@ hf_repo="piqa", hf_subset="plain_text", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="siqa", @@ -46,6 +51,7 @@ hf_subset="default", hf_avail_splits=["train", "validation"], metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="openbookqa", @@ -53,6 +59,7 @@ hf_repo="openbookqa", hf_subset="main", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="arc:easy", @@ -62,6 +69,7 @@ evaluation_splits=["test"], generation_size=1, metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="arc:challenge", @@ -71,6 +79,7 @@ evaluation_splits=["test"], generation_size=1, metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="commonsense_qa", @@ -78,6 +87,7 @@ hf_repo="commonsense_qa", hf_subset="default", metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + trust_dataset=trust_remote_code, ), ] @@ -138,6 +148,7 @@ def preprocess(text): metric=[Metrics.quasi_exact_match], generation_size=20, stop_sequence=["\n", ".", ","], + trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="natural_questions", @@ -147,6 +158,7 @@ def preprocess(text): metric=[Metrics.quasi_exact_match], generation_size=20, stop_sequence=["\n", ".", ","], + trust_dataset=trust_remote_code, ), ] @@ -175,6 +187,7 @@ def natural_questions_prompt(line, task_name: str = None): hf_repo="super_glue", hf_subset="boolq", metric=["target_perplexity"], + trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="quac", @@ -184,6 +197,7 @@ def natural_questions_prompt(line, task_name: str = None): metric=[Metrics.quasi_exact_match], generation_size=20, stop_sequence=["\n", ".", ","], + trust_dataset=trust_remote_code, ), ] @@ -221,7 +235,8 @@ def __init__( generation_size=40, stop_sequence=None, output_regex=None, - frozen=False, + frozen=False, + trust_dataset=trust_remote_code, ): super().__init__( name=name, @@ -238,6 +253,7 @@ def __init__( stop_sequence=stop_sequence, output_regex=output_regex, frozen=frozen, + trust_dataset=trust_dataset, ) @@ -260,6 +276,7 @@ def __init__( metric=[Metrics.perfect_exact_match], generation_size=10, stop_sequence=["\n"], + trust_dataset=trust_remote_code, ) @@ -289,6 +306,7 @@ def __init__( stop_sequence=None, output_regex=None, frozen=False, + trust_dataset=trust_remote_code, ): super().__init__( name=name, @@ -305,6 +323,7 @@ def __init__( stop_sequence=stop_sequence, output_regex=output_regex, frozen=frozen, + trust_dataset=trust_dataset, ) @@ -431,6 +450,7 @@ def __init__( stop_sequence=None, output_regex=None, frozen=False, + trust_dataset=trust_remote_code, ): super().__init__( name=name, @@ -447,6 +467,7 @@ def __init__( stop_sequence=stop_sequence, output_regex=output_regex, frozen=frozen, + trust_dataset=trust_dataset, ) @@ -523,6 +544,7 @@ def __init__( stop_sequence=None, output_regex=None, frozen=False, + trust_dataset=trust_remote_code, ): super().__init__( name=name, @@ -539,6 +561,7 @@ def __init__( stop_sequence=stop_sequence, output_regex=output_regex, frozen=frozen, + trust_dataset=trust_dataset, ) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index ebaff33e..5b6bf32d 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -309,7 +309,7 @@ def pre_training(self, *args, **kwargs): if dist.get_rank(self.parallel_context.world_pg) == 0 and wandb is not None: wandb.init( project=self.config.general.project, - name=f"{current_time}_{self.config.general.run}", + name=f"{self.config.general.run}_{self.config.general.timestamp_with_run}", config={"nanotron_config": self.config.as_dict()}, ) # Define tokens metric as x-axis for all metrics @@ -338,6 +338,13 @@ def post_train_step(self): def post_training(self): if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg) + def post_training(self): + if self.s3_mover is not None: + self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg) + + if dist.get_rank(self.parallel_context.world_pg) == 0 and wandb is not None: + wandb.finish() + def _print_training_plan(self): if hasattr(self.config, "data_stages") and self.config.data_stages is not None: From 90860f521ac20b82a2daa682b0ddf32e2bfec57e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 3 Sep 2024 02:01:01 +0000 Subject: [PATCH 21/43] fix the computation calculation by adding GQA and layer norm at different places --- src/nanotron/models/llama.py | 55 ++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ca6c2441..9d8891a6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -818,8 +818,10 @@ def get_block_compute_costs(self): d_qkv = model_config.hidden_size // model_config.num_attention_heads block_compute_costs = { # CausalSelfAttention (qkv proj + attn out) + MLP - LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size - + 3 * d_ff * model_config.hidden_size, + LlamaDecoderLayer: 2 * model_config.num_attention_heads * d_qkv * model_config.hidden_size # Q output projection + + 2 * model_config.num_key_value_heads * d_qkv * model_config.hidden_size # KV + + 3 * d_ff * model_config.hidden_size # for the MLP (3 because of the gated mechanism) + + 2 * model_config.hidden_size, # for the layernorm # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } @@ -1033,6 +1035,7 @@ def get_flops( 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head + 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head ) + ## qk logits decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len ## v logits @@ -1047,6 +1050,10 @@ def get_flops( ## 2nd layer decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size + # Layer Norm (RMSNorm for LLaMA) + # There are typically 2 layer norms per transformer layer, plus one at the end + layer_norm_flops_fwd = (2 * num_layers + 1) * batch_size * seq_len * hidden_size * 2 # multiply by 2 for division and square root (square root take significatively more time ?) + decoder_flops_fwd = ( decoder_qkv_proj_flops_fwd + decoder_qk_logits_flops_fwd @@ -1054,6 +1061,7 @@ def get_flops( + decoder_attn_out_flops_fwd + decoder_ffn_1_flops_fwd + decoder_ffn_2_flops_fwd + + layer_norm_flops_fwd ) # lm head @@ -1066,3 +1074,46 @@ def get_flops( hardware_flops = model_flops # TODO: This is a placeholder for now return model_flops, hardware_flops + + + def get_llama_param_count(self): + # Embedding layer + embedding_params = self.vocab_size * self.hidden_size + + # Input RMS Norm + input_rms = self.num_hidden_layers * self.hidden_size + # Post attention RMS Norm + after_attention_rms = self.num_hidden_layers * self.hidden_size + + # Attention layers + attn_params = self.num_hidden_layers * ( + # Query projection + self.num_attention_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size + + # Key and Value projections (different than query in case of GQA) + 2 * self.num_key_value_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size + + # Output projection + self.num_attention_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size + ) + + # MLP layers + mlp_params = self.num_hidden_layers * ( + # First linear layer (2 for gated) + 2* self.hidden_size * self.intermediate_size + + # Second linear layer + self.intermediate_size * self.hidden_size + ) + + + # Final RMS Norm + final_rms = self.hidden_size + + total_params = ( + embedding_params + + input_rms + + after_attention_rms + + attn_params + + mlp_params + + final_rms + ) + + return total_params \ No newline at end of file From 28b3847c7648d784cb85f2745b1c3de08978bea4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 3 Sep 2024 02:03:57 +0000 Subject: [PATCH 22/43] change the localisation of get_llama_param_count() --- src/nanotron/config/models_config.py | 42 ++++++++++++++++++++++++++ src/nanotron/models/llama.py | 45 +--------------------------- 2 files changed, 43 insertions(+), 44 deletions(-) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 57225243..fb80dc46 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -52,6 +52,48 @@ class LlamaConfig: use_cache: bool = True vocab_size: int = 32000 + def get_llama_param_count(self): + # Embedding layer + embedding_params = self.vocab_size * self.hidden_size + + # Input RMS Norm + input_rms = self.num_hidden_layers * self.hidden_size + # Post attention RMS Norm + after_attention_rms = self.num_hidden_layers * self.hidden_size + + # Attention layers + attn_params = self.num_hidden_layers * ( + # Query projection + self.num_attention_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size + + # Key and Value projections (different than query in case of GQA) + 2 * self.num_key_value_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size + + # Output projection + self.num_attention_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size + ) + + # MLP layers + mlp_params = self.num_hidden_layers * ( + # First linear layer (2 for gated) + 2* self.hidden_size * self.intermediate_size + + # Second linear layer + self.intermediate_size * self.hidden_size + ) + + + # Final RMS Norm + final_rms = self.hidden_size + + total_params = ( + embedding_params + + input_rms + + after_attention_rms + + attn_params + + mlp_params + + final_rms + ) + + return total_params + def __post_init__(self): # NOTE: user don't set self._init_method, ModelArgs will set it # then we only pass LlamaConfig around diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 9d8891a6..9728f4ea 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -1073,47 +1073,4 @@ def get_flops( hardware_flops = model_flops # TODO: This is a placeholder for now - return model_flops, hardware_flops - - - def get_llama_param_count(self): - # Embedding layer - embedding_params = self.vocab_size * self.hidden_size - - # Input RMS Norm - input_rms = self.num_hidden_layers * self.hidden_size - # Post attention RMS Norm - after_attention_rms = self.num_hidden_layers * self.hidden_size - - # Attention layers - attn_params = self.num_hidden_layers * ( - # Query projection - self.num_attention_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size + - # Key and Value projections (different than query in case of GQA) - 2 * self.num_key_value_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size + - # Output projection - self.num_attention_heads * (self.hidden_size // self.num_attention_heads) * self.hidden_size - ) - - # MLP layers - mlp_params = self.num_hidden_layers * ( - # First linear layer (2 for gated) - 2* self.hidden_size * self.intermediate_size + - # Second linear layer - self.intermediate_size * self.hidden_size - ) - - - # Final RMS Norm - final_rms = self.hidden_size - - total_params = ( - embedding_params + - input_rms + - after_attention_rms + - attn_params + - mlp_params + - final_rms - ) - - return total_params \ No newline at end of file + return model_flops, hardware_flops \ No newline at end of file From 157c2aefdb68766adfa72f567cff620af2ae8900 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 3 Sep 2024 02:08:23 +0000 Subject: [PATCH 23/43] change G to B i think it's better --- src/nanotron/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/logging.py b/src/nanotron/logging.py index 708393b5..cdbd0b78 100644 --- a/src/nanotron/logging.py +++ b/src/nanotron/logging.py @@ -236,7 +236,7 @@ def warn_once( def human_format(num: float, billions: bool = False, divide_by_1024: bool = False) -> str: if abs(num) < 1: return "{:.3g}".format(num) - SIZES = ["", "K", "M", "G", "T", "P", "E"] + SIZES = ["", "K", "M", "B", "T", "P", "E"] num = float("{:.3g}".format(num)) magnitude = 0 i = 0 From 6daa717266a17dd0da12dbe0b49a81b6a809c32a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 3 Sep 2024 05:01:51 +0000 Subject: [PATCH 24/43] last fix --- launcher.py | 30 ++++++++++-------------------- slurm/eval_slurm_config.json | 5 +---- slurm/launch_slurm_config.json | 5 +---- slurm/launch_training.slurm.jinja | 2 -- src/nanotron/config/config.py | 2 +- 5 files changed, 13 insertions(+), 31 deletions(-) diff --git a/launcher.py b/launcher.py index 0c8e1020..f88ba39b 100644 --- a/launcher.py +++ b/launcher.py @@ -55,8 +55,10 @@ def set_nested_attribute(obj, path, value): args = parser.parse_args() supported_base_configs = { - 'llama-1B': "path_to_the_config", - } + "smollm-1700M-8nodes": "examples/smollm/configs/yaml/smollm-1700M-8nodes.yaml", + "smollm-360M-4nodes": "examples/smollm/configs/yaml/smollm-360M-4nodes.yaml", + "smollm-135M-4nodes": "examples/smollm/configs/yaml/smollm-135M-4nodes.yaml", + } # add your base configs here {name: path} if args.base_config is None and args.config_path is None: raise ValueError("Please provide a base config or a config path") @@ -78,35 +80,23 @@ def set_nested_attribute(obj, path, value): if config.general.logs_path is None and args.logs_path is None: raise ValueError("Please provide a logs path") - if config.model.model_config.tie_word_embeddings ==True: - tie_word_embeddings_multiplier = 1 - else: - tie_word_embeddings_multiplier = 2 - num_params = human_format( - config.model.model_config.vocab_size * config.model.model_config.hidden_size * tie_word_embeddings_multiplier - + config.model.model_config.num_hidden_layers - * ( - 3 * config.model.model_config.hidden_size * config.model.model_config.intermediate_size - + 4 * config.model.model_config.hidden_size * config.model.model_config.hidden_size - ) - ).replace(".", "p") - # Apply overrides + config.model.model_config.get_llama_param_count() + ).replace(".", ",") + if args.override: for item in args.override: if '=' not in item: raise ValueError(f"Invalid override format: {item}. Use KEY=VALUE.") key, value = item.split('=', 1) try: - # Try to evaluate the value as a Python literal value = eval(value) except: - # If eval fails, treat it as a string pass set_nested_attribute(config, key, value) - print("Applied overrides:") + print("⇄ Applied overrides:") for item in args.override: print(f" {item}") @@ -122,7 +112,7 @@ def set_nested_attribute(obj, path, value): GBS = BS * config.parallelism.dp total_tokens = config.tokens.train_steps * GBS - total_tokens_billions = total_tokens / 1e9 + total_tokens_billions = human_format(total_tokens).replace(".", ",") print(f""" 🏋️ Model Parameters: @@ -153,7 +143,7 @@ def set_nested_attribute(obj, path, value): print(f""" 📙 Training Configuration: ┌───────────────────────┬────────────────────────┐ -│ Total Tokens │ {total_tokens_billions:>21.2f}B │ +│ Total Tokens │ {total_tokens_billions:>22} │ │ Global Batch Size │ {GBS:>22,d} │ │ Batch Size (per GPU) │ {BS:>22,d} │ └───────────────────────┴────────────────────────┘ diff --git a/slurm/eval_slurm_config.json b/slurm/eval_slurm_config.json index 4cf78cf1..51bd4912 100644 --- a/slurm/eval_slurm_config.json +++ b/slurm/eval_slurm_config.json @@ -13,15 +13,12 @@ "account": null, "reservation": null, "torchrun_args": { - "rdzv_backend": "etcd-v2", - "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", - "rdzv_id": "$SLURM_JOB_ID", "node_rank": "$SLURM_PROCID", "role": "$SLURMD_NODENAME", "max_restarts": 0, "tee": 3 }, - "hf_cache": "/fsx/elie_bakouch/.cache", + "hf_cache": "~/.cache", "array": null, "mem": null, "begin": null diff --git a/slurm/launch_slurm_config.json b/slurm/launch_slurm_config.json index 86b0a1f8..b82b4bc6 100644 --- a/slurm/launch_slurm_config.json +++ b/slurm/launch_slurm_config.json @@ -13,15 +13,12 @@ "account": null, "reservation": null, "torchrun_args": { - "rdzv_backend": "etcd-v2", - "rdzv_endpoint": "etcd.hpc-cluster-hopper.hpc.internal.huggingface.tech:2379", - "rdzv_id": "$SLURM_JOB_ID", "node_rank": "$SLURM_PROCID", "role": "$SLURMD_NODENAME", "max_restarts": 0, "tee": 3 }, - "hf_cache": "/fsx/elie_bakouch/.cache", + "hf_cache": "~/.cache", "array": null, "mem": null, "begin": null diff --git a/slurm/launch_training.slurm.jinja b/slurm/launch_training.slurm.jinja index 9d4f21bd..4e71e88a 100644 --- a/slurm/launch_training.slurm.jinja +++ b/slurm/launch_training.slurm.jinja @@ -66,8 +66,6 @@ export HF_DATASETS_CACHE={{ hf_cache }} export HF_MODULES_CACHE={{ hf_cache }} export HF_HOME={{ hf_cache }} -module load cuda/12.1 - echo go $COUNT_NODE echo $HOSTNAMES diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index cfed3f1e..edeadcbd 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -190,7 +190,7 @@ class GeneralArgs: project: str run: Optional[str] = None - logs_path: Optional[str] = "./logs" + logs_path: Optional[str] = "logs" launch_slurm_config: Optional[str] = None eval_slurm_config: Optional[str] = None timestamp_with_run: Optional[str] = None From 17bfd5f70be087db76eaa139e19987bf9e8ff487 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 3 Sep 2024 05:11:11 +0000 Subject: [PATCH 25/43] create_config is smollm-135M toy example --- create_config.py | 140 ++++++++++++++++++++++++++++------------------- launcher.py | 1 + 2 files changed, 86 insertions(+), 55 deletions(-) diff --git a/create_config.py b/create_config.py index 02c8e7d7..3d2cbd57 100644 --- a/create_config.py +++ b/create_config.py @@ -32,12 +32,20 @@ ) if __name__ == "__main__": + ########################################### + ## ADAPT TO YOUR ENVIRONMENT (toy example of smollm-135M on 1 GPU) + + HF_USER_OR_ORG = "eliebak" + TRAIN_STEPS = 100 + CHECKPOINT_INTERVAL = 200 + SAVE_NAME="smollm-135M-1gpu-toy" + + + ########################################### + parser = argparse.ArgumentParser() - parser.add_argument("--save-path", help="path to save the configuration file", type=str, required=True) + parser.add_argument("--save-path", help="path to save the configuration file", type=str, default="yaml") parser.add_argument("--seed", help="seed", type=int, default=8) - parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="high") - parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", - help="Override config values. Use dot notation for nested keys.") parser.add_argument("--launch", action="store_true", help="Launch the configuration immediately") parser.add_argument("--run", help="name of the run", type=str) parser.add_argument("--logs-path", help="path to the logs folder", type=str) @@ -48,8 +56,9 @@ timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") general = GeneralArgs( project="smollm", + run="toy-smollm", seed=args.seed, - temp_dir="/scratch", + temp_dir="temp", ) model_config = LlamaConfig( @@ -68,55 +77,56 @@ rope_scaling=None, tie_word_embeddings=True, use_cache=True, - vocab_size=49152, + vocab_size=49152, ) - lighteval = LightEvalConfig( - tasks=LightEvalTasksArgs( - tasks="early-signal", # "generatives", "all" - custom_tasks="nanotron.lighteval.evaluation_tasks", - max_samples=1000, # Cap very large evals or for debugging - dataset_loading_processes=8, - ), - parallelism=ParallelismArgs( - dp=8, - pp=1, - tp=1, - pp_engine="1f1b", - tp_mode="ALL_REDUCE", - # recompute_granularity="selective", - tp_linear_async_communication=False, - ), - batch_size=16, - logging=LightEvalLoggingArgs( - local_output_path=f"/fsx/elie_bakouch/refactor-lighteval-logs/{general.project}-{general.run}", - #local_output_path=PATH_TO_LOCAL_LOG, - private=True, - push_details_to_hub=True, - push_results_to_hub=True, - push_results_to_tensorboard=True, - hf_user_or_org="eliebak", - #hf_user_or_org="USER_OR_ORG", - hub_repo_results="lighteval-results", - hub_repo_details="lighteval-details", - hub_repo_tensorboard="smollm-evals-visualization", - tensorboard_metric_prefix="eval", - ), - slurm_template="/fsx/elie_bakouch/nanotron/slurm/run_eval.slurm.jinja", - ) + # lighteval = LightEvalConfig( + # tasks=LightEvalTasksArgs( + # tasks="early-signal", # "generatives", "all" + # custom_tasks="nanotron.lighteval.evaluation_tasks", + # max_samples=1000, + # dataset_loading_processes=8, + # ), + # parallelism=ParallelismArgs( + # dp=8, + # pp=1, + # tp=1, + # pp_engine="1f1b", + # tp_mode="ALL_REDUCE", + # # recompute_granularity="selective", + # tp_linear_async_communication=False, + # ), + # batch_size=16, + # logging=LightEvalLoggingArgs( + # local_output_path="lighteval-logs", + # private=True, + # push_details_to_hub=True, + # push_results_to_hub=True, + # push_results_to_tensorboard=True, + # hf_user_or_org=HF_USER_OR_ORG, + # hub_repo_results="lighteval-results", + # hub_repo_details="lighteval-details", + # hub_repo_tensorboard="smollm-evals-visualization", + # tensorboard_metric_prefix="eval", + # ), + # slurm_template="slurm/run_eval.slurm.jinja", + # # slurm_template="slurm/run_eval_s3.slurm.jinja", if s3 + + # ) + + lighteval = None checkpoints = CheckpointsArgs( - checkpoints_path=f"/fsx/elie_bakouch/refactor-checkpoints/{general.project}-{general.run}", - #checkpoints_path="CHECKPOINTS_PATH", + checkpoints_path="checkpoints", checkpoints_path_is_shared_file_system=False, - resume_checkpoint_path="/fsx/elie_bakouch/refactor-checkpoints/smollm-%date_%jobid/60", - checkpoint_interval=20, + # resume_checkpoint_path="", + checkpoint_interval=CHECKPOINT_INTERVAL, save_initial_state=False, ) parallelism = ParallelismArgs( - dp=8, + dp=1, pp=1, tp=1, pp_engine="1f1b", @@ -126,9 +136,9 @@ tokens = TokensArgs( batch_accumulation_per_replica=8, - micro_batch_size=16, + micro_batch_size=8, sequence_length=2048, - train_steps=100, + train_steps=TRAIN_STEPS, val_check_interval=-1, ) @@ -148,12 +158,12 @@ ) learning_rate_scheduler = LRSchedulerArgs( - learning_rate=1e-4, + learning_rate=3e-3, lr_warmup_steps=10, lr_warmup_style="linear", - lr_decay_style="linear", + lr_decay_style="1-sqrt", lr_decay_steps = 20, - lr_decay_starting_step= 80, + lr_decay_starting_step=80 , min_decay_lr=0, ) @@ -176,18 +186,20 @@ tokenizer_name_or_path="HuggingFaceTB/cosmo2-tokenizer", ) + # Uncomment if you want to upload the checkpoints to s3 or load a ckpt from s3 # s3_upload = S3UploadArgs( - # upload_s3_path=f"s3://elie-exp/debug_nanotron/{general.project}-{general.run}-{timestamp}", + # upload_s3_path=f"S3_PATH", # remove_after_upload=True, # s5cmd_numworkers=16, # s5cmd_concurrency=5, - # s5cmd_path="/fsx/elie_bakouch/miniconda3/envs/smollm/bin/s5cmd", + # s5cmd_path="PATH_TO_S5CMD", # ) + data_stages=[ DatasetStageArgs( data=DataArgs( dataset=NanosetDatasetsArgs( - dataset_folder="/fsx/elie_bakouch/nanotron/datasets/cosmopedia-v2", + dataset_folder="datasets/cosmopedia-v2", ), num_loading_workers=0, seed=general.seed, @@ -195,6 +207,24 @@ name="training stage", start_training_step=1, ), + # You can add a decay stage here if you want to change the data mixture + # Example (weight are arbitrary here): + # DatasetStageArgs( + # data=DataArgs( + # dataset=NanosetDatasetsArgs( + # dataset_folder={ + # "datasets/fineweb-edu-dedup": 50, + # "datasets/cosmopedia-v2": 30, + # "datasets/python-edu": 10, + # "datasets/open-web-math": 10, + # } + # ), + # num_loading_workers=0, + # seed=general.seed, + # ), + # name="decay stage", + # start_training_step=optimizer.learning_rate_scheduler.lr_decay_starting_step, + # ), ] config = Config( @@ -213,13 +243,13 @@ save_path= Path(args.save_path) save_path.mkdir(parents=True, exist_ok=True) - config_path_yaml = save_path / f"{args.run}-{timestamp}.yaml" + config_path_yaml = save_path / f"{SAVE_NAME}.yaml" config.save_as_yaml(config_path_yaml) print(f"💾 Configuration saved in: {str(save_path)}") if args.launch: - # Change the launcher_path + # Sanity check for logs_path and run if not args.logs_path: raise ValueError("--logs_path must be defined. Please provide a path for the logs.") @@ -248,4 +278,4 @@ subprocess.run(launch_command, check=True) else: print("To launch this configuration, run:") - print(f"python {os.path.join(dir, 'launcher.py')} {config_path_yaml}") \ No newline at end of file + print(f"python 'launcher.py' configs/{str(config_path_yaml)}") \ No newline at end of file diff --git a/launcher.py b/launcher.py index f88ba39b..2f2d7e9b 100644 --- a/launcher.py +++ b/launcher.py @@ -58,6 +58,7 @@ def set_nested_attribute(obj, path, value): "smollm-1700M-8nodes": "examples/smollm/configs/yaml/smollm-1700M-8nodes.yaml", "smollm-360M-4nodes": "examples/smollm/configs/yaml/smollm-360M-4nodes.yaml", "smollm-135M-4nodes": "examples/smollm/configs/yaml/smollm-135M-4nodes.yaml", + "smollm-135M-1gpu": "examples/smollm/configs/yaml/smollm-135M-1gpu.yaml", } # add your base configs here {name: path} if args.base_config is None and args.config_path is None: From 43728d5ef2b6eff4c282c97c8b9e7c078193e854 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 3 Sep 2024 05:54:02 +0000 Subject: [PATCH 26/43] last fix --- create_config.py | 2 +- launcher.py | 29 ++++++++++++++++++++++------- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/create_config.py b/create_config.py index 3d2cbd57..a00301ac 100644 --- a/create_config.py +++ b/create_config.py @@ -47,8 +47,8 @@ parser.add_argument("--save-path", help="path to save the configuration file", type=str, default="yaml") parser.add_argument("--seed", help="seed", type=int, default=8) parser.add_argument("--launch", action="store_true", help="Launch the configuration immediately") - parser.add_argument("--run", help="name of the run", type=str) parser.add_argument("--logs-path", help="path to the logs folder", type=str) + parser.add_argument("--run", help="name of the run", type=str) parser.add_argument("--slurm", help="use slurm", action="store_true") parser.add_argument("--nodes", help="specify the number of nodes", type=int) args = parser.parse_args() diff --git a/launcher.py b/launcher.py index 2f2d7e9b..d10d7870 100644 --- a/launcher.py +++ b/launcher.py @@ -47,7 +47,7 @@ def set_nested_attribute(obj, path, value): parser.add_argument("--config-path", help="path to the configuration file", type=str, default=None) parser.add_argument("--base-config", help="base config to use", type=str, default=None) parser.add_argument("--run", help="name of the run", type=str, required=True) - parser.add_argument("--logs-path", help="path to the logs folder", type=str, default=None) + parser.add_argument("--logs-path", help="path to the logs folder", type=str, default="logs") parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", help="Override config values. Use dot notation for nested keys.") parser.add_argument("--slurm", action="store_true", help="Launch the job on Slurm") @@ -64,13 +64,16 @@ def set_nested_attribute(obj, path, value): if args.base_config is None and args.config_path is None: raise ValueError("Please provide a base config or a config path") - if args.base_config not in supported_base_configs.keys(): - raise ValueError(f"Base config {args.base_config} is not supported. Please choose one of the following: {supported_base_configs}") if args.config_path is not None and args.base_config is not None: print("Both config_path and base_config are provided. Using config_path and ignoring base_config.") args.base_config = None + if args.base_config not in supported_base_configs.keys(): + raise ValueError(f"Base config {args.base_config} is not supported. Please choose one of the following: {supported_base_configs}") + else: + args.config_path = supported_base_configs[args.base_config] + if args.slurm: if args.nodes is None: raise ValueError("When using Slurm (--slurm), you must specify the number of nodes (--nodes)") @@ -197,8 +200,8 @@ def set_nested_attribute(obj, path, value): nodes = args.nodes - launch_slurm_config_path = Path("./slurm/launch_slurm_config.json") - eval_slurm_config_path = Path("./slurm/eval_slurm_config.json") + launch_slurm_config_path = Path("slurm/launch_slurm_config.json") + eval_slurm_config_path = Path("slurm/eval_slurm_config.json") with open(launch_slurm_config_path, 'r') as f: launch_slurm_config = json.load(f) @@ -276,8 +279,6 @@ def set_nested_attribute(obj, path, value): with open(script_path, 'w') as f: f.write(sbatch_script) - - print(f" 💾 Logs are saved to : {config.general.logs_path}/{config.general.run}-{config.general.project}") print(f" 🤖 Slurm Configuration Details:") slurm_config_keys = ['qos', 'gpus_per_node', 'cpus_per_task', 'constraint', 'account', 'reservation'] @@ -285,6 +286,20 @@ def set_nested_attribute(obj, path, value): if key in launch_slurm_config: if launch_slurm_config[key] is not None: print(f" {key}: {launch_slurm_config[key]}") + + print(" ") + print(" 📁 Log structure:") + print(f" {config.general.logs_path}/{config.general.run}/") + print(f" └── {timestamp_with_run}/") + print(" ├── config/") + print(" ├── launch-script/") + print(" ├── slurm-logs/") + if hasattr(config, 'lighteval') and config.lighteval is not None: + print(" └── evals/") + print(" ├── launch-config/") + print(" └── logs/") + else: + print(" └── (No evals folder)") else: # Check if running on an interactive node From 714644da13cc31730bde596d7aa31e3211eddaa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Wed, 4 Sep 2024 06:41:40 +0000 Subject: [PATCH 27/43] update test and flavours --- .github/workflows/fa2_unit_tests.yaml | 3 +++ pyproject.toml | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/.github/workflows/fa2_unit_tests.yaml b/.github/workflows/fa2_unit_tests.yaml index 342be45e..341b651d 100644 --- a/.github/workflows/fa2_unit_tests.yaml +++ b/.github/workflows/fa2_unit_tests.yaml @@ -48,6 +48,9 @@ jobs: pip install -e . pip install -e .[dev] pip install -e .[test] + pip install -e .[nanosets] + pip install -e .[s3] + pip install -e .[lighteval] - name: Show installed libraries and their versions run: pip freeze | tee installed.txt diff --git a/pyproject.toml b/pyproject.toml index 6a0cfb83..33c5c61c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,16 @@ nanosets = [ "numba", ] +s3 = [ + "boto3", + "s3fs", + "s5cmd", +] + +lighteval = [ + "lighteval@git+https://github.com/eliebak/lighteval.git@current-nanotron", +] + [build-system] requires = [ "setuptools", From 930add6de328ead90a25190a0856fcbd488e1c88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Wed, 4 Sep 2024 06:53:00 +0000 Subject: [PATCH 28/43] forgot datasets --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 33c5c61c..66f15e7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "safetensors", "dacite", "tqdm", + "datasets", ] [tool.setuptools.packages.find] From fd213224f4c23dab6b491d9aa0ed38665084b5b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Wed, 4 Sep 2024 07:08:46 +0000 Subject: [PATCH 29/43] fix wandb import --- src/nanotron/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 9e23f381..3b98b7bc 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -7,7 +7,6 @@ from contextlib import ExitStack, contextmanager from typing import ContextManager, List, Optional import json -import wandb import os import torch From 03e0e82aaf285a09113ec015942ffb8a32095407 Mon Sep 17 00:00:00 2001 From: elie <97572401+eliebak@users.noreply.github.com> Date: Wed, 4 Sep 2024 11:43:56 +0200 Subject: [PATCH 30/43] no need to modify this --- .github/workflows/fa2_unit_tests.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/fa2_unit_tests.yaml b/.github/workflows/fa2_unit_tests.yaml index 341b651d..342be45e 100644 --- a/.github/workflows/fa2_unit_tests.yaml +++ b/.github/workflows/fa2_unit_tests.yaml @@ -48,9 +48,6 @@ jobs: pip install -e . pip install -e .[dev] pip install -e .[test] - pip install -e .[nanosets] - pip install -e .[s3] - pip install -e .[lighteval] - name: Show installed libraries and their versions run: pip freeze | tee installed.txt From ab1e3c91e39692e2dd23b16eb263aba0bbb88314 Mon Sep 17 00:00:00 2001 From: elie <97572401+eliebak@users.noreply.github.com> Date: Tue, 10 Sep 2024 12:34:00 +0200 Subject: [PATCH 31/43] remove debugging print --- src/nanotron/helpers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 374214e3..73ca3484 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -53,7 +53,6 @@ def _vocab_size_with_padding(orig_vocab_size: int, pg_size: int, make_vocab_size multiple = make_vocab_size_divisible_by * pg_size after = int(ceil(orig_vocab_size / multiple) * multiple) if after != orig_vocab_size: - print("i'm in") log_rank( f"[Vocab Size Padding] Padded vocab (size: {orig_vocab_size}) with {after - orig_vocab_size} dummy tokens (new size: {after})", logger=logger, From 7649815e6ee9178935e7bcc67e173adf0a6ef00a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 10 Sep 2024 11:58:12 +0000 Subject: [PATCH 32/43] change the lighteval path to the main repo --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ab62df83..802a30ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ s3 = [ ] lighteval = [ - "lighteval@git+https://github.com/eliebak/lighteval.git@current-nanotron", + "lighteval[nanotron]@git+https://github.com/huggingface/lighteval.git@nanotron-compatible", ] [build-system] requires = [ From 065d9b1e09dc253058a21f9ad8cd7e43a9b8e449 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 10 Sep 2024 13:25:32 +0000 Subject: [PATCH 33/43] fix the interactive cases if we request less gpus than available --- launcher.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/launcher.py b/launcher.py index d10d7870..d6b7be50 100644 --- a/launcher.py +++ b/launcher.py @@ -311,14 +311,18 @@ def set_nested_attribute(obj, path, value): if is_interactive: print("💻 Running on an interactive node with GPUs.") - - total_gpus = gpu_count - config_gpus = config.parallelism.dp * config.parallelism.tp * config.parallelism.pp - - if total_gpus != config_gpus: - raise ValueError(f"The parallelism configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " - f"doesn't match the number of available GPUs ({total_gpus}). " - f"Please adjust your configuration to match the available resources.") + gpu_config = config.parallelism.dp * config.parallelism.tp * config.parallelism.pp + if gpu_count < gpu_config: + raise ValueError(f"Error: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " + f"requires {gpu_config} GPUs, but only {gpu_count} are available.") + elif gpu_count == gpu_config: + print(f"🚀 Running on {gpu_count} GPUs, which matches your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp})") + total_gpus= gpu_count + elif gpu_count > gpu_config: + total_gpus= gpu_config + print(f"⚠️ Warning: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " + f"uses {total_gpus} GPUs, but {gpu_count} are available. " + f"You are not fully utilizing all available GPUs on this device.") config_path_yaml = f"{config.general.config_logs_path}/launch.yaml" os.makedirs("config.general.config_logs_path", exist_ok=True) @@ -327,7 +331,7 @@ def set_nested_attribute(obj, path, value): trainer_python_file = "run_train.py" cmd = f"{trainer_python_file} --config-file {args.config_path}" - launch_cmd = f"CUDA_DEVICE_MAX_CONNECTIONS='1' torchrun --nproc_per_node {gpu_count} {cmd}" + launch_cmd = f"CUDA_DEVICE_MAX_CONNECTIONS='1' torchrun --nproc_per_node {total_gpus} {cmd}" print(f"🚀 Launching interactive job with command: {launch_cmd}") subprocess.run(launch_cmd, shell=True, check=True) From a7804f50fc0205395dd0d1814abe518ea136a06e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 10 Sep 2024 13:29:39 +0000 Subject: [PATCH 34/43] remove the base-configs args --- launcher.py | 24 +++--------------------- 1 file changed, 3 insertions(+), 21 deletions(-) diff --git a/launcher.py b/launcher.py index d6b7be50..41214ed2 100644 --- a/launcher.py +++ b/launcher.py @@ -44,8 +44,7 @@ def set_nested_attribute(obj, path, value): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--config-path", help="path to the configuration file", type=str, default=None) - parser.add_argument("--base-config", help="base config to use", type=str, default=None) + parser.add_argument("--config-path", help="path to the configuration file", type=str, default=None, required=True) parser.add_argument("--run", help="name of the run", type=str, required=True) parser.add_argument("--logs-path", help="path to the logs folder", type=str, default="logs") parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", @@ -54,25 +53,8 @@ def set_nested_attribute(obj, path, value): parser.add_argument("--nodes", type=int, help="Number of nodes to use for the job") args = parser.parse_args() - supported_base_configs = { - "smollm-1700M-8nodes": "examples/smollm/configs/yaml/smollm-1700M-8nodes.yaml", - "smollm-360M-4nodes": "examples/smollm/configs/yaml/smollm-360M-4nodes.yaml", - "smollm-135M-4nodes": "examples/smollm/configs/yaml/smollm-135M-4nodes.yaml", - "smollm-135M-1gpu": "examples/smollm/configs/yaml/smollm-135M-1gpu.yaml", - } # add your base configs here {name: path} - - if args.base_config is None and args.config_path is None: - raise ValueError("Please provide a base config or a config path") - - - if args.config_path is not None and args.base_config is not None: - print("Both config_path and base_config are provided. Using config_path and ignoring base_config.") - args.base_config = None - - if args.base_config not in supported_base_configs.keys(): - raise ValueError(f"Base config {args.base_config} is not supported. Please choose one of the following: {supported_base_configs}") - else: - args.config_path = supported_base_configs[args.base_config] + if args.config_path is None: + raise ValueError("Please provide a config path") if args.slurm: if args.nodes is None: From efce15b17dff4f9ca8767a46a858a7c32e9a9e8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 10 Sep 2024 13:56:26 +0000 Subject: [PATCH 35/43] fix bs and gbs --- launcher.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/launcher.py b/launcher.py index 41214ed2..ea00c442 100644 --- a/launcher.py +++ b/launcher.py @@ -94,10 +94,15 @@ def set_nested_attribute(obj, path, value): lr_decay_start = config.optimizer.learning_rate_scheduler.lr_decay_starting_step lr_decay_style = config.optimizer.learning_rate_scheduler.lr_decay_style - BS = config.tokens.micro_batch_size*config.tokens.batch_accumulation_per_replica*config.tokens.sequence_length - GBS = BS * config.parallelism.dp - - total_tokens = config.tokens.train_steps * GBS + # Sample/Token per GPU (at once) + bs_gpu_sample = config.tokens.micro_batch_size + bs_gpu_token = bs_gpu_sample * config.tokens.sequence_length + + # Sample/Token in one step + gbs_sample = bs_gpu_sample * config.parallelism.dp*config.tokens.batch_accumulation_per_replica + gbs_token = gbs_sample * config.tokens.sequence_length + + total_tokens = config.tokens.train_steps * gbs_token total_tokens_billions = human_format(total_tokens).replace(".", ",") print(f""" @@ -130,8 +135,8 @@ def set_nested_attribute(obj, path, value): 📙 Training Configuration: ┌───────────────────────┬────────────────────────┐ │ Total Tokens │ {total_tokens_billions:>22} │ -│ Global Batch Size │ {GBS:>22,d} │ -│ Batch Size (per GPU) │ {BS:>22,d} │ +│ Batch Size (per GPU) │ {bs_gpu_token:>15,d} Tokens │ +│ Global Batch Size │ {gbs_token:>15,d} Tokens │ └───────────────────────┴────────────────────────┘ """) From 73da08639f73964e2961206daf0a64f093207e93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 10 Sep 2024 14:30:24 +0000 Subject: [PATCH 36/43] fix the logs structure --- create_config.py | 26 ++++++++++----- launcher.py | 44 ++++++++++++++++--------- src/nanotron/config/config.py | 4 +-- src/nanotron/config/lighteval_config.py | 2 +- 4 files changed, 49 insertions(+), 27 deletions(-) diff --git a/create_config.py b/create_config.py index a00301ac..31242c90 100644 --- a/create_config.py +++ b/create_config.py @@ -13,6 +13,7 @@ Config, DataArgs, NanosetDatasetsArgs, + PretrainDatasetsArgs, S3UploadArgs, CheckpointsArgs, GeneralArgs, @@ -80,7 +81,7 @@ vocab_size=49152, ) - + # Uncomment to evaluate the model on a set of tasks with lighteval during the training. # lighteval = LightEvalConfig( # tasks=LightEvalTasksArgs( # tasks="early-signal", # "generatives", "all" @@ -110,6 +111,7 @@ # hub_repo_tensorboard="smollm-evals-visualization", # tensorboard_metric_prefix="eval", # ), + # temp_dir = "temp_dir", # slurm_template="slurm/run_eval.slurm.jinja", # # slurm_template="slurm/run_eval_s3.slurm.jinja", if s3 @@ -118,9 +120,9 @@ lighteval = None checkpoints = CheckpointsArgs( - checkpoints_path="checkpoints", + # checkpoints_path="checkpoints", checkpoints_path_is_shared_file_system=False, - # resume_checkpoint_path="", + # resume_checkpoint_path="local_path/to/checkpoint" or s3_path, checkpoint_interval=CHECKPOINT_INTERVAL, save_initial_state=False, ) @@ -161,7 +163,7 @@ learning_rate=3e-3, lr_warmup_steps=10, lr_warmup_style="linear", - lr_decay_style="1-sqrt", + lr_decay_style="linear", lr_decay_steps = 20, lr_decay_starting_step=80 , min_decay_lr=0, @@ -198,11 +200,19 @@ data_stages=[ DatasetStageArgs( data=DataArgs( - dataset=NanosetDatasetsArgs( - dataset_folder="datasets/cosmopedia-v2", + # 1. Un-tokenized dataset from HuggingFace + dataset=PretrainDatasetsArgs( + hf_dataset_or_datasets="HuggingFaceTB/smollm-corpus", # feel free to replace it by a smaller one if you don't have enough memory + hf_dataset_splits="train", + hf_dataset_config_name="cosmopedia-v2", + text_column_name="text", ), - num_loading_workers=0, - seed=general.seed, + # 2. Pre-tokenized local dataset with Nanoset + # dataset=NanosetDatasetsArgs( + # dataset_folder="datasets/cosmopedia-v2", + # ), + # num_loading_workers=0, + # seed=general.seed, ), name="training stage", start_training_step=1, diff --git a/launcher.py b/launcher.py index ea00c442..49df6c1f 100644 --- a/launcher.py +++ b/launcher.py @@ -177,11 +177,13 @@ def set_nested_attribute(obj, path, value): timestamp_with_run = f"run{run_number:03d}_{timestamp}" config.general.timestamp_with_run = timestamp_with_run - config.general.config_logs_path = f"{config.general.logs_path}/{args.run}/{timestamp_with_run}/config" + config.general.config_logs_path = str(Path(config.general.logs_path) / args.run / timestamp_with_run / "config") Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) - - #making sure the logs path folder exists + if config.checkpoints.checkpoints_path is None: + config.checkpoints.checkpoints_path = str(Path(config.general.logs_path) / args.run / timestamp_with_run / "checkpoints") + Path(config.checkpoints.checkpoints_path).mkdir(parents=True, exist_ok=True) + if args.slurm: @@ -210,19 +212,24 @@ def set_nested_attribute(obj, path, value): subfolders.append('evals') for subfolder in subfolders: - folder_path = os.path.join(log_folder, subfolder) - os.makedirs(folder_path, exist_ok=True) + folder_path = str(log_folder / subfolder) + Path(folder_path).mkdir(parents=True, exist_ok=True) if subfolder == 'launch-script': config.general.launch_script_path = folder_path elif subfolder == 'slurm-logs': config.general.slurm_logs_path = folder_path elif subfolder == 'evals': config.general.evals_logs_path = folder_path - for evals_subfolder in ['launch-config', 'logs']: - evals_subfolder_path = os.path.join(config.general.evals_logs_path, evals_subfolder) - os.makedirs(evals_subfolder_path, exist_ok=True) - - + for evals_subfolder in ['launch-config', 'logs',"lighteval-logs"]: + if evals_subfolder == "lighteval-logs": + if config.lighteval.logging.local_output_path is None: + evals_subfolder_path = str(Path(config.general.evals_logs_path) / evals_subfolder) + Path(evals_subfolder_path).mkdir(parents=True, exist_ok=True) + config.lighteval.logging.local_output_path = evals_subfolder_path + else: + evals_subfolder_path = str(Path(config.general.evals_logs_path) / evals_subfolder) + Path(evals_subfolder_path).mkdir(parents=True, exist_ok=True) + torchrun_args = "" if 'torchrun_args' in launch_slurm_config and launch_slurm_config['torchrun_args']: torchrun_args = " ".join([f"--{k} {v}" for k, v in launch_slurm_config['torchrun_args'].items()]) @@ -252,7 +259,9 @@ def set_nested_attribute(obj, path, value): else: config.general.eval_slurm_config = None - config.save_as_yaml(launch_slurm_config["config_path_yaml"]) + config_path_yaml = str(Path(config.general.config_logs_path) / "launch.yaml") + Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) + config.save_as_yaml(config_path_yaml) # Launch the Slurm job job_id = launch_slurm_job(sbatch_script) @@ -260,8 +269,9 @@ def set_nested_attribute(obj, path, value): # Save the Slurm script if a path is provided if config.general.launch_script_path: - os.makedirs(config.general.launch_script_path, exist_ok=True) + Path(config.general.launch_script_path).mkdir(parents=True, exist_ok=True) script_filename = f"slurm_launch_script.slurm" + script_path = str(Path(config.general.launch_script_path) / script_filename) script_path = os.path.join(config.general.launch_script_path, script_filename) with open(script_path, 'w') as f: @@ -278,6 +288,8 @@ def set_nested_attribute(obj, path, value): print(" 📁 Log structure:") print(f" {config.general.logs_path}/{config.general.run}/") print(f" └── {timestamp_with_run}/") + if config.checkpoints.checkpoints_path == str(Path(config.general.logs_path) / args.run / timestamp_with_run / "checkpoints"): + print(" ├── checkpoints/") print(" ├── config/") print(" ├── launch-script/") print(" ├── slurm-logs/") @@ -285,8 +297,8 @@ def set_nested_attribute(obj, path, value): print(" └── evals/") print(" ├── launch-config/") print(" └── logs/") - else: - print(" └── (No evals folder)") + if config.lighteval.logging.local_output_path== str(Path(config.general.evals_logs_path) / "lighteval-logs"): + print(" └── lighteval-logs/") else: # Check if running on an interactive node @@ -311,8 +323,8 @@ def set_nested_attribute(obj, path, value): f"uses {total_gpus} GPUs, but {gpu_count} are available. " f"You are not fully utilizing all available GPUs on this device.") - config_path_yaml = f"{config.general.config_logs_path}/launch.yaml" - os.makedirs("config.general.config_logs_path", exist_ok=True) + config_path_yaml = str(Path(config.general.config_logs_path) / "launch.yaml") + os.makedirs(config.general.config_logs_path, exist_ok=True) config.save_as_yaml(config_path_yaml) trainer_python_file = "run_train.py" diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 8bdffb87..488ebf96 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -171,8 +171,8 @@ class CheckpointsArgs: resume_checkpoint_path: if you want to load from a specific checkpoint path """ - checkpoints_path: str checkpoint_interval: int + checkpoints_path: Optional[str] = None save_initial_state: Optional[bool] = False save_final_state: Optional[bool] = False resume_checkpoint_path: Optional[str] = None @@ -210,7 +210,7 @@ class GeneralArgs: slurm_logs_path: Optional[str] = None config_logs_path: Optional[str] = None evals_logs_path: Optional[str] = None - temp_dir: Optional[str] = None + temp_dir: Optional[str] = "temp_dir" seed: Optional[int] = None step: Optional[int] = None consumed_train_samples: Optional[int] = None diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index ea3ba120..fe11437d 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -93,7 +93,7 @@ class LightEvalConfig: slurm_template: Optional[str] = None slurm_script_dir: Optional[str] = None - temp_dir: Optional[str] = None + temp_dir: Optional[str] = "temp_dir" checkpoints_path: Optional[str] = None parallelism: Optional[ParallelismArgs] = None batch_size: Optional[int] = None From 67115a560fdf10fe4212419c25f0f5adc65ce44a Mon Sep 17 00:00:00 2001 From: elie <97572401+eliebak@users.noreply.github.com> Date: Fri, 13 Sep 2024 18:20:57 +0200 Subject: [PATCH 37/43] remove layer norm flops --- src/nanotron/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 1491b677..0bf80427 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -920,7 +920,6 @@ def get_block_compute_costs(self): LlamaDecoderLayer: 2 * model_config.num_attention_heads * d_qkv * model_config.hidden_size # Q output projection + 2 * model_config.num_key_value_heads * d_qkv * model_config.hidden_size # KV + 3 * d_ff * model_config.hidden_size # for the MLP (3 because of the gated mechanism) - + 2 * model_config.hidden_size, # for the layernorm # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } @@ -1172,4 +1171,4 @@ def get_flops( hardware_flops = model_flops # TODO: This is a placeholder for now - return model_flops, hardware_flops \ No newline at end of file + return model_flops, hardware_flops From 6249264247a0d7a4e6e833ea0482525bcbd5418f Mon Sep 17 00:00:00 2001 From: elie <97572401+eliebak@users.noreply.github.com> Date: Fri, 13 Sep 2024 18:22:14 +0200 Subject: [PATCH 38/43] forget comma --- src/nanotron/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 0bf80427..80522d62 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -919,7 +919,7 @@ def get_block_compute_costs(self): # CausalSelfAttention (qkv proj + attn out) + MLP LlamaDecoderLayer: 2 * model_config.num_attention_heads * d_qkv * model_config.hidden_size # Q output projection + 2 * model_config.num_key_value_heads * d_qkv * model_config.hidden_size # KV - + 3 * d_ff * model_config.hidden_size # for the MLP (3 because of the gated mechanism) + + 3 * d_ff * model_config.hidden_size # for the MLP (3 because of the gated mechanism), # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } From 11d60c875202b4032037675e5dd87c5f64d6f28c Mon Sep 17 00:00:00 2001 From: elie <97572401+eliebak@users.noreply.github.com> Date: Fri, 13 Sep 2024 20:20:26 +0200 Subject: [PATCH 39/43] put the comma in the right place --- src/nanotron/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 80522d62..e7510e58 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -919,7 +919,7 @@ def get_block_compute_costs(self): # CausalSelfAttention (qkv proj + attn out) + MLP LlamaDecoderLayer: 2 * model_config.num_attention_heads * d_qkv * model_config.hidden_size # Q output projection + 2 * model_config.num_key_value_heads * d_qkv * model_config.hidden_size # KV - + 3 * d_ff * model_config.hidden_size # for the MLP (3 because of the gated mechanism), + + 3 * d_ff * model_config.hidden_size, # for the MLP (3 because of the gated mechanism) # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } From 5e8361c06f4652ec3d98d6ab7cb1760e75896cb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Fri, 20 Sep 2024 05:26:08 +0000 Subject: [PATCH 40/43] adapt it to the current lighteval main --- create_config.py | 108 ++---- launcher.py | 249 ++++++++------ pyproject.toml | 2 +- slurm/run_eval.slurm.jinja | 27 +- slurm/run_eval_s3.slurm.jinja | 27 +- src/nanotron/config/config.py | 49 ++- src/nanotron/config/lighteval_config.py | 61 +--- src/nanotron/config/utils_config.py | 2 +- src/nanotron/lighteval/evaluation_tasks.py | 361 ++++++++++----------- src/nanotron/lighteval/one_job_runner.py | 86 +++-- src/nanotron/lighteval/run_evals.py | 12 +- src/nanotron/trainer.py | 64 ++-- 12 files changed, 520 insertions(+), 528 deletions(-) diff --git a/create_config.py b/create_config.py index 31242c90..f09df36e 100644 --- a/create_config.py +++ b/create_config.py @@ -1,57 +1,42 @@ -import os -from pathlib import Path -import subprocess -from datetime import datetime -import math -import torch - import argparse +import math +from datetime import datetime +from pathlib import Path -from nanotron.models.llama import LlamaConfig - +import torch from nanotron.config import ( + AdamWOptimizerArgs, + CheckpointsArgs, Config, DataArgs, - NanosetDatasetsArgs, - PretrainDatasetsArgs, - S3UploadArgs, - CheckpointsArgs, + DatasetStageArgs, GeneralArgs, - LightEvalConfig, - LightEvalLoggingArgs, - LightEvalTasksArgs, LoggingArgs, LRSchedulerArgs, ModelArgs, OptimizerArgs, - AdamWOptimizerArgs, ParallelismArgs, + PretrainDatasetsArgs, RandomInit, TokenizerArgs, TokensArgs, - DatasetStageArgs, ) +from nanotron.models.llama import LlamaConfig if __name__ == "__main__": ########################################### ## ADAPT TO YOUR ENVIRONMENT (toy example of smollm-135M on 1 GPU) - HF_USER_OR_ORG = "eliebak" + HF_USER_OR_ORG = None TRAIN_STEPS = 100 CHECKPOINT_INTERVAL = 200 - SAVE_NAME="smollm-135M-1gpu-toy" - + SAVE_NAME = "smollm-135M-1gpu-toy" ########################################### parser = argparse.ArgumentParser() parser.add_argument("--save-path", help="path to save the configuration file", type=str, default="yaml") parser.add_argument("--seed", help="seed", type=int, default=8) - parser.add_argument("--launch", action="store_true", help="Launch the configuration immediately") - parser.add_argument("--logs-path", help="path to the logs folder", type=str) - parser.add_argument("--run", help="name of the run", type=str) - parser.add_argument("--slurm", help="use slurm", action="store_true") - parser.add_argument("--nodes", help="specify the number of nodes", type=int) args = parser.parse_args() timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") @@ -78,7 +63,7 @@ rope_scaling=None, tie_word_embeddings=True, use_cache=True, - vocab_size=49152, + vocab_size=49152, ) # Uncomment to evaluate the model on a set of tasks with lighteval during the training. @@ -100,24 +85,16 @@ # ), # batch_size=16, # logging=LightEvalLoggingArgs( - # local_output_path="lighteval-logs", - # private=True, - # push_details_to_hub=True, - # push_results_to_hub=True, - # push_results_to_tensorboard=True, - # hf_user_or_org=HF_USER_OR_ORG, - # hub_repo_results="lighteval-results", - # hub_repo_details="lighteval-details", - # hub_repo_tensorboard="smollm-evals-visualization", + # output_dir=None, + # push_to_hub=True, + # push_to_tensorboard=True, + # public_run=False, + # results_org=HF_USER_OR_ORG, # tensorboard_metric_prefix="eval", # ), - # temp_dir = "temp_dir", - # slurm_template="slurm/run_eval.slurm.jinja", - # # slurm_template="slurm/run_eval_s3.slurm.jinja", if s3 - # ) - lighteval = None + # lighteval = None checkpoints = CheckpointsArgs( # checkpoints_path="checkpoints", @@ -137,7 +114,7 @@ ) tokens = TokensArgs( - batch_accumulation_per_replica=8, + batch_accumulation_per_replica=1, micro_batch_size=8, sequence_length=2048, train_steps=TRAIN_STEPS, @@ -147,7 +124,7 @@ model = ModelArgs( model_config=model_config, init_method=RandomInit( - std=1/math.sqrt(model_config.hidden_size), + std=1 / math.sqrt(model_config.hidden_size), ), dtype=torch.bfloat16, ) @@ -164,12 +141,11 @@ lr_warmup_steps=10, lr_warmup_style="linear", lr_decay_style="linear", - lr_decay_steps = 20, - lr_decay_starting_step=80 , + lr_decay_steps=20, + lr_decay_starting_step=80, min_decay_lr=0, ) - optimizer = OptimizerArgs( zero_stage=0, weight_decay=0.01, @@ -197,12 +173,12 @@ # s5cmd_path="PATH_TO_S5CMD", # ) - data_stages=[ + data_stages = [ DatasetStageArgs( data=DataArgs( # 1. Un-tokenized dataset from HuggingFace dataset=PretrainDatasetsArgs( - hf_dataset_or_datasets="HuggingFaceTB/smollm-corpus", # feel free to replace it by a smaller one if you don't have enough memory + hf_dataset_or_datasets="HuggingFaceTB/smollm-corpus", # feel free to replace it by a smaller one if you don't have enough memory hf_dataset_splits="train", hf_dataset_config_name="cosmopedia-v2", text_column_name="text", @@ -250,42 +226,12 @@ lighteval=lighteval, ) - save_path= Path(args.save_path) + save_path = Path(args.save_path) save_path.mkdir(parents=True, exist_ok=True) config_path_yaml = save_path / f"{SAVE_NAME}.yaml" config.save_as_yaml(config_path_yaml) print(f"💾 Configuration saved in: {str(save_path)}") - - if args.launch: - - # Sanity check for logs_path and run - if not args.logs_path: - raise ValueError("--logs_path must be defined. Please provide a path for the logs.") - if not args.run: - raise ValueError("--run must be defined. Please provide a name for the run.") - - launcher_path = Path("launcher.py") - if not launcher_path.exists(): - raise FileNotFoundError(f"Launcher not found at {launcher_path}. Please ensure the file exists or change the launcher path in the create_config.py file.") - launch_command = [ - "python", str(launcher_path), - "--config-path", str(config_path_yaml), - ] - launch_command.extend([ - "--logs-path", args.logs_path, - "--run", args.run - ]) - if args.slurm: - launch_command.append("--slurm") - - if args.nodes: - launch_command.extend(["--nodes", str(args.nodes)]) - - - print(f"🧪 Launching configuration with command: {' '.join(launch_command)}") - subprocess.run(launch_command, check=True) - else: - print("To launch this configuration, run:") - print(f"python 'launcher.py' configs/{str(config_path_yaml)}") \ No newline at end of file + print("To launch this configuration, run:") + print(f"python launcher.py --config-path configs/{str(config_path_yaml)}") diff --git a/launcher.py b/launcher.py index 49df6c1f..c8dd4a75 100644 --- a/launcher.py +++ b/launcher.py @@ -1,23 +1,26 @@ +import argparse +import json import os -from pathlib import Path import subprocess import tempfile from datetime import datetime +from pathlib import Path + import torch -import argparse -import json from jinja2 import Template - -from nanotron.logging import human_format - from nanotron.config import ( Config, get_config_from_file, + save_as_yaml, ) +from nanotron.config.lighteval_config import LightEvalConfig +from nanotron.logging import human_format + def count_subdirectories(path): return sum(os.path.isdir(os.path.join(path, item)) for item in os.listdir(path)) + def launch_slurm_job(launch_file_contents, *args): """ Small helper function to save a sbatch script and call it. @@ -33,22 +36,26 @@ def launch_slurm_job(launch_file_contents, *args): f.flush() return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[-1] + def set_nested_attribute(obj, path, value): - parts = path.split('.') + parts = path.split(".") for part in parts[:-1]: if not hasattr(obj, part): - setattr(obj, part, type('', (), {})()) + setattr(obj, part, type("", (), {})()) obj = getattr(obj, part) setattr(obj, parts[-1], value) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config-path", help="path to the configuration file", type=str, default=None, required=True) - parser.add_argument("--run", help="name of the run", type=str, required=True) + parser.add_argument("--project", help="name of the project", type=str) + parser.add_argument("--run", help="name of the run", type=str) parser.add_argument("--logs-path", help="path to the logs folder", type=str, default="logs") - parser.add_argument("--override", nargs="+", metavar="KEY=VALUE", - help="Override config values. Use dot notation for nested keys.") + parser.add_argument( + "--override", nargs="+", metavar="KEY=VALUE", help="Override config values. Use dot notation for nested keys." + ) parser.add_argument("--slurm", action="store_true", help="Launch the job on Slurm") parser.add_argument("--nodes", type=int, help="Number of nodes to use for the job") args = parser.parse_args() @@ -65,27 +72,34 @@ def set_nested_attribute(obj, path, value): if config.general.logs_path is None and args.logs_path is None: raise ValueError("Please provide a logs path") + if config.general.project is None and args.project is None: + raise ValueError("Please provide a project name") + elif args.project is not None: + config.general.project = args.project + + if config.general.run is None and args.run is None: + raise ValueError("Please provide a run name") + elif args.run is not None: + config.general.run = args.run - num_params = human_format( - config.model.model_config.get_llama_param_count() - ).replace(".", ",") + num_params = human_format(config.model.model_config.get_llama_param_count()).replace(".", ",") if args.override: for item in args.override: - if '=' not in item: + if "=" not in item: raise ValueError(f"Invalid override format: {item}. Use KEY=VALUE.") - key, value = item.split('=', 1) + key, value = item.split("=", 1) try: value = eval(value) except: pass - + set_nested_attribute(config, key, value) print("⇄ Applied overrides:") for item in args.override: print(f" {item}") - + # Calculate and print learning rate and global batch size information lr_initial = config.optimizer.learning_rate_scheduler.learning_rate lr_min = config.optimizer.learning_rate_scheduler.min_decay_lr @@ -99,13 +113,14 @@ def set_nested_attribute(obj, path, value): bs_gpu_token = bs_gpu_sample * config.tokens.sequence_length # Sample/Token in one step - gbs_sample = bs_gpu_sample * config.parallelism.dp*config.tokens.batch_accumulation_per_replica + gbs_sample = bs_gpu_sample * config.parallelism.dp * config.tokens.batch_accumulation_per_replica gbs_token = gbs_sample * config.tokens.sequence_length total_tokens = config.tokens.train_steps * gbs_token total_tokens_billions = human_format(total_tokens).replace(".", ",") - print(f""" + print( + f""" 🏋️ Model Parameters: ┌───────────────────────┬────────────────────────┐ │ Total Parameters │ {num_params:>22} │ @@ -117,30 +132,36 @@ def set_nested_attribute(obj, path, value): │ Tokenizer │ {config.tokenizer.tokenizer_name_or_path[:22]:>22} │ │ Vocab Size │ {config.model.model_config.vocab_size:>22d} │ └───────────────────────┴────────────────────────┘ -""") +""" + ) num_nodes = args.nodes if args.slurm else 1 - print(f""" + print( + f""" 🎛️ Parallelism Configuration: ┌───────────────────────┬────────────────────────┐ │ Nodes │ {num_nodes:>22d} │ │ Total GPUs │ {config.parallelism.dp*config.parallelism.pp*config.parallelism.tp:>22d} │ │ Data Parallel (DP) │ {config.parallelism.dp:>22d} │ │ Pipeline Parallel (PP)│ {config.parallelism.pp:>22d} │ -│ Tensor Parallel (TP) │ {config.parallelism.tp:>22d} │ +│ Tensor Parallel (TP) │ {config.parallelism.tp:>22d} │ └───────────────────────┴────────────────────────┘ -""") +""" + ) - print(f""" + print( + f""" 📙 Training Configuration: ┌───────────────────────┬────────────────────────┐ │ Total Tokens │ {total_tokens_billions:>22} │ │ Batch Size (per GPU) │ {bs_gpu_token:>15,d} Tokens │ │ Global Batch Size │ {gbs_token:>15,d} Tokens │ └───────────────────────┴────────────────────────┘ -""") +""" + ) - print(f""" + print( + f""" 📊 Learning Rate Schedule: ┌───────────────────────┬────────────────────────┐ │ Initial LR │ {lr_initial:>22.2e} │ @@ -151,8 +172,10 @@ def set_nested_attribute(obj, path, value): │ Decay Steps │ {lr_decay_steps:>22d} │ │ Final LR │ {lr_min:>22.2e} │ └───────────────────────┴────────────────────────┘ -""") - print(f""" +""" + ) + print( + f""" 🔧 Optimization Configuration: ┌───────────────────────┬────────────────────────┐ │ Optimizer │ {config.optimizer.optimizer_factory.__class__.__name__:>22} │ @@ -164,88 +187,109 @@ def set_nested_attribute(obj, path, value): │ ZeRO Stage │ {config.optimizer.zero_stage:>22d} │ │ FP32 Grad Accumulation│ {str(config.optimizer.accumulate_grad_in_fp32):>22} │ └───────────────────────┴────────────────────────┘ -""") +""" + ) config.general.logs_path = args.logs_path - config.general.run = args.run - path = Path(args.logs_path) / f"{args.run}" + path = Path(args.logs_path) / f"{config.general.run}" path.mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - run_number = count_subdirectories(f"{args.logs_path}/{args.run}") + 1 + run_number = count_subdirectories(f"{args.logs_path}/{config.general.run}") + 1 timestamp_with_run = f"run{run_number:03d}_{timestamp}" config.general.timestamp_with_run = timestamp_with_run - config.general.config_logs_path = str(Path(config.general.logs_path) / args.run / timestamp_with_run / "config") + config.general.config_logs_path = str( + Path(config.general.logs_path) / config.general.run / timestamp_with_run / "config" + ) Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) if config.checkpoints.checkpoints_path is None: - config.checkpoints.checkpoints_path = str(Path(config.general.logs_path) / args.run / timestamp_with_run / "checkpoints") + config.checkpoints.checkpoints_path = str( + Path(config.general.logs_path) / config.general.run / timestamp_with_run / "checkpoints" + ) Path(config.checkpoints.checkpoints_path).mkdir(parents=True, exist_ok=True) - if args.slurm: - + nodes = args.nodes launch_slurm_config_path = Path("slurm/launch_slurm_config.json") - eval_slurm_config_path = Path("slurm/eval_slurm_config.json") - - with open(launch_slurm_config_path, 'r') as f: + if config.lighteval is not None: + eval_slurm_config_path = Path("slurm/eval_slurm_config.json") + if eval_slurm_config_path.exists(): + config.general.eval_slurm_config = str(eval_slurm_config_path.resolve()) + else: + raise ValueError("Lighteval SLURM configuration is required but not provided.") + if config.general.is_s3_available: + config.general.eval_slurm_template = "slurm/run_eval_s3.slurm.jinja" + else: + config.general.eval_slurm_template = "slurm/run_eval.slurm.jinja" + + with open(launch_slurm_config_path, "r") as f: launch_slurm_config = json.load(f) - - + total_gpus = config.parallelism.dp * config.parallelism.pp * config.parallelism.tp - gpus_per_node = launch_slurm_config.get('gpus_per_node') - required_nodes = (total_gpus + gpus_per_node - 1) // gpus_per_node # Ceiling division + gpus_per_node = launch_slurm_config.get("gpus_per_node") + if total_gpus < gpus_per_node: + required_nodes = 1 + gpus_per_node = total_gpus + print( + "Warning: The total number of GPUs is less than the GPUs per node. You need to adjust to use all available GPUs." + ) + else: + required_nodes = (total_gpus + gpus_per_node - 1) // gpus_per_node # Ceiling division if args.nodes != required_nodes: - raise ValueError(f"Number of nodes in config ({args.nodes}) does not match the required number of nodes ({required_nodes}) based on the parallelism configuration.") + raise ValueError( + f"Number of nodes in config ({args.nodes}) does not match the required number of nodes ({required_nodes}) based on the parallelism configuration." + ) - # Create necessary folders project_log_folder = Path(config.general.logs_path) - log_folder = project_log_folder / f"{args.run}"/ f"{timestamp_with_run}" - subfolders = ['launch-script', 'slurm-logs'] - if hasattr(config, 'lighteval') and config.lighteval is not None: - subfolders.append('evals') + log_folder = project_log_folder / f"{config.general.run}" / f"{timestamp_with_run}" + subfolders = ["launch-script", "slurm-logs"] + if hasattr(config, "lighteval") and config.lighteval is not None: + subfolders.append("evals") for subfolder in subfolders: folder_path = str(log_folder / subfolder) Path(folder_path).mkdir(parents=True, exist_ok=True) - if subfolder == 'launch-script': + if subfolder == "launch-script": config.general.launch_script_path = folder_path - elif subfolder == 'slurm-logs': + elif subfolder == "slurm-logs": config.general.slurm_logs_path = folder_path - elif subfolder == 'evals': + elif subfolder == "evals": config.general.evals_logs_path = folder_path - for evals_subfolder in ['launch-config', 'logs',"lighteval-logs"]: + for evals_subfolder in ["launch-config", "logs", "lighteval-logs"]: if evals_subfolder == "lighteval-logs": - if config.lighteval.logging.local_output_path is None: + if config.lighteval.logging.output_dir is None: evals_subfolder_path = str(Path(config.general.evals_logs_path) / evals_subfolder) Path(evals_subfolder_path).mkdir(parents=True, exist_ok=True) - config.lighteval.logging.local_output_path = evals_subfolder_path + config.lighteval.logging.output_dir = evals_subfolder_path else: evals_subfolder_path = str(Path(config.general.evals_logs_path) / evals_subfolder) Path(evals_subfolder_path).mkdir(parents=True, exist_ok=True) torchrun_args = "" - if 'torchrun_args' in launch_slurm_config and launch_slurm_config['torchrun_args']: - torchrun_args = " ".join([f"--{k} {v}" for k, v in launch_slurm_config['torchrun_args'].items()]) - - launch_slurm_config.update({ - "job_name": f"{config.general.project}-{config.general.run}", - "nodes": args.nodes, - "slurm_logs_path": config.general.slurm_logs_path, - "path_to_trainer_python_file": os.path.join(os.path.dirname(__file__), "run_train.py"), - "config_path_yaml": f"{config.general.config_logs_path}/launch.yaml", - "torchrun_args": torchrun_args, - }) + if "torchrun_args" in launch_slurm_config and launch_slurm_config["torchrun_args"]: + torchrun_args = " ".join([f"--{k} {v}" for k, v in launch_slurm_config["torchrun_args"].items()]) + + launch_slurm_config.update( + { + "job_name": f"{config.general.project}-{config.general.run}", + "nodes": args.nodes, + "slurm_logs_path": config.general.slurm_logs_path, + "path_to_trainer_python_file": os.path.join(os.path.dirname(__file__), "run_train.py"), + "config_path_yaml": f"{config.general.config_logs_path}/launch_config.yaml", + "torchrun_args": torchrun_args, + } + ) # Load Jinja2 template template_path = Path("slurm/launch_training.slurm.jinja") - with open(template_path, 'r') as f: + with open(template_path, "r") as f: template = Template(f.read()) # Render the template @@ -254,15 +298,18 @@ def set_nested_attribute(obj, path, value): config.general.launch_slurm_config = str(launch_slurm_config_path.resolve()) else: config.general.launch_slurm_config = None - if eval_slurm_config_path.exists(): - config.general.eval_slurm_config = str(eval_slurm_config_path.resolve()) - else: - config.general.eval_slurm_config = None - config_path_yaml = str(Path(config.general.config_logs_path) / "launch.yaml") + if config.lighteval is not None: + # Save the lighteval configuration + lighteval_config = config.lighteval + Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) + config.general.lighteval_config_path = str(Path(config.general.config_logs_path) / "lighteval_config.yaml") + save_as_yaml(lighteval_config, LightEvalConfig, config.general.lighteval_config_path) + + config_path_yaml = str(Path(config.general.config_logs_path) / "launch_config.yaml") Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) config.save_as_yaml(config_path_yaml) - + # Launch the Slurm job job_id = launch_slurm_job(sbatch_script) print(f"🚀 Slurm job launched with id={job_id}") @@ -270,34 +317,36 @@ def set_nested_attribute(obj, path, value): # Save the Slurm script if a path is provided if config.general.launch_script_path: Path(config.general.launch_script_path).mkdir(parents=True, exist_ok=True) - script_filename = f"slurm_launch_script.slurm" + script_filename = "slurm_launch_script.slurm" script_path = str(Path(config.general.launch_script_path) / script_filename) script_path = os.path.join(config.general.launch_script_path, script_filename) - - with open(script_path, 'w') as f: + + with open(script_path, "w") as f: f.write(sbatch_script) - print(f" 🤖 Slurm Configuration Details:") + print(" 🤖 Slurm Configuration Details:") - slurm_config_keys = ['qos', 'gpus_per_node', 'cpus_per_task', 'constraint', 'account', 'reservation'] + slurm_config_keys = ["qos", "gpus_per_node", "cpus_per_task", "constraint", "account", "reservation"] for key in slurm_config_keys: if key in launch_slurm_config: if launch_slurm_config[key] is not None: print(f" {key}: {launch_slurm_config[key]}") - + print(" ") print(" 📁 Log structure:") print(f" {config.general.logs_path}/{config.general.run}/") print(f" └── {timestamp_with_run}/") - if config.checkpoints.checkpoints_path == str(Path(config.general.logs_path) / args.run / timestamp_with_run / "checkpoints"): + if config.checkpoints.checkpoints_path == str( + Path(config.general.logs_path) / config.general.run / timestamp_with_run / "checkpoints" + ): print(" ├── checkpoints/") print(" ├── config/") print(" ├── launch-script/") print(" ├── slurm-logs/") - if hasattr(config, 'lighteval') and config.lighteval is not None: + if hasattr(config, "lighteval") and config.lighteval is not None: print(" └── evals/") print(" ├── launch-config/") print(" └── logs/") - if config.lighteval.logging.local_output_path== str(Path(config.general.evals_logs_path) / "lighteval-logs"): + if config.lighteval.logging.output_dir == str(Path(config.general.evals_logs_path) / "lighteval-logs"): print(" └── lighteval-logs/") else: @@ -312,27 +361,35 @@ def set_nested_attribute(obj, path, value): print("💻 Running on an interactive node with GPUs.") gpu_config = config.parallelism.dp * config.parallelism.tp * config.parallelism.pp if gpu_count < gpu_config: - raise ValueError(f"Error: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " - f"requires {gpu_config} GPUs, but only {gpu_count} are available.") + raise ValueError( + f"Error: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " + f"requires {gpu_config} GPUs, but only {gpu_count} are available." + ) elif gpu_count == gpu_config: - print(f"🚀 Running on {gpu_count} GPUs, which matches your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp})") - total_gpus= gpu_count + print( + f"🚀 Running on {gpu_count} GPUs, which matches your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp})" + ) + total_gpus = gpu_count elif gpu_count > gpu_config: - total_gpus= gpu_config - print(f"⚠️ Warning: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " - f"uses {total_gpus} GPUs, but {gpu_count} are available. " - f"You are not fully utilizing all available GPUs on this device.") - - config_path_yaml = str(Path(config.general.config_logs_path) / "launch.yaml") + total_gpus = gpu_config + print( + f"⚠️ Warning: Your configuration (dp={config.parallelism.dp}, tp={config.parallelism.tp}, pp={config.parallelism.pp}) " + f"uses {total_gpus} GPUs, but {gpu_count} are available. " + f"You are not fully utilizing all available GPUs on this device." + ) + + config_path_yaml = str(Path(config.general.config_logs_path) / "launch_config.yaml") os.makedirs(config.general.config_logs_path, exist_ok=True) config.save_as_yaml(config_path_yaml) trainer_python_file = "run_train.py" - cmd = f"{trainer_python_file} --config-file {args.config_path}" + cmd = f"{trainer_python_file} --config-file {config_path_yaml}" launch_cmd = f"CUDA_DEVICE_MAX_CONNECTIONS='1' torchrun --nproc_per_node {total_gpus} {cmd}" print(f"🚀 Launching interactive job with command: {launch_cmd}") - + subprocess.run(launch_cmd, shell=True, check=True) else: - print("❌ Not running on a Slurm cluster or an interactive node with GPUs. Please submit a Slurm job or use an interactive node with GPUs.") \ No newline at end of file + print( + "❌ Not running on a Slurm cluster or an interactive node with GPUs. Please submit a Slurm job or use an interactive node with GPUs." + ) diff --git a/pyproject.toml b/pyproject.toml index 802a30ab..dbde7f0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ s3 = [ ] lighteval = [ - "lighteval[nanotron]@git+https://github.com/huggingface/lighteval.git@nanotron-compatible", + "lighteval[nanotron]@git+https://github.com/huggingface/lighteval.git", ] [build-system] requires = [ diff --git a/slurm/run_eval.slurm.jinja b/slurm/run_eval.slurm.jinja index 8cd3ee5a..6444858a 100644 --- a/slurm/run_eval.slurm.jinja +++ b/slurm/run_eval.slurm.jinja @@ -70,17 +70,26 @@ echo go $COUNT_NODE echo $HOSTNAMES -torch_dist_args="--nproc_per_node 8 \ +CMD="/fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path {{ model_checkpoint_path }}/config.yaml \ + --lighteval-config-path {{ lighteval_config_path }} \ + " + +export LAUNCHER="torchrun \ + --nproc_per_node {{ gpus_per_node }} \ --nnodes $COUNT_NODE \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: \ --max_restarts 0 \ --tee 3 \ - --node_rank $SLURM_PROCID \ - --role $SLURMD_NODENAME: " - -launch_args="$torch_dist_args \ - /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ - --checkpoint-config-path {{ model_checkpoint_path }}/config.yaml \ - --hf-user-or-org {{ hf_user_or_org }} \ " -srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file +# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub +random_milliseconds=$(( RANDOM % 1001 )) +sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") +echo "Sleeping for $sleep_time seconds..." +sleep $sleep_time + +launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" diff --git a/slurm/run_eval_s3.slurm.jinja b/slurm/run_eval_s3.slurm.jinja index ee467274..04441638 100644 --- a/slurm/run_eval_s3.slurm.jinja +++ b/slurm/run_eval_s3.slurm.jinja @@ -75,17 +75,26 @@ echo $HOSTNAMES mkdir -p $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER s5cmd cp --exclude "optimizer/*" {{ model_checkpoint_path }}* $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER -torch_dist_args="--nproc_per_node 8 \ +CMD="/fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ + --checkpoint-config-path $LOCAL_DOWNLOAD_CHECKPOINT_FOLDER/config.yaml \ + --lighteval-config-path {{ lighteval_config_path }} \ + " + +export LAUNCHER="torchrun \ + --nproc_per_node {{ gpus_per_node }} \ --nnodes $COUNT_NODE \ + --node_rank $SLURM_PROCID \ + --role $SLURMD_NODENAME: \ --max_restarts 0 \ --tee 3 \ - --node_rank $SLURM_PROCID \ - --role $SLURMD_NODENAME: " - -launch_args="$torch_dist_args \ - /fsx/elie_bakouch/nanotron/src/nanotron/lighteval/run_evals.py \ - --checkpoint-config-path ${LOCAL_DOWNLOAD_CHECKPOINT_FOLDER}/config.yaml \ - --hf-user-or-org {{ hf_user_or_org }} \ " -srun -u bash -c "python3 -u -m torch.distributed.run ${launch_args}" \ No newline at end of file +# Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub +random_milliseconds=$(( RANDOM % 1001 )) +sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") +echo "Sleeping for $sleep_time seconds..." +sleep $sleep_time + +launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" + +srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 488ebf96..27105ee8 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -1,18 +1,16 @@ import datetime import os from dataclasses import dataclass, fields -import pathlib from pathlib import Path -from datasets.download.streaming_download_manager import xPath -from typing import List, Optional, Type, Union, Dict +from typing import List, Optional, Type, Union import dacite import torch import yaml from dacite import from_dict -from datasets.download.streaming_download_manager import xPath from yaml.loader import SafeLoader +from datasets.download.streaming_download_manager import xPath from nanotron.config.lighteval_config import LightEvalConfig from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit from nanotron.config.parallelism_config import ParallelismArgs @@ -22,11 +20,11 @@ cast_str_to_torch_dtype, serialize, ) -from nanotron.s3_checkpoints import check_path_is_local from nanotron.generation.sampler import SamplerType from nanotron.logging import get_logger from nanotron.parallel.pipeline_parallel.engine import PipelineEngine from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.s3_checkpoints import check_path_is_local logger = get_logger(__name__) @@ -93,12 +91,16 @@ def __post_init__(self): self.text_column_name = "text" if self.hf_dataset_splits is None: self.hf_dataset_splits = "train" + + @dataclass class S3UploadArgs: """Arguments related to uploading checkpoints on s3""" remove_after_upload: bool - upload_s3_path: Optional[str] = None # set to None if we want to use S3UploadArgs to download checkpoints from s3 but not upload checkpoints on s3 + upload_s3_path: Optional[ + str + ] = None # set to None if we want to use S3UploadArgs to download checkpoints from s3 but not upload checkpoints on s3 s5cmd_numworkers: Optional[int] = None s5cmd_concurrency: Optional[int] = None s5cmd_path: Optional[str] = None @@ -109,6 +111,7 @@ def __post_init__(self): if isinstance(self.s5cmd_path, str): self.s5cmd_path = Path(self.s5cmd_path) + @dataclass class S3UploadArgs: """Arguments related to uploading checkpoints on s3""" @@ -200,17 +203,20 @@ class GeneralArgs: ignore_sanity_checks: Whether to ignore sanity checks """ - project: str + project: Optional[str] = None run: Optional[str] = None - logs_path: Optional[str] = "logs" + logs_path: Optional[str] = None launch_slurm_config: Optional[str] = None eval_slurm_config: Optional[str] = None + eval_slurm_template: Optional[str] = None + lighteval_config_path: Optional[str] = None + is_s3_available: Optional[bool] = None timestamp_with_run: Optional[str] = None launch_script_path: Optional[str] = None slurm_logs_path: Optional[str] = None config_logs_path: Optional[str] = None evals_logs_path: Optional[str] = None - temp_dir: Optional[str] = "temp_dir" + temp_dir: Optional[str] = None seed: Optional[int] = None step: Optional[int] = None consumed_train_samples: Optional[int] = None @@ -389,15 +395,18 @@ def create_empty(cls): return cls(**{f.name: None for f in cls_fields}) def __post_init__(self): - - if hasattr(self, '_post_init_done'): + + if hasattr(self, "_post_init_done"): return self._post_init_done = True self.general.__post_init__() if self.s3_upload is not None: self.s3_upload.__post_init__() + self.general.is_s3_available = True + else: + self.general.is_s3_available = False # Some final sanity checks across separate arguments sections: if self.profiler is not None and self.profiler.profiler_export_path is not None: assert self.tokens.train_steps < 10 @@ -430,18 +439,14 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" - - # if lighteval, we need tokenizer to be defined if self.lighteval is not None: assert self.tokenizer.tokenizer_name_or_path is not None - @property def global_batch_size(self): return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp - def save_as_yaml(self, file_path: str): config_dict = serialize(self) @@ -514,11 +519,10 @@ def get_config_from_file( skip_unused_config_keys: whether to skip unused first-nesting-level keys in the config file (for config with additional sections) skip_null_keys: whether to skip keys with value None at first and second nesting level """ - + with open(config_path) as f: config_dict = yaml.load(f, Loader=SafeLoader) - config = get_config_from_dict( config_dict, config_class=config_class, @@ -532,3 +536,14 @@ def get_config_from_file( ) config.model.model_config = model_config_class(**config.model.model_config) return config + + +def save_as_yaml(config, config_class, file_path: str): + + config_dict = serialize(config) + file_path = str(file_path) + with open(file_path, "w") as f: + yaml.dump(config_dict, f) + + # Sanity test config can be reloaded + _ = get_config_from_file(file_path, config_class=config_class) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index fe11437d..3808d60c 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from pathlib import Path from typing import Dict, Optional, Union from nanotron.config.parallelism_config import ParallelismArgs @@ -32,55 +31,28 @@ def __post_init__(self): @dataclass class LightEvalLoggingArgs: """Arguments related to logging for LightEval""" - local_output_path: Optional[Path] = None - private: Optional[bool] = True - push_results_to_hub: Optional[bool] = None - push_details_to_hub: Optional[bool] = None - push_results_to_tensorboard: Optional[bool] = None - hf_user_or_org: Optional[str] = None - hub_repo_results: Optional[str] = None #path is hf_user_or_org/hub_repo_results - hub_repo_details: Optional[str] = None #path is hf_user_or_org/hub_repo_details - hub_repo_tensorboard: Optional[str] = None - tensorboard_metric_prefix: Optional[str] = None - def __post_init__(self): - if isinstance(self.local_output_path, str): - self.local_output_path = Path(self.local_output_path) - if self.push_results_to_hub is not None and self.hf_user_or_org is None: - raise ValueError("hf_user_or_org must be specified if push_results_to_hub is set") - if self.push_details_to_hub is not None and self.hf_user_or_org is None: - raise ValueError("hf_user_or_org must be specified if push_details_to_hub is set") - if self.hf_user_or_org is not None: - if self.push_results_to_hub is not None and self.hub_repo_results is None: - self.hub_repo_results = "evals-results" - if self.push_details_to_hub is not None and self.hub_repo_details is None: - self.hub_repo_details = "evals-details" + output_dir: Optional[str] = None + save_details: bool = True + push_to_hub: bool = False + push_to_tensorboard: bool = False + public_run: bool = False + results_org: str | None = None + tensorboard_metric_prefix: str = "eval" @dataclass class LightEvalTasksArgs: """Arguments related to tasks for LightEval""" - tasks: Optional[str] = None + tasks: str custom_tasks: Optional[str] = None max_samples: Optional[int] = None num_fewshot_seeds: Optional[int] = None - dataset_loading_processes: Optional[int] = 8 + dataset_loading_processes: int = 8 multichoice_continuations_start_space: Optional[bool] = None - no_multichoice_continuations_start_space: Optional[bool] = None - - -@dataclass -class LightEvalWandbLoggerConfig: - """Arguments related to the local Wandb logger""" - - wandb_project: str = "" - wandb_entity: Optional[str] = None - wandb_run_name: Optional[str] = None - - def __post_init__(self): - assert self.wandb_project != "", "Please specify a wandb_project" + pair_wise_tokenization: bool = False @dataclass @@ -91,13 +63,8 @@ class LightEvalConfig: the saved config when running LightEval after training. """ - slurm_template: Optional[str] = None - slurm_script_dir: Optional[str] = None - temp_dir: Optional[str] = "temp_dir" - checkpoints_path: Optional[str] = None - parallelism: Optional[ParallelismArgs] = None - batch_size: Optional[int] = None + logging: LightEvalLoggingArgs + tasks: LightEvalTasksArgs + parallelism: ParallelismArgs + batch_size: int = 0 generation: Optional[Union[GenerationArgs, Dict[str, GenerationArgs]]] = None - tasks: Optional[LightEvalTasksArgs] = None - logging: Optional[LightEvalLoggingArgs] = None - wandb: Optional[LightEvalWandbLoggerConfig] = None diff --git a/src/nanotron/config/utils_config.py b/src/nanotron/config/utils_config.py index 124516cd..87d69585 100644 --- a/src/nanotron/config/utils_config.py +++ b/src/nanotron/config/utils_config.py @@ -1,10 +1,10 @@ from dataclasses import fields from enum import Enum, auto from pathlib import Path -from datasets.download.streaming_download_manager import xPath import torch +from datasets.download.streaming_download_manager import xPath from nanotron.generation.sampler import SamplerType from nanotron.parallel.pipeline_parallel.engine import ( AllForwardAllBackwardPipelineEngine, diff --git a/src/nanotron/lighteval/evaluation_tasks.py b/src/nanotron/lighteval/evaluation_tasks.py index a78fe486..2dd9820c 100644 --- a/src/nanotron/lighteval/evaluation_tasks.py +++ b/src/nanotron/lighteval/evaluation_tasks.py @@ -8,10 +8,11 @@ from dataclasses import asdict from typing import Dict, List, Tuple -from lighteval.metrics import Metrics +import lighteval.tasks.default_prompts as prompt +from lighteval.metrics.metrics import Metrics +from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc -from lighteval.tasks.tasks_prompt_formatting import LETTER_INDICES _TASKS_STRINGS: List[Tuple[LightevalTaskConfig, str]] = [] _TASKS: List[LightevalTaskConfig] = [] @@ -19,130 +20,123 @@ trust_remote_code = True ## COMMON_SENSE_REASONING_TASKS ## + + +def commonsense_qa_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["question"], + choices=[f" {c}" for c in line["choices"]["text"]], + gold_index=LETTER_INDICES.index(line["answerKey"].strip()), + instruction="", + ) + + +def siqa_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["context"] + " " + line["question"], + choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]], + gold_index=int(line["label"]) - 1, + instruction="", + ) + + COMMON_SENSE_REASONING_TASKS = [ LightevalTaskConfig( name="hellaswag", - prompt_function="hellaswag_prompt", + prompt_function=prompt.hellaswag_harness, # Updated prompt function hf_repo="hellaswag", hf_subset="default", - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], - trust_dataset=True, + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric + trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="winogrande", - prompt_function="winogrande", + prompt_function=prompt.winogrande, # Updated prompt function hf_repo="winogrande", hf_subset="winogrande_xl", - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="piqa", - prompt_function="piqa_harness", + prompt_function=prompt.piqa_harness, # Updated prompt function hf_repo="piqa", hf_subset="plain_text", - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="siqa", - prompt_function="siqa_prompt", + prompt_function=siqa_prompt, # Updated prompt function hf_repo="lighteval/siqa", hf_subset="default", hf_avail_splits=["train", "validation"], - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="openbookqa", - prompt_function="openbookqa", + prompt_function=prompt.openbookqa, # Updated prompt function hf_repo="openbookqa", hf_subset="main", - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="arc:easy", - prompt_function="arc", + prompt_function=prompt.arc, # Updated prompt function hf_repo="ai2_arc", hf_subset="ARC-Easy", evaluation_splits=["test"], generation_size=1, - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="arc:challenge", - prompt_function="arc", + prompt_function=prompt.arc, # Updated prompt function hf_repo="ai2_arc", hf_subset="ARC-Challenge", evaluation_splits=["test"], generation_size=1, - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="commonsense_qa", - prompt_function="commonsense_qa_prompt", + prompt_function=commonsense_qa_prompt, # Updated prompt function hf_repo="commonsense_qa", hf_subset="default", - metric=["loglikelihood_acc", "loglikelihood_acc_norm_nospace"], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], # Updated metric trust_dataset=trust_remote_code, ), ] -def commonsense_qa_prompt(line, task_name: str = None): - return Doc( - task_name=task_name, - query=line["question"], - choices=[f" {c}" for c in line["choices"]["text"]], - gold_index=LETTER_INDICES.index(line["answerKey"].strip()), - instruction="", - ) - - -def siqa_prompt(line, task_name: str = None): - return Doc( - task_name=task_name, - query=line["context"] + " " + line["question"], - choices=[f" {c}" for c in [line["answerA"], line["answerB"], line["answerC"]]], - gold_index=int(line["label"]) - 1, - instruction="", - ) +# 0 short for common sense +COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS] +_TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING) +_TASKS += COMMON_SENSE_REASONING_TASKS +## WORLD_KNOWLEDGE_TASKS ## -def hellaswag_prompt(line, task_name: str = None): - def preprocess(text): - """Comes from AiHarness""" - # text = text.strip() - # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. - text = text.replace(" [title]", ". ") - text = re.sub("\\[.*?\\]", "", text) - text = text.replace(" ", " ") - return text - ctx = f"{line['ctx_a']} {line['ctx_b'].capitalize()} " +def natural_questions_prompt(line, task_name: str = None): return Doc( task_name=task_name, - query=preprocess(line["activity_label"] + ": " + ctx), - choices=[" " + preprocess(ending) for ending in line["endings"]], - gold_index=int(line["label"]) if line["label"] != "" else -1, # -1 for test - # "metric": "choices_loglikelihood", + query=line["question"] + "?\nAnswer: ", + choices=[line["short_answers"]], + gold_index=0, + instruction="", ) -# 0 short for common sense -COMMON_SENSE_REASONING_STRING = [(t, f"custom|{t.name}|0|1") for t in COMMON_SENSE_REASONING_TASKS] -_TASKS_STRINGS.extend(COMMON_SENSE_REASONING_STRING) -_TASKS += COMMON_SENSE_REASONING_TASKS - -## WORLD_KNOWLEDGE_TASKS ## - WORLD_KNOWLEDGE_TASKS = [ LightevalTaskConfig( name="trivia_qa", - prompt_function="triviaqa", + prompt_function=prompt.triviaqa, hf_repo="trivia_qa", hf_subset="rc.nocontext", metric=[Metrics.quasi_exact_match], @@ -152,7 +146,7 @@ def preprocess(text): ), LightevalTaskConfig( name="natural_questions", - prompt_function="natural_questions_prompt", + prompt_function=natural_questions_prompt, hf_repo="lighteval/natural_questions_clean", hf_subset="default", metric=[Metrics.quasi_exact_match], @@ -163,35 +157,33 @@ def preprocess(text): ] -def natural_questions_prompt(line, task_name: str = None): - return Doc( - task_name=task_name, - query=line["question"] + "?\nAnswer: ", - choices=[line["short_answers"]], - gold_index=0, - instruction="", - ) - - WORLD_KNOWLEDGE_STRING = [(t, f"custom|{t.name}|5|1") for t in WORLD_KNOWLEDGE_TASKS] # WORLD_KNOWLEDGE_STRING = {t: f'custom|{t.name}|0|1' for t in WORLD_KNOWLEDGE_TASKS} _TASKS_STRINGS.extend(WORLD_KNOWLEDGE_STRING) _TASKS += WORLD_KNOWLEDGE_TASKS ## Reading comprehension ## +def boolq_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:", + choices=[" No", " Yes"], # Only gold + gold_index=int(line["label"]), + ) + READING_COMP_TASKS = [ LightevalTaskConfig( name="super_glue:boolq", - prompt_function="boolq_prompt", + prompt_function=boolq_prompt, hf_repo="super_glue", hf_subset="boolq", - metric=["target_perplexity"], + metric=[Metrics.target_perplexity], trust_dataset=trust_remote_code, ), LightevalTaskConfig( name="quac", - prompt_function="quac", + prompt_function=prompt.quac, hf_repo="lighteval/quac_helm", hf_subset="deault", metric=[Metrics.quasi_exact_match], @@ -202,15 +194,6 @@ def natural_questions_prompt(line, task_name: str = None): ] -def boolq_prompt(line, task_name: str = None): - return Doc( - task_name=task_name, - query=f"{line['passage']}\nQuestion: {line['question'].capitalize()}?\nAnswer:", - choices=[" No", " Yes"], # Only gold - gold_index=int(line["label"]), - ) - - READING_COMP_STRING = [(t, f"custom|{t.name}|0|1") for t in READING_COMP_TASKS] _TASKS_STRINGS.extend(READING_COMP_STRING) _TASKS += READING_COMP_TASKS @@ -223,7 +206,7 @@ class CustomMathEvaluationTask(LightevalTaskConfig): def __init__( self, name, - prompt_function="math", + prompt_function=prompt.math, hf_repo="lighteval/MATH", hf_subset=None, metric=[Metrics.quasi_exact_match_math], @@ -235,7 +218,7 @@ def __init__( generation_size=40, stop_sequence=None, output_regex=None, - frozen=False, + frozen=False, trust_dataset=trust_remote_code, ): super().__init__( @@ -268,7 +251,7 @@ def __init__( ] GSM8K = LightevalTaskConfig( name="gsm8k", - prompt_function="gsm8k", + prompt_function=prompt.gsm8k, hf_repo="gsm8k", hf_subset="main", hf_avail_splits=["train", "test"], @@ -288,20 +271,55 @@ def __init__( ## MMLU ## +def mmlu_harness(line, task_name: str = None): + topic = line["subject"] + prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" + prompt += line["question"] + "\n" + prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) + prompt += "Answer:" + + gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] + "__few_shots" in line and line["__few_shots"] is True # We are adding few shots + + return Doc( + task_name=task_name, + query=prompt, + choices=[" A", " B", " C", " D"], + target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix], + gold_index=gold_ix, + instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", + ) + + +def mmlu_prompt(line, task_name: str = None): + """MMLU prompt without letters""" + topic = line["subject"] + prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: " + prompt += line["question"] + "\nAnswer:" + + return Doc( + task_name=task_name, + query=prompt, + choices=[f" {c}" for c in line["choices"]], + gold_index=line["answer"], + instruction=f"The following are questions about {topic.replace('_', ' ')}.\n", + ) + + class CustomMMLUEvaluationTask(LightevalTaskConfig): def __init__( self, name, - prompt_function="mmlu_prompt", + prompt_function=mmlu_prompt, hf_repo="lighteval/mmlu", hf_subset=None, # metric=[Metrics.loglikelihood_acc_single_token], - metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], hf_avail_splits=None, evaluation_splits=["test"], few_shots_split="dev", few_shots_select=None, - suite=None, + suite=["custom"], generation_size=-1, stop_sequence=None, output_regex=None, @@ -390,41 +408,6 @@ def __init__( ] -def mmlu_harness(line, task_name: str = None): - topic = line["subject"] - prompt = f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n" - prompt += line["question"] + "\n" - prompt += "".join([f"{key}. {choice}\n" for key, choice in zip(LETTER_INDICES, line["choices"])]) - prompt += "Answer:" - - gold_ix = LETTER_INDICES.index(line["answer"]) if isinstance(line["answer"], str) else line["answer"] - "__few_shots" in line and line["__few_shots"] is True # We are adding few shots - - return Doc( - task_name=task_name, - query=prompt, - choices=[" A", " B", " C", " D"], - target_for_fewshot_sorting=[" A", " B", " C", " D"][gold_ix], - gold_index=gold_ix, - instruction=f"The following are multiple choice questions (with answers) about {topic.replace('_', ' ')}.\n\n", - ) - - -def mmlu_prompt(line, task_name: str = None): - """MMLU prompt without letters""" - topic = line["subject"] - prompt = f"The following are questions about {topic.replace('_', ' ')}.\nQuestion: " - prompt += line["question"] + "\nAnswer:" - - return Doc( - task_name=task_name, - query=prompt, - choices=[f" {c}" for c in line["choices"]], - gold_index=line["answer"], - instruction=f"The following are questions about {topic.replace('_', ' ')}.\n", - ) - - # MMLU_STRING = {t: f'custom|{t.name}|5|1' for t in MMLU_TASKS} MMLU_STRING = [(t, f"custom|{t.name}|0|1") for t in MMLU_TASKS] _TASKS_STRINGS.extend(MMLU_STRING) @@ -433,11 +416,20 @@ def mmlu_prompt(line, task_name: str = None): ## BBH ## +def bbh_prompt(line, task_name: str = None): + return Doc( + task_name=task_name, + query=line["input"] + "\nAnswer: ", + choices=[line["target"]], + gold_index=0, + ) + + class CustomBBHEvaluationTask(LightevalTaskConfig): def __init__( self, name, - prompt_function="bbh_prompt", + prompt_function=bbh_prompt, hf_repo="lighteval/big_bench_hard", hf_subset=None, metric=[Metrics.exact_match], @@ -445,7 +437,7 @@ def __init__( evaluation_splits=["train"], few_shots_split="train", few_shots_select=None, - suite=None, + suite=["custom"], generation_size=4, stop_sequence=None, output_regex=None, @@ -510,36 +502,80 @@ def __init__( ] -def bbh_prompt(line, task_name: str = None): +# BBH_STRING = {t: f'custom|{t.name}|3|1' for t in BBH_TASKS} +BBH_STRING = [(t, f"custom|{t.name}|0|1") for t in BBH_TASKS] +_TASKS_STRINGS.extend(BBH_STRING) +_TASKS += BBH_TASKS + + +## AGI eval ## + + +def agi_eval_math_prompt(line, task_name: str = None): return Doc( task_name=task_name, - query=line["input"] + "\nAnswer: ", - choices=[line["target"]], + query=line["question"], + choices=[line["answer"]], gold_index=0, + instruction="", ) -# BBH_STRING = {t: f'custom|{t.name}|3|1' for t in BBH_TASKS} -BBH_STRING = [(t, f"custom|{t.name}|0|1") for t in BBH_TASKS] -_TASKS_STRINGS.extend(BBH_STRING) -_TASKS += BBH_TASKS +def agi_eval_prompt(line, task_name: str = None): + cleaned_options = [o.replace("(", "").replace(")", " ") for o in line["options"]] + prompt = "The following are multiple choice questions (with answers).\n\n" + prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n" + prompt += "Answer: " + + choices = LETTER_INDICES[: len(line["options"])] + + output = Doc( + query=prompt, + instruction="The following are multiple choice questions (with answers).\n\n", + choices=None, # updated below + gold_index=None, # updated below + ) + + if line["label"]: + output.choices = choices + output.gold_index = LETTER_INDICES.index(line["label"].strip()) + else: + output.choices = [line["answer"]] + output.gold_index = 0 + + return output + + +def agi_eval_prompt_no_letters(line, task_name: str = None): + cleaned_options = [ + " " + o.replace("(A)", "").replace("(B)", "").replace("(C)", "").replace("(D)", "").replace("(E)", "") + for o in line["options"] + ] + + output = Doc( + query=line["question"], + choices=cleaned_options, + gold_index=LETTER_INDICES.index(line["label"].strip()), + instruction="", + ) + + return output -## AGI eval ## class CustomAGIEvalEvaluationTask(LightevalTaskConfig): def __init__( self, name, - prompt_function="agi_eval_prompt_no_letters", + prompt_function=agi_eval_prompt_no_letters, hf_repo="lighteval/agi_eval_en", hf_subset=None, # metric=[Metrics.loglikelihood_acc_single_token], - metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm_nospace], + metric=[Metrics.loglikelihood_acc, Metrics.loglikelihood_acc_norm], hf_avail_splits=["train", "validation"], evaluation_splits=["train"], few_shots_split="validation", few_shots_select=None, - suite=None, + suite=["custom"], generation_size=-1, stop_sequence=None, output_regex=None, @@ -583,57 +619,6 @@ def __init__( ] -def agi_eval_math_prompt(line, task_name: str = None): - return Doc( - task_name=task_name, - query=line["question"], - choices=[line["answer"]], - gold_index=0, - instruction="", - ) - - -def agi_eval_prompt(line, task_name: str = None): - cleaned_options = [o.replace("(", "").replace(")", " ") for o in line["options"]] - prompt = "The following are multiple choice questions (with answers).\n\n" - prompt += line["question"] + "\n" + "\n".join(cleaned_options) + "\n" - prompt += "Answer: " - - choices = LETTER_INDICES[: len(line["options"])] - - output = Doc( - query=prompt, - instruction="The following are multiple choice questions (with answers).\n\n", - choices=None, # updated below - gold_index=None, # updated below - ) - - if line["label"]: - output.choices = choices - output.gold_index = LETTER_INDICES.index(line["label"].strip()) - else: - output.choices = [line["answer"]] - output.gold_index = 0 - - return output - - -def agi_eval_prompt_no_letters(line, task_name: str = None): - cleaned_options = [ - " " + o.replace("(A)", "").replace("(B)", "").replace("(C)", "").replace("(D)", "").replace("(E)", "") - for o in line["options"] - ] - - output = Doc( - query=line["question"], - choices=cleaned_options, - gold_index=LETTER_INDICES.index(line["label"].strip()), - instruction="", - ) - - return output - - # AGIEVAL_STRING = {t: f'custom|{t.name}|5|1' for t in AGIEVAL_TASKS} AGIEVAL_STRING = [(t, f"custom|{t.name}|0|1") for t in AGIEVAL_TASKS] _TASKS_STRINGS.extend(AGIEVAL_STRING) @@ -661,7 +646,7 @@ def agi_eval_prompt_no_letters(line, task_name: str = None): EARLY_SIGNAL_TASKS = ",".join([t[1] for t in COMMON_SENSE_REASONING_STRING] + [t[1] for t in MMLU_STRING]) # Convert to dict for lighteval -TASKS_TABLE = [task.as_dict() for task in _TASKS] +TASKS_TABLE = _TASKS # You can have a few pre-organised groups of tasks TASKS_GROUPS = { "all": ",".join(t[1] for t in _TASKS_STRINGS), diff --git a/src/nanotron/lighteval/one_job_runner.py b/src/nanotron/lighteval/one_job_runner.py index b56aafda..3321e7ce 100644 --- a/src/nanotron/lighteval/one_job_runner.py +++ b/src/nanotron/lighteval/one_job_runner.py @@ -1,19 +1,19 @@ """ Mostly complete a SLURM template with a link to a single checkpoint on s3 and launch it """ import datetime +import json import os import re import subprocess -from typing import List, Optional, Tuple, Union -import copy -import json +from typing import List, Optional, Tuple + import jinja2 + from nanotron import logging +from nanotron.config import Config, LightEvalConfig from nanotron.logging import log_rank from nanotron.parallel import ParallelContext -from nanotron.config import Config, LightEvalConfig - logger = logging.get_logger(__name__) @@ -35,12 +35,11 @@ def eval_single_checkpoint_no_s3(self, checkpoint_path: str) -> Tuple[str, str]: return None, None slurm_job_id, slurm_log = run_slurm_one_job( - config = self.config, - lighteval_config = self.lighteval_config, - slurm_template=self.lighteval_config.slurm_template, + config=self.config, + lighteval_config=self.lighteval_config, + slurm_template=self.config.general.eval_slurm_template, model_checkpoint_path=checkpoint_path, current_step=self.config.general.step, - s3=False, ) return slurm_job_id, slurm_log @@ -73,12 +72,11 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: checkpoint_path = config_files[0]["destination"].replace("config.yaml", "") slurm_job_id, slurm_log = run_slurm_one_job( - config = self.config, - lighteval_config = self.lighteval_config, - slurm_template=self.lighteval_config.slurm_template, + config=self.config, + lighteval_config=self.lighteval_config, + slurm_template=self.config.general.eval_slurm_template, model_checkpoint_path=checkpoint_path, current_step=self.config.general.step, - s3=True, ) return slurm_job_id, slurm_log @@ -90,7 +88,6 @@ def run_slurm_one_job( model_checkpoint_path: str, slurm_template: str, current_step: int, - s3: bool = True, slurm_name: Optional[str] = "eval", ): """Launch a single job on Slurm with the given mapping @@ -98,11 +95,11 @@ def run_slurm_one_job( slurm_config: Slurm configuration mapping: Mapping to use for the job script (see SLURM_ONE_JOB_MAPPING) """ - + s3 = config.general.is_s3_available eval_launch_script_path = os.path.join(config.general.evals_logs_path, "launch-config", str(current_step)) eval_logs_path = os.path.join(config.general.evals_logs_path, "logs", str(current_step)) - with open(config.general.eval_slurm_config, 'r') as f: + with open(config.general.eval_slurm_config, "r") as f: eval_slurm_config = json.load(f) os.makedirs(eval_launch_script_path, exist_ok=True) @@ -118,28 +115,33 @@ def run_slurm_one_job( # Update the config with additional required parameters # Calculate the number of nodes based on parallelism config and gpus_per_node - total_gpus_needed = lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp - gpus_per_node = eval_slurm_config.get('gpus_per_node') + total_gpus_needed = ( + lighteval_config.parallelism.dp * lighteval_config.parallelism.pp * lighteval_config.parallelism.tp + ) + gpus_per_node = eval_slurm_config.get("gpus_per_node") nodes = (total_gpus_needed + gpus_per_node - 1) // gpus_per_node # Ceiling division - + if s3: - eval_slurm_config.update({ - 'nodes': nodes, # Assuming we want to run on a single node - 'job_name': f"eval-{current_step}", - 'eval_path': eval_logs_path, - 'local_path': config.lighteval.temp_dir, - 'hf_user_or_org': config.logging.hf_user_or_org if hasattr(config.logging, 'hf_user_or_org') else None, - "model_checkpoint_path": model_checkpoint_path, - }) + eval_slurm_config.update( + { + "nodes": nodes, # Assuming we want to run on a single node + "job_name": f"eval-{current_step}", + "eval_path": eval_logs_path, + "local_path": f"{config.general.temp_dir}/eval_{config.general.timestamp_with_run}/{current_step}", + "model_checkpoint_path": model_checkpoint_path, + "lighteval_config_path": config.general.lighteval_config_path, + } + ) else: - eval_slurm_config.update({ - 'nodes': nodes, # Assuming we want to run on a single node - 'job_name': f"eval-{current_step}", - 'eval_path': eval_logs_path, - 'hf_user_or_org': config.logging.hf_user_or_org if hasattr(config.logging, 'hf_user_or_org') else None, - "model_checkpoint_path": model_checkpoint_path, - }) - + eval_slurm_config.update( + { + "nodes": nodes, # Assuming we want to run on a single node + "job_name": f"eval-{current_step}", + "eval_path": eval_logs_path, + "model_checkpoint_path": model_checkpoint_path, + "lighteval_config_path": config.general.lighteval_config_path, + } + ) launch_string = SLURM_JOBS_ARRAY_TEMPLATE.render(**eval_slurm_config) @@ -164,20 +166,14 @@ def run_slurm_one_job( # Preserve important environment variables env = { - 'PATH': os.environ['PATH'], - 'LD_LIBRARY_PATH': os.environ.get('LD_LIBRARY_PATH', ''), - 'HOME': os.path.expanduser("~"), + "PATH": os.environ["PATH"], + "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", ""), + "HOME": os.path.expanduser("~"), } try: # Use subprocess.run instead of check_output for better error handling - result = subprocess.run( - ["sbatch", launch_script_path], - env=env, - check=True, - capture_output=True, - text=True - ) + result = subprocess.run(["sbatch", launch_script_path], env=env, check=True, capture_output=True, text=True) output = result.stdout job_ids = output.split()[-1] output_log = ( diff --git a/src/nanotron/lighteval/run_evals.py b/src/nanotron/lighteval/run_evals.py index 5ee36f53..1fd4b178 100644 --- a/src/nanotron/lighteval/run_evals.py +++ b/src/nanotron/lighteval/run_evals.py @@ -15,9 +15,10 @@ def get_parser(): help="Path to the Nanotron checkpoint YAML or python config file, potentially on S3", ) parser.add_argument( - "--lighteval-override", + "--lighteval-config-path", type=str, - help="Path to an optional YAML or python Lighteval config to override part of the checkpoint Lighteval config", + required=True, + help="Path to an optional YAML or python Lighteval config", ) parser.add_argument( "--cache-dir", @@ -25,7 +26,6 @@ def get_parser(): default=None, help="Cache directory", ) - return parser @@ -33,7 +33,7 @@ def get_parser(): parser = get_parser() args, unknowns = parser.parse_known_args() main( - checkpoint_config_path=args.checkpoint_config_path, - lighteval_config_path=args.lighteval_override, + checkpoint_config_path=args.checkpoint_config_path, + lighteval_config_path=args.lighteval_config_path, cache_dir=args.cache_dir, - ) \ No newline at end of file + ) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 3a4ab60a..76b8fa4a 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -19,14 +19,12 @@ cast, ) -from nanotron.s3_checkpoints import S3Mover, check_path_is_local import torch from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader from nanotron import distributed as dist from nanotron import logging -from nanotron.lighteval import LightEvalRunner from nanotron.config import ( Config, DatasetStageArgs, @@ -48,6 +46,7 @@ log_throughput, lr_scheduler_builder, ) +from nanotron.lighteval import LightEvalRunner from nanotron.logging import ( LoggerWriter, LogItem, @@ -151,14 +150,12 @@ def __init__( data_parallel_size=self.config.parallelism.dp, expert_parallel_size=self.config.parallelism.expert_parallel_size, ) - + self.pre_init() # Set log levels set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging) - - # Log benchmark info if os.environ.get("NANOTRON_BENCHMARK", "0") == "1": log_throughput(self.config, self.parallel_context) @@ -288,12 +285,13 @@ def post_init(self): self.post_checkpoint_callback = None else: # Use the no_s3 version of the evaluation function - # TODO: make it one function + make it automatic to switch to the right jinja template + # TODO: make it one function + make it automatic to switch to the right jinja template self.post_checkpoint_callback = self.lighteval_runner.eval_single_checkpoint_no_s3 else: self.post_checkpoint_callback = None else: self.post_checkpoint_callback = None + def pre_training(self, *args, **kwargs): self._print_training_plan() @@ -306,7 +304,7 @@ def pre_training(self, *args, **kwargs): rank=0, ) - current_time = datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S") + datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S") if dist.get_rank(self.parallel_context.world_pg) == 0 and wandb is not None: wandb.init( project=self.config.general.project, @@ -316,20 +314,19 @@ def pre_training(self, *args, **kwargs): # Define tokens metric as x-axis for all metrics wandb.define_metric("Tokens") wandb.define_metric("*", step_metric="Tokens") - + # Handle resuming from a previous run - initial_step = getattr(self.config.general, 'step', 0) + initial_step = getattr(self.config.general, "step", 0) if initial_step is None: initial_step = 0 - + initial_tokens = initial_step * self.global_batch_size - + # Log initial tokens to set the starting point wandb.log({"Tokens": initial_tokens}) - + print(f"Initial Tokens: {initial_tokens}") - def post_train_step(self): # Update our background upload/removal of checkpoints if self.s3_mover is not None: @@ -338,12 +335,11 @@ def post_train_step(self): def post_training(self): if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg) - + def post_training(self): if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg) - def _print_training_plan(self): if hasattr(self.config, "data_stages") and self.config.data_stages is not None: stages_info = "".join( @@ -748,17 +744,21 @@ def _init_model_instance(self) -> NanotronModel: def _load_model_checkpoint(self, model: NanotronModel) -> NanotronModel: unwrapped_model = model.module if isinstance(model, DistributedDataParallel) else model - # Load or initialize model weights + # Load or initialize model weights reloaded_from_checkpoint = False if self.init_checkpoint_path is not None: - # Load from a pre existing checkpoint + # Load from a pre existing checkpoint if check_path_is_local(self.init_checkpoint_path): - # Reload from a training checkpoint - log_rank(f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0) + # Reload from a training checkpoint + log_rank( + f"Loading weights from {self.init_checkpoint_path}", logger=logger, level=logging.INFO, rank=0 + ) self.param_shard_metadata = load_weights( - model=unwrapped_model, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path + model=unwrapped_model, + parallel_context=self.parallel_context, + root_folder=self.init_checkpoint_path, ) - reloaded_from_checkpoint=True + reloaded_from_checkpoint = True if not reloaded_from_checkpoint: log_rank("No checkpoint path provided.", logger=logger, level=logging.INFO, rank=0) @@ -894,18 +894,25 @@ def setup_log_writers( def pre_save_checkpoint(self) -> Path: if wandb is not None and dist.get_rank(self.parallel_context.dp_pg) == 0: - if self.config.general.wandb_id is None: + if self.config.general.wandb_id is None: self.config.general.wandb_id = wandb.run.id self.config.general.wandb_project = wandb.run.project - elif self.config.general.wandb_id is not None and self.config.general.wandb_id!= wandb.run.id: - log_rank("Update the wandb run due too resume from checkpoint", logger=logger, level=logging.WARNING, rank=0) + elif self.config.general.wandb_id is not None and self.config.general.wandb_id != wandb.run.id: + log_rank( + "Update the wandb run due too resume from checkpoint", logger=logger, level=logging.WARNING, rank=0 + ) self.config.general.wandb_id = wandb.run.id self.config.general.wandb_project = wandb.run.project if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) if self.s3_mover.post_upload_callback_outputs is not None: slurm_job_id, slurm_log = self.s3_mover.post_upload_callback_outputs - log_rank(f"launching eval job: job_id={slurm_job_id} log at {slurm_log} slurm_eval", logger=logger, level=logging.INFO, rank=0) + log_rank( + f"launching eval job: job_id={slurm_job_id} log at {slurm_log} slurm_eval", + logger=logger, + level=logging.INFO, + rank=0, + ) def post_save_checkpoint(self): # Upload to S3 @@ -914,14 +921,15 @@ def post_save_checkpoint(self): elif self.post_checkpoint_callback is not None: # If we're not using S3, but we have a post-checkpoint callback for evals - checkpoint_path = self.config.checkpoints.checkpoints_path / f"{self.config.general.step}" + checkpoint_path = Path(self.config.checkpoints.checkpoints_path) / f"{self.config.general.step}" self.post_checkpoint_callback(checkpoint_path) def save_checkpoint(self) -> Path: self.pre_save_checkpoint() - checkpoints_path = self.config.checkpoints.checkpoints_path - checkpoint_path = checkpoints_path / f"{self.iteration_step}" + print(f"config: {self.config}") + print(f"checkpoints_path: {checkpoints_path}") + checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}" if self.config.checkpoints.checkpoints_path_is_shared_file_system: should_mkdir = dist.get_rank(self.parallel_context.world_pg) == 0 else: From 43c833fdd3a1ebedb384b4f902526632888774c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Fri, 20 Sep 2024 06:11:30 +0000 Subject: [PATCH 41/43] remove print --- src/nanotron/trainer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 76b8fa4a..2e4be82b 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -336,10 +336,6 @@ def post_training(self): if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg) - def post_training(self): - if self.s3_mover is not None: - self.s3_mover.distributed_wait_for_completion(group=self.parallel_context.world_pg) - def _print_training_plan(self): if hasattr(self.config, "data_stages") and self.config.data_stages is not None: stages_info = "".join( @@ -927,8 +923,6 @@ def post_save_checkpoint(self): def save_checkpoint(self) -> Path: self.pre_save_checkpoint() checkpoints_path = self.config.checkpoints.checkpoints_path - print(f"config: {self.config}") - print(f"checkpoints_path: {checkpoints_path}") checkpoint_path = Path(checkpoints_path) / f"{self.iteration_step}" if self.config.checkpoints.checkpoints_path_is_shared_file_system: should_mkdir = dist.get_rank(self.parallel_context.world_pg) == 0 From 3d7c98f329e42b84a0af5dc110bbcba7aee9be85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Tue, 24 Sep 2024 20:41:09 +0000 Subject: [PATCH 42/43] change after review --- create_config.py | 2 +- launcher.py | 10 +++++----- src/nanotron/config/config.py | 29 +++-------------------------- src/nanotron/serialize/main.py | 4 ---- src/nanotron/trainer.py | 14 -------------- src/nanotron/utils.py | 13 ++++++------- 6 files changed, 15 insertions(+), 57 deletions(-) diff --git a/create_config.py b/create_config.py index f09df36e..cd9a122b 100644 --- a/create_config.py +++ b/create_config.py @@ -94,7 +94,7 @@ # ), # ) - # lighteval = None + lighteval = None checkpoints = CheckpointsArgs( # checkpoints_path="checkpoints", diff --git a/launcher.py b/launcher.py index c8dd4a75..f00e20e9 100644 --- a/launcher.py +++ b/launcher.py @@ -13,7 +13,6 @@ get_config_from_file, save_as_yaml, ) -from nanotron.config.lighteval_config import LightEvalConfig from nanotron.logging import human_format @@ -91,8 +90,8 @@ def set_nested_attribute(obj, path, value): key, value = item.split("=", 1) try: value = eval(value) - except: - pass + except Exception as e: + print(f"Warning: Could not evaluate '{value}': {e}") set_nested_attribute(config, key, value) @@ -304,7 +303,7 @@ def set_nested_attribute(obj, path, value): lighteval_config = config.lighteval Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) config.general.lighteval_config_path = str(Path(config.general.config_logs_path) / "lighteval_config.yaml") - save_as_yaml(lighteval_config, LightEvalConfig, config.general.lighteval_config_path) + save_as_yaml(lighteval_config, config.general.lighteval_config_path) config_path_yaml = str(Path(config.general.config_logs_path) / "launch_config.yaml") Path(config.general.config_logs_path).mkdir(parents=True, exist_ok=True) @@ -354,7 +353,8 @@ def set_nested_attribute(obj, path, value): try: gpu_count = torch.cuda.device_count() is_interactive = gpu_count > 0 - except: + except Exception as e: + print(f"Warning: Could not get GPU count: {e}") is_interactive = False if is_interactive: diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 27105ee8..31f3ee4d 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -97,7 +97,7 @@ def __post_init__(self): class S3UploadArgs: """Arguments related to uploading checkpoints on s3""" - remove_after_upload: bool + remove_after_upload: Optional[bool] = True upload_s3_path: Optional[ str ] = None # set to None if we want to use S3UploadArgs to download checkpoints from s3 but not upload checkpoints on s3 @@ -112,23 +112,6 @@ def __post_init__(self): self.s5cmd_path = Path(self.s5cmd_path) -@dataclass -class S3UploadArgs: - """Arguments related to uploading checkpoints on s3""" - - upload_s3_path: xPath - remove_after_upload: bool - s5cmd_numworkers: Optional[int] - s5cmd_concurrency: Optional[int] - s5cmd_path: Optional[xPath] - - def __post_init__(self): - if isinstance(self.upload_s3_path, str): - self.upload_s3_path = xPath(self.upload_s3_path) - if isinstance(self.s5cmd_path, str): - self.s5cmd_path = xPath(self.s5cmd_path) - - @dataclass class NanosetDatasetsArgs: dataset_folder: Union[str, List[str]] @@ -222,8 +205,6 @@ class GeneralArgs: consumed_train_samples: Optional[int] = None benchmark_csv_path: Optional[Path] = None ignore_sanity_checks: bool = True - wandb_id: Optional[str] = None - wandb_project: Optional[str] = None def __post_init__(self): if self.seed is None: @@ -395,10 +376,6 @@ def create_empty(cls): return cls(**{f.name: None for f in cls_fields}) def __post_init__(self): - - if hasattr(self, "_post_init_done"): - return - self._post_init_done = True self.general.__post_init__() if self.s3_upload is not None: @@ -538,8 +515,8 @@ def get_config_from_file( return config -def save_as_yaml(config, config_class, file_path: str): - +def save_as_yaml(config: Union[Config, LightEvalConfig], file_path: str): + config_class = type(config) config_dict = serialize(config) file_path = str(file_path) with open(file_path, "w") as f: diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index b395d145..c6c5fb07 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -1,12 +1,8 @@ import os from pathlib import Path from typing import Optional, cast -from datasets.download.streaming_download_manager import xPath -import os -from nanotron.s3_checkpoints import S3Mover, check_path_is_local, fs_open import torch -from datasets.download.streaming_download_manager import xPath from torch import nn from torch.nn.parallel import DistributedDataParallel from torch.optim.lr_scheduler import LambdaLR diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 2e4be82b..04d89cf4 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -304,7 +304,6 @@ def pre_training(self, *args, **kwargs): rank=0, ) - datetime.datetime.now().strftime("%d/%m/%Y_%H:%M:%S") if dist.get_rank(self.parallel_context.world_pg) == 0 and wandb is not None: wandb.init( project=self.config.general.project, @@ -325,8 +324,6 @@ def pre_training(self, *args, **kwargs): # Log initial tokens to set the starting point wandb.log({"Tokens": initial_tokens}) - print(f"Initial Tokens: {initial_tokens}") - def post_train_step(self): # Update our background upload/removal of checkpoints if self.s3_mover is not None: @@ -888,17 +885,6 @@ def setup_log_writers( return loggerwriter def pre_save_checkpoint(self) -> Path: - - if wandb is not None and dist.get_rank(self.parallel_context.dp_pg) == 0: - if self.config.general.wandb_id is None: - self.config.general.wandb_id = wandb.run.id - self.config.general.wandb_project = wandb.run.project - elif self.config.general.wandb_id is not None and self.config.general.wandb_id != wandb.run.id: - log_rank( - "Update the wandb run due too resume from checkpoint", logger=logger, level=logging.WARNING, rank=0 - ) - self.config.general.wandb_id = wandb.run.id - self.config.general.wandb_project = wandb.run.project if self.s3_mover is not None: self.s3_mover.distributed_wait_for_completion(self.parallel_context.world_pg) if self.s3_mover.post_upload_callback_outputs is not None: diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 3b98b7bc..07cd4898 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -2,12 +2,10 @@ import inspect import os import random -import socket import re +import socket from contextlib import ExitStack, contextmanager from typing import ContextManager, List, Optional -import json -import os import torch from packaging import version @@ -164,7 +162,8 @@ def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int: except OSError: continue -def check_path_is_s3(path:str) -> bool: - #TODO maybe replace by a better method ? - s3_pattern = r'^s3://|^https?://[\w\-\.]+\.s3\.amazonaws\.com/|^https?://s3\.amazonaws\.com/[\w\-\.]+' - return bool(re.match(s3_pattern, path)) \ No newline at end of file + +def check_path_is_s3(path: str) -> bool: + # TODO maybe replace by a better method ? + s3_pattern = r"^s3://|^https?://[\w\-\.]+\.s3\.amazonaws\.com/|^https?://s3\.amazonaws\.com/[\w\-\.]+" + return bool(re.match(s3_pattern, path)) From e74ffd113675d562bc8d2f883fc1a6824bc37dc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Celiebak=E2=80=9D?= Date: Thu, 26 Sep 2024 14:18:14 +0000 Subject: [PATCH 43/43] uncomment logging item --- src/nanotron/trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 04d89cf4..62986fba 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -621,12 +621,12 @@ def train_step_logs( lr = self.lr_scheduler.get_last_lr()[0] log_entries = [ - # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), - # LogItem( - # "consumed_tokens", - # self.metadata.consumed_train_samples * self.config.tokens.sequence_length, - # "human_format", - # ), # , "12d"), + # LogItem("consumed_samples", self.metadata.consumed_train_samples, "human_format"), # , "12d"), + LogItem( + "consumed_tokens", + self.metadata.consumed_train_samples * self.config.tokens.sequence_length, + "human_format", + ), # , "12d"), LogItem("elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format"), # , ".1f"), LogItem("tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), LogItem(