diff --git a/examples/hunyuanvideo/hyvideo/eval/script/cal_lpips.sh b/examples/hunyuanvideo/hyvideo/eval/script/cal_lpips.sh deleted file mode 100644 index 1cccf58d17..0000000000 --- a/examples/hunyuanvideo/hyvideo/eval/script/cal_lpips.sh +++ /dev/null @@ -1,9 +0,0 @@ -python hyvideo/eval/eval_common_metrics.py \ - --real_video_dir datasets/MCL_JCV/ \ - --generated_video_dir datasets/MCL_JCV_generated/ \ - --batch_size 10 \ - --num_frames 33 \ - --crop_size 360 \ - --resolution 640 \ - --device 'Ascend' \ - --metric 'lpips' diff --git a/examples/hunyuanvideo/hyvideo/eval/script/cal_psnr.sh b/examples/hunyuanvideo/hyvideo/eval/script/cal_psnr.sh deleted file mode 100644 index 6a4168da31..0000000000 --- a/examples/hunyuanvideo/hyvideo/eval/script/cal_psnr.sh +++ /dev/null @@ -1,9 +0,0 @@ -python hyvideo/eval/eval_common_metrics.py \ - --real_video_dir datasets/MCL_JCV/ \ - --generated_video_dir datasets/MCL_JCV_generated/ \ - --batch_size 10 \ - --num_frames 33 \ - --crop_size 360 \ - --resolution 640 \ - --device 'Ascend' \ - --metric 'psnr' diff --git a/examples/hunyuanvideo/hyvideo/eval/script/cal_ssim.sh b/examples/hunyuanvideo/hyvideo/eval/script/cal_ssim.sh deleted file mode 100644 index df7c7d7abe..0000000000 --- a/examples/hunyuanvideo/hyvideo/eval/script/cal_ssim.sh +++ /dev/null @@ -1,9 +0,0 @@ -python hyvideo/eval/eval_common_metrics.py \ - --real_video_dir datasets/MCL_JCV/ \ - --generated_video_dir datasets/MCL_JCV_generated/ \ - --batch_size 10 \ - --num_frames 33 \ - --crop_size 360 \ - --resolution 640 \ - --device 'Ascend' \ - --metric 'ssim' diff --git a/examples/hunyuanvideo/hyvideo/utils/dataset_utils.py b/examples/hunyuanvideo/hyvideo/utils/dataset_utils.py index 3e63f3e6d5..d39a15639d 100644 --- a/examples/hunyuanvideo/hyvideo/utils/dataset_utils.py +++ b/examples/hunyuanvideo/hyvideo/utils/dataset_utils.py @@ -161,7 +161,7 @@ def __init__( real_data_file_path=None, sample_rate=1, crop_size=None, - resolution=128, + short_size=128, output_columns=["real", "generated"], ) -> None: super().__init__() @@ -179,7 +179,7 @@ def __init__( self.num_frames = num_frames self.sample_rate = sample_rate self.crop_size = crop_size - self.short_size = resolution + self.short_size = short_size self.output_columns = output_columns self.real_video_dir = real_video_dir diff --git a/examples/hunyuanvideo/hyvideo/train/commons.py b/examples/hunyuanvideo/scripts/commons.py similarity index 100% rename from examples/hunyuanvideo/hyvideo/train/commons.py rename to examples/hunyuanvideo/scripts/commons.py diff --git a/examples/hunyuanvideo/hyvideo/eval/cal_fvd.py b/examples/hunyuanvideo/scripts/eval/cal_fvd.py similarity index 100% rename from examples/hunyuanvideo/hyvideo/eval/cal_fvd.py rename to examples/hunyuanvideo/scripts/eval/cal_fvd.py diff --git a/examples/hunyuanvideo/hyvideo/eval/cal_lpips.py b/examples/hunyuanvideo/scripts/eval/cal_lpips.py similarity index 100% rename from examples/hunyuanvideo/hyvideo/eval/cal_lpips.py rename to examples/hunyuanvideo/scripts/eval/cal_lpips.py diff --git a/examples/hunyuanvideo/hyvideo/eval/cal_psnr.py b/examples/hunyuanvideo/scripts/eval/cal_psnr.py similarity index 100% rename from examples/hunyuanvideo/hyvideo/eval/cal_psnr.py rename to examples/hunyuanvideo/scripts/eval/cal_psnr.py diff --git a/examples/hunyuanvideo/hyvideo/eval/cal_ssim.py b/examples/hunyuanvideo/scripts/eval/cal_ssim.py similarity index 100% rename from examples/hunyuanvideo/hyvideo/eval/cal_ssim.py rename to examples/hunyuanvideo/scripts/eval/cal_ssim.py diff --git a/examples/hunyuanvideo/hyvideo/eval/eval_clip_score.py b/examples/hunyuanvideo/scripts/eval/eval_clip_score.py similarity index 100% rename from examples/hunyuanvideo/hyvideo/eval/eval_clip_score.py rename to examples/hunyuanvideo/scripts/eval/eval_clip_score.py diff --git a/examples/hunyuanvideo/hyvideo/eval/eval_common_metrics.py b/examples/hunyuanvideo/scripts/eval/eval_common_metrics.py similarity index 67% rename from examples/hunyuanvideo/hyvideo/eval/eval_common_metrics.py rename to examples/hunyuanvideo/scripts/eval/eval_common_metrics.py index 1874bfac01..8875e5f9e5 100644 --- a/examples/hunyuanvideo/hyvideo/eval/eval_common_metrics.py +++ b/examples/hunyuanvideo/scripts/eval/eval_common_metrics.py @@ -39,10 +39,11 @@ sys.path.insert(0, mindone_lib_path) sys.path.append(".") -from hyvideo.eval.cal_lpips import calculate_lpips -from hyvideo.eval.cal_psnr import calculate_psnr from hyvideo.utils.dataset_utils import VideoPairDataset, create_dataloader +from .cal_lpips import calculate_lpips +from .cal_psnr import calculate_psnr + flolpips_isavailable = False calculate_flolpips = None from hyvideo.eval.cal_ssim import calculate_ssim @@ -96,27 +97,64 @@ def calculate_common_metric(args, dataloader, dataset_size): def main(): parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) - parser.add_argument("--batch_size", type=int, default=2, help="Batch size to use") - parser.add_argument("--real_video_dir", type=str, help=("the path of real videos`")) - parser.add_argument("--real_data_file_path", type=str, default=None, help=("the path of real videos csv file`")) - parser.add_argument("--generated_video_dir", type=str, help=("the path of generated videos`")) + parser.add_argument("--batch-size", type=int, default=2, help="Batch size to use") + parser.add_argument("--real-video-dir", type=str, help="The path of real videos") + parser.add_argument("--real-data-file-path", type=str, default=None, help="The path of real videos CSV file") + parser.add_argument("--generated-video-dir", type=str, help="The path of generated videos") parser.add_argument("--device", type=str, default=None, help="Device to use. Like GPU or Ascend") parser.add_argument( - "--num_workers", + "--num-workers", type=int, default=8, - help=("Number of processes to use for data loading. " "Defaults to `min(8, num_cpus)`"), + help="Number of processes to use for data loading. Defaults to `min(8, num_cpus)`", + ) + parser.add_argument("--sample-fps", type=int, default=30, help="Sampling frames per second for the video.") + parser.add_argument( + "--short-size", + type=int, + default=256, + help=( + "Resize the video frames to this size. If provided, the smaller dimension will be resized to this value " + "while maintaining the aspect ratio. " + ), + ) + parser.add_argument( + "--height", + type=int, + default=256, + help="Height to crop the video frames.", + ) + + parser.add_argument( + "--width", + type=int, + default=256, + help="Width to crop the video frames.", + ) + parser.add_argument("--num-frames", type=int, default=100, help="Number of frames to sample from the video.") + parser.add_argument("--sample-rate", type=int, default=1, help="Sampling rate for video frames.") + parser.add_argument("--subset-size", type=int, default=None, help="Subset size to evaluate.") + parser.add_argument( + "--metric", + type=str, + default="fvd", + choices=["fvd", "psnr", "ssim", "lpips", "flolpips"], + help="Metric to calculate.", + ) + parser.add_argument( + "--fvd-method", + type=str, + default="styleganv", + choices=["styleganv", "videogpt"], + help="Method to use for FVD calculation.", ) - parser.add_argument("--sample_fps", type=int, default=30) - parser.add_argument("--resolution", type=int, default=336) - parser.add_argument("--crop_size", type=int, default=None) - parser.add_argument("--num_frames", type=int, default=100) - parser.add_argument("--sample_rate", type=int, default=1) - parser.add_argument("--subset_size", type=int, default=None) - parser.add_argument("--metric", type=str, default="fvd", choices=["fvd", "psnr", "ssim", "lpips", "flolpips"]) - parser.add_argument("--fvd_method", type=str, default="styleganv", choices=["styleganv", "videogpt"]) args = parser.parse_args() + # Check if short_size is less than the minimum of height and width + if args.short_size < min(args.height, args.width): + raise ValueError( + f"short_size ({args.short_size}) cannot be less than the minimum of height ({args.height}) and width ({args.width})." + ) if args.num_workers is None: try: @@ -137,8 +175,8 @@ def main(): num_frames=args.num_frames, real_data_file_path=args.real_data_file_path, sample_rate=args.sample_rate, - crop_size=args.crop_size, - resolution=args.resolution, + crop_size=(args.height, args.width), + short_size=args.short_size, ) dataloader = create_dataloader( diff --git a/examples/hunyuanvideo/scripts/eval/script/cal_lpips.sh b/examples/hunyuanvideo/scripts/eval/script/cal_lpips.sh new file mode 100644 index 0000000000..251a9bbf30 --- /dev/null +++ b/examples/hunyuanvideo/scripts/eval/script/cal_lpips.sh @@ -0,0 +1,10 @@ +python scripts/eval/eval_common_metrics.py \ + --real-video-dir datasets/MCL_JCV/ \ + --generated-video-dir datasets/MCL_JCV_generated/ \ + --batch-size 10 \ + --num-frames 33 \ + --height 360 \ + --width 640 \ + --short-size 360 \ + --device 'Ascend' \ + --metric 'lpips' diff --git a/examples/hunyuanvideo/scripts/eval/script/cal_psnr.sh b/examples/hunyuanvideo/scripts/eval/script/cal_psnr.sh new file mode 100644 index 0000000000..f61b25826c --- /dev/null +++ b/examples/hunyuanvideo/scripts/eval/script/cal_psnr.sh @@ -0,0 +1,10 @@ +python scripts/eval/eval_common_metrics.py \ + --real-video-dir datasets/MCL_JCV/ \ + --generated-video-dir datasets/MCL_JCV_generated/ \ + --batch-size 10 \ + --num-frames 33 \ + --height 360 \ + --width 640 \ + --short-size 360 \ + --device 'Ascend' \ + --metric 'psnr' diff --git a/examples/hunyuanvideo/scripts/eval/script/cal_ssim.sh b/examples/hunyuanvideo/scripts/eval/script/cal_ssim.sh new file mode 100644 index 0000000000..f8c3238fc2 --- /dev/null +++ b/examples/hunyuanvideo/scripts/eval/script/cal_ssim.sh @@ -0,0 +1,10 @@ +python scripts/eval/eval_common_metrics.py \ + --real-video-dir datasets/MCL_JCV/ \ + --generated-video-dir datasets/MCL_JCV_generated/ \ + --batch-size 10 \ + --num-frames 33 \ + --height 360 \ + --width 640 \ + --short-size 360 \ + --device 'Ascend' \ + --metric 'ssim' diff --git a/examples/hunyuanvideo/scripts/run_t2v_sample.sh b/examples/hunyuanvideo/scripts/hyvideo/run_t2v_sample.sh similarity index 96% rename from examples/hunyuanvideo/scripts/run_t2v_sample.sh rename to examples/hunyuanvideo/scripts/hyvideo/run_t2v_sample.sh index e1b871c699..d4e2f6dda7 100644 --- a/examples/hunyuanvideo/scripts/run_t2v_sample.sh +++ b/examples/hunyuanvideo/scripts/hyvideo/run_t2v_sample.sh @@ -1,7 +1,7 @@ # plot memory usage and compile info # export MS_DEV_RUNTIME_CONF="memory_statistics:True,compile_statistics:True" -python sample_video.py \ +python scripts/sample_video.py \ --ms-mode 0 \ --video-size 544 960 \ --video-length 129 \ diff --git a/examples/hunyuanvideo/scripts/run_t2v_sample_ddp.sh b/examples/hunyuanvideo/scripts/hyvideo/run_t2v_sample_ddp.sh similarity index 94% rename from examples/hunyuanvideo/scripts/run_t2v_sample_ddp.sh rename to examples/hunyuanvideo/scripts/hyvideo/run_t2v_sample_ddp.sh index 6cc4ced156..2603b5a3eb 100644 --- a/examples/hunyuanvideo/scripts/run_t2v_sample_ddp.sh +++ b/examples/hunyuanvideo/scripts/hyvideo/run_t2v_sample_ddp.sh @@ -1,6 +1,6 @@ msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="./sample_videos/parallel_logs/" \ - sample_video.py \ + scripts/sample_video.py \ --video-size 256 256 \ --video-length 29 \ --infer-steps 50 \ diff --git a/examples/hunyuanvideo/scripts/train_t2v_debug.sh b/examples/hunyuanvideo/scripts/hyvideo/train_t2v_debug.sh similarity index 96% rename from examples/hunyuanvideo/scripts/train_t2v_debug.sh rename to examples/hunyuanvideo/scripts/hyvideo/train_t2v_debug.sh index 310935598f..c81b3aef5b 100644 --- a/examples/hunyuanvideo/scripts/train_t2v_debug.sh +++ b/examples/hunyuanvideo/scripts/hyvideo/train_t2v_debug.sh @@ -1,4 +1,4 @@ -python hyvideo/train/train_t2v.py \ +python scripts/train_t2v.py \ --model "HYVideo-T/2-depth1" \ --cache_dir "./ckpts" \ --dataset t2v \ diff --git a/examples/hunyuanvideo/scripts/train_t2v_zero3.sh b/examples/hunyuanvideo/scripts/hyvideo/train_t2v_zero3.sh similarity index 97% rename from examples/hunyuanvideo/scripts/train_t2v_zero3.sh rename to examples/hunyuanvideo/scripts/hyvideo/train_t2v_zero3.sh index de9836705e..ed68ab0894 100644 --- a/examples/hunyuanvideo/scripts/train_t2v_zero3.sh +++ b/examples/hunyuanvideo/scripts/hyvideo/train_t2v_zero3.sh @@ -1,7 +1,7 @@ NUM_FRAME=29 export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=8000 --log_dir="t2v-video3d-${NUM_FRAME}x256p_zero3/parallel_logs" \ - hyvideo/train/train_t2v.py \ + scripts/train_t2v.py \ --model "HYVideo-T/2-cfgdistill" \ --cache_dir "./ckpts" \ --dataset t2v \ diff --git a/examples/hunyuanvideo/hyvideo/run_text_encoder.py b/examples/hunyuanvideo/scripts/run_text_encoder.py similarity index 97% rename from examples/hunyuanvideo/hyvideo/run_text_encoder.py rename to examples/hunyuanvideo/scripts/run_text_encoder.py index 1c25d7b08a..bd4b68729f 100644 --- a/examples/hunyuanvideo/hyvideo/run_text_encoder.py +++ b/examples/hunyuanvideo/scripts/run_text_encoder.py @@ -12,12 +12,12 @@ sys.path.insert(0, mindone_lib_path) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) -from constants import PRECISIONS, PROMPT_TEMPLATE -from dataset.text_dataset import create_dataloader -from dataset.transform import text_preprocessing -from text_encoder import TextEncoder -from utils.message_utils import print_banner -from utils.ms_utils import init_env +from hyvideo.constants import PRECISIONS, PROMPT_TEMPLATE +from hyvideo.dataset.text_dataset import create_dataloader +from hyvideo.dataset.transform import text_preprocessing +from hyvideo.text_encoder import TextEncoder +from hyvideo.utils.message_utils import print_banner +from hyvideo.utils.ms_utils import init_env from mindone.utils.config import str2bool from mindone.utils.logger import set_logger diff --git a/examples/hunyuanvideo/hyvideo/run_vae.py b/examples/hunyuanvideo/scripts/run_vae.py similarity index 100% rename from examples/hunyuanvideo/hyvideo/run_vae.py rename to examples/hunyuanvideo/scripts/run_vae.py diff --git a/examples/hunyuanvideo/sample_video.py b/examples/hunyuanvideo/scripts/sample_video.py similarity index 100% rename from examples/hunyuanvideo/sample_video.py rename to examples/hunyuanvideo/scripts/sample_video.py diff --git a/examples/hunyuanvideo/scripts/run_text_encoder.sh b/examples/hunyuanvideo/scripts/text_encoder/run_text_encoder.sh similarity index 100% rename from examples/hunyuanvideo/scripts/run_text_encoder.sh rename to examples/hunyuanvideo/scripts/text_encoder/run_text_encoder.sh diff --git a/examples/hunyuanvideo/hyvideo/train/train_t2v.py b/examples/hunyuanvideo/scripts/train_t2v.py similarity index 99% rename from examples/hunyuanvideo/hyvideo/train/train_t2v.py rename to examples/hunyuanvideo/scripts/train_t2v.py index cd14eb4700..faa0889995 100644 --- a/examples/hunyuanvideo/hyvideo/train/train_t2v.py +++ b/examples/hunyuanvideo/scripts/train_t2v.py @@ -20,7 +20,6 @@ from hyvideo.diffusion.net_with_loss import DiffusionWithLoss from hyvideo.diffusion.schedulers.rectified_flow_trainer import RFlowEvalLoss, RFlowLossWrapper from hyvideo.modules.models import HUNYUAN_VIDEO_CONFIG, HYVideoDiffusionTransformer -from hyvideo.train.commons import create_loss_scaler, parse_args from hyvideo.utils.callbacks import EMAEvalSwapCallback, PerfRecorderCallback from hyvideo.utils.dataset_utils import Collate, LengthGroupedSampler from hyvideo.utils.ema import EMA @@ -43,6 +42,8 @@ from mindone.utils.logger import set_logger from mindone.utils.params import count_params +from .commons import create_loss_scaler, parse_args + logger = logging.getLogger(__name__) @@ -206,7 +207,7 @@ def main(args): # Load text encoder if not args.text_embed_cache: print_banner("text encoder init") - from hyvideo.run_text_encoder import build_model + from .run_text_encoder import build_model text_encoder_1, text_encoder_2 = build_model(args, logger) text_encoder_1_dtype, text_encoder_2_dtype = args.text_encoder_precision, args.text_encoder_precision_2 diff --git a/examples/hunyuanvideo/scripts/vae/recon_image.sh b/examples/hunyuanvideo/scripts/vae/recon_image.sh new file mode 100644 index 0000000000..c142dfa3f6 --- /dev/null +++ b/examples/hunyuanvideo/scripts/vae/recon_image.sh @@ -0,0 +1,7 @@ +python scripts/run_vae.py \ + --input-type image \ + --image-path "path/to/input_image.jpg" \ + --output-path "path/to/output_directory" \ + --rec-path "reconstructed_image.jpg" \ + --height 336 \ + --width 336 \ diff --git a/examples/hunyuanvideo/scripts/vae/recon_video.sh b/examples/hunyuanvideo/scripts/vae/recon_video.sh new file mode 100644 index 0000000000..67a9c8afe9 --- /dev/null +++ b/examples/hunyuanvideo/scripts/vae/recon_video.sh @@ -0,0 +1,10 @@ +python scripts/run_vae.py \ + --input-type video \ + --video-path "path/to/input_video.mp4" \ + --output-path "path/to/output_directory" \ + --rec-path "reconstructed_video.mp4" \ + --height 336 \ + --width 336 \ + --num-frames 65 \ + --sample-rate 1 \ + --fps 30 \ diff --git a/examples/hunyuanvideo/scripts/vae/recon_video_folder.sh b/examples/hunyuanvideo/scripts/vae/recon_video_folder.sh new file mode 100644 index 0000000000..0798f7d18e --- /dev/null +++ b/examples/hunyuanvideo/scripts/vae/recon_video_folder.sh @@ -0,0 +1,9 @@ +python run_vae.py \ + --input-type folder \ + --real-video-dir datasets/MCL_JCV/ \ + --generated-video-dir "path/to/generated_videos" \ + --output-path datasets/MCL_JCV_generated/ \ + --height 336 \ + --width 336 \ + --num-frames 65 \ + --sample-rate 1 \