diff --git a/connectomics/jax/models/util.py b/connectomics/jax/models/util.py index 56505e7..00d90cb 100644 --- a/connectomics/jax/models/util.py +++ b/connectomics/jax/models/util.py @@ -71,8 +71,13 @@ def model_from_config( cfg_field = get_config_name(cfg_cls.__name__) logging.info('Using config settings from "%r"', cfg_field) - cfg = cfg_cls(**getattr(config, cfg_field)) - return model_cls(config=cfg, name=getattr(config, 'model_name', None)) + model_cfg = getattr(config, cfg_field) + # By converting the config to a FrozenConfigDict, we ensure that it is + # hashable. This is e.g. required for static arguments passed to jax.jit. + model_cfg = ml_collections.config_dict.FrozenConfigDict(model_cfg) + return model_cls( + config=cfg_cls(**model_cfg), name=getattr(config, 'model_name', None) + ) def model_from_name(