From ce8f5715d5cf4b96c5eb5e7ce8452293fa7a7b44 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 11 Jan 2025 03:03:03 +0000 Subject: [PATCH] Avoid `parameters` function in op backward pass Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/fuser.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index dc96c12523..725926e185 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -192,7 +192,9 @@ def forward( func_ctx.backward_ops = backward_ops func_ctx.basic_ops = basic_ops func_ctx.basic_op_ctxs = basic_op_ctxs - func_ctx.num_params = num_params + func_ctx.basic_op_num_params = [ + sum(1 for _ in op.parameters()) for op in basic_ops + ] func_ctx.num_extra_inputs = num_extra_inputs func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() @@ -258,14 +260,14 @@ def backward( # Flatten list of parameter gradients grad_params_flat = [] for idx, dparams in enumerate(grad_params): - params = list(basic_ops[idx].parameters()) + num_params = func_ctx.basic_op_num_params[idx] if dparams is None: - dparams = [None for _ in range(len(params))] + dparams = [None for _ in range(num_params)] else: dparams = list(dparams) - if len(dparams) != len(params): + if len(dparams) != num_params: raise RuntimeError( - f"Expected op {idx} to generate {len(params)} param grads, " + f"Expected op {idx} to generate {num_params} param grads, " f"but got {len(dparams)}" ) grad_params_flat.extend(dparams)