diff --git a/connectomics/jax/checkpoint.py b/connectomics/jax/checkpoint.py index c0adc79..5ef121c 100644 --- a/connectomics/jax/checkpoint.py +++ b/connectomics/jax/checkpoint.py @@ -22,7 +22,6 @@ from etils import epath import flax import grain.python as grain -import grain.tensorflow as tfgrain from orbax import checkpoint as ocp import tensorflow as tf @@ -118,14 +117,3 @@ def restore_checkpoint( return manager.restore( manager.latest_step() if step is None else step, args=ocp.args.Composite(**restore_args_dict)) - - -TfGrainCheckpointHandler = tfgrain.OrbaxCheckpointHandler - - -@ocp.args.register_with_handler( # pytype:disable=wrong-arg-types - TfGrainCheckpointHandler, for_save=True, for_restore=True -) -@dataclasses.dataclass -class TfGrainCheckpointArgs(ocp.args.CheckpointArgs): - item: Any