From 358f3950895d2e47f253ee3dca83a507e0a5c8d6 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Fri, 17 Jan 2025 20:05:17 +0800 Subject: [PATCH] improve ut coverage --- tests/unit/model_parallelism/test_autotp_training.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/unit/model_parallelism/test_autotp_training.py b/tests/unit/model_parallelism/test_autotp_training.py index ba9d43edfb6d..5f363e976481 100644 --- a/tests/unit/model_parallelism/test_autotp_training.py +++ b/tests/unit/model_parallelism/test_autotp_training.py @@ -351,13 +351,14 @@ def prepare_tp_model(hidden_dim, nlayers, linear_indices, allreduce_indices, gro return model, base_model +@pytest.mark.parametrize("zero_stage", [0, 1]) @pytest.mark.parametrize("tp_size", [2, 4]) class TestSave(DistributedTest): world_size = 4 reuse_dist_env = True - def test_save_original_weight(self, tp_size: int): + def test_save_original_weight(self, tp_size: int, zero_stage: int): hidden_dim = 64 set_autotp_mode(training=True) config_dict = { @@ -373,7 +374,7 @@ def test_save_original_weight(self, tp_size: int): "autotp_size": tp_size }, "zero_optimization": { - "stage": 0, + "stage": zero_stage, } } if preferred_dtype() is torch.float16: @@ -415,7 +416,7 @@ def compare_state_dicts(state_dict1, state_dict2): else: assert tp_state_dict is None, f"noly rank0 should have the state_dict" - def test_ckpt_save(self, tmpdir, tp_size: int): + def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int): hidden_dim = 64 set_autotp_mode(training=True) config_dict = { @@ -428,7 +429,7 @@ def test_ckpt_save(self, tmpdir, tp_size: int): } }, "zero_optimization": { - "stage": 0, + "stage": zero_stage, }, "tensor_parallel": { "autotp_size": tp_size