Skip to content

Commit

Permalink
Merge pull request huggingface#51 from huggingface/nouamane/custom-mo…
Browse files Browse the repository at this point in the history
…dels

Fix support for custom modeling
  • Loading branch information
NouamaneTazi authored Jan 31, 2024
2 parents b2e1b25 + 6d2b81c commit 682efcc
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
self,
config_or_config_file: Union[Config, str],
config_class: Type[Config] = Config,
model_config_class: Optional[Type] = None,
model_class: Type[NanotronModel] = None,
):
"""
Expand All @@ -98,12 +99,13 @@ def __init__(
Args:
config_or_config_file: Either a `Config` object or a path to a YAML file containing the config.
config_class: The `Config` class to use.
model_config_class: The `ModelConfig` class to use (for example `LlamaConfig`). Defaults to `None` which will use the model config class defined in the config.
model_class: The `NanotronModel` class to use (for example `LlamaForTraining`). Defaults to `None` which will use the model class defined in the config.
"""

super().__init__()
self.config = get_config_from_file(config_or_config_file, config_class=config_class)
self.model_config = self.config.model.model_config
self.model_config = model_config_class(**self.config.model.model_config)
if model_class is not None:
CONFIG_TO_MODEL_CLASS[self.model_config.__class__.__name__] = model_class

Expand Down

0 comments on commit 682efcc

Please sign in to comment.