Skip to content

Commit

Permalink
Merge branch 'main' into fix_load
Browse files Browse the repository at this point in the history
  • Loading branch information
cyanguwa authored Dec 13, 2024
2 parents d58ad6a + e7bfc0c commit 00a7ca9
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 14 deletions.
1 change: 1 addition & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
|| github.actor == 'kocchop'
|| github.actor == 'youngeunkwon0405'
|| github.actor == 'KshitijLakhani'
|| github.actor == 'jberchtold-nvidia'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/common/normalization/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {

class NormalizationPlanRegistry {
public:
// TODO thread-safe
static NormalizationPlanRegistry& getInstance() {
static NormalizationPlanRegistry instance;
static thread_local NormalizationPlanRegistry instance;
return instance;
}

Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size

wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]

out_types = [
ir.RankedTensorType.get(out_shape, output_type),
Expand Down Expand Up @@ -441,7 +441,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon):

sm_margin = get_backward_sm_margin()

wkspace_aval = ctx.avals_out[-4:]
wkspace_aval = ctx.avals_out[-1]
opaque = transformer_engine_jax.pack_norm_descriptor(
batch_size,
hidden_size,
Expand Down Expand Up @@ -650,7 +650,7 @@ def lowering(ctx, x, gamma, *, epsilon):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size

wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]

out_types = [
ir.RankedTensorType.get(out_shape, x_type.element_type),
Expand Down Expand Up @@ -841,7 +841,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon):
hidden_size = reduce(operator.mul, g_shape)
batch_size = reduce(operator.mul, x_shape) // hidden_size

wkspace_aval = ctx.avals_out[-3:]
wkspace_aval = ctx.avals_out[-1]

out_types = [
ir.RankedTensorType.get(x_shape, x_type.element_type),
Expand Down Expand Up @@ -1088,7 +1088,7 @@ def lowering(
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size

wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]

out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
Expand Down Expand Up @@ -1394,7 +1394,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon):
batch_shape = out_shape[:-1]
batch_size = reduce(operator.mul, x_shape) // hidden_size

wkspace_aval = ctx.avals_out[-2:]
wkspace_aval = ctx.avals_out[-1]

out_types = [
ir.RankedTensorType.get(out_shape, ir_out_dtype),
Expand Down
24 changes: 18 additions & 6 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,24 @@ pybind11::dict Registrations() {
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

// Normalization
dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler);
dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler);
dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler);
dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler);
dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler);
dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler);
dict["te_layernorm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler));
dict["te_layernorm_forward_fp8_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler));
dict["te_layernorm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler));
dict["te_rmsnorm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler));
dict["te_rmsnorm_forward_fp8_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler));
dict["te_rmsnorm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler));

// Attention
pybind11::dict fused_attn_forward_ffi;
Expand Down

0 comments on commit 00a7ca9

Please sign in to comment.