Skip to content

Commit

Permalink
Convert model config to frozen dict.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696209041
  • Loading branch information
jan-matthis authored and copybara-github committed Nov 13, 2024
1 parent affd226 commit 9137510
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions connectomics/jax/models/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,13 @@ def model_from_config(
cfg_field = get_config_name(cfg_cls.__name__)

logging.info('Using config settings from "%r"', cfg_field)
cfg = cfg_cls(**getattr(config, cfg_field))
return model_cls(config=cfg, name=getattr(config, 'model_name', None))
model_cfg = getattr(config, cfg_field)
# By converting the config to a FrozenConfigDict, we ensure that it is
# hashable. This is e.g. required for static arguments passed to jax.jit.
model_cfg = ml_collections.config_dict.FrozenConfigDict(model_cfg)
return model_cls(
config=cfg_cls(**model_cfg), name=getattr(config, 'model_name', None)
)


def model_from_name(
Expand Down

0 comments on commit 9137510

Please sign in to comment.