diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index b95e1ce7..0ea24c97 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -90,6 +90,7 @@ def __init__( self, config_or_config_file: Union[Config, str], config_class: Type[Config] = Config, + model_config_class: Optional[Type] = None, model_class: Type[NanotronModel] = None, ): """ @@ -98,12 +99,13 @@ def __init__( Args: config_or_config_file: Either a `Config` object or a path to a YAML file containing the config. config_class: The `Config` class to use. + model_config_class: The `ModelConfig` class to use (for example `LlamaConfig`). Defaults to `None` which will use the model config class defined in the config. model_class: The `NanotronModel` class to use (for example `LlamaForTraining`). Defaults to `None` which will use the model class defined in the config. """ super().__init__() self.config = get_config_from_file(config_or_config_file, config_class=config_class) - self.model_config = self.config.model.model_config + self.model_config = model_config_class(**self.config.model.model_config) if model_class is not None: CONFIG_TO_MODEL_CLASS[self.model_config.__class__.__name__] = model_class