diff --git a/t5x/checkpoint_utils.py b/t5x/checkpoint_utils.py index d817b61af..f6f611a8e 100644 --- a/t5x/checkpoint_utils.py +++ b/t5x/checkpoint_utils.py @@ -280,7 +280,9 @@ def get_restore_parameters( restore_args = jax.tree.map(lambda x: ocp.RestoreArgs(), structure) flat_param_infos = {} is_ocdbt_checkpoint = ocp.type_handlers.is_ocdbt_checkpoint(directory) - ts_context = ocp.type_handlers.get_ts_context() + ts_context = ocp.serialization.ts_utils.get_ts_context( + use_ocdbt=is_ocdbt_checkpoint + ) def _get_param_info( name: str,