Skip to content

Commit

Permalink
update load_model
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Feb 12, 2025
1 parent e3eeb82 commit 07e9c81
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 18 deletions.
5 changes: 4 additions & 1 deletion examples/hunyuanvideo/hyvideo/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ def from_pretrained(cls, pretrained_model_path, args, **kwargs):
dtype = factor_kwargs["dtype"]
rank_id = kwargs.get("rank_id", 0)
model = load_model(
args,
name=args.model,
zero_stage=args.zero_stage,
text_states_dim=args.text_states_dim,
text_states_dim_2=args.text_states_dim_2,
in_channels=in_channels,
out_channels=out_channels,
factor_kwargs=factor_kwargs,
Expand Down
27 changes: 19 additions & 8 deletions examples/hunyuanvideo/hyvideo/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,42 @@
from .models import HUNYUAN_VIDEO_CONFIG, HYVideoDiffusionTransformer


def load_model(args, in_channels, out_channels, factor_kwargs):
def load_model(
name,
in_channels,
out_channels,
factor_kwargs,
zero_stage=None,
text_states_dim: int = 4096,
text_states_dim_2: int = 768,
):
"""load hunyuan video model
Args:
args (dict): model args
in_channels (int): input channels number
out_channels (int): output channels number
zero_stage (int, optional): zero stage. Defaults to None.
text_states_dim (int, optional): text states dim of text encoder 1. Defaults to 4096.
text_states_dim (int ,optional): text states dim of text encoder 2. Defaults to 768.
factor_kwargs (dict): factor kwargs
Returns:
model (nn.Module): The hunyuan video model
"""
if args.model in HUNYUAN_VIDEO_CONFIG.keys():
if name in HUNYUAN_VIDEO_CONFIG.keys():
model = HYVideoDiffusionTransformer(
args,
text_states_dim=text_states_dim,
text_states_dim_2=text_states_dim_2,
in_channels=in_channels,
out_channels=out_channels,
**HUNYUAN_VIDEO_CONFIG[args.model],
**HUNYUAN_VIDEO_CONFIG[name],
**factor_kwargs,
)
if args.zero_stage is not None:
assert args.zero_stage in [0, 1, 2, 3], "zero_stage should be in [0, 1, 2, 3]"
if zero_stage is not None:
assert zero_stage in [0, 1, 2, 3], "zero_stage should be in [0, 1, 2, 3]"
model = prepare_network(
model,
zero_stage=args.zero_stage,
zero_stage=zero_stage,
op_group=GlobalComm.WORLD_COMM_GROUP,
)

Expand Down
16 changes: 9 additions & 7 deletions examples/hunyuanvideo/hyvideo/modules/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import Any, List, Optional, Tuple
from typing import List, Optional, Tuple

import mindspore as ms
from mindspore import mint, nn, ops
Expand Down Expand Up @@ -380,8 +380,10 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
text_state_dim: int
The text embedding dim of text encoder 1
text_state_dim_2: int
The text embedding dim of text encoder2
patch_size: list
The size of the patch.
in_channels: int
Expand Down Expand Up @@ -425,7 +427,8 @@ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
args: Any,
text_states_dim: int = 4096,
text_states_dim_2: int = 768,
patch_size: list = [1, 2, 2],
in_channels: int = 4, # Should be VAE.config.latent_channels.
out_channels: int = None,
Expand Down Expand Up @@ -468,9 +471,8 @@ def __init__(
self.use_attention_mask = use_attention_mask
self.text_projection = text_projection

# TODO: no need to use args, just parse these two params
self.text_states_dim = args.text_states_dim
self.text_states_dim_2 = args.text_states_dim_2
self.text_states_dim = text_states_dim
self.text_states_dim_2 = text_states_dim_2

self.param_dtype = dtype

Expand Down
3 changes: 2 additions & 1 deletion examples/hunyuanvideo/tests/test_dit_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ def test_hyvtransformer(pt_ckpt=None, pt_np=None, debug=True, dtype=ms.float32,
# model
factor_kwargs = {"dtype": dtype}
net = HYVideoDiffusionTransformer(
args,
text_states_dim=args.text_states_dim,
text_states_dim_2=args.text_states_dim_2,
in_channels=C,
use_conv2d_patchify=True,
attn_mode="vanilla",
Expand Down
5 changes: 4 additions & 1 deletion examples/hunyuanvideo/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def test_pipeline():
factor_kwargs = {"dtype": dtype, "attn_mode": "vanilla"}
print("creating model")
model = load_model(
args,
name=args.model,
zero_stage=args.zero_stage,
text_states_dim=args.text_states_dim,
text_states_dim_2=args.text_states_dim_2,
in_channels=args.latent_channels,
out_channels=args.latent_channels,
factor_kwargs=factor_kwargs,
Expand Down

0 comments on commit 07e9c81

Please sign in to comment.