From b2a75a5db0eb381c728529dd6a8ff0a0240a41cb Mon Sep 17 00:00:00 2001 From: rlsu9 Date: Tue, 31 Dec 2024 22:31:34 +0000 Subject: [PATCH] update --- fastvideo/distill.py | 11 ++++++----- fastvideo/models/hunyuan/modules/models.py | 8 ++++---- fastvideo/models/mochi_hf/modeling_mochi.py | 6 ++++-- fastvideo/utils/env_utils.py | 2 +- 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/fastvideo/distill.py b/fastvideo/distill.py index 3c59a97..656658a 100644 --- a/fastvideo/distill.py +++ b/fastvideo/distill.py @@ -54,8 +54,9 @@ def reshard_fsdp(model): def get_norm(model_pred, norms, gradient_accumulation_steps): - fro_norm = (torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore - gradient_accumulation_steps) + fro_norm = ( + torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore + gradient_accumulation_steps) largest_singular_value = (torch.linalg.matrix_norm(model_pred, ord=2) / gradient_accumulation_steps) absolute_mean = torch.mean( @@ -65,7 +66,7 @@ def get_norm(model_pred, norms, gradient_accumulation_steps): dist.all_reduce(fro_norm, op=dist.ReduceOp.AVG) dist.all_reduce(largest_singular_value, op=dist.ReduceOp.AVG) dist.all_reduce(absolute_mean, op=dist.ReduceOp.AVG) - norms["fro"] += torch.mean(fro_norm).item() # codespell:ignore + norms["fro"] += torch.mean(fro_norm).item() # codespell:ignore norms["largest singular value"] += torch.mean( largest_singular_value).item() norms["absolute mean"] += absolute_mean.item() @@ -100,7 +101,7 @@ def distill_one_step( total_loss = 0.0 optimizer.zero_grad() model_pred_norm = { - "fro": 0.0, # codespell:ignore + "fro": 0.0, # codespell:ignore "largest singular value": 0.0, "absolute mean": 0.0, "absolute max": 0.0, @@ -555,7 +556,7 @@ def get_num_phases(multi_phased_distill_schedule, step): "grad_norm": grad_norm, "pred_fro_norm": - pred_norm["fro"], # codespell:ignore + pred_norm["fro"], # codespell:ignore "pred_largest_singular_value": pred_norm["largest singular value"], "pred_absolute_mean": diff --git a/fastvideo/models/hunyuan/modules/models.py b/fastvideo/models/hunyuan/modules/models.py index 068503f..759897e 100644 --- a/fastvideo/models/hunyuan/modules/models.py +++ b/fastvideo/models/hunyuan/modules/models.py @@ -608,11 +608,11 @@ def forward( txt = encoder_hidden_states[:, 1:] text_states_2 = encoder_hidden_states[:, 0, :self.config. text_states_dim_2] - _, _, ot, oh, ow = x.shape # codespell:ignore + _, _, ot, oh, ow = x.shape # codespell:ignore tt, th, tw = ( - ot // self.patch_size[0], # codespell:ignore - oh // self.patch_size[1], # codespell:ignore - ow // self.patch_size[2], # codespell:ignore + ot // self.patch_size[0], # codespell:ignore + oh // self.patch_size[1], # codespell:ignore + ow // self.patch_size[2], # codespell:ignore ) original_tt = nccl_info.sp_size * tt freqs_cos, freqs_sin = self.get_rotary_pos_embed((original_tt, th, tw)) diff --git a/fastvideo/models/mochi_hf/modeling_mochi.py b/fastvideo/models/mochi_hf/modeling_mochi.py index 9ec884f..8e73c4b 100644 --- a/fastvideo/models/mochi_hf/modeling_mochi.py +++ b/fastvideo/models/mochi_hf/modeling_mochi.py @@ -488,8 +488,10 @@ def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: with torch.autocast(freqs.device.type, enabled=False): # Always run ROPE freqs computation in FP32 - freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), # codespell:ignore - freqs.to(torch.float32)) + freqs = torch.einsum( + "nd,dhf->nhf", + pos.to(torch.float32), # codespell:ignore + freqs.to(torch.float32)) freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) return freqs_cos, freqs_sin diff --git a/fastvideo/utils/env_utils.py b/fastvideo/utils/env_utils.py index b700253..206c4aa 100644 --- a/fastvideo/utils/env_utils.py +++ b/fastvideo/utils/env_utils.py @@ -26,7 +26,7 @@ if is_torch_npu_available(): info["PyTorch version"] += " (NPU)" info["NPU type"] = torch.npu.get_device_name() - info["CANN version"] = torch.version.cann # codespell:ignore + info["CANN version"] = torch.version.cann # codespell:ignore try: import bitsandbytes