Skip to content

Commit

Permalink
dataset updates from moviegen
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Feb 11, 2025
1 parent 15f2ca9 commit 906ec5d
Show file tree
Hide file tree
Showing 21 changed files with 723 additions and 2,010 deletions.
2 changes: 2 additions & 0 deletions examples/hunyuanvideo/hyvideo/acceleration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .communications import *
from .parallel_states import *
71 changes: 71 additions & 0 deletions examples/hunyuanvideo/hyvideo/acceleration/communications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Callable, Literal, Tuple

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore.communication import GlobalComm, get_group_size, get_rank

__all__ = ["SplitFowardGatherBackward", "GatherFowardSplitBackward"]


def _split(x: Tensor, dim: int, rank: int, world_size: int) -> Tensor:
dim_size = x.shape[dim]
tensor_list = x.split(dim_size // world_size, axis=dim)
x = tensor_list[rank]
return x


def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor:
x = x.swapaxes(0, dim)
x = func(x)
x = x.swapaxes(dim, 0)
return x


class SplitFowardGatherBackward(nn.Cell):
def __init__(
self, dim: int = 0, grad_scale: Literal["up", "down"] = "down", group: str = GlobalComm.WORLD_COMM_GROUP
) -> None:
super().__init__()
self.dim = dim
self.rank = get_rank(group)
self.world_size = get_group_size(group)
self.gather = ops.AllGather(group=group)

if grad_scale == "up":
self.scale = self.world_size
else:
self.scale = 1 / self.world_size

def construct(self, x: Tensor) -> Tensor:
return _split(x, self.dim, self.rank, self.world_size)

def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]:
dout = dout * self.scale
dout = _communicate_along_dim(dout, self.dim, self.gather)
return (dout,)


class GatherFowardSplitBackward(nn.Cell):
def __init__(
self, dim: int = 0, grad_scale: Literal["up", "down"] = "up", group: str = GlobalComm.WORLD_COMM_GROUP
) -> None:
super().__init__()
self.dim = dim
self.rank = get_rank(group)
self.world_size = get_group_size(group)
self.gather = ops.AllGather(group=group)

if grad_scale == "up":
self.scale = self.world_size
else:
self.scale = 1 / self.world_size

def construct(self, x: Tensor) -> Tensor:
x = _communicate_along_dim(x, self.dim, self.gather)
return x

def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]:
dout = dout * self.scale
dout = _split(dout, self.dim, self.rank, self.world_size)
return (dout,)
35 changes: 35 additions & 0 deletions examples/hunyuanvideo/hyvideo/acceleration/parallel_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Optional

from mindspore.communication import create_group, get_group_size, get_rank

__all__ = ["set_sequence_parallel_group", "get_sequence_parallel_group", "create_parallel_group"]

_GLOBAL_PARALLEL_GROUPS = dict()


def set_sequence_parallel_group(group: str) -> None:
_GLOBAL_PARALLEL_GROUPS["sequence"] = group


def get_sequence_parallel_group() -> Optional[str]:
return _GLOBAL_PARALLEL_GROUPS.get("sequence", None)


def create_parallel_group(shards: int) -> None:
if shards <= 1:
raise ValueError(
f"The number of sequence parallel shards must be larger than 1 to enable sequence parallel, but got {shards}."
)

device_num = get_group_size()
if device_num % shards != 0:
raise ValueError(
f"Total number of devices ({device_num}) must be divisible by the number of sequence parallel shards ({shards})."
)

rank_id = get_rank()
sp_group_id = rank_id // shards
sp_group_rank_ids = list(range(sp_group_id * shards, (sp_group_id + 1) * shards))
sp_group_name = f"sp_group_{sp_group_id}"
create_group(sp_group_name, sp_group_rank_ids)
set_sequence_parallel_group(sp_group_name)
38 changes: 0 additions & 38 deletions examples/hunyuanvideo/hyvideo/dataset/README.md

This file was deleted.

82 changes: 2 additions & 80 deletions examples/hunyuanvideo/hyvideo/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,2 @@
import logging
from functools import partial

import cv2
from albumentations import Compose, Lambda, Resize, ToFloat
from hyvideo.text_encoder import load_tokenizer

from .t2v_datasets import T2V_dataset
from .transform import TemporalRandomCrop, center_crop_th_tw, maxhxw_resize, spatial_stride_crop_video

logger = logging.getLogger(__name__)


def getdataset(args, dataset_file):
temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x
norm_fun = lambda x: 2.0 * x - 1.0

def norm_func_albumentation(image, **kwargs):
return norm_fun(image)

mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC}
targets = {"image{}".format(i): "image" for i in range(args.num_frames)}

if args.force_resolution:
assert (args.max_height is not None) and (
args.max_width is not None
), "set max_height and max_width for fixed resolution"
resize = [
Lambda(
name="crop_centercrop",
image=partial(center_crop_th_tw, th=args.max_height, tw=args.max_width, top_crop=False),
p=1.0,
),
Resize(args.max_height, args.max_width, interpolation=mapping["bilinear"]),
]
else: # dynamic resolution
assert args.max_hxw is not None, "set max_hxw for dynamic resolution"
resize = [
Lambda(
name="maxhxw_resize",
image=partial(maxhxw_resize, max_hxw=args.max_hxw, interpolation_mode=mapping["bilinear"]),
p=1.0,
),
Lambda(
name="spatial_stride_crop",
image=partial(spatial_stride_crop_video, stride=args.hw_stride), # default stride=32
p=1.0,
),
]

transform = Compose(
[*resize, ToFloat(255.0), Lambda(name="ae_norm", image=norm_func_albumentation, p=1.0)],
additional_targets=targets,
)

tokenizer_1, _ = load_tokenizer(
tokenizer_type=args.tokenizer,
tokenizer_path=args.tokenizer_path if args.tokenizer_path is not None else args.text_encoder_path,
padding_side="right",
logger=logger,
)
tokenizer_2, _ = load_tokenizer(
tokenizer_type=args.tokenizer_2,
tokenizer_path=args.tokenizer_path_2 if args.tokenizer_path_2 is not None else args.text_encoder_path_2,
padding_side="right",
logger=logger,
)
if args.dataset == "t2v":
return T2V_dataset(
args,
transform=transform,
temporal_sample=temporal_sample,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
return_text_emb=args.text_embed_cache,
)
elif args.dataset == "inpaint" or args.dataset == "i2v":
raise NotImplementedError

raise NotImplementedError(args.dataset)
from .buckets import bucket_split_function
from .dataset import ImageVideoDataset
13 changes: 13 additions & 0 deletions examples/hunyuanvideo/hyvideo/dataset/buckets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import Callable, List, Tuple

import numpy as np


def bucket_split_function(
image_batch_size: int, video_batch_size: int
) -> Tuple[Callable[[np.ndarray], int], List[int], List[int]]:
return (
lambda x: int(x.shape[0] > 1), # image or video
[1], # 2 buckets for now: image and videos of fixed length
[image_batch_size, video_batch_size],
)
Loading

0 comments on commit 906ec5d

Please sign in to comment.