Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Feb 11, 2025
1 parent 7035e2a commit 15f06c2
Show file tree
Hide file tree
Showing 26 changed files with 127 additions and 59 deletions.
9 changes: 0 additions & 9 deletions examples/hunyuanvideo/hyvideo/eval/script/cal_lpips.sh

This file was deleted.

9 changes: 0 additions & 9 deletions examples/hunyuanvideo/hyvideo/eval/script/cal_psnr.sh

This file was deleted.

9 changes: 0 additions & 9 deletions examples/hunyuanvideo/hyvideo/eval/script/cal_ssim.sh

This file was deleted.

4 changes: 2 additions & 2 deletions examples/hunyuanvideo/hyvideo/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(
real_data_file_path=None,
sample_rate=1,
crop_size=None,
resolution=128,
short_size=128,
output_columns=["real", "generated"],
) -> None:
super().__init__()
Expand All @@ -179,7 +179,7 @@ def __init__(
self.num_frames = num_frames
self.sample_rate = sample_rate
self.crop_size = crop_size
self.short_size = resolution
self.short_size = short_size
self.output_columns = output_columns
self.real_video_dir = real_video_dir

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@
sys.path.insert(0, mindone_lib_path)
sys.path.append(".")

from hyvideo.eval.cal_lpips import calculate_lpips
from hyvideo.eval.cal_psnr import calculate_psnr
from hyvideo.utils.dataset_utils import VideoPairDataset, create_dataloader

from .cal_lpips import calculate_lpips
from .cal_psnr import calculate_psnr

flolpips_isavailable = False
calculate_flolpips = None
from hyvideo.eval.cal_ssim import calculate_ssim
Expand Down Expand Up @@ -96,27 +97,64 @@ def calculate_common_metric(args, dataloader, dataset_size):

def main():
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("--batch_size", type=int, default=2, help="Batch size to use")
parser.add_argument("--real_video_dir", type=str, help=("the path of real videos`"))
parser.add_argument("--real_data_file_path", type=str, default=None, help=("the path of real videos csv file`"))
parser.add_argument("--generated_video_dir", type=str, help=("the path of generated videos`"))
parser.add_argument("--batch-size", type=int, default=2, help="Batch size to use")
parser.add_argument("--real-video-dir", type=str, help="The path of real videos")
parser.add_argument("--real-data-file-path", type=str, default=None, help="The path of real videos CSV file")
parser.add_argument("--generated-video-dir", type=str, help="The path of generated videos")
parser.add_argument("--device", type=str, default=None, help="Device to use. Like GPU or Ascend")
parser.add_argument(
"--num_workers",
"--num-workers",
type=int,
default=8,
help=("Number of processes to use for data loading. " "Defaults to `min(8, num_cpus)`"),
help="Number of processes to use for data loading. Defaults to `min(8, num_cpus)`",
)
parser.add_argument("--sample-fps", type=int, default=30, help="Sampling frames per second for the video.")
parser.add_argument(
"--short-size",
type=int,
default=256,
help=(
"Resize the video frames to this size. If provided, the smaller dimension will be resized to this value "
"while maintaining the aspect ratio. "
),
)
parser.add_argument(
"--height",
type=int,
default=256,
help="Height to crop the video frames.",
)

parser.add_argument(
"--width",
type=int,
default=256,
help="Width to crop the video frames.",
)
parser.add_argument("--num-frames", type=int, default=100, help="Number of frames to sample from the video.")
parser.add_argument("--sample-rate", type=int, default=1, help="Sampling rate for video frames.")
parser.add_argument("--subset-size", type=int, default=None, help="Subset size to evaluate.")
parser.add_argument(
"--metric",
type=str,
default="fvd",
choices=["fvd", "psnr", "ssim", "lpips", "flolpips"],
help="Metric to calculate.",
)
parser.add_argument(
"--fvd-method",
type=str,
default="styleganv",
choices=["styleganv", "videogpt"],
help="Method to use for FVD calculation.",
)
parser.add_argument("--sample_fps", type=int, default=30)
parser.add_argument("--resolution", type=int, default=336)
parser.add_argument("--crop_size", type=int, default=None)
parser.add_argument("--num_frames", type=int, default=100)
parser.add_argument("--sample_rate", type=int, default=1)
parser.add_argument("--subset_size", type=int, default=None)
parser.add_argument("--metric", type=str, default="fvd", choices=["fvd", "psnr", "ssim", "lpips", "flolpips"])
parser.add_argument("--fvd_method", type=str, default="styleganv", choices=["styleganv", "videogpt"])

args = parser.parse_args()
# Check if short_size is less than the minimum of height and width
if args.short_size < min(args.height, args.width):
raise ValueError(
f"short_size ({args.short_size}) cannot be less than the minimum of height ({args.height}) and width ({args.width})."
)

if args.num_workers is None:
try:
Expand All @@ -137,8 +175,8 @@ def main():
num_frames=args.num_frames,
real_data_file_path=args.real_data_file_path,
sample_rate=args.sample_rate,
crop_size=args.crop_size,
resolution=args.resolution,
crop_size=(args.height, args.width),
short_size=args.short_size,
)

