From 8b934eb90ef7dd2e9992b3596024415b730a40b2 Mon Sep 17 00:00:00 2001 From: luc <59030697+jaxs-ribs@users.noreply.github.com> Date: Mon, 23 Oct 2023 14:38:14 +0200 Subject: [PATCH] Add "latest_checkpoint" path to state_dict (#169) Doesn't quite mirror [this cog PR](https://github.com/EmbarkStudios/cog/pull/731/files), but ensures we have the path, as discussed with @tgolsson. Please let me know in case I misunderstood something. --- emote/callbacks/checkpointing.py | 10 ++++++++-- tests/test_checkpoints.py | 2 ++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/emote/callbacks/checkpointing.py b/emote/callbacks/checkpointing.py index 5c9eba08..3b33f51b 100644 --- a/emote/callbacks/checkpointing.py +++ b/emote/callbacks/checkpointing.py @@ -59,19 +59,25 @@ def begin_training(self): os.makedirs(self._folder_path, exist_ok=True) def end_cycle(self, bp_step, bp_samples): + name = f"checkpoint_{self._checkpoint_index}.tar" + final_path = os.path.join(self._folder_path, name) state_dict = { "callback_state_dicts": {cb.name: cb.state_dict() for cb in self._cbs}, "training_state": { + "latest_checkpoint": final_path, "bp_step": bp_step, "bp_samples": bp_samples, "checkpoint_index": self._checkpoint_index, }, } - name = f"checkpoint_{self._checkpoint_index}.tar" - final_path = os.path.join(self._folder_path, name) torch.save(state_dict, final_path) self._checkpoint_index += 1 + return { + "latest_checkpoint": state_dict["training_state"]["latest_checkpoint"], + "checkpoint_index": state_dict["training_state"]["checkpoint_index"], + } + class CheckpointLoader(Callback): """CheckpointLoader loads a checkpoint like the one created by Checkpointer. diff --git a/tests/test_checkpoints.py b/tests/test_checkpoints.py index fd3cade7..72dcd40f 100644 --- a/tests/test_checkpoints.py +++ b/tests/test_checkpoints.py @@ -70,6 +70,8 @@ def test_networks_checkpoint(): t1.train() n2 = nn.Linear(1, 1) test_data = torch.rand(5, 1) + + assert "latest_checkpoint" in t1.state assert not torch.allclose(n1(test_data), n2(test_data)) c2 = [