-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
723 additions
and
2,010 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .communications import * | ||
from .parallel_states import * |
71 changes: 71 additions & 0 deletions
71
examples/hunyuanvideo/hyvideo/acceleration/communications.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from typing import Callable, Literal, Tuple | ||
|
||
import mindspore.nn as nn | ||
import mindspore.ops as ops | ||
from mindspore import Tensor | ||
from mindspore.communication import GlobalComm, get_group_size, get_rank | ||
|
||
__all__ = ["SplitFowardGatherBackward", "GatherFowardSplitBackward"] | ||
|
||
|
||
def _split(x: Tensor, dim: int, rank: int, world_size: int) -> Tensor: | ||
dim_size = x.shape[dim] | ||
tensor_list = x.split(dim_size // world_size, axis=dim) | ||
x = tensor_list[rank] | ||
return x | ||
|
||
|
||
def _communicate_along_dim(x: Tensor, dim: int, func: Callable[[Tensor], Tensor]) -> Tensor: | ||
x = x.swapaxes(0, dim) | ||
x = func(x) | ||
x = x.swapaxes(dim, 0) | ||
return x | ||
|
||
|
||
class SplitFowardGatherBackward(nn.Cell): | ||
def __init__( | ||
self, dim: int = 0, grad_scale: Literal["up", "down"] = "down", group: str = GlobalComm.WORLD_COMM_GROUP | ||
) -> None: | ||
super().__init__() | ||
self.dim = dim | ||
self.rank = get_rank(group) | ||
self.world_size = get_group_size(group) | ||
self.gather = ops.AllGather(group=group) | ||
|
||
if grad_scale == "up": | ||
self.scale = self.world_size | ||
else: | ||
self.scale = 1 / self.world_size | ||
|
||
def construct(self, x: Tensor) -> Tensor: | ||
return _split(x, self.dim, self.rank, self.world_size) | ||
|
||
def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: | ||
dout = dout * self.scale | ||
dout = _communicate_along_dim(dout, self.dim, self.gather) | ||
return (dout,) | ||
|
||
|
||
class GatherFowardSplitBackward(nn.Cell): | ||
def __init__( | ||
self, dim: int = 0, grad_scale: Literal["up", "down"] = "up", group: str = GlobalComm.WORLD_COMM_GROUP | ||
) -> None: | ||
super().__init__() | ||
self.dim = dim | ||
self.rank = get_rank(group) | ||
self.world_size = get_group_size(group) | ||
self.gather = ops.AllGather(group=group) | ||
|
||
if grad_scale == "up": | ||
self.scale = self.world_size | ||
else: | ||
self.scale = 1 / self.world_size | ||
|
||
def construct(self, x: Tensor) -> Tensor: | ||
x = _communicate_along_dim(x, self.dim, self.gather) | ||
return x | ||
|
||
def bprop(self, x: Tensor, out: Tensor, dout: Tensor) -> Tuple[Tensor]: | ||
dout = dout * self.scale | ||
dout = _split(dout, self.dim, self.rank, self.world_size) | ||
return (dout,) |
35 changes: 35 additions & 0 deletions
35
examples/hunyuanvideo/hyvideo/acceleration/parallel_states.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import Optional | ||
|
||
from mindspore.communication import create_group, get_group_size, get_rank | ||
|
||
__all__ = ["set_sequence_parallel_group", "get_sequence_parallel_group", "create_parallel_group"] | ||
|
||
_GLOBAL_PARALLEL_GROUPS = dict() | ||
|
||
|
||
def set_sequence_parallel_group(group: str) -> None: | ||
_GLOBAL_PARALLEL_GROUPS["sequence"] = group | ||
|
||
|
||
def get_sequence_parallel_group() -> Optional[str]: | ||
return _GLOBAL_PARALLEL_GROUPS.get("sequence", None) | ||
|
||
|
||
def create_parallel_group(shards: int) -> None: | ||
if shards <= 1: | ||
raise ValueError( | ||
f"The number of sequence parallel shards must be larger than 1 to enable sequence parallel, but got {shards}." | ||
) | ||
|
||
device_num = get_group_size() | ||
if device_num % shards != 0: | ||
raise ValueError( | ||
f"Total number of devices ({device_num}) must be divisible by the number of sequence parallel shards ({shards})." | ||
) | ||
|
||
rank_id = get_rank() | ||
sp_group_id = rank_id // shards | ||
sp_group_rank_ids = list(range(sp_group_id * shards, (sp_group_id + 1) * shards)) | ||
sp_group_name = f"sp_group_{sp_group_id}" | ||
create_group(sp_group_name, sp_group_rank_ids) | ||
set_sequence_parallel_group(sp_group_name) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,80 +1,2 @@ | ||
import logging | ||
from functools import partial | ||
|
||
import cv2 | ||
from albumentations import Compose, Lambda, Resize, ToFloat | ||
from hyvideo.text_encoder import load_tokenizer | ||
|
||
from .t2v_datasets import T2V_dataset | ||
from .transform import TemporalRandomCrop, center_crop_th_tw, maxhxw_resize, spatial_stride_crop_video | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def getdataset(args, dataset_file): | ||
temporal_sample = TemporalRandomCrop(args.num_frames) # 16 x | ||
norm_fun = lambda x: 2.0 * x - 1.0 | ||
|
||
def norm_func_albumentation(image, **kwargs): | ||
return norm_fun(image) | ||
|
||
mapping = {"bilinear": cv2.INTER_LINEAR, "bicubic": cv2.INTER_CUBIC} | ||
targets = {"image{}".format(i): "image" for i in range(args.num_frames)} | ||
|
||
if args.force_resolution: | ||
assert (args.max_height is not None) and ( | ||
args.max_width is not None | ||
), "set max_height and max_width for fixed resolution" | ||
resize = [ | ||
Lambda( | ||
name="crop_centercrop", | ||
image=partial(center_crop_th_tw, th=args.max_height, tw=args.max_width, top_crop=False), | ||
p=1.0, | ||
), | ||
Resize(args.max_height, args.max_width, interpolation=mapping["bilinear"]), | ||
] | ||
else: # dynamic resolution | ||
assert args.max_hxw is not None, "set max_hxw for dynamic resolution" | ||
resize = [ | ||
Lambda( | ||
name="maxhxw_resize", | ||
image=partial(maxhxw_resize, max_hxw=args.max_hxw, interpolation_mode=mapping["bilinear"]), | ||
p=1.0, | ||
), | ||
Lambda( | ||
name="spatial_stride_crop", | ||
image=partial(spatial_stride_crop_video, stride=args.hw_stride), # default stride=32 | ||
p=1.0, | ||
), | ||
] | ||
|
||
transform = Compose( | ||
[*resize, ToFloat(255.0), Lambda(name="ae_norm", image=norm_func_albumentation, p=1.0)], | ||
additional_targets=targets, | ||
) | ||
|
||
tokenizer_1, _ = load_tokenizer( | ||
tokenizer_type=args.tokenizer, | ||
tokenizer_path=args.tokenizer_path if args.tokenizer_path is not None else args.text_encoder_path, | ||
padding_side="right", | ||
logger=logger, | ||
) | ||
tokenizer_2, _ = load_tokenizer( | ||
tokenizer_type=args.tokenizer_2, | ||
tokenizer_path=args.tokenizer_path_2 if args.tokenizer_path_2 is not None else args.text_encoder_path_2, | ||
padding_side="right", | ||
logger=logger, | ||
) | ||
if args.dataset == "t2v": | ||
return T2V_dataset( | ||
args, | ||
transform=transform, | ||
temporal_sample=temporal_sample, | ||
tokenizer_1=tokenizer_1, | ||
tokenizer_2=tokenizer_2, | ||
return_text_emb=args.text_embed_cache, | ||
) | ||
elif args.dataset == "inpaint" or args.dataset == "i2v": | ||
raise NotImplementedError | ||
|
||
raise NotImplementedError(args.dataset) | ||
from .buckets import bucket_split_function | ||
from .dataset import ImageVideoDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import Callable, List, Tuple | ||
|
||
import numpy as np | ||
|
||
|
||
def bucket_split_function( | ||
image_batch_size: int, video_batch_size: int | ||
) -> Tuple[Callable[[np.ndarray], int], List[int], List[int]]: | ||
return ( | ||
lambda x: int(x.shape[0] > 1), # image or video | ||
[1], # 2 buckets for now: image and videos of fixed length | ||
[image_batch_size, video_batch_size], | ||
) |
Oops, something went wrong.