From 531891ba20f36410621052cc1e08093251c064b0 Mon Sep 17 00:00:00 2001 From: nieyuzhou Date: Wed, 8 Jan 2025 01:55:55 +0800 Subject: [PATCH 1/3] add parallel for vae decoding --- .../pipelines/pipeline_hunyuan_video.py | 17 +- fastvideo/models/hunyuan/inference.py | 1 + .../hunyuan/vae/autoencoder_kl_causal_3d.py | 152 ++++++++++++++++++ fastvideo/sample/sample_t2v_hunyuan.py | 4 + .../inference/inference_diffusers_hunyuan.sh | 3 +- 5 files changed, 166 insertions(+), 11 deletions(-) diff --git a/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py b/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py index 86adda3..d7e90f6 100644 --- a/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py +++ b/fastvideo/models/hunyuan/diffusion/pipelines/pipeline_hunyuan_video.py @@ -373,9 +373,7 @@ def decode_latents(self, latents, enable_tiling=True): latents = 1 / self.vae.config.scaling_factor * latents if enable_tiling: self.vae.enable_tiling() - image = self.vae.decode(latents, return_dict=False)[0] - else: - image = self.vae.decode(latents, return_dict=False)[0] + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 if image.ndim == 4: @@ -605,6 +603,7 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], vae_ver: str = "88-4c-sd", enable_tiling: bool = False, + enable_vae_sp: bool = False, n_tokens: Optional[int] = None, embedded_guidance_scale: Optional[float] = None, **kwargs, @@ -986,13 +985,11 @@ def __call__( enabled=vae_autocast_enabled): if enable_tiling: self.vae.enable_tiling() - image = self.vae.decode(latents, - return_dict=False, - generator=generator)[0] - else: - image = self.vae.decode(latents, - return_dict=False, - generator=generator)[0] + if enable_vae_sp: + self.vae.enable_parallel() + image = self.vae.decode(latents, + return_dict=False, + generator=generator)[0] if expand_temporal_dim or image.shape[2] == 1: image = image.squeeze(2) diff --git a/fastvideo/models/hunyuan/inference.py b/fastvideo/models/hunyuan/inference.py index 17723fa..560664f 100644 --- a/fastvideo/models/hunyuan/inference.py +++ b/fastvideo/models/hunyuan/inference.py @@ -523,6 +523,7 @@ def predict( is_progress_bar=True, vae_ver=self.args.vae, enable_tiling=self.args.vae_tiling, + enable_vae_sp=self.args.vae_sp, )[0] out_dict["samples"] = samples out_dict["prompts"] = prompt diff --git a/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py b/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py index 3d3b530..3ee6e20 100644 --- a/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py +++ b/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py @@ -22,6 +22,11 @@ import torch import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config +from math import prod + +from fastvideo.utils.parallel_states import (get_sequence_parallel_state, + nccl_info) +import torch.distributed as dist try: # This diffusers is modified and packed in the mirror. @@ -119,6 +124,7 @@ def __init__( self.use_slicing = False self.use_spatial_tiling = False self.use_temporal_tiling = False + self.use_parallel = False # only relevant if vae tiling is enabled self.tile_sample_min_tsize = sample_tsize @@ -165,6 +171,12 @@ def disable_tiling(self): self.disable_spatial_tiling() self.disable_temporal_tiling() + def enable_parallel(self): + r""" + Enable sequence parallelism for the model. This will allow the vae to decode (with tiling) in parallel. + """ + self.use_parallel = True + def enable_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to @@ -319,6 +331,9 @@ def _decode( ) -> Union[DecoderOutput, torch.FloatTensor]: assert len(z.shape) == 5, "The input tensor should have 5 dimensions." + if self.use_parallel: + return self.parallel_tiled_decode(z, return_dict=return_dict) + if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize: return self.temporal_tiled_decode(z, return_dict=return_dict) @@ -591,6 +606,143 @@ def temporal_tiled_decode(self, return DecoderOutput(sample=dec) + def _parallel_data_generator(self, gathered_results, gathered_dim_metadata): + global_idx = 0 + for i, per_rank_metadata in enumerate(gathered_dim_metadata): + _start_shape = 0 + for shape in per_rank_metadata: + mul_shape = prod(shape) + yield ( + gathered_results[i, _start_shape:_start_shape + mul_shape].reshape(shape), + global_idx + ) + _start_shape += mul_shape + global_idx += 1 + + + def parallel_tiled_decode(self, + z: torch.FloatTensor, + return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Parallel version of tiled_decode that distributes both temporal and spatial computation across GPUs + """ + world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group + B, C, T, H, W = z.shape + + # Calculate parameters + t_overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) + t_blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) + t_limit = self.tile_sample_min_tsize - t_blend_extent + + s_overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + s_blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + s_row_limit = self.tile_sample_min_size - s_blend_extent + + # Calculate tile dimensions + num_t_tiles = (T + t_overlap_size - 1) // t_overlap_size + num_h_tiles = (H + s_overlap_size - 1) // s_overlap_size + num_w_tiles = (W + s_overlap_size - 1) // s_overlap_size + total_spatial_tiles = num_h_tiles * num_w_tiles + total_tiles = num_t_tiles * total_spatial_tiles + + # Calculate tiles per rank and padding + tiles_per_rank = (total_tiles + world_size - 1) // world_size + start_tile_idx = rank * tiles_per_rank + end_tile_idx = min((rank + 1) * tiles_per_rank, total_tiles) + + local_results = [] + local_dim_metadata = [] + # Process assigned tiles + for local_idx, global_idx in enumerate(range(start_tile_idx, end_tile_idx)): + # Convert flat index to 3D indices + t_idx = global_idx // total_spatial_tiles + spatial_idx = global_idx % total_spatial_tiles + h_idx = spatial_idx // num_w_tiles + w_idx = spatial_idx % num_w_tiles + + # Calculate positions + t_start = t_idx * t_overlap_size + h_start = h_idx * s_overlap_size + w_start = w_idx * s_overlap_size + + # Extract and process tile + tile = z[:, :, + t_start:t_start + self.tile_latent_min_tsize + 1, + h_start:h_start + self.tile_latent_min_size, + w_start:w_start + self.tile_latent_min_size] + + # Process tile + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + + if t_start > 0: + decoded = decoded[:, :, 1:, :, :] + + # Store metadata + shape = decoded.shape + # Store decoded data (flattened) + decoded_flat = decoded.reshape(-1) + local_results.append(decoded_flat) + local_dim_metadata.append(shape) + + results = torch.cat(local_results, dim=0).contiguous() + del local_results + torch.cuda.empty_cache() + # first gather size to pad the results + local_size = torch.tensor([results.size(0)], device=results.device, dtype = torch.int64) + all_sizes = [torch.zeros(1, device=results.device, dtype = torch.int64) for _ in range(world_size)] + dist.all_gather(all_sizes, local_size) + max_size = max(size.item() for size in all_sizes) + padded_results = torch.zeros(max_size, device=results.device) + padded_results[:results.size(0)] = results + del results + torch.cuda.empty_cache() + # Gather all results + gathered_dim_metadata = [None] * world_size + gathered_results = torch.zeros_like(padded_results).repeat(world_size, *[1] * len(padded_results.shape)).contiguous() # use contiguous to make sure it won't copy data in the following operations + dist.all_gather_into_tensor(gathered_results, padded_results) + dist.all_gather_object(gathered_dim_metadata, local_dim_metadata) + # Process gathered results + data = [[[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)] for _ in range(num_t_tiles)] + for current_data, global_idx in self._parallel_data_generator(gathered_results, gathered_dim_metadata): + t_idx = global_idx // total_spatial_tiles + spatial_idx = global_idx % total_spatial_tiles + h_idx = spatial_idx // num_w_tiles + w_idx = spatial_idx % num_w_tiles + data[t_idx][h_idx][w_idx] = current_data + # Merge results + result_slices = [] + last_slice_data = None + for i, tem_data in enumerate(data): + slice_data = self._merge_spatial_tiles(tem_data, s_blend_extent, s_row_limit) + if i > 0: + slice_data = self.blend_t(last_slice_data, slice_data, t_blend_extent) + result_slices.append(slice_data[:, :, :t_limit, :, :]) + else: + result_slices.append(slice_data[:, :, :t_limit + 1, :, :]) + last_slice_data = slice_data + dec = torch.cat(result_slices, dim=2) + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + + def _merge_spatial_tiles(self, spatial_rows, blend_extent, row_limit): + """Helper function to merge spatial tiles with blending""" + result_rows = [] + for i, row in enumerate(spatial_rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(spatial_rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=-1)) + return torch.cat(result_rows, dim=-2) + def forward( self, sample: torch.FloatTensor, diff --git a/fastvideo/sample/sample_t2v_hunyuan.py b/fastvideo/sample/sample_t2v_hunyuan.py index b5506f9..f6da866 100644 --- a/fastvideo/sample/sample_t2v_hunyuan.py +++ b/fastvideo/sample/sample_t2v_hunyuan.py @@ -202,6 +202,7 @@ def main(args): default="fp16", choices=["fp32", "fp16", "bf16"]) parser.add_argument("--vae-tiling", action="store_true", default=True) + parser.add_argument("--vae-sp", action="store_true", default=False) parser.add_argument("--text-encoder", type=str, default="llm") parser.add_argument( @@ -234,4 +235,7 @@ def main(args): parser.add_argument("--text-len-2", type=int, default=77) args = parser.parse_args() + # process for vae sequence parallel + if args.vae_sp and not args.vae_tiling: + raise ValueError("Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True.") main(args) diff --git a/scripts/inference/inference_diffusers_hunyuan.sh b/scripts/inference/inference_diffusers_hunyuan.sh index 376d225..0bc4b9b 100644 --- a/scripts/inference/inference_diffusers_hunyuan.sh +++ b/scripts/inference/inference_diffusers_hunyuan.sh @@ -17,4 +17,5 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 12345 \ --output_path outputs_video/hunyuan_quant/nf4/ \ --model_path $MODEL_BASE \ --quantization "nf4" \ - --cpu_offload \ No newline at end of file + --cpu_offload \ + --vae-sp \ No newline at end of file From 553cee363ad3157fefc4e1ceb9873b0c351475a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Wed, 8 Jan 2025 01:05:10 +0000 Subject: [PATCH 2/3] correct script --- scripts/inference/inference_diffusers_hunyuan.sh | 4 +--- scripts/inference/inference_hunyuan.sh | 5 +++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/scripts/inference/inference_diffusers_hunyuan.sh b/scripts/inference/inference_diffusers_hunyuan.sh index 0bc4b9b..6fc2256 100644 --- a/scripts/inference/inference_diffusers_hunyuan.sh +++ b/scripts/inference/inference_diffusers_hunyuan.sh @@ -16,6 +16,4 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 12345 \ --seed 1024 \ --output_path outputs_video/hunyuan_quant/nf4/ \ --model_path $MODEL_BASE \ - --quantization "nf4" \ - --cpu_offload \ - --vae-sp \ No newline at end of file + --quantization "nf4" \ No newline at end of file diff --git a/scripts/inference/inference_hunyuan.sh b/scripts/inference/inference_hunyuan.sh index 0340431..48ee9e3 100644 --- a/scripts/inference/inference_hunyuan.sh +++ b/scripts/inference/inference_hunyuan.sh @@ -14,6 +14,7 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \ --flow-reverse \ --prompt ./assets/prompt.txt \ --seed 1024 \ - --output_path outputs_video/hunyuan/cfg6/ \ + --output_path outputs_video/hunyuan/vae_sp/ \ --model_path $MODEL_BASE \ - --dit-weight ${MODEL_BASE}/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt + --dit-weight ${MODEL_BASE}/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \ + --vae-sp From f53208c67ebe2583fdc98a1c2b9708e43b6e8bca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CBrianChen1129=E2=80=9D?= Date: Wed, 8 Jan 2025 01:07:05 +0000 Subject: [PATCH 3/3] lint --- .../hunyuan/vae/autoencoder_kl_causal_3d.py | 71 +++++++++++-------- fastvideo/sample/sample_t2v_hunyuan.py | 4 +- 2 files changed, 45 insertions(+), 30 deletions(-) diff --git a/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py b/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py index 3ee6e20..1461fb1 100644 --- a/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py +++ b/fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py @@ -17,16 +17,15 @@ # # ============================================================================== from dataclasses import dataclass +from math import prod from typing import Dict, Optional, Tuple, Union import torch +import torch.distributed as dist import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config -from math import prod -from fastvideo.utils.parallel_states import (get_sequence_parallel_state, - nccl_info) -import torch.distributed as dist +from fastvideo.utils.parallel_states import nccl_info try: # This diffusers is modified and packed in the mirror. @@ -606,20 +605,18 @@ def temporal_tiled_decode(self, return DecoderOutput(sample=dec) - def _parallel_data_generator(self, gathered_results, gathered_dim_metadata): + def _parallel_data_generator(self, gathered_results, + gathered_dim_metadata): global_idx = 0 for i, per_rank_metadata in enumerate(gathered_dim_metadata): _start_shape = 0 for shape in per_rank_metadata: mul_shape = prod(shape) - yield ( - gathered_results[i, _start_shape:_start_shape + mul_shape].reshape(shape), - global_idx - ) + yield (gathered_results[i, _start_shape:_start_shape + + mul_shape].reshape(shape), global_idx) _start_shape += mul_shape global_idx += 1 - def parallel_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True @@ -631,12 +628,16 @@ def parallel_tiled_decode(self, B, C, T, H, W = z.shape # Calculate parameters - t_overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor)) - t_blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor) + t_overlap_size = int(self.tile_latent_min_tsize * + (1 - self.tile_overlap_factor)) + t_blend_extent = int(self.tile_sample_min_tsize * + self.tile_overlap_factor) t_limit = self.tile_sample_min_tsize - t_blend_extent - s_overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) - s_blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + s_overlap_size = int(self.tile_latent_min_size * + (1 - self.tile_overlap_factor)) + s_blend_extent = int(self.tile_sample_min_size * + self.tile_overlap_factor) s_row_limit = self.tile_sample_min_size - s_blend_extent # Calculate tile dimensions @@ -654,7 +655,8 @@ def parallel_tiled_decode(self, local_results = [] local_dim_metadata = [] # Process assigned tiles - for local_idx, global_idx in enumerate(range(start_tile_idx, end_tile_idx)): + for local_idx, global_idx in enumerate( + range(start_tile_idx, end_tile_idx)): # Convert flat index to 3D indices t_idx = global_idx // total_spatial_tiles spatial_idx = global_idx % total_spatial_tiles @@ -667,10 +669,9 @@ def parallel_tiled_decode(self, w_start = w_idx * s_overlap_size # Extract and process tile - tile = z[:, :, - t_start:t_start + self.tile_latent_min_tsize + 1, - h_start:h_start + self.tile_latent_min_size, - w_start:w_start + self.tile_latent_min_size] + tile = z[:, :, t_start:t_start + self.tile_latent_min_tsize + 1, + h_start:h_start + self.tile_latent_min_size, + w_start:w_start + self.tile_latent_min_size] # Process tile tile = self.post_quant_conv(tile) @@ -690,8 +691,13 @@ def parallel_tiled_decode(self, del local_results torch.cuda.empty_cache() # first gather size to pad the results - local_size = torch.tensor([results.size(0)], device=results.device, dtype = torch.int64) - all_sizes = [torch.zeros(1, device=results.device, dtype = torch.int64) for _ in range(world_size)] + local_size = torch.tensor([results.size(0)], + device=results.device, + dtype=torch.int64) + all_sizes = [ + torch.zeros(1, device=results.device, dtype=torch.int64) + for _ in range(world_size) + ] dist.all_gather(all_sizes, local_size) max_size = max(size.item() for size in all_sizes) padded_results = torch.zeros(max_size, device=results.device) @@ -700,12 +706,17 @@ def parallel_tiled_decode(self, torch.cuda.empty_cache() # Gather all results gathered_dim_metadata = [None] * world_size - gathered_results = torch.zeros_like(padded_results).repeat(world_size, *[1] * len(padded_results.shape)).contiguous() # use contiguous to make sure it won't copy data in the following operations + gathered_results = torch.zeros_like(padded_results).repeat( + world_size, *[1] * len(padded_results.shape) + ).contiguous( + ) # use contiguous to make sure it won't copy data in the following operations dist.all_gather_into_tensor(gathered_results, padded_results) dist.all_gather_object(gathered_dim_metadata, local_dim_metadata) # Process gathered results - data = [[[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)] for _ in range(num_t_tiles)] - for current_data, global_idx in self._parallel_data_generator(gathered_results, gathered_dim_metadata): + data = [[[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)] + for _ in range(num_t_tiles)] + for current_data, global_idx in self._parallel_data_generator( + gathered_results, gathered_dim_metadata): t_idx = global_idx // total_spatial_tiles spatial_idx = global_idx % total_spatial_tiles h_idx = spatial_idx // num_w_tiles @@ -715,9 +726,11 @@ def parallel_tiled_decode(self, result_slices = [] last_slice_data = None for i, tem_data in enumerate(data): - slice_data = self._merge_spatial_tiles(tem_data, s_blend_extent, s_row_limit) + slice_data = self._merge_spatial_tiles(tem_data, s_blend_extent, + s_row_limit) if i > 0: - slice_data = self.blend_t(last_slice_data, slice_data, t_blend_extent) + slice_data = self.blend_t(last_slice_data, slice_data, + t_blend_extent) result_slices.append(slice_data[:, :, :t_limit, :, :]) else: result_slices.append(slice_data[:, :, :t_limit + 1, :, :]) @@ -725,10 +738,9 @@ def parallel_tiled_decode(self, dec = torch.cat(result_slices, dim=2) if not return_dict: - return (dec,) + return (dec, ) return DecoderOutput(sample=dec) - def _merge_spatial_tiles(self, spatial_rows, blend_extent, row_limit): """Helper function to merge spatial tiles with blending""" result_rows = [] @@ -736,7 +748,8 @@ def _merge_spatial_tiles(self, spatial_rows, blend_extent, row_limit): result_row = [] for j, tile in enumerate(row): if i > 0: - tile = self.blend_v(spatial_rows[i - 1][j], tile, blend_extent) + tile = self.blend_v(spatial_rows[i - 1][j], tile, + blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :, :row_limit, :row_limit]) diff --git a/fastvideo/sample/sample_t2v_hunyuan.py b/fastvideo/sample/sample_t2v_hunyuan.py index f6da866..aecb3af 100644 --- a/fastvideo/sample/sample_t2v_hunyuan.py +++ b/fastvideo/sample/sample_t2v_hunyuan.py @@ -237,5 +237,7 @@ def main(args): args = parser.parse_args() # process for vae sequence parallel if args.vae_sp and not args.vae_tiling: - raise ValueError("Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True.") + raise ValueError( + "Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True." + ) main(args)