Skip to content

Commit

Permalink
Remove TfGrain.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707422138
  • Loading branch information
jan-matthis authored and copybara-github committed Dec 18, 2024
1 parent f087c2c commit c38934e
Showing 1 changed file with 0 additions and 12 deletions.
12 changes: 0 additions & 12 deletions connectomics/jax/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from etils import epath
import flax
import grain.python as grain
import grain.tensorflow as tfgrain
from orbax import checkpoint as ocp
import tensorflow as tf

Expand Down Expand Up @@ -118,14 +117,3 @@ def restore_checkpoint(
return manager.restore(
manager.latest_step() if step is None else step,
args=ocp.args.Composite(**restore_args_dict))


TfGrainCheckpointHandler = tfgrain.OrbaxCheckpointHandler


@ocp.args.register_with_handler( # pytype:disable=wrong-arg-types
TfGrainCheckpointHandler, for_save=True, for_restore=True
)
@dataclasses.dataclass
class TfGrainCheckpointArgs(ocp.args.CheckpointArgs):
item: Any

0 comments on commit c38934e

Please sign in to comment.