Skip to content

Commit

Permalink
softmax custom calls with correct encapsulates
Browse files Browse the repository at this point in the history
Signed-off-by: Phuong Nguyen <[email protected]>
  • Loading branch information
phu0ngng committed Dec 13, 2024
1 parent e7bfc0c commit c54b7cc
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions transformer_engine/jax/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,25 @@ pybind11::dict Registrations() {
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
dict["te_dact_lu_dbias_cast_transpose_ffi"] =
EncapsulateFunction(DActLuDBiasCastTransposeHandler);
EncapsulateFFI(DActLuDBiasCastTransposeHandler);
dict["te_dgated_act_lu_cast_transpose_ffi"] =
EncapsulateFunction(DGatedActLuCastTransposeHandler);
EncapsulateFFI(DGatedActLuCastTransposeHandler);

// Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);

// Softmax
dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler);
dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler);
dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler);
dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler);
dict["te_scaled_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler);
EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler);
dict["te_scaled_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler);
EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler);
dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler);
EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler);
dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

// Normalization
dict["te_layernorm_forward_ffi"] =
Expand Down

0 comments on commit c54b7cc

Please sign in to comment.