From 0b1a50d600db5ab1b419034d6deaf8afb4353159 Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Sun, 5 Jan 2025 03:43:52 +0000 Subject: [PATCH] update --- tests/test_save_checkpoint.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_save_checkpoint.py b/tests/test_save_checkpoint.py index 373ac1c..002fb74 100644 --- a/tests/test_save_checkpoint.py +++ b/tests/test_save_checkpoint.py @@ -2,10 +2,10 @@ import shutil import pytest +import torch import torch.distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel from fastvideo.utils.checkpoint import save_checkpoint from fastvideo.utils.fsdp_util import get_dit_fsdp_kwargs @@ -23,7 +23,11 @@ def setup_distributed(): dist.destroy_process_group() +@pytest.mark.skipif(not torch.cuda.is_available(), + reason="Requires at least 2 GPUs to run NCCL tests") def test_save_and_remove_checkpoint(): + from fastvideo.models.mochi_hf.modeling_mochi import \ + MochiTransformer3DModel transformer = MochiTransformer3DModel(num_layers=0) fsdp_kwargs, _ = get_dit_fsdp_kwargs(transformer, "none") transformer = FSDP(transformer, **fsdp_kwargs)