Skip to content

Commit

Permalink
Clear JAX state sharding after fit, evaluate and predict. (#20865)
Browse files Browse the repository at this point in the history
The state sharding is leaked at the end of `fit`, `evaluate` and `predict`. The values are not reused if `fit`, `evaluate` and `predict` is called again.
  • Loading branch information
hertschuh authored Feb 6, 2025
1 parent 93b393f commit c04cf9d
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def fit(
del self._eval_epoch_iterator
callbacks.on_train_end(logs=training_logs)
self._jax_state = None
self._clear_jax_state_sharding()
return self.history

@traceback_utils.filter_traceback
Expand Down Expand Up @@ -601,6 +602,9 @@ def evaluate(
logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)
self._jax_state = None
if not use_cached_eval_dataset:
# Only clear sharding if evaluate is not called from `fit`.
self._clear_jax_state_sharding()
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
Expand Down Expand Up @@ -696,6 +700,7 @@ def append_to_outputs(batch_outputs, outputs):
self.jax_state_sync()
callbacks.on_predict_end()
self._jax_state = None
self._clear_jax_state_sharding()
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)

def train_on_batch(
Expand Down Expand Up @@ -873,6 +878,12 @@ def _record_training_state_sharding_spec(self):
v.value.sharding for v in self.metrics_variables
]

def _clear_jax_state_sharding(self):
self._trainable_variable_shardings = None
self._non_trainable_variable_shardings = None
self._optimizer_variable_shardings = None
self._metrics_variable_shardings = None

def _enforce_jax_state_sharding(
self,
trainable_variables=None,
Expand Down

0 comments on commit c04cf9d

Please sign in to comment.