Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add parallel for vae decoding #134

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,7 @@ def decode_latents(self, latents, enable_tiling=True):
latents = 1 / self.vae.config.scaling_factor * latents
if enable_tiling:
self.vae.enable_tiling()
image = self.vae.decode(latents, return_dict=False)[0]
else:
image = self.vae.decode(latents, return_dict=False)[0]
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
if image.ndim == 4:
Expand Down Expand Up @@ -605,6 +603,7 @@ def __call__(
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
vae_ver: str = "88-4c-sd",
enable_tiling: bool = False,
enable_vae_sp: bool = False,
n_tokens: Optional[int] = None,
embedded_guidance_scale: Optional[float] = None,
**kwargs,
Expand Down Expand Up @@ -986,13 +985,11 @@ def __call__(
enabled=vae_autocast_enabled):
if enable_tiling:
self.vae.enable_tiling()
image = self.vae.decode(latents,
return_dict=False,
generator=generator)[0]
else:
image = self.vae.decode(latents,
return_dict=False,
generator=generator)[0]
if enable_vae_sp:
self.vae.enable_parallel()
image = self.vae.decode(latents,
return_dict=False,
generator=generator)[0]

if expand_temporal_dim or image.shape[2] == 1:
image = image.squeeze(2)
Expand Down
1 change: 1 addition & 0 deletions fastvideo/models/hunyuan/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ def predict(
is_progress_bar=True,
vae_ver=self.args.vae,
enable_tiling=self.args.vae_tiling,
enable_vae_sp=self.args.vae_sp,
)[0]
out_dict["samples"] = samples
out_dict["prompts"] = prompt
Expand Down
165 changes: 165 additions & 0 deletions fastvideo/models/hunyuan/vae/autoencoder_kl_causal_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
#
# ==============================================================================
from dataclasses import dataclass
from math import prod
from typing import Dict, Optional, Tuple, Union

import torch
import torch.distributed as dist
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config

from fastvideo.utils.parallel_states import nccl_info

try:
# This diffusers is modified and packed in the mirror.
from diffusers.loaders import FromOriginalVAEMixin
Expand Down Expand Up @@ -119,6 +123,7 @@ def __init__(
self.use_slicing = False
self.use_spatial_tiling = False
self.use_temporal_tiling = False
self.use_parallel = False

# only relevant if vae tiling is enabled
self.tile_sample_min_tsize = sample_tsize
Expand Down Expand Up @@ -165,6 +170,12 @@ def disable_tiling(self):
self.disable_spatial_tiling()
self.disable_temporal_tiling()

def enable_parallel(self):
r"""
Enable sequence parallelism for the model. This will allow the vae to decode (with tiling) in parallel.
"""
self.use_parallel = True

def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
Expand Down Expand Up @@ -319,6 +330,9 @@ def _decode(
) -> Union[DecoderOutput, torch.FloatTensor]:
assert len(z.shape) == 5, "The input tensor should have 5 dimensions."

if self.use_parallel:
return self.parallel_tiled_decode(z, return_dict=return_dict)

if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
return self.temporal_tiled_decode(z, return_dict=return_dict)

Expand Down Expand Up @@ -591,6 +605,157 @@ def temporal_tiled_decode(self,

return DecoderOutput(sample=dec)

def _parallel_data_generator(self, gathered_results,
gathered_dim_metadata):
global_idx = 0
for i, per_rank_metadata in enumerate(gathered_dim_metadata):
_start_shape = 0
for shape in per_rank_metadata:
mul_shape = prod(shape)
yield (gathered_results[i, _start_shape:_start_shape +
mul_shape].reshape(shape), global_idx)
_start_shape += mul_shape
global_idx += 1

def parallel_tiled_decode(self,
z: torch.FloatTensor,
return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Parallel version of tiled_decode that distributes both temporal and spatial computation across GPUs
"""
world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group
B, C, T, H, W = z.shape

# Calculate parameters
t_overlap_size = int(self.tile_latent_min_tsize *
(1 - self.tile_overlap_factor))
t_blend_extent = int(self.tile_sample_min_tsize *
self.tile_overlap_factor)
t_limit = self.tile_sample_min_tsize - t_blend_extent

s_overlap_size = int(self.tile_latent_min_size *
(1 - self.tile_overlap_factor))
s_blend_extent = int(self.tile_sample_min_size *
self.tile_overlap_factor)
s_row_limit = self.tile_sample_min_size - s_blend_extent

# Calculate tile dimensions
num_t_tiles = (T + t_overlap_size - 1) // t_overlap_size
num_h_tiles = (H + s_overlap_size - 1) // s_overlap_size
num_w_tiles = (W + s_overlap_size - 1) // s_overlap_size
total_spatial_tiles = num_h_tiles * num_w_tiles
total_tiles = num_t_tiles * total_spatial_tiles

# Calculate tiles per rank and padding
tiles_per_rank = (total_tiles + world_size - 1) // world_size
start_tile_idx = rank * tiles_per_rank
end_tile_idx = min((rank + 1) * tiles_per_rank, total_tiles)

local_results = []
local_dim_metadata = []
# Process assigned tiles
for local_idx, global_idx in enumerate(
range(start_tile_idx, end_tile_idx)):
# Convert flat index to 3D indices
t_idx = global_idx // total_spatial_tiles
spatial_idx = global_idx % total_spatial_tiles
h_idx = spatial_idx // num_w_tiles
w_idx = spatial_idx % num_w_tiles

# Calculate positions
t_start = t_idx * t_overlap_size
h_start = h_idx * s_overlap_size
w_start = w_idx * s_overlap_size

# Extract and process tile
tile = z[:, :, t_start:t_start + self.tile_latent_min_tsize + 1,
h_start:h_start + self.tile_latent_min_size,
w_start:w_start + self.tile_latent_min_size]

# Process tile
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)

if t_start > 0:
decoded = decoded[:, :, 1:, :, :]

# Store metadata
shape = decoded.shape
# Store decoded data (flattened)
decoded_flat = decoded.reshape(-1)
local_results.append(decoded_flat)
local_dim_metadata.append(shape)

results = torch.cat(local_results, dim=0).contiguous()
del local_results
torch.cuda.empty_cache()
# first gather size to pad the results
local_size = torch.tensor([results.size(0)],
device=results.device,
dtype=torch.int64)
all_sizes = [
torch.zeros(1, device=results.device, dtype=torch.int64)
for _ in range(world_size)
]
dist.all_gather(all_sizes, local_size)
max_size = max(size.item() for size in all_sizes)
padded_results = torch.zeros(max_size, device=results.device)
padded_results[:results.size(0)] = results
del results
torch.cuda.empty_cache()
# Gather all results
gathered_dim_metadata = [None] * world_size
gathered_results = torch.zeros_like(padded_results).repeat(
world_size, *[1] * len(padded_results.shape)
).contiguous(
) # use contiguous to make sure it won't copy data in the following operations
dist.all_gather_into_tensor(gathered_results, padded_results)
dist.all_gather_object(gathered_dim_metadata, local_dim_metadata)
# Process gathered results
data = [[[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)]
for _ in range(num_t_tiles)]
for current_data, global_idx in self._parallel_data_generator(
gathered_results, gathered_dim_metadata):
t_idx = global_idx // total_spatial_tiles
spatial_idx = global_idx % total_spatial_tiles
h_idx = spatial_idx // num_w_tiles
w_idx = spatial_idx % num_w_tiles
data[t_idx][h_idx][w_idx] = current_data
# Merge results
result_slices = []
last_slice_data = None
for i, tem_data in enumerate(data):
slice_data = self._merge_spatial_tiles(tem_data, s_blend_extent,
s_row_limit)
if i > 0:
slice_data = self.blend_t(last_slice_data, slice_data,
t_blend_extent)
result_slices.append(slice_data[:, :, :t_limit, :, :])
else:
result_slices.append(slice_data[:, :, :t_limit + 1, :, :])
last_slice_data = slice_data
dec = torch.cat(result_slices, dim=2)

if not return_dict:
return (dec, )
return DecoderOutput(sample=dec)

def _merge_spatial_tiles(self, spatial_rows, blend_extent, row_limit):
"""Helper function to merge spatial tiles with blending"""
result_rows = []
for i, row in enumerate(spatial_rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(spatial_rows[i - 1][j], tile,
blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
return torch.cat(result_rows, dim=-2)

def forward(
self,
sample: torch.FloatTensor,
Expand Down
6 changes: 6 additions & 0 deletions fastvideo/sample/sample_t2v_hunyuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def main(args):
default="fp16",
choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vae-tiling", action="store_true", default=True)
parser.add_argument("--vae-sp", action="store_true", default=False)

parser.add_argument("--text-encoder", type=str, default="llm")
parser.add_argument(
Expand Down Expand Up @@ -234,4 +235,9 @@ def main(args):
parser.add_argument("--text-len-2", type=int, default=77)

args = parser.parse_args()
# process for vae sequence parallel
if args.vae_sp and not args.vae_tiling:
raise ValueError(
"Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True."
)
main(args)
3 changes: 1 addition & 2 deletions scripts/inference/inference_diffusers_hunyuan.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 12345 \
--seed 1024 \
--output_path outputs_video/hunyuan_quant/nf4/ \
--model_path $MODEL_BASE \
--quantization "nf4" \
--cpu_offload
--quantization "nf4"
5 changes: 3 additions & 2 deletions scripts/inference/inference_hunyuan.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ torchrun --nnodes=1 --nproc_per_node=$num_gpus --master_port 29503 \
--flow-reverse \
--prompt ./assets/prompt.txt \
--seed 1024 \
--output_path outputs_video/hunyuan/cfg6/ \
--output_path outputs_video/hunyuan/vae_sp/ \
--model_path $MODEL_BASE \
--dit-weight ${MODEL_BASE}/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt
--dit-weight ${MODEL_BASE}/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt \
--vae-sp
Loading