Skip to content

Commit

Permalink
add K4 video-to-audio pipeline (#181)
Browse files Browse the repository at this point in the history
- add pipeline and optimizations for running v2a on 24Gb VRAM GPU
- rework procedure of applying quantization to K4 pipelines
- adds multiple hacks to prevent errors
- adds tab for uploading and converting video to video+sound
- adds 'send to v2a' button to t2v tab
  • Loading branch information
seruva19 committed Dec 23, 2024
1 parent 4da685d commit 5783ea5
Show file tree
Hide file tree
Showing 36 changed files with 2,214 additions and 134 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
.idea
.local
/venv
main.*
**/__pycache__
**/__triton__
**/.installed
/flagged
/configs/*.yaml
Expand Down
3 changes: 2 additions & 1 deletion client/css.0.base.css
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ body.unselectable * {
}

.options-optimizations>div.wrap {
display: block !important;
/* display: block !important; */
margin-left: -13px !important;
}

.options-optimizations>div.wrap>label {
background: none !important;
border: none !important;
box-shadow: none !important;
width: 30%;
}

.block-resizable-anchor {
Expand Down
3 changes: 2 additions & 1 deletion client/dist/bundle.css
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ body.unselectable * {
}

.options-optimizations>div.wrap {
display: block !important;
/* display: block !important; */
margin-left: -13px !important;
}

.options-optimizations>div.wrap>label {
background: none !important;
border: none !important;
box-shadow: none !important;
width: 30%;
}

.block-resizable-anchor {
Expand Down
2 changes: 1 addition & 1 deletion client/dist/bundle.js
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@
return `
<span style='font-weight: bold'>${key}</span>
<span>was successfully changed to</span>
<span style='font-weight: bold'>${value}</span>
<span style='font-weight: bold; word-wrap: break-word;'>${value}</span>
<span style='color: red'>${requiresRestart ? ' restart required' : ''}</span>
`
}).join('<br>')
Expand Down
2 changes: 1 addition & 1 deletion client/js.2.options.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
return `
<span style='font-weight: bold'>${key}</span>
<span>was successfully changed to</span>
<span style='font-weight: bold'>${value}</span>
<span style='font-weight: bold; word-wrap: break-word;'>${value}</span>
<span style='color: red'>${requiresRestart ? ' restart required' : ''}</span>
`
}).join('<br>')
Expand Down
4 changes: 2 additions & 2 deletions configs/kubin.default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ native:
available_text_encoders: default;google/flan-ul2;pszemraj/flan-ul2-text-encoder;JulesGo/t5-v1_1-xxl-fp8
text_encoder: pszemraj/flan-ul2-text-encoder
use_kandinsky31_flash: false
available_optimization_flags: kd21_flash_attention;kd30_low_vram;kd31_low_vram;kd40_flash_attention;kd40_sage_attention;kd40_fp8_tenc_ao_quantization;kd40_fp8_vae_ao_quantization;kd40_fp8_dit_ao_quantization;kd40_fp8_tenc_oq_quantization;kd40_fp8_vae_oq_quantization;kd40_fp8_dit_oq_quantization;kd40_vae_tiling;kd40_vae_slicing;kd40_model_offload;kd40_save_quantized_weights
optimization_flags: kd30_low_vram;kd31_low_vram;kd40_flash_attention;kd40_fp8_tenc_ao_quantization;kd40_fp8_vae_ao_quantization;kd40_fp8_dit_ao_quantization;kd40_vae_tiling;kd40_vae_slicing
available_optimization_flags: kd21_flash_attention;kd30_low_vram;kd31_low_vram;kd40_flash_attention;kd40_sage_attention;kd40_t2v_tenc_int8_ao_quantization;kd40_t2v_vae_int8_ao_quantization;kd40_t2v_dit_int8_ao_quantization;kd40_t2v_tenc_int8_oq_quantization;kd40_t2v_vae_int8_oq_quantization;kd40_t2v_dit_int8_oq_quantization;kd40_v2a_mm_int8_bnb_quantization;kd40_v2a_mm_nf4_bnb_quantization;kd40_v2a_vae_int8_bnb_quantization;kd40_v2a_vae_nf4_bnb_quantization;kd40_v2a_unet_int8_bnb_quantization;kd40_v2a_unet_nf4_bnb_quantization;kd40_vae_tiling;kd40_vae_slicing;kd40_model_offload;kd40_save_quantized_weights
optimization_flags: kd30_low_vram;kd31_low_vram;kd40_flash_attention;kd40_t2v_tenc_int8_ao_quantization;kd40_t2v_vae_int8_ao_quantization;kd40_t2v_dit_int8_ao_quantization;kd40_v2a_vae_int8_bnb_quantization;kd40_v2a_unet_int8_bnb_quantization;kd40_vae_tiling;kd40_vae_slicing;kd40_v2a_mm_nf4_bnb_quantization

