Skip to content

Commit

Permalink
step can be on cpu or gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Nov 22, 2024
1 parent 62fa626 commit ef931bc
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,45 @@ def _test_save_and_load_optimizer(parallel_context: ParallelContext, test_contex
else:
assert not match, "Newly initialised optimizer should not match."

load_optimizer(optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder)
load_optimizer(
optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder, map_location=None
)

# Assert the optimizer states are exactly the same after loading.
match, msg = is_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())
assert match, msg

# Test loading optimizer states to CPU
cpu_optimizer = NamedOptimizer(
named_params_or_groups=model.named_parameters(),
optimizer_builder=lambda params: torch.optim.AdamW(params),
)

# Load optimizer states to CPU
load_optimizer(
optimizer=cpu_optimizer, parallel_context=parallel_context, root_folder=store_folder, map_location="cpu"
)

# Get state dicts
gpu_state = optimizer.state_dict()
cpu_state = cpu_optimizer.state_dict()

# Check that states match except for device
for param_id in gpu_state["state"]:
for key, gpu_value in gpu_state["state"][param_id].items():
cpu_value = cpu_state["state"][param_id][key]
if isinstance(gpu_value, torch.Tensor):
assert torch.equal(gpu_value.cpu(), cpu_value), f"Values don't match for param {param_id}, key {key}"
if key != "step": # Skip device checks for 'step' key
assert (
cpu_value.device.type == "cpu"
), f"CPU optimizer state should be on CPU for param {param_id}, key {key}"
assert (
gpu_value.device.type == "cuda"
), f"GPU optimizer state should be on CUDA for param {param_id}, key {key}"
else:
assert gpu_value == cpu_value, f"Non-tensor values don't match for param {param_id}, key {key}"

parallel_context.destroy()


Expand Down

0 comments on commit ef931bc

Please sign in to comment.