dataloader = create_dataloader(
Expand Down
10 changes: 10 additions & 0 deletions examples/hunyuanvideo/scripts/eval/script/cal_lpips.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
python scripts/eval/eval_common_metrics.py \
--real-video-dir datasets/MCL_JCV/ \
--generated-video-dir datasets/MCL_JCV_generated/ \
--batch-size 10 \
--num-frames 33 \
--height 360 \
--width 640 \
--short-size 360 \
--device 'Ascend' \
--metric 'lpips'
10 changes: 10 additions & 0 deletions examples/hunyuanvideo/scripts/eval/script/cal_psnr.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
python scripts/eval/eval_common_metrics.py \
--real-video-dir datasets/MCL_JCV/ \
--generated-video-dir datasets/MCL_JCV_generated/ \
--batch-size 10 \
--num-frames 33 \
--height 360 \
--width 640 \
--short-size 360 \
--device 'Ascend' \
--metric 'psnr'
10 changes: 10 additions & 0 deletions examples/hunyuanvideo/scripts/eval/script/cal_ssim.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
python scripts/eval/eval_common_metrics.py \
--real-video-dir datasets/MCL_JCV/ \
--generated-video-dir datasets/MCL_JCV_generated/ \
--batch-size 10 \
--num-frames 33 \
--height 360 \
--width 640 \
--short-size 360 \
--device 'Ascend' \
--metric 'ssim'
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# plot memory usage and compile info
# export MS_DEV_RUNTIME_CONF="memory_statistics:True,compile_statistics:True"

python sample_video.py \
python scripts/sample_video.py \
--ms-mode 0 \
--video-size 544 960 \
--video-length 129 \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=9000 --log_dir="./sample_videos/parallel_logs/" \
sample_video.py \
scripts/sample_video.py \
--video-size 256 256 \
--video-length 29 \
--infer-steps 50 \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
python hyvideo/train/train_t2v.py \
python scripts/train_t2v.py \
--model "HYVideo-T/2-depth1" \
--cache_dir "./ckpts" \
--dataset t2v \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
NUM_FRAME=29
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
msrun --bind_core=True --worker_num=8 --local_worker_num=8 --master_port=8000 --log_dir="t2v-video3d-${NUM_FRAME}x256p_zero3/parallel_logs" \
hyvideo/train/train_t2v.py \
scripts/train_t2v.py \
--model "HYVideo-T/2-cfgdistill" \
--cache_dir "./ckpts" \
--dataset t2v \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
sys.path.insert(0, mindone_lib_path)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))

from constants import PRECISIONS, PROMPT_TEMPLATE
from dataset.text_dataset import create_dataloader
from dataset.transform import text_preprocessing
from text_encoder import TextEncoder
from utils.message_utils import print_banner
from utils.ms_utils import init_env
from hyvideo.constants import PRECISIONS, PROMPT_TEMPLATE
from hyvideo.dataset.text_dataset import create_dataloader
from hyvideo.dataset.transform import text_preprocessing
from hyvideo.text_encoder import TextEncoder
from hyvideo.utils.message_utils import print_banner
from hyvideo.utils.ms_utils import init_env

from mindone.utils.config import str2bool
from mindone.utils.logger import set_logger
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from hyvideo.diffusion.net_with_loss import DiffusionWithLoss
from hyvideo.diffusion.schedulers.rectified_flow_trainer import RFlowEvalLoss, RFlowLossWrapper
from hyvideo.modules.models import HUNYUAN_VIDEO_CONFIG, HYVideoDiffusionTransformer
from hyvideo.train.commons import create_loss_scaler, parse_args
from hyvideo.utils.callbacks import EMAEvalSwapCallback, PerfRecorderCallback
from hyvideo.utils.dataset_utils import Collate, LengthGroupedSampler
from hyvideo.utils.ema import EMA
Expand All @@ -43,6 +42,8 @@
from mindone.utils.logger import set_logger
from mindone.utils.params import count_params

from .commons import create_loss_scaler, parse_args

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -206,7 +207,7 @@ def main(args):
# Load text encoder
if not args.text_embed_cache:
print_banner("text encoder init")
from hyvideo.run_text_encoder import build_model
from .run_text_encoder import build_model

text_encoder_1, text_encoder_2 = build_model(args, logger)
text_encoder_1_dtype, text_encoder_2_dtype = args.text_encoder_precision, args.text_encoder_precision_2
Expand Down
7 changes: 7 additions & 0 deletions examples/hunyuanvideo/scripts/vae/recon_image.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
python scripts/run_vae.py \
--input-type image \
--image-path "path/to/input_image.jpg" \
--output-path "path/to/output_directory" \
--rec-path "reconstructed_image.jpg" \
--height 336 \
--width 336 \
10 changes: 10 additions & 0 deletions examples/hunyuanvideo/scripts/vae/recon_video.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
python scripts/run_vae.py \
--input-type video \
--video-path "path/to/input_video.mp4" \
--output-path "path/to/output_directory" \
--rec-path "reconstructed_video.mp4" \
--height 336 \
--width 336 \
--num-frames 65 \
--sample-rate 1 \
--fps 30 \
9 changes: 9 additions & 0 deletions examples/hunyuanvideo/scripts/vae/recon_video_folder.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
python run_vae.py \
--input-type folder \
--real-video-dir datasets/MCL_JCV/ \
--generated-video-dir "path/to/generated_videos" \
--output-path datasets/MCL_JCV_generated/ \
--height 336 \
--width 336 \
--num-frames 65 \
--sample-rate 1 \

0 comments on commit 15f06c2

Please sign in to comment.