Skip to content

Commit

Permalink
Avoid parameters function in op backward pass
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 committed Jan 11, 2025
1 parent a65ad37 commit ce8f571
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions transformer_engine/pytorch/ops/fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def forward(
func_ctx.backward_ops = backward_ops
func_ctx.basic_ops = basic_ops
func_ctx.basic_op_ctxs = basic_op_ctxs
func_ctx.num_params = num_params
func_ctx.basic_op_num_params = [
sum(1 for _ in op.parameters()) for op in basic_ops
]
func_ctx.num_extra_inputs = num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
Expand Down Expand Up @@ -258,14 +260,14 @@ def backward(
# Flatten list of parameter gradients
grad_params_flat = []
for idx, dparams in enumerate(grad_params):
params = list(basic_ops[idx].parameters())
num_params = func_ctx.basic_op_num_params[idx]
if dparams is None:
dparams = [None for _ in range(len(params))]
dparams = [None for _ in range(num_params)]
else:
dparams = list(dparams)
if len(dparams) != len(params):
if len(dparams) != num_params:
raise RuntimeError(
f"Expected op {idx} to generate {len(params)} param grads, "
f"Expected op {idx} to generate {num_params} param grads, "
f"but got {len(dparams)}"
)
grad_params_flat.extend(dparams)
Expand Down

0 comments on commit ce8f571

Please sign in to comment.