From ceb9f876b7805bbce65b582564802104814abe12 Mon Sep 17 00:00:00 2001 From: Jan-Matthis Lueckmann Date: Wed, 18 Dec 2024 02:11:18 -0800 Subject: [PATCH] Remove TfGrain. PiperOrigin-RevId: 707459348 --- connectomics/jax/checkpoint.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/connectomics/jax/checkpoint.py b/connectomics/jax/checkpoint.py index c0adc79..5ef121c 100644 --- a/connectomics/jax/checkpoint.py +++ b/connectomics/jax/checkpoint.py @@ -22,7 +22,6 @@ from etils import epath import flax import grain.python as grain -import grain.tensorflow as tfgrain from orbax import checkpoint as ocp import tensorflow as tf @@ -118,14 +117,3 @@ def restore_checkpoint( return manager.restore( manager.latest_step() if step is None else step, args=ocp.args.Composite(**restore_args_dict)) - - -TfGrainCheckpointHandler = tfgrain.OrbaxCheckpointHandler - - -@ocp.args.register_with_handler( # pytype:disable=wrong-arg-types - TfGrainCheckpointHandler, for_save=True, for_restore=True -) -@dataclasses.dataclass -class TfGrainCheckpointArgs(ocp.args.CheckpointArgs): - item: Any