Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rlsu9 committed Dec 31, 2024
1 parent 68a38cc commit b2a75a5
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
11 changes: 6 additions & 5 deletions fastvideo/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def reshard_fsdp(model):


def get_norm(model_pred, norms, gradient_accumulation_steps):
fro_norm = (torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore
gradient_accumulation_steps)
fro_norm = (
torch.linalg.matrix_norm(model_pred, ord="fro") / # codespell:ignore
gradient_accumulation_steps)
largest_singular_value = (torch.linalg.matrix_norm(model_pred, ord=2) /
gradient_accumulation_steps)
absolute_mean = torch.mean(
Expand All @@ -65,7 +66,7 @@ def get_norm(model_pred, norms, gradient_accumulation_steps):
dist.all_reduce(fro_norm, op=dist.ReduceOp.AVG)
dist.all_reduce(largest_singular_value, op=dist.ReduceOp.AVG)
dist.all_reduce(absolute_mean, op=dist.ReduceOp.AVG)
norms["fro"] += torch.mean(fro_norm).item() # codespell:ignore
norms["fro"] += torch.mean(fro_norm).item() # codespell:ignore
norms["largest singular value"] += torch.mean(
largest_singular_value).item()
norms["absolute mean"] += absolute_mean.item()
Expand Down Expand Up @@ -100,7 +101,7 @@ def distill_one_step(
total_loss = 0.0
optimizer.zero_grad()
model_pred_norm = {
"fro": 0.0, # codespell:ignore
"fro": 0.0, # codespell:ignore
"largest singular value": 0.0,
"absolute mean": 0.0,
"absolute max": 0.0,
Expand Down Expand Up @@ -555,7 +556,7 @@ def get_num_phases(multi_phased_distill_schedule, step):
"grad_norm":
grad_norm,
"pred_fro_norm":
pred_norm["fro"], # codespell:ignore
pred_norm["fro"], # codespell:ignore
"pred_largest_singular_value":
pred_norm["largest singular value"],
"pred_absolute_mean":
Expand Down
8 changes: 4 additions & 4 deletions fastvideo/models/hunyuan/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,11 +608,11 @@ def forward(
txt = encoder_hidden_states[:, 1:]
text_states_2 = encoder_hidden_states[:, 0, :self.config.
text_states_dim_2]
_, _, ot, oh, ow = x.shape # codespell:ignore
_, _, ot, oh, ow = x.shape # codespell:ignore
tt, th, tw = (
ot // self.patch_size[0], # codespell:ignore
oh // self.patch_size[1], # codespell:ignore
ow // self.patch_size[2], # codespell:ignore
ot // self.patch_size[0], # codespell:ignore
oh // self.patch_size[1], # codespell:ignore
ow // self.patch_size[2], # codespell:ignore
)
original_tt = nccl_info.sp_size * tt
freqs_cos, freqs_sin = self.get_rotary_pos_embed((original_tt, th, tw))
Expand Down
6 changes: 4 additions & 2 deletions fastvideo/models/mochi_hf/modeling_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,10 @@ def _create_rope(self, freqs: torch.Tensor,
pos: torch.Tensor) -> torch.Tensor:
with torch.autocast(freqs.device.type, enabled=False):
# Always run ROPE freqs computation in FP32
freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), # codespell:ignore
freqs.to(torch.float32))
freqs = torch.einsum(
"nd,dhf->nhf",
pos.to(torch.float32), # codespell:ignore
freqs.to(torch.float32))
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
Expand Down
2 changes: 1 addition & 1 deletion fastvideo/utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
if is_torch_npu_available():
info["PyTorch version"] += " (NPU)"
info["NPU type"] = torch.npu.get_device_name()
info["CANN version"] = torch.version.cann # codespell:ignore
info["CANN version"] = torch.version.cann # codespell:ignore

try:
import bitsandbytes
Expand Down

0 comments on commit b2a75a5

Please sign in to comment.