From 913751023c94a74685fe3a225aefc4126534920e Mon Sep 17 00:00:00 2001 From: Jan-Matthis Lueckmann Date: Wed, 13 Nov 2024 11:27:24 -0800 Subject: [PATCH] Convert model config to frozen dict. PiperOrigin-RevId: 696209041 --- connectomics/jax/models/util.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/connectomics/jax/models/util.py b/connectomics/jax/models/util.py index 56505e7..00d90cb 100644 --- a/connectomics/jax/models/util.py +++ b/connectomics/jax/models/util.py @@ -71,8 +71,13 @@ def model_from_config( cfg_field = get_config_name(cfg_cls.__name__) logging.info('Using config settings from "%r"', cfg_field) - cfg = cfg_cls(**getattr(config, cfg_field)) - return model_cls(config=cfg, name=getattr(config, 'model_name', None)) + model_cfg = getattr(config, cfg_field) + # By converting the config to a FrozenConfigDict, we ensure that it is + # hashable. This is e.g. required for static arguments passed to jax.jit. + model_cfg = ml_collections.config_dict.FrozenConfigDict(model_cfg) + return model_cls( + config=cfg_cls(**model_cfg), name=getattr(config, 'model_name', None) + ) def model_from_name(