diffusers:
half_precision_weights: true
Expand Down
21 changes: 4 additions & 17 deletions src/kubin.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,13 @@
from patches import patch

patch()

from arguments import parse_arguments
from env import Kubin
from utils.platform import is_windows
from web_gui import gradio_ui
from pathlib import Path
import os

os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import gradio.analytics

check_executed = False


def custom_version_check():
global check_executed
if not check_executed:
check_executed = True
print(
"fyi: kubin uses an old version of Gradio (3.50.2), which is now considered deprecated for security reasons.\nhowever, the author is too stubborn to upgrade (https://github.com/seruva19/kubin/blob/main/DOCS.md#gradio-4)."
)


gradio.analytics.version_check = custom_version_check

kubin = Kubin()
args = parse_arguments()
Expand Down
1 change: 0 additions & 1 deletion src/models/model_40/kandinsky_4/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def parallelize(model, tp_mesh):


class TransformerBlock(nn.Module):

def __init__(self, model_dim, time_dim, ff_dim, attention_type, head_dim=64):
super().__init__()
self.visual_modulation = Modulation(time_dim, model_dim)
Expand Down
2 changes: 0 additions & 2 deletions src/models/model_40/kandinsky_4/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@ def forward(


class FeedForward(nn.Module):

def __init__(self, dim, ff_dim):
super().__init__()
self.in_layer = nn.Linear(dim, ff_dim, bias=True)
Expand All @@ -445,7 +444,6 @@ def forward(self, x):


class OutLayer(nn.Module):

def __init__(self, model_dim, time_dim, visual_dim, patch_size):
super().__init__()
self.patch_size = patch_size
Expand Down
117 changes: 45 additions & 72 deletions src/models/model_40/kandinsky_4/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,30 +121,21 @@ def get_T2V_pipeline(
dit = dit.to(dtype=torch.bfloat16, device=device_map["dit"])
# dit = dit.to(dtype=torch.float8_e4m3fn, device=device_map["dit"])

if environment.use_dit_fp8_quantization:
path, ext = os.path.splitext(conf.dit.checkpoint_path)
if environment.use_torchao_quantization:
k_log("quantizing DiT [torchao-int8]...")
dit = quantize_with_torch_ao(
dit,
True,
(
f"{path}.q8_tao{ext}"
if environment.use_save_quantized_weights
else None
),
)
else:
k_log("quantizing DiT [optimum-quanto]...")
dit = quantize_with_optimum_quanto(
dit,
True,
(
f"{path}.q8_oq{ext}"
if environment.use_save_quantized_weights
else None
),
)
path, ext = os.path.splitext(conf.dit.checkpoint_path)
if environment.use_t2v_tenc_int8_ao_quantization:
k_log("quantizing DiT [torchao-int8]...")
dit = quantize_with_torch_ao(
dit,
True,
(f"{path}.q8_tao{ext}" if environment.use_save_quantized_weights else None),
)
elif environment.use_t2v_dit_int8_oq_quantization:
k_log("quantizing DiT [optimum-quanto]...")
dit = quantize_with_optimum_quanto(
dit,
True,
(f"{path}.q8_oq{ext}" if environment.use_save_quantized_weights else None),
)

noise_scheduler = CogVideoXDDIMScheduler.from_pretrained(conf.dit.scheduler)

Expand All @@ -154,30 +145,21 @@ def get_T2V_pipeline(
k_log("loading text embedder...")
text_embedder = get_text_embedder(conf)

if environment.use_textencoder_fp8_quantization:
path, ext = os.path.splitext(conf.text_embedder.params.checkpoint_path)
if environment.use_torchao_quantization:
k_log("quantizing embedder [torchao-int8]...")
text_embedder.llm = quantize_with_torch_ao(
text_embedder.llm,
True,
(
f"{path}.q8_tao{ext}"
if environment.use_save_quantized_weights
else None
),
)
else:
k_log("quantizing embedder [optimum-quanto]...")
text_embedder.llm = quantize_with_optimum_quanto(
text_embedder.llm,
True,
(
f"{path}.q8_oq{ext}"
if environment.use_save_quantized_weights
else None
),
)
path, ext = os.path.splitext(conf.text_embedder.params.checkpoint_path)
if environment.use_t2v_tenc_int8_ao_quantization:
k_log("quantizing embedder [torchao-int8]...")
text_embedder.llm = quantize_with_torch_ao(
text_embedder.llm,
True,
(f"{path}.q8_tao{ext}" if environment.use_save_quantized_weights else None),
)
elif environment.use_t2v_tenc_int8_oq_quantization:
k_log("quantizing embedder [optimum-quanto]...")
text_embedder.llm = quantize_with_optimum_quanto(
text_embedder.llm,
True,
(f"{path}.q8_oq{ext}" if environment.use_save_quantized_weights else None),
)
else:
text_embedder = text_embedder.freeze()

Expand Down Expand Up @@ -206,30 +188,21 @@ def get_T2V_pipeline(
if local_rank == 0:
vae = vae.to(device_map["vae"], dtype=torch.bfloat16)

if environment.use_vae_fp8_quantization:
path, ext = os.path.splitext(conf.vae.checkpoint_path)
if environment.use_torchao_quantization:
k_log("quantizing vae [torchao-int8]...")
vae = quantize_with_torch_ao(
vae,
True,
(
f"{path}.q8_tao{ext}"
if environment.use_save_quantized_weights
else None
),
)
else:
k_log("quantizing vae [optimum-quanto]...")
vae = quantize_with_optimum_quanto(
vae,
True,
(
f"{path}.q8_oq{ext}"
if environment.use_save_quantized_weights
else None
),
)
path, ext = os.path.splitext(conf.vae.checkpoint_path)
if environment.use_t2v_vae_int8_ao_quantization:
k_log("quantizing vae [torchao-int8]...")
vae = quantize_with_torch_ao(
vae,
True,
(f"{path}.q8_tao{ext}" if environment.use_save_quantized_weights else None),
)
elif environment.use_t2v_vae_int8_oq_quantization:
k_log("quantizing vae [optimum-quanto]...")
vae = quantize_with_optimum_quanto(
vae,
True,
(f"{path}.q8_oq{ext}" if environment.use_save_quantized_weights else None),
)

return Kandinsky4T2VPipeline(
environment=environment,
Expand Down
16 changes: 16 additions & 0 deletions src/models/model_40/kandinsky_4/t2v_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ def __init__(
],
}

self.progress_fn = lambda progress, desc: None

def register_progress_bar(self, progress_fn=None):
self.progress_fn = progress_fn if progress_fn is not None else self.progress_fn

def update_progress(self, step, total_steps):
if hasattr(self, "progress_fn"):
try:
self.progress_fn(
step / total_steps, desc=f"Generating {step}/{total_steps}"
)
except:
self.progress_fn(step, total_steps)
else:
pass

def __call__(
self,
text: str,
Expand Down
Loading

0 comments on commit 5783ea5

Please sign in to comment.