Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cogvideox support for gaudi. #1600

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e87f823
add cogvideox support for gaudi.
nc-BobLee Dec 12, 2024
fd556ca
update README for cogvideX
nc-BobLee Dec 12, 2024
6a8e73d
import cogvideo module from optimumu lib
nc-BobLee Dec 12, 2024
7092a51
refine test examples
nc-BobLee Dec 12, 2024
13b86c8
fix module import defect
Zhiwei35 Dec 17, 2024
d125fe6
update module import method
Zhiwei35 Dec 17, 2024
4698ded
upgrade for diffusers version 0.31.0
Zhiwei35 Dec 17, 2024
33a35da
Merge branch 'huggingface:main' into cogvideox_dev
nc-BobLee Dec 17, 2024
21caddc
add cogVideo test case.
ranzhejiang Dec 18, 2024
feff2a3
refine model default path
nc-BobLee Dec 18, 2024
ae05af9
add required python lib for cogvideo
nc-BobLee Dec 19, 2024
12badb8
refine README.MD
nc-BobLee Jan 13, 2025
731bd91
Merge branch 'huggingface:main' into cogvideox_dev
nc-BobLee Jan 15, 2025
7df1a6c
use gaudi implementation of apply rotary embedding.
nc-BobLee Jan 15, 2025
6919313
fix htcore defect
gyou2021 Jan 23, 2025
c15aa51
fix can't find htcore defect.
gyou2021 Jan 23, 2025
687caf9
support for G3 on graph optimization
nc-BobLee Jan 23, 2025
339e31f
clear debug code,
nc-BobLee Jan 23, 2025
54fa10c
Merge branch 'huggingface:main' into cogvideox_dev
nc-BobLee Jan 24, 2025
4ab7ebe
set transformer gaudi fowrad in pipelines.
ranzhejiang Jan 26, 2025
c3e253e
set autoencoder tiled decode gaudi wit setattr.
nc-BobLee Jan 26, 2025
bb036d6
move cogvideox conv3d to gaudi pipeline.
Feb 7, 2025
65fb0ed
remove import gaudi function in __init__.py
nc-BobLee Feb 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/text-to-video/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,12 @@ python3 text_to_video_generation.py \

Models that have been validated:
- [ali-vilab/text-to-video-ms-1.7b](https://huggingface.co/ali-vilab/text-to-video-ms-1.7b)

CogvideoX test:
```bash
python3 cogvideo_generate.py \
nc-BobLee marked this conversation as resolved.
Show resolved Hide resolved
--model_name_or_path CogVideoX-2b \
nc-BobLee marked this conversation as resolved.
Show resolved Hide resolved
--output_name gaudi_output.mp4
```


86 changes: 86 additions & 0 deletions examples/text-to-video/cogvideox_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import argparse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this file and use text_to_video_generation.py script for this sample (need to switch pipeline based on model, similarly as one in text_to_image_generation.py)

import logging

import torch
from diffusers.utils import export_to_video

from optimum.habana.diffusers.pipelines.cogvideox.cogvideoX_gaudi import adapt_cogvideo_to_gaudi
from optimum.habana.diffusers.pipelines.cogvideox.pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline
from optimum.habana.transformers.gaudi_configuration import GaudiConfig


logger = logging.getLogger(__name__)


def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)


parser.add_argument(
"--model_name_or_path",
default="THUDM/CogVideoX-2b",
type=str,
help="Path to pre-trained model",
)
# Pipeline arguments
parser.add_argument(
"--prompts",
type=str,
nargs="*",
default="A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance.",
help="The prompt or prompts to guide the video generation.",
)
parser.add_argument(
"--output_name",
default="panda_gaudi.mp4",
type=str,
help="Path to pre-trained model",
)

args = parser.parse_args()

gaudi_config_kwargs = {"use_fused_adam": True, "use_fused_clip_norm": True}
gaudi_config_kwargs["use_torch_autocast"] = True

gaudi_config = GaudiConfig(**gaudi_config_kwargs)
logger.info(f"Gaudi Config: {gaudi_config}")


kwargs = {
"use_habana": True,
"use_hpu_graphs": True,
"gaudi_config": gaudi_config,
}
kwargs["torch_dtype"] = torch.bfloat16


print('now to load model.....')
pipe = GaudiCogVideoXPipeline.from_pretrained(
args.model_name_or_path,
**kwargs
)
print('pipe line load done!')

pipe.vae.enable_tiling()
pipe.vae.enable_slicing()

