diff --git a/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py b/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py index 86adda3..d7e90f6 100644 --- a/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py +++ b/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py @@ -373,9 +373,7 @@ def decode_latents(self, latents, enable_tiling=True): latents = 1 / self.vae.config.scaling_factor * latents if enable_tiling: self.vae.enable_tiling() - image = self.vae.decode(latents, return_dict=False)[0] - else: - image = self.vae.decode(latents, return_dict=False)[0] + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 if image.ndim == 4: @@ -605,6 +603,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], vae_ver: str = "88-4c-sd", enable_tiling: bool = False, + enable_vae_sp: bool = False, n_tokens: Optional[int] = None, embedded_guidance_scale: Optional[float] = None, **kwargs, @@ -986,13 +985,11 @@ def __call__( enabled=vae_autocast_enabled): if enable_tiling: self.vae.enable_tiling() - image = self.vae.decode(latents, - return_dict=False, - generator=generator)[0] - else: - image = self.vae.decode(latents, - return_dict=False, - generator=generator)[0] + if enable_vae_sp: + self.vae.enable_parallel() + image = self.vae.decode(latents, + return_dict=False, + generator=generator)[0] if expand_temporal_dim or image.shape[2] == 1: image = image.squeeze(2) diff --git a/fastvideo/models/hunyuan/inference.py b/fastvideo/models/hunyuan/inference.py index 17723fa..560664f 100644 --- a/fastvideo/models/hunyuan/inference.py +++ b/fastvideo/models/hunyuan/inference.py @@ -523,6 +523,7 @@ def predict( is_progress_bar=True, vae_ver=self.args.vae, enable_tiling=self.args.vae_tiling, + enable_vae_sp=self.args.vae_sp, )[0] out_dict["samples"] = samples out_dict["prompts"] = prompt diff --git a/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py b/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py index 3d3b530..1461fb1 100644 --- a/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py +++ b/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py @@ -17,12 +17,16 @@ # # ============================================================================== from dataclasses import dataclass +from math import prod from typing import Dict, Optional, Tuple, Union import torch +import torch.distributed as dist import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config +from fastvideo.utils.parallel_states import nccl_info + try: # This diffusers is modified and packed in the mirror. from diffusers.loaders import FromOriginalVAEMixin @@ -119,6 +123,7 @@ def __init__( self.use_slicing = False self.use_spatial_tiling = False self.use_temporal_tiling = False + self.use_parallel = False # only relevant if vae tiling is enabled self.tile_sample_min_tsize = sample_tsize @@ -165,6 +170,12 @@ def disable_tiling(self): self.disable_spatial_tiling() self.disable_temporal_tiling() + def enable_parallel(self): + r""" + Enable sequence parallelism for the model. This will allow the vae to decode (with tiling) in parallel. + """ + self.use_parallel = True + def enable_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to @@ -319,6 +330,9 @@ def _decode( ) -> Union[DecoderOutput, torch.FloatTensor]: assert len(z.shape) == 5, "The input tensor should have 5 dimensions." + if self.use_parallel: + return self.parallel_tiled_decode(z, return_dict=return_dict) + if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize: return self.temporal_tiled_decode(z, return_dict=return_dict) @@ -591,6 +605,157 @@ def temporal_tiled_decode(self, return DecoderOutput(sample=dec) + def _parallel_data_generator(self, gathered_results, + gathered_dim_metadata): + global_idx = 0 + for i, per_rank_metadata in enumerate(gathered_dim_metadata): + _start_shape = 0 + for shape in per_rank_metadata: + mul_shape = prod(shape) + yield (gathered_results[i, _start_shape:_start_shape + + mul_shape].reshape(shape), global_idx) + _start_shape += mul_shape + global_idx += 1 + + def parallel_tiled_decode(self, + z: torch.FloatTensor, + return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Parallel version of tiled_decode that distributes both temporal and spatial computation across GPUs + """ + world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group + B, C, T, H, W = z.shape + + # Calculate parameters + t_overlap_size = int(self.tile_latent_min_tsize * + (1 - self.tile_overlap_factor)) + t_blend_extent = int(self.tile_sample_min_tsize * + self.tile_overlap_factor) + t_limit = self.tile_sample_min_tsize - t_blend_extent + + s_overlap_size = int(self.tile_latent_min_size * + (1 - self.tile_overlap_factor)) + s_blend_extent = int(self.tile_sample_min_size * + self.tile_overlap_factor) + s_row_limit = self.tile_sample_min_size - s_blend_extent + + # Calculate tile dimensions + num_t_tiles = (T + t_overlap_size - 1) // t_overlap_size + num_h_tiles = (H + s_overlap_size - 1) // s_overlap_size + num_w_tiles = (W + s_overlap_size - 1) // s_overlap_size + total_spatial_tiles = num_h_tiles * num_w_tiles + total_tiles = num_t_tiles * total_spatial_tiles + + # Calculate tiles per rank and padding + tiles_per_rank = (total_tiles + world_size - 1) // world_size + start_tile_idx = rank * tiles_per_rank + end_tile_idx = min((rank + 1) * tiles_per_rank, total_tiles) + + local_results = [] + local_dim_metadata = [] + # Process assigned tiles + for local_idx, global_idx in enumerate( + range(start_tile_idx, end_tile_idx)): + # Convert flat index to 3D indices + t_idx = global_idx // total_spatial_tiles + spatial_idx = global_idx % total_spatial_tiles + h_idx = spatial_idx // num_w_tiles + w_idx = spatial_idx % num_w_tiles + + # Calculate positions + t_start = t_idx * t_overlap_size + h_start = h_idx * s_overlap_size + w_start = w_idx * s_overlap_size + + # Extract and process tile + tile = z[:, :, t_start:t_start + self.tile_latent_min_tsize + 1, + h_start:h_start + self.tile_latent_min_size, + w_start:w_start + self.tile_latent_min_size] + + # Process tile + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + + if t_start > 0: + decoded = decoded[:, :, 1:, :, :] + + # Store metadata + shape = decoded.shape + # Store decoded data (flattened) + decoded_flat = decoded.reshape(-1) + local_results.append(decoded_flat) + local_dim_metadata.append(shape) + + results = torch.cat(local_results, dim=0).contiguous() + del local_results + torch.cuda.empty_cache() + # first gather size to pad the results + local_size = torch.tensor([results.size(0)], + device=results.device, + dtype=torch.int64) + all_sizes = [ + torch.zeros(1, device=results.device, dtype=torch.int64) + for _ in range(world_size) + ] + dist.all_gather(all_sizes, local_size) + max_size = max(size.item() for size in all_sizes) + padded_results = torch.zeros(max_size, device=results.device) + padded_results[:results.size(0)] = results + del results + torch.cuda.empty_cache() + # Gather all results + gathered_dim_metadata = [None] * world_size + gathered_results = torch.zeros_like(padded_results).repeat( + world_size, *[1] * len(padded_results.shape) + ).contiguous( + ) # use contiguous to make sure it won't copy data in the following operations + dist.all_gather_into_tensor(gathered_results, padded_results) + dist.all_gather_object(gathered_dim_metadata, local_dim_metadata) + # Process gathered results + data = [[[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)] + for _ in range(num_t_tiles)] + for current_data, global_idx in self._parallel_data_generator( + gathered_results, gathered_dim_metadata): + t_idx = global_idx // total_spatial_tiles + spatial_idx = global_idx % total_spatial_tiles + h_idx = spatial_idx // num_w_tiles + w_idx = spatial_idx % num_w_tiles + data[t_idx][h_idx][w_idx] = current_data + # Merge results + result_slices = [] + last_slice_data = None + for i, tem_data in enumerate(data): + slice_data = self._merge_spatial_tiles(tem_data, s_blend_extent, + s_row_limit) + if i > 0: + slice_data = self.blend_t(last_slice_data, slice_data, + t_blend_extent) + result_slices.append(slice_data[:, :, :t_limit, :, :]) + else: + result_slices.append(slice_data[:, :, :t_limit + 1, :, :]) + last_slice_data = slice_data + dec = torch.cat(result_slices, dim=2) + + if not return_dict: + return (dec, ) + return DecoderOutput(sample=dec) + + def _merge_spatial_tiles(self, spatial_rows, blend_extent, row_limit): + """Helper function to merge spatial tiles with blending""" + result_rows = [] + for i, row in enumerate(spatial_rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(spatial_rows[i - 1][j], tile, + blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + return torch.cat(result_rows, dim=-2) + def forward( self, sample: torch.FloatTensor, diff --git a/fastvideo/sample/sample_t2v_hunyuan.py b/fastvideo/sample/sample_t2v_hunyuan.py index b5506f9..aecb3af 100644 --- a/fastvideo/sample/sample_t2v_hunyuan.py +++ b/fastvideo/sample/sample_t2v_hunyuan.py @@ -202,6 +202,7 @@ def main(args): default="fp16", choices=["fp32", "fp16", "bf16"]) parser.add_argument("--vae-tiling", action="store_true", default=True) + parser.add_argument("--vae-sp", action="store_true", default=False) parser.add_argument("--text-encoder", type=str, default="llm") parser.add_argument( @@ -234,4 +235,9 @@ def main(args): parser.add_argument("--text-len-2", type=int, default=77) args = parser.parse_args() + # process for vae sequence parallel + if args.vae_sp and not args.vae_tiling: + raise ValueError( + "Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True." + ) main(args) diff --git a/scripts/inference/inference_diffusers_hunyuan.sh b/scripts/inference/inference_diffusers_hunyuan.sh index 376d225..6fc2256 100644 --- a/scripts/inference/inference_diffusers_hunyuan.sh +++ b/scripts/inference/inference_diffusers_hunyuan.sh @@ -16,5 +16,4 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 12345 \ --seed 1024 \ --output_path outputs_video/hunyuan_quant/nf4/ \ --model_path $MODEL_BASE \ - --quantization "nf4" \ - --cpu_offload \ No newline at end of file + --quantization "nf4" \ No newline at end of file diff --git a/scripts/inference/inference_hunyuan.sh b/scripts/inference/inference_hunyuan.sh index 0340431..48ee9e3 100644 --- a/scripts/inference/inference_hunyuan.sh +++ b/scripts/inference/inference_hunyuan.sh @@ -14,6 +14,7 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ --flow-reverse \ --prompt ./assets/prompt.txt \ --seed 1024 \ - --output_path outputs_video/hunyuan/cfg6/ \ + --output_path outputs_video/hunyuan/vae_sp/ \ --model_path $MODEL_BASE \ - --dit-weight ${MODEL_BASE}/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt + --dit-weight ${MODEL_BASE}/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \ + --vae-sp