Skip to content

Commit

Permalink
train 256 script and config
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Feb 12, 2025
1 parent 07e9c81 commit fb3349e
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 46 deletions.
115 changes: 115 additions & 0 deletions examples/hunyuanvideo/configs/train/stage1_t2i_256px.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
env:
mode: 0
jit_level: O0
seed: 42
distributed: False
debug: False

model:
name: "HYVideo-T/2-cfgdistill"
in_channels: 16
pretrained_model_path:
zero_stage: 3
text_states_dim: 4096
text_states_dim_2: 768
enable_ms_amp: True
amp_level: O2
factor_kwargs:
dtype: bf16
use_conv2d_patchify: True
attn_mode: flash
use_recompute: True
num_no_recompute: 0

vae:
vae_type: "884-16c-hy"
vae_precision: fp16
vae_tiling: True

dataset:
csv_path: CSV_PATH
video_folder: VIDEO_FOLDER
text_emb_folder:
ul2: UL2_FOLDER
byt5: BYT5_FOLDER
empty_text_emb:
ul2: EMPTY_TEXT_EMB
byt5: EMPTY_TEXT_EMB
deterministic_sample: False
text_drop_prob: 0.2
target_size: [ 256, 455 ]
apply_transforms_dataset: True
output_columns: [ "video", "ul2_caption", "byt5_caption" ]

dataloader:
batch_size: 70
shuffle: True
num_workers_dataset: 4

train:
steps: 30000
output_path: ../../output/stage1_t2i_256px # the path is relative to this config

sequence_parallel:
shards: 1

lr_scheduler:
name: constant
lr: 1.0e-4
warmup_steps: 1000

lr_reduce_on_plateau:
factor: 0.5
patience: 50 # in the number of validation steps, i.e., valid.frequency * patience steps
mode: min
min_delta: 0.01
min_lr: 1.0e-6

optimizer:
name: adamw_re
eps: 1e-15
betas: [ 0.9, 0.999 ]
weight_decay: 0.1

loss_scaler:
class_path: mindspore.nn.FixedLossScaleUpdateCell # or DynamicLossScaleUpdateCell in FP16
init_args:
loss_scale_value: 1

ema:
ema_decay: 0.9999
offloading: True

settings:
zero_stage: 0
gradient_accumulation_steps: 1
clip_grad: True
clip_norm: 1.0

save:
ckpt_save_policy: top_k
monitor_metric: eval_loss_smoothed
ckpt_save_interval: &save_interval 500
ckpt_max_keep: 10
log_interval: 1
save_ema_only: False
record_lr: False

valid:
sampling_steps: 10
frequency: *save_interval # train.save.ckpt_save_interval should be divisible by the frequency

dataset:
csv_path: CSV_PATH
video_folder: VIDEO_FOLDER
text_emb_folder:
ul2: UL2_FOLDER
byt5: BYT5_FOLDER
target_size: [ 256, 256 ]
apply_transforms_dataset: True
output_columns: [ "video", "ul2_caption", "byt5_caption" ]

dataloader:
batch_size: 50
shuffle: False
num_workers_dataset: 4
4 changes: 4 additions & 0 deletions examples/hunyuanvideo/hyvideo/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from hyvideo.constants import PRECISION_TO_TYPE

import mindspore as ms
from mindspore.communication.management import GlobalComm

