Skip to content

Commit

Permalink
Add "latest_checkpoint" path to state_dict (#169)
Browse files Browse the repository at this point in the history
Doesn't quite mirror [this cog
PR](https://github.com/EmbarkStudios/cog/pull/731/files), but ensures we
have the path, as discussed with @tgolsson. Please let me know in case I
misunderstood something.
  • Loading branch information
jaxs-ribs authored Oct 23, 2023
1 parent 46f4165 commit 8b934eb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
10 changes: 8 additions & 2 deletions emote/callbacks/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,25 @@ def begin_training(self):
os.makedirs(self._folder_path, exist_ok=True)

def end_cycle(self, bp_step, bp_samples):
name = f"checkpoint_{self._checkpoint_index}.tar"
final_path = os.path.join(self._folder_path, name)
state_dict = {
"callback_state_dicts": {cb.name: cb.state_dict() for cb in self._cbs},
"training_state": {
"latest_checkpoint": final_path,
"bp_step": bp_step,
"bp_samples": bp_samples,
"checkpoint_index": self._checkpoint_index,
},
}
name = f"checkpoint_{self._checkpoint_index}.tar"
final_path = os.path.join(self._folder_path, name)
torch.save(state_dict, final_path)
self._checkpoint_index += 1

return {
"latest_checkpoint": state_dict["training_state"]["latest_checkpoint"],
"checkpoint_index": state_dict["training_state"]["checkpoint_index"],
}


class CheckpointLoader(Callback):
"""CheckpointLoader loads a checkpoint like the one created by Checkpointer.
Expand Down
2 changes: 2 additions & 0 deletions tests/test_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def test_networks_checkpoint():
t1.train()
n2 = nn.Linear(1, 1)
test_data = torch.rand(5, 1)

assert "latest_checkpoint" in t1.state
assert not torch.allclose(n1(test_data), n2(test_data))

c2 = [
Expand Down

0 comments on commit 8b934eb

Please sign in to comment.