From 0e1d9faed1ef8d341614c31b2fa7694b4a9f39a5 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 12 Dec 2024 08:00:46 -0500 Subject: [PATCH 1/2] [JAX] Bug fix for distributed normalization (#1366) * fix ctx.aval_out indexing for workspace * add cudnn init to prepare phase of norm custom calls * add thread_local for norm registry instance --------- Signed-off-by: Phuong Nguyen --- .../common/normalization/common.h | 3 +-- .../jax/cpp_extensions/normalization.py | 12 +++++----- .../jax/csrc/extensions/pybind.cpp | 24 ++++++++++++++----- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 8a8df63ba4..d1d56d5cc9 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { class NormalizationPlanRegistry { public: - // TODO thread-safe static NormalizationPlanRegistry& getInstance() { - static NormalizationPlanRegistry instance; + static thread_local NormalizationPlanRegistry instance; return instance; } diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 0b7df0b5a8..69d7962b62 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -147,7 +147,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, output_type), @@ -441,7 +441,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): sm_margin = get_backward_sm_margin() - wkspace_aval = ctx.avals_out[-4:] + wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, @@ -650,7 +650,7 @@ def lowering(ctx, x, gamma, *, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, x_type.element_type), @@ -841,7 +841,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): hidden_size = reduce(operator.mul, g_shape) batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-3:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(x_shape, x_type.element_type), @@ -1088,7 +1088,7 @@ def lowering( batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), @@ -1394,7 +1394,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 9b5c156e5d..a319b74d76 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -83,12 +83,24 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization - dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); - dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); - dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); - dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler); - dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler); - dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler); + dict["te_layernorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler)); + dict["te_layernorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler)); + dict["te_layernorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler)); + dict["te_rmsnorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler)); + dict["te_rmsnorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler)); + dict["te_rmsnorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler)); // Attention pybind11::dict fused_attn_forward_ffi; From e7bfc0c547d63332e4f8d65e606dc69f4c22ffbe Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 12 Dec 2024 14:16:09 -0800 Subject: [PATCH 2/2] Add user to CI (#1371) Add Jeremy to ci users Signed-off-by: Kirthi Shankar Sivamani --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 586abd0541..86d22b7944 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -42,6 +42,7 @@ jobs: || github.actor == 'kocchop' || github.actor == 'youngeunkwon0405' || github.actor == 'KshitijLakhani' + || github.actor == 'jberchtold-nvidia' ) steps: - name: Check if comment is issued by authorized person