Skip to content

Commit

Permalink
improve ut coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
inkcherry committed Jan 17, 2025
1 parent 23bd0fc commit 358f395
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/unit/model_parallelism/test_autotp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,14 @@ def prepare_tp_model(hidden_dim, nlayers, linear_indices, allreduce_indices, gro
return model, base_model


@pytest.mark.parametrize("zero_stage", [0, 1])
@pytest.mark.parametrize("tp_size", [2, 4])
class TestSave(DistributedTest):

world_size = 4
reuse_dist_env = True

def test_save_original_weight(self, tp_size: int):
def test_save_original_weight(self, tp_size: int, zero_stage: int):
hidden_dim = 64
set_autotp_mode(training=True)
config_dict = {
Expand All @@ -373,7 +374,7 @@ def test_save_original_weight(self, tp_size: int):
"autotp_size": tp_size
},
"zero_optimization": {
"stage": 0,
"stage": zero_stage,
}
}
if preferred_dtype() is torch.float16:
Expand Down Expand Up @@ -415,7 +416,7 @@ def compare_state_dicts(state_dict1, state_dict2):
else:
assert tp_state_dict is None, f"noly rank0 should have the state_dict"

def test_ckpt_save(self, tmpdir, tp_size: int):
def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int):
hidden_dim = 64
set_autotp_mode(training=True)
config_dict = {
Expand All @@ -428,7 +429,7 @@ def test_ckpt_save(self, tmpdir, tp_size: int):
}
},
"zero_optimization": {
"stage": 0,
"stage": zero_stage,
},
"tensor_parallel": {
"autotp_size": tp_size
Expand Down

0 comments on commit 358f395

Please sign in to comment.