Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rlsu9 committed Jan 5, 2025
1 parent d400968 commit 0b1a50d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tests/test_save_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import shutil

import pytest
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel
from fastvideo.utils.checkpoint import save_checkpoint
from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs

Expand All @@ -23,7 +23,11 @@ def setup_distributed():
dist.destroy_process_group()


@pytest.mark.skipif(not torch.cuda.is_available(),
reason="Requires at least 2 GPUs to run NCCL tests")
def test_save_and_remove_checkpoint():
from fastvideo.models.mochi_hf.modeling_mochi import \
MochiTransformer3DModel
transformer = MochiTransformer3DModel(num_layers=0)
fsdp_kwargs, _ = get_dit_fsdp_kwargs(transformer, "none")
transformer = FSDP(transformer, **fsdp_kwargs)
Expand Down

0 comments on commit 0b1a50d

Please sign in to comment.