From ef931bc66b399d159c44adf04a395aab30ec5b52 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Fri, 22 Nov 2024 10:11:13 +0000 Subject: [PATCH] step can be on cpu or gpu --- tests/test_serialize.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 329ff279..520c6517 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -139,12 +139,45 @@ def _test_save_and_load_optimizer(parallel_context: ParallelContext, test_contex else: assert not match, "Newly initialised optimizer should not match." - load_optimizer(optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder) + load_optimizer( + optimizer=new_optimizer, parallel_context=parallel_context, root_folder=store_folder, map_location=None + ) # Assert the optimizer states are exactly the same after loading. match, msg = is_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) assert match, msg + # Test loading optimizer states to CPU + cpu_optimizer = NamedOptimizer( + named_params_or_groups=model.named_parameters(), + optimizer_builder=lambda params: torch.optim.AdamW(params), + ) + + # Load optimizer states to CPU + load_optimizer( + optimizer=cpu_optimizer, parallel_context=parallel_context, root_folder=store_folder, map_location="cpu" + ) + + # Get state dicts + gpu_state = optimizer.state_dict() + cpu_state = cpu_optimizer.state_dict() + + # Check that states match except for device + for param_id in gpu_state["state"]: + for key, gpu_value in gpu_state["state"][param_id].items(): + cpu_value = cpu_state["state"][param_id][key] + if isinstance(gpu_value, torch.Tensor): + assert torch.equal(gpu_value.cpu(), cpu_value), f"Values don't match for param {param_id}, key {key}" + if key != "step": # Skip device checks for 'step' key + assert ( + cpu_value.device.type == "cpu" + ), f"CPU optimizer state should be on CPU for param {param_id}, key {key}" + assert ( + gpu_value.device.type == "cuda" + ), f"GPU optimizer state should be on CUDA for param {param_id}, key {key}" + else: + assert gpu_value == cpu_value, f"Non-tensor values don't match for param {param_id}, key {key}" + parallel_context.destroy()