From c35b3518f1ad4da5b7a97453cb37da8d5c1097ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Nov 2024 20:39:01 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/gemm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index d7b1fac4f9..2367e9219b 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -151,7 +151,6 @@ def _gemm_fwd_rule( fuse_bias = bias is not None - # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) --------> ([B], M, N/P) # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) out, pre_gelu_out, extra_out = gemm_impl( @@ -204,7 +203,6 @@ def _gemm_bwd_rule( mirror_dim, (x_inner_dim, kernel_inner_dim), (x.ndim, kernel.ndim) ) - dgrad_overlap_config = None if comm_overlap_config is not None: dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad"