From e87f8239550cf5eee1a942f63e8b36c63f996ad0 Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Thu, 12 Dec 2024 07:37:43 +0000 Subject: [PATCH 01/23] add cogvideox support for gaudi. --- examples/text-to-video/cogvideox_generate.py | 75 +++ .../pipelines/cogvideox/cogvideoX_gaudi.py | 273 +++++++++ .../cogvideox/pipeline_cogvideox_gaudi.py | 536 ++++++++++++++++++ 3 files changed, 884 insertions(+) create mode 100644 examples/text-to-video/cogvideox_generate.py create mode 100644 optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py create mode 100644 optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py diff --git a/examples/text-to-video/cogvideox_generate.py b/examples/text-to-video/cogvideox_generate.py new file mode 100644 index 0000000000..8848f92654 --- /dev/null +++ b/examples/text-to-video/cogvideox_generate.py @@ -0,0 +1,75 @@ +import argparse +import logging +import sys +from pathlib import Path + +import torch +from pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline +#from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video + +from optimum.habana.transformers.gaudi_configuration import GaudiConfig +from optimum.habana.utils import set_seed +logger = logging.getLogger(__name__) + +prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." + +#prompt = "A 360-degree panoramic view of a lush mountain valley with a flowing river, birds flying across the sky, and a soft orange-pink sunrise." +#prompt = "Spiderman is surfing, Darth Vader is also surfing and following Spiderman" +#prompt = "An astronaut riding a horse" +#prompt = "A drone shot flying above vibrant red and orange foliage with occasional sunlight beams piercing through the canopy." +#prompt = "Skyscrapers with glowing neon signs, flying cars zipping between buildings, and a massive digital billboard displaying a news broadcast." +#prompt = "Bright, surreal waves of color blending and transforming into abstract shapes in rhythm with gentle ambient music." +#prompt = "A first-person view of a runner jumping between rooftops, flipping over obstacles, and climbing walls." + +gaudi_config_kwargs = {"use_fused_adam": True, "use_fused_clip_norm": True} +gaudi_config_kwargs["use_torch_autocast"] = True + +gaudi_config = GaudiConfig(**gaudi_config_kwargs) +logger.info(f"Gaudi Config: {gaudi_config}") + + +kwargs = { + "use_habana": True, + "use_hpu_graphs": True, + "gaudi_config": gaudi_config, +} +kwargs["torch_dtype"] = torch.bfloat16 + + +print('now to load model.....') +model_path = "/mnt/disk2/libo/hf_models/CogVideoX-2b/" +#model_path = "/mnt/disk2/libo/hf_models/CogVideoX-5b/" +pipe = GaudiCogVideoXPipeline.from_pretrained( + model_path, + **kwargs +) +print('pipe line load done!') + +pipe.vae.enable_tiling() +pipe.vae.enable_slicing() + +print('now to generate video.') +video = pipe( + prompt=prompt, + num_videos_per_prompt=1, + num_inference_steps=50, + num_frames=49, + guidance_scale=6, + generator=torch.Generator(device="cpu").manual_seed(42), +).frames[0] + +print('generate video done!') + +export_to_video(video, "panda_gaudi.mp4", fps=8) +#export_to_video(video, "output_gaudi.mp4", fps=8) +#export_to_video(video, "Spiderman_gaudi.mp4", fps=8) +#export_to_video(video, "astronaut_gaudi.mp4", fps=8) +#export_to_video(video, "drone_gaudi.mp4", fps=8) +#export_to_video(video, "Skyscrapers_gaudi.mp4", fps=8) +#export_to_video(video, "waves_gaudi.mp4", fps=8) + + + + + diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py new file mode 100644 index 0000000000..89e10248e1 --- /dev/null +++ b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py @@ -0,0 +1,273 @@ +from typing import Any, Callable, Dict, List, Optional, Union, Tuple +import torch +import torch.nn as nn + +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) + + +from diffusers.models.attention import Attention +class CogVideoXAttnProcessorGaudi: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + hidden_states = self.fused_scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_casual=False, scale=None, softmax_mode='fast' + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + +import torch.nn.functional as F +from diffusers.models import attention_processor +attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi + +from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXSafeConv3d +from diffusers.models.autoencoders.vae import DecoderOutput + +class CogVideoXCausalConv3dGaudi(nn.Module): + r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. + + Args: + in_channels (`int`): Number of channels in the input tensor. + out_channels (`int`): Number of output channels produced by the convolution. + kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel. + stride (`int`, defaults to `1`): Stride of the convolution. + dilation (`int`, defaults to `1`): Dilation rate of the convolution. + pad_mode (`str`, defaults to `"constant"`): Padding mode. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: int = 1, + dilation: int = 1, + pad_mode: str = "constant", + ): + super().__init__() + + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + + self.temporal_dim = 2 + self.time_kernel_size = time_kernel_size + + stride = (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = CogVideoXSafeConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + self.conv_cache = None + + def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: + kernel_size = self.time_kernel_size + if kernel_size > 1: + cached_inputs = ( + [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) + ) + inputs = torch.cat(cached_inputs + [inputs], dim=2) + return inputs + + def _clear_fake_context_parallel_cache(self): + del self.conv_cache + self.conv_cache = None + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + inputs = self.fake_context_parallel_forward(inputs) + + #self._clear_fake_context_parallel_cache() + # Note: we could move these to the cpu for a lower maximum memory usage but its only a few + # hundred megabytes and so let's not do it for now + #self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + inputs_pad = F.pad(inputs, padding_2d, mode="constant", value=0) + + output = self.conv(inputs_pad) + if self.time_kernel_size>1: + if self.conv_cache is not None and self.conv_cache.shape == inputs[:, :, -self.time_kernel_size + 1:].shape: + self.conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:]) + else: + self.conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() + return output + +from diffusers.models.autoencoders import autoencoder_kl_cogvideox +autoencoder_kl_cogvideox.CogVideoXCausalConv3d = CogVideoXCausalConv3dGaudi + +from diffusers.models.autoencoders.autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +class AutoencoderKLCogVideoXGaudi(AutoencoderKLCogVideoX): + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + # Rough memory assessment: + # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. + # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. + # - Assume fp16 (2 bytes per value). + # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB + # + # Memory assessment when using tiling: + # - Assume everything as above but now HxW is 240x360 by tiling in half + # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + + print('run gaudi tiled decode!') + batch_size, num_channels, num_frames, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + num_batches = num_frames // frame_batch_size + time = [] + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ].clone() + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile = self.decoder(tile) + time.append(tile.clone()) + self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + +from diffusers.models.autoencoders import autoencoder_kl_cogvideox +autoencoder_kl_cogvideox.AutoencoderKLCogVideoX=AutoencoderKLCogVideoXGaudi + + diff --git a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py new file mode 100644 index 0000000000..f5b7d7b9ca --- /dev/null +++ b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py @@ -0,0 +1,536 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from math import ceil +from typing import Any, Callable, Dict, List, Optional, Union + +import cogvideoX_gaudi + +import numpy as np +import PIL.Image +import torch +from diffusers import CogVideoXPipeline +from transformers import T5EncoderModel, T5Tokenizer +from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import BaseOutput +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback + +from diffusers.utils import logging + +from optimum.habana.transformers.gaudi_configuration import GaudiConfig +from optimum.habana.diffusers.pipelines.pipeline_utils import GaudiDiffusionPipeline +import time as tm_perf + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class time_box_t(): + def __init__(self): + self.t0=None + + def start(self): + self.t0 = tm_perf.perf_counter() + + def show_time(self, desc): + torch.hpu.synchronize() + t1 = tm_perf.perf_counter() + duration = t1-self.t0 + self.t0 = t1 + print(f'{desc} duration:{duration:.3f}s') + +@dataclass +class GaudiTextToVideoSDPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class GaudiCogVideoXPipeline(GaudiDiffusionPipeline, CogVideoXPipeline): + r""" + Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py#L84 + """ + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + use_habana: bool = False, + use_hpu_graphs: bool = False, + gaudi_config: Union[str, GaudiConfig] = None, + bf16_full_eval: bool = False, + ): + print(f'GaudiCogVideoXPipeline init use_habana:{use_habana} use_hpu_graphs:{use_hpu_graphs}') + GaudiDiffusionPipeline.__init__( + self, + use_habana, + use_hpu_graphs, + gaudi_config, + bf16_full_eval, + ) + CogVideoXPipeline.__init__( + self, + tokenizer, + text_encoder, + vae, + transformer, + scheduler, + ) + self.to(self._device) + + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + self.vae.decoder = wrap_in_hpu_graph(self.vae.decoder) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def enable_model_cpu_offload(self, *args, **kwargs): + if self.use_habana: + raise NotImplementedError("enable_model_cpu_offload() is not implemented for HPU") + else: + return super().enable_model_cpu_offload(*args, **kwargs) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + # torch.randn is broken on HPU so running it on CPU + rand_device = "cpu" if device.type == "hpu" else device + rand_device = torch.device(rand_device) + latents = randn_tensor(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + time_box = time_box_t() + time_box.start() + # 0. Default height and width to unet + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + time_box.show_time('prepare latents') + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + outputs = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer_hpu( + latent_model_input=latent_model_input, + prompt_embeds=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + ) + + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + if not self.use_hpu_graphs: + self.htcore.mark_step() + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if not self.use_hpu_graphs: + self.htcore.mark_step() + time_box.show_time('transformer_hpu') + + #HabanaProfile.stop() + if not output_type == "latent": + #print('baymax now to decode latents') + #latents = latents.to('cpu') + video = self.decode_latents(latents) + time_box.show_time('decode latents') + #print('baymax decode latents done!') + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + time_box.show_time('postprocess_video') + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return GaudiTextToVideoSDPipelineOutput(frames=video) + + @torch.no_grad() + def transformer_hpu(self, latent_model_input, prompt_embeds, timestep, image_rotary_emb): + if self.use_hpu_graphs: + return self.capture_replay(latent_model_input, prompt_embeds, timestep, image_rotary_emb) + else: + return self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + @torch.no_grad() + def capture_replay(self, latent_model_input, prompt_embeds, timestep, image_rotary_emb): + inputs = [latent_model_input.clone(), prompt_embeds.clone(), timestep.clone(), image_rotary_emb, False] + h = self.ht.hpu.graphs.input_hash(inputs) + cached = self.cache.get(h) + + if cached is None: + # Capture the graph and cache it + with self.ht.hpu.stream(self.hpu_stream): + graph = self.ht.hpu.HPUGraph() + graph.capture_begin() + outputs = self.transformer( + hidden_states = inputs[0], + encoder_hidden_states = inputs[1], + timestep=inputs[2], + image_rotary_emb=inputs[3], + return_dict=inputs[4] + )[0] + graph.capture_end() + graph_inputs = inputs + graph_outputs = outputs + self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) + return outputs + + # Replay the cached graph with updated inputs + self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) + cached.graph.replay() + self.ht.core.hpu.default_stream().synchronize() + + return cached.graph_outputs From fd556caf8206d01d5333c00cb80d0cf103e6dbdd Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Thu, 12 Dec 2024 09:16:37 +0000 Subject: [PATCH 02/23] update README for cogvideX --- examples/text-to-video/README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/text-to-video/README.md b/examples/text-to-video/README.md index 1df4e44e59..061c5c7928 100644 --- a/examples/text-to-video/README.md +++ b/examples/text-to-video/README.md @@ -39,3 +39,12 @@ python3 text_to_video_generation.py \ Models that have been validated: - [ali-vilab/text-to-video-ms-1.7b](https://huggingface.co/ali-vilab/text-to-video-ms-1.7b) + +CogvideoX test: +```bash +python3 cogvideo_generate.py \ + --model_name_or_path CogVideoX-2b \ + --output_name gaudi_output.mp4 +``` + + From 6a8e73d71b27f724d4e61bdc95cf5cd2fa1b19f3 Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Thu, 12 Dec 2024 09:26:59 +0000 Subject: [PATCH 03/23] import cogvideo module from optimumu lib --- examples/text-to-video/cogvideox_generate.py | 131 ++++++++++--------- 1 file changed, 71 insertions(+), 60 deletions(-) diff --git a/examples/text-to-video/cogvideox_generate.py b/examples/text-to-video/cogvideox_generate.py index 8848f92654..39811c2267 100644 --- a/examples/text-to-video/cogvideox_generate.py +++ b/examples/text-to-video/cogvideox_generate.py @@ -4,72 +4,83 @@ from pathlib import Path import torch -from pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline -#from diffusers import CogVideoXPipeline +from optimum.habana.diffusers.pipelines.cogvideox.pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline from diffusers.utils import export_to_video from optimum.habana.transformers.gaudi_configuration import GaudiConfig from optimum.habana.utils import set_seed logger = logging.getLogger(__name__) -prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." - -#prompt = "A 360-degree panoramic view of a lush mountain valley with a flowing river, birds flying across the sky, and a soft orange-pink sunrise." -#prompt = "Spiderman is surfing, Darth Vader is also surfing and following Spiderman" -#prompt = "An astronaut riding a horse" -#prompt = "A drone shot flying above vibrant red and orange foliage with occasional sunlight beams piercing through the canopy." -#prompt = "Skyscrapers with glowing neon signs, flying cars zipping between buildings, and a massive digital billboard displaying a news broadcast." -#prompt = "Bright, surreal waves of color blending and transforming into abstract shapes in rhythm with gentle ambient music." -#prompt = "A first-person view of a runner jumping between rooftops, flipping over obstacles, and climbing walls." - -gaudi_config_kwargs = {"use_fused_adam": True, "use_fused_clip_norm": True} -gaudi_config_kwargs["use_torch_autocast"] = True - -gaudi_config = GaudiConfig(**gaudi_config_kwargs) -logger.info(f"Gaudi Config: {gaudi_config}") - - -kwargs = { - "use_habana": True, - "use_hpu_graphs": True, - "gaudi_config": gaudi_config, -} -kwargs["torch_dtype"] = torch.bfloat16 - - -print('now to load model.....') -model_path = "/mnt/disk2/libo/hf_models/CogVideoX-2b/" -#model_path = "/mnt/disk2/libo/hf_models/CogVideoX-5b/" -pipe = GaudiCogVideoXPipeline.from_pretrained( - model_path, - **kwargs -) -print('pipe line load done!') - -pipe.vae.enable_tiling() -pipe.vae.enable_slicing() - -print('now to generate video.') -video = pipe( - prompt=prompt, - num_videos_per_prompt=1, - num_inference_steps=50, - num_frames=49, - guidance_scale=6, - generator=torch.Generator(device="cpu").manual_seed(42), -).frames[0] - -print('generate video done!') - -export_to_video(video, "panda_gaudi.mp4", fps=8) -#export_to_video(video, "output_gaudi.mp4", fps=8) -#export_to_video(video, "Spiderman_gaudi.mp4", fps=8) -#export_to_video(video, "astronaut_gaudi.mp4", fps=8) -#export_to_video(video, "drone_gaudi.mp4", fps=8) -#export_to_video(video, "Skyscrapers_gaudi.mp4", fps=8) -#export_to_video(video, "waves_gaudi.mp4", fps=8) - - +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + + parser.add_argument( + "--model_name_or_path", + default="/mnt/disk2/libo/hf_models/CogVideoX-2b/", + type=str, + help="Path to pre-trained model", + ) + # Pipeline arguments + parser.add_argument( + "--prompts", + type=str, + nargs="*", + default="A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance.", + help="The prompt or prompts to guide the video generation.", + ) + parser.add_argument( + "--output_name", + default="panda_gaudi.mp4", + type=str, + help="Path to pre-trained model", + ) + + args = parser.parse_args() + + gaudi_config_kwargs = {"use_fused_adam": True, "use_fused_clip_norm": True} + gaudi_config_kwargs["use_torch_autocast"] = True + + gaudi_config = GaudiConfig(**gaudi_config_kwargs) + logger.info(f"Gaudi Config: {gaudi_config}") + + + kwargs = { + "use_habana": True, + "use_hpu_graphs": True, + "gaudi_config": gaudi_config, + } + kwargs["torch_dtype"] = torch.bfloat16 + + + print('now to load model.....') + pipe = GaudiCogVideoXPipeline.from_pretrained( + args.model_name_or_path, + **kwargs + ) + print('pipe line load done!') + + pipe.vae.enable_tiling() + pipe.vae.enable_slicing() + + print('now to generate video.') + video = pipe( + prompt=args.prompts, + num_videos_per_prompt=1, + num_inference_steps=50, + num_frames=49, + guidance_scale=6, + generator=torch.Generator(device="cpu").manual_seed(42), + ).frames[0] + + print('generate video done!') + + export_to_video(video, args.output_name, fps=8) + + + +if __name__ == "__main__": + main() From 7092a515aa41f642a2621022a5527e4ad97f8c95 Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Thu, 12 Dec 2024 09:54:54 +0000 Subject: [PATCH 04/23] refine test examples --- examples/text-to-video/cogvideox_generate.py | 33 +++++++++---------- .../pipelines/cogvideox/cogvideoX_gaudi.py | 15 ++++++++- .../cogvideox/pipeline_cogvideox_gaudi.py | 31 +++++++---------- 3 files changed, 42 insertions(+), 37 deletions(-) diff --git a/examples/text-to-video/cogvideox_generate.py b/examples/text-to-video/cogvideox_generate.py index 39811c2267..26e4b74c4f 100644 --- a/examples/text-to-video/cogvideox_generate.py +++ b/examples/text-to-video/cogvideox_generate.py @@ -1,24 +1,23 @@ import argparse import logging -import sys -from pathlib import Path import torch -from optimum.habana.diffusers.pipelines.cogvideox.pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline from diffusers.utils import export_to_video +from optimum.habana.diffusers.pipelines.cogvideox.pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline from optimum.habana.transformers.gaudi_configuration import GaudiConfig -from optimum.habana.utils import set_seed + + logger = logging.getLogger(__name__) def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - + parser.add_argument( "--model_name_or_path", - default="/mnt/disk2/libo/hf_models/CogVideoX-2b/", + default="CogVideoX-2b", type=str, help="Path to pre-trained model", ) @@ -38,32 +37,32 @@ def main(): ) args = parser.parse_args() - + gaudi_config_kwargs = {"use_fused_adam": True, "use_fused_clip_norm": True} gaudi_config_kwargs["use_torch_autocast"] = True - + gaudi_config = GaudiConfig(**gaudi_config_kwargs) logger.info(f"Gaudi Config: {gaudi_config}") - - + + kwargs = { "use_habana": True, "use_hpu_graphs": True, "gaudi_config": gaudi_config, } kwargs["torch_dtype"] = torch.bfloat16 - - + + print('now to load model.....') pipe = GaudiCogVideoXPipeline.from_pretrained( args.model_name_or_path, **kwargs ) print('pipe line load done!') - + pipe.vae.enable_tiling() pipe.vae.enable_slicing() - + print('now to generate video.') video = pipe( prompt=args.prompts, @@ -73,9 +72,9 @@ def main(): guidance_scale=6, generator=torch.Generator(device="cpu").manual_seed(42), ).frames[0] - + print('generate video done!') - + export_to_video(video, args.output_name, fps=8) @@ -83,4 +82,4 @@ def main(): if __name__ == "__main__": main() - + diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py index 89e10248e1..e9b598a629 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py @@ -1,7 +1,9 @@ -from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from typing import Optional, Tuple, Union + import torch import torch.nn as nn + try: from habana_frameworks.torch.hpex.kernels import FusedSDPA except ImportError: @@ -19,6 +21,8 @@ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, sof from diffusers.models.attention import Attention + + class CogVideoXAttnProcessorGaudi: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -90,11 +94,14 @@ def __call__( import torch.nn.functional as F from diffusers.models import attention_processor + + attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXSafeConv3d from diffusers.models.autoencoders.vae import DecoderOutput + class CogVideoXCausalConv3dGaudi(nn.Module): r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. @@ -181,9 +188,13 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return output from diffusers.models.autoencoders import autoencoder_kl_cogvideox + + autoencoder_kl_cogvideox.CogVideoXCausalConv3d = CogVideoXCausalConv3dGaudi from diffusers.models.autoencoders.autoencoder_kl_cogvideox import AutoencoderKLCogVideoX + + class AutoencoderKLCogVideoXGaudi(AutoencoderKLCogVideoX): def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: r""" @@ -268,6 +279,8 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod return DecoderOutput(sample=dec) from diffusers.models.autoencoders import autoencoder_kl_cogvideox + + autoencoder_kl_cogvideox.AutoencoderKLCogVideoX=AutoencoderKLCogVideoXGaudi diff --git a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py index f5b7d7b9ca..7812543a7c 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py @@ -13,28 +13,21 @@ # limitations under the License. import inspect +import time as tm_perf from dataclasses import dataclass -from math import ceil -from typing import Any, Callable, Dict, List, Optional, Union - -import cogvideoX_gaudi +from typing import Callable, Dict, List, Optional, Union -import numpy as np -import PIL.Image import torch from diffusers import CogVideoXPipeline -from transformers import T5EncoderModel, T5Tokenizer +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import BaseOutput, logging from diffusers.utils.torch_utils import randn_tensor -from diffusers.utils import BaseOutput -from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback - -from diffusers.utils import logging +from transformers import T5EncoderModel, T5Tokenizer -from optimum.habana.transformers.gaudi_configuration import GaudiConfig from optimum.habana.diffusers.pipelines.pipeline_utils import GaudiDiffusionPipeline -import time as tm_perf +from optimum.habana.transformers.gaudi_configuration import GaudiConfig logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -323,7 +316,7 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - time_box = time_box_t() + time_box = time_box_t() time_box.start() # 0. Default height and width to unet height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial @@ -358,7 +351,7 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt + # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, negative_prompt, @@ -516,10 +509,10 @@ def capture_replay(self, latent_model_input, prompt_embeds, timestep, image_rota graph = self.ht.hpu.HPUGraph() graph.capture_begin() outputs = self.transformer( - hidden_states = inputs[0], - encoder_hidden_states = inputs[1], - timestep=inputs[2], - image_rotary_emb=inputs[3], + hidden_states = inputs[0], + encoder_hidden_states = inputs[1], + timestep=inputs[2], + image_rotary_emb=inputs[3], return_dict=inputs[4] )[0] graph.capture_end() From 13b86c8cf1bae2ea8f08119fc07581ef0d0b4f72 Mon Sep 17 00:00:00 2001 From: Zhiwei35 Date: Tue, 17 Dec 2024 11:43:51 +0800 Subject: [PATCH 05/23] fix module import defect --- examples/text-to-video/cogvideox_generate.py | 1 + .../habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/examples/text-to-video/cogvideox_generate.py b/examples/text-to-video/cogvideox_generate.py index 26e4b74c4f..4d77c01174 100644 --- a/examples/text-to-video/cogvideox_generate.py +++ b/examples/text-to-video/cogvideox_generate.py @@ -4,6 +4,7 @@ import torch from diffusers.utils import export_to_video +from optimum.habana.diffusers.pipelines.cogvideox.cogvideoX_gaudi import adapt_cogvideo_to_gaudi from optimum.habana.diffusers.pipelines.cogvideox.pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline from optimum.habana.transformers.gaudi_configuration import GaudiConfig diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py index e9b598a629..37c1df5d44 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py @@ -283,4 +283,10 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod autoencoder_kl_cogvideox.AutoencoderKLCogVideoX=AutoencoderKLCogVideoXGaudi +import diffusers +def adapt_cogvideo_to_gaudi(): + diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoXCausalConv3d = CogVideoXCausalConv3dGaudi + diffusers.models.autoencoders.autoencoder_kl_cogvideox.AutoencoderKLCogVideoX = AutoencoderKLCogVideoXGaudi + diffusers.models.attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi + From d125fe699d106d433ebe23dd017943a7c1286fb3 Mon Sep 17 00:00:00 2001 From: Zhiwei35 Date: Tue, 17 Dec 2024 11:45:15 +0800 Subject: [PATCH 06/23] update module import method --- optimum/habana/diffusers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/habana/diffusers/__init__.py b/optimum/habana/diffusers/__init__.py index 86b6477c0b..234233065f 100644 --- a/optimum/habana/diffusers/__init__.py +++ b/optimum/habana/diffusers/__init__.py @@ -1,3 +1,4 @@ +from .pipelines.cogvideox.cogvideoX_gaudi import adapt_cogvideo_to_gaudi from .pipelines.auto_pipeline import AutoPipelineForInpainting, AutoPipelineForText2Image from .pipelines.controlnet.pipeline_controlnet import GaudiStableDiffusionControlNetPipeline from .pipelines.controlnet.pipeline_stable_video_diffusion_controlnet import ( From 4698ded7961c08cb5e1fd3e8a4fc1fcd430ba742 Mon Sep 17 00:00:00 2001 From: Zhiwei35 Date: Tue, 17 Dec 2024 16:33:27 +0800 Subject: [PATCH 07/23] upgrade for diffusers version 0.31.0 --- .../pipelines/cogvideox/cogvideoX_gaudi.py | 39 ++++++++----------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py index 37c1df5d44..8aa487b65b 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py @@ -153,39 +153,30 @@ def __init__( dilation=dilation, ) - self.conv_cache = None - def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor: + def fake_context_parallel_forward( + self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None + ) -> torch.Tensor: kernel_size = self.time_kernel_size if kernel_size > 1: - cached_inputs = ( - [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) - ) + cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) inputs = torch.cat(cached_inputs + [inputs], dim=2) return inputs - def _clear_fake_context_parallel_cache(self): - del self.conv_cache - self.conv_cache = None - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - inputs = self.fake_context_parallel_forward(inputs) - - #self._clear_fake_context_parallel_cache() - # Note: we could move these to the cpu for a lower maximum memory usage but its only a few - # hundred megabytes and so let's not do it for now - #self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: + inputs = self.fake_context_parallel_forward(inputs, conv_cache) + #conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) inputs_pad = F.pad(inputs, padding_2d, mode="constant", value=0) output = self.conv(inputs_pad) if self.time_kernel_size>1: - if self.conv_cache is not None and self.conv_cache.shape == inputs[:, :, -self.time_kernel_size + 1:].shape: - self.conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:]) + if conv_cache is not None and conv_cache.shape == inputs[:, :, -self.time_kernel_size + 1:].shape: + conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:]) else: - self.conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() - return output + conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() + return output, conv_cache from diffusers.models.autoencoders import autoencoder_kl_cogvideox @@ -237,8 +228,10 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod for i in range(0, height, overlap_height): row = [] for j in range(0, width, overlap_width): - num_batches = num_frames // frame_batch_size + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None time = [] + for k in range(num_batches): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) @@ -252,9 +245,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod ].clone() if self.post_quant_conv is not None: tile = self.post_quant_conv(tile) - tile = self.decoder(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) time.append(tile.clone()) - self._clear_fake_context_parallel_cache() + row.append(torch.cat(time, dim=2)) rows.append(row) From 21caddcd392a46996809e2ab5bea7ae0735c4940 Mon Sep 17 00:00:00 2001 From: ranzhejiang Date: Wed, 18 Dec 2024 18:12:36 +0800 Subject: [PATCH 08/23] add cogVideo test case. --- .../cogvideox/pipeline_cogvideox_gaudi.py | 3 - tests/test_diffusers.py | 136 ++++++++++++++++++ 2 files changed, 136 insertions(+), 3 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py index 7812543a7c..afb8165340 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py @@ -466,11 +466,8 @@ def __call__( #HabanaProfile.stop() if not output_type == "latent": - #print('baymax now to decode latents') - #latents = latents.to('cpu') video = self.decode_latents(latents) time_box.show_time('decode latents') - #print('baymax decode latents done!') video = self.video_processor.postprocess_video(video=video, output_type=output_type) time_box.show_time('postprocess_video') else: diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index 97bbb7632d..fd580bbcbe 100755 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -42,6 +42,7 @@ AutoencoderKL, AutoencoderKLTemporalDecoder, AutoencoderTiny, + AutoencoderKLCogVideoX, ControlNetModel, DiffusionPipeline, DPMSolverMultistepScheduler, @@ -59,6 +60,8 @@ UNet3DConditionModel, UNetSpatioTemporalConditionModel, UniPCMultistepScheduler, + CogVideoXTransformer3DModel, + CogVideoXDDIMScheduler, ) from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel @@ -89,6 +92,8 @@ DPTFeatureExtractor, DPTForDepthEstimation, T5EncoderModel, + T5Tokenizer, + T5Config, ) from transformers.testing_utils import parse_flag_from_env, slow @@ -117,6 +122,7 @@ GaudiStableVideoDiffusionControlNetPipeline, GaudiStableVideoDiffusionPipeline, GaudiTextToVideoSDPipeline, + GaudiCogVideoXPipeline, ) from optimum.habana.diffusers.models import ( ControlNetSDVModel, @@ -3767,6 +3773,136 @@ def test_deterministic_image_generation_no_throughput_regression_bf16(self): self.assertGreaterEqual(outputs.throughput, 0.95 * DETERMINISTIC_IMAGE_GENERATION_THROUGHPUT) +class GaudiCogVideoXPipelineTester(TestCase): + """ + Tests the TextToVideoSDPipeline for Gaudi. + Adapted from https://github.com/huggingface/diffusers/blob/v0.24.0-release/tests/pipelines/text_to_video_synthesis/test_text_to_video.py + """ + + def get_dummy_components(self): + tokenizer = T5Tokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + set_seed(0) + text_encoder_cfg = T5Config(vocab_size = 32128, + d_kv = 64, + d_ff = 10240, + num_layers = 8, + num_decoder_layers=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + initializer_factor=1.0, + feed_forward_proj='gated-gelu', + is_encoder_decoder=True, + pad_token_id=0, + eos_token_id=1, + torch_dtype = torch.bfloat16, + d_model = 4096) + text_encoder = T5EncoderModel(text_encoder_cfg).bfloat16() + + set_seed(0) + transformer = CogVideoXTransformer3DModel( + num_attention_heads=30, + attention_head_dim=64, + in_channels=16, + out_channels=16, + flip_sin_to_cos=True, + freq_shift=0, + time_embed_dim=512, + text_embed_dim=4096, + num_layers=8, + dropout=0.0, + attention_bias=True, + sample_width=90, + sample_height=60, + sample_frames=49, + patch_size=2, + temporal_compression_ratio=4, + max_text_seq_length=226, + activation_fn="gelu-approximate", + timestep_activation_fn="silu", + norm_elementwise_affine=True, + norm_eps=1e-5, + spatial_interpolation_scale=1.875, + temporal_interpolation_scale=1.0, + ).bfloat16() + + scheduler = CogVideoXDDIMScheduler( + num_train_timesteps=1000, + beta_start = 0.00085, + beta_end = 0.0120, + beta_schedule = "scaled_linear", + clip_sample=False, + set_alpha_to_one = True, + steps_offset=0, + prediction_type = "v_prediction", + clip_sample_range = 1.0, + sample_max_value = 1.0, + timestep_spacing = "trailing", + rescale_betas_zero_snr = True, + snr_shift_scale=1.0, + ) + + + set_seed(0) + vae = AutoencoderKLCogVideoX(in_channels=3, + out_channels = 3, + down_block_types = [ + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D" + ], + block_out_channels = [128,256,256,512], + latent_channels=16, + layers_per_block=1, + act_fn="silu", + norm_eps=1e-6, + norm_num_groups=32, + temporal_compression_ratio=4, + sample_height=480, + sample_width=720, + scaling_factor=1.15258426, + ).bfloat16() + + + vae.enable_slicing() + vae.enable_tiling() + + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "scheduler": scheduler, + "vae": vae, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + prompts = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." + return prompts + + def test_cogvideoX_default_case(self): + gaudi_config_kwargs = {"use_fused_adam": True, "use_fused_clip_norm": True} + gaudi_config_kwargs["use_torch_autocast"] = True + gaudi_config = GaudiConfig(**gaudi_config_kwargs) + + components = self.get_dummy_components() + components["use_habana"] = True + components["use_hpu_graphs"] = True + components["gaudi_config"] = gaudi_config + + cogVideoX_pipe = GaudiCogVideoXPipeline(**components) + video = pipe( + prompt=prompts, + num_videos_per_prompt=1, + num_inference_steps=5, + num_frames=49, + guidance_scale=6, + generator=torch.Generator(device="cpu").manual_seed(42), + ).frames[0] + + assert video is not None + assert 49 == len(video) class GaudiTextToVideoSDPipelineTester(TestCase): """ From feff2a366af064304fa1a2f226957766a1fce1b9 Mon Sep 17 00:00:00 2001 From: libo7x Date: Wed, 18 Dec 2024 18:42:29 +0800 Subject: [PATCH 09/23] refine model default path --- examples/text-to-video/cogvideox_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text-to-video/cogvideox_generate.py b/examples/text-to-video/cogvideox_generate.py index 4d77c01174..4b95c0a8ee 100644 --- a/examples/text-to-video/cogvideox_generate.py +++ b/examples/text-to-video/cogvideox_generate.py @@ -18,7 +18,7 @@ def main(): parser.add_argument( "--model_name_or_path", - default="CogVideoX-2b", + default="THUDM/CogVideoX-2b", type=str, help="Path to pre-trained model", ) From ae05af94907e80f5c83fc769d5d8c6eb23cae146 Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Thu, 19 Dec 2024 08:49:48 +0000 Subject: [PATCH 10/23] add required python lib for cogvideo --- examples/text-to-video/requirements.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/text-to-video/requirements.txt b/examples/text-to-video/requirements.txt index 6ab6d0d570..f3e192bbdc 100644 --- a/examples/text-to-video/requirements.txt +++ b/examples/text-to-video/requirements.txt @@ -1 +1,5 @@ opencv-python-headless +sentencepiece +imageio +imageio-ffmpeg + From 12badb89586148947065a214b6cd8860d88392d6 Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Mon, 13 Jan 2025 10:53:22 +0800 Subject: [PATCH 11/23] refine README.MD --- examples/text-to-video/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text-to-video/README.md b/examples/text-to-video/README.md index 061c5c7928..a7ab947b24 100644 --- a/examples/text-to-video/README.md +++ b/examples/text-to-video/README.md @@ -42,8 +42,8 @@ Models that have been validated: CogvideoX test: ```bash -python3 cogvideo_generate.py \ - --model_name_or_path CogVideoX-2b \ +python3 cogvideox_generate.py \ + --model_name_or_path THUDM/CogVideoX-2b \ --output_name gaudi_output.mp4 ``` From 7df1a6c0c1c31c730cda7115dfecc925b6f4f802 Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Wed, 15 Jan 2025 06:46:04 +0000 Subject: [PATCH 12/23] use gaudi implementation of apply rotary embedding. --- .../pipelines/cogvideox/cogvideoX_gaudi.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py index 8aa487b65b..5d73bfbe9b 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py @@ -22,6 +22,23 @@ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, sof from diffusers.models.attention import Attention +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Adapted from: https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/models/embeddings.py#L697 + """ + cos_, sin_ = freqs_cis # [S, D] + + cos = cos_[None, None] + sin = sin_[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + x = torch.ops.hpu.rotary_pos_embedding(x, sin, cos, None, 0, 1) + + return x + class CogVideoXAttnProcessorGaudi: r""" @@ -70,8 +87,6 @@ def __call__( # Apply RoPE if needed if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) From 6919313fa76009d9638c1550aa790a868978980c Mon Sep 17 00:00:00 2001 From: "tony.lin@intel.com" Date: Thu, 23 Jan 2025 15:23:58 +0800 Subject: [PATCH 13/23] fix htcore defect --- .../pipelines/cogvideox/pipeline_cogvideox_gaudi.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py index afb8165340..8e4c3c7e40 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py @@ -28,6 +28,7 @@ from optimum.habana.diffusers.pipelines.pipeline_utils import GaudiDiffusionPipeline from optimum.habana.transformers.gaudi_configuration import GaudiConfig +import habana_frameworks.torch.core as htcore logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -444,7 +445,7 @@ def __call__( latents = latents.to(prompt_embeds.dtype) if not self.use_hpu_graphs: - self.htcore.mark_step() + htcore.mark_step() # call the callback, if provided if callback_on_step_end is not None: @@ -461,7 +462,7 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if not self.use_hpu_graphs: - self.htcore.mark_step() + htcore.mark_step() time_box.show_time('transformer_hpu') #HabanaProfile.stop() From c15aa511ea32e1758769c5a654c33687dd2ac361 Mon Sep 17 00:00:00 2001 From: "tony.lin@intel.com" Date: Thu, 23 Jan 2025 15:25:01 +0800 Subject: [PATCH 14/23] fix can't find htcore defect. --- .../diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py index 8e4c3c7e40..400a16f66f 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py @@ -444,7 +444,7 @@ def __call__( ) latents = latents.to(prompt_embeds.dtype) - if not self.use_hpu_graphs: + if self.use_hpu_graphs: htcore.mark_step() # call the callback, if provided @@ -461,7 +461,7 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if not self.use_hpu_graphs: + if self.use_hpu_graphs: htcore.mark_step() time_box.show_time('transformer_hpu') From 687caf9ff135b6d74a548e7f8014a142e452714e Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Thu, 23 Jan 2025 09:51:07 +0000 Subject: [PATCH 15/23] support for G3 on graph optimization --- .../pipelines/cogvideox/cogvideoX_gaudi.py | 174 +++++++++++++++++- 1 file changed, 172 insertions(+), 2 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py index 5d73bfbe9b..210ac631b4 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -291,10 +291,180 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod autoencoder_kl_cogvideox.AutoencoderKLCogVideoX=AutoencoderKLCogVideoXGaudi -import diffusers +from diffusers.utils import USE_PEFT_BACKEND +from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel +import habana_frameworks.torch.core as htcore + +class CogVideoXTransformer3DModelGaudi(CogVideoXTransformer3DModel): + def __init__( + self, + num_attention_heads: int = 30, + attention_head_dim: int = 64, + in_channels: int = 16, + out_channels: Optional[int] = 16, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + time_embed_dim: int = 512, + text_embed_dim: int = 4096, + num_layers: int = 30, + dropout: float = 0.0, + attention_bias: bool = True, + sample_width: int = 90, + sample_height: int = 60, + sample_frames: int = 49, + patch_size: int = 2, + temporal_compression_ratio: int = 4, + max_text_seq_length: int = 226, + activation_fn: str = "gelu-approximate", + timestep_activation_fn: str = "silu", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + spatial_interpolation_scale: float = 1.875, + temporal_interpolation_scale: float = 1.0, + use_rotary_positional_embeddings: bool = False, + use_learned_positional_embeddings: bool = False, + ): + super().__init__( + num_attention_heads, + attention_head_dim, + in_channels, + out_channels, + flip_sin_to_cos, + freq_shift, + time_embed_dim, + text_embed_dim, + num_layers, + dropout, + attention_bias, + sample_width, + sample_height, + sample_frames, + patch_size, + temporal_compression_ratio, + max_text_seq_length, + activation_fn, + timestep_activation_fn, + norm_elementwise_affine, + norm_eps, + spatial_interpolation_scale, + temporal_interpolation_scale, + use_rotary_positional_embeddings, + use_learned_positional_embeddings, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + print(f'baymax run gaudi CogVideoXTransformer3DModel forward!') + + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + htcore.mark_step() + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + +from diffusers.models.transformers import cogvideox_transformer_3d +cogvideox_transformer_3d.CogVideoXTransformer3DModel = CogVideoXTransformer3DModelGaudi + + def adapt_cogvideo_to_gaudi(): + import diffusers diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoXCausalConv3d = CogVideoXCausalConv3dGaudi diffusers.models.autoencoders.autoencoder_kl_cogvideox.AutoencoderKLCogVideoX = AutoencoderKLCogVideoXGaudi diffusers.models.attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi + diffusers.models.transformers.cogvideox_transformer_3d.CogVideoXTransformer3DModel = CogVideoXTransformer3DModelGaudi From 339e31f9cf649c302f61db63e21ec69f16b749df Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Thu, 23 Jan 2025 09:52:44 +0000 Subject: [PATCH 16/23] clear debug code, --- optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py index 210ac631b4..4c0e64423a 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py @@ -397,8 +397,6 @@ def forward( encoder_hidden_states = hidden_states[:, :text_seq_length] hidden_states = hidden_states[:, text_seq_length:] - print(f'baymax run gaudi CogVideoXTransformer3DModel forward!') - # 3. Transformer blocks for i, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: From 4ab7ebeba95bac476da943a28a0a9517225be1d7 Mon Sep 17 00:00:00 2001 From: ranzhejiang Date: Sun, 26 Jan 2025 14:38:52 +0800 Subject: [PATCH 17/23] set transformer gaudi fowrad in pipelines. --- .../pipelines/cogvideox/cogvideoX_gaudi.py | 171 +----------------- .../cogvideox/pipeline_cogvideox_gaudi.py | 108 ++++++++++- 2 files changed, 110 insertions(+), 169 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py index 4c0e64423a..64395ea473 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py @@ -117,6 +117,7 @@ def __call__( from diffusers.models.autoencoders.vae import DecoderOutput +import habana_frameworks.torch.core as htcore class CogVideoXCausalConv3dGaudi(nn.Module): r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. @@ -191,6 +192,7 @@ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = Non conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:]) else: conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() + htcore.mark_step() return output, conv_cache from diffusers.models.autoencoders import autoencoder_kl_cogvideox @@ -291,178 +293,11 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod autoencoder_kl_cogvideox.AutoencoderKLCogVideoX=AutoencoderKLCogVideoXGaudi -from diffusers.utils import USE_PEFT_BACKEND -from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXTransformer3DModel -import habana_frameworks.torch.core as htcore - -class CogVideoXTransformer3DModelGaudi(CogVideoXTransformer3DModel): - def __init__( - self, - num_attention_heads: int = 30, - attention_head_dim: int = 64, - in_channels: int = 16, - out_channels: Optional[int] = 16, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - time_embed_dim: int = 512, - text_embed_dim: int = 4096, - num_layers: int = 30, - dropout: float = 0.0, - attention_bias: bool = True, - sample_width: int = 90, - sample_height: int = 60, - sample_frames: int = 49, - patch_size: int = 2, - temporal_compression_ratio: int = 4, - max_text_seq_length: int = 226, - activation_fn: str = "gelu-approximate", - timestep_activation_fn: str = "silu", - norm_elementwise_affine: bool = True, - norm_eps: float = 1e-5, - spatial_interpolation_scale: float = 1.875, - temporal_interpolation_scale: float = 1.0, - use_rotary_positional_embeddings: bool = False, - use_learned_positional_embeddings: bool = False, - ): - super().__init__( - num_attention_heads, - attention_head_dim, - in_channels, - out_channels, - flip_sin_to_cos, - freq_shift, - time_embed_dim, - text_embed_dim, - num_layers, - dropout, - attention_bias, - sample_width, - sample_height, - sample_frames, - patch_size, - temporal_compression_ratio, - max_text_seq_length, - activation_fn, - timestep_activation_fn, - norm_elementwise_affine, - norm_eps, - spatial_interpolation_scale, - temporal_interpolation_scale, - use_rotary_positional_embeddings, - use_learned_positional_embeddings, - ) - - def forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: Union[int, float, torch.LongTensor], - timestep_cond: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ): - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - - batch_size, num_frames, channels, height, width = hidden_states.shape - - # 1. Time embedding - timesteps = timestep - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - emb = self.time_embedding(t_emb, timestep_cond) - - # 2. Patch embedding - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - hidden_states = self.embedding_dropout(hidden_states) - - text_seq_length = encoder_hidden_states.shape[1] - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] - - # 3. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - ) - htcore.mark_step() - - if not self.config.use_rotary_positional_embeddings: - # CogVideoX-2B - hidden_states = self.norm_final(hidden_states) - else: - # CogVideoX-5B - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - hidden_states = self.norm_final(hidden_states) - hidden_states = hidden_states[:, text_seq_length:] - - # 4. Final block - hidden_states = self.norm_out(hidden_states, temb=emb) - hidden_states = self.proj_out(hidden_states) - - # 5. Unpatchify - # Note: we use `-1` instead of `channels`: - # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) - # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) - -from diffusers.models.transformers import cogvideox_transformer_3d -cogvideox_transformer_3d.CogVideoXTransformer3DModel = CogVideoXTransformer3DModelGaudi - - def adapt_cogvideo_to_gaudi(): import diffusers diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoXCausalConv3d = CogVideoXCausalConv3dGaudi diffusers.models.autoencoders.autoencoder_kl_cogvideox.AutoencoderKLCogVideoX = AutoencoderKLCogVideoXGaudi diffusers.models.attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi - diffusers.models.transformers.cogvideox_transformer_3d.CogVideoXTransformer3DModel = CogVideoXTransformer3DModelGaudi + #diffusers.models.transformers.cogvideox_transformer_3d.CogVideoXTransformer3DModel = CogVideoXTransformer3DModelGaudi diff --git a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py index 400a16f66f..cb6fa07a30 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py @@ -15,7 +15,7 @@ import inspect import time as tm_perf from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union, Callable, List, import torch from diffusers import CogVideoXPipeline @@ -121,6 +121,109 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps +from diffusers.utils import USE_PEFT_BACKEND +def gaudi_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + print(f'baymax debug run gaudi transformer forward!') + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + htcore.mark_step() + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) class GaudiCogVideoXPipeline(GaudiDiffusionPipeline, CogVideoXPipeline): r""" @@ -156,6 +259,7 @@ def __init__( scheduler, ) self.to(self._device) + self.transformer.forward = gaudi_forward from habana_frameworks.torch.hpu import wrap_in_hpu_graph self.vae.decoder = wrap_in_hpu_graph(self.vae.decoder) @@ -488,6 +592,7 @@ def transformer_hpu(self, latent_model_input, prompt_embeds, timestep, image_rot return self.capture_replay(latent_model_input, prompt_embeds, timestep, image_rotary_emb) else: return self.transformer( + self.transformer, hidden_states=latent_model_input, encoder_hidden_states=prompt_embeds, timestep=timestep, @@ -507,6 +612,7 @@ def capture_replay(self, latent_model_input, prompt_embeds, timestep, image_rota graph = self.ht.hpu.HPUGraph() graph.capture_begin() outputs = self.transformer( + self.transformer, hidden_states = inputs[0], encoder_hidden_states = inputs[1], timestep=inputs[2], From c3e253ee91f3241b646803936f967b8a753793ea Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Sun, 26 Jan 2025 10:16:31 +0000 Subject: [PATCH 18/23] set autoencoder tiled decode gaudi wit setattr. --- .../pipelines/cogvideox/cogvideoX_gaudi.py | 95 +----------------- .../cogvideox/pipeline_cogvideox_gaudi.py | 96 ++++++++++++++++++- 2 files changed, 95 insertions(+), 96 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py index 64395ea473..faab010bcd 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py @@ -117,7 +117,6 @@ def __call__( from diffusers.models.autoencoders.vae import DecoderOutput -import habana_frameworks.torch.core as htcore class CogVideoXCausalConv3dGaudi(nn.Module): r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. @@ -192,7 +191,6 @@ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = Non conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:]) else: conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() - htcore.mark_step() return output, conv_cache from diffusers.models.autoencoders import autoencoder_kl_cogvideox @@ -202,101 +200,10 @@ def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = Non from diffusers.models.autoencoders.autoencoder_kl_cogvideox import AutoencoderKLCogVideoX - -class AutoencoderKLCogVideoXGaudi(AutoencoderKLCogVideoX): - def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - r""" - Decode a batch of images using a tiled decoder. - - Args: - z (`torch.Tensor`): Input batch of latent vectors. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. - - Returns: - [`~models.vae.DecoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is - returned. - """ - # Rough memory assessment: - # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. - # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. - # - Assume fp16 (2 bytes per value). - # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB - # - # Memory assessment when using tiling: - # - Assume everything as above but now HxW is 240x360 by tiling in half - # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB - - print('run gaudi tiled decode!') - batch_size, num_channels, num_frames, height, width = z.shape - - overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) - overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) - blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) - blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) - row_limit_height = self.tile_sample_min_height - blend_extent_height - row_limit_width = self.tile_sample_min_width - blend_extent_width - frame_batch_size = self.num_latent_frames_batch_size - - # Split z into overlapping tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - rows = [] - for i in range(0, height, overlap_height): - row = [] - for j in range(0, width, overlap_width): - num_batches = max(num_frames // frame_batch_size, 1) - conv_cache = None - time = [] - - for k in range(num_batches): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) - end_frame = frame_batch_size * (k + 1) + remaining_frames - tile = z[ - :, - :, - start_frame:end_frame, - i : i + self.tile_latent_min_height, - j : j + self.tile_latent_min_width, - ].clone() - if self.post_quant_conv is not None: - tile = self.post_quant_conv(tile) - tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) - time.append(tile.clone()) - - row.append(torch.cat(time, dim=2)) - rows.append(row) - - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent_width) - result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) - result_rows.append(torch.cat(result_row, dim=4)) - - dec = torch.cat(result_rows, dim=3) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - -from diffusers.models.autoencoders import autoencoder_kl_cogvideox - - -autoencoder_kl_cogvideox.AutoencoderKLCogVideoX=AutoencoderKLCogVideoXGaudi - def adapt_cogvideo_to_gaudi(): import diffusers diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoXCausalConv3d = CogVideoXCausalConv3dGaudi - diffusers.models.autoencoders.autoencoder_kl_cogvideox.AutoencoderKLCogVideoX = AutoencoderKLCogVideoXGaudi + #diffusers.models.autoencoders.autoencoder_kl_cogvideox.AutoencoderKLCogVideoX = AutoencoderKLCogVideoXGaudi diffusers.models.attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi #diffusers.models.transformers.cogvideox_transformer_3d.CogVideoXTransformer3DModel = CogVideoXTransformer3DModelGaudi diff --git a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py index cb6fa07a30..fc8b897fea 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py @@ -15,7 +15,8 @@ import inspect import time as tm_perf from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union, Callable, List, +from typing import Callable, Dict, List, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import torch from diffusers import CogVideoXPipeline @@ -24,6 +25,7 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import BaseOutput, logging from diffusers.utils.torch_utils import randn_tensor +from diffusers.models.autoencoders.vae import DecoderOutput from transformers import T5EncoderModel, T5Tokenizer from optimum.habana.diffusers.pipelines.pipeline_utils import GaudiDiffusionPipeline @@ -225,6 +227,93 @@ def custom_forward(*inputs): return (output,) return Transformer2DModelOutput(sample=output) +def tiled_decode_gaudi(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + # Rough memory assessment: + # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. + # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. + # - Assume fp16 (2 bytes per value). + # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB + # + # Memory assessment when using tiling: + # - Assume everything as above but now HxW is 240x360 by tiling in half + # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + + print('run gaudi pipelined tiled decode!') + batch_size, num_channels, num_frames, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + time = [] + + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ].clone() + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + time.append(tile.clone()) + htcore.mark_step() + + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + +setattr(AutoencoderKLCogVideoX, 'tiled_decode', tiled_decode_gaudi) + class GaudiCogVideoXPipeline(GaudiDiffusionPipeline, CogVideoXPipeline): r""" Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py#L84 @@ -260,9 +349,12 @@ def __init__( ) self.to(self._device) self.transformer.forward = gaudi_forward + print(f'vae decode name:{self.vae.tiled_decode.__name__}') from habana_frameworks.torch.hpu import wrap_in_hpu_graph - self.vae.decoder = wrap_in_hpu_graph(self.vae.decoder) + #self.vae.decoder = wrap_in_hpu_graph(self.vae.decoder) + self.vae.tiled_decode = tiled_decode_gaudi + print(f' vae decode name:{self.vae.tiled_decode.__name__} tiled_decode_gaudi:{tiled_decode_gaudi.__name__}') @property def guidance_scale(self): From bb036d6299c2f0e8cccde7e9a34ccab33c07ce4b Mon Sep 17 00:00:00 2001 From: root Date: Fri, 7 Feb 2025 16:46:45 +0800 Subject: [PATCH 19/23] move cogvideox conv3d to gaudi pipeline. --- .../pipelines/cogvideox/cogvideoX_gaudi.py | 92 ------------ .../cogvideox/pipeline_cogvideox_gaudi.py | 131 +++++++++++++++++- 2 files changed, 128 insertions(+), 95 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py index faab010bcd..08c8cdd874 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py @@ -107,104 +107,12 @@ def __call__( ) return hidden_states, encoder_hidden_states -import torch.nn.functional as F from diffusers.models import attention_processor - attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi -from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXSafeConv3d -from diffusers.models.autoencoders.vae import DecoderOutput - - -class CogVideoXCausalConv3dGaudi(nn.Module): - r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. - - Args: - in_channels (`int`): Number of channels in the input tensor. - out_channels (`int`): Number of output channels produced by the convolution. - kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel. - stride (`int`, defaults to `1`): Stride of the convolution. - dilation (`int`, defaults to `1`): Dilation rate of the convolution. - pad_mode (`str`, defaults to `"constant"`): Padding mode. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int, int]], - stride: int = 1, - dilation: int = 1, - pad_mode: str = "constant", - ): - super().__init__() - - if isinstance(kernel_size, int): - kernel_size = (kernel_size,) * 3 - - time_kernel_size, height_kernel_size, width_kernel_size = kernel_size - - self.pad_mode = pad_mode - time_pad = dilation * (time_kernel_size - 1) + (1 - stride) - height_pad = height_kernel_size // 2 - width_pad = width_kernel_size // 2 - - self.height_pad = height_pad - self.width_pad = width_pad - self.time_pad = time_pad - self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) - - self.temporal_dim = 2 - self.time_kernel_size = time_kernel_size - - stride = (stride, 1, 1) - dilation = (dilation, 1, 1) - self.conv = CogVideoXSafeConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - ) - - - def fake_context_parallel_forward( - self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None - ) -> torch.Tensor: - kernel_size = self.time_kernel_size - if kernel_size > 1: - cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1) - inputs = torch.cat(cached_inputs + [inputs], dim=2) - return inputs - - def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: - inputs = self.fake_context_parallel_forward(inputs, conv_cache) - #conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() - - padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - inputs_pad = F.pad(inputs, padding_2d, mode="constant", value=0) - - output = self.conv(inputs_pad) - if self.time_kernel_size>1: - if conv_cache is not None and conv_cache.shape == inputs[:, :, -self.time_kernel_size + 1:].shape: - conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:]) - else: - conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() - return output, conv_cache - -from diffusers.models.autoencoders import autoencoder_kl_cogvideox - - -autoencoder_kl_cogvideox.CogVideoXCausalConv3d = CogVideoXCausalConv3dGaudi - -from diffusers.models.autoencoders.autoencoder_kl_cogvideox import AutoencoderKLCogVideoX - def adapt_cogvideo_to_gaudi(): import diffusers - diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoXCausalConv3d = CogVideoXCausalConv3dGaudi - #diffusers.models.autoencoders.autoencoder_kl_cogvideox.AutoencoderKLCogVideoX = AutoencoderKLCogVideoXGaudi diffusers.models.attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi - #diffusers.models.transformers.cogvideox_transformer_3d.CogVideoXTransformer3DModel = CogVideoXTransformer3DModelGaudi diff --git a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py index fc8b897fea..ca679aa146 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py @@ -50,6 +50,112 @@ def show_time(self, desc): self.t0 = t1 print(f'{desc} duration:{duration:.3f}s') +#try: +# from habana_frameworks.torch.hpex.kernels import FusedSDPA +#except ImportError: +# print("Not using HPU fused scaled dot-product attention kernel.") +# FusedSDPA = None +# +## FusedScaledDotProductAttention +#class ModuleFusedSDPA(torch.nn.Module): +# def __init__(self, fusedSDPA): +# super().__init__() +# self._hpu_kernel_fsdpa = fusedSDPA +# +# def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): +# return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) +# +# +#from diffusers.models.attention import Attention +# +#def apply_rotary_emb( +# x: torch.Tensor, +# freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], +#) -> Tuple[torch.Tensor, torch.Tensor]: +# """ +# Adapted from: https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/models/embeddings.py#L697 +# """ +# cos_, sin_ = freqs_cis # [S, D] +# +# cos = cos_[None, None] +# sin = sin_[None, None] +# cos, sin = cos.to(x.device), sin.to(x.device) +# +# x = torch.ops.hpu.rotary_pos_embedding(x, sin, cos, None, 0, 1) +# +# return x +# +#class CogVideoXAttnProcessorGaudi: +# r""" +# Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on +# query and key vectors, but does not include spatial normalization. +# """ +# +# def __init__(self): +# self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None +# +# def __call__( +# self, +# attn: Attention, +# hidden_states: torch.Tensor, +# encoder_hidden_states: torch.Tensor, +# attention_mask: Optional[torch.Tensor] = None, +# image_rotary_emb: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: +# print(f'run gaudi transformer attention_processor with fused SDPA!') +# text_seq_length = encoder_hidden_states.size(1) +# +# hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) +# +# batch_size, sequence_length, _ = ( +# hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape +# ) +# +# if attention_mask is not None: +# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) +# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) +# +# query = attn.to_q(hidden_states) +# key = attn.to_k(hidden_states) +# value = attn.to_v(hidden_states) +# +# inner_dim = key.shape[-1] +# head_dim = inner_dim // attn.heads +# +# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) +# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) +# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) +# +# if attn.norm_q is not None: +# query = attn.norm_q(query) +# if attn.norm_k is not None: +# key = attn.norm_k(key) +# +# # Apply RoPE if needed +# if image_rotary_emb is not None: +# query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) +# if not attn.is_cross_attention: +# key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) +# +# hidden_states = self.fused_scaled_dot_product_attention( +# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_casual=False, scale=None, softmax_mode='fast' +# ) +# +# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) +# +# # linear proj +# hidden_states = attn.to_out[0](hidden_states) +# # dropout +# hidden_states = attn.to_out[1](hidden_states) +# +# encoder_hidden_states, hidden_states = hidden_states.split( +# [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 +# ) +# return hidden_states, encoder_hidden_states +# +#from diffusers.models import attention_processor +#setattr(attention_processor, 'CogVideoXAttnProcessor2_0', CogVideoXAttnProcessorGaudi) + @dataclass class GaudiTextToVideoSDPipelineOutput(BaseOutput): r""" @@ -314,6 +420,26 @@ def tiled_decode_gaudi(self, z: torch.Tensor, return_dict: bool = True) -> Union setattr(AutoencoderKLCogVideoX, 'tiled_decode', tiled_decode_gaudi) +import torch.nn.functional as F +def CogVideoXCausalConv3dforwardGaudi(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: + #print('run gaudi CogVideoXCausalConv3d forward!') + inputs = self.fake_context_parallel_forward(inputs, conv_cache) + #conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + inputs_pad = F.pad(inputs, padding_2d, mode="constant", value=0) + + output = self.conv(inputs_pad) + if self.time_kernel_size>1: + if conv_cache is not None and conv_cache.shape == inputs[:, :, -self.time_kernel_size + 1:].shape: + conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:]) + else: + conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() + return output, conv_cache + +from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXCausalConv3d +setattr(CogVideoXCausalConv3d, 'forward', CogVideoXCausalConv3dforwardGaudi) + class GaudiCogVideoXPipeline(GaudiDiffusionPipeline, CogVideoXPipeline): r""" Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py#L84 @@ -352,9 +478,8 @@ def __init__( print(f'vae decode name:{self.vae.tiled_decode.__name__}') from habana_frameworks.torch.hpu import wrap_in_hpu_graph - #self.vae.decoder = wrap_in_hpu_graph(self.vae.decoder) - self.vae.tiled_decode = tiled_decode_gaudi - print(f' vae decode name:{self.vae.tiled_decode.__name__} tiled_decode_gaudi:{tiled_decode_gaudi.__name__}') + self.vae.decoder = wrap_in_hpu_graph(self.vae.decoder) + #print(f' vae decode name:{self.vae.tiled_decode.__name__} tiled_decode_gaudi:{tiled_decode_gaudi.__name__}') @property def guidance_scale(self): From 65fb0ed60b89f1211c6a4e3e9923f3ad81827a43 Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Sat, 8 Feb 2025 10:14:15 +0000 Subject: [PATCH 20/23] remove import gaudi function in __init__.py --- optimum/habana/diffusers/__init__.py | 1 - .../pipelines/cogvideox/cogvideoX_gaudi.py | 216 +++++++++++- .../cogvideox/pipeline_cogvideox_gaudi.py | 326 +----------------- 3 files changed, 217 insertions(+), 326 deletions(-) diff --git a/optimum/habana/diffusers/__init__.py b/optimum/habana/diffusers/__init__.py index 234233065f..86b6477c0b 100644 --- a/optimum/habana/diffusers/__init__.py +++ b/optimum/habana/diffusers/__init__.py @@ -1,4 +1,3 @@ -from .pipelines.cogvideox.cogvideoX_gaudi import adapt_cogvideo_to_gaudi from .pipelines.auto_pipeline import AutoPipelineForInpainting, AutoPipelineForText2Image from .pipelines.controlnet.pipeline_controlnet import GaudiStableDiffusionControlNetPipeline from .pipelines.controlnet.pipeline_stable_video_diffusion_controlnet import ( diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py index 08c8cdd874..5cbfc6427c 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py @@ -2,6 +2,10 @@ import torch import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.attention import Attention +from diffusers.models.autoencoders.vae import DecoderOutput +from diffusers.utils import USE_PEFT_BACKEND try: @@ -20,7 +24,6 @@ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, sof return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) -from diffusers.models.attention import Attention def apply_rotary_emb( x: torch.Tensor, @@ -39,7 +42,6 @@ def apply_rotary_emb( return x - class CogVideoXAttnProcessorGaudi: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -57,6 +59,7 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.Tensor: + print(f'run gaudi transformer attention_processor with fused SDPA!') text_seq_length = encoder_hidden_states.size(1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -107,12 +110,209 @@ def __call__( ) return hidden_states, encoder_hidden_states -from diffusers.models import attention_processor - -attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi +def cogvideoXTransformerForwardGaudi( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + print(f'baymax debug run gaudi transformer forward!') + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + htcore.mark_step() + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + +import habana_frameworks.torch.core as htcore +def tiled_decode_gaudi(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. -def adapt_cogvideo_to_gaudi(): - import diffusers - diffusers.models.attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + # Rough memory assessment: + # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. + # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. + # - Assume fp16 (2 bytes per value). + # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB + # + # Memory assessment when using tiling: + # - Assume everything as above but now HxW is 240x360 by tiling in half + # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + + print('run gaudi pipelined tiled decode!') + batch_size, num_channels, num_frames, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + time = [] + + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ].clone() + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + time.append(tile.clone()) + htcore.mark_step() + + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + +def CogVideoXCausalConv3dforwardGaudi(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: + #print('run gaudi CogVideoXCausalConv3d forward!') + inputs = self.fake_context_parallel_forward(inputs, conv_cache) + #conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + inputs_pad = F.pad(inputs, padding_2d, mode="constant", value=0) + + output = self.conv(inputs_pad) + if self.time_kernel_size>1: + if conv_cache is not None and conv_cache.shape == inputs[:, :, -self.time_kernel_size + 1:].shape: + conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:]) + else: + conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() + return output, conv_cache diff --git a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py index ca679aa146..1cc13382a4 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py @@ -25,9 +25,10 @@ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import BaseOutput, logging from diffusers.utils.torch_utils import randn_tensor -from diffusers.models.autoencoders.vae import DecoderOutput from transformers import T5EncoderModel, T5Tokenizer +from optimum.habana.diffusers.pipelines.cogvideox.cogvideoX_gaudi import CogVideoXAttnProcessorGaudi, cogvideoXTransformerForwardGaudi +from optimum.habana.diffusers.pipelines.cogvideox.cogvideoX_gaudi import tiled_decode_gaudi, CogVideoXCausalConv3dforwardGaudi from optimum.habana.diffusers.pipelines.pipeline_utils import GaudiDiffusionPipeline from optimum.habana.transformers.gaudi_configuration import GaudiConfig import habana_frameworks.torch.core as htcore @@ -35,6 +36,9 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXCausalConv3d +setattr(CogVideoXCausalConv3d, 'forward', CogVideoXCausalConv3dforwardGaudi) +setattr(AutoencoderKLCogVideoX, 'tiled_decode', tiled_decode_gaudi) class time_box_t(): def __init__(self): @@ -50,112 +54,6 @@ def show_time(self, desc): self.t0 = t1 print(f'{desc} duration:{duration:.3f}s') -#try: -# from habana_frameworks.torch.hpex.kernels import FusedSDPA -#except ImportError: -# print("Not using HPU fused scaled dot-product attention kernel.") -# FusedSDPA = None -# -## FusedScaledDotProductAttention -#class ModuleFusedSDPA(torch.nn.Module): -# def __init__(self, fusedSDPA): -# super().__init__() -# self._hpu_kernel_fsdpa = fusedSDPA -# -# def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): -# return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) -# -# -#from diffusers.models.attention import Attention -# -#def apply_rotary_emb( -# x: torch.Tensor, -# freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], -#) -> Tuple[torch.Tensor, torch.Tensor]: -# """ -# Adapted from: https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/models/embeddings.py#L697 -# """ -# cos_, sin_ = freqs_cis # [S, D] -# -# cos = cos_[None, None] -# sin = sin_[None, None] -# cos, sin = cos.to(x.device), sin.to(x.device) -# -# x = torch.ops.hpu.rotary_pos_embedding(x, sin, cos, None, 0, 1) -# -# return x -# -#class CogVideoXAttnProcessorGaudi: -# r""" -# Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on -# query and key vectors, but does not include spatial normalization. -# """ -# -# def __init__(self): -# self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None -# -# def __call__( -# self, -# attn: Attention, -# hidden_states: torch.Tensor, -# encoder_hidden_states: torch.Tensor, -# attention_mask: Optional[torch.Tensor] = None, -# image_rotary_emb: Optional[torch.Tensor] = None, -# ) -> torch.Tensor: -# print(f'run gaudi transformer attention_processor with fused SDPA!') -# text_seq_length = encoder_hidden_states.size(1) -# -# hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) -# -# batch_size, sequence_length, _ = ( -# hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape -# ) -# -# if attention_mask is not None: -# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) -# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) -# -# query = attn.to_q(hidden_states) -# key = attn.to_k(hidden_states) -# value = attn.to_v(hidden_states) -# -# inner_dim = key.shape[-1] -# head_dim = inner_dim // attn.heads -# -# query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) -# key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) -# value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) -# -# if attn.norm_q is not None: -# query = attn.norm_q(query) -# if attn.norm_k is not None: -# key = attn.norm_k(key) -# -# # Apply RoPE if needed -# if image_rotary_emb is not None: -# query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) -# if not attn.is_cross_attention: -# key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) -# -# hidden_states = self.fused_scaled_dot_product_attention( -# query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_casual=False, scale=None, softmax_mode='fast' -# ) -# -# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) -# -# # linear proj -# hidden_states = attn.to_out[0](hidden_states) -# # dropout -# hidden_states = attn.to_out[1](hidden_states) -# -# encoder_hidden_states, hidden_states = hidden_states.split( -# [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 -# ) -# return hidden_states, encoder_hidden_states -# -#from diffusers.models import attention_processor -#setattr(attention_processor, 'CogVideoXAttnProcessor2_0', CogVideoXAttnProcessorGaudi) - @dataclass class GaudiTextToVideoSDPipelineOutput(BaseOutput): r""" @@ -229,216 +127,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps -from diffusers.utils import USE_PEFT_BACKEND -def gaudi_forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: Union[int, float, torch.LongTensor], - timestep_cond: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, -): - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - - batch_size, num_frames, channels, height, width = hidden_states.shape - - # 1. Time embedding - timesteps = timestep - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - emb = self.time_embedding(t_emb, timestep_cond) - - # 2. Patch embedding - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - hidden_states = self.embedding_dropout(hidden_states) - - text_seq_length = encoder_hidden_states.shape[1] - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] - - print(f'baymax debug run gaudi transformer forward!') - # 3. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - ) - htcore.mark_step() - - if not self.config.use_rotary_positional_embeddings: - # CogVideoX-2B - hidden_states = self.norm_final(hidden_states) - else: - # CogVideoX-5B - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - hidden_states = self.norm_final(hidden_states) - hidden_states = hidden_states[:, text_seq_length:] - - # 4. Final block - hidden_states = self.norm_out(hidden_states, temb=emb) - hidden_states = self.proj_out(hidden_states) - - # 5. Unpatchify - # Note: we use `-1` instead of `channels`: - # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) - # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) - -def tiled_decode_gaudi(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - r""" - Decode a batch of images using a tiled decoder. - Args: - z (`torch.Tensor`): Input batch of latent vectors. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. - - Returns: - [`~models.vae.DecoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is - returned. - """ - # Rough memory assessment: - # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. - # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. - # - Assume fp16 (2 bytes per value). - # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB - # - # Memory assessment when using tiling: - # - Assume everything as above but now HxW is 240x360 by tiling in half - # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB - - print('run gaudi pipelined tiled decode!') - batch_size, num_channels, num_frames, height, width = z.shape - - overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) - overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) - blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) - blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) - row_limit_height = self.tile_sample_min_height - blend_extent_height - row_limit_width = self.tile_sample_min_width - blend_extent_width - frame_batch_size = self.num_latent_frames_batch_size - - # Split z into overlapping tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - rows = [] - for i in range(0, height, overlap_height): - row = [] - for j in range(0, width, overlap_width): - num_batches = max(num_frames // frame_batch_size, 1) - conv_cache = None - time = [] - - for k in range(num_batches): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) - end_frame = frame_batch_size * (k + 1) + remaining_frames - tile = z[ - :, - :, - start_frame:end_frame, - i : i + self.tile_latent_min_height, - j : j + self.tile_latent_min_width, - ].clone() - if self.post_quant_conv is not None: - tile = self.post_quant_conv(tile) - tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) - time.append(tile.clone()) - htcore.mark_step() - - row.append(torch.cat(time, dim=2)) - rows.append(row) - - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent_width) - result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) - result_rows.append(torch.cat(result_row, dim=4)) - - dec = torch.cat(result_rows, dim=3) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - -setattr(AutoencoderKLCogVideoX, 'tiled_decode', tiled_decode_gaudi) - -import torch.nn.functional as F -def CogVideoXCausalConv3dforwardGaudi(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: - #print('run gaudi CogVideoXCausalConv3d forward!') - inputs = self.fake_context_parallel_forward(inputs, conv_cache) - #conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() - - padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - inputs_pad = F.pad(inputs, padding_2d, mode="constant", value=0) - - output = self.conv(inputs_pad) - if self.time_kernel_size>1: - if conv_cache is not None and conv_cache.shape == inputs[:, :, -self.time_kernel_size + 1:].shape: - conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:]) - else: - conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() - return output, conv_cache - -from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXCausalConv3d -setattr(CogVideoXCausalConv3d, 'forward', CogVideoXCausalConv3dforwardGaudi) class GaudiCogVideoXPipeline(GaudiDiffusionPipeline, CogVideoXPipeline): r""" @@ -474,8 +163,11 @@ def __init__( scheduler, ) self.to(self._device) - self.transformer.forward = gaudi_forward + self.transformer.forward = cogvideoXTransformerForwardGaudi print(f'vae decode name:{self.vae.tiled_decode.__name__}') + for block in self.transformer.transformer_blocks: + block.attn1.set_processor(CogVideoXAttnProcessorGaudi()) + print(f'set gaudi attention Processor done!') from habana_frameworks.torch.hpu import wrap_in_hpu_graph self.vae.decoder = wrap_in_hpu_graph(self.vae.decoder) From cbf7ee1ef2227b41354e481f5745525a31d3be7f Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Tue, 11 Feb 2025 09:30:35 +0000 Subject: [PATCH 21/23] mv gaudi func to cogvideo pipelines. --- .../pipelines/cogvideox/cogvideoX_gaudi.py | 318 ----------------- .../cogvideox/pipeline_cogvideox_gaudi.py | 323 +++++++++++++++++- 2 files changed, 313 insertions(+), 328 deletions(-) delete mode 100644 optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py diff --git a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py deleted file mode 100644 index 5cbfc6427c..0000000000 --- a/optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py +++ /dev/null @@ -1,318 +0,0 @@ -from typing import Any, Dict, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from diffusers.models.attention import Attention -from diffusers.models.autoencoders.vae import DecoderOutput -from diffusers.utils import USE_PEFT_BACKEND - - -try: - from habana_frameworks.torch.hpex.kernels import FusedSDPA -except ImportError: - print("Not using HPU fused scaled dot-product attention kernel.") - FusedSDPA = None - -# FusedScaledDotProductAttention -class ModuleFusedSDPA(torch.nn.Module): - def __init__(self, fusedSDPA): - super().__init__() - self._hpu_kernel_fsdpa = fusedSDPA - - def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): - return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) - - - -def apply_rotary_emb( - x: torch.Tensor, - freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Adapted from: https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/models/embeddings.py#L697 - """ - cos_, sin_ = freqs_cis # [S, D] - - cos = cos_[None, None] - sin = sin_[None, None] - cos, sin = cos.to(x.device), sin.to(x.device) - - x = torch.ops.hpu.rotary_pos_embedding(x, sin, cos, None, 0, 1) - - return x - -class CogVideoXAttnProcessorGaudi: - r""" - Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on - query and key vectors, but does not include spatial normalization. - """ - - def __init__(self): - self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - print(f'run gaudi transformer attention_processor with fused SDPA!') - text_seq_length = encoder_hidden_states.size(1) - - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if image_rotary_emb is not None: - query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) - if not attn.is_cross_attention: - key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - - hidden_states = self.fused_scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_casual=False, scale=None, softmax_mode='fast' - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states, hidden_states = hidden_states.split( - [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 - ) - return hidden_states, encoder_hidden_states - -def cogvideoXTransformerForwardGaudi( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - timestep: Union[int, float, torch.LongTensor], - timestep_cond: Optional[torch.Tensor] = None, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - attention_kwargs: Optional[Dict[str, Any]] = None, - return_dict: bool = True, -): - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) - - batch_size, num_frames, channels, height, width = hidden_states.shape - - # 1. Time embedding - timesteps = timestep - t_emb = self.time_proj(timesteps) - - # timesteps does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=hidden_states.dtype) - emb = self.time_embedding(t_emb, timestep_cond) - - # 2. Patch embedding - hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) - hidden_states = self.embedding_dropout(hidden_states) - - text_seq_length = encoder_hidden_states.shape[1] - encoder_hidden_states = hidden_states[:, :text_seq_length] - hidden_states = hidden_states[:, text_seq_length:] - - print(f'baymax debug run gaudi transformer forward!') - # 3. Transformer blocks - for i, block in enumerate(self.transformer_blocks): - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - emb, - image_rotary_emb, - **ckpt_kwargs, - ) - else: - hidden_states, encoder_hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - temb=emb, - image_rotary_emb=image_rotary_emb, - ) - htcore.mark_step() - - if not self.config.use_rotary_positional_embeddings: - # CogVideoX-2B - hidden_states = self.norm_final(hidden_states) - else: - # CogVideoX-5B - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - hidden_states = self.norm_final(hidden_states) - hidden_states = hidden_states[:, text_seq_length:] - - # 4. Final block - hidden_states = self.norm_out(hidden_states, temb=emb) - hidden_states = self.proj_out(hidden_states) - - # 5. Unpatchify - # Note: we use `-1` instead of `channels`: - # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) - # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) - p = self.config.patch_size - output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) - output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (output,) - return Transformer2DModelOutput(sample=output) - -import habana_frameworks.torch.core as htcore -def tiled_decode_gaudi(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - r""" - Decode a batch of images using a tiled decoder. - - Args: - z (`torch.Tensor`): Input batch of latent vectors. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. - - Returns: - [`~models.vae.DecoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is - returned. - """ - # Rough memory assessment: - # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. - # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. - # - Assume fp16 (2 bytes per value). - # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB - # - # Memory assessment when using tiling: - # - Assume everything as above but now HxW is 240x360 by tiling in half - # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB - - print('run gaudi pipelined tiled decode!') - batch_size, num_channels, num_frames, height, width = z.shape - - overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) - overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) - blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) - blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) - row_limit_height = self.tile_sample_min_height - blend_extent_height - row_limit_width = self.tile_sample_min_width - blend_extent_width - frame_batch_size = self.num_latent_frames_batch_size - - # Split z into overlapping tiles and decode them separately. - # The tiles have an overlap to avoid seams between tiles. - rows = [] - for i in range(0, height, overlap_height): - row = [] - for j in range(0, width, overlap_width): - num_batches = max(num_frames // frame_batch_size, 1) - conv_cache = None - time = [] - - for k in range(num_batches): - remaining_frames = num_frames % frame_batch_size - start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) - end_frame = frame_batch_size * (k + 1) + remaining_frames - tile = z[ - :, - :, - start_frame:end_frame, - i : i + self.tile_latent_min_height, - j : j + self.tile_latent_min_width, - ].clone() - if self.post_quant_conv is not None: - tile = self.post_quant_conv(tile) - tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) - time.append(tile.clone()) - htcore.mark_step() - - row.append(torch.cat(time, dim=2)) - rows.append(row) - - result_rows = [] - for i, row in enumerate(rows): - result_row = [] - for j, tile in enumerate(row): - # blend the above tile and the left tile - # to the current tile and add the current tile to the result row - if i > 0: - tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) - if j > 0: - tile = self.blend_h(row[j - 1], tile, blend_extent_width) - result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) - result_rows.append(torch.cat(result_row, dim=4)) - - dec = torch.cat(result_rows, dim=3) - - if not return_dict: - return (dec,) - - return DecoderOutput(sample=dec) - - -def CogVideoXCausalConv3dforwardGaudi(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: - #print('run gaudi CogVideoXCausalConv3d forward!') - inputs = self.fake_context_parallel_forward(inputs, conv_cache) - #conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() - - padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) - inputs_pad = F.pad(inputs, padding_2d, mode="constant", value=0) - - output = self.conv(inputs_pad) - if self.time_kernel_size>1: - if conv_cache is not None and conv_cache.shape == inputs[:, :, -self.time_kernel_size + 1:].shape: - conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:]) - else: - conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() - return output, conv_cache - diff --git a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py index 1cc13382a4..64258149e5 100644 --- a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py +++ b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py @@ -19,23 +19,330 @@ from typing import Any, Dict, Optional, Tuple, Union import torch +import torch.nn as nn +import torch.nn.functional as F from diffusers import CogVideoXPipeline from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from diffusers.utils import BaseOutput, logging from diffusers.utils.torch_utils import randn_tensor +from diffusers.models.attention import Attention +from diffusers.models.autoencoders.vae import DecoderOutput +from diffusers.utils import USE_PEFT_BACKEND from transformers import T5EncoderModel, T5Tokenizer -from optimum.habana.diffusers.pipelines.cogvideox.cogvideoX_gaudi import CogVideoXAttnProcessorGaudi, cogvideoXTransformerForwardGaudi -from optimum.habana.diffusers.pipelines.cogvideox.cogvideoX_gaudi import tiled_decode_gaudi, CogVideoXCausalConv3dforwardGaudi from optimum.habana.diffusers.pipelines.pipeline_utils import GaudiDiffusionPipeline from optimum.habana.transformers.gaudi_configuration import GaudiConfig -import habana_frameworks.torch.core as htcore logger = logging.get_logger(__name__) # pylint: disable=invalid-name +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Adapted from: https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/models/embeddings.py#L697 + """ + cos_, sin_ = freqs_cis # [S, D] + + cos = cos_[None, None] + sin = sin_[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + x = torch.ops.hpu.rotary_pos_embedding(x, sin, cos, None, 0, 1) + + return x + +class CogVideoXAttnProcessorGaudi: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + hidden_states = self.fused_scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_casual=False, scale=None, softmax_mode='fast' + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + +def cogvideoXTransformerForwardGaudi( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + import habana_frameworks.torch.core as htcore + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + htcore.mark_step() + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + +def tiled_decode_gaudi(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + # Rough memory assessment: + # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. + # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. + # - Assume fp16 (2 bytes per value). + # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB + # + # Memory assessment when using tiling: + # - Assume everything as above but now HxW is 240x360 by tiling in half + # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + + print('run gaudi pipelined tiled decode!') + batch_size, num_channels, num_frames, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + import habana_frameworks.torch.core as htcore + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + time = [] + + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ].clone() + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + time.append(tile.clone()) + htcore.mark_step() + + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + +def CogVideoXCausalConv3dforwardGaudi(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor: + #print('run gaudi CogVideoXCausalConv3d forward!') + inputs = self.fake_context_parallel_forward(inputs, conv_cache) + #conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + inputs_pad = F.pad(inputs, padding_2d, mode="constant", value=0) + + output = self.conv(inputs_pad) + if self.time_kernel_size>1: + if conv_cache is not None and conv_cache.shape == inputs[:, :, -self.time_kernel_size + 1:].shape: + conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:]) + else: + conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() + return output, conv_cache + from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXCausalConv3d setattr(CogVideoXCausalConv3d, 'forward', CogVideoXCausalConv3dforwardGaudi) setattr(AutoencoderKLCogVideoX, 'tiled_decode', tiled_decode_gaudi) @@ -68,7 +375,6 @@ class GaudiTextToVideoSDPipelineOutput(BaseOutput): frames: torch.Tensor -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, @@ -146,7 +452,6 @@ def __init__( gaudi_config: Union[str, GaudiConfig] = None, bf16_full_eval: bool = False, ): - print(f'GaudiCogVideoXPipeline init use_habana:{use_habana} use_hpu_graphs:{use_hpu_graphs}') GaudiDiffusionPipeline.__init__( self, use_habana, @@ -164,14 +469,11 @@ def __init__( ) self.to(self._device) self.transformer.forward = cogvideoXTransformerForwardGaudi - print(f'vae decode name:{self.vae.tiled_decode.__name__}') for block in self.transformer.transformer_blocks: block.attn1.set_processor(CogVideoXAttnProcessorGaudi()) - print(f'set gaudi attention Processor done!') from habana_frameworks.torch.hpu import wrap_in_hpu_graph self.vae.decoder = wrap_in_hpu_graph(self.vae.decoder) - #print(f' vae decode name:{self.vae.tiled_decode.__name__} tiled_decode_gaudi:{tiled_decode_gaudi.__name__}') @property def guidance_scale(self): @@ -410,6 +712,7 @@ def __call__( # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) outputs = [] + import habana_frameworks.torch.core as htcore with self.progress_bar(total=num_inference_steps) as progress_bar: # for DPM-solver++ old_pred_original_sample = None @@ -457,7 +760,7 @@ def __call__( ) latents = latents.to(prompt_embeds.dtype) - if self.use_hpu_graphs: + if not self.use_hpu_graphs: htcore.mark_step() # call the callback, if provided @@ -474,7 +777,7 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if self.use_hpu_graphs: + if not self.use_hpu_graphs: htcore.mark_step() time_box.show_time('transformer_hpu') From 7148011f1d59292ce37b2c2b1db6c384225ec5a5 Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Tue, 11 Feb 2025 10:31:47 +0000 Subject: [PATCH 22/23] refine cogvideox examples. --- examples/text-to-video/README.md | 5 +- examples/text-to-video/cogvideox_generate.py | 86 ----------------- .../text-to-video/text_to_video_generation.py | 95 ++++++++++++------- 3 files changed, 66 insertions(+), 120 deletions(-) delete mode 100644 examples/text-to-video/cogvideox_generate.py diff --git a/examples/text-to-video/README.md b/examples/text-to-video/README.md index a7ab947b24..49905cb5b8 100644 --- a/examples/text-to-video/README.md +++ b/examples/text-to-video/README.md @@ -42,9 +42,10 @@ Models that have been validated: CogvideoX test: ```bash -python3 cogvideox_generate.py \ +python3 text_to_video_generation.py \ --model_name_or_path THUDM/CogVideoX-2b \ - --output_name gaudi_output.mp4 + --pipeline_type 'cogvideox' \ + --video_save_dir 'cogvideo_out' \ ``` diff --git a/examples/text-to-video/cogvideox_generate.py b/examples/text-to-video/cogvideox_generate.py deleted file mode 100644 index 4b95c0a8ee..0000000000 --- a/examples/text-to-video/cogvideox_generate.py +++ /dev/null @@ -1,86 +0,0 @@ -import argparse -import logging - -import torch -from diffusers.utils import export_to_video - -from optimum.habana.diffusers.pipelines.cogvideox.cogvideoX_gaudi import adapt_cogvideo_to_gaudi -from optimum.habana.diffusers.pipelines.cogvideox.pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline -from optimum.habana.transformers.gaudi_configuration import GaudiConfig - - -logger = logging.getLogger(__name__) - - -def main(): - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - - parser.add_argument( - "--model_name_or_path", - default="THUDM/CogVideoX-2b", - type=str, - help="Path to pre-trained model", - ) - # Pipeline arguments - parser.add_argument( - "--prompts", - type=str, - nargs="*", - default="A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance.", - help="The prompt or prompts to guide the video generation.", - ) - parser.add_argument( - "--output_name", - default="panda_gaudi.mp4", - type=str, - help="Path to pre-trained model", - ) - - args = parser.parse_args() - - gaudi_config_kwargs = {"use_fused_adam": True, "use_fused_clip_norm": True} - gaudi_config_kwargs["use_torch_autocast"] = True - - gaudi_config = GaudiConfig(**gaudi_config_kwargs) - logger.info(f"Gaudi Config: {gaudi_config}") - - - kwargs = { - "use_habana": True, - "use_hpu_graphs": True, - "gaudi_config": gaudi_config, - } - kwargs["torch_dtype"] = torch.bfloat16 - - - print('now to load model.....') - pipe = GaudiCogVideoXPipeline.from_pretrained( - args.model_name_or_path, - **kwargs - ) - print('pipe line load done!') - - pipe.vae.enable_tiling() - pipe.vae.enable_slicing() - - print('now to generate video.') - video = pipe( - prompt=args.prompts, - num_videos_per_prompt=1, - num_inference_steps=50, - num_frames=49, - guidance_scale=6, - generator=torch.Generator(device="cpu").manual_seed(42), - ).frames[0] - - print('generate video done!') - - export_to_video(video, args.output_name, fps=8) - - - -if __name__ == "__main__": - main() - - diff --git a/examples/text-to-video/text_to_video_generation.py b/examples/text-to-video/text_to_video_generation.py index 8813e321cf..220722224b 100755 --- a/examples/text-to-video/text_to_video_generation.py +++ b/examples/text-to-video/text_to_video_generation.py @@ -23,7 +23,9 @@ import torch from diffusers.utils.export_utils import export_to_video +from optimum.habana.diffusers.pipelines.cogvideox.pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline from optimum.habana.diffusers import GaudiTextToVideoSDPipeline +#from optimum.habana.diffusers import GaudiCogVideoXPipeline from optimum.habana.transformers.gaudi_configuration import GaudiConfig from optimum.habana.utils import set_seed @@ -60,6 +62,13 @@ def main(): default="Spiderman is surfing", help="The prompt or prompts to guide the video generation.", ) + parser.add_argument( + "--pipeline_type", + type=str, + nargs="*", + default="sdp", + help="pipeline type:sdp or cogvideoX", + ) parser.add_argument( "--num_videos_per_prompt", type=int, default=1, help="The number of videos to generate per prompt." ) @@ -178,38 +187,60 @@ def main(): kwargs["torch_dtype"] = torch.float32 # Generate images - pipeline: GaudiTextToVideoSDPipeline = GaudiTextToVideoSDPipeline.from_pretrained( - args.model_name_or_path, **kwargs - ) - set_seed(args.seed) - outputs = pipeline( - prompt=args.prompts, - num_videos_per_prompt=args.num_videos_per_prompt, - batch_size=args.batch_size, - num_inference_steps=args.num_inference_steps, - guidance_scale=args.guidance_scale, - negative_prompt=args.negative_prompts, - eta=args.eta, - output_type="pil" if args.output_type == "mp4" else args.output_type, # Naming inconsistency in base class - **kwargs_call, - ) - - # Save the pipeline in the specified directory if not None - if args.pipeline_save_dir is not None: - pipeline.save_pretrained(args.pipeline_save_dir) - - # Save images in the specified directory if not None and if they are in PIL format - if args.video_save_dir is not None: - if args.output_type == "mp4": - video_save_dir = Path(args.video_save_dir) - video_save_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"Saving images in {video_save_dir.resolve()}...") - - for i, video in enumerate(outputs.videos): - filename = video_save_dir / f"video_{i + 1}.mp4" - export_to_video(video, str(filename.resolve())) - else: - logger.warning("--output_type should be equal to 'mp4' to save images in --video_save_dir.") + if args.pipeline_type[0] == 'sdp': + pipeline: GaudiTextToVideoSDPipeline = GaudiTextToVideoSDPipeline.from_pretrained( + args.model_name_or_path, **kwargs + ) + set_seed(args.seed) + outputs = pipeline( + prompt=args.prompts, + num_videos_per_prompt=args.num_videos_per_prompt, + batch_size=args.batch_size, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + negative_prompt=args.negative_prompts, + eta=args.eta, + output_type="pil" if args.output_type == "mp4" else args.output_type, # Naming inconsistency in base class + **kwargs_call, + ) + # Save the pipeline in the specified directory if not None + if args.pipeline_save_dir is not None: + pipeline.save_pretrained(args.pipeline_save_dir) + + # Save images in the specified directory if not None and if they are in PIL format + if args.video_save_dir is not None: + if args.output_type == "mp4": + video_save_dir = Path(args.video_save_dir) + video_save_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving images in {video_save_dir.resolve()}...") + + for i, video in enumerate(outputs.videos): + filename = video_save_dir / f"video_{i + 1}.mp4" + export_to_video(video, str(filename.resolve())) + else: + logger.warning("--output_type should be equal to 'mp4' to save images in --video_save_dir.") + + elif args.pipeline_type[0] == 'cogvideox': + pipeline: GaudiCogVideoXPipeline= GaudiCogVideoXPipeline.from_pretrained( + args.model_name_or_path, **kwargs + ) + pipeline.vae.enable_tiling() + pipeline.vae.enable_slicing() + video = pipeline( + prompt=args.prompts, + num_videos_per_prompt=1, + num_inference_steps=50, + num_frames=49, + guidance_scale=6, + generator=torch.Generator(device="cpu").manual_seed(42), + ).frames[0] + video_save_dir = Path(args.video_save_dir) + video_save_dir.mkdir(parents=True, exist_ok=True) + filename = video_save_dir / f"cogvideoX_out.mp4" + export_to_video(video, str(filename.resolve()), fps=8) + else: + logger.error(f"unsupported pipe line:{args.pipeline_type}") + if __name__ == "__main__": From 1316cf2f29755fccf47bd51232bf0e432ef8610e Mon Sep 17 00:00:00 2001 From: nc-BobLee Date: Tue, 11 Feb 2025 10:40:10 +0000 Subject: [PATCH 23/23] import cogvideo pipeline from OH diffusers. --- examples/text-to-video/text_to_video_generation.py | 3 +-- optimum/habana/diffusers/__init__.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text-to-video/text_to_video_generation.py b/examples/text-to-video/text_to_video_generation.py index 220722224b..26c83bdf98 100755 --- a/examples/text-to-video/text_to_video_generation.py +++ b/examples/text-to-video/text_to_video_generation.py @@ -23,9 +23,8 @@ import torch from diffusers.utils.export_utils import export_to_video -from optimum.habana.diffusers.pipelines.cogvideox.pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline from optimum.habana.diffusers import GaudiTextToVideoSDPipeline -#from optimum.habana.diffusers import GaudiCogVideoXPipeline +from optimum.habana.diffusers import GaudiCogVideoXPipeline from optimum.habana.transformers.gaudi_configuration import GaudiConfig from optimum.habana.utils import set_seed diff --git a/optimum/habana/diffusers/__init__.py b/optimum/habana/diffusers/__init__.py index 86b6477c0b..086257a8f8 100644 --- a/optimum/habana/diffusers/__init__.py +++ b/optimum/habana/diffusers/__init__.py @@ -1,4 +1,5 @@ from .pipelines.auto_pipeline import AutoPipelineForInpainting, AutoPipelineForText2Image +from .pipelines.cogvideox.pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline from .pipelines.controlnet.pipeline_controlnet import GaudiStableDiffusionControlNetPipeline from .pipelines.controlnet.pipeline_stable_video_diffusion_controlnet import ( GaudiStableVideoDiffusionControlNetPipeline,