diff --git a/connectomics/jax/models/util.py b/connectomics/jax/models/util.py index 00d90cb..a142ffc 100644 --- a/connectomics/jax/models/util.py +++ b/connectomics/jax/models/util.py @@ -16,10 +16,12 @@ import collections.abc import inspect +import json import re from typing import Any, Type from absl import logging +from connectomics.common import file from connectomics.common import import_util # pylint:disable=unused-import from connectomics.jax.models import convstack @@ -181,3 +183,15 @@ def model_from_dict_config( default_packages, **config[cfg_field], ) + + +def save_config(config: ml_collections.ConfigDict, path: file.PathLike): + """Saves model config to a file.""" + with file.Open(path, 'w') as f: + f.write(config.to_json_best_effort() + '\n') + + +def load_config(path: file.PathLike) -> ml_collections.ConfigDict: + """Loads a model config from a file.""" + with file.Open(path, 'r') as f: + return ml_collections.ConfigDict(json.loads(f.read()))