diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d003c02c85e..e4b9efffc7c 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4029,7 +4029,9 @@ def from_pretrained( sub_config = getattr(config, sub_config_key) sub_config.torch_dtype = torch_dtype elif isinstance(torch_dtype, torch.dtype): - pass + for sub_config_key in config.sub_configs.keys(): + sub_config = getattr(config, sub_config_key) + sub_config.torch_dtype = torch_dtype elif isinstance(torch_dtype, dict): for key, curr_dtype in torch_dtype.items(): if hasattr(config, key): diff --git a/src/transformers/models/dbrx/configuration_dbrx.py b/src/transformers/models/dbrx/configuration_dbrx.py index 7935b1d1beb..72df1fe335b 100644 --- a/src/transformers/models/dbrx/configuration_dbrx.py +++ b/src/transformers/models/dbrx/configuration_dbrx.py @@ -57,7 +57,7 @@ def __init__( self.kv_n_heads = kv_n_heads self.rope_theta = rope_theta - for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]: + for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]: if k in kwargs: kwargs.pop(k) if len(kwargs) != 0: @@ -109,7 +109,7 @@ def __init__( self.moe_loss_weight = moe_loss_weight self.moe_normalize_expert_weights = moe_normalize_expert_weights - for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash"]: + for k in ["model_type", "attn_implementation", "transformers_version", "_commit_hash", "torch_dtype"]: if k in kwargs: kwargs.pop(k) if len(kwargs) != 0: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 965d7593693..aaa9d720a13 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -331,6 +331,12 @@ def check_save_load(out1, out2): with torch.no_grad(): second = model(**self._prepare_for_class(inputs_dict, model_class))[0] + # Save and load second time because `from_pretrained` adds a bunch of new config fields + # so we need to make sure those fields can be loaded back after saving + # Simply init as `model(config)` doesn't add those fields + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname) + if isinstance(first, tuple) and isinstance(second, tuple): for tensor1, tensor2 in zip(first, second): check_save_load(tensor1, tensor2)