Expand Down Expand Up @@ -48,6 +50,8 @@ def load_model(

# half model parameter
dtype = factor_kwargs["dtype"]
if isinstance(dtype, str):
dtype = PRECISION_TO_TYPE[dtype]
if dtype != ms.float32:
set_model_param_dtype(model, dtype=dtype)

Expand Down
64 changes: 47 additions & 17 deletions examples/hunyuanvideo/hyvideo/utils/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import logging
from typing import Dict, Literal, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

from hyvideo.constants import PRECISION_TO_TYPE
from hyvideo.modules import load_model
from jsonargparse.typing import Path_fr

import mindspore as ms
from mindspore import _no_grad, jit_class, nn
from mindspore import _no_grad, amp, jit_class, nn

from mindone.trainers.train_step import TrainOneStepWrapper
from mindone.utils.amp import auto_mixed_precision
from mindone.utils.params import load_param_into_net_with_filter

__all__ = ["MODEL_DTYPE", "no_grad", "init_model", "resume_train_net"]
Expand Down Expand Up @@ -58,31 +61,58 @@ def __exit__(self, *args):


def init_model(
name: Literal["llama-1B", "llama-5B", "llama-30B"],
name: str = "",
in_channels: int = 16,
out_channels: int = 16,
pretrained_model_path: Optional[Path_fr] = None,
zero_stage: Optional[int] = None,
text_states_dim: int = 4096,
text_states_dim_2: int = 768,
resume: bool = False,
enable_flash_attention: bool = True,
recompute_every_nth_block: Optional[int] = None,
not_recompute_fa: bool = False,
dtype: Literal["fp32", "fp16", "bf16"] = "fp32",
factor_kwargs: dict = {},
use_fp8: bool = False,
enable_ms_amp: bool = True,
amp_level: str = "O2",
):
# attn_implementation = "flash_attention" if enable_flash_attention else "eager"
# model = MODEL_SPEC[name](
# in_channels=in_channels,
# attn_implementation=attn_implementation,
# recompute_every_nth_block=recompute_every_nth_block,
# not_recompute_fa=not_recompute_fa,
# dtype=MODEL_DTYPE[dtype],
# )

model = None
dtype = factor_kwargs["dtype"]
dtype = PRECISION_TO_TYPE[dtype]
model = load_model(
name=name,
zero_stage=zero_stage,
text_states_dim=text_states_dim,
text_states_dim_2=text_states_dim_2,
in_channels=in_channels,
out_channels=out_channels,
factor_kwargs=factor_kwargs,
)

if resume:
logger.info("Resume training checkpoint provided, skipping weight loading.")
elif pretrained_model_path:
load_ckpt_params(model, pretrained_model_path.absolute)
else:
logger.info(f"Initialize {name} model randomly.")

if use_fp8:
raise NotImplementedError("fp8 is not supported yet.")

if enable_ms_amp and dtype != ms.float32:
logger.warning(f"Use MS auto mixed precision, amp_level: {amp_level}")
if amp_level == "auto":
amp.auto_mixed_precision(model, amp_level=amp_level, dtype=dtype)
else:
from hyvideo.modules.embed_layers import SinusoidalEmbedding
from hyvideo.modules.norm_layers import FP32LayerNorm, LayerNorm, RMSNorm

whitelist_ops = [
LayerNorm,
RMSNorm,
FP32LayerNorm,
SinusoidalEmbedding,
]
logger.info("custom fp32 cell for dit: ", whitelist_ops)
model = auto_mixed_precision(model, amp_level=amp_level, dtype=dtype, custom_fp32_cells=whitelist_ops)

return model


Expand Down
34 changes: 5 additions & 29 deletions examples/hunyuanvideo/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,18 @@
sys.path.append(mindone_lib_path)
sys.path.append(os.path.join(__dir__, ".."))
from hyvideo.acceleration import create_parallel_group
from hyvideo.constants import PRECISION_TO_TYPE # , PRECISIONS, PROMPT_TEMPLATE, VAE_PATH
from hyvideo.dataset import ImageVideoDataset, bucket_split_function
from hyvideo.diffusion.pipelines import DiffusionWithLoss
from hyvideo.diffusion.schedulers import RFlowEvalLoss, RFlowLossWrapper
from hyvideo.utils import EMA, init_model, resume_train_net
from hyvideo.utils.callbacks import PerfRecorderCallback, ReduceLROnPlateauByStep, ValidationCallback
from hyvideo.vae import AutoencoderKLCausal3D, load_vae
from hyvideo.vae.unet_causal_3d_blocks import GroupNorm, MSInterpolate, MSPad

from mindone.data import create_dataloader
from mindone.trainers import create_optimizer, create_scheduler
from mindone.trainers.callback import EvalSaveCallback, OverflowMonitor, StopAtStepCallback
from mindone.trainers.zero import prepare_train_network
from mindone.utils import count_params, init_train_env, set_logger
from mindone.utils.amp import auto_mixed_precision

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,37 +96,16 @@ def main(args):
):
logger.info("Initializing vae...")
vae, _, s_ratio, t_ratio = load_vae(
args.vae,
args.vae.vae_type,
logger=logger,
vae_precision=args.vae.vae_precision,
)
# vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
# vae_dtype = PRECISION_TO_TYPE(args.vae.vae_precision)

if args.vae_tiling:
if args.vae.vae_tiling:
vae.enable_tiling()

if args.vae_precision in ["fp16", "bf16"]:
amp_level = "O2"
vae_dtype = PRECISION_TO_TYPE[args.vae_precision]
if vae_dtype == ms.float16:
custom_fp32_cells = [GroupNorm] if args.vae_keep_gn_fp32 else []
else:
custom_fp32_cells = [MSPad, MSInterpolate]

vae = auto_mixed_precision(vae, amp_level, vae_dtype, custom_fp32_cells=custom_fp32_cells)
logger.info(
f"Set mixed precision to {amp_level} with dtype={args.vae_precision}, custom fp32_cells {custom_fp32_cells}"
)
elif args.vae_precision == "fp32":
vae_dtype = PRECISION_TO_TYPE[args.vae_precision]
else:
raise ValueError(f"Unsupported precision {args.vae_precision}")

if args.model.in_channels != vae.out_channels:
logger.warning(
f"The number of model input channels ({args.model.in_channels}) doesn't match the number of vae output"
f" channels ({vae.out_channels}). Setting it to {vae.out_channels}."
)
args.model.in_channels = vae.out_channels
else:
logger.info("vae latent folder provided. Skipping vae initialization.")
vae = None
Expand Down Expand Up @@ -299,7 +275,7 @@ def main(args):


if __name__ == "__main__":
parser = ArgumentParser(description="Movie Gen training script.")
parser = ArgumentParser(description="Hunyuan Video training script.")
parser.add_argument(
"-c",
"--config",
Expand Down

0 comments on commit fb3349e

Please sign in to comment.