print('now to generate video.')
video = pipe(
prompt=args.prompts,
num_videos_per_prompt=1,
num_inference_steps=50,
num_frames=49,
guidance_scale=6,
generator=torch.Generator(device="cpu").manual_seed(42),
).frames[0]

print('generate video done!')

export_to_video(video, args.output_name, fps=8)



if __name__ == "__main__":
main()


4 changes: 4 additions & 0 deletions examples/text-to-video/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
opencv-python-headless
sentencepiece
imageio
imageio-ffmpeg

1 change: 1 addition & 0 deletions optimum/habana/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .pipelines.cogvideox.cogvideoX_gaudi import adapt_cogvideo_to_gaudi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove from here

from .pipelines.auto_pipeline import AutoPipelineForInpainting, AutoPipelineForText2Image
from .pipelines.controlnet.pipeline_controlnet import GaudiStableDiffusionControlNetPipeline
from .pipelines.controlnet.pipeline_stable_video_diffusion_controlnet import (
Expand Down
285 changes: 285 additions & 0 deletions optimum/habana/diffusers/pipelines/cogvideox/cogvideoX_gaudi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn


try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None

# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)


from diffusers.models.attention import Attention


class CogVideoXAttnProcessorGaudi:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adapted from? Please put reference here

r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""

def __init__(self):
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)

hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb

query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)

hidden_states = self.fused_scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_casual=False, scale=None, softmax_mode='fast'
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states

import torch.nn.functional as F
from diffusers.models import attention_processor


attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi

from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXSafeConv3d
from diffusers.models.autoencoders.vae import DecoderOutput


class CogVideoXCausalConv3dGaudi(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adapted from? Please put reference here

r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.

Args:
in_channels (`int`): Number of channels in the input tensor.
out_channels (`int`): Number of output channels produced by the convolution.
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
stride (`int`, defaults to `1`): Stride of the convolution.
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
pad_mode (`str`, defaults to `"constant"`): Padding mode.
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: int = 1,
dilation: int = 1,
pad_mode: str = "constant",
):
super().__init__()

if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3

time_kernel_size, height_kernel_size, width_kernel_size = kernel_size

self.pad_mode = pad_mode
time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2

self.height_pad = height_pad
self.width_pad = width_pad
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)

self.temporal_dim = 2
self.time_kernel_size = time_kernel_size

stride = (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = CogVideoXSafeConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
)


def fake_context_parallel_forward(
self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
) -> torch.Tensor:
kernel_size = self.time_kernel_size
if kernel_size > 1:
cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
inputs = torch.cat(cached_inputs + [inputs], dim=2)
return inputs

def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs, conv_cache)
#conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()

padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
inputs_pad = F.pad(inputs, padding_2d, mode="constant", value=0)

output = self.conv(inputs_pad)
if self.time_kernel_size>1:
if conv_cache is not None and conv_cache.shape == inputs[:, :, -self.time_kernel_size + 1:].shape:
conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1:])
else:
conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone()
return output, conv_cache

from diffusers.models.autoencoders import autoencoder_kl_cogvideox


autoencoder_kl_cogvideox.CogVideoXCausalConv3d = CogVideoXCausalConv3dGaudi

from diffusers.models.autoencoders.autoencoder_kl_cogvideox import AutoencoderKLCogVideoX


class AutoencoderKLCogVideoXGaudi(AutoencoderKLCogVideoX):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adapted from? Please put reference here

def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.

Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
# Rough memory assessment:
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
# - Assume fp16 (2 bytes per value).
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
#
# Memory assessment when using tiling:
# - Assume everything as above but now HxW is 240x360 by tiling in half
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB

print('run gaudi tiled decode!')
batch_size, num_channels, num_frames, height, width = z.shape

overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
frame_batch_size = self.num_latent_frames_batch_size

# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
time = []

for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = z[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
].clone()
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
time.append(tile.clone())

row.append(torch.cat(time, dim=2))
rows.append(row)

result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))

dec = torch.cat(result_rows, dim=3)

if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)

from diffusers.models.autoencoders import autoencoder_kl_cogvideox


autoencoder_kl_cogvideox.AutoencoderKLCogVideoX=AutoencoderKLCogVideoXGaudi

import diffusers
def adapt_cogvideo_to_gaudi():
diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoXCausalConv3d = CogVideoXCausalConv3dGaudi
diffusers.models.autoencoders.autoencoder_kl_cogvideox.AutoencoderKLCogVideoX = AutoencoderKLCogVideoXGaudi
diffusers.models.attention_processor.CogVideoXAttnProcessor2_0 = CogVideoXAttnProcessorGaudi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are good clues where to place the definitions in this file under optimum/habana/diffusers/models path



Loading