Skip to content

Commit

Permalink
Merge branch 'main' of github.com:huggingface/nanotron into xrsrke/se…
Browse files Browse the repository at this point in the history
…tup_cicd
  • Loading branch information
xrsrke committed Jan 31, 2024
2 parents 128eea5 + 682efcc commit e5e2045
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/config_tiny_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
os.makedirs(checkpoints_path, exist_ok=True)

config = Config(
general=GeneralArgs(project="debug", run="tiny_llama", seed=seed),
general=GeneralArgs(project="debug", run="tiny_llama_%date_%jobid", seed=seed),
checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10),
parallelism=parallelism,
model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
Expand Down
2 changes: 1 addition & 1 deletion run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_dataloader(trainer: DistributedTrainer):
# Check if we have enough samples for train_steps
assert (
trainer.config.tokens.train_steps - trainer.start_iteration_step
) * trainer.global_batch_size // trainer.parallel_context.dp_pg.size() < len(dataloader), (
) * trainer.global_batch_size // trainer.parallel_context.dp_pg.size() <= len(dataloader), (
f"Dataset is too small for steps ({len(dataloader)} < {(trainer.config.tokens.train_steps - trainer.start_iteration_step) * trainer.global_batch_size // trainer.parallel_context.dp_pg.size()}), "
f"Try train_steps<={len(dataloader) * trainer.parallel_context.dp_pg.size() // trainer.global_batch_size + trainer.start_iteration_step}"
)
Expand Down
4 changes: 3 additions & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def __post_init__(self):
), f"Please set NANOTRON_BENCHMARK to 1 when using benchmark_csv_path. Got {os.environ.get('NANOTRON_BENCHMARK', None)}"

if self.run is None:
self.run = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
self.run = "%date_%jobid"
self.run.replace("%date", datetime.datetime.now().strftime("%Y%m%d_%H%M%S"))
self.run.replace("%jobid", os.environ.get("SLURM_JOB_ID", "local"))


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Union
from typing import Any, List, Optional, Union


@dataclass
Expand Down Expand Up @@ -116,4 +116,4 @@ def n_inner(self):
return self.intermediate_size


NanotronConfigs = Union[LlamaConfig, Starcoder2Config]
NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any]
4 changes: 3 additions & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
self,
config_or_config_file: Union[Config, str],
config_class: Type[Config] = Config,
model_config_class: Optional[Type] = None,
model_class: Type[NanotronModel] = None,
):
"""
Expand All @@ -98,12 +99,13 @@ def __init__(
Args:
config_or_config_file: Either a `Config` object or a path to a YAML file containing the config.
config_class: The `Config` class to use.
model_config_class: The `ModelConfig` class to use (for example `LlamaConfig`). Defaults to `None` which will use the model config class defined in the config.
model_class: The `NanotronModel` class to use (for example `LlamaForTraining`). Defaults to `None` which will use the model class defined in the config.
"""

super().__init__()
self.config = get_config_from_file(config_or_config_file, config_class=config_class)
self.model_config = self.config.model.model_config
self.model_config = model_config_class(**self.config.model.model_config)
if model_class is not None:
CONFIG_TO_MODEL_CLASS[self.model_config.__class__.__name__] = model_class

Expand Down

0 comments on commit e5e2045

Please sign in to comment.