diff --git a/tests/test_save_checkpoint.py b/tests/test_save_checkpoint.py index 002fb74..790ff12 100644 --- a/tests/test_save_checkpoint.py +++ b/tests/test_save_checkpoint.py @@ -6,9 +6,6 @@ import torch.distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from fastvideo.utils.checkpoint import save_checkpoint -from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs - @pytest.fixture(scope="module", autouse=True) def setup_distributed(): @@ -28,6 +25,9 @@ def setup_distributed(): def test_save_and_remove_checkpoint(): from fastvideo.models.mochi_hf.modeling_mochi import \ MochiTransformer3DModel + from fastvideo.utils.checkpoint import save_checkpoint + from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs + transformer = MochiTransformer3DModel(num_layers=0) fsdp_kwargs, _ = get_dit_fsdp_kwargs(transformer, "none") transformer = FSDP(transformer, **fsdp_kwargs)