From ad6bf2ae16ed7d7c54ecf725c649d19e6eab5017 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 24 Oct 2024 22:38:05 +0000 Subject: [PATCH 01/39] added XLA custom op defs for TE GEMM Signed-off-by: Alp Dener Added XLA FFI custom op for TE GEMM Signed-off-by: Alp Dener finished GEMM custom op primitive and serial unit test Signed-off-by: Alp Dener fixed GEMM custom op batcher Signed-off-by: Alp Dener fixed output dtype error and contracting dimensions options Signed-off-by: Alp Dener AG overlap working but executes scatter to match outer LHS dim Signed-off-by: Alp Dener both all-gather and all-reduce are now working Signed-off-by: Alp Dener code style Signed-off-by: Alp Dener changed kwargs in abstract to be explicit Signed-off-by: Alp Dener added fwd/bwd implementation for non-fp8 gemm Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 55 ++ .../jax/cpp_extensions/__init__.py | 1 + transformer_engine/jax/cpp_extensions/gemm.py | 647 ++++++++++++++++++ transformer_engine/jax/cpp_extensions/misc.py | 7 + transformer_engine/jax/csrc/extensions.h | 39 ++ .../jax/csrc/extensions/gemm.cpp | 170 +++++ .../jax/csrc/extensions/packing.cpp | 11 + .../jax/csrc/extensions/pybind.cpp | 5 +- transformer_engine/jax/csrc/utils.h | 2 +- transformer_engine/jax/flax/module.py | 7 +- transformer_engine/jax/fp8.py | 7 +- transformer_engine/jax/gemm.py | 425 ++++++++++++ 12 files changed, 1370 insertions(+), 6 deletions(-) create mode 100644 transformer_engine/jax/cpp_extensions/gemm.py create mode 100644 transformer_engine/jax/csrc/extensions/gemm.cpp create mode 100644 transformer_engine/jax/gemm.py diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 20b16c2809..9bf3f9fa91 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -25,6 +25,7 @@ _jax_dbias_cast_transpose, ) from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8 +from transformer_engine.jax.gemm import fp8_gemm, gemm from transformer_engine.jax import cpp_extensions as tex @@ -415,6 +416,60 @@ def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_ ) +class TestGemm: + + @staticmethod + def _generate_inputs(b, m, n, k, dtype): + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 3) + a = jax.random.normal(subkeys[0], (b, m, k), dtype) + b = jax.random.normal(subkeys[1], (n, k), dtype) + bias_dtype = dtype if dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2] else jnp.bfloat16 + bias = jax.random.normal(subkeys[2], (n, ), bias_dtype) + return a, b, bias + + @staticmethod + def _generate_fp8_inputs(b, m, n, k, fp8_dtype): + a, b, bias = TestGemm._generate_inputs(b, m, n, k, jnp.bfloat16) + a_scale, b_scale = map( + lambda x: (jnp.max(jnp.abs(x)) / 127.).astype(jnp.float32), + [a, b] + ) + a_q, b_q = map( + lambda x, x_scale: jnp.round(x / x_scale).astype(fp8_dtype), + [(a, a_scale), (b, b_scale)] + ) + return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias + + @pytest.mark.parametrize("m,n,k", GEMM_CASES) + @pytest.mark.parametrize("use_bias", (False, True)) + @pytest.mark.parametrize("do_gelu", (False, True)) + def test_gemm(self, b, m, n, k, use_bias, do_gelu): + a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16) + + primitive_out = gemm(a, b, bias=bias if use_bias else None, layout='NT', do_gelu=do_gelu) + ref_out = jnp.dot(a, b) + if use_bias: + ref_out += bias + if do_gelu: + ref_out = jax.nn.gelu(ref_out) + + assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) + + @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.parametrize("m,n,k", GEMM_CASES) + @pytest.mark.parametrize("fp8_dtype", FP8_COMPUTE_TYPE) + def test_fp8_gemm(self, m, n, k, fp8_dtype): + a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs( + m, n, k, fp8_dtype + ) + + primitive_out = fp8_gemm(a_q, a_scale_inv, b_q, b_scale_inv, out_dtype=jnp.bfloat16) + ref_out = jnp.dot(a, b) + + assert_allclose(primitive_out, ref_out, dtype=fp8_dtype) + + @pytest.fixture(name="random_inputs") def random_inputs_fixture(shape): key = jax.random.PRNGKey(0) diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index 579daa8e41..1e5cc4c07e 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -4,6 +4,7 @@ """Python interface for c++ extensions""" from .activation import * from .attention import * +from .gemm import * from .normalization import * from .quantization import * from .softmax import * diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py new file mode 100644 index 0000000000..677fabca59 --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -0,0 +1,647 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for cuBlasLt GEMM""" +import warnings +import operator +from functools import reduce +from typing import Optional, Union, Tuple + +import jax +import jax.numpy as jnp +from jax import dtypes +from jax.interpreters import mlir +from jax.interpreters.mlir import ir +from jax.sharding import PartitionSpec, NamedSharding +from jax.extend import ffi +from jax.typing import ArrayLike + +from transformer_engine import transformer_engine_jax as tex +from .base import BasePrimitive, register_primitive +from .custom_call import custom_caller, CustomCallArgsWrapper +from .misc import ( + jax_dtype_to_te_dtype, + jax_dtype_is_fp8, + get_padded_spec, + is_ffi_enabled, +) +from ..sharding import ( + global_mesh_resource, + get_mesh_axis_size, + lax_paral_op, + all_reduce_max_along_all_axes_except_PP, +) + + +__all__ = [ + "fp8_gemm_impl", + "gemm_impl", +] + + +def get_cublas_workspace_size_bytes() -> None: + """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" + if tex.get_device_compute_capability() >= 90: + return 33_554_432 + return 4_194_304 + + +class CollectiveGemmPrimitive(BasePrimitive): + """ + cuBlasLt GEMM Primitive w/ support for distributed inputs + """ + + name = "te_gemm" + impl_static_args = (8, 9, 10, 11, 12, 13, 14) + multiple_results = True + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_aval, + gelu_input_aval, out_amax_aval, out_scale_aval, out_dtype, contracting_dims, + fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator): + """ + cuBlasLt GEMM abstract + """ + del grad, accumulate, use_split_accumulator + + # Validate operand dtypes + lhs_dtype = dtypes.canonicalize_dtype(lhs_aval.dtype) + rhs_dtype = dtypes.canonicalize_dtype(rhs_aval.dtype) + assert lhs_dtype == rhs_dtype, "Mismatched matrix dtypes for GEMM." + is_fp8 = False + if jax_dtype_is_fp8(lhs_dtype): + assert ( + lhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(lhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing LHS operand scale inverse in FP8 GEMM." + is_fp8 = True + if jax_dtype_is_fp8(rhs_dtype): + assert ( + rhs_scale_inv_aval.size == 1 + and dtypes.canonicalize_dtype(rhs_scale_inv_aval.dtype) == jnp.float32 + ), "Missing RHS operand scale inverse in FP8 GEMM." + + # Disallow batching for RHS + assert rhs_aval.ndim == 2, "GEMM does not support batching the RHS operand." + + # Validate operand layouts + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim) + ) + assert ( + lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] + ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." + + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + assert ( + not (lhs_trans and rhs_trans) + ), "GEMM does not support transposed LHS and transposed RHS at the same time." + if is_fp8: + assert lhs_trans, "FP8 GEMM does not support transposed LHS." + assert rhs_trans, "FP8 GEMM requires transposed RHS." + + # Validate output dtype + if jax_dtype_is_fp8(out_dtype): + assert ( + jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8(rhs_dtype) + ), "FP8 GEMM output requires FP8 inputs." + assert ( + out_amax_aval.size == out_scale_aval.size == 1 + ), "Invalid/missing output amax and scale." + out_amax_updated_dtype = dtypes.canonicalize_dtype(out_amax_aval.dtype) + out_scale_updated_dtype = dtypes.canonicalize_dtype(out_scale_aval.dtype) + assert ( + out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32 + ), "Invalid output amax or scale dtype." + else: + out_dtype = lhs_dtype + out_amax_updated_dtype = jnp.float32 + out_scale_updated_dtype = jnp.float32 + + # Infer output shape + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 + lhs_bdims = [dim for dim in range(lhs_aval.ndim) + if dim not in [lhs_outer_dim, lhs_inner_dim]] + lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + + # Validate bias/bias_grad shape against inferred output + bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype + if fuse_bias: + assert ( + bias_aval.size > 0 + and bias_aval.ndim == 1 + and bias_aval.shape[0] == out_shape[-1] + ), "Incorrect bias shape." + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + else: + assert bias_aval.size == 0, "Internal TE error." + + # Validate GELU input/output + if fuse_gelu: + assert ( + all([gelu_input_aval.shape[i] == out_shape[i] for i in len(out_shape)]) + ), "Invalid GELU input shape." + assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." + else: + assert gelu_input_aval.size == 0, "Internal TE error." + + # Create abstract arrays for all outputs + out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype) + out_amax_updated_aval = out_amax_aval.update(shape=out_amax_aval.shape, + dtype=out_amax_updated_dtype) + out_scale_updated_aval = out_scale_aval.update(shape=out_scale_aval.shape, + dtype=out_scale_updated_dtype) + pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_input_aval.shape, dtype=bias_dtype) + bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + workspace_aval = jax.core.ShapedArray(shape=(get_cublas_workspace_size_bytes(), ), + dtype=jnp.uint8) + + return ( + out_aval, + out_amax_updated_aval, + out_scale_updated_aval, + pre_gelu_out_aval, + bias_grad_aval, + workspace_aval + ) + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + cuBlasLt GEMM outer abstract + """ + ( + out_aval, + out_amax_aval, + out_scale_aval, + pre_gelu_out_aval, + bias_grad_aval, + _ + ) = CollectiveGemmPrimitive.abstract(*args, **kwargs) + return out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval + + @staticmethod + def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, + *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, + use_split_accumulator): + """ + Fused attention fwd lowering rules + """ + lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim) + ) + lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + + operand_output_aliases = { + 4: 4, # bias <--> bias_grad + 5: 3, # gelu_input <--> pre_gelu_out + 6: 1, # out_amax <--> out_amax_updated + 7: 2, # out_scale <--> out_scale_updated + } + + if is_ffi_enabled(): + name = "te_gemm_ffi" + return ffi.ffi_lowering(name, operand_output_aliases=operand_output_aliases)( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + lhs_trans=lhs_trans, + rhs_trans=rhs_trans, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator + ) + else: + operands = [ + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + ] + operand_shapes = map(lambda x: ir.RankedTensorType(x.type).shape, operands) + out_types = [ + ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_dtype(output.dtype)) + for output in ctx.avals_out + ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 + lhs_bdims = [dim for dim in range(lhs_aval.ndim) + if dim not in [lhs_outer_dim, lhs_inner_dim]] + lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + m = reduce(operator.mul, lhs_batch_shape, 1) * lhs_aval.shape[lhs_outer_dim] + k = rhs_aval.shape[rhs_inner_dim] + n = rhs_aval.shape[rhs_outer_dim] + workspace_size = get_cublas_workspace_size_bytes() + operand_dtype = jax_dtype_to_te_dtype(lhs_aval.dtype) + bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype) + opaque = tex.pack_gemm_descriptor(m, n, k, workspace_size, operand_dtype, + jax_dtype_to_te_dtype(out_dtype), bias_dtype, + lhs_trans, rhs_trans, fuse_gelu, fuse_bias, grad, + accumulate, use_split_accumulator) + + return custom_caller( + CollectiveGemmPrimitive.name, + args, + opaque, + has_side_effect=False, + operand_output_aliases=operand_output_aliases, + ) + + @staticmethod + def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, + out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, + use_split_accumulator): + assert CollectiveGemmPrimitive.inner_primitive is not None + + ( + out, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + _, + ) = CollectiveGemmPrimitive.inner_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, + accumulate, use_split_accumulator): + assert CollectiveGemmPrimitive.outer_primitive is not None + + lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale = batched_args + assert rhs.ndim == 2, "TE/JAX GEMM custom op does not support batching RHS operands." + + # Get contracting and batch dimensions out + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs.ndim, rhs.ndim) + ) + lhs_trans = lhs_inner_dim != lhs.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] + + # FP8 GEMM only supports lhs_trans = False and rhs_trans = True so we may need to + # reorder the axes here to match + if jax_dtype_is_fp8(lhs.dtype): + lhs = jnp.transpose(lhs, (*lhs_bdims, lhs_outer_dim, lhs_inner_dim)) + lhs_trans = False + rhs = jnp.transpose(rhs, (rhs_outer_dim, rhs_inner_dim)) + rhs_trans = True + contracting_dims = (1, 1) + + # Collapse all non-contracting dimensions + batch_shape = [lhs.shape[dim] for dim in lhs_bdims] + batch_size = reduce(operator.mul, batch_shape, 1) + lhs_outer_size = lhs.shape[lhs_outer_dim] + lhs_shape_2d = ( + (lhs.shape[lhs_inner_dim], batch_size * lhs_outer_size) + if lhs_trans + else (batch_size * lhs_outer_size, lhs.shape[lhs_inner_dim]) + ) + lhs = jnp.reshape(lhs, lhs_shape_2d) + if fuse_gelu: + gelu_input = jnp.reshape( + gelu_input, (batch_size * lhs_outer_size, rhs.shape[rhs_outer_dim]) + ) + + outputs = CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # Reshape output to recover original LHS batch shape + outputs[0] = jnp.reshape( + outputs[0], + (*batch_shape, lhs_outer_size, rhs.shape[rhs_outer_dim]) + ) + gelu_bdims = batch_dims[3] + if fuse_gelu: + outputs[3] = jnp.reshape(outputs[3], outputs[0].shape) + gelu_bdims = lhs_bdims + + return ( + outputs, + (lhs_bdims, batch_dims[1], batch_dims[2], gelu_bdims, batch_dims[4]) + ) + + @staticmethod + def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, + accumulate, use_split_accumulator, mesh, arg_infos, + result_infos): + del out_dtype, accumulate, use_split_accumulator, result_infos + lhs, _, rhs, *_ = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs.ndim, rhs.ndim) + ) + if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad: + warnings.warn("Forcing the inner dimension of LHS to match the sharding of inner " + + "dimension of RHS. This can trigger additional communication if LHS is " + + "not already partitioned correctly.") + + lhs_trans = lhs_inner_dim != lhs.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] + batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] + rhs_outer_spec = rhs_spec[rhs_outer_dim] + + if rhs_spec[rhs_inner_dim] is not None and rhs_outer_spec is not None: + raise RuntimeError("Both inner and outer dimensions of RHS cannot be sharded.") + + # Outer (sequence) dimension of the GEMM output is always unsharded + out_spec = [*batch_specs, None, rhs_outer_spec] + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded + gelu_spec = out_spec if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + # Bias gradient spec matches outer dimension of output if bias fusion is turned on + bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + + return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) + + @staticmethod + def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, + use_split_accumulator, mesh, arg_infos, result_infos): + del result_infos + lhs, _, rhs, *_ = arg_infos + lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) + + lhs_inner_dim, rhs_inner_dim = map( + lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, + contracting_dims, + (lhs.ndim, rhs.ndim) + ) + + lhs_trans = lhs_inner_dim != lhs.ndim - 1 + rhs_trans = rhs_inner_dim == 1 + lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 + rhs_outer_dim = 0 if rhs_trans else 1 + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] + batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] + rhs_outer_spec = rhs_spec[rhs_outer_dim] + + # Force all-gather the outer (sequence) dimension of the LHS operand + lhs_spec_new = [spec for spec in lhs_spec] + lhs_spec_new[lhs_outer_dim] = None + lhs_spec_new[lhs_inner_dim] = rhs_spec[rhs_inner_dim] + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) + + # RHS operand is unchanged, we already enforce that only one dimension can be sharded + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec)) + + # Bias is sharded to match outer dimension spec of the RHS operand (also the output) + bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + + # FP8 metas are always unsharded + fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) + + # Outer (sequence) dimension of the GEMM output is always unsharded + out_spec = [*batch_specs, None, rhs_outer_spec] + out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) + + # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded + gelu_spec = out_spec if fuse_gelu else [None] + gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + + arg_shardings = (lhs_sharding, fp8_meta_sharding, rhs_sharding, fp8_meta_sharding, + bias_sharding, gelu_sharding, fp8_meta_sharding, fp8_meta_sharding) + out_shardings = (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, + bias_sharding) + + def sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, + out_scale): + ( + out, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + ) = CollectiveGemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + # FP8 amax reduction + if jax_dtype_is_fp8(lhs.dtype): + out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) + + if rhs_spec[rhs_inner_dim] is not None: + # GEMM output needs to be all-reduced when the contracting dimension is sharded. + # If the layer is sequence-parallel, we also need to scatter the output, which we + # can combine into a reduce-scatter here. + out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().cp_resource, + mesh) + if fuse_gelu: + pre_gelu_out = lax_paral_op( + pre_gelu_out, jax.lax.psum, global_mesh_resource().cp_resource, mesh + ) + + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(CollectiveGemmPrimitive) + + +def fp8_gemm_impl( + lhs: ArrayLike, + lhs_scale_inv: ArrayLike, + rhs: ArrayLike, + rhs_scale_inv: ArrayLike, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, + out_amax: Optional[ArrayLike] = None, + out_scale: Optional[ArrayLike] = None, + out_dtype: jnp.dtype = jnp.bfloat16, + contracting_dims: Tuple[int, int] = (1, 1), + fuse_gelu: bool = False, + fuse_bias: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> Tuple[ArrayLike, ...]: + """FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + if out_dtype is not None and jax_dtype_is_fp8(out_dtype): + assert out_amax is not None and out_scale is not None, "Missing output amax and scale." + else: + out_amax = jnp.zeros(0, dtype=jnp.float32) + out_scale = jnp.zeros(0, dtype=jnp.float32) + + if not fuse_bias: + bias = jnp.zeros(0, dtype=jnp.bfloat16) + else: + assert ( + bias is not None + ), "Missing bias in forward GEMM when bias epilogue is enabled." + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=bias.dtype) + elif gelu_input is None: + lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 + rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + gelu_input = jnp.zeros(out_shape, dtype=bias.dtype) + + out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind( + rhs, + rhs_scale_inv, + lhs, + lhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + contracting_dims=tuple(reversed(contracting_dims)), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=False, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + return out, out_amax, out_scale, pre_gelu_out + + +def gemm_impl( + lhs: ArrayLike, + rhs: ArrayLike, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, + contracting_dims: Tuple[int, int] = (1, 0), + fuse_gelu: bool = False, + fuse_bias: bool = False, + grad: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> Tuple[ArrayLike, ...]: + """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) + + lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 + rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + + if not fuse_bias: + bias = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + bias = jnp.zeros(out_shape[-1], dtype=lhs.dtype) + else: + assert ( + bias is not None + ), "Missing bias in forward GEMM when bias epilogue is enabled." + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=lhs.dtype) + elif grad: + assert ( + gelu_input is not None + ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." + elif gelu_input is None: + lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 + rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) + + out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + dummy_fp8_meta, + rhs, + dummy_fp8_meta, + bias, + gelu_input, + dummy_fp8_meta, + dummy_fp8_meta, + out_dtype=lhs.dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if grad: + return out, pre_gelu_out, bias_grad + else: + return out, pre_gelu_out diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 1f13484b98..15d7537fbd 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -81,6 +81,13 @@ def jax_dtype_to_te_dtype(jax_dtype): return converter.get(jax_dtype) +def jax_dtype_is_fp8(dtype): + """ + Check if the given jax.numpy.dtype is an FP8 dtype. + """ + return dtypes.canonicalize_dtype(dtype) in [jnp.float8_e4m3fn, jnp.float8_e5m2] + + def get_padded_spec(arg_info): """ Get padded spec for partitioning from arguments' information diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 02e6aaf9d5..afac283a6f 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -147,6 +147,31 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right); +struct CustomCallGemmDescriptor { + size_t batch; + size_t m; + size_t k; + size_t n; + size_t workspace_size; + DType operand_dtype; + DType bias_dtype; + DType out_dtype; + bool lhs_trans; + bool rhs_trans; + bool fuse_gelu; + bool fuse_bias; + bool grad; + bool accumulate; + bool use_split_accumulator; +}; + +pybind11::bytes PackCustomCallGemmDescriptor(size_t batch, size_t m, size_t n, size_t k, + size_t workspace_size, DType operand_dtype, + DType out_dtype, DType bias_dtype, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, + bool grad, bool accumulate, + bool use_split_accumulator); + // Transpose void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -308,6 +333,20 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); +// GEMM + +void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp new file mode 100644 index 0000000000..f60ae510df --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -0,0 +1,170 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/gemm.h" + +#include "common/util/cuda_runtime.h" +#include "common/util/system.h" +#include "extensions.h" + +namespace transformer_engine { + +namespace jax { + +void GemmImpl(cudaStream_t stream, void *lhs, const std::vector &lhs_shape, + float *lhs_scale_inv, bool lhs_trans, void *rhs, const std::vector &rhs_shape, + float *rhs_scale_inv, bool rhs_trans, DType operand_dtype, void *bias, + DType bias_dtype, void *out, float *out_amax, float *out_scale, DType out_dtype, + void *pre_gelu_out, void *workspace, size_t workspace_size, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator) { + auto lhs_ = TensorWrapper(lhs, lhs_shape, operand_dtype, nullptr, nullptr, lhs_scale_inv); + auto rhs_ = TensorWrapper(rhs, rhs_shape, operand_dtype, nullptr, nullptr, rhs_scale_inv); + + std::vector out_shape(2, 0); + out_shape[0] = (lhs_trans) ? lhs_shape[1] : lhs_shape[0]; + out_shape[1] = (rhs_trans) ? rhs_shape[0] : rhs_shape[1]; + auto out_ = TensorWrapper(out, out_shape, out_dtype, out_amax, out_scale, nullptr); + + void *bias_ptr = (fuse_bias) ? bias : nullptr; + std::vector bias_shape = (fuse_bias) ? std::vector{out_shape[1]} + : std::vector{0}; + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + void *pre_gelu_ptr = (fuse_gelu) ? pre_gelu_out : nullptr; + std::vector pre_gelu_shape = (fuse_gelu) ? out_shape : std::vector{0}; + auto pre_gelu_out_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, bias_dtype); + auto workspace_ = TensorWrapper(workspace, std::vector{workspace_size}, DType::kByte); + + // cuBLAS is column-major, so we swap LHS and RHS in the arguments + auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_out_.data(), + (rhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, (lhs_trans) ? CUBLAS_OP_T : CUBLAS_OP_N, + grad, workspace_.data(), accumulate, use_split_accumulator, num_math_sm, stream); +} + +void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + // Inputs + auto *lhs = buffers[0]; + auto *lhs_scale_inv = reinterpret_cast(buffers[1]); + auto *rhs = buffers[2]; + auto *rhs_scale_inv = reinterpret_cast(buffers[3]); + auto *bias = buffers[4]; + auto *gelu_input = buffers[5]; + auto *out_amax = reinterpret_cast(buffers[6]); + auto *out_scale = reinterpret_cast(buffers[7]); + + // Outputs + auto *out = buffers[8]; + auto *out_amax_updated = reinterpret_cast(buffers[9]); + auto *out_scale_updated = reinterpret_cast(buffers[10]); + auto *pre_gelu_out = buffers[11]; + auto *bias_grad = buffers[12]; + auto *workspace = buffers[13]; + + // Operand aliasing + NVTE_CHECK(bias == bias_grad, + "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input == pre_gelu_out, + "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_amax == out_amax_updated, + "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(out_scale == out_scale_updated, + "out_scale not bound to out_scale_updated in TE/JAX GEMM"); + + // GEMM sizing + const auto &desc = *UnpackOpaque(opaque, opaque_len); + std::vector lhs_shape = {(desc.lhs_trans) ? desc.k : desc.m, + (desc.lhs_trans) ? desc.m : desc.k}; + std::vector rhs_shape = {(desc.rhs_trans) ? desc.n : desc.k, + (desc.rhs_trans) ? desc.k : desc.n}; + + GemmImpl(stream, lhs, lhs_shape, lhs_scale_inv, desc.lhs_trans, rhs, rhs_shape, rhs_scale_inv, + desc.rhs_trans, desc.operand_dtype, bias, desc.bias_dtype, out, out_amax, out_scale, + desc.out_dtype, pre_gelu_out, workspace, desc.workspace_size, desc.fuse_gelu, + desc.fuse_bias, desc.grad, desc.accumulate, desc.use_split_accumulator); +} + +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator) { + // Inputs + auto lhs_ptr = lhs.untyped_data(); + auto lhs_scale_inv_ptr = reinterpret_cast(lhs_scale_inv.untyped_data()); + auto rhs_ptr = rhs.untyped_data(); + auto rhs_scale_inv_ptr = reinterpret_cast(rhs_scale_inv.untyped_data()); + auto operand_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + auto bias_ptr = bias.untyped_data(); + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + auto gelu_input_ptr = gelu_input.untyped_data(); + auto out_amax_ptr = reinterpret_cast(out_amax.untyped_data()); + auto out_scale_ptr = reinterpret_cast(out_scale.untyped_data()); + + // Outputs + auto out_ptr = out->untyped_data(); + auto out_amax_updated_ptr = reinterpret_cast(out_amax_updated->untyped_data()); + auto out_scale_updated_ptr = reinterpret_cast(out_scale_updated->untyped_data()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(out->element_type()); + auto pre_gelu_out_ptr = pre_gelu_out->untyped_data(); + auto bias_grad_ptr = bias_grad->untyped_data(); + auto workspace_ptr = workspace->untyped_data(); + auto workspace_size = workspace->dimensions().back(); + + // Operand aliasing + NVTE_CHECK(bias_ptr == bias_grad_ptr, + "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input_ptr == pre_gelu_out_ptr, + "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, + "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(out_scale_ptr == out_scale_updated_ptr, + "out_scale not bound to out_scale_updated in TE/JAX GEMM"); + + // GEMM sizing + std::vector lhs_shape(lhs.dimensions().begin(), lhs.dimensions().end()); + std::vector rhs_shape(rhs.dimensions().begin(), rhs.dimensions().end()); + + // Swap A and B argument locations to match what the TE/common kernel expects + GemmImpl(stream, lhs_ptr, lhs_shape, lhs_scale_inv_ptr, lhs_trans, rhs_ptr, rhs_shape, + rhs_scale_inv_ptr, rhs_trans, operand_dtype, bias_ptr, bias_dtype, out_ptr, out_amax_ptr, + out_scale_ptr, out_dtype, pre_gelu_out_ptr, workspace_ptr, workspace_size, fuse_gelu, + fuse_bias, grad, accumulate, use_split_accumulator); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Arg() // out_amax + .Arg() // out_scale + .Ret() // out + .Ret() // out_amax_updated + .Ret() // out_scale_updated + .Ret() // pre_gelu_out + .Ret() // bias_grad + .Ret() // workspace + .Attr("lhs_trans") + .Attr("rhs_trans") + .Attr("fuse_gelu") + .Attr("fuse_bias") + .Attr("grad") + .Attr("accumulate") + .Attr("use_split_accumulator"), + FFI_CudaGraph_Traits); + +} // namespace jax + +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 298478603b..1a9ce987af 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -80,5 +80,16 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( deterministic, window_size_left, window_size_right}); } +pybind11::bytes PackCustomCallGemmDescriptor(size_t batch, size_t m, size_t n, size_t k, + size_t workspace_size, DType operand_dtype, + DType bias_dtype, DType out_dtype, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, + bool grad, bool accumulate, + bool use_split_accumulator) { + return PackOpaque(CustomCallGemmDescriptor{batch, m, n, k, workspace_size, operand_dtype, + bias_dtype, out_dtype, lhs_trans, rhs_trans, fuse_gelu, + fuse_bias, grad, accumulate, use_split_accumulator}); +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 9b5c156e5d..7b8ebdcdd2 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -51,6 +51,7 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); + dict["te_gemm"] = EncapsulateFunction(Gemm); // Transpose dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); @@ -101,6 +102,7 @@ pybind11::dict Registrations() { fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler); dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi; + dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); return dict; } @@ -114,10 +116,11 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("pack_norm_descriptor", &PackCustomCallNormDescriptor); m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); + m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor); m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion); - m.def("get_device_compute_capability", &GetDeviceComputeCapability); + m.def("get_device_compute_capability", &GetDeviceComputeCapability, pybind11::arg("gpu_id") = -1); m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_dact_dbias_ct_workspace_sizes", &GetDActDBiasCastTransposeWorkspaceSizes); m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes); diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index 32de33bac9..b328c6e278 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -23,7 +23,7 @@ namespace jax { int GetCudaRuntimeVersion(); size_t GetCudnnRuntimeVersion(); -int GetDeviceComputeCapability(int gpu_id); +int GetDeviceComputeCapability(int gpu_id = -1); void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8b13c47cd4..7312aa8295 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -334,6 +334,7 @@ def generate_fp8_meta_set(postfix: str) -> FP8MetaPackage: input_name_post_fix = f"_i_{postfix}" weight_name_post_fix = f"_w_{postfix}" grad_name_post_fix = f"_g_{postfix}" + output_name_post_fix = f"_o_{postfix}" def generate_a_set(target_postfix): amax = nn_partitioning.variable_with_axes( @@ -359,10 +360,10 @@ def generate_a_set(target_postfix): input_amax, input_scale = generate_a_set(input_name_post_fix) weight_amax, weight_scale = generate_a_set(weight_name_post_fix) grad_amax, grad_scale = generate_a_set(grad_name_post_fix) + output_amax, output_scale = generate_a_set(output_name_post_fix) - return FP8MetaPackage( - input_amax, input_scale, weight_amax, weight_scale, grad_amax, grad_scale - ) + return FP8MetaPackage(input_amax, input_scale, weight_amax, weight_scale, grad_amax, + grad_scale, output_amax, output_scale) class DenseGeneral(TransformerEngineBase): diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index 5df8ce4386..3d58c86e3e 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -86,10 +86,11 @@ class FP8MetaPackage: A container that contains all required meta data for FP8 """ - NUM_OF_META: int = 3 + NUM_OF_META: int = 4 INPUT_IDX: int = 0 WEIGHT_IDX: int = 1 GRAD_IDX: int = 2 + OUTPUT_IDX: int = 3 def __init__( self, @@ -99,6 +100,8 @@ def __init__( weight_scale: jnp.ndarray, grad_amax: jnp.ndarray, grad_scale: jnp.ndarray, + output_amax: jnp.ndarray, + output_scale: jnp.ndarray, ) -> None: self._amax_list = [None] * FP8MetaPackage.NUM_OF_META @@ -110,6 +113,8 @@ def __init__( self._scale_list[FP8MetaPackage.WEIGHT_IDX] = weight_scale self._amax_list[FP8MetaPackage.GRAD_IDX] = grad_amax self._scale_list[FP8MetaPackage.GRAD_IDX] = grad_scale + self._amax_list[FP8MetaPackage.OUTPUT_IDX] = output_amax + self._scale_list[FP8MetaPackage.OUTPUT_IDX] = output_scale @property def amax_list(self) -> List[jnp.ndarray]: diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py new file mode 100644 index 0000000000..ccd109e095 --- /dev/null +++ b/transformer_engine/jax/gemm.py @@ -0,0 +1,425 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +from functools import partial +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike +from jax.ad_checkpoint import checkpoint_name + +from .fp8 import FP8Helper, FP8MetaPackage +from .cpp_extensions import ( + gemm_impl, + fp8_gemm_impl, + cast_fp8, + cast_transpose, + dact_lu, + dbias_cast_transpose, + dact_lu_dbias_cast_transpose, +) + + + +__all__ = [ + "gemm", + "fp8_gemm", + "type_safe_gemm", +] + + +def gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Optional[ArrayLike] = None, + contracting_dims: Tuple[int, int] = (1, 0), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + """Non-FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions.""" + return _gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) +def _gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Union[ArrayLike, None], + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> ArrayLike: + out, _ = _gemm_fwd_rule(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, + use_split_accumulator) + return out + + +def _gemm_fwd_rule( + x: ArrayLike, + kernel: ArrayLike, + bias: ArrayLike, + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> Tuple[ArrayLike, ...]: + fuse_bias = bias is not None + + out, pre_gelu_out = gemm_impl( + x, + kernel, + bias=bias, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator + ) + + ctx = ( + x, + kernel, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + ) + + return out, ctx + + +def _gemm_bwd_rule( + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ctx, + grad, +): + x, kernel, pre_gelu_out, fuse_bias = ctx + + x_t_contracting = 0 if contracting_dims[0] == 1 else 1 + wgrad, dgelu, bgrad = gemm_impl( + x, + grad, + gelu_input=pre_gelu_out, + contracting_dims=(x_t_contracting, 0), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + kernel_t_contracting = 1 if contracting_dims[1] == 0 else 0 + dgrad, *_ = gemm_impl( + dgelu if fuse_gelu else grad, + kernel, + gelu_input=pre_gelu_out, + contracting_dims=(1, kernel_t_contracting), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=True, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if not fuse_bias: + bgrad = None + + return dgrad, wgrad, bgrad + + +_gemm.defvjp(_gemm_fwd_rule, _gemm_bwd_rule) + + +def fp8_gemm( + x: ArrayLike, + kernel: ArrayLike, + fp8_meta: FP8MetaPackage, + bias: Optional[ArrayLike] = None, + out_dtype: jnp.dtype = jnp.bfloat16, + contracting_dims: Tuple[int, int] = (1, 1), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + return _fp8_gemm(x, kernel, bias, fp8_meta.amax_list, fp8_meta.scale_list, out_dtype, + contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + + +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) +def _fp8_gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: ArrayLike, + amax_list: ArrayLike, + scale_list: ArrayLike, + out_dtype: jnp.dtype, + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> ArrayLike: + """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" + out, _ = _fp8_gemm_fwd_rule(x, kernel, bias, amax_list, scale_list, out_dtype, + contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + return out + + +def _fp8_gemm_fwd_rule( + x: ArrayLike, + kernel: ArrayLike, + bias: ArrayLike, + amax_list: ArrayLike, + scale_list: ArrayLike, + out_dtype: jnp.dtype, + contracting_dims: Tuple[int, int], + fuse_gelu: bool, + accumulate: bool, + use_split_accumulator: bool, +) -> Tuple[ArrayLike, ...]: + fuse_bias = bias is not None + + maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( + *amax_list, *scale_list, + ) + amax_list = maybe_fm32_to_fp32(*amax_list) + scale_list = maybe_fm32_to_fp32(*scale_list) + + fwd_dtype = FP8Helper.FWD_DTYPE + bwd_dtype = FP8Helper.BWD_DTYPE + fp8_dtype_list = [fwd_dtype, fwd_dtype, bwd_dtype, fwd_dtype] + scale_list, scale_inv_list = FP8MetaPackage.update_fp8_scale( + amax_list, scale_list, fp8_dtype_list + ) + amax_list = FP8MetaPackage.update_amax_list(amax_list) + + x_amax = amax_list[FP8MetaPackage.INPUT_IDX][0:1] + x_scale = scale_list[FP8MetaPackage.INPUT_IDX] + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + if x.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + if contracting_dims[0] == 0: + _, casted_x, updated_x_amax = cast_transpose( + x, + x_amax, + x_scale, + x_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + casted_x, updated_x_amax = cast_fp8(x, x_amax, x_scale, x_scale_inv, fwd_dtype) + else: + if contracting_dims[0] == 0: + casted_x_t = x + casted_x = casted_x_t.transpose() + else: + casted_x = x + updated_x_amax = x_amax + + kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1] + kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX] + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + if kernel.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + if contracting_dims[1] == 0: # need to transpose the kernel for FP8 GEMM + _, casted_kernel_t, updated_kernel_amax = cast_transpose( + kernel, + kernel_amax, + kernel_scale, + kernel_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + casted_kernel_t, updated_kernel_amax = cast_fp8( + kernel, + kernel_amax, + kernel_scale, + kernel_scale_inv, + fwd_dtype, + ) + else: + if contracting_dims[1] == 0: + casted_kernel = kernel + casted_kernel_t = casted_kernel.transpose() + else: + casted_kernel_t = kernel + updated_kernel_amax = kernel_amax + + out_amax = ( + amax_list[FP8MetaPackage.OUTPUT_IDX][0:1] + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + else None + ) + out_scale = ( + scale_list[FP8MetaPackage.OUTPUT_IDX][0:1] + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + else None + ) + out, updated_out_amax, updated_out_scale, pre_gelu_out = fp8_gemm_impl( + casted_x, + x_scale_inv, + casted_kernel_t, + kernel_scale_inv, + bias=bias, + out_amax=out_amax, + out_scale=out_scale, + out_dtype=out_dtype, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator + ) + if out_dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + updated_out_amax = None + updated_out_scale = None + + ctx = ( + casted_x, + casted_kernel_t, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + pre_gelu_out if fuse_gelu else None, + fuse_bias, + maybe_fp32_to_fm32 + ) + + return (out, updated_out_amax, updated_out_scale), ctx + + +def _fp8_gemm_bwd_rule( + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ctx, + grad, +): + ( + casted_x, + casted_kernel_t, + amax_list, + scale_list, + scale_inv_list, + updated_x_amax, + updated_kernel_amax, + pre_gelu_out, + fuse_bias, + maybe_fp32_to_fm32 + ) = ctx + + fwd_dtype = FP8Helper.FWD_DTYPE + bwd_dtype = FP8Helper.BWD_DTYPE + + grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] + grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] + grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_ID] + if fuse_bias and not fuse_gelu: + # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. + _, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + # If both bias and GELU is fused into the forward pass, we will fuse dbias later with + # dGELU. No need to do it here. + _, casted_grad_t, updated_grad_amax = cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + bgrad = None + + + + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + wgrad, *_ = fp8_gemm_impl( + casted_x, + x_scale_inv, + casted_grad_t, + grad_scale_inv, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + if fuse_gelu and fuse_bias: + # Fuse dbias into this dGELU. + casted_dgelu, casted_dgelu_t, bgrad, updated_dgelu_amax = dact_lu_dbias_cast_transpose( + grad, + pre_gelu_out, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + activation_type=("gelu", ), + ) + elif fuse_gelu: + # No bias to fuse so we just do dGELU. + casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu", )) + bgrad = None + + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + dgrad, *_ = gemm_impl( + casted_dgelu if fuse_gelu else grad, + grad_scale_inv, + casted_kernel_t, + kernel_scale_inv, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + + amax_list[FP8MetaPackage.INPUT_IDX] = ( + amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0]) + ) + amax_list[FP8MetaPackage.WEIGHT_IDX] = ( + amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0]) + ) + + amax_list = maybe_fp32_to_fm32(*amax_list) + scale_list = maybe_fp32_to_fm32(*scale_list) + + return dgrad, wgrad, bgrad, amax_list, scale_list + + +_fp8_gemm.defvjp(_fp8_gemm_fwd_rule, _fp8_gemm_bwd_rule) + + +def type_safe_gemm( + x: ArrayLike, + kernel: ArrayLike, + bias: Optional[ArrayLike] = None, + fp8_meta: Optional[FP8MetaPackage] = None, + out_dtype: Optional[jnp.dtype] = None, + contracting_dims: Tuple[int, int] = (1, 0), + fuse_gelu: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, +) -> ArrayLike: + if (x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + or kernel.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]): + assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." + + if fp8_meta is not None: + return fp8_gemm(x, kernel, bias, fp8_meta, out_dtype, contracting_dims, fuse_gelu, + accumulate, use_split_accumulator) + else: + return gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) From c9774d8c203d5b0f5769f47daf70e0c655d0d110 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 14 Nov 2024 17:59:20 +0000 Subject: [PATCH 02/39] fixed batching rules to accommodated batched RHS operand for GEMM Signed-off-by: Alp Dener --- .../common/util/pybind_helper.h | 138 ++++++++++-------- transformer_engine/jax/cpp_extensions/gemm.py | 133 ++++++----------- .../jax/csrc/extensions/pybind.cpp | 59 +------- 3 files changed, 123 insertions(+), 207 deletions(-) diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 432ac815ec..a36ff3f0f9 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -8,72 +8,88 @@ #define TRANSFORMER_ENGINE_COMMON_UTIL_PYBIND_HELPER_H_ #include +#include #include #include #include #include "cuda_runtime.h" -#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ - pybind11::enum_(m, "DType") \ - .value("kByte", transformer_engine::DType::kByte) \ - .value("kInt32", transformer_engine::DType::kInt32) \ - .value("kFloat32", transformer_engine::DType::kFloat32) \ - .value("kFloat16", transformer_engine::DType::kFloat16) \ - .value("kBFloat16", transformer_engine::DType::kBFloat16) \ - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ - pybind11::enum_(m, "NVTE_Bias_Type") \ - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ - pybind11::enum_(m, "NVTE_Mask_Type") \ - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ - pybind11::enum_(m, "NVTE_QKV_Layout") \ - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ - pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ - pybind11::enum_(m, "CommOverlapType") \ - .value("RS", transformer_engine::CommOverlapType::RS) \ - .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo") \ - .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ - .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ - .value("SPLIT_PIPELINED_AG_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ - .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ - .value("SPLIT_PIPELINED_RS_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ - .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ - .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ - m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ - py::call_guard()); +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType") \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type") \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type") \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Format") \ + .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ + .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ + .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); \ + pybind11::enum_(m, "NVTE_QKV_Layout") \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "NVTE_Activation_Type") \ + .value("GELU", NVTE_Activation_Type::GELU) \ + .value("GEGLU", NVTE_Activation_Type::GEGLU) \ + .value("SILU", NVTE_Activation_Type::SILU) \ + .value("SWIGLU", NVTE_Activation_Type::SWIGLU) \ + .value("RELU", NVTE_Activation_Type::RELU) \ + .value("REGLU", NVTE_Activation_Type::REGLU) \ + .value("QGELU", NVTE_Activation_Type::QGELU) \ + .value("QGEGLU", NVTE_Activation_Type::QGEGLU) \ + .value("SRELU", NVTE_Activation_Type::SRELU) \ + .value("SREGLU", NVTE_Activation_Type::SREGLU); \ + pybind11::enum_(m, "CommOverlapType") \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo") \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + pybind11::call_guard(), pybind11::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ + pybind11::call_guard()); #endif diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 677fabca59..ceafce46e1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -24,10 +24,10 @@ jax_dtype_is_fp8, get_padded_spec, is_ffi_enabled, + check_valid_batch_dims, ) from ..sharding import ( global_mesh_resource, - get_mesh_axis_size, lax_paral_op, all_reduce_max_along_all_axes_except_PP, ) @@ -83,9 +83,6 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av and dtypes.canonicalize_dtype(rhs_scale_inv_aval.dtype) == jnp.float32 ), "Missing RHS operand scale inverse in FP8 GEMM." - # Disallow batching for RHS - assert rhs_aval.ndim == 2, "GEMM does not support batching the RHS operand." - # Validate operand layouts lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, @@ -97,12 +94,12 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 assert ( not (lhs_trans and rhs_trans) ), "GEMM does not support transposed LHS and transposed RHS at the same time." if is_fp8: - assert lhs_trans, "FP8 GEMM does not support transposed LHS." + assert not lhs_trans, "FP8 GEMM does not support transposed LHS." assert rhs_trans, "FP8 GEMM requires transposed RHS." # Validate output dtype @@ -124,11 +121,18 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av out_scale_updated_dtype = jnp.float32 # Infer output shape - rhs_outer_dim = 0 if rhs_trans else 1 lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 lhs_bdims = [dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + rhs_bdims = [dim for dim in range(rhs_aval.ndim) + if dim not in [rhs_outer_dim, rhs_inner_dim]] + rhs_batch_size = reduce(operator.mul, rhs_bdims, 1) + assert ( + lhs_batch_size == rhs_batch_size + ), "LHS and RHS operands must have the same batched sizes." out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) # Validate bias/bias_grad shape against inferred output @@ -201,7 +205,7 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ (lhs_aval.ndim, rhs_aval.ndim) ) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 operand_output_aliases = { 4: 4, # bias <--> bias_grad @@ -248,12 +252,9 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - rhs_outer_dim = 0 if rhs_trans else 1 lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 - lhs_bdims = [dim for dim in range(lhs_aval.ndim) - if dim not in [lhs_outer_dim, lhs_inner_dim]] - lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] - m = reduce(operator.mul, lhs_batch_shape, 1) * lhs_aval.shape[lhs_outer_dim] + rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + m = lhs_aval.shape[lhs_outer_dim] k = rhs_aval.shape[rhs_inner_dim] n = rhs_aval.shape[rhs_outer_dim] workspace_size = get_cublas_workspace_size_bytes() @@ -308,77 +309,32 @@ def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator): assert CollectiveGemmPrimitive.outer_primitive is not None + check_valid_batch_dims(batch_dims) + lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims - lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale = batched_args - assert rhs.ndim == 2, "TE/JAX GEMM custom op does not support batching RHS operands." - - # Get contracting and batch dimensions out - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs.ndim, rhs.ndim) - ) - lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == 1 - lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = 0 if rhs_trans else 1 - lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] - - # FP8 GEMM only supports lhs_trans = False and rhs_trans = True so we may need to - # reorder the axes here to match - if jax_dtype_is_fp8(lhs.dtype): - lhs = jnp.transpose(lhs, (*lhs_bdims, lhs_outer_dim, lhs_inner_dim)) - lhs_trans = False - rhs = jnp.transpose(rhs, (rhs_outer_dim, rhs_inner_dim)) - rhs_trans = True - contracting_dims = (1, 1) - - # Collapse all non-contracting dimensions - batch_shape = [lhs.shape[dim] for dim in lhs_bdims] - batch_size = reduce(operator.mul, batch_shape, 1) - lhs_outer_size = lhs.shape[lhs_outer_dim] - lhs_shape_2d = ( - (lhs.shape[lhs_inner_dim], batch_size * lhs_outer_size) - if lhs_trans - else (batch_size * lhs_outer_size, lhs.shape[lhs_inner_dim]) - ) - lhs = jnp.reshape(lhs, lhs_shape_2d) - if fuse_gelu: - gelu_input = jnp.reshape( - gelu_input, (batch_size * lhs_outer_size, rhs.shape[rhs_outer_dim]) - ) - - outputs = CollectiveGemmPrimitive.outer_primitive.bind( - lhs, - lhs_scale_inv, - rhs, - rhs_scale_inv, - bias, - gelu_input, - out_amax, - out_scale, - out_dtype=out_dtype, - contracting_dims=contracting_dims, - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, - grad=grad, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - - # Reshape output to recover original LHS batch shape - outputs[0] = jnp.reshape( - outputs[0], - (*batch_shape, lhs_outer_size, rhs.shape[rhs_outer_dim]) - ) - gelu_bdims = batch_dims[3] - if fuse_gelu: - outputs[3] = jnp.reshape(outputs[3], outputs[0].shape) - gelu_bdims = lhs_bdims + # FP8 GEMM only supports non-transposed LHS and transposed RHS + lhs, _, rhs, *_ = batched_args + lhs_trans = contracting_dims[0] != lhs.ndim - 1 + rhs_trans = contracting_dims[1] == rhs.ndim - 1 + lhs = jnp.matrix_transpose(lhs) if lhs_trans and jax_dtype_is_fp8(lhs.dtype) else lhs + rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs + contracting_dims = (1, 1) return ( - outputs, - (lhs_bdims, batch_dims[1], batch_dims[2], gelu_bdims, batch_dims[4]) + CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + batched_args[1], + rhs, + *batched_args[3:], + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ) + (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) ) @staticmethod @@ -400,9 +356,9 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bi + "not already partitioned correctly.") lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs.ndim - 1 lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = 0 if rhs_trans else 1 + rhs_outer_dim = rhs.ndim - 2 if rhs_trans else rhs.ndim - 1 lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] @@ -440,9 +396,9 @@ def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulat ) lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == 1 + rhs_trans = rhs_inner_dim == rhs.ndim - 1 lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = 0 if rhs_trans else 1 + rhs_outer_dim = rhs.ndim - 2 if rhs_trans else rhs.ndim - 1 lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] @@ -558,7 +514,7 @@ def fp8_gemm_impl( gelu_input = jnp.zeros(0, dtype=bias.dtype) elif gelu_input is None: lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) gelu_input = jnp.zeros(out_shape, dtype=bias.dtype) @@ -599,7 +555,7 @@ def gemm_impl( dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 + rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) if not fuse_bias: @@ -618,9 +574,6 @@ def gemm_impl( gelu_input is not None ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." elif gelu_input is None: - lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = 1 if contracting_dims[1] == 0 else 0 - out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 7b8ebdcdd2..ddf98d9d78 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include "common/util/pybind_helper.h" #include "extensions.h" namespace transformer_engine { @@ -107,6 +108,8 @@ pybind11::dict Registrations() { } PYBIND11_MODULE(transformer_engine_jax, m) { + NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + m.def("registrations", &Registrations); m.def("pack_common_descriptor", &PackCustomCallCommonDescriptor, pybind11::arg(), pybind11::arg(), pybind11::arg(), pybind11::arg("act_num") = 0); @@ -129,62 +132,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); - - pybind11::enum_(m, "DType", pybind11::module_local()) - .value("kByte", DType::kByte) - .value("kInt32", DType::kInt32) - .value("kInt64", DType::kInt64) - .value("kFloat32", DType::kFloat32) - .value("kFloat16", DType::kFloat16) - .value("kBFloat16", DType::kBFloat16) - .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); - - pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - - pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); - - pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); - - pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) - .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) - .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) - .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); - - pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) - .value("GELU", NVTE_Activation_Type::GELU) - .value("GEGLU", NVTE_Activation_Type::GEGLU) - .value("SILU", NVTE_Activation_Type::SILU) - .value("SWIGLU", NVTE_Activation_Type::SWIGLU) - .value("RELU", NVTE_Activation_Type::RELU) - .value("REGLU", NVTE_Activation_Type::REGLU) - .value("QGELU", NVTE_Activation_Type::QGELU) - .value("QGEGLU", NVTE_Activation_Type::QGEGLU) - .value("SRELU", NVTE_Activation_Type::SRELU) - .value("SREGLU", NVTE_Activation_Type::SREGLU) - .export_values(); - - pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8); } } // namespace jax From e523018a8f7e3de2e1e4ab2a989eb6e13ca4a9b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 18:14:24 +0000 Subject: [PATCH 03/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 15 +- transformer_engine/jax/cpp_extensions/gemm.py | 275 ++++++++++++------ .../jax/csrc/extensions/gemm.cpp | 16 +- transformer_engine/jax/flax/module.py | 12 +- transformer_engine/jax/gemm.py | 70 +++-- 5 files changed, 254 insertions(+), 134 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9bf3f9fa91..355f587265 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -425,19 +425,16 @@ def _generate_inputs(b, m, n, k, dtype): a = jax.random.normal(subkeys[0], (b, m, k), dtype) b = jax.random.normal(subkeys[1], (n, k), dtype) bias_dtype = dtype if dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2] else jnp.bfloat16 - bias = jax.random.normal(subkeys[2], (n, ), bias_dtype) + bias = jax.random.normal(subkeys[2], (n,), bias_dtype) return a, b, bias @staticmethod def _generate_fp8_inputs(b, m, n, k, fp8_dtype): a, b, bias = TestGemm._generate_inputs(b, m, n, k, jnp.bfloat16) - a_scale, b_scale = map( - lambda x: (jnp.max(jnp.abs(x)) / 127.).astype(jnp.float32), - [a, b] - ) + a_scale, b_scale = map(lambda x: (jnp.max(jnp.abs(x)) / 127.0).astype(jnp.float32), [a, b]) a_q, b_q = map( lambda x, x_scale: jnp.round(x / x_scale).astype(fp8_dtype), - [(a, a_scale), (b, b_scale)] + [(a, a_scale), (b, b_scale)], ) return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias @@ -447,7 +444,7 @@ def _generate_fp8_inputs(b, m, n, k, fp8_dtype): def test_gemm(self, b, m, n, k, use_bias, do_gelu): a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16) - primitive_out = gemm(a, b, bias=bias if use_bias else None, layout='NT', do_gelu=do_gelu) + primitive_out = gemm(a, b, bias=bias if use_bias else None, layout="NT", do_gelu=do_gelu) ref_out = jnp.dot(a, b) if use_bias: ref_out += bias @@ -460,9 +457,7 @@ def test_gemm(self, b, m, n, k, use_bias, do_gelu): @pytest.mark.parametrize("m,n,k", GEMM_CASES) @pytest.mark.parametrize("fp8_dtype", FP8_COMPUTE_TYPE) def test_fp8_gemm(self, m, n, k, fp8_dtype): - a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs( - m, n, k, fp8_dtype - ) + a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs(m, n, k, fp8_dtype) primitive_out = fp8_gemm(a_q, a_scale_inv, b_q, b_scale_inv, out_dtype=jnp.bfloat16) ref_out = jnp.dot(a, b) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ceafce46e1..2df05d6df4 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -58,9 +58,23 @@ class CollectiveGemmPrimitive(BasePrimitive): outer_primitive = None @staticmethod - def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_aval, - gelu_input_aval, out_amax_aval, out_scale_aval, out_dtype, contracting_dims, - fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator): + def abstract( + lhs_aval, + lhs_scale_inv_aval, + rhs_aval, + rhs_scale_inv_aval, + bias_aval, + gelu_input_aval, + out_amax_aval, + out_scale_aval, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): """ cuBlasLt GEMM abstract """ @@ -87,7 +101,7 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim) + (lhs_aval.ndim, rhs_aval.ndim), ) assert ( lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] @@ -95,8 +109,8 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 - assert ( - not (lhs_trans and rhs_trans) + assert not ( + lhs_trans and rhs_trans ), "GEMM does not support transposed LHS and transposed RHS at the same time." if is_fp8: assert not lhs_trans, "FP8 GEMM does not support transposed LHS." @@ -104,8 +118,8 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Validate output dtype if jax_dtype_is_fp8(out_dtype): - assert ( - jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8(rhs_dtype) + assert jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8( + rhs_dtype ), "FP8 GEMM output requires FP8 inputs." assert ( out_amax_aval.size == out_scale_aval.size == 1 @@ -122,13 +136,15 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Infer output shape lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 - lhs_bdims = [dim for dim in range(lhs_aval.ndim) - if dim not in [lhs_outer_dim, lhs_inner_dim]] + lhs_bdims = [ + dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] + ] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 - rhs_bdims = [dim for dim in range(rhs_aval.ndim) - if dim not in [rhs_outer_dim, rhs_inner_dim]] + rhs_bdims = [ + dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] + ] rhs_batch_size = reduce(operator.mul, rhs_bdims, 1) assert ( lhs_batch_size == rhs_batch_size @@ -139,9 +155,7 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype if fuse_bias: assert ( - bias_aval.size > 0 - and bias_aval.ndim == 1 - and bias_aval.shape[0] == out_shape[-1] + bias_aval.size > 0 and bias_aval.ndim == 1 and bias_aval.shape[0] == out_shape[-1] ), "Incorrect bias shape." bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) else: @@ -149,8 +163,8 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Validate GELU input/output if fuse_gelu: - assert ( - all([gelu_input_aval.shape[i] == out_shape[i] for i in len(out_shape)]) + assert all( + [gelu_input_aval.shape[i] == out_shape[i] for i in len(out_shape)] ), "Invalid GELU input shape." assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." else: @@ -158,14 +172,17 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av # Create abstract arrays for all outputs out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype) - out_amax_updated_aval = out_amax_aval.update(shape=out_amax_aval.shape, - dtype=out_amax_updated_dtype) - out_scale_updated_aval = out_scale_aval.update(shape=out_scale_aval.shape, - dtype=out_scale_updated_dtype) + out_amax_updated_aval = out_amax_aval.update( + shape=out_amax_aval.shape, dtype=out_amax_updated_dtype + ) + out_scale_updated_aval = out_scale_aval.update( + shape=out_scale_aval.shape, dtype=out_scale_updated_dtype + ) pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_input_aval.shape, dtype=bias_dtype) bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - workspace_aval = jax.core.ShapedArray(shape=(get_cublas_workspace_size_bytes(), ), - dtype=jnp.uint8) + workspace_aval = jax.core.ShapedArray( + shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 + ) return ( out_aval, @@ -173,7 +190,7 @@ def abstract(lhs_aval, lhs_scale_inv_aval, rhs_aval, rhs_scale_inv_aval, bias_av out_scale_updated_aval, pre_gelu_out_aval, bias_grad_aval, - workspace_aval + workspace_aval, ) @staticmethod @@ -181,20 +198,31 @@ def outer_abstract(*args, **kwargs): """ cuBlasLt GEMM outer abstract """ - ( - out_aval, - out_amax_aval, - out_scale_aval, - pre_gelu_out_aval, - bias_grad_aval, - _ - ) = CollectiveGemmPrimitive.abstract(*args, **kwargs) + (out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval, _) = ( + CollectiveGemmPrimitive.abstract(*args, **kwargs) + ) return out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval @staticmethod - def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, - *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, - use_split_accumulator): + def lowering( + ctx, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + *, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): """ Fused attention fwd lowering rules """ @@ -202,7 +230,7 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim) + (lhs_aval.ndim, rhs_aval.ndim), ) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 @@ -232,7 +260,7 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ fuse_bias=fuse_bias, grad=grad, accumulate=accumulate, - use_split_accumulator=use_split_accumulator + use_split_accumulator=use_split_accumulator, ) else: operands = [ @@ -260,10 +288,22 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ workspace_size = get_cublas_workspace_size_bytes() operand_dtype = jax_dtype_to_te_dtype(lhs_aval.dtype) bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype) - opaque = tex.pack_gemm_descriptor(m, n, k, workspace_size, operand_dtype, - jax_dtype_to_te_dtype(out_dtype), bias_dtype, - lhs_trans, rhs_trans, fuse_gelu, fuse_bias, grad, - accumulate, use_split_accumulator) + opaque = tex.pack_gemm_descriptor( + m, + n, + k, + workspace_size, + operand_dtype, + jax_dtype_to_te_dtype(out_dtype), + bias_dtype, + lhs_trans, + rhs_trans, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ) return custom_caller( CollectiveGemmPrimitive.name, @@ -274,9 +314,23 @@ def lowering(ctx, lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_ ) @staticmethod - def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale, - out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, - use_split_accumulator): + def impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): assert CollectiveGemmPrimitive.inner_primitive is not None ( @@ -306,13 +360,23 @@ def impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad @staticmethod - def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, - accumulate, use_split_accumulator): + def batcher( + batched_args, + batch_dims, + *, + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + ): assert CollectiveGemmPrimitive.outer_primitive is not None check_valid_batch_dims(batch_dims) lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims - # FP8 GEMM only supports non-transposed LHS and transposed RHS + # FP8 GEMM only supports non-transposed LHS and transposed RHS lhs, _, rhs, *_ = batched_args lhs_trans = contracting_dims[0] != lhs.ndim - 1 rhs_trans = contracting_dims[1] == rhs.ndim - 1 @@ -320,27 +384,33 @@ def batcher(batched_args, batch_dims, *, out_dtype, contracting_dims, fuse_gelu, rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs contracting_dims = (1, 1) - return ( - CollectiveGemmPrimitive.outer_primitive.bind( - lhs, - batched_args[1], - rhs, - *batched_args[3:], - out_dtype=out_dtype, - contracting_dims=contracting_dims, - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, - grad=grad, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) - ) + return CollectiveGemmPrimitive.outer_primitive.bind( + lhs, + batched_args[1], + rhs, + *batched_args[3:], + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + )(lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) @staticmethod - def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, - accumulate, use_split_accumulator, mesh, arg_infos, - result_infos): + def infer_sharding_from_operands( + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): del out_dtype, accumulate, use_split_accumulator, result_infos lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) @@ -348,12 +418,14 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bi lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad: - warnings.warn("Forcing the inner dimension of LHS to match the sharding of inner " - + "dimension of RHS. This can trigger additional communication if LHS is " - + "not already partitioned correctly.") + warnings.warn( + "Forcing the inner dimension of LHS to match the sharding of inner " + + "dimension of RHS. This can trigger additional communication if LHS is " + + "not already partitioned correctly." + ) lhs_trans = lhs_inner_dim != lhs.ndim - 1 rhs_trans = rhs_inner_dim == rhs.ndim - 1 @@ -383,8 +455,18 @@ def infer_sharding_from_operands(out_dtype, contracting_dims, fuse_gelu, fuse_bi return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) @staticmethod - def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulate, - use_split_accumulator, mesh, arg_infos, result_infos): + def partition( + out_dtype, + contracting_dims, + fuse_gelu, + fuse_bias, + grad, + accumulate, + use_split_accumulator, + mesh, + arg_infos, + result_infos, + ): del result_infos lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) @@ -392,7 +474,7 @@ def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulat lhs_inner_dim, rhs_inner_dim = map( lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, contracting_dims, - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) lhs_trans = lhs_inner_dim != lhs.ndim - 1 @@ -426,13 +508,27 @@ def partition(out_dtype, contracting_dims, fuse_gelu, fuse_bias, grad, accumulat gelu_spec = out_spec if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) - arg_shardings = (lhs_sharding, fp8_meta_sharding, rhs_sharding, fp8_meta_sharding, - bias_sharding, gelu_sharding, fp8_meta_sharding, fp8_meta_sharding) - out_shardings = (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, - bias_sharding) + arg_shardings = ( + lhs_sharding, + fp8_meta_sharding, + rhs_sharding, + fp8_meta_sharding, + bias_sharding, + gelu_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + ) + out_shardings = ( + out_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + gelu_sharding, + bias_sharding, + ) - def sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, - out_scale): + def sharded_impl( + lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale + ): ( out, out_amax_updated, @@ -465,8 +561,7 @@ def sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_a # GEMM output needs to be all-reduced when the contracting dimension is sharded. # If the layer is sequence-parallel, we also need to scatter the output, which we # can combine into a reduce-scatter here. - out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().cp_resource, - mesh) + out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().cp_resource, mesh) if fuse_gelu: pre_gelu_out = lax_paral_op( pre_gelu_out, jax.lax.psum, global_mesh_resource().cp_resource, mesh @@ -485,10 +580,10 @@ def fp8_gemm_impl( lhs_scale_inv: ArrayLike, rhs: ArrayLike, rhs_scale_inv: ArrayLike, - bias: Optional[ArrayLike] = None, + bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, - out_amax: Optional[ArrayLike] = None, - out_scale: Optional[ArrayLike] = None, + out_amax: Optional[ArrayLike] = None, + out_scale: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, contracting_dims: Tuple[int, int] = (1, 1), fuse_gelu: bool = False, @@ -506,9 +601,7 @@ def fp8_gemm_impl( if not fuse_bias: bias = jnp.zeros(0, dtype=jnp.bfloat16) else: - assert ( - bias is not None - ), "Missing bias in forward GEMM when bias epilogue is enabled." + assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=bias.dtype) @@ -542,8 +635,8 @@ def fp8_gemm_impl( def gemm_impl( lhs: ArrayLike, rhs: ArrayLike, - bias: Optional[ArrayLike] = None, - gelu_input: Optional[ArrayLike] = None, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, contracting_dims: Tuple[int, int] = (1, 0), fuse_gelu: bool = False, fuse_bias: bool = False, @@ -563,9 +656,7 @@ def gemm_impl( elif grad: bias = jnp.zeros(out_shape[-1], dtype=lhs.dtype) else: - assert ( - bias is not None - ), "Missing bias in forward GEMM when bias epilogue is enabled." + assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=lhs.dtype) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f60ae510df..5dae9d6757 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -29,8 +29,8 @@ void GemmImpl(cudaStream_t stream, void *lhs, const std::vector &lhs_sha auto out_ = TensorWrapper(out, out_shape, out_dtype, out_amax, out_scale, nullptr); void *bias_ptr = (fuse_bias) ? bias : nullptr; - std::vector bias_shape = (fuse_bias) ? std::vector{out_shape[1]} - : std::vector{0}; + std::vector bias_shape = + (fuse_bias) ? std::vector{out_shape[1]} : std::vector{0}; auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); void *pre_gelu_ptr = (fuse_gelu) ? pre_gelu_out : nullptr; @@ -65,12 +65,9 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque auto *workspace = buffers[13]; // Operand aliasing - NVTE_CHECK(bias == bias_grad, - "bias not bound to bias_grad in TE/JAX GEMM"); - NVTE_CHECK(gelu_input == pre_gelu_out, - "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); - NVTE_CHECK(out_amax == out_amax_updated, - "out_amax not bound to out_amax_updated in TE/JAX GEMM"); + NVTE_CHECK(bias == bias_grad, "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(gelu_input == pre_gelu_out, "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_amax == out_amax_updated, "out_amax not bound to out_amax_updated in TE/JAX GEMM"); NVTE_CHECK(out_scale == out_scale_updated, "out_scale not bound to out_scale_updated in TE/JAX GEMM"); @@ -117,8 +114,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto workspace_size = workspace->dimensions().back(); // Operand aliasing - NVTE_CHECK(bias_ptr == bias_grad_ptr, - "bias not bound to bias_grad in TE/JAX GEMM"); + NVTE_CHECK(bias_ptr == bias_grad_ptr, "bias not bound to bias_grad in TE/JAX GEMM"); NVTE_CHECK(gelu_input_ptr == pre_gelu_out_ptr, "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 7312aa8295..abe23fdf8b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -362,8 +362,16 @@ def generate_a_set(target_postfix): grad_amax, grad_scale = generate_a_set(grad_name_post_fix) output_amax, output_scale = generate_a_set(output_name_post_fix) - return FP8MetaPackage(input_amax, input_scale, weight_amax, weight_scale, grad_amax, - grad_scale, output_amax, output_scale) + return FP8MetaPackage( + input_amax, + input_scale, + weight_amax, + weight_scale, + grad_amax, + grad_scale, + output_amax, + output_scale, + ) class DenseGeneral(TransformerEngineBase): diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index ccd109e095..79499725b7 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -21,7 +21,6 @@ ) - __all__ = [ "gemm", "fp8_gemm", @@ -52,8 +51,9 @@ def _gemm( accumulate: bool, use_split_accumulator: bool, ) -> ArrayLike: - out, _ = _gemm_fwd_rule(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, - use_split_accumulator) + out, _ = _gemm_fwd_rule( + x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator + ) return out @@ -76,7 +76,7 @@ def _gemm_fwd_rule( fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, accumulate=accumulate, - use_split_accumulator=use_split_accumulator + use_split_accumulator=use_split_accumulator, ) ctx = ( @@ -145,8 +145,18 @@ def fp8_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: - return _fp8_gemm(x, kernel, bias, fp8_meta.amax_list, fp8_meta.scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + return _fp8_gemm( + x, + kernel, + bias, + fp8_meta.amax_list, + fp8_meta.scale_list, + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ) @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) @@ -163,8 +173,18 @@ def _fp8_gemm( use_split_accumulator: bool, ) -> ArrayLike: """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" - out, _ = _fp8_gemm_fwd_rule(x, kernel, bias, amax_list, scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + out, _ = _fp8_gemm_fwd_rule( + x, + kernel, + bias, + amax_list, + scale_list, + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ) return out @@ -183,7 +203,8 @@ def _fp8_gemm_fwd_rule( fuse_bias = bias is not None maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( - *amax_list, *scale_list, + *amax_list, + *scale_list, ) amax_list = maybe_fm32_to_fp32(*amax_list) scale_list = maybe_fm32_to_fp32(*scale_list) @@ -272,7 +293,7 @@ def _fp8_gemm_fwd_rule( fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, accumulate=accumulate, - use_split_accumulator=use_split_accumulator + use_split_accumulator=use_split_accumulator, ) if out_dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: updated_out_amax = None @@ -288,7 +309,7 @@ def _fp8_gemm_fwd_rule( updated_kernel_amax, pre_gelu_out if fuse_gelu else None, fuse_bias, - maybe_fp32_to_fm32 + maybe_fp32_to_fm32, ) return (out, updated_out_amax, updated_out_scale), ctx @@ -313,7 +334,7 @@ def _fp8_gemm_bwd_rule( updated_kernel_amax, pre_gelu_out, fuse_bias, - maybe_fp32_to_fm32 + maybe_fp32_to_fm32, ) = ctx fwd_dtype = FP8Helper.FWD_DTYPE @@ -347,8 +368,6 @@ def _fp8_gemm_bwd_rule( ) bgrad = None - - x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] wgrad, *_ = fp8_gemm_impl( casted_x, @@ -370,11 +389,11 @@ def _fp8_gemm_bwd_rule( bwd_dtype, static_axis_boundary=-1, transpose_axis_boundary=-1, - activation_type=("gelu", ), + activation_type=("gelu",), ) elif fuse_gelu: # No bias to fuse so we just do dGELU. - casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu", )) + casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) bgrad = None kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] @@ -414,12 +433,23 @@ def type_safe_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: - if (x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] - or kernel.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]): + if x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] or kernel.dtype in [ + jnp.float8_e4m3fn, + jnp.float8_e5m2, + ]: assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." if fp8_meta is not None: - return fp8_gemm(x, kernel, bias, fp8_meta, out_dtype, contracting_dims, fuse_gelu, - accumulate, use_split_accumulator) + return fp8_gemm( + x, + kernel, + bias, + fp8_meta, + out_dtype, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + ) else: return gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) From 2c3dbf1cf516d3dec5022b9b8304ee0d053170ba Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 15 Nov 2024 23:56:38 +0000 Subject: [PATCH 04/39] re-applied bug fixes to working older version, updated backward pass, passing test Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 93 +++---- transformer_engine/jax/gemm.py | 260 +++++++++--------- 2 files changed, 174 insertions(+), 179 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2df05d6df4..ee4c38d076 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1,7 +1,6 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""JAX/TE custom ops for cuBlasLt GEMM""" import warnings import operator from functools import reduce @@ -39,6 +38,10 @@ ] +def sanitize_dims(dim, ndims): + return (ndims + dim) if dim < 0 else dim + + def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if tex.get_device_compute_capability() >= 90: @@ -98,11 +101,8 @@ def abstract( ), "Missing RHS operand scale inverse in FP8 GEMM." # Validate operand layouts - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim)) assert ( lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." @@ -134,23 +134,31 @@ def abstract( out_amax_updated_dtype = jnp.float32 out_scale_updated_dtype = jnp.float32 - # Infer output shape + # Make sure leading dimensions of RHS is broadcast-compatible with LHS lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 + rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + lhs_bdims = [ dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] ] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) - rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 - rhs_bdims = [ - dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] - ] - rhs_batch_size = reduce(operator.mul, rhs_bdims, 1) - assert ( - lhs_batch_size == rhs_batch_size - ), "LHS and RHS operands must have the same batched sizes." - out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + if rhs_aval.ndim > 2: + rhs_bdims = [ + dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] + ] + rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + if rhs_batch_size > 1: + assert ( + lhs_batch_size == rhs_batch_size + ), ( + f"Leading dimensins of RHS ({rhs_batch_shape=}) is not broadcast-compatible " + + f"with the leading dimensions of LHS ({lhs_batch_shape=})." + ) + # Infer output shape + out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) # Validate bias/bias_grad shape against inferred output bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype if fuse_bias: @@ -227,11 +235,8 @@ def lowering( Fused attention fwd lowering rules """ lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, + (lhs_aval.ndim, rhs_aval.ndim)) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 @@ -376,19 +381,8 @@ def batcher( check_valid_batch_dims(batch_dims) lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims - # FP8 GEMM only supports non-transposed LHS and transposed RHS - lhs, _, rhs, *_ = batched_args - lhs_trans = contracting_dims[0] != lhs.ndim - 1 - rhs_trans = contracting_dims[1] == rhs.ndim - 1 - lhs = jnp.matrix_transpose(lhs) if lhs_trans and jax_dtype_is_fp8(lhs.dtype) else lhs - rhs = jnp.matrix_transpose(rhs) if not rhs_trans and jax_dtype_is_fp8(rhs.dtype) else rhs - contracting_dims = (1, 1) - return CollectiveGemmPrimitive.outer_primitive.bind( - lhs, - batched_args[1], - rhs, - *batched_args[3:], + *batched_args, out_dtype=out_dtype, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, @@ -415,11 +409,7 @@ def infer_sharding_from_operands( lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs.ndim, rhs.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad: warnings.warn( "Forcing the inner dimension of LHS to match the sharding of inner " @@ -471,11 +461,7 @@ def partition( lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) - lhs_inner_dim, rhs_inner_dim = map( - lambda inner_dim, ndims: (ndims - inner_dim) if inner_dim < 0 else inner_dim, - contracting_dims, - (lhs.ndim, rhs.ndim), - ) + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) lhs_trans = lhs_inner_dim != lhs.ndim - 1 rhs_trans = rhs_inner_dim == rhs.ndim - 1 @@ -578,14 +564,13 @@ def sharded_impl( def fp8_gemm_impl( lhs: ArrayLike, lhs_scale_inv: ArrayLike, - rhs: ArrayLike, + rhs_t: ArrayLike, rhs_scale_inv: ArrayLike, bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, out_amax: Optional[ArrayLike] = None, out_scale: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, - contracting_dims: Tuple[int, int] = (1, 1), fuse_gelu: bool = False, fuse_bias: bool = False, accumulate: bool = False, @@ -606,22 +591,20 @@ def fp8_gemm_impl( if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=bias.dtype) elif gelu_input is None: - lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 - out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + out_shape = (*lhs.shape[:-2], lhs.shape[-2], rhs_t.shape[-2]) gelu_input = jnp.zeros(out_shape, dtype=bias.dtype) out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind( - rhs, - rhs_scale_inv, lhs, lhs_scale_inv, + rhs_t, + rhs_scale_inv, bias, gelu_input, out_amax, out_scale, out_dtype=out_dtype, - contracting_dims=tuple(reversed(contracting_dims)), + contracting_dims=(-1, -1), fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, grad=False, @@ -645,10 +628,9 @@ def gemm_impl( use_split_accumulator: bool = False, ) -> Tuple[ArrayLike, ...]: """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" - dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) - - lhs_outer_dim = lhs.ndim - 1 if contracting_dims[0] == 1 else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if contracting_dims[1] == 0 else rhs.ndim - 1 + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_outer_dim = lhs.ndim - 1 if lhs_inner_dim == lhs.ndim - 2 else lhs.ndim - 2 + rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) if not fuse_bias: @@ -667,6 +649,7 @@ def gemm_impl( elif gelu_input is None: gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) + dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( lhs, dummy_fp8_meta, diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 79499725b7..e9e046d182 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -1,7 +1,8 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -from functools import partial +import operator +from functools import partial, reduce from typing import Optional, Tuple, Union import jax @@ -19,6 +20,7 @@ dbias_cast_transpose, dact_lu_dbias_cast_transpose, ) +from .cpp_extensions.gemm import sanitize_dims __all__ = [ @@ -98,27 +100,48 @@ def _gemm_bwd_rule( grad, ): x, kernel, pre_gelu_out, fuse_bias = ctx + x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - x_t_contracting = 0 if contracting_dims[0] == 1 else 1 - wgrad, dgelu, bgrad = gemm_impl( - x, + + kernel_t_contracting = ( + kernel.ndim - 2 if kernel_inner_dim == kernel.ndim - 1 else kernel.ndim - 1 + ) + # DGRAD: ([B], M, N) x (K, N)^T = ([B], M, K) + dgrad, dgelu, _ = gemm_impl( grad, + kernel, gelu_input=pre_gelu_out, - contracting_dims=(x_t_contracting, 0), + contracting_dims=(-1, kernel_t_contracting), fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, + fuse_bias=False, grad=True, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) - kernel_t_contracting = 1 if contracting_dims[1] == 0 else 0 - dgrad, *_ = gemm_impl( - dgelu if fuse_gelu else grad, - kernel, + # Collapse batch x sequence dimensions for WGRAD + x_outer_dim = x.ndim - 2 if x_inner_dim == x.ndim - 1 else x.ndim - 1 + wgrad_rhs = dgelu if fuse_gelu else grad + if x.ndim > 2: + batch_size = reduce(operator.mul, x.shape[:-2], 1) + x = jax.lax.reshape( + jax.lax.transpose(x, (*list(range(x.ndim - 2)), x_outer_dim, x_inner_dim)), + (batch_size * x.shape[x_outer_dim], x.shape[x_inner_dim]), + ) + wgrad_rhs = jnp.reshape( + wgrad_rhs, shape=(batch_size * wgrad_rhs.shape[-2], wgrad_rhs.shape[-1]) + ) + x_t_contracting = 0 + else: + x_t_contracting = x_outer_dim + + # WGRAD: ([B], M, K)^T x ([B], M, N) = ([B], K, N) + wgrad, _, bgrad = gemm_impl( + x, + wgrad_rhs, gelu_input=pre_gelu_out, - contracting_dims=(1, kernel_t_contracting), - fuse_gelu=fuse_gelu, + contracting_dims=(x_t_contracting, wgrad_rhs.ndim - 2), + fuse_gelu=False, fuse_bias=fuse_bias, grad=True, accumulate=accumulate, @@ -140,7 +163,6 @@ def fp8_gemm( fp8_meta: FP8MetaPackage, bias: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, - contracting_dims: Tuple[int, int] = (1, 1), fuse_gelu: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, @@ -152,7 +174,6 @@ def fp8_gemm( fp8_meta.amax_list, fp8_meta.scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator, @@ -162,12 +183,11 @@ def fp8_gemm( @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) def _fp8_gemm( x: ArrayLike, - kernel: ArrayLike, + kernel_t: ArrayLike, bias: ArrayLike, amax_list: ArrayLike, scale_list: ArrayLike, out_dtype: jnp.dtype, - contracting_dims: Tuple[int, int], fuse_gelu: bool, accumulate: bool, use_split_accumulator: bool, @@ -175,12 +195,11 @@ def _fp8_gemm( """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" out, _ = _fp8_gemm_fwd_rule( x, - kernel, + kernel_t, bias, amax_list, scale_list, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator, @@ -190,12 +209,11 @@ def _fp8_gemm( def _fp8_gemm_fwd_rule( x: ArrayLike, - kernel: ArrayLike, + kernel_t: ArrayLike, bias: ArrayLike, amax_list: ArrayLike, scale_list: ArrayLike, out_dtype: jnp.dtype, - contracting_dims: Tuple[int, int], fuse_gelu: bool, accumulate: bool, use_split_accumulator: bool, @@ -221,54 +239,36 @@ def _fp8_gemm_fwd_rule( x_scale = scale_list[FP8MetaPackage.INPUT_IDX] x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] if x.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: - if contracting_dims[0] == 0: - _, casted_x, updated_x_amax = cast_transpose( - x, - x_amax, - x_scale, - x_scale_inv, - fwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - else: - casted_x, updated_x_amax = cast_fp8(x, x_amax, x_scale, x_scale_inv, fwd_dtype) + casted_x, casted_x_t, updated_x_amax = cast_transpose( + x, + x_amax, + x_scale, + x_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) else: - if contracting_dims[0] == 0: - casted_x_t = x - casted_x = casted_x_t.transpose() - else: - casted_x = x + casted_x = x + casted_x_t = jnp.matrix_transpose(x) updated_x_amax = x_amax kernel_amax = amax_list[FP8MetaPackage.WEIGHT_IDX][0:1] kernel_scale = scale_list[FP8MetaPackage.WEIGHT_IDX] kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] - if kernel.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: - if contracting_dims[1] == 0: # need to transpose the kernel for FP8 GEMM - _, casted_kernel_t, updated_kernel_amax = cast_transpose( - kernel, - kernel_amax, - kernel_scale, - kernel_scale_inv, - fwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - else: - casted_kernel_t, updated_kernel_amax = cast_fp8( - kernel, - kernel_amax, - kernel_scale, - kernel_scale_inv, - fwd_dtype, - ) + if kernel_t.dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + casted_kernel_t, casted_kernel, updated_kernel_amax = cast_transpose( + kernel_t, + kernel_amax, + kernel_scale, + kernel_scale_inv, + fwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) else: - if contracting_dims[1] == 0: - casted_kernel = kernel - casted_kernel_t = casted_kernel.transpose() - else: - casted_kernel_t = kernel + casted_kernel = jnp.matrix_transpose(kernel_t) + casted_kernel_t = kernel_t updated_kernel_amax = kernel_amax out_amax = ( @@ -300,24 +300,24 @@ def _fp8_gemm_fwd_rule( updated_out_scale = None ctx = ( - casted_x, - casted_kernel_t, + casted_x_t, + casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax, updated_kernel_amax, + updated_out_amax, pre_gelu_out if fuse_gelu else None, fuse_bias, maybe_fp32_to_fm32, ) - return (out, updated_out_amax, updated_out_scale), ctx + return (out, updated_out_scale), ctx def _fp8_gemm_bwd_rule( out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator, @@ -325,83 +325,84 @@ def _fp8_gemm_bwd_rule( grad, ): ( - casted_x, - casted_kernel_t, + casted_x_t, + casted_kernel, amax_list, scale_list, scale_inv_list, updated_x_amax, updated_kernel_amax, + updated_out_amax, pre_gelu_out, fuse_bias, maybe_fp32_to_fm32, ) = ctx - fwd_dtype = FP8Helper.FWD_DTYPE bwd_dtype = FP8Helper.BWD_DTYPE grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_ID] - if fuse_bias and not fuse_gelu: - # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. - _, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( - grad, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) + if fuse_gelu: + if fuse_bias: + # Fuse dbias into this dGELU. + casted_grad, casted_grad_t, bgrad, updated_grad_amax = dact_lu_dbias_cast_transpose( + grad, + pre_gelu_out, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + activation_type=("gelu",), + ) + else: + # No bias to fuse so we just do dGELU. + casted_grad, casted_grad_t, updated_grad_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) + bgrad = None else: - # If both bias and GELU is fused into the forward pass, we will fuse dbias later with - # dGELU. No need to do it here. - _, casted_grad_t, updated_grad_amax = cast_transpose( - grad, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - ) - bgrad = None + if fuse_bias: + # Since there is no GELU fusion, we need to fuse dbias into this cast_transpose. + casted_grad, casted_grad_t, bgrad, updated_grad_amax = dbias_cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + else: + # If both bias and GELU is fused into the forward pass, we will fuse dbias later with + # dGELU. No need to do it here. + casted_grad, casted_grad_t, updated_grad_amax = cast_transpose( + grad, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + ) + bgrad = None - x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] - wgrad, *_ = fp8_gemm_impl( - casted_x, - x_scale_inv, - casted_grad_t, + kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] + dgrad, *_ = fp8_gemm_impl( + casted_grad, grad_scale_inv, + casted_kernel, + kernel_scale_inv, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) - if fuse_gelu and fuse_bias: - # Fuse dbias into this dGELU. - casted_dgelu, casted_dgelu_t, bgrad, updated_dgelu_amax = dact_lu_dbias_cast_transpose( - grad, - pre_gelu_out, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - activation_type=("gelu",), - ) - elif fuse_gelu: - # No bias to fuse so we just do dGELU. - casted_dgelu, casted_dgelu_t, updated_dgelu_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) - bgrad = None - - kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] - dgrad, *_ = gemm_impl( - casted_dgelu if fuse_gelu else grad, + x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] + wgrad, *_ = fp8_gemm_impl( + casted_x_t, + x_scale_inv, + casted_grad_t, grad_scale_inv, - casted_kernel_t, - kernel_scale_inv, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) @@ -412,6 +413,13 @@ def _fp8_gemm_bwd_rule( amax_list[FP8MetaPackage.WEIGHT_IDX] = ( amax_list[FP8MetaPackage.WEIGHT_IDX].at[0].set(updated_kernel_amax[0]) ) + amax_list[FP8MetaPackage.GRAD_IDX] = ( + amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) + ) + if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + amax_list[FP8MetaPackage.OUTPUT_IDX] = ( + amax_list[FP8MetaPackage.OUTPUT_IDX].at[0].set(updated_out_amax[0]) + ) amax_list = maybe_fp32_to_fm32(*amax_list) scale_list = maybe_fp32_to_fm32(*scale_list) @@ -433,20 +441,24 @@ def type_safe_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: - if x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] or kernel.dtype in [ - jnp.float8_e4m3fn, - jnp.float8_e5m2, - ]: + if (x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + or kernel.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]): assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." if fp8_meta is not None: + x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) + assert ( + x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 2 + ), ( + "FP8 GEMM requires non-transposed X (LHS) and transposed kernel (RHS), " + + "i.e. contracting_dims=(-1, -1)." + ) return fp8_gemm( x, kernel, bias, fp8_meta, out_dtype, - contracting_dims, fuse_gelu, accumulate, use_split_accumulator, From 448eaa99a3c3c93d8bcf2cb2d8ca6273f4f950d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Nov 2024 23:57:09 +0000 Subject: [PATCH 05/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 14 +++++++------- transformer_engine/jax/gemm.py | 11 +++++------ 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ee4c38d076..b935a5c2f7 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -101,8 +101,9 @@ def abstract( ), "Missing RHS operand scale inverse in FP8 GEMM." # Validate operand layouts - lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim)) + lhs_inner_dim, rhs_inner_dim = map( + sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) + ) assert ( lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." @@ -150,9 +151,7 @@ def abstract( rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) if rhs_batch_size > 1: - assert ( - lhs_batch_size == rhs_batch_size - ), ( + assert lhs_batch_size == rhs_batch_size, ( f"Leading dimensins of RHS ({rhs_batch_shape=}) is not broadcast-compatible " + f"with the leading dimensions of LHS ({lhs_batch_shape=})." ) @@ -235,8 +234,9 @@ def lowering( Fused attention fwd lowering rules """ lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in - lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, - (lhs_aval.ndim, rhs_aval.ndim)) + lhs_inner_dim, rhs_inner_dim = map( + sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) + ) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index e9e046d182..3cab17b10b 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -102,7 +102,6 @@ def _gemm_bwd_rule( x, kernel, pre_gelu_out, fuse_bias = ctx x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - kernel_t_contracting = ( kernel.ndim - 2 if kernel_inner_dim == kernel.ndim - 1 else kernel.ndim - 1 ) @@ -441,15 +440,15 @@ def type_safe_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: - if (x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] - or kernel.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]): + if x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] or kernel.dtype in [ + jnp.float8_e4m3fn, + jnp.float8_e5m2, + ]: assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." if fp8_meta is not None: x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - assert ( - x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 2 - ), ( + assert x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 2, ( "FP8 GEMM requires non-transposed X (LHS) and transposed kernel (RHS), " + "i.e. contracting_dims=(-1, -1)." ) From cb6ae3cf7570285a13aae30b414a3a7ec19b4f6c Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Mon, 18 Nov 2024 22:31:35 +0000 Subject: [PATCH 06/39] batched operands for GEMM custom op seem to be working now Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 151 +++++++++++++----- transformer_engine/jax/gemm.py | 26 +-- 2 files changed, 119 insertions(+), 58 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index b935a5c2f7..cf029d16db 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -136,8 +136,11 @@ def abstract( out_scale_updated_dtype = jnp.float32 # Make sure leading dimensions of RHS is broadcast-compatible with LHS - lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 - rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + lhs_outer_dim, rhs_outer_dim = map( + lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + (lhs_inner_dim, rhs_inner_dim), + (lhs_aval.ndim, rhs_aval.ndim) + ) lhs_bdims = [ dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] @@ -152,12 +155,17 @@ def abstract( rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) if rhs_batch_size > 1: assert lhs_batch_size == rhs_batch_size, ( - f"Leading dimensins of RHS ({rhs_batch_shape=}) is not broadcast-compatible " - + f"with the leading dimensions of LHS ({lhs_batch_shape=})." + f"Leading dimensins of RHS ({rhs_aval.shape=}) is not broadcast-compatible " + + f"with the leading dimensions of LHS ({lhs_aval.shape=})." ) - # Infer output shape + # Infer output shape: out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + if lhs_aval.ndim > 2 and rhs_aval.ndim > 2 and lhs_batch_size > 1: + # When both RHS and LHS are batched, the batch dimensions are collapsed into the + # contracting dimension. + out_shape = (lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + # Validate bias/bias_grad shape against inferred output bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype if fuse_bias: @@ -169,9 +177,16 @@ def abstract( assert bias_aval.size == 0, "Internal TE error." # Validate GELU input/output + gelu_shape = (0, ) if fuse_gelu: - assert all( - [gelu_input_aval.shape[i] == out_shape[i] for i in len(out_shape)] + gelu_shape = ( + (reduce(operator.mul, out_shape[:-1], 1), out_shape[-1]) + if len(out_shape) > 2 + else out_shape + ) + assert ( + gelu_input_aval.ndim == 2 + and all([gelu_input_aval.shape[i] == gelu_shape[i] for i in len(gelu_shape)]) ), "Invalid GELU input shape." assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." else: @@ -185,7 +200,7 @@ def abstract( out_scale_updated_aval = out_scale_aval.update( shape=out_scale_aval.shape, dtype=out_scale_updated_dtype ) - pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_input_aval.shape, dtype=bias_dtype) + pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_shape, dtype=bias_dtype) bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) workspace_aval = jax.core.ShapedArray( shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 @@ -285,8 +300,11 @@ def lowering( ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - lhs_outer_dim = lhs_aval.ndim - 1 if lhs_trans else lhs_aval.ndim - 2 - rhs_outer_dim = rhs_aval.ndim - 2 if rhs_trans else rhs_aval.ndim - 1 + lhs_outer_dim, rhs_outer_dim = map( + lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + (lhs_inner_dim, rhs_inner_dim), + (lhs_aval.ndim, rhs_aval.ndim) + ) m = lhs_aval.shape[lhs_outer_dim] k = rhs_aval.shape[rhs_inner_dim] n = rhs_aval.shape[rhs_outer_dim] @@ -338,6 +356,43 @@ def impl( ): assert CollectiveGemmPrimitive.inner_primitive is not None + lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) + lhs_trans = lhs_inner_dim != lhs.ndim - 1 + rhs_trans = rhs_inner_dim == rhs.ndim - 1 + + # Squeeze batch dimensions of size 1 without any modification. + squeeze_dims = [] + expand_out = False + if lhs.ndim > 2: + squeeze_dims = [dim for dim in range(lhs.ndim - 2) if lhs.shape[dim] == 1] + if len(squeeze_dims) > 0: + expand_out = True + lhs = jax.lax.squeeze(lhs, squeeze_dims) + contracting_dims = (lhs.ndim - 2 if lhs_trans else lhs.ndim - 1, + contracting_dims[1]) + if rhs.ndim > 2: + rhs_squeeze_dims = [dim for dim in range(rhs.ndim - 2) if rhs.shape[dim] == 1] + if len(squeeze_dims) > 0: + rhs = jax.lax.squeeze(rhs, rhs_squeeze_dims) + contracting_dims = (contracting_dims[0], + rhs.ndim - 1 if rhs_trans else rhs.ndim - 2) + + # Collapse batch dimensions that are larger thanm size 1. + # FWD: (B, M, K) x (K, N) = (B*M, K) x (K, N) = (B*M, N) + # DGRAD: (B, M, N) x (K, N)^T = (B*M, N) x (N, K) = (B*M, K) + # WGRAD: (B, M, K)^T x (B, M, N) = (K, B*M) x (B*M, N) = (K, N) + batch_shape = [lhs.shape[dim] for dim in range(lhs.ndim - 2)] + batch_size = reduce(operator.mul, batch_shape, 1) + reshape_output = not (lhs.ndim > 2 and rhs.ndim > 2) + if lhs.ndim > 2: + lhs_2d_shape = (batch_size * lhs.shape[-2], lhs.shape[-1]) + lhs = jax.lax.reshape(lhs, lhs_2d_shape) + contracting_dims = (0 if lhs_trans else 1, contracting_dims[1]) + if rhs.ndim > 2: + rhs_2d_shape = (reduce(operator.mul, rhs.shape[:-1], 1), rhs.shape[-1]) + rhs = jax.lax.reshape(rhs, rhs_2d_shape) + contracting_dims = (contracting_dims[0], 1 if rhs_trans else 0) + ( out, out_amax_updated, @@ -362,6 +417,15 @@ def impl( accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) + + # Recover batched dimensions in the output + if reshape_output: + out_batched_shape = (*batch_shape, int(out.shape[-2] / batch_size), out.shape[-1]) + out = jax.lax.reshape(out, out_batched_shape) + + if expand_out: + out = jax.lax.expand_dims(out, squeeze_dims) + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad @staticmethod @@ -381,16 +445,19 @@ def batcher( check_valid_batch_dims(batch_dims) lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims - return CollectiveGemmPrimitive.outer_primitive.bind( - *batched_args, - out_dtype=out_dtype, - contracting_dims=contracting_dims, - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, - grad=grad, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - )(lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) + return ( + CollectiveGemmPrimitive.outer_primitive.bind( + *batched_args, + out_dtype=out_dtype, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=grad, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + ), + (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) + ) @staticmethod def infer_sharding_from_operands( @@ -417,10 +484,12 @@ def infer_sharding_from_operands( + "not already partitioned correctly." ) - lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == rhs.ndim - 1 - lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if rhs_trans else rhs.ndim - 1 + lhs_outer_dim, rhs_outer_dim = map( + lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim) + ) + rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] @@ -430,18 +499,20 @@ def infer_sharding_from_operands( # Outer (sequence) dimension of the GEMM output is always unsharded out_spec = [*batch_specs, None, rhs_outer_spec] + batch_size = reduce(operator.mul, lhs.shape[:-2], 1) + if lhs.ndim > 2 and rhs.ndim > 2 and batch_size > 1: + out_spec = [None, rhs_outer_spec] out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) # FP8 metas are always unsharded fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) - # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded - gelu_spec = out_spec if fuse_gelu else [None] + # Pre-GELU output matches output, if GELU fusion is turned on, otherwise unsharded + gelu_spec = [None, rhs_outer_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) # Bias gradient spec matches outer dimension of output if bias fusion is turned on bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) - return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) @staticmethod @@ -462,11 +533,11 @@ def partition( lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) - - lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == rhs.ndim - 1 - lhs_outer_dim = lhs.ndim - 1 if lhs_trans else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if rhs_trans else rhs.ndim - 1 + lhs_outer_dim, rhs_outer_dim = map( + lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim) + ) lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] @@ -488,10 +559,13 @@ def partition( # Outer (sequence) dimension of the GEMM output is always unsharded out_spec = [*batch_specs, None, rhs_outer_spec] + batch_size = reduce(operator.mul, lhs.shape[:-2], 1) + if lhs.ndim > 2 and rhs.ndim > 2 and batch_size > 1: + out_spec = [None, rhs_outer_spec] out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded - gelu_spec = out_spec if fuse_gelu else [None] + gelu_spec = [None, rhs_outer_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) arg_shardings = ( @@ -547,10 +621,10 @@ def sharded_impl( # GEMM output needs to be all-reduced when the contracting dimension is sharded. # If the layer is sequence-parallel, we also need to scatter the output, which we # can combine into a reduce-scatter here. - out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().cp_resource, mesh) + out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().tp_resource, mesh) if fuse_gelu: pre_gelu_out = lax_paral_op( - pre_gelu_out, jax.lax.psum, global_mesh_resource().cp_resource, mesh + pre_gelu_out, jax.lax.psum, global_mesh_resource().tp_resource, mesh ) return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad @@ -629,8 +703,11 @@ def gemm_impl( ) -> Tuple[ArrayLike, ...]: """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) - lhs_outer_dim = lhs.ndim - 1 if lhs_inner_dim == lhs.ndim - 2 else lhs.ndim - 2 - rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 + lhs_outer_dim, rhs_outer_dim = map( + lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim) + ) out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) if not fuse_bias: diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 3cab17b10b..01ee60f24b 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -101,16 +101,15 @@ def _gemm_bwd_rule( ): x, kernel, pre_gelu_out, fuse_bias = ctx x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) + x_outer_dim = x.ndim - 1 if x_inner_dim != x.ndim - 1 else x.ndim - 2 + kernel_outer_dim = kernel.ndim - 2 if kernel_inner_dim == kernel.ndim - 1 else kernel.ndim - 1 - kernel_t_contracting = ( - kernel.ndim - 2 if kernel_inner_dim == kernel.ndim - 1 else kernel.ndim - 1 - ) # DGRAD: ([B], M, N) x (K, N)^T = ([B], M, K) dgrad, dgelu, _ = gemm_impl( grad, kernel, gelu_input=pre_gelu_out, - contracting_dims=(-1, kernel_t_contracting), + contracting_dims=(-1, kernel_outer_dim), fuse_gelu=fuse_gelu, fuse_bias=False, grad=True, @@ -118,28 +117,13 @@ def _gemm_bwd_rule( use_split_accumulator=use_split_accumulator, ) - # Collapse batch x sequence dimensions for WGRAD - x_outer_dim = x.ndim - 2 if x_inner_dim == x.ndim - 1 else x.ndim - 1 + # WGRAD: ([B], M, K)^T x ([B], M, N) = (K, N) wgrad_rhs = dgelu if fuse_gelu else grad - if x.ndim > 2: - batch_size = reduce(operator.mul, x.shape[:-2], 1) - x = jax.lax.reshape( - jax.lax.transpose(x, (*list(range(x.ndim - 2)), x_outer_dim, x_inner_dim)), - (batch_size * x.shape[x_outer_dim], x.shape[x_inner_dim]), - ) - wgrad_rhs = jnp.reshape( - wgrad_rhs, shape=(batch_size * wgrad_rhs.shape[-2], wgrad_rhs.shape[-1]) - ) - x_t_contracting = 0 - else: - x_t_contracting = x_outer_dim - - # WGRAD: ([B], M, K)^T x ([B], M, N) = ([B], K, N) wgrad, _, bgrad = gemm_impl( x, wgrad_rhs, gelu_input=pre_gelu_out, - contracting_dims=(x_t_contracting, wgrad_rhs.ndim - 2), + contracting_dims=(x_outer_dim, wgrad_rhs.ndim - 2), fuse_gelu=False, fuse_bias=fuse_bias, grad=True, From 6f673559d250c9cf9c2713201da256b641cad279 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 22:32:02 +0000 Subject: [PATCH 07/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index cf029d16db..0948139dc9 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -139,7 +139,7 @@ def abstract( lhs_outer_dim, rhs_outer_dim = map( lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, (lhs_inner_dim, rhs_inner_dim), - (lhs_aval.ndim, rhs_aval.ndim) + (lhs_aval.ndim, rhs_aval.ndim), ) lhs_bdims = [ @@ -177,16 +177,15 @@ def abstract( assert bias_aval.size == 0, "Internal TE error." # Validate GELU input/output - gelu_shape = (0, ) + gelu_shape = (0,) if fuse_gelu: gelu_shape = ( (reduce(operator.mul, out_shape[:-1], 1), out_shape[-1]) if len(out_shape) > 2 else out_shape ) - assert ( - gelu_input_aval.ndim == 2 - and all([gelu_input_aval.shape[i] == gelu_shape[i] for i in len(gelu_shape)]) + assert gelu_input_aval.ndim == 2 and all( + [gelu_input_aval.shape[i] == gelu_shape[i] for i in len(gelu_shape)] ), "Invalid GELU input shape." assert gelu_input_aval.dtype == bias_dtype, "Invalid GELU dtype." else: @@ -303,7 +302,7 @@ def lowering( lhs_outer_dim, rhs_outer_dim = map( lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, (lhs_inner_dim, rhs_inner_dim), - (lhs_aval.ndim, rhs_aval.ndim) + (lhs_aval.ndim, rhs_aval.ndim), ) m = lhs_aval.shape[lhs_outer_dim] k = rhs_aval.shape[rhs_inner_dim] @@ -368,14 +367,18 @@ def impl( if len(squeeze_dims) > 0: expand_out = True lhs = jax.lax.squeeze(lhs, squeeze_dims) - contracting_dims = (lhs.ndim - 2 if lhs_trans else lhs.ndim - 1, - contracting_dims[1]) + contracting_dims = ( + lhs.ndim - 2 if lhs_trans else lhs.ndim - 1, + contracting_dims[1], + ) if rhs.ndim > 2: rhs_squeeze_dims = [dim for dim in range(rhs.ndim - 2) if rhs.shape[dim] == 1] if len(squeeze_dims) > 0: rhs = jax.lax.squeeze(rhs, rhs_squeeze_dims) - contracting_dims = (contracting_dims[0], - rhs.ndim - 1 if rhs_trans else rhs.ndim - 2) + contracting_dims = ( + contracting_dims[0], + rhs.ndim - 1 if rhs_trans else rhs.ndim - 2, + ) # Collapse batch dimensions that are larger thanm size 1. # FWD: (B, M, K) x (K, N) = (B*M, K) x (K, N) = (B*M, N) @@ -456,7 +459,7 @@ def batcher( accumulate=accumulate, use_split_accumulator=use_split_accumulator, ), - (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims) + (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims), ) @staticmethod @@ -487,7 +490,7 @@ def infer_sharding_from_operands( lhs_outer_dim, rhs_outer_dim = map( lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, (lhs_inner_dim, rhs_inner_dim), - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] @@ -536,7 +539,7 @@ def partition( lhs_outer_dim, rhs_outer_dim = map( lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, (lhs_inner_dim, rhs_inner_dim), - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] @@ -706,7 +709,7 @@ def gemm_impl( lhs_outer_dim, rhs_outer_dim = map( lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, (lhs_inner_dim, rhs_inner_dim), - (lhs.ndim, rhs.ndim) + (lhs.ndim, rhs.ndim), ) out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) From 4b2b2d44d735714ea9917fb00748c77e473fdafa Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 19 Nov 2024 17:57:33 +0000 Subject: [PATCH 08/39] fixed batch size 1 issue and enabled FSDP sharding for RHS operand Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 65 ++++++++----------- transformer_engine/jax/gemm.py | 18 +++-- 2 files changed, 39 insertions(+), 44 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0948139dc9..431dea6c1d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -161,7 +161,7 @@ def abstract( # Infer output shape: out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) - if lhs_aval.ndim > 2 and rhs_aval.ndim > 2 and lhs_batch_size > 1: + if lhs_aval.ndim > 2 and rhs_aval.ndim > 2: # When both RHS and LHS are batched, the batch dimensions are collapsed into the # contracting dimension. out_shape = (lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) @@ -359,27 +359,6 @@ def impl( lhs_trans = lhs_inner_dim != lhs.ndim - 1 rhs_trans = rhs_inner_dim == rhs.ndim - 1 - # Squeeze batch dimensions of size 1 without any modification. - squeeze_dims = [] - expand_out = False - if lhs.ndim > 2: - squeeze_dims = [dim for dim in range(lhs.ndim - 2) if lhs.shape[dim] == 1] - if len(squeeze_dims) > 0: - expand_out = True - lhs = jax.lax.squeeze(lhs, squeeze_dims) - contracting_dims = ( - lhs.ndim - 2 if lhs_trans else lhs.ndim - 1, - contracting_dims[1], - ) - if rhs.ndim > 2: - rhs_squeeze_dims = [dim for dim in range(rhs.ndim - 2) if rhs.shape[dim] == 1] - if len(squeeze_dims) > 0: - rhs = jax.lax.squeeze(rhs, rhs_squeeze_dims) - contracting_dims = ( - contracting_dims[0], - rhs.ndim - 1 if rhs_trans else rhs.ndim - 2, - ) - # Collapse batch dimensions that are larger thanm size 1. # FWD: (B, M, K) x (K, N) = (B*M, K) x (K, N) = (B*M, N) # DGRAD: (B, M, N) x (K, N)^T = (B*M, N) x (N, K) = (B*M, K) @@ -426,9 +405,6 @@ def impl( out_batched_shape = (*batch_shape, int(out.shape[-2] / batch_size), out.shape[-1]) out = jax.lax.reshape(out, out_batched_shape) - if expand_out: - out = jax.lax.expand_dims(out, squeeze_dims) - return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad @staticmethod @@ -497,13 +473,9 @@ def infer_sharding_from_operands( batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] rhs_outer_spec = rhs_spec[rhs_outer_dim] - if rhs_spec[rhs_inner_dim] is not None and rhs_outer_spec is not None: - raise RuntimeError("Both inner and outer dimensions of RHS cannot be sharded.") - # Outer (sequence) dimension of the GEMM output is always unsharded out_spec = [*batch_specs, None, rhs_outer_spec] - batch_size = reduce(operator.mul, lhs.shape[:-2], 1) - if lhs.ndim > 2 and rhs.ndim > 2 and batch_size > 1: + if lhs.ndim > 2 and rhs.ndim > 2: out_spec = [None, rhs_outer_spec] out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) @@ -543,7 +515,6 @@ def partition( ) lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] - rhs_outer_spec = rhs_spec[rhs_outer_dim] # Force all-gather the outer (sequence) dimension of the LHS operand lhs_spec_new = [spec for spec in lhs_spec] @@ -551,8 +522,29 @@ def partition( lhs_spec_new[lhs_inner_dim] = rhs_spec[rhs_inner_dim] lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) + # If both dims of RHS is sharded (i.e. FSDP), determine if we do AG or AR based on LHS + # sharding. + rhs_spec_new = [spec for spec in rhs_spec] + if rhs_spec[rhs_inner_dim] is not None and rhs_spec[rhs_outer_dim] is not None: + if lhs_spec[lhs_inner_dim] is not None and lhs_spec[lhs_outer_dim] is not None: + # All dimensions of both LHS and RHS are sharded and the collective operation is + # ambiguous, we cannot infer sharding. + raise RuntimeError( + "Collective GEMM custom op cannot infer partitioning when both outer and " + + "contracting dimensions of both LHS and RHS operands are sharded." + ) + elif lhs_spec[lhs_inner_dim] is not None: + # All-reduce after GEMM, so unshard the outer dimension of RHS + rhs_spec_new[rhs_outer_dim] = None + else: + # We either do all-gather before GEMM, or LHS is already unsharded, so unshard + # the inner dimension of RHS to match + rhs_spec_new[rhs_inner_dim] = None + + rhs_outer_spec = rhs_spec_new[rhs_outer_dim] + # RHS operand is unchanged, we already enforce that only one dimension can be sharded - rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec)) + rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec_new)) # Bias is sharded to match outer dimension spec of the RHS operand (also the output) bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) @@ -562,8 +554,7 @@ def partition( # Outer (sequence) dimension of the GEMM output is always unsharded out_spec = [*batch_specs, None, rhs_outer_spec] - batch_size = reduce(operator.mul, lhs.shape[:-2], 1) - if lhs.ndim > 2 and rhs.ndim > 2 and batch_size > 1: + if lhs.ndim > 2 and rhs.ndim > 2: out_spec = [None, rhs_outer_spec] out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) @@ -620,10 +611,8 @@ def sharded_impl( if jax_dtype_is_fp8(lhs.dtype): out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) - if rhs_spec[rhs_inner_dim] is not None: - # GEMM output needs to be all-reduced when the contracting dimension is sharded. - # If the layer is sequence-parallel, we also need to scatter the output, which we - # can combine into a reduce-scatter here. + # GEMM output needs to be all-reduced when the contracting dimension is sharded. + if rhs_spec_new[rhs_inner_dim] is not None: out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().tp_resource, mesh) if fuse_gelu: pre_gelu_out = lax_paral_op( diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 01ee60f24b..3b562e4ffa 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -8,13 +8,11 @@ import jax import jax.numpy as jnp from jax.typing import ArrayLike -from jax.ad_checkpoint import checkpoint_name from .fp8 import FP8Helper, FP8MetaPackage from .cpp_extensions import ( gemm_impl, fp8_gemm_impl, - cast_fp8, cast_transpose, dact_lu, dbias_cast_transpose, @@ -68,6 +66,10 @@ def _gemm_fwd_rule( accumulate: bool, use_split_accumulator: bool, ) -> Tuple[ArrayLike, ...]: + assert kernel.ndim == 2, ( + "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." + ) + fuse_bias = bias is not None out, pre_gelu_out = gemm_impl( @@ -142,7 +144,7 @@ def _gemm_bwd_rule( def fp8_gemm( x: ArrayLike, - kernel: ArrayLike, + kernel_t: ArrayLike, fp8_meta: FP8MetaPackage, bias: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, @@ -150,9 +152,10 @@ def fp8_gemm( accumulate: bool = False, use_split_accumulator: bool = False, ) -> ArrayLike: + """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" return _fp8_gemm( x, - kernel, + kernel_t, bias, fp8_meta.amax_list, fp8_meta.scale_list, @@ -175,7 +178,6 @@ def _fp8_gemm( accumulate: bool, use_split_accumulator: bool, ) -> ArrayLike: - """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" out, _ = _fp8_gemm_fwd_rule( x, kernel_t, @@ -201,6 +203,10 @@ def _fp8_gemm_fwd_rule( accumulate: bool, use_split_accumulator: bool, ) -> Tuple[ArrayLike, ...]: + assert kernel_t.ndim == 2, ( + "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." + ) + fuse_bias = bias is not None maybe_fm32_to_fp32, maybe_fp32_to_fm32 = FP8Helper.generate_fp8_meta_dtype_converter_pair( @@ -432,7 +438,7 @@ def type_safe_gemm( if fp8_meta is not None: x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - assert x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 2, ( + assert x_inner_dim == x.ndim - 1 and kernel_inner_dim == kernel.ndim - 1, ( "FP8 GEMM requires non-transposed X (LHS) and transposed kernel (RHS), " + "i.e. contracting_dims=(-1, -1)." ) From 2b2753e2463ce788f5f7c582e898a304156b4f54 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:58:03 +0000 Subject: [PATCH 09/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/gemm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 3b562e4ffa..730d17846e 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -66,9 +66,9 @@ def _gemm_fwd_rule( accumulate: bool, use_split_accumulator: bool, ) -> Tuple[ArrayLike, ...]: - assert kernel.ndim == 2, ( - "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." - ) + assert ( + kernel.ndim == 2 + ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." fuse_bias = bias is not None @@ -203,9 +203,9 @@ def _fp8_gemm_fwd_rule( accumulate: bool, use_split_accumulator: bool, ) -> Tuple[ArrayLike, ...]: - assert kernel_t.ndim == 2, ( - "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." - ) + assert ( + kernel_t.ndim == 2 + ), "TE/JAX Collective GEMM custom op does not support batched RHS operand in forward mode." fuse_bias = bias is not None From 969f597cb11fe9fd5b9780e57e818d402704fc0c Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 09:28:44 +0000 Subject: [PATCH 10/39] fixed FSDP+TP w/ DP=1 and TP+DP, but FSDP+TP w/ DP>1 still crashes Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 283 +++++++++++------- transformer_engine/jax/gemm.py | 29 +- 2 files changed, 205 insertions(+), 107 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 431dea6c1d..bf80941f85 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -4,7 +4,8 @@ import warnings import operator from functools import reduce -from typing import Optional, Union, Tuple +from typing import Optional, Tuple +from collections.abc import Iterable import jax import jax.numpy as jnp @@ -42,6 +43,34 @@ def sanitize_dims(dim, ndims): return (ndims + dim) if dim < 0 else dim +def mirror_dim(dim, ndims): + return ndims - 2 if dim == ndims - 1 else ndims - 1 + + +def remove_fsdp_specs(pspecs): + fsdp_resource = global_mesh_resource().fsdp_resource + new_pspecs = [] + for spec in pspecs: + if spec is None: + new_pspecs.append(None) + elif fsdp_resource not in spec: + new_pspecs.append(spec) + elif isinstance(spec, Iterable) and not isinstance(spec, str): + new_spec = [] + for s in spec: + if s != fsdp_resource: + new_spec.append(s) + if len(new_spec) > 1: + new_pspecs.append(new_spec) + elif len(new_spec) == 1: + new_pspecs.append(new_spec[0]) + else: + new_pspecs.append(None) + else: + new_pspecs.append(None) + return new_pspecs + + def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if tex.get_device_compute_capability() >= 90: @@ -55,7 +84,7 @@ class CollectiveGemmPrimitive(BasePrimitive): """ name = "te_gemm" - impl_static_args = (8, 9, 10, 11, 12, 13, 14) + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15) multiple_results = True inner_primitive = None outer_primitive = None @@ -71,6 +100,7 @@ def abstract( out_amax_aval, out_scale_aval, out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -137,33 +167,40 @@ def abstract( # Make sure leading dimensions of RHS is broadcast-compatible with LHS lhs_outer_dim, rhs_outer_dim = map( - lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs_aval.ndim, rhs_aval.ndim), ) - lhs_bdims = [ dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] ] lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) - if rhs_aval.ndim > 2: - rhs_bdims = [ - dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] - ] - rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] - rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) - if rhs_batch_size > 1: + + # Infer output shape + if batched_output: + assert lhs_aval.ndim > 2 and rhs_aval.ndim == 2, ( + "Batched output requires batched LHS and non-batched RHS operands." + ) + out_shape = ( + *lhs_batch_shape, + lhs_aval.shape[lhs_outer_dim], + rhs_aval.shape[rhs_outer_dim] + ) + else: + assert lhs_aval.ndim == rhs_aval.ndim, ( + "Non-batched output requires LHS and RHS operands with same number of dimensions." + ) + if lhs_aval.ndim > 2: + rhs_bdims = [ + dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] + ] + rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) assert lhs_batch_size == rhs_batch_size, ( f"Leading dimensins of RHS ({rhs_aval.shape=}) is not broadcast-compatible " + f"with the leading dimensions of LHS ({lhs_aval.shape=})." ) - - # Infer output shape: - out_shape = (*lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) - if lhs_aval.ndim > 2 and rhs_aval.ndim > 2: - # When both RHS and LHS are batched, the batch dimensions are collapsed into the - # contracting dimension. out_shape = (lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) # Validate bias/bias_grad shape against inferred output @@ -237,6 +274,7 @@ def lowering( out_scale, *, out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -247,6 +285,7 @@ def lowering( """ Fused attention fwd lowering rules """ + del batched_output lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in lhs_inner_dim, rhs_inner_dim = map( sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) @@ -300,9 +339,9 @@ def lowering( args = CustomCallArgsWrapper(out_types, operands, operand_shapes) lhs_outer_dim, rhs_outer_dim = map( - lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + mirror_dim, (lhs_inner_dim, rhs_inner_dim), - (lhs_aval.ndim, rhs_aval.ndim), + (lhs.ndim, rhs.ndim), ) m = lhs_aval.shape[lhs_outer_dim] k = rhs_aval.shape[rhs_inner_dim] @@ -346,6 +385,7 @@ def impl( out_amax, out_scale, out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -356,25 +396,59 @@ def impl( assert CollectiveGemmPrimitive.inner_primitive is not None lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) - lhs_trans = lhs_inner_dim != lhs.ndim - 1 - rhs_trans = rhs_inner_dim == rhs.ndim - 1 - - # Collapse batch dimensions that are larger thanm size 1. - # FWD: (B, M, K) x (K, N) = (B*M, K) x (K, N) = (B*M, N) - # DGRAD: (B, M, N) x (K, N)^T = (B*M, N) x (N, K) = (B*M, K) - # WGRAD: (B, M, K)^T x (B, M, N) = (K, B*M) x (B*M, N) = (K, N) - batch_shape = [lhs.shape[dim] for dim in range(lhs.ndim - 2)] - batch_size = reduce(operator.mul, batch_shape, 1) - reshape_output = not (lhs.ndim > 2 and rhs.ndim > 2) - if lhs.ndim > 2: - lhs_2d_shape = (batch_size * lhs.shape[-2], lhs.shape[-1]) - lhs = jax.lax.reshape(lhs, lhs_2d_shape) - contracting_dims = (0 if lhs_trans else 1, contracting_dims[1]) - if rhs.ndim > 2: - rhs_2d_shape = (reduce(operator.mul, rhs.shape[:-1], 1), rhs.shape[-1]) - rhs = jax.lax.reshape(rhs, rhs_2d_shape) - contracting_dims = (contracting_dims[0], 1 if rhs_trans else 0) + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs.ndim, rhs.ndim) + ) + + # Infer output shape and collapse batch dimensions + lhs_2d_shape = rhs_2d_shape = None + lhs_layout = rhs_layout = None + lhs_batch_dims = [ + dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim] + ] + lhs_batch_shape = [lhs.shape[dim] for dim in lhs_batch_dims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + contracting_dims_2d = list(contracting_dims).copy() + if batched_output: + # If output is batched, the LSH batch dimension collapses into the outer dimension + # and RHS cannot be batched + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_outer_dim], lhs.shape[lhs_inner_dim]) + lhs_layout = (*lhs_batch_dims, lhs_outer_dim, lhs_inner_dim) + contracting_dims_2d[0] = 1 + else: + # If the output is not batched, both LHS and RHS batch dimensions collapse into the + # contracting dimensions + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_inner_dim], lhs.shape[lhs_outer_dim]) + lhs_layout = (*lhs_batch_dims, lhs_inner_dim, lhs_outer_dim) + contracting_dims_2d[0] = 0 + + rhs_batch_dims = [ + dim for dim in range(rhs.ndim) if dim not in [rhs_inner_dim, rhs_outer_dim] + ] + rhs_batch_shape = [rhs.shape[dim] for dim in rhs_batch_dims] + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + rhs_2d_shape = (rhs_batch_size * rhs.shape[rhs_inner_dim], rhs.shape[rhs_outer_dim]) + rhs_layout = (*rhs_batch_dims, rhs_inner_dim, rhs_outer_dim) + contracting_dims_2d[1] = 0 + + # Reshape LHS and RHS into 2D and fix layouts for FP8 GEMM + if lhs_2d_shape is not None and lhs.ndim > 2: + lhs = jax.lax.reshape(lhs, lhs_2d_shape, dimensions=lhs_layout) + if jax_dtype_is_fp8(lhs.dtype): + lhs = jax.lax.transpose(lhs, (1, 0)) + contracting_dims_2d[0] = 1 + else: + contracting_dims_2d[0] = contracting_dims[0] + + if rhs_2d_shape is not None and rhs.ndim > 2: + rhs = jax.lax.reshape(rhs, rhs_2d_shape, dimensions=rhs_layout) + if jax_dtype_is_fp8(rhs.dtype): + rhs = jax.lax.transpose(rhs, (1, 0)) + contracting_dims_2d[1] = 1 + else: + contracting_dims_2d[1] = contracting_dims[1] + # Invoke GEMM with guaranteed 2D inputs, so batched_output=False ( out, out_amax_updated, @@ -392,7 +466,8 @@ def impl( out_amax, out_scale, out_dtype=out_dtype, - contracting_dims=contracting_dims, + batched_output=False, + contracting_dims=contracting_dims_2d, fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, grad=grad, @@ -401,9 +476,9 @@ def impl( ) # Recover batched dimensions in the output - if reshape_output: - out_batched_shape = (*batch_shape, int(out.shape[-2] / batch_size), out.shape[-1]) - out = jax.lax.reshape(out, out_batched_shape) + if batched_output: + out_shape = (*lhs_batch_shape, out.shape[-2] // lhs_batch_size, out.shape[-1]) + out = jax.lax.reshape(out, out_shape) return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad @@ -413,6 +488,7 @@ def batcher( batch_dims, *, out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -428,6 +504,7 @@ def batcher( CollectiveGemmPrimitive.outer_primitive.bind( *batched_args, out_dtype=out_dtype, + batched_output=batched_output, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, @@ -441,6 +518,7 @@ def batcher( @staticmethod def infer_sharding_from_operands( out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -456,34 +534,43 @@ def infer_sharding_from_operands( lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) - if lhs_spec[lhs_inner_dim] != rhs_spec[rhs_inner_dim] and not grad: + lhs_outer_dim, rhs_outer_dim = map( + mirror_dim, + (lhs_inner_dim, rhs_inner_dim), + (lhs.ndim, rhs.ndim), + ) + + # Modify operand specs: + # - FSDP axes are all-gathered + # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded + # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension + lhs_spec_new = remove_fsdp_specs(lhs_spec) + rhs_spec_new = remove_fsdp_specs(rhs_spec) + if lhs_spec_new[lhs_inner_dim] != rhs_spec_new[rhs_inner_dim] and not grad: warnings.warn( "Forcing the inner dimension of LHS to match the sharding of inner " + "dimension of RHS. This can trigger additional communication if LHS is " + "not already partitioned correctly." ) - - lhs_outer_dim, rhs_outer_dim = map( - lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, - (lhs_inner_dim, rhs_inner_dim), - (lhs.ndim, rhs.ndim), - ) - rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 - lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] - batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] - rhs_outer_spec = rhs_spec[rhs_outer_dim] - - # Outer (sequence) dimension of the GEMM output is always unsharded - out_spec = [*batch_specs, None, rhs_outer_spec] - if lhs.ndim > 2 and rhs.ndim > 2: - out_spec = [None, rhs_outer_spec] + rhs_outer_spec = rhs_spec_new[rhs_outer_dim] + if rhs_outer_spec is not None: + lhs_spec_new[lhs_outer_dim] = None + lhs_spec_new[lhs_inner_dim] = rhs_spec_new[rhs_inner_dim] + + # Output sharding is conditional on output shape + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] + lhs_outer_spec = lhs_spec_new[lhs_outer_dim] + out_spec = [lhs_outer_spec, rhs_outer_spec] + if batched_output: + out_spec = batch_spec + out_spec out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) # FP8 metas are always unsharded fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) - # Pre-GELU output matches output, if GELU fusion is turned on, otherwise unsharded - gelu_spec = [None, rhs_outer_spec] if fuse_gelu else [None] + # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded + gelu_spec = [lhs_outer_spec, rhs_outer_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) # Bias gradient spec matches outer dimension of output if bias fusion is turned on @@ -493,6 +580,7 @@ def infer_sharding_from_operands( @staticmethod def partition( out_dtype, + batched_output, contracting_dims, fuse_gelu, fuse_bias, @@ -509,41 +597,22 @@ def partition( lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) lhs_outer_dim, rhs_outer_dim = map( - lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs.ndim, rhs.ndim), ) - lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim]] - batch_specs = [lhs_spec[bdim] for bdim in lhs_bdims] - - # Force all-gather the outer (sequence) dimension of the LHS operand - lhs_spec_new = [spec for spec in lhs_spec] - lhs_spec_new[lhs_outer_dim] = None - lhs_spec_new[lhs_inner_dim] = rhs_spec[rhs_inner_dim] - lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) - - # If both dims of RHS is sharded (i.e. FSDP), determine if we do AG or AR based on LHS - # sharding. - rhs_spec_new = [spec for spec in rhs_spec] - if rhs_spec[rhs_inner_dim] is not None and rhs_spec[rhs_outer_dim] is not None: - if lhs_spec[lhs_inner_dim] is not None and lhs_spec[lhs_outer_dim] is not None: - # All dimensions of both LHS and RHS are sharded and the collective operation is - # ambiguous, we cannot infer sharding. - raise RuntimeError( - "Collective GEMM custom op cannot infer partitioning when both outer and " - + "contracting dimensions of both LHS and RHS operands are sharded." - ) - elif lhs_spec[lhs_inner_dim] is not None: - # All-reduce after GEMM, so unshard the outer dimension of RHS - rhs_spec_new[rhs_outer_dim] = None - else: - # We either do all-gather before GEMM, or LHS is already unsharded, so unshard - # the inner dimension of RHS to match - rhs_spec_new[rhs_inner_dim] = None + # Modify operand specs: + # - FSDP axes are all-gathered + # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded + # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension + lhs_spec_new = remove_fsdp_specs(lhs_spec) + rhs_spec_new = remove_fsdp_specs(rhs_spec) rhs_outer_spec = rhs_spec_new[rhs_outer_dim] - - # RHS operand is unchanged, we already enforce that only one dimension can be sharded + if rhs_outer_spec is not None: + lhs_spec_new[lhs_outer_dim] = None + lhs_spec_new[lhs_inner_dim] = rhs_spec_new[rhs_inner_dim] + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec_new)) # Bias is sharded to match outer dimension spec of the RHS operand (also the output) @@ -552,14 +621,17 @@ def partition( # FP8 metas are always unsharded fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) - # Outer (sequence) dimension of the GEMM output is always unsharded - out_spec = [*batch_specs, None, rhs_outer_spec] - if lhs.ndim > 2 and rhs.ndim > 2: - out_spec = [None, rhs_outer_spec] + # Output sharding is conditional on output shape + lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] + lhs_outer_spec = lhs_spec_new[lhs_outer_dim] + out_spec = [lhs_outer_spec, rhs_outer_spec] + if batched_output: + out_spec = batch_spec + out_spec out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) - # Pre-GELU output matches output spec if GELU fusion is turned on, otherwise unsharded - gelu_spec = [None, rhs_outer_spec] if fuse_gelu else [None] + # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded + gelu_spec = [lhs_outer_spec, rhs_outer_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) arg_shardings = ( @@ -599,6 +671,7 @@ def sharded_impl( out_amax, out_scale, out_dtype=out_dtype, + batched_output=batched_output, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, @@ -637,6 +710,7 @@ def fp8_gemm_impl( out_amax: Optional[ArrayLike] = None, out_scale: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, + batched_output: bool = False, fuse_gelu: bool = False, fuse_bias: bool = False, accumulate: bool = False, @@ -657,8 +731,8 @@ def fp8_gemm_impl( if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=bias.dtype) elif gelu_input is None: - out_shape = (*lhs.shape[:-2], lhs.shape[-2], rhs_t.shape[-2]) - gelu_input = jnp.zeros(out_shape, dtype=bias.dtype) + gelu_shape = (reduce(operator.mul, lhs.shape[:-1]), rhs_t.shape[-1]) + gelu_input = jnp.zeros(gelu_shape, dtype=bias.dtype) out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind( lhs, @@ -670,6 +744,7 @@ def fp8_gemm_impl( out_amax, out_scale, out_dtype=out_dtype, + batched_output=batched_output, contracting_dims=(-1, -1), fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, @@ -686,7 +761,8 @@ def gemm_impl( rhs: ArrayLike, bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, - contracting_dims: Tuple[int, int] = (1, 0), + batched_output: bool = False, + contracting_dims: Tuple[int, int] = (-1, -2), fuse_gelu: bool = False, fuse_bias: bool = False, grad: bool = False, @@ -696,16 +772,15 @@ def gemm_impl( """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) lhs_outer_dim, rhs_outer_dim = map( - lambda inner_dim, ndim: ndim - 2 if inner_dim == ndim - 1 else ndim - 1, + mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs.ndim, rhs.ndim), ) - out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) if not fuse_bias: bias = jnp.zeros(0, dtype=lhs.dtype) elif grad: - bias = jnp.zeros(out_shape[-1], dtype=lhs.dtype) + bias = jnp.zeros(rhs.shape[rhs_outer_dim], dtype=lhs.dtype) else: assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." @@ -716,7 +791,10 @@ def gemm_impl( gelu_input is not None ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." elif gelu_input is None: - gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) + bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] + batch_size = reduce(operator.mul, [lhs.shape[dim] for dim in bdims], 1) + gelu_shape = (batch_size * lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + gelu_input = jnp.zeros(gelu_shape, dtype=lhs.dtypes) dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( @@ -729,6 +807,7 @@ def gemm_impl( dummy_fp8_meta, dummy_fp8_meta, out_dtype=lhs.dtype, + batched_output=batched_output, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 730d17846e..18d1f76da7 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -18,7 +18,7 @@ dbias_cast_transpose, dact_lu_dbias_cast_transpose, ) -from .cpp_extensions.gemm import sanitize_dims +from .cpp_extensions.gemm import sanitize_dims, mirror_dim __all__ = [ @@ -72,10 +72,13 @@ def _gemm_fwd_rule( fuse_bias = bias is not None + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) --------> ([B], M, N/P) + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) out, pre_gelu_out = gemm_impl( x, kernel, bias=bias, + batched_output=(x.ndim > 2), contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, @@ -103,14 +106,22 @@ def _gemm_bwd_rule( ): x, kernel, pre_gelu_out, fuse_bias = ctx x_inner_dim, kernel_inner_dim = map(sanitize_dims, contracting_dims, (x.ndim, kernel.ndim)) - x_outer_dim = x.ndim - 1 if x_inner_dim != x.ndim - 1 else x.ndim - 2 - kernel_outer_dim = kernel.ndim - 2 if kernel_inner_dim == kernel.ndim - 1 else kernel.ndim - 1 + x_outer_dim, kernel_outer_dim = map( + mirror_dim, (x_inner_dim, kernel_inner_dim), (x.ndim, kernel.ndim) + ) + + # FWD MODE: + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) --------> ([B], M, N/P) + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) - # DGRAD: ([B], M, N) x (K, N)^T = ([B], M, K) + # DGRAD: + # AG+GEMM: ([B], M, N/P) x (K, N/P)^T --(AR)--> ([B], M, K) + # GEMM+AR: ([B], M, N) x (K/P, N)^T --------> ([B], M, K/P) dgrad, dgelu, _ = gemm_impl( grad, kernel, gelu_input=pre_gelu_out, + batched_output=(x.ndim > 2), contracting_dims=(-1, kernel_outer_dim), fuse_gelu=fuse_gelu, fuse_bias=False, @@ -119,12 +130,15 @@ def _gemm_bwd_rule( use_split_accumulator=use_split_accumulator, ) - # WGRAD: ([B], M, K)^T x ([B], M, N) = (K, N) + # WGRAD: + # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # GEMM+AR: ([B], M, K/P)^T x ([B], M, N) ----> (K/P, N) wgrad_rhs = dgelu if fuse_gelu else grad wgrad, _, bgrad = gemm_impl( x, wgrad_rhs, gelu_input=pre_gelu_out, + batched_output=False, contracting_dims=(x_outer_dim, wgrad_rhs.ndim - 2), fuse_gelu=False, fuse_bias=fuse_bias, @@ -279,6 +293,7 @@ def _fp8_gemm_fwd_rule( out_amax=out_amax, out_scale=out_scale, out_dtype=out_dtype, + batched_output=(x.ndim > 2), fuse_gelu=fuse_gelu, fuse_bias=fuse_bias, accumulate=accumulate, @@ -300,6 +315,7 @@ def _fp8_gemm_fwd_rule( pre_gelu_out if fuse_gelu else None, fuse_bias, maybe_fp32_to_fm32, + (x.ndim > 2), ) return (out, updated_out_scale), ctx @@ -325,6 +341,7 @@ def _fp8_gemm_bwd_rule( pre_gelu_out, fuse_bias, maybe_fp32_to_fm32, + batched_input, ) = ctx bwd_dtype = FP8Helper.BWD_DTYPE @@ -382,6 +399,7 @@ def _fp8_gemm_bwd_rule( grad_scale_inv, casted_kernel, kernel_scale_inv, + batched_output=batched_input, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) @@ -392,6 +410,7 @@ def _fp8_gemm_bwd_rule( x_scale_inv, casted_grad_t, grad_scale_inv, + out_shape=False, accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) From ce86dcb9c5d55c409ac92f9d8bafb0b7f01bc042 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 11:38:22 +0000 Subject: [PATCH 11/39] fixed logic to remove FSDP sharding Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index bf80941f85..d54009e60b 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -49,25 +49,44 @@ def mirror_dim(dim, ndims): def remove_fsdp_specs(pspecs): fsdp_resource = global_mesh_resource().fsdp_resource + if fsdp_resource is None: + return list(pspecs).copy() + new_pspecs = [] for spec in pspecs: if spec is None: new_pspecs.append(None) - elif fsdp_resource not in spec: - new_pspecs.append(spec) + elif isinstance(spec, Iterable) and not isinstance(spec, str): new_spec = [] for s in spec: - if s != fsdp_resource: + if s == fsdp_resource: + new_spec.append(None) + else: new_spec.append(s) + if len(new_spec) > 1: new_pspecs.append(new_spec) elif len(new_spec) == 1: new_pspecs.append(new_spec[0]) else: new_pspecs.append(None) + + elif isinstance(spec, str): + if spec == fsdp_resource: + new_pspecs.append(None) + else: + new_pspecs.append(spec) + else: - new_pspecs.append(None) + new_pspecs.append(spec) + + assert len(new_pspecs) == len(pspecs), ( + "Length of partition specs changed when removing FSDP sharding!\n" + + f"Original: {pspecs}\n" + + f"Filtered: {new_pspecs}\n" + ) + return new_pspecs From b215f207bd78acfd672264f7e52880a0a8137598 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Nov 2024 11:38:49 +0000 Subject: [PATCH 12/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d54009e60b..3c4bf15d00 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -198,18 +198,18 @@ def abstract( # Infer output shape if batched_output: - assert lhs_aval.ndim > 2 and rhs_aval.ndim == 2, ( - "Batched output requires batched LHS and non-batched RHS operands." - ) + assert ( + lhs_aval.ndim > 2 and rhs_aval.ndim == 2 + ), "Batched output requires batched LHS and non-batched RHS operands." out_shape = ( *lhs_batch_shape, lhs_aval.shape[lhs_outer_dim], - rhs_aval.shape[rhs_outer_dim] + rhs_aval.shape[rhs_outer_dim], ) else: - assert lhs_aval.ndim == rhs_aval.ndim, ( - "Non-batched output requires LHS and RHS operands with same number of dimensions." - ) + assert ( + lhs_aval.ndim == rhs_aval.ndim + ), "Non-batched output requires LHS and RHS operands with same number of dimensions." if lhs_aval.ndim > 2: rhs_bdims = [ dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] From cbab16c03109cf5b802b93adc03841828df332dd Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 19:03:52 +0000 Subject: [PATCH 13/39] retained FSDP dims and pushed FSDP all-gather of weight array to outside the custom op Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 52 ++----------------- transformer_engine/jax/gemm.py | 1 + 2 files changed, 6 insertions(+), 47 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 3c4bf15d00..353f2d2509 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -47,49 +47,6 @@ def mirror_dim(dim, ndims): return ndims - 2 if dim == ndims - 1 else ndims - 1 -def remove_fsdp_specs(pspecs): - fsdp_resource = global_mesh_resource().fsdp_resource - if fsdp_resource is None: - return list(pspecs).copy() - - new_pspecs = [] - for spec in pspecs: - if spec is None: - new_pspecs.append(None) - - elif isinstance(spec, Iterable) and not isinstance(spec, str): - new_spec = [] - for s in spec: - if s == fsdp_resource: - new_spec.append(None) - else: - new_spec.append(s) - - if len(new_spec) > 1: - new_pspecs.append(new_spec) - elif len(new_spec) == 1: - new_pspecs.append(new_spec[0]) - else: - new_pspecs.append(None) - - elif isinstance(spec, str): - if spec == fsdp_resource: - new_pspecs.append(None) - else: - new_pspecs.append(spec) - - else: - new_pspecs.append(spec) - - assert len(new_pspecs) == len(pspecs), ( - "Length of partition specs changed when removing FSDP sharding!\n" - + f"Original: {pspecs}\n" - + f"Filtered: {new_pspecs}\n" - ) - - return new_pspecs - - def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if tex.get_device_compute_capability() >= 90: @@ -563,8 +520,8 @@ def infer_sharding_from_operands( # - FSDP axes are all-gathered # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension - lhs_spec_new = remove_fsdp_specs(lhs_spec) - rhs_spec_new = remove_fsdp_specs(rhs_spec) + lhs_spec_new = [spec for spec in lhs_spec] + rhs_spec_new = [spec for spec in rhs_spec] if lhs_spec_new[lhs_inner_dim] != rhs_spec_new[rhs_inner_dim] and not grad: warnings.warn( "Forcing the inner dimension of LHS to match the sharding of inner " @@ -594,6 +551,7 @@ def infer_sharding_from_operands( # Bias gradient spec matches outer dimension of output if bias fusion is turned on bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) @staticmethod @@ -625,8 +583,8 @@ def partition( # - FSDP axes are all-gathered # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension - lhs_spec_new = remove_fsdp_specs(lhs_spec) - rhs_spec_new = remove_fsdp_specs(rhs_spec) + lhs_spec_new = [spec for spec in lhs_spec] + rhs_spec_new = [spec for spec in rhs_spec] rhs_outer_spec = rhs_spec_new[rhs_outer_dim] if rhs_outer_spec is not None: lhs_spec_new[lhs_outer_dim] = None diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 18d1f76da7..464ccb12f9 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -8,6 +8,7 @@ import jax import jax.numpy as jnp from jax.typing import ArrayLike +from jax.sharding import NamedSharding, PartitionSpec from .fp8 import FP8Helper, FP8MetaPackage from .cpp_extensions import ( From 0ea55c0eed1c5551a8b8872ff095d70d9e5d1625 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 19:46:02 +0000 Subject: [PATCH 14/39] Added useful warning about DGRAD sharding not matching sequence/context-parallel LHS operands Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 353f2d2509..823e9f7ea1 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -53,7 +53,6 @@ def get_cublas_workspace_size_bytes() -> None: return 33_554_432 return 4_194_304 - class CollectiveGemmPrimitive(BasePrimitive): """ cuBlasLt GEMM Primitive w/ support for distributed inputs @@ -385,15 +384,9 @@ def impl( lhs_batch_shape = [lhs.shape[dim] for dim in lhs_batch_dims] lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) contracting_dims_2d = list(contracting_dims).copy() - if batched_output: - # If output is batched, the LSH batch dimension collapses into the outer dimension - # and RHS cannot be batched - lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_outer_dim], lhs.shape[lhs_inner_dim]) - lhs_layout = (*lhs_batch_dims, lhs_outer_dim, lhs_inner_dim) - contracting_dims_2d[0] = 1 - else: - # If the output is not batched, both LHS and RHS batch dimensions collapse into the - # contracting dimensions + if lhs.ndim > 2 and rhs.ndim > 2: + # If both LHS and RHS are batched, the batch dimensions collapse into the + # contracting dimensions for both operands lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_inner_dim], lhs.shape[lhs_outer_dim]) lhs_layout = (*lhs_batch_dims, lhs_inner_dim, lhs_outer_dim) contracting_dims_2d[0] = 0 @@ -406,6 +399,11 @@ def impl( rhs_2d_shape = (rhs_batch_size * rhs.shape[rhs_inner_dim], rhs.shape[rhs_outer_dim]) rhs_layout = (*rhs_batch_dims, rhs_inner_dim, rhs_outer_dim) contracting_dims_2d[1] = 0 + elif lhs.ndim > 2: + # If only the LHS is batched,the batch dimension collapses into the outer dimension + lhs_2d_shape = (lhs_batch_size * lhs.shape[lhs_outer_dim], lhs.shape[lhs_inner_dim]) + lhs_layout = (*lhs_batch_dims, lhs_outer_dim, lhs_inner_dim) + contracting_dims_2d[0] = 1 # Reshape LHS and RHS into 2D and fix layouts for FP8 GEMM if lhs_2d_shape is not None and lhs.ndim > 2: @@ -524,12 +522,17 @@ def infer_sharding_from_operands( rhs_spec_new = [spec for spec in rhs_spec] if lhs_spec_new[lhs_inner_dim] != rhs_spec_new[rhs_inner_dim] and not grad: warnings.warn( - "Forcing the inner dimension of LHS to match the sharding of inner " - + "dimension of RHS. This can trigger additional communication if LHS is " - + "not already partitioned correctly." + "Forcing LHS sharding in the contracting dimension to match RHS. This can trigger " + + "additional communication if LHS is not already partitioned correctly." ) rhs_outer_spec = rhs_spec_new[rhs_outer_dim] if rhs_outer_spec is not None: + warnings.warn( + "Forcing the outer dimension of LHS (sequence/context dim) to be all- gathered. " + + "This may trigger additional communication if LHS is not already partitioned " + + "correctly. Additionally, the DGRAD output in the backward pass will not match " + + "the sharding of a sequence/context-parallel LHS operand." + ) lhs_spec_new[lhs_outer_dim] = None lhs_spec_new[lhs_inner_dim] = rhs_spec_new[rhs_inner_dim] @@ -661,8 +664,8 @@ def sharded_impl( if jax_dtype_is_fp8(lhs.dtype): out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) - # GEMM output needs to be all-reduced when the contracting dimension is sharded. if rhs_spec_new[rhs_inner_dim] is not None: + # GEMM output needs to be all-reduced when the contracting dimension is sharded. out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().tp_resource, mesh) if fuse_gelu: pre_gelu_out = lax_paral_op( From 2acb92f49b4687fde25f803f3115b693b900569b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Nov 2024 19:46:35 +0000 Subject: [PATCH 15/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 823e9f7ea1..31a8760564 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -53,6 +53,7 @@ def get_cublas_workspace_size_bytes() -> None: return 33_554_432 return 4_194_304 + class CollectiveGemmPrimitive(BasePrimitive): """ cuBlasLt GEMM Primitive w/ support for distributed inputs From b07bb2db5726d45dba28b7207bbc2051f166d8c4 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 21 Nov 2024 19:47:44 +0000 Subject: [PATCH 16/39] documentation fixes Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 31a8760564..0f567eecef 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -516,7 +516,6 @@ def infer_sharding_from_operands( ) # Modify operand specs: - # - FSDP axes are all-gathered # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension lhs_spec_new = [spec for spec in lhs_spec] @@ -584,7 +583,6 @@ def partition( ) # Modify operand specs: - # - FSDP axes are all-gathered # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension lhs_spec_new = [spec for spec in lhs_spec] From 765b844525e42d2def624bce7430f798828874d9 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 27 Nov 2024 21:54:39 +0000 Subject: [PATCH 17/39] added unit test, both AG+GEMM and GEMM+AR passing with FSDP+TP, DP+TP and TP-only meshes Signed-off-by: Alp Dener --- tests/jax/test_distributed_gemm.py | 311 ++++++++++++++++++ transformer_engine/jax/cpp_extensions/gemm.py | 107 +++--- transformer_engine/jax/gemm.py | 31 +- 3 files changed, 400 insertions(+), 49 deletions(-) create mode 100644 tests/jax/test_distributed_gemm.py diff --git a/tests/jax/test_distributed_gemm.py b/tests/jax/test_distributed_gemm.py new file mode 100644 index 0000000000..f1e3c58c4a --- /dev/null +++ b/tests/jax/test_distributed_gemm.py @@ -0,0 +1,311 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +import pytest +from functools import partial +from collections.abc import Iterable + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils + +import transformer_engine.jax as te +from transformer_engine.jax.gemm import gemm + +from utils import assert_allclose + + +jax.config.update('jax_enable_compilation_cache', False) + + +# AG+GEMM: (4, 32/P, 128) ----(AG)----> (4, 32, 128) x (128, 256/P) ----------> (4, 32, 256/P) +# - DGRAD: (4, 32, 256/P) x (128, 256/P)^T --(AR)--> (4, 32, 128) +# - WGRAD: (4, 32/P, 128)^T --(AG)--> (4, 32, 128)^T x (4, 32, 256/P) --------> (128, 256/P) + +# GEMM+AR: (4, 32, 256/P) x (256/P, 128) --(AR)--> (4, 32, 128) +# - DGRAD: (4, 32, 128) x (256/P, 128)^T ------> (4, 32, 256/P) +# - WGRAD: (4, 32, 256/P)^T --(AG)--> (4, 32, 256)^T x (4, 32, 128) --------> (256, 128) + +BATCH = 4 +BASE_SIZE = 16 +SEQ_LEN = BASE_SIZE * 8 +HIDDEN_SIZE = BASE_SIZE * 6 +FFN_HIDDEN_SIZE = BASE_SIZE * 16 + +COMM_TYPES = ["ALL_GATHER", "ALL_REDUCE"] +MESH_TYPES = ["FSDP_TP", "DP_TP", "TP"] +NUM_DEVICES = 4 + +is_fp8_supported, no_fp8_reason = te.fp8.is_fp8_available() + + +def _get_mesh(parallel_dist): + jax.clear_caches() + + batched = False + fsdp = False + mesh_shape = dict(tp=NUM_DEVICES) + resources = dict(cp_resource='tp', tp_resource='tp') + if parallel_dist in ["DP_TP", "FSDP_TP"]: + batched = True + mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=NUM_DEVICES//2)) + resources.update(dict(dp_resource='dp')) + if parallel_dist == "FSDP_TP": + fsdp = True + mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=1, zp=NUM_DEVICES//2)) + resources.update(dict(fsdp_resource='zp')) + mesh_resource = te.MeshResource(**resources) + + devices = mesh_utils.create_device_mesh( + (NUM_DEVICES, ), devices=jax.devices()[:NUM_DEVICES] + ) + + mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) + + return mesh, mesh_resource, batched, fsdp + + +def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bwd=False): + fp8_gemm = dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] + + # Operand and output shapes + lhs_shape = ( + [SEQ_LEN, HIDDEN_SIZE] + if fwd_comm_type == "ALL_GATHER" + else [SEQ_LEN, FFN_HIDDEN_SIZE] + ) + rhs_shape = ( + [HIDDEN_SIZE, FFN_HIDDEN_SIZE] + if fwd_comm_type == "ALL_GATHER" + else [FFN_HIDDEN_SIZE, HIDDEN_SIZE] + ) + out_shape = [lhs_shape[0], rhs_shape[1]] + + if batched: + lhs_shape = [BATCH] + lhs_shape + out_shape = [BATCH] + out_shape + + # Operand and output partition specs + lhs_spec = ( + [mesh_resource.tp_resource, None] + if fwd_comm_type == "ALL_GATHER" + else [None, mesh_resource.tp_resource] + ) + rhs_spec = ( + [None, mesh_resource.tp_resource] + if fwd_comm_type == "ALL_GATHER" + else [mesh_resource.tp_resource, None] + ) + out_spec = [None, rhs_spec[-1]] + + # Modify RHS operand for FP8 + fsdp_gathered_rhs_spec = rhs_spec.copy() + if fp8_gemm: + rhs_shape = list(reversed(rhs_shape)) + rhs_spec = list(reversed(rhs_spec)) + fsdp_gathered_rhs_spec = list(reversed(fsdp_gathered_rhs_spec)) + + # Add batch dimensions and specs + if batched: + if fsdp: + lhs_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + lhs_spec + rhs_spec = [mesh_resource.fsdp_resource if spec is None else spec for spec in rhs_spec] + out_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + out_spec + else: + lhs_spec = [mesh_resource.dp_resource] + lhs_spec + out_spec = [mesh_resource.dp_resource] + out_spec + + # Allocate global operands on device + key = jax.random.PRNGKey(42) + split_keys = jax.random.split(key, 3 if fwd_bwd else 2) + mu = 0.0 + sigma = 0.023 + shapes = (lhs_shape, rhs_shape) + if fwd_bwd: + shapes += (out_shape, ) + global_operands = list( + map( + lambda key, shape: jax.device_put( + mu + (sigma * jax.random.normal(key, shape, dtype=dtype)), + NamedSharding(mesh, PartitionSpec(None)) + ), + split_keys, + shapes, + ) + ) + + # Allocate sharded operands on device + partition_axes = (lhs_spec, rhs_spec) + if fwd_bwd: + partition_axes += (out_spec, ) + local_operands = list( + map( + lambda x, spec: jax.device_put(x, NamedSharding(mesh, PartitionSpec(*spec))), + global_operands, + partition_axes, + ) + ) + + # Tranpose global RHS back to non-transpoosed orientation if it was originally allocated + # for FP8 GEMM + if fp8_gemm: + rhs_global = jnp.matrix_transpose(global_operands[1]) + global_operands = (global_operands[0], rhs_global, *global_operands[2:]) + + return ( + local_operands, + global_operands, + (out_shape, out_spec), + fsdp_gathered_rhs_spec, + ) + + +def _check_output(mesh, expected_out_shape, expected_out_specs, *tensors, fwd_bwd=False): + num_operands = 3 if fwd_bwd else 2 + ref_operands = tensors[:num_operands] + test_outputs = tensors[num_operands:] + + # Check number of dimensions + assert test_outputs[0].ndim == len(expected_out_shape), ( + f"Output has different number of dimensions ({test_outputs[0].ndim}) than expected " + + f"({len(expected_out_shape)})" + ) + + # Pad test output spec for unsharded dimensions + test_spec = te.sharding.get_padded_spec(test_outputs[0].sharding.spec, test_outputs[0].ndim) + + for i in range(test_outputs[0].ndim): + # Check shape + assert test_outputs[0].shape[i] == expected_out_shape[i], ( + f"Output with shape {test_outputs[0].shape} does not match expected shape " + + f"{expected_out_shape} in dimension index {i}." + ) + + # Check shardings (with padded output spec) + spec_mismatch = False + if isinstance(expected_out_specs[i], str): + if test_spec[i] != expected_out_specs[i]: + spec_mismatch = True + elif isinstance(expected_out_specs[i], Iterable): + if not isinstance(test_spec[i], type(expected_out_specs[i])): + if test_spec[i] not in expected_out_specs[i]: + spec_mismatch = True + elif len(test_spec[i]) != len(expected_out_specs[i]): + spec_mismatch = True + else: + for j in range(len(expected_out_specs[i])): + if test_spec[i][j] != expected_out_specs[i][j]: + spec_mismatch = True + break + elif expected_out_specs[i] == None: + if test_spec[i] != None: + spec_mismatch = True + else: + raise RuntimeError("Internal TE error: Unrecognized reference partition spec type.") + if spec_mismatch: + raise AssertionError( + f"Output sharding {test_spec} does not match expected sharding " + + f"{expected_out_specs} in dimension index {i}." + ) + + def _native_gemm_fwd_bwd(lhs, rhs, grad): + fwd_out, vjp_fn = jax.vjp(jnp.dot, lhs, rhs) + lhs_grad, rhs_grad = vjp_fn(grad) + return fwd_out, lhs_grad, rhs_grad + + ref_fn = jax.jit(_native_gemm_fwd_bwd if fwd_bwd else jnp.dot) + + out_names = ["output"] + ref_outputs = ref_fn(*ref_operands) + if not fwd_bwd: + ref_outputs = [ref_outputs] + else: + out_names += ["dgrad", "wgrad"] + + for i, (test_out, ref_out) in enumerate(zip(test_outputs, ref_outputs)): + test_out_global = jax.lax.with_sharding_constraint( + test_out, NamedSharding(mesh, PartitionSpec(None)) + ) + try: + assert_allclose(ref_out, test_out_global) + except AssertionError as err: + raise AssertionError(f"Numerical mismatch in {out_names[i]}:\n" + str(err)) + + +@pytest.mark.parametrize("comm_type", COMM_TYPES) +@pytest.mark.parametrize("mesh_type", MESH_TYPES) +def test_gemm_impl(comm_type, mesh_type): + mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) + + ( + local_operands, + global_operands, + output_info, + fsdp_gathered_rhs_spec, + ) = _get_inputs( + mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp + ) + + @jax.jit + def _test_fn(lhs, rhs): + rhs_no_fsdp = jax.lax.with_sharding_constraint( + rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) + ) + return te.cpp_extensions.gemm_impl(lhs, rhs_no_fsdp, batched_output=batched) + + with te.sharding.global_shard_guard(mesh_resource): + output, *_ = _test_fn(*local_operands) + + _check_output(mesh, *output_info, *global_operands, output) + + +@pytest.mark.parametrize("comm_type", COMM_TYPES) +@pytest.mark.parametrize("mesh_type", MESH_TYPES) +def test_gemm_fwd_bwd(comm_type, mesh_type): + mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) + + ( + local_operands, + global_operands, + output_info, + fsdp_gathered_rhs_spec, + ) = _get_inputs( + mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True + ) + + @jax.jit + def _test_fn(lhs, rhs, grad): + # Gather weights in FSDP axis + rhs_no_fsdp = jax.lax.with_sharding_constraint( + rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) + ) + + # FWD pass + fwd_out, vjp_fn = jax.vjp(gemm, lhs, rhs_no_fsdp) + + # BWD pass + lhs_grad, rhs_grad = vjp_fn(grad) + + return fwd_out, lhs_grad, rhs_grad + + print( + f"INPUTS: {local_operands[0].shape} x {local_operands[1].shape}\n" + + f" LHS sharding: {local_operands[0].sharding.spec}\n" + + f" RHS sharding: {local_operands[1].sharding.spec}\n" + ) + + with te.sharding.global_shard_guard(mesh_resource): + output, dgrad, wgrad = _test_fn(*local_operands) + + print( + f"{'AG + GEMM' if comm_type == 'AG' else 'GEMM + AR'} output: " + + f"{output.shape} | {output.sharding.spec}\n" + + f"DGRAD: {dgrad.shape} | {dgrad.sharding.spec}\n" + + f"WGRAD: {wgrad.shape} | {wgrad.sharding.spec}\n" + ) + + _check_output(mesh, *output_info, *global_operands, output, dgrad, wgrad, fwd_bwd=True) + diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0f567eecef..30ff0ca54a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -516,31 +516,53 @@ def infer_sharding_from_operands( ) # Modify operand specs: - # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded - # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension - lhs_spec_new = [spec for spec in lhs_spec] - rhs_spec_new = [spec for spec in rhs_spec] - if lhs_spec_new[lhs_inner_dim] != rhs_spec_new[rhs_inner_dim] and not grad: - warnings.warn( - "Forcing LHS sharding in the contracting dimension to match RHS. This can trigger " - + "additional communication if LHS is not already partitioned correctly." + # - If contracting dimensions of both operands are sharded, force them to match. + # - If contracting dimensions of both operands are sharded, all-gather outer dimensions. + # - If contracting dimension of only one operand is sharded, all-gather the sharded + # operand. + # - Never scatter any operand. + lhs_spec_new = list(lhs_spec).copy() + rhs_spec_new = list(rhs_spec).copy() + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + assert lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim], ( + "Contracting dimensions of LHS and RHS operands must have the same sharding." ) - rhs_outer_spec = rhs_spec_new[rhs_outer_dim] - if rhs_outer_spec is not None: - warnings.warn( - "Forcing the outer dimension of LHS (sequence/context dim) to be all- gathered. " - + "This may trigger additional communication if LHS is not already partitioned " - + "correctly. Additionally, the DGRAD output in the backward pass will not match " - + "the sharding of a sequence/context-parallel LHS operand." - ) - lhs_spec_new[lhs_outer_dim] = None - lhs_spec_new[lhs_inner_dim] = rhs_spec_new[rhs_inner_dim] + if lhs_spec_new[lhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the LHS operand must be all-gathered when both contracting " + + "dimensions are sharded. This will cause additional communication overhead." + ) + + if rhs_spec_new[rhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the RHS operand must be all-gathered when both contracting " + + "dimensions are sharded. This will cause additional communication overhead." + ) + rhs_spec_new[rhs_outer_dim] = None + else: + if lhs_spec_new[lhs_inner_dim] is None and rhs_spec_new[rhs_inner_dim] is not None: + warnings.warn( + "Contracting dimension of the RHS operand must be all-gathered when the " + + "contracting dimension of the LHS operand is unsharded. This will cause " + + "additional communication overhead." + ) + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is None: + if not grad: + # This is expected for sequence/context-parallel gradient in BWD (DGRAD) GEMM. + warnings.warn( + "Contracting dimension of the LHS operand must be all-gathered when the " + + "contracting dimension of the RHS operand is unsharded. This will cause " + + "additional communication overhead." + ) + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None + out_col_spec = rhs_spec_new[rhs_outer_dim] # Output sharding is conditional on output shape lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] - lhs_outer_spec = lhs_spec_new[lhs_outer_dim] - out_spec = [lhs_outer_spec, rhs_outer_spec] + out_spec = [None, out_col_spec] if batched_output: out_spec = batch_spec + out_spec out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) @@ -549,11 +571,11 @@ def infer_sharding_from_operands( fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded - gelu_spec = [lhs_outer_spec, rhs_outer_spec] if fuse_gelu else [None] + gelu_spec = [None, out_col_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) # Bias gradient spec matches outer dimension of output if bias fusion is turned on - bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + bias_sharding = NamedSharding(mesh, PartitionSpec(out_col_spec if fuse_bias else None)) return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) @@ -583,19 +605,27 @@ def partition( ) # Modify operand specs: - # - LHS operand outer dimension is all-gathered if RHS operand outer dimension is sharded - # - LHS operand contracting dimension sharding is forced to match RHS contracting dimension - lhs_spec_new = [spec for spec in lhs_spec] - rhs_spec_new = [spec for spec in rhs_spec] - rhs_outer_spec = rhs_spec_new[rhs_outer_dim] - if rhs_outer_spec is not None: - lhs_spec_new[lhs_outer_dim] = None - lhs_spec_new[lhs_inner_dim] = rhs_spec_new[rhs_inner_dim] + # - Always all-gather the outer dimension of LHS. + # - If contracting dimensions of both operands are sharded, all-gather RHS outer dimension. + # - If contracting dimension of only one operand is sharded, all-gather the sharded + # operand. + # - Never scatter any operand. + lhs_spec_new = list(lhs_spec).copy() + rhs_spec_new = list(rhs_spec).copy() + reduce_output = False + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + rhs_spec_new[rhs_outer_dim] = None + reduce_output = True + else: + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None + out_col_spec = rhs_spec_new[rhs_outer_dim] lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec_new)) # Bias is sharded to match outer dimension spec of the RHS operand (also the output) - bias_sharding = NamedSharding(mesh, PartitionSpec(rhs_outer_spec if fuse_bias else None)) + bias_sharding = NamedSharding(mesh, PartitionSpec(out_col_spec if fuse_bias else None)) # FP8 metas are always unsharded fp8_meta_sharding = NamedSharding(mesh, PartitionSpec(None)) @@ -603,14 +633,13 @@ def partition( # Output sharding is conditional on output shape lhs_bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] batch_spec = [lhs_spec_new[dim] for dim in lhs_bdims] - lhs_outer_spec = lhs_spec_new[lhs_outer_dim] - out_spec = [lhs_outer_spec, rhs_outer_spec] + out_spec = [None, out_col_spec] if batched_output: out_spec = batch_spec + out_spec out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec)) # Pre-GELU output is always 2D if GELU fusion is turned on, otherwise unsharded - gelu_spec = [lhs_outer_spec, rhs_outer_spec] if fuse_gelu else [None] + gelu_spec = [None, out_col_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) arg_shardings = ( @@ -663,13 +692,11 @@ def sharded_impl( if jax_dtype_is_fp8(lhs.dtype): out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) - if rhs_spec_new[rhs_inner_dim] is not None: - # GEMM output needs to be all-reduced when the contracting dimension is sharded. - out = lax_paral_op(out, jax.lax.psum, global_mesh_resource().tp_resource, mesh) + # All-reduce sum GEMM output when contracting dimensions are sharded + if reduce_output: + out = jax.lax.psum(out, global_mesh_resource().tp_resource) if fuse_gelu: - pre_gelu_out = lax_paral_op( - pre_gelu_out, jax.lax.psum, global_mesh_resource().tp_resource, mesh - ) + pre_gelu_out = jax.lax.psum(pre_gelu_out, global_mesh_resource().tp_resource) return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 464ccb12f9..4cf09a204f 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -33,7 +33,7 @@ def gemm( x: ArrayLike, kernel: ArrayLike, bias: Optional[ArrayLike] = None, - contracting_dims: Tuple[int, int] = (1, 0), + contracting_dims: Tuple[int, int] = (-1, -2), fuse_gelu: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, @@ -73,8 +73,11 @@ def _gemm_fwd_rule( fuse_bias = bias is not None - # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) --------> ([B], M, N/P) - # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) + # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) out, pre_gelu_out = gemm_impl( x, kernel, @@ -112,12 +115,18 @@ def _gemm_bwd_rule( ) # FWD MODE: - # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) --------> ([B], M, N/P) - # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) + # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) + # + # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) + # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) # DGRAD: - # AG+GEMM: ([B], M, N/P) x (K, N/P)^T --(AR)--> ([B], M, K) - # GEMM+AR: ([B], M, N) x (K/P, N)^T --------> ([B], M, K/P) + # AG+GEMM: ([B], M, N/P) x (K, N/P)^T ----(AR)----> ([B], M, K) + # (DP, None, TP) x (None, TP)^T --(AR)--> (DP, None, None) + # + # GEMM+AR: ([B], M, N) x (K/P, N)^T ------> ([B], M, K/P) + # (DP, None, None) x (TP, None)^T --> (DP, None, TP) dgrad, dgelu, _ = gemm_impl( grad, kernel, @@ -133,7 +142,11 @@ def _gemm_bwd_rule( # WGRAD: # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) - # GEMM+AR: ([B], M, K/P)^T x ([B], M, N) ----> (K/P, N) + # (DP, 'tp', None)^T --(AG)-->(DP, None, None)^T x (DP, None, 'tp') --> (None, 'tp') + # + # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K/P, N) + # (DP, None, 'tp')^T --(AG)--> (DP, None, None)^T x (DP, None, None) ----> (None, None) + # Make XLA scatter output in first dim. wgrad_rhs = dgelu if fuse_gelu else grad wgrad, _, bgrad = gemm_impl( x, @@ -445,7 +458,7 @@ def type_safe_gemm( bias: Optional[ArrayLike] = None, fp8_meta: Optional[FP8MetaPackage] = None, out_dtype: Optional[jnp.dtype] = None, - contracting_dims: Tuple[int, int] = (1, 0), + contracting_dims: Tuple[int, int] = (-1, -2), fuse_gelu: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, From 2ce4377702d20d48564383647caede1f2dcf1e6e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 21:55:29 +0000 Subject: [PATCH 18/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_distributed_gemm.py | 35 +++++++------------ transformer_engine/jax/cpp_extensions/gemm.py | 6 ++-- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/tests/jax/test_distributed_gemm.py b/tests/jax/test_distributed_gemm.py index f1e3c58c4a..b246999d8a 100644 --- a/tests/jax/test_distributed_gemm.py +++ b/tests/jax/test_distributed_gemm.py @@ -18,7 +18,7 @@ from utils import assert_allclose -jax.config.update('jax_enable_compilation_cache', False) +jax.config.update("jax_enable_compilation_cache", False) # AG+GEMM: (4, 32/P, 128) ----(AG)----> (4, 32, 128) x (128, 256/P) ----------> (4, 32, 256/P) @@ -48,20 +48,18 @@ def _get_mesh(parallel_dist): batched = False fsdp = False mesh_shape = dict(tp=NUM_DEVICES) - resources = dict(cp_resource='tp', tp_resource='tp') + resources = dict(cp_resource="tp", tp_resource="tp") if parallel_dist in ["DP_TP", "FSDP_TP"]: batched = True - mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=NUM_DEVICES//2)) - resources.update(dict(dp_resource='dp')) + mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=NUM_DEVICES // 2)) + resources.update(dict(dp_resource="dp")) if parallel_dist == "FSDP_TP": fsdp = True - mesh_shape.update(dict(tp=NUM_DEVICES//2, dp=1, zp=NUM_DEVICES//2)) - resources.update(dict(fsdp_resource='zp')) + mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=1, zp=NUM_DEVICES // 2)) + resources.update(dict(fsdp_resource="zp")) mesh_resource = te.MeshResource(**resources) - devices = mesh_utils.create_device_mesh( - (NUM_DEVICES, ), devices=jax.devices()[:NUM_DEVICES] - ) + devices = mesh_utils.create_device_mesh((NUM_DEVICES,), devices=jax.devices()[:NUM_DEVICES]) mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) @@ -73,9 +71,7 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw # Operand and output shapes lhs_shape = ( - [SEQ_LEN, HIDDEN_SIZE] - if fwd_comm_type == "ALL_GATHER" - else [SEQ_LEN, FFN_HIDDEN_SIZE] + [SEQ_LEN, HIDDEN_SIZE] if fwd_comm_type == "ALL_GATHER" else [SEQ_LEN, FFN_HIDDEN_SIZE] ) rhs_shape = ( [HIDDEN_SIZE, FFN_HIDDEN_SIZE] @@ -125,12 +121,12 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw sigma = 0.023 shapes = (lhs_shape, rhs_shape) if fwd_bwd: - shapes += (out_shape, ) + shapes += (out_shape,) global_operands = list( map( lambda key, shape: jax.device_put( mu + (sigma * jax.random.normal(key, shape, dtype=dtype)), - NamedSharding(mesh, PartitionSpec(None)) + NamedSharding(mesh, PartitionSpec(None)), ), split_keys, shapes, @@ -140,7 +136,7 @@ def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bw # Allocate sharded operands on device partition_axes = (lhs_spec, rhs_spec) if fwd_bwd: - partition_axes += (out_spec, ) + partition_axes += (out_spec,) local_operands = list( map( lambda x, spec: jax.device_put(x, NamedSharding(mesh, PartitionSpec(*spec))), @@ -245,9 +241,7 @@ def test_gemm_impl(comm_type, mesh_type): global_operands, output_info, fsdp_gathered_rhs_spec, - ) = _get_inputs( - mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp - ) + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp) @jax.jit def _test_fn(lhs, rhs): @@ -272,9 +266,7 @@ def test_gemm_fwd_bwd(comm_type, mesh_type): global_operands, output_info, fsdp_gathered_rhs_spec, - ) = _get_inputs( - mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True - ) + ) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True) @jax.jit def _test_fn(lhs, rhs, grad): @@ -308,4 +300,3 @@ def _test_fn(lhs, rhs, grad): ) _check_output(mesh, *output_info, *global_operands, output, dgrad, wgrad, fwd_bwd=True) - diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 30ff0ca54a..250e8e0c29 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -525,9 +525,9 @@ def infer_sharding_from_operands( rhs_spec_new = list(rhs_spec).copy() lhs_spec_new[lhs_outer_dim] = None if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: - assert lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim], ( - "Contracting dimensions of LHS and RHS operands must have the same sharding." - ) + assert ( + lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim] + ), "Contracting dimensions of LHS and RHS operands must have the same sharding." if lhs_spec_new[lhs_outer_dim] is not None: warnings.warn( "Outer dimension of the LHS operand must be all-gathered when both contracting " From f68d71edc56980932b4a4a07ab7d26c44fdaa4e7 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 5 Dec 2024 21:29:27 +0000 Subject: [PATCH 19/39] restored old test_custom_call_compute.py to remove erroneous changes Signed-off-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 50 --------------------------- 1 file changed, 50 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 355f587265..20b16c2809 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -25,7 +25,6 @@ _jax_dbias_cast_transpose, ) from transformer_engine.jax.cpp_extensions.quantization import _jax_cast_fp8 -from transformer_engine.jax.gemm import fp8_gemm, gemm from transformer_engine.jax import cpp_extensions as tex @@ -416,55 +415,6 @@ def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_ ) -class TestGemm: - - @staticmethod - def _generate_inputs(b, m, n, k, dtype): - key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 3) - a = jax.random.normal(subkeys[0], (b, m, k), dtype) - b = jax.random.normal(subkeys[1], (n, k), dtype) - bias_dtype = dtype if dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2] else jnp.bfloat16 - bias = jax.random.normal(subkeys[2], (n,), bias_dtype) - return a, b, bias - - @staticmethod - def _generate_fp8_inputs(b, m, n, k, fp8_dtype): - a, b, bias = TestGemm._generate_inputs(b, m, n, k, jnp.bfloat16) - a_scale, b_scale = map(lambda x: (jnp.max(jnp.abs(x)) / 127.0).astype(jnp.float32), [a, b]) - a_q, b_q = map( - lambda x, x_scale: jnp.round(x / x_scale).astype(fp8_dtype), - [(a, a_scale), (b, b_scale)], - ) - return a, a_q, jnp.reciprocal(a_scale), b, b_q, jnp.reciprocal(b_scale), bias - - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - @pytest.mark.parametrize("use_bias", (False, True)) - @pytest.mark.parametrize("do_gelu", (False, True)) - def test_gemm(self, b, m, n, k, use_bias, do_gelu): - a, b, bias = self._generate_inputs(b, m, n, k, jnp.bfloat16) - - primitive_out = gemm(a, b, bias=bias if use_bias else None, layout="NT", do_gelu=do_gelu) - ref_out = jnp.dot(a, b) - if use_bias: - ref_out += bias - if do_gelu: - ref_out = jax.nn.gelu(ref_out) - - assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) - - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - @pytest.mark.parametrize("fp8_dtype", FP8_COMPUTE_TYPE) - def test_fp8_gemm(self, m, n, k, fp8_dtype): - a, a_q, a_scale_inv, b, b_q, b_scale_inv, _ = self._generate_fp8_inputs(m, n, k, fp8_dtype) - - primitive_out = fp8_gemm(a_q, a_scale_inv, b_q, b_scale_inv, out_dtype=jnp.bfloat16) - ref_out = jnp.dot(a, b) - - assert_allclose(primitive_out, ref_out, dtype=fp8_dtype) - - @pytest.fixture(name="random_inputs") def random_inputs_fixture(shape): key = jax.random.PRNGKey(0) From 6b322bb163c2de7d53cd69cb9306c5f0567fcdf6 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 14 Nov 2024 09:23:13 +0000 Subject: [PATCH 20/39] added XLA custom ops and C++ infrastructure for comm+GEMM overlap in TE/JAX Signed-off-by: Alp Dener comm+GEMM overlap API for TE/JAX compiles, untested, but did not break collective GEMM op Signed-off-by: Alp Dener [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fixed static args Signed-off-by: Alp Dener [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .gitmodules | 3 + 3rdparty/dlpack | 1 + build_tools/jax.py | 20 + setup.py | 1 + .../comm_gemm_overlap/comm_gemm_overlap.cpp | 72 ++- .../transformer_engine/comm_gemm_overlap.h | 143 +++-- .../transformer_engine/transformer_engine.h | 2 +- .../common/transformer_engine.cpp | 2 +- .../common/util/dlpack_helper.h | 188 ++++++ .../common/util/pybind_helper.h | 18 +- transformer_engine/jax/cpp_extensions/gemm.py | 566 ++++++++++++----- transformer_engine/jax/csrc/extensions.h | 100 ++- .../jax/csrc/extensions/comm_gemm_overlap.cpp | 291 +++++++++ .../jax/csrc/extensions/gemm.cpp | 11 +- .../jax/csrc/extensions/packing.cpp | 34 +- .../jax/csrc/extensions/pybind.cpp | 14 + transformer_engine/jax/gemm.py | 575 +++++++++++++++++- transformer_engine/pytorch/csrc/extensions.h | 3 +- .../csrc/extensions/comm_gemm_overlap.cpp | 14 +- .../pytorch/csrc/extensions/pybind.cpp | 5 +- transformer_engine/pytorch/module/base.py | 13 +- 21 files changed, 1815 insertions(+), 261 deletions(-) create mode 160000 3rdparty/dlpack create mode 100644 transformer_engine/common/util/dlpack_helper.h create mode 100644 transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp diff --git a/.gitmodules b/.gitmodules index 21492db5ef..7fc91b1f54 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend url = https://github.com/NVIDIA/cudnn-frontend.git +[submodule "3rdparty/dlpack"] + path = 3rdparty/dlpack + url = git@github.com:dmlc/dlpack.git diff --git a/3rdparty/dlpack b/3rdparty/dlpack new file mode 160000 index 0000000000..bbd2f4d324 --- /dev/null +++ b/3rdparty/dlpack @@ -0,0 +1 @@ +Subproject commit bbd2f4d32427e548797929af08cfe2a9cbb3cf12 diff --git a/build_tools/jax.py b/build_tools/jax.py index f829230f50..bb4da4e5ed 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -5,6 +5,7 @@ """JAX related extensions.""" import os from pathlib import Path +from typing import Optional import setuptools from glob import glob @@ -36,6 +37,7 @@ def setup_jax_extension( csrc_source_files, csrc_header_files, common_header_files, + third_party_packages, ) -> setuptools.Extension: """Setup PyBind11 extension for JAX support""" # Source files @@ -55,12 +57,28 @@ def setup_jax_extension( common_header_files / "common" / "include", csrc_header_files, xla_home, + third_party_packages / "dlpack" / "include", ] # Compile flags cxx_flags = ["-O3"] nvcc_flags = ["-O3"] + # Userbuffers MPI dependence + libraries = [] + library_dirs = [] + if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): + mpi_home = os.getenv("MPI_HOME") + assert mpi_home is not None, "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" + mpi_home = Path(mpi_home) + libraries.append("mpi") + library_dirs.append(mpi_home / "lib") + + include_dirs.append(mpi_home / "include") + + cxx_flags.append("-DNVTE_UB_WITH_MPI") + nvcc_flags.append("-DNVTE_UB_WITH_MPI") + # Define TE/JAX as a Pybind11Extension from pybind11.setup_helpers import Pybind11Extension @@ -79,5 +97,7 @@ def _add_cflags(self, flags: List[str]) -> None: "transformer_engine_jax", sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], + library_dirs=[str(path) for path in library_dirs], + libraries=libraries, extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags}, ) diff --git a/setup.py b/setup.py index 3bb2fe6b95..a702399bc9 100644 --- a/setup.py +++ b/setup.py @@ -164,6 +164,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: "transformer_engine/jax/csrc", current_file_path / "transformer_engine" / "jax" / "csrc", current_file_path / "transformer_engine", + current_file_path / "3rdparty", ) ) if "paddle" in frameworks: diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index c6f0f870ff..810eeb2ebe 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -139,11 +139,12 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool atomic_gemm) + bool atomic_gemm, bool overlap_first_gemm) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, false, atomic_gemm) { _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); + _overlap_first_gemm = overlap_first_gemm; NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", "or 2 (multi-atomic)."); @@ -164,6 +165,36 @@ CommOverlapBase::~CommOverlapBase() { cudaStreamDestroy(_stream_comm); } +TensorWrapper CommOverlapBase::get_ubuf_output(CommOverlapType comm_type) { + char *output_ptr = reinterpret_cast(_ubuf.dptr()); + if (comm_type == CommOverlapType::RS) + output_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + size_t output_c_dim0 = + (comm_type == CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + size_t output_c_dim1 = _ubuf.size(1); + return TensorWrapper(reinterpret_cast(output_ptr), {output_c_dim0, output_c_dim1}, + _ubuf.dtype()); +} + +void CommOverlapBase::copy_into_ubuf(cudaStream_t stream, TensorWrapper &input, + CommOverlapType comm_type) { + char *ubuf_ptr = reinterpret_cast(_ubuf.dptr()); + if (comm_type == CommOverlapType::AG) { + if ((input.numel() * _tp_size) != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("Input and buffer sizes do not match!"); + } + ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + } else { + if (input.numel() != (int64_t)_ubuf.numel() || + input.element_size() != (int64_t)_ubuf.element_size()) { + NVTE_ERROR("Input and buffer sizes do not match!"); + } + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.dptr(), input.numel() * input.element_size(), + cudaMemcpyDeviceToDevice, stream)); +} + /* ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf @@ -225,8 +256,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main) { + TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -325,8 +355,7 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main) { + TensorWrapper &rs_output, cudaStream_t stream_main) { // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; @@ -358,7 +387,7 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap assert(pre_gelu_out.numel() == 0); - if (gemm_overlap) { + if (_overlap_first_gemm) { auto input_a_chunk = TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); auto output_chunk = @@ -565,6 +594,37 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { cudaStreamDestroy(_stream_send); } +TensorWrapper CommOverlapP2PBase::get_ubuf_output(CommOverlapType comm_type) { + char *output_ptr = reinterpret_cast(_ubuf.dptr()); + if (comm_type == CommOverlapType::RS) + output_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); + size_t output_c_dim0 = + (comm_type == CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; + size_t output_c_dim1 = _ubuf.size(1); + return TensorWrapper(reinterpret_cast(output_ptr), {output_c_dim0, output_c_dim1}, + _ubuf.dtype()); +} + +void CommOverlapP2PBase::copy_into_ubuf(cudaStream_t stream, TensorWrapper &input, + CommOverlapType comm_type) { + if (comm_type == CommOverlapType::RS) { + // Copy input to the target ubuf chunk by rank offset + if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) { + NVTE_ERROR("Input and buffer sizes do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.dptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + stream)); + } else { + if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) { + NVTE_ERROR("Input and buffer sizes do not match!"); + } + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.dptr(), + input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, + stream)); + } +} + /* ** Split AllGather + AtomicGEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 1d5d192a39..16e4ccf16a 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -17,6 +17,9 @@ #define NVTE_COMM_OVERLAP_MAX_STREAMS 3 +#define NOT_IMPLEMENTED_ERROR() NVTE_ERROR("Operation is not implemented.") + +#define NOT_SUPPORTED_ERROR() NVTE_ERROR("Operation not supported.") namespace transformer_engine { /* \brief Check if Userbufers bootstraps with direct calls to MPI collectives. @@ -26,9 +29,9 @@ namespace transformer_engine { */ bool ubuf_built_with_mpi(); -enum class CommOverlapType { RS = 0, AG = 1 }; +enum class CommOverlapType : int32_t { RS = 0, AG = 1 }; -enum class CommOverlapAlgo { +enum class CommOverlapAlgo : int32_t { BULK_OVERLAP_AG = 0, BULK_OVERLAP_RS = 1, SPLIT_PIPELINED_AG_P2P = 2, @@ -77,16 +80,64 @@ class CommOverlapCore { _ubuf_scale_inv_initialized = true; } - bool is_atomic_gemm() { return _atomic_gemm; } + virtual TensorWrapper get_ubuf_output(CommOverlapType comm_type) { NOT_IMPLEMENTED_ERROR(); } + + virtual void copy_into_ubuf(cudaStream_t stream, TensorWrapper &input, + CommOverlapType comm_type) { + NOT_IMPLEMENTED_ERROR(); + } + + virtual bool is_atomic_gemm() { return _atomic_gemm; } + + virtual bool is_p2p_overlap() { return _is_p2p; } + + virtual bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + + virtual void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } + + virtual void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } - bool is_p2p_overlap() { return _is_p2p; } + virtual void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } - bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + virtual void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } + + virtual void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NOT_IMPLEMENTED_ERROR(); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { protected: int _rs_kernel_type; + bool _overlap_first_gemm; cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; @@ -95,36 +146,47 @@ class CommOverlapBase : public CommOverlapCore { int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, + bool overlap_first_gemm = false); virtual ~CommOverlapBase(); - /* - ** Bulk GEMM + COMM - ** This function assumes the communication input is pre-copied to _ubuf - */ + TensorWrapper get_ubuf_output(CommOverlapType comm_type); + + void copy_into_ubuf(cudaStream_t stream, TensorWrapper &input, CommOverlapType comm_type); + void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main); - /* - ** Split FPROP GEMM + ReduceScatter - */ void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, - TensorWrapper &rs_output, cudaStream_t stream_main); + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main); - /* - ** Split FPROP GEMM + ReduceScatter - */ void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output, + bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main); + + void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NOT_SUPPORTED_ERROR(); + } + + void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NOT_SUPPORTED_ERROR(); + } }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { @@ -155,44 +217,39 @@ class CommOverlapP2PBase : public CommOverlapCore { virtual ~CommOverlapP2PBase(); - /* - ** Split AllGather + AtomicGEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG - ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. - */ - void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + TensorWrapper get_ubuf_output(CommOverlapType comm_type); + + void copy_into_ubuf(cudaStream_t stream, TensorWrapper &input, CommOverlapType comm_type); + + void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main) { + NOT_SUPPORTED_ERROR(); + } + + void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, + bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main); - /* - ** Split AllGather + GEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG - ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. - */ - void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, + bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main); - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, + bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main); - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, + bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main); }; // CommOverlapP2PBase diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index d302518235..6fdc93098f 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -393,7 +393,7 @@ class TensorWrapper { return nvte_tensor_scale_inv(tensor_); } - private: + protected: /*! \brief Wrapped NVTETensor. */ NVTETensor tensor_ = nullptr; }; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1a3b49f9fa..b92a993d49 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -93,7 +93,7 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { return ret; } -size_t nvte_tensor_ndim(const NVTETensor tensor) { +size_t nvte_tensor_ndims(const NVTETensor tensor) { const auto &t = *reinterpret_cast(tensor); return t.data.shape.size(); } diff --git a/transformer_engine/common/util/dlpack_helper.h b/transformer_engine/common/util/dlpack_helper.h new file mode 100644 index 0000000000..cd8210e37a --- /dev/null +++ b/transformer_engine/common/util/dlpack_helper.h @@ -0,0 +1,188 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_DLPACK_HELPER_H +#define TRANSFORMER_ENGINE_COMMON_UTIL_DLPACK_HELPER_H + +#include +#include +#include + +#include "cuda_runtime.h" +#include "logging.h" + +namespace transformer_engine { + +DLDataType nvte_dtype_to_dldtype(DType dtype) { + DLDataType dldtype; + dldtype.lanes = 1; + switch (dtype) { + case DType::kInt64: + dldtype.bits = 64; + dldtype.code = DLDataTypeCode::kDLInt; + break; + + case DType::kInt32: + dldtype.bits = 32; + dldtype.code = DLDataTypeCode::kDLInt; + break; + + case DType::kByte: + dldtype.bits = 8; + dldtype.code = DLDataTypeCode::kDLUInt; + break; + + case DType::kFloat32: + dldtype.bits = 32; + dldtype.code = DLDataTypeCode::kDLFloat; + break; + + case DType::kFloat16: + dldtype.bits = 16; + dldtype.code = DLDataTypeCode::kDLFloat; + break; + + case DType::kBFloat16: + dldtype.bits = 16; + dldtype.code = DLDataTypeCode::kDLBfloat; + break; + + case DType::kFloat8E4M3: + dldtype.bits = 8; + dldtype.code = DLDataTypeCode::kDLFloat; + break; + + case DType::kFloat8E5M2: + dldtype.bits = 8; + dldtype.code = DLDataTypeCode::kDLFloat; + break; + + default: + NVTE_ERROR("Unrecognized transformer_engine::DType."); + } + return dldtype; +} + +DType dldtype_to_nvte_dtype(const DLDataType &dldtype, bool grad) { + NVTE_CHECK(dldtype.lanes == 1, "Unsupported number of lanes in DLDataType: ", dldtype.lanes); + + switch (dldtype.code) { + case DLDataTypeCode::kDLInt: + switch (dldtype.bits) { + case 64: + return DType::kInt64; + + case 32: + return DType::kInt32; + + default: + NVTE_ERROR("Unsupported bits in integer DLDataType: ", dldtype.bits); + } + + case DLDataTypeCode::kDLFloat: + switch (dldtype.bits) { + case 32: + return DType::kFloat32; + + case 16: + return DType::kFloat16; + + case 8: + if (grad) { + return DType::kFloat8E5M2; + } else { + return DType::kFloat8E4M3; + } + + default: + NVTE_ERROR("Unsupported bits in float DLDataType: ", dldtype.bits); + } + + case DLDataTypeCode::kDLBfloat: + if (dldtype.bits == 16) { + return DType::kBFloat16; + } else { + NVTE_ERROR("Unsupported bits in bfloat DLDataType: ", dldtype.bits); + } + + case DLDataTypeCode::kDLBool: + case DLDataTypeCode::kDLUInt: + if (dldtype.bits == 8) { + return DType::kByte; + } else { + NVTE_ERROR("Unsupported bits in unsigned int DLDataType: ", dldtype.bits); + } + + default: + NVTE_ERROR("Unsupported DLDataType."); + } +} + +class DLPackWrapper : public TensorWrapper { + protected: + DLManagedTensor managed_tensor; + + public: + // Inherit TensorWrapper constructors + using TensorWrapper::TensorWrapper; + + // Construct a new DLPackWrapper from existing TensorWrapper + DLPackWrapper(TensorWrapper &&other) : TensorWrapper(std::move(other)) {} + + // New constructor from PyObject + DLPackWrapper(pybind11::object obj, bool grad = false) { + NVTE_CHECK(PyCapsule_CheckExact(obj.ptr()), "Expected DLPack capsule"); + + DLManagedTensor *dlMTensor = (DLManagedTensor *)PyCapsule_GetPointer(obj.ptr(), "dltensor"); + NVTE_CHECK(dlMTensor, "Invalid DLPack capsule."); + + DLTensor *dlTensor = &dlMTensor->dl_tensor; + NVTE_CHECK(dlTensor->device.device_type == DLDeviceType::kDLCUDA, + "DLPack tensor is not on a CUDA device."); + NVTE_CHECK(dlTensor->device.device_id == cuda::current_device(), + "DLPack tensor resides on a different device."); + + if (dlTensor->strides) { + for (int idx = dlTensor->ndim - 1; idx >= 0; ++idx) { + NVTE_CHECK(dlTensor->strides[idx] == 1, + "DLPack tensors with non-standard strides are not supported."); + } + } + + NVTEShape shape; + shape.data = reinterpret_cast(dlTensor->shape); + shape.ndim = static_cast(dlTensor->ndim); + this->tensor_ = nvte_create_tensor( + dlTensor->data, shape, static_cast(dldtype_to_nvte_dtype(dlTensor->dtype, grad)), + nullptr, nullptr, nullptr); + } + + pybind11::object capsule() { + DLDevice tensor_context; + tensor_context.device_type = DLDeviceType::kDLCUDA; + tensor_context.device_id = cuda::current_device(); + + DLTensor dlTensor; + dlTensor.data = dptr(); + dlTensor.device = tensor_context; + dlTensor.ndim = ndim(); + dlTensor.dtype = nvte_dtype_to_dldtype(dtype()); + dlTensor.shape = reinterpret_cast(const_cast(shape().data)); + dlTensor.strides = nullptr; + dlTensor.byte_offset = 0; + + managed_tensor.dl_tensor = dlTensor; + managed_tensor.manager_ctx = nullptr; + managed_tensor.deleter = [](DLManagedTensor *) {}; + + return pybind11::reinterpret_steal( + PyCapsule_New(&managed_tensor, "dltensor", nullptr)); + } +}; + +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index a36ff3f0f9..6fa9574f63 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -28,7 +28,8 @@ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI) \ + .export_values(); \ pybind11::enum_(m, "NVTE_Mask_Type") \ .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ @@ -36,11 +37,13 @@ .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) \ + .export_values(); \ pybind11::enum_(m, "NVTE_QKV_Format") \ .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ - .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD); \ + .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD) \ + .export_values(); \ pybind11::enum_(m, "NVTE_QKV_Layout") \ .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ @@ -56,12 +59,14 @@ .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \ + .export_values(); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) \ + .export_values(); \ pybind11::enum_(m, "NVTE_Activation_Type") \ .value("GELU", NVTE_Activation_Type::GELU) \ .value("GEGLU", NVTE_Activation_Type::GEGLU) \ @@ -72,7 +77,8 @@ .value("QGELU", NVTE_Activation_Type::QGELU) \ .value("QGEGLU", NVTE_Activation_Type::QGEGLU) \ .value("SRELU", NVTE_Activation_Type::SRELU) \ - .value("SREGLU", NVTE_Activation_Type::SREGLU); \ + .value("SREGLU", NVTE_Activation_Type::SREGLU) \ + .export_values(); \ pybind11::enum_(m, "CommOverlapType") \ .value("RS", transformer_engine::CommOverlapType::RS) \ .value("AG", transformer_engine::CommOverlapType::AG); \ diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 250e8e0c29..2ff98c20d9 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -3,9 +3,8 @@ # See LICENSE for license information. import warnings import operator -from functools import reduce +from functools import reduce, partial from typing import Optional, Tuple -from collections.abc import Iterable import jax import jax.numpy as jnp @@ -30,6 +29,7 @@ global_mesh_resource, lax_paral_op, all_reduce_max_along_all_axes_except_PP, + get_mesh_axis_size, ) @@ -38,6 +38,14 @@ "gemm_impl", ] +_COMM_GEMM_OVERLAP_LAYERS = ["qkv", "proj", "fc1", "fc2"] +_COMM_GEMM_OVERLAP_NAMES = ( + [layer + "_fprop" for layer in _COMM_GEMM_OVERLAP_LAYERS] + + [layer + "_dgrad" for layer in _COMM_GEMM_OVERLAP_LAYERS] + + [layer + "_wgrad" for layer in _COMM_GEMM_OVERLAP_LAYERS if layer != "fc2"] + + ["generic_ag", "generic_rs"] +) + def sanitize_dims(dim, ndims): return (ndims + dim) if dim < 0 else dim @@ -60,7 +68,7 @@ class CollectiveGemmPrimitive(BasePrimitive): """ name = "te_gemm" - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15) + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16) multiple_results = True inner_primitive = None outer_primitive = None @@ -83,12 +91,18 @@ def abstract( grad, accumulate, use_split_accumulator, + comm_overlap_config, ): """ cuBlasLt GEMM abstract """ del grad, accumulate, use_split_accumulator + assert tex.ubuf_built_with_mpi(), ( + "Comm+GEMM overlap in TE/JAX requires Transformer Engine to be compiled with " + + "`NVTE_UB_WITH_MPI=1` and `MPI_HOME=/path/to/mpi` options." + ) + # Validate operand dtypes lhs_dtype = dtypes.canonicalize_dtype(lhs_aval.dtype) rhs_dtype = dtypes.canonicalize_dtype(rhs_aval.dtype) @@ -106,13 +120,13 @@ def abstract( and dtypes.canonicalize_dtype(rhs_scale_inv_aval.dtype) == jnp.float32 ), "Missing RHS operand scale inverse in FP8 GEMM." - # Validate operand layouts + # Validate operand layouts, adjusted for comm-overlap if necessary lhs_inner_dim, rhs_inner_dim = map( sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) ) assert ( lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] - ), f"Incompatible operand sizes: {lhs_aval.shape} x {rhs_aval.shape}." + ), f"Incompatible contracting dimensions: {lhs_aval.shape} x {rhs_aval.shape}." lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 @@ -153,6 +167,18 @@ def abstract( lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) + if rhs_aval.ndim > 2: + rhs_bdims = [ + dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] + ] + rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] + rhs_batch_size = reduce(operator.mul, rhs_bdims, 1) + if rhs_batch_size > 1: + assert lhs_batch_size == rhs_batch_size, ( + f"Leading dimensins of RHS ({rhs_batch_shape=}) is not broadcast-compatible " + + f"with the leading dimensions of LHS ({lhs_batch_shape=})." + ) + # Infer output shape if batched_output: assert ( @@ -204,6 +230,43 @@ def abstract( else: assert gelu_input_aval.size == 0, "Internal TE error." + # Adjust output sizes for comm-overlap + extra_out_shape = (0,) + extra_out_dtype = jnp.bfloat16 + if comm_overlap_config is not None: + comm_overlap_type = comm_overlap_config.get("comm_type", None) + assert comm_overlap_type is not None, "Missing comm type for comm+GEMM overlap." + comm_overlap_name = comm_overlap_config.get("name", None) + assert ( + comm_overlap_name in _COMM_GEMM_OVERLAP_NAMES + ), f"Unrecognized comm+GEMM overlap name: {comm_overlap_name=}" + + mesh = comm_overlap_config.get("mesh", None) + tp_resource = comm_overlap_config.get("tp_resource", global_mesh_resource().tp_resource) + tp_size = get_mesh_axis_size(tp_resource, mesh=mesh) + + match comm_overlap_type: + case tex.CommOverlapType.AG: + # Extra output is all-gathered LHS copy + extra_out_shape = list(lhs_aval.shape).copy() + extra_out_shape[lhs_outer_dim] *= tp_size + extra_out_dtype = lhs_dtype + + case tex.CommOverlapType.RS: + # FP8 GEMM output for RS overlap is always FP8 + if jax_dtype_is_fp8(lhs_dtype): + assert jax_dtype_is_fp8( + out_dtype + ), "FP8 GEMM with reduce-scatter overlap requires FP8 output." + # Extra output is reduce-scattered GEMM output + extra_out_shape = list(out_shape).copy() + extra_out_shape[-2] /= tp_size + + case _: + raise RuntimeError( + f"Unrecognized comm type for comm+GEMM overlap: {comm_overlap_type=}" + ) + # Create abstract arrays for all outputs out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype) out_amax_updated_aval = out_amax_aval.update( @@ -214,6 +277,7 @@ def abstract( ) pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_shape, dtype=bias_dtype) bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) + extra_out_aval = jax.core.ShapedArray(shape=extra_out_shape, dtype=extra_out_dtype) workspace_aval = jax.core.ShapedArray( shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 ) @@ -224,6 +288,7 @@ def abstract( out_scale_updated_aval, pre_gelu_out_aval, bias_grad_aval, + extra_out_aval, # global LHS for AG overlap, or sharded output for RS overlap workspace_aval, ) @@ -232,10 +297,23 @@ def outer_abstract(*args, **kwargs): """ cuBlasLt GEMM outer abstract """ - (out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval, _) = ( - CollectiveGemmPrimitive.abstract(*args, **kwargs) + ( + out_aval, + out_amax_aval, + out_scale_aval, + pre_gelu_out_aval, + bias_grad_aval, + extra_out_aval, + *_, + ) = CollectiveGemmPrimitive.abstract(*args, **kwargs) + return ( + out_aval, + out_amax_aval, + out_scale_aval, + pre_gelu_out_aval, + bias_grad_aval, + extra_out_aval, ) - return out_aval, out_amax_aval, out_scale_aval, pre_gelu_out_aval, bias_grad_aval @staticmethod def lowering( @@ -257,6 +335,7 @@ def lowering( grad, accumulate, use_split_accumulator, + comm_overlap_config, ): """ Fused attention fwd lowering rules @@ -278,7 +357,7 @@ def lowering( if is_ffi_enabled(): name = "te_gemm_ffi" - return ffi.ffi_lowering(name, operand_output_aliases=operand_output_aliases)( + ffi_args = ( ctx, lhs, lhs_scale_inv, @@ -288,6 +367,8 @@ def lowering( gelu_input, out_amax, out_scale, + ) + ffi_kwargs = dict( lhs_trans=lhs_trans, rhs_trans=rhs_trans, fuse_gelu=fuse_gelu, @@ -296,6 +377,15 @@ def lowering( accumulate=accumulate, use_split_accumulator=use_split_accumulator, ) + + if comm_overlap_config is not None: + name = "te_comm_gemm_overlap_ffi" + ffi_kwargs["comm_type"] = int(comm_overlap_config["comm_type"]) + ffi_kwargs["name"] = comm_overlap_config["name"] + + return ffi.ffi_lowering(name, operand_output_aliases=operand_output_aliases)( + *ffi_args, **ffi_kwargs + ) else: operands = [ lhs, @@ -325,7 +415,9 @@ def lowering( workspace_size = get_cublas_workspace_size_bytes() operand_dtype = jax_dtype_to_te_dtype(lhs_aval.dtype) bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype) - opaque = tex.pack_gemm_descriptor( + + descriptor_packer_fn = tex.pack_gemm_decriptor + descriptor_args = ( m, n, k, @@ -342,6 +434,16 @@ def lowering( use_split_accumulator, ) + comm_overlap_type = comm_overlap_config.get("comm_type", None) + if comm_overlap_type is not None: + name = "te_comm_gemm_overlap" + descriptor_packer_fn = tex.pack_overlap_descriptor + descriptor_args += ( + comm_overlap_type, + comm_overlap_config.get("name", None), + ) + opaque = descriptor_packer_fn(*descriptor_args) + return custom_caller( CollectiveGemmPrimitive.name, args, @@ -368,6 +470,7 @@ def impl( grad, accumulate, use_split_accumulator, + comm_overlap_config, ): assert CollectiveGemmPrimitive.inner_primitive is not None @@ -430,6 +533,7 @@ def impl( out_scale_updated, pre_gelu_out, bias_grad, + extra_out, _, ) = CollectiveGemmPrimitive.inner_primitive.bind( lhs, @@ -448,6 +552,7 @@ def impl( grad=grad, accumulate=accumulate, use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, ) # Recover batched dimensions in the output @@ -455,7 +560,7 @@ def impl( out_shape = (*lhs_batch_shape, out.shape[-2] // lhs_batch_size, out.shape[-1]) out = jax.lax.reshape(out, out_shape) - return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad, extra_out @staticmethod def batcher( @@ -470,6 +575,7 @@ def batcher( grad, accumulate, use_split_accumulator, + comm_overlap_config, ): assert CollectiveGemmPrimitive.outer_primitive is not None check_valid_batch_dims(batch_dims) @@ -500,6 +606,7 @@ def infer_sharding_from_operands( grad, accumulate, use_split_accumulator, + comm_overlap_config, mesh, arg_infos, result_infos, @@ -515,48 +622,59 @@ def infer_sharding_from_operands( (lhs.ndim, rhs.ndim), ) - # Modify operand specs: - # - If contracting dimensions of both operands are sharded, force them to match. - # - If contracting dimensions of both operands are sharded, all-gather outer dimensions. - # - If contracting dimension of only one operand is sharded, all-gather the sharded - # operand. - # - Never scatter any operand. - lhs_spec_new = list(lhs_spec).copy() - rhs_spec_new = list(rhs_spec).copy() - lhs_spec_new[lhs_outer_dim] = None - if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: - assert ( - lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim] - ), "Contracting dimensions of LHS and RHS operands must have the same sharding." - if lhs_spec_new[lhs_outer_dim] is not None: - warnings.warn( - "Outer dimension of the LHS operand must be all-gathered when both contracting " - + "dimensions are sharded. This will cause additional communication overhead." - ) + # Modify operand specs + lhs_spec_new = [spec for spec in lhs_spec] + rhs_spec_new = [spec for spec in rhs_spec] + reduce_output = False + if comm_overlap_config is None: + # When comm overlap is not enabled: + # - Always all-gather the outer dimension of LHS. + # - If contracting dims of both operands are sharded, all-gather RHS outer dim. + # - If contracting dim of only one operand is sharded, all-gather the sharded operand. + # - Never scatter any operand. + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + assert ( + lhs_spec_new[lhs_inner_dim] == rhs_spec_new[rhs_inner_dim] + ), "Contracting dimensions of LHS and RHS operands must have the same sharding." + if lhs_spec_new[lhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the LHS operand must be all-gathered when both " + + "contracting dimensions are sharded. This will cause additional " + + "communication overhead." + ) - if rhs_spec_new[rhs_outer_dim] is not None: - warnings.warn( - "Outer dimension of the RHS operand must be all-gathered when both contracting " - + "dimensions are sharded. This will cause additional communication overhead." - ) - rhs_spec_new[rhs_outer_dim] = None - else: - if lhs_spec_new[lhs_inner_dim] is None and rhs_spec_new[rhs_inner_dim] is not None: - warnings.warn( - "Contracting dimension of the RHS operand must be all-gathered when the " - + "contracting dimension of the LHS operand is unsharded. This will cause " - + "additional communication overhead." - ) - if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is None: - if not grad: - # This is expected for sequence/context-parallel gradient in BWD (DGRAD) GEMM. + if rhs_spec_new[rhs_outer_dim] is not None: + warnings.warn( + "Outer dimension of the RHS operand must be all-gathered when both " + + "contracting dimensions are sharded. This will cause additional " + + "communication overhead." + ) + rhs_spec_new[rhs_outer_dim] = None + reduce_output = True + else: + if lhs_spec_new[lhs_inner_dim] is None and rhs_spec_new[rhs_inner_dim] is not None: warnings.warn( - "Contracting dimension of the LHS operand must be all-gathered when the " - + "contracting dimension of the RHS operand is unsharded. This will cause " + "Contracting dimension of the RHS operand must be all-gathered when the " + + "contracting dimension of the LHS operand is unsharded. This will cause " + "additional communication overhead." ) - lhs_spec_new[lhs_inner_dim] = None - rhs_spec_new[rhs_inner_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is None: + if not grad: + # This is expected for sequence/context-parallel gradient in BWD (DGRAD) GEMM. + warnings.warn( + "Contracting dimension of the LHS operand must be all-gathered when " + + "the contracting dimension of the RHS operand is unsharded. This " + + "will cause additional communication overhead." + ) + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None + else: + # When comm overlap is enabled, make sure both contracting dims are unsharded if one + # of them is unsharded. + if lhs_spec_new[lhs_inner_dim] is None or rhs_spec_new[rhs_inner_dim] is None: + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None out_col_spec = rhs_spec_new[rhs_outer_dim] # Output sharding is conditional on output shape @@ -577,7 +695,50 @@ def infer_sharding_from_operands( # Bias gradient spec matches outer dimension of output if bias fusion is turned on bias_sharding = NamedSharding(mesh, PartitionSpec(out_col_spec if fuse_bias else None)) - return (out_sharding, fp8_meta_sharding, fp8_meta_sharding, gelu_sharding, bias_sharding) + # Validate operand sharding for comm+GEMM overlap and adust extra output sharding + extra_out_spec = [None] + if comm_overlap_config is not None: + mesh = comm_overlap_config.get("mesh", None) + tp_resource = comm_overlap_config.get("tp_resource", global_mesh_resource().tp_resource) + match comm_overlap_config.get("comm_type", None): + case tex.CommOverlapType.AG: + # AG overlap requires the outer dimension of LHS to be sharded + # over the TP resource + assert lhs_spec[lhs_outer_dim] == tp_resource, ( + "AG+GEMM overlap requires the outer (sequence) dimension of the LHS " + + f"operand to be sharded over the TP resource (mesh axis: {tp_resource=})." + ) + extra_out_spec = list(lhs_spec).copy() + extra_out_spec[lhs_outer_dim] = None + + case tex.CommOverlapType.RS: + # RS overlap requires the contracting dimensions of both LHS and RHS to be + # sharded over the TP resource, and the outer dimension of LHS to be unsharded + assert lhs_spec[lhs_outer_dim] is None, ( + "GEMM+RS overlap requires the outer (sequence) dimension of the LHS " + + "operand to be un-sharded." + ) + assert lhs_spec[lhs_inner_dim] == tp_resource, ( + "GEMM+RS overlap requires the contracting dimension of the LHS operand " + + f"to be sharded over the TP resource (mesh axis: {tp_resource=})." + ) + assert rhs_spec[rhs_inner_dim] == tp_resource, ( + "GEMM+RS overlap requires the contracting dimension of the RHS operand " + + f"to be sharded over the TP resource (mesh axis: {tp_resource=})." + ) + extra_out_spec = out_spec.copy() + extra_out_spec[-2] = tp_resource + + extra_out_sharding = NamedSharding(mesh, PartitionSpec(*extra_out_spec)) + + return ( + out_sharding, + fp8_meta_sharding, + fp8_meta_sharding, + gelu_sharding, + bias_sharding, + extra_out_sharding, + ) @staticmethod def partition( @@ -589,6 +750,7 @@ def partition( grad, accumulate, use_split_accumulator, + comm_overlap_config, mesh, arg_infos, result_infos, @@ -604,23 +766,31 @@ def partition( (lhs.ndim, rhs.ndim), ) - # Modify operand specs: - # - Always all-gather the outer dimension of LHS. - # - If contracting dimensions of both operands are sharded, all-gather RHS outer dimension. - # - If contracting dimension of only one operand is sharded, all-gather the sharded - # operand. - # - Never scatter any operand. - lhs_spec_new = list(lhs_spec).copy() - rhs_spec_new = list(rhs_spec).copy() + # Modify operand specs + lhs_spec_new = [spec for spec in lhs_spec] + rhs_spec_new = [spec for spec in rhs_spec] reduce_output = False - lhs_spec_new[lhs_outer_dim] = None - if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: - rhs_spec_new[rhs_outer_dim] = None - reduce_output = True + if comm_overlap_config is None: + # When comm overlap is not enabled: + # - Always all-gather the outer dimension of LHS. + # - If contracting dims of both operands are sharded, all-gather RHS outer dim. + # - If contracting dim of only one operand is sharded, all-gather the sharded operand. + # - Never scatter any operand. + lhs_spec_new[lhs_outer_dim] = None + if lhs_spec_new[lhs_inner_dim] is not None and rhs_spec_new[rhs_inner_dim] is not None: + rhs_spec_new[rhs_outer_dim] = None + reduce_output = True + else: + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None else: - lhs_spec_new[lhs_inner_dim] = None - rhs_spec_new[rhs_inner_dim] = None + # When comm overlap is enabled, make sure both contracting dims are unsharded if one + # of them is unsharded. + if lhs_spec_new[lhs_inner_dim] is None or rhs_spec_new[rhs_inner_dim] is None: + lhs_spec_new[lhs_inner_dim] = None + rhs_spec_new[rhs_inner_dim] = None out_col_spec = rhs_spec_new[rhs_outer_dim] + lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_spec_new)) @@ -642,6 +812,22 @@ def partition( gelu_spec = [None, out_col_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) + # Adjust extra output sharding for comm+GEMM overlap + extra_out_spec = [None] + if comm_overlap_config is not None: + mesh = comm_overlap_config.get("mesh", None) + tp_resource = comm_overlap_config.get("tp_resource", global_mesh_resource().tp_resource) + match comm_overlap_config.get("comm_type", None): + case tex.CommOverlapType.AG: + extra_out_spec = list(lhs_spec).copy() + extra_out_spec[lhs_outer_dim] = None + + case tex.CommOverlapType.RS: + extra_out_spec = out_spec.copy() + extra_out_spec[-2] = tp_resource + + extra_out_sharding = NamedSharding(mesh, PartitionSpec(*extra_out_spec)) + arg_shardings = ( lhs_sharding, fp8_meta_sharding, @@ -658,6 +844,7 @@ def partition( fp8_meta_sharding, gelu_sharding, bias_sharding, + extra_out_sharding, ) def sharded_impl( @@ -669,6 +856,7 @@ def sharded_impl( out_scale_updated, pre_gelu_out, bias_grad, + extra_out, ) = CollectiveGemmPrimitive.impl( lhs, lhs_scale_inv, @@ -686,6 +874,7 @@ def sharded_impl( grad=grad, accumulate=accumulate, use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, ) # FP8 amax reduction @@ -693,12 +882,15 @@ def sharded_impl( out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) # All-reduce sum GEMM output when contracting dimensions are sharded - if reduce_output: - out = jax.lax.psum(out, global_mesh_resource().tp_resource) - if fuse_gelu: - pre_gelu_out = jax.lax.psum(pre_gelu_out, global_mesh_resource().tp_resource) + if comm_overlap_config is None: + if reduce_output: + out = jax.lax.psum(out, global_mesh_resource().tp_resource) + if fuse_gelu: + pre_gelu_out = jax.lax.psum( + pre_gelu_out, global_mesh_resource().tp_resource + ) - return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad + return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad, extra_out return mesh, sharded_impl, out_shardings, arg_shardings @@ -706,62 +898,6 @@ def sharded_impl( register_primitive(CollectiveGemmPrimitive) -def fp8_gemm_impl( - lhs: ArrayLike, - lhs_scale_inv: ArrayLike, - rhs_t: ArrayLike, - rhs_scale_inv: ArrayLike, - bias: Optional[ArrayLike] = None, - gelu_input: Optional[ArrayLike] = None, - out_amax: Optional[ArrayLike] = None, - out_scale: Optional[ArrayLike] = None, - out_dtype: jnp.dtype = jnp.bfloat16, - batched_output: bool = False, - fuse_gelu: bool = False, - fuse_bias: bool = False, - accumulate: bool = False, - use_split_accumulator: bool = False, -) -> Tuple[ArrayLike, ...]: - """FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" - if out_dtype is not None and jax_dtype_is_fp8(out_dtype): - assert out_amax is not None and out_scale is not None, "Missing output amax and scale." - else: - out_amax = jnp.zeros(0, dtype=jnp.float32) - out_scale = jnp.zeros(0, dtype=jnp.float32) - - if not fuse_bias: - bias = jnp.zeros(0, dtype=jnp.bfloat16) - else: - assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." - - if not fuse_gelu: - gelu_input = jnp.zeros(0, dtype=bias.dtype) - elif gelu_input is None: - gelu_shape = (reduce(operator.mul, lhs.shape[:-1]), rhs_t.shape[-1]) - gelu_input = jnp.zeros(gelu_shape, dtype=bias.dtype) - - out, out_amax, out_scale, pre_gelu_out, _ = CollectiveGemmPrimitive.outer_primitive.bind( - lhs, - lhs_scale_inv, - rhs_t, - rhs_scale_inv, - bias, - gelu_input, - out_amax, - out_scale, - out_dtype=out_dtype, - batched_output=batched_output, - contracting_dims=(-1, -1), - fuse_gelu=fuse_gelu, - fuse_bias=fuse_bias, - grad=False, - accumulate=accumulate, - use_split_accumulator=use_split_accumulator, - ) - - return out, out_amax, out_scale, pre_gelu_out - - def gemm_impl( lhs: ArrayLike, rhs: ArrayLike, @@ -774,19 +910,19 @@ def gemm_impl( grad: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, + comm_overlap_config: Optional[dict] = None, ) -> Tuple[ArrayLike, ...]: """Non-FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) - lhs_outer_dim, rhs_outer_dim = map( - mirror_dim, - (lhs_inner_dim, rhs_inner_dim), - (lhs.ndim, rhs.ndim), - ) + lhs_outer_dim = lhs.ndim - 1 if lhs_inner_dim != lhs.ndim - 1 else lhs.ndim - 2 + rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 + out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) if not fuse_bias: bias = jnp.zeros(0, dtype=lhs.dtype) elif grad: - bias = jnp.zeros(rhs.shape[rhs_outer_dim], dtype=lhs.dtype) + bias = jnp.zeros(out_shape[-1], dtype=lhs.dtype) else: assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." @@ -797,13 +933,16 @@ def gemm_impl( gelu_input is not None ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." elif gelu_input is None: - bdims = [dim for dim in range(lhs.ndim) if dim not in [lhs_inner_dim, lhs_outer_dim]] - batch_size = reduce(operator.mul, [lhs.shape[dim] for dim in bdims], 1) - gelu_shape = (batch_size * lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) - gelu_input = jnp.zeros(gelu_shape, dtype=lhs.dtypes) - - dummy_fp8_meta = jnp.zeros(0, dtype=jnp.float32) - out, _, _, pre_gelu_out, bias_grad = CollectiveGemmPrimitive.outer_primitive.bind( + gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) + + ( + out, + _, # out_amax in FP8 GEMM + _, # out_scale in FP8 GEMM + pre_gelu_out, + bias_grad, + extra_out, + ) = CollectiveGemmPrimitive.outer_primitive.bind( lhs, dummy_fp8_meta, rhs, @@ -820,9 +959,156 @@ def gemm_impl( grad=grad, accumulate=accumulate, use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, ) if grad: - return out, pre_gelu_out, bias_grad + return out, pre_gelu_out, bias_grad, extra_out + else: + return out, pre_gelu_out, extra_out + + +def fp8_gemm_impl( + lhs: ArrayLike, + lhs_scale_inv: ArrayLike, + rhs_t: ArrayLike, + rhs_scale_inv: ArrayLike, + bias: Optional[ArrayLike] = None, + gelu_input: Optional[ArrayLike] = None, + out_amax: Optional[ArrayLike] = None, + out_scale: Optional[ArrayLike] = None, + out_dtype: jnp.dtype = jnp.bfloat16, + batched_output: bool = False, + fuse_gelu: bool = False, + fuse_bias: bool = False, + accumulate: bool = False, + use_split_accumulator: bool = False, + comm_overlap_config: Optional[dict] = None, +) -> Tuple[ArrayLike, ...]: + """FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" + if out_dtype is not None and jax_dtype_is_fp8(out_dtype): + assert out_amax is not None and out_scale is not None, "Missing output amax and scale." else: - return out, pre_gelu_out + out_amax = jnp.zeros(0, dtype=jnp.float32) + out_scale = jnp.zeros(0, dtype=jnp.float32) + + if not fuse_bias: + bias = jnp.zeros(0, dtype=jnp.bfloat16) + else: + assert bias is not None, "Missing bias in forward GEMM when bias epilogue is enabled." + + if not fuse_gelu: + gelu_input = jnp.zeros(0, dtype=bias.dtype) + elif gelu_input is None: + gelu_shape = (reduce(operator.mul, lhs.shape[:-1]), rhs_t.shape[-1]) + gelu_input = jnp.zeros(gelu_shape, dtype=bias.dtype) + + (out, out_amax, out_scale, pre_gelu_out, _, extra_out) = ( # bias_grad in non-FP8 GEMM + CollectiveGemmPrimitive.outer_primitive.bind( + rhs_t, + rhs_scale_inv, + lhs, + lhs_scale_inv, + bias, + gelu_input, + out_amax, + out_scale, + out_dtype=out_dtype, + batched_output=batched_output, + contracting_dims=(-1, -1), + fuse_gelu=fuse_gelu, + fuse_bias=fuse_bias, + grad=False, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, + ) + ) + + return out, out_amax, out_scale, pre_gelu_out, extra_out + + +class CopyIntoOverlapBufferPrimitive(BasePrimitive): + """ + Copy JAX array data into comm+GEMM overlap buffer + """ + + name = "te_copy_into_overlap_buffer" + impl_static_args = (1, 2) + multiple_results = False + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(inp_aval, name, comm_type): + assert name in _COMM_GEMM_OVERLAP_NAMES, f"Unrecognized comm+GEMM overlap name: {name=}" + assert comm_type in [ + tex.CommOverlapType.AG, + tex.CommOverlapType.RS, + ], "Invalid comm+GEMM overlap type." + assert inp_aval.size > 0, "Cannot copy a zero-size array into overlap buffer." + assert inp_aval.ndim == 2, "Cannot copy more than 2 dimensions!" + return jax.core.ShapedArray(shape=(0,), dtype=dtypes.canonicalize_dtype(inp_aval.dtype)) + + @staticmethod + def lowering(ctx, inp, *, name, comm_type): + if is_ffi_enabled(): + name = "te_copy_into_overlap_buffer_ffi" + return ffi.ffi_lowering(name)( + ctx, + inp, + name=name, + comm_type=int(comm_type), + ) + else: + operands = [inp] + operand_shapes = [ir.RankedTensorType(inp.type).shape] + out_types = [] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + opaque = tex.pack_buffer_descriptor( + name, inp.shape, jax_dtype_to_te_dtype(inp.dtype), comm_type + ) + return custom_caller(CopyIntoOverlapBufferPrimitive.name, args, opaque, False) + + @staticmethod + def impl(inp, name, comm_type): + assert CopyIntoOverlapBufferPrimitive.inner_primitive is not None + return CopyIntoOverlapBufferPrimitive.inner_primitive.bind( + inp, name=name, comm_type=comm_type + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, name, comm_type): + assert CopyIntoOverlapBufferPrimitive.inner_primitive is not None + check_valid_batch_dims(batch_dims) + return ( + CopyIntoOverlapBufferPrimitive.inner_primitive.bind( + *batched_args, name=name, comm_type=comm_type + ), + None, + ) + + @staticmethod + def infer_sharding_from_operands(name, comm_type, mesh, arg_infos, result_infos): + del name, comm_type, arg_infos, result_infos + return NamedSharding(mesh, PartitionSpec(None)) + + @staticmethod + def partition(name, comm_type, mesh, arg_infos, result_infos): + del name, comm_type, result_infos + inp_spec = arg_infos[0] + arg_shardings = (NamedSharding(mesh, PartitionSpec(*inp_spec)),) + out_sharding = NamedSharding(mesh, PartitionSpec(None)) + return ( + mesh, + partial(CopyIntoOverlapBufferPrimitive.impl, name=name, comm_type=comm_type), + out_sharding, + arg_shardings, + ) + + +register_primitive(CopyIntoOverlapBufferPrimitive) + + +def copy_into_overlap_buffer(inp: ArrayLike, name: str, comm_type: tex.CommOverlapType) -> None: + _ = CollectiveGemmPrimitive.outer_primitive.bind(inp, name=name, comm_type=comm_type) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index afac283a6f..d123d9b5b4 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -4,8 +4,8 @@ * See LICENSE for license information. ************************************************************************/ -#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ -#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ +#ifndef TRANSFORMER_ENGINE_JAX_CSRC_EXTENSIONS_H_ +#define TRANSFORMER_ENGINE_JAX_CSRC_EXTENSIONS_H_ #include #include @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -148,7 +149,6 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( bool deterministic, int64_t window_size_left, int64_t window_size_right); struct CustomCallGemmDescriptor { - size_t batch; size_t m; size_t k; size_t n; @@ -165,13 +165,50 @@ struct CustomCallGemmDescriptor { bool use_split_accumulator; }; -pybind11::bytes PackCustomCallGemmDescriptor(size_t batch, size_t m, size_t n, size_t k, - size_t workspace_size, DType operand_dtype, - DType out_dtype, DType bias_dtype, bool lhs_trans, - bool rhs_trans, bool fuse_gelu, bool fuse_bias, - bool grad, bool accumulate, +pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, size_t workspace_size, + DType operand_dtype, DType out_dtype, DType bias_dtype, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator); +struct CustomCallBufferDescriptor { + const std::string name; + const size_t *shape; + const size_t ndim; + DType dtype; + CommOverlapType comm_type; +}; + +pybind11::bytes PackCustomCallBufferDescriptor(const std::string &name, + const std::vector &shape, DType dtype, + CommOverlapType comm_type); + +struct CustomCallOverlapDescriptor { + size_t m; + size_t k; + size_t n; + size_t workspace_size; + DType operand_dtype; + DType bias_dtype; + DType out_dtype; + bool lhs_trans; + bool rhs_trans; + bool fuse_gelu; + bool fuse_bias; + bool grad; + bool accumulate; + bool use_split_accumulator; + CommOverlapType comm_type; + const std::string name; +}; + +pybind11::bytes PackCustomCallOverlapDescriptor(size_t m, size_t k, size_t n, size_t workspace_size, + DType operand_dtype, DType bias_dtype, + DType out_dtype, bool lhs_trans, bool rhs_trans, + bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, const std::string &name); + // Transpose void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -341,13 +378,52 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, Result_Type out_amax_updated, Result_Type out_scale_updated, - Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, - bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, - bool accumulate, bool use_split_accumulator); + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type dummy_out, + Result_Type workspace, bool lhs_trans, bool rhs_trans, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); +// Comm+GEMM Overlap + +void BootstrapCommGemmOverlap(const std::string &name, const std::string &method, + const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, int tp_size, int num_splits, + int num_max_streams, int comm_cga_size, int num_comm_sm, + int set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm); + +void DestroyCommGemmOverlap(const std::string &name); + +void SetOverlapBufferScaleInverse(const std::string &name, pybind11::object scale_inv, + bool grad = false); + +bool OverlapBufferIsFp8(const std::string &name); + +pybind11::object GetOverlapBuffer(const std::string &name, CommOverlapType comm_type); + +void CopyIntoOverlapBuffer(cudaStream_t, void **buffers, const char *opaque, size_t opaque_len); + +Error_Type CopyIntoOverlapBufferFFI(cudaStream_t stream, Buffer_Type input, std::string_view name, + int32_t comm_type_flag); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(CopyIntoOverlapBufferHandler); + +void CommGemmOverlap(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); + +Error_Type CommGemmOverlapFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, + Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Buffer_Type out_amax, Buffer_Type out_scale, + Result_Type out, Result_Type out_amax_new, Result_Type out_scale_new, + Result_Type pre_gelu_out, Result_Type bias_grad, + Result_Type extra_out, Result_Type workspace, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator, int32_t comm_type_flag, + std::string_view name); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(CommGemmOverlapHandler); + } // namespace jax } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ +#endif // TRANSFORMER_ENGINE_JAX_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp new file mode 100644 index 0000000000..df1f4bdc23 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp @@ -0,0 +1,291 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common/util/dlpack_helper.h" +#include "extensions.h" + +void _dummy_allgather(void *global, size_t globalbytes, void *local, size_t localbytes, + ExtComm comm) {}; + +void _dummy_barrier(ExtComm comm) {}; + +namespace transformer_engine { + +namespace jax { + +static std::unordered_map _overlaps; + +void BootstrapCommGemmOverlap(const std::string &name, const std::string &method, + const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, int tp_size, int num_splits, + int num_max_streams, int comm_cga_size, int num_comm_sm, + int set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm) { +#ifndef NVTE_UB_WITH_MPI + NVTE_ERROR( + std::string("Comm+GEMM overlap in TE/JAX requires bootstrapping Userbuffers with MPI. ") + + std::string("Please compile TE with `NVTE_UB_WITH_MPI=1`.")); +#endif + + // Initialize overlap object -- this allocates the comm buffer + NVTE_CHECK(_overlaps.find(name) == _overlaps.end(), name, " is already initialized!"); + if (method == "ring-exchange") { + _overlaps[name] = reinterpret_cast(new CommOverlapP2PBase( + buffer_shape, buffer_dtype, -1, -1, -1, -1, -1, -1, tp_size, &_dummy_allgather, + &_dummy_barrier, comm_type, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, + use_ce, atomic_gemm, aggregate)); + } else { + _overlaps[name] = reinterpret_cast(new CommOverlapBase( + buffer_shape, buffer_dtype, -1, -1, -1, -1, -1, -1, tp_size, &_dummy_allgather, + &_dummy_barrier, num_splits, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, + atomic_gemm, pipeline_rs_overlap_first_gemm)); + } +}; + +void DestroyCommGemmOverlap(const std::string &name) { + auto overlap = _overlaps.find(name); + if (overlap != _overlaps.end()) { + delete overlap->second; + _overlaps.erase(overlap); + } +}; + +void SetOverlapBufferScaleInverse(const std::string &name, pybind11::object scale_inv, bool grad) { + auto scale_inv_tensor = DLPackWrapper(scale_inv, grad); + _overlaps[name]->set_ubuf_scale_inv(reinterpret_cast(scale_inv_tensor.dptr())); +} + +bool OverlapBufferIsFp8(const std::string &name) { return _overlaps[name]->is_fp8_ubuf(); } + +pybind11::object GetOverlapBuffer(const std::string &name, CommOverlapType comm_type) { + DLPackWrapper output = std::move(_overlaps[name]->get_ubuf_output(comm_type)); + auto capsule = output.capsule(); + return capsule; +}; + +void CopyIntoOverlapBufferImpl(cudaStream_t stream, void *input_ptr, + const std::vector &shape, DType dtype, + const std::string &name, CommOverlapType comm_type) { + auto input = TensorWrapper(input_ptr, shape, dtype); + _overlaps[name]->copy_into_ubuf(stream, input, comm_type); +} + +void CopyIntoOverlapBuffer(cudaStream_t stream, void **buffers, const char *opaque, + size_t opaque_len) { + auto input_ptr = buffers[0]; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + + CopyIntoOverlapBufferImpl(stream, input_ptr, + std::vector(desc.shape, desc.shape + desc.ndim), desc.dtype, + desc.name, desc.comm_type); +} + +Error_Type CopyIntoOverlapBufferFFI(cudaStream_t stream, Buffer_Type input, std::string_view name, + int32_t comm_type_flag) { + auto input_ptr = input.untyped_data(); + auto shape = std::vector(input.dimensions().begin(), input.dimensions().end()); + auto dtype = convert_ffi_datatype_to_te_dtype(input.element_type()); + + CopyIntoOverlapBufferImpl(stream, input_ptr, shape, dtype, static_cast(name), + static_cast(comm_type_flag)); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CopyIntoOverlapBufferHandler, CopyIntoOverlapBufferFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Attr("name") + .Attr("comm_type_flag"), + FFI_CudaGraph_Traits); + +void CommGemmOverlapImpl(void *lhs, const std::vector &lhs_shape, DType lhs_dtype, + float *lhs_scale_inv, bool lhs_trans, void *rhs, + const std::vector &rhs_shape, DType rhs_dtype, + float *rhs_scale_inv, bool rhs_trans, void *out, + const std::vector &out_shape, DType out_dtype, float *out_amax, + float *out_scale, void *bias, DType bias_dtype, void *pre_gelu_out, + void *extra_out, const std::vector &extra_out_shape, + void *workspace, size_t workspace_size, bool fuse_gelu, bool fuse_bias, + bool grad, bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, const std::string &name, cudaStream_t stream) { + auto lhs_ = TensorWrapper(lhs, lhs_shape, lhs_dtype, nullptr, nullptr, lhs_scale_inv); + auto rhs_ = TensorWrapper(rhs, rhs_shape, rhs_dtype, nullptr, nullptr, rhs_scale_inv); + auto out_ = TensorWrapper(out, out_shape, out_dtype, out_amax, out_scale, nullptr); + + auto bias_ptr = (fuse_bias) ? bias : nullptr; + auto bias_shape = (fuse_bias) ? std::vector(out_shape.back()) : std::vector{0}; + auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + + auto pre_gelu_ptr = (fuse_gelu) ? pre_gelu_out : nullptr; + auto pre_gelu_shape = (fuse_gelu) ? out_shape : std::vector{0}; + auto pre_gelu_out_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, bias_dtype); + + auto workspace_ = TensorWrapper(workspace, std::vector{workspace_size}, DType::kByte); + + auto extra_out_ = + TensorWrapper(extra_out, extra_out_shape, lhs_dtype, nullptr, nullptr, lhs_scale_inv); + + auto overlap = _overlaps[name]; + if (comm_type == CommOverlapType::AG) { + // AG overlap is only ring-exchange + if (overlap->is_atomic_gemm()) { + overlap->atomic_gemm_overlap_ag(rhs_, rhs_trans, lhs_, lhs_trans, out_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + extra_out_, stream); + } else { + overlap->split_overlap_ag(rhs_, rhs_trans, lhs_, lhs_trans, out_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, extra_out_, + stream); + } + } else if (comm_type == CommOverlapType::RS) { + if (overlap->is_atomic_gemm()) { + overlap->atomic_gemm_overlap_rs(rhs_, rhs_trans, lhs_, lhs_trans, out_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, + extra_out_, stream); + } else { + overlap->split_overlap_rs(rhs_, rhs_trans, lhs_, lhs_trans, out_, bias_, pre_gelu_out_, + workspace_, grad, accumulate, use_split_accumulator, extra_out_, + stream); + } + } +} + +void CommGemmOverlap(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { + // Inputs + auto lhs = buffers[0]; + auto lhs_scale_inv = reinterpret_cast(buffers[1]); + auto rhs = buffers[2]; + auto rhs_scale_inv = reinterpret_cast(buffers[3]); + auto bias = buffers[4]; + auto gelu_input = buffers[5]; + auto out_amax = reinterpret_cast(buffers[6]); + auto out_scale = reinterpret_cast(buffers[7]); + + // Outputs + auto out = buffers[8]; + auto out_amax_new = reinterpret_cast(buffers[9]); + auto out_scale_new = reinterpret_cast(buffers[10]); + auto pre_gelu_out = buffers[11]; + auto bias_grad = buffers[12]; + auto extra_out = buffers[13]; + auto workspace = buffers[14]; + + // Check operand-output aliases + NVTE_CHECK(bias == bias_grad, "bias not bound to bias_grad in AG+GEMM overlap."); + NVTE_CHECK(gelu_input == pre_gelu_out, + "gelu_input not bound to pre_gelu_out in AG+GEMM overlap."); + NVTE_CHECK(out_amax == out_amax_new, "out_amax not bound to out_amax_new in AG+GEMM overlap."); + NVTE_CHECK(out_scale == out_scale_new, + "out_scale not bound to out_scale_new in AG+GEMM overlap."); + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + + auto lhs_shape = + (desc.lhs_trans) ? std::vector{desc.k, desc.m} : std::vector{desc.m, desc.k}; + auto rhs_shape = + (desc.rhs_trans) ? std::vector{desc.n, desc.k} : std::vector{desc.k, desc.n}; + auto out_shape = std::vector{desc.m, desc.n}; + + CommGemmOverlapImpl(lhs, lhs_shape, desc.operand_dtype, lhs_scale_inv, desc.lhs_trans, rhs, + rhs_shape, desc.operand_dtype, rhs_scale_inv, desc.rhs_trans, out, out_shape, + desc.out_dtype, out_amax, out_scale, bias, desc.bias_dtype, pre_gelu_out, + extra_out, lhs_shape, workspace, desc.workspace_size, desc.fuse_gelu, + desc.fuse_bias, desc.grad, desc.accumulate, desc.use_split_accumulator, + desc.comm_type, desc.name, stream); +} + +Error_Type CommGemmOverlapFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, + Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Buffer_Type out_amax, Buffer_Type out_scale, + Result_Type out, Result_Type out_amax_new, Result_Type out_scale_new, + Result_Type pre_gelu_out, Result_Type bias_grad, + Result_Type extra_out, Result_Type workspace, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator, int32_t comm_type_flag, + std::string_view name) { + // Inputs + auto lhs_ptr = lhs.untyped_data(); + auto lhs_shape = std::vector(lhs.dimensions().begin(), lhs.dimensions().end()); + auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + auto lhs_scale_inv_ptr = reinterpret_cast(lhs_scale_inv.untyped_data()); + auto rhs_ptr = rhs.untyped_data(); + auto rhs_shape = std::vector(rhs.dimensions().begin(), rhs.dimensions().end()); + auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs.element_type()); + auto rhs_scale_inv_ptr = reinterpret_cast(rhs_scale_inv.untyped_data()); + auto bias_ptr = bias.untyped_data(); + auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); + auto gelu_input_ptr = gelu_input.untyped_data(); + auto out_amax_ptr = reinterpret_cast(out_amax.untyped_data()); + auto out_scale_ptr = reinterpret_cast(out_scale.untyped_data()); + + // Outputs + auto out_ptr = out->untyped_data(); + auto out_shape = std::vector(out->dimensions().begin(), out->dimensions().end()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(out->element_type()); + auto out_amax_new_ptr = reinterpret_cast(out_amax_new->untyped_data()); + auto out_scale_new_ptr = reinterpret_cast(out_scale_new->untyped_data()); + auto pre_gelu_ptr = pre_gelu_out->untyped_data(); + auto bias_grad_ptr = bias_grad->untyped_data(); + auto extra_out_ptr = extra_out->untyped_data(); + auto extra_out_shape = + std::vector(extra_out->dimensions().begin(), extra_out->dimensions().end()); + auto workspace_ptr = workspace->untyped_data(); + auto workspace_size = workspace->element_count(); + + // Check operand-output aliases + NVTE_CHECK(bias_ptr == bias_grad_ptr, "bias not bound to bias_grad in AG+GEMM overlap."); + NVTE_CHECK(gelu_input_ptr == pre_gelu_ptr, + "gelu_input not bound to pre_gelu_out in AG+GEMM overlap."); + NVTE_CHECK(out_amax_ptr == out_amax_new_ptr, + "out_amax not bound to out_amax_new in AG+GEMM overlap."); + NVTE_CHECK(out_scale_ptr == out_scale_new_ptr, + "out_scale not bound to out_scale_new in AG+GEMM overlap."); + + CommGemmOverlapImpl( + lhs_ptr, lhs_shape, lhs_dtype, lhs_scale_inv_ptr, lhs_trans, rhs_ptr, rhs_shape, rhs_dtype, + rhs_scale_inv_ptr, rhs_trans, out_ptr, out_shape, out_dtype, out_amax_ptr, out_scale_ptr, + bias_ptr, bias_dtype, pre_gelu_ptr, extra_out_ptr, extra_out_shape, workspace_ptr, + workspace_size, fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator, + static_cast(comm_type_flag), static_cast(name), stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CommGemmOverlapHandler, CommGemmOverlapFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Arg() // out_amax + .Arg() // out_scale + .Ret() // out + .Ret() // out_amax_new + .Ret() // out_scale_new + .Ret() // pre_gelu_out + .Ret() // bias_grad + .Ret() // extra_out + .Ret() // workspace + .Attr("lhs_trans") + .Attr("rhs_trans") + .Attr("fuse_gelu") + .Attr("fuse_bias") + .Attr("grad") + .Attr("accumulate") + .Attr("use_split_accumulator") + .Attr("comm_type_flag") + .Attr("name"), + FFI_CudaGraph_Traits); + +} // namespace jax + +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 5dae9d6757..14148ecbd0 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -62,7 +62,8 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque auto *out_scale_updated = reinterpret_cast(buffers[10]); auto *pre_gelu_out = buffers[11]; auto *bias_grad = buffers[12]; - auto *workspace = buffers[13]; + // buffers[13] is the extra output for comm+GEMM overlap, not used here + auto *workspace = buffers[14]; // Operand aliasing NVTE_CHECK(bias == bias_grad, "bias not bound to bias_grad in TE/JAX GEMM"); @@ -88,9 +89,9 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, Result_Type out_amax_updated, Result_Type out_scale_updated, - Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type workspace, - bool lhs_trans, bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, - bool accumulate, bool use_split_accumulator) { + Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type dummy_out, + Result_Type workspace, bool lhs_trans, bool rhs_trans, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator) { // Inputs auto lhs_ptr = lhs.untyped_data(); auto lhs_scale_inv_ptr = reinterpret_cast(lhs_scale_inv.untyped_data()); @@ -110,6 +111,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto out_dtype = convert_ffi_datatype_to_te_dtype(out->element_type()); auto pre_gelu_out_ptr = pre_gelu_out->untyped_data(); auto bias_grad_ptr = bias_grad->untyped_data(); + // dummy_out is the extra output for comm+GEMM overlap, not used here auto workspace_ptr = workspace->untyped_data(); auto workspace_size = workspace->dimensions().back(); @@ -151,6 +153,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Ret() // out_scale_updated .Ret() // pre_gelu_out .Ret() // bias_grad + .Ret() // dummy_out .Ret() // workspace .Attr("lhs_trans") .Attr("rhs_trans") diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 1a9ce987af..31a53529e3 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -80,15 +80,33 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( deterministic, window_size_left, window_size_right}); } -pybind11::bytes PackCustomCallGemmDescriptor(size_t batch, size_t m, size_t n, size_t k, - size_t workspace_size, DType operand_dtype, - DType bias_dtype, DType out_dtype, bool lhs_trans, - bool rhs_trans, bool fuse_gelu, bool fuse_bias, - bool grad, bool accumulate, +pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, size_t workspace_size, + DType operand_dtype, DType bias_dtype, DType out_dtype, + bool lhs_trans, bool rhs_trans, bool fuse_gelu, + bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator) { - return PackOpaque(CustomCallGemmDescriptor{batch, m, n, k, workspace_size, operand_dtype, - bias_dtype, out_dtype, lhs_trans, rhs_trans, fuse_gelu, - fuse_bias, grad, accumulate, use_split_accumulator}); + return PackOpaque(CustomCallGemmDescriptor{m, n, k, workspace_size, operand_dtype, bias_dtype, + out_dtype, lhs_trans, rhs_trans, fuse_gelu, fuse_bias, + grad, accumulate, use_split_accumulator}); +} + +pybind11::bytes PackCustomCallBufferDescriptor(const std::string &name, + const std::vector &shape, DType dtype, + CommOverlapType comm_type) { + return PackOpaque( + {name, shape.data(), shape.size(), dtype, comm_type}); +} + +pybind11::bytes PackCustomCallOverlapDescriptor(size_t m, size_t k, size_t n, size_t workspace_size, + DType operand_dtype, DType bias_dtype, + DType out_dtype, bool lhs_trans, bool rhs_trans, + bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator, + CommOverlapType comm_type, + const std::string &name) { + return PackOpaque( + {m, n, k, workspace_size, operand_dtype, bias_dtype, out_dtype, lhs_trans, rhs_trans, + fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator, comm_type, name}); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index ddf98d9d78..2bf13a600d 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -8,6 +8,7 @@ #include "extensions.h" namespace transformer_engine { + namespace jax { template @@ -53,6 +54,8 @@ pybind11::dict Registrations() { dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); dict["te_gemm"] = EncapsulateFunction(Gemm); + dict["te_copy_into_overlap_buffer"] = EncapsulateFunction(CopyIntoOverlapBuffer); + dict["te_comm_gemm_overlap"] = EncapsulateFunction(CommGemmOverlap); // Transpose dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); @@ -104,6 +107,8 @@ pybind11::dict Registrations() { dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi; dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); + dict["te_copy_into_overlap_buffer_ffi"] = EncapsulateFFI(CopyIntoOverlapBufferHandler); + dict["te_comm_gemm_overlap_ffi"] = EncapsulateFFI(CommGemmOverlapHandler); return dict; } @@ -120,6 +125,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor); + m.def("pack_buffer_descriptor", &PackCustomCallBufferDescriptor); + m.def("pack_overlap_descriptor", &PackCustomCallOverlapDescriptor); m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion); @@ -132,7 +139,14 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); + m.def("bootstrap_comm_gemm_overlap", &BootstrapCommGemmOverlap); + m.def("destroy_comm_gemm_overlaps", &DestroyCommGemmOverlap); + m.def("set_buffer_scale_inv", &SetOverlapBufferScaleInverse, pybind11::arg(), pybind11::arg(), + pybind11::arg("grad") = false); + m.def("get_overlap_buffer", &GetOverlapBuffer); + m.def("overlap_buffer_is_fp8", &OverlapBufferIsFp8); } } // namespace jax + } // namespace transformer_engine diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 4cf09a204f..e463f0ace2 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -1,15 +1,18 @@ # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +import os +import warnings import operator from functools import partial, reduce -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Sequence import jax import jax.numpy as jnp from jax.typing import ArrayLike from jax.sharding import NamedSharding, PartitionSpec +from transformer_engine import transformer_engine_jax as tex from .fp8 import FP8Helper, FP8MetaPackage from .cpp_extensions import ( gemm_impl, @@ -19,15 +22,23 @@ dbias_cast_transpose, dact_lu_dbias_cast_transpose, ) -from .cpp_extensions.gemm import sanitize_dims, mirror_dim + +from .cpp_extensions.gemm import sanitize_dims, mirror_dim, copy_into_overlap_buffer +from .cpp_extensions.misc import jax_dtype_is_fp8, jax_dtype_to_te_dtype +from .sharding import get_mesh_axis_size, global_mesh_resource __all__ = [ "gemm", "fp8_gemm", "type_safe_gemm", + "initialize_comm_gemm_overlaps", + "destroy_comm_gemm_overlap", ] +_NUM_MAX_UB_STREAMS = 3 +_ACTIVE_COMM_GEMM_OVERLAPS = dict() + def gemm( x: ArrayLike, @@ -37,12 +48,70 @@ def gemm( fuse_gelu: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, + comm_overlap_name: Optional[str] = None, + ag_overlap_skip_copy: bool = False, ) -> ArrayLike: - """Non-FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions.""" - return _gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + """ + Non-FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions. + + Parameters + ---------- + x : ArrayLike + LHS operand, sized ([B], M, K) when not transposed. + kernel : ArrayLike + RHS operand, sized (K, N) when not transposed. + bias : Optional[ArrayLike], default = `None` + Optional bias term to add onto the (LHS x RHS) result. + contracting_dims : Tuple[int, int], default = `(-1, 0)` + Contracting dimensions of LHS and RHS, respectively, in the matrix-multiplication. + The default (-1, 0) describes the fully non-transposed 'NN' layout where LHS contracts in + the last dimension, and RHS contracts in the first dimension. + fuse_gelu : bool, default = `False` + Enable the GELU epilogue for GEMM. This applies GELU after the bias-addition if the bias + term is not `None`. + accumulate : bool, default = `False` + use_split_accumulator : bool, default = `False` + comm_overlap_name : Optional[str], default = `None` + Name of the comm+GEMM overlap layer that this GEMM is associated with. Comm+GEMM overlap + must be initialized with `te.jax.gemm.initialize_comm_gemm_overlaps()` before this + GEMM call, and the configuration dictionary used in the initialization must include + the name passed into this function. + ag_overlap_skip_copy: bool = `False` + All-gather overlap requires the LHS operand to be copied into the communication buffer. + If the communication buffer already has the necessary data, setting this flag will + avoid an unnecessary memcpy operation. + """ + comm_overlap_config = None + if comm_overlap_name is not None: + comm_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(comm_overlap_name, None) + if comm_overlap_config is None: + warnings.warn( + f"Comm+GEMM overlap for {comm_overlap_name} has not been initialized! " + + "Sharded operands will trigger XLA collectives instead." + ) + + elif ( + not ag_overlap_skip_copy + and comm_overlap_config["method"] != "bulk" + and comm_overlap_config["comm_type"] == tex.CommOverlapType.AG + ): + if sanitize_dims(contracting_dims[0], x.ndim) != x.ndim - 1: + x = jnp.matrix_transpose(x) + copy_into_overlap_buffer(x, comm_overlap_name, tex.CommOverlapType.RS) + + return _gemm( + x, + kernel, + bias, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + comm_overlap_config, + ) -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) def _gemm( x: ArrayLike, kernel: ArrayLike, @@ -51,9 +120,17 @@ def _gemm( fuse_gelu: bool, accumulate: bool, use_split_accumulator: bool, + comm_overlap_config: dict, ) -> ArrayLike: out, _ = _gemm_fwd_rule( - x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator + x, + kernel, + bias, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + comm_overlap_config, ) return out @@ -66,6 +143,7 @@ def _gemm_fwd_rule( fuse_gelu: bool, accumulate: bool, use_split_accumulator: bool, + comm_overlap_config: dict, ) -> Tuple[ArrayLike, ...]: assert ( kernel.ndim == 2 @@ -78,7 +156,7 @@ def _gemm_fwd_rule( # # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) - out, pre_gelu_out = gemm_impl( + out, pre_gelu_out, extra_out = gemm_impl( x, kernel, bias=bias, @@ -88,16 +166,29 @@ def _gemm_fwd_rule( fuse_bias=fuse_bias, accumulate=accumulate, use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, ) + # Update returned and saved tensors based on comm+GEMM overlap + saved_x = x + final_out = out + if comm_overlap_config is not None: + match comm_overlap_config.get("comm_type", None): + case tex.CommOverlapType.AG: + # AG overlap puts the all-gathered global LHS (X) into extra_out + saved_x = extra_out + case tex.CommOverlapType.RS: + # RS overlap puts the reduce-scattered sharded output into extra_out + final_out = extra_out + ctx = ( - x, + saved_x, kernel, pre_gelu_out if fuse_gelu else None, fuse_bias, ) - return out, ctx + return final_out, ctx def _gemm_bwd_rule( @@ -105,6 +196,7 @@ def _gemm_bwd_rule( fuse_gelu, accumulate, use_split_accumulator, + comm_overlap_config, ctx, grad, ): @@ -114,6 +206,11 @@ def _gemm_bwd_rule( mirror_dim, (x_inner_dim, kernel_inner_dim), (x.ndim, kernel.ndim) ) + dgrad_overlap_config = None + if comm_overlap_config is not None: + dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" + dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) + # FWD MODE: # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) @@ -127,7 +224,7 @@ def _gemm_bwd_rule( # # GEMM+AR: ([B], M, N) x (K/P, N)^T ------> ([B], M, K/P) # (DP, None, None) x (TP, None)^T --> (DP, None, TP) - dgrad, dgelu, _ = gemm_impl( + dgrad, dgelu, _, dgrad_extra_out = gemm_impl( grad, kernel, gelu_input=pre_gelu_out, @@ -138,17 +235,42 @@ def _gemm_bwd_rule( grad=True, accumulate=accumulate, use_split_accumulator=use_split_accumulator, + comm_overlap_config=dgrad_overlap_config, ) + # If dgrad overlapped reduce-scatter, set it to the RS output + if dgrad_overlap_config is not None: + if ( + dgrad_overlap_config["method"] != "bulk" + and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): + dgrad = dgrad_extra_out + + # Collapse batch dimension for wgrad + wgrad_rhs = dgelu if fuse_gelu else grad + if x.ndim > 2: + # If x was originally transposed, we need to transpose it back in order to collapse + # the batch dims correctly. + if x_inner_dim == x.ndim - 2: + x = jnp.matrix_transpose(x) + batch_size = reduce(operator.mul, x.shape[:-2], 1) + x = jnp.reshape(x, (batch_size * x.shape[-2], x.shape[-1])) + wgrad_rhs = jnp.reshape(wgrad_rhs, (batch_size * wgrad_rhs.shape[-2], wgrad_rhs.shape[-1])) + + # Recover comm+GEMM overlap config for wgrad + wgrad_overlap_config = None + if comm_overlap_config is not None: + wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" + wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + # WGRAD: # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) # (DP, 'tp', None)^T --(AG)-->(DP, None, None)^T x (DP, None, 'tp') --> (None, 'tp') # # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K/P, N) # (DP, None, 'tp')^T --(AG)--> (DP, None, None)^T x (DP, None, None) ----> (None, None) - # Make XLA scatter output in first dim. wgrad_rhs = dgelu if fuse_gelu else grad - wgrad, _, bgrad = gemm_impl( + wgrad, _, bgrad, wgrad_extra_out = gemm_impl( x, wgrad_rhs, gelu_input=pre_gelu_out, @@ -159,8 +281,17 @@ def _gemm_bwd_rule( grad=True, accumulate=accumulate, use_split_accumulator=use_split_accumulator, + comm_overlap_config=wgrad_overlap_config, ) + # If wgrad overlapped reduce-scatter, set it to the RS output + if wgrad_overlap_config is not None: + if ( + wgrad_overlap_config["method"] != "bulk" + and wgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): + wgrad = wgrad_extra_out + if not fuse_bias: bgrad = None @@ -179,8 +310,60 @@ def fp8_gemm( fuse_gelu: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, + comm_overlap_name: Optional[str] = None, + ag_overlap_skip_copy: bool = False, ) -> ArrayLike: - """Non-FP8 `nvte_cublas_gemm()` with optional GELU and bias-add fusions.""" + """ + FP8 collective/distributed `nvte_cublas_gemm()` with GELU and bias-add fusions. + + FP8 GEMM requires the LHS operand to be non-transposed, and the RHS operand to be transposed, + such that the contracting dimensions are always the last dimension for both operands. + + Parameters + ---------- + x : ArrayLike + Non-transposed LHS operand, sized ([B], M, K). + kernel_t : ArrayLike + Transposed RHS operand, sized (N, K). + fp8_meta : transformer_engine.jax.fp8.FP8MetaPackage + FP8MetaPackage object carrying amax, scale and scale_inv information for the GEMM operands. + bias : Optional[ArrayLike], default = `None` + Optional bias term to add onto the (LHS x RHS) result. + out_dtype : jnp.dtype, default = `jnp.bfloat16` + Data type of the FP8 GEMM output. If chosen as an FP8 dtype (i.e. `jnp.float8_e4m3fn` or + `jnp.float8_e5m2`), the `fp8_meta` must also contain amax and scale information for the + GEMM output. + fuse_gelu : bool, default = `False` + Enable the GELU epilogue for GEMM. This applies GELU after the bias-addition if the bias + term is not `None`. + accumulate : bool, default = `False` + use_split_accumulator : bool, default = `False` + comm_overlap_name : Optional[str], default = `None` + Name of the comm+GEMM overlap layer that this GEMM is associated with. Comm+GEMM overlap + must be initialized with `te.jax.gemm.initialize_comm_gemm_overlaps()` before this + GEMM call, and the configuration dictionary used in the initialization must include + the name passed into this function. + ag_overlap_skip_copy: bool = `False` + All-gather overlap requires the LHS operand to be copied into the communication buffer. + If the communication buffer already has the necessary data, setting this flag will + avoid an unnecessary memcpy operation. + """ + comm_overlap_config = None + if comm_overlap_name is not None: + comm_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(comm_overlap_name, None) + if comm_overlap_config is None: + warnings.warn( + f"Comm+GEMM overlap for {comm_overlap_name} has not been initialized! " + + "Sharded operands will trigger XLA collectives instead." + ) + + elif ( + not ag_overlap_skip_copy + and comm_overlap_config["method"] != "bulk" + and comm_overlap_config["comm_type"] == tex.CommOverlapType.AG + ): + copy_into_overlap_buffer(x, comm_overlap_name, tex.CommOverlapType.RS) + return _fp8_gemm( x, kernel_t, @@ -191,6 +374,7 @@ def fp8_gemm( fuse_gelu, accumulate, use_split_accumulator, + comm_overlap_config, ) @@ -205,6 +389,7 @@ def _fp8_gemm( fuse_gelu: bool, accumulate: bool, use_split_accumulator: bool, + comm_overlap_config: dict, ) -> ArrayLike: out, _ = _fp8_gemm_fwd_rule( x, @@ -216,6 +401,7 @@ def _fp8_gemm( fuse_gelu, accumulate, use_split_accumulator, + comm_overlap_config, ) return out @@ -230,6 +416,7 @@ def _fp8_gemm_fwd_rule( fuse_gelu: bool, accumulate: bool, use_split_accumulator: bool, + comm_overlap_config: dict, ) -> Tuple[ArrayLike, ...]: assert ( kernel_t.ndim == 2 @@ -298,7 +485,26 @@ def _fp8_gemm_fwd_rule( if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] else None ) - out, updated_out_amax, updated_out_scale, pre_gelu_out = fp8_gemm_impl( + + # Set scale_inv for comm overlap buffer + buffer_scale_inv = None + if comm_overlap_config is not None: + overlap_name = comm_overlap_config["name"] + + if comm_overlap_config["method"] != "bulk" and tex.overlap_buffer_is_fp8(overlap_name): + match comm_overlap_config["comm_type"]: + case tex.CommOverlapType.AG: + buffer_scale_inv = x_scale_inv + + case tex.CommOverlapType.RS: + buffer_scale_inv = jnp.reciprocal(out_scale) + + tex.set_overlap_buffer_scale_inverse( + overlap_name, + jax.dlpack.to_dlpack(buffer_scale_inv), + ) + + out, updated_out_amax, updated_out_scale, pre_gelu_out, extra_out = fp8_gemm_impl( casted_x, x_scale_inv, casted_kernel_t, @@ -312,12 +518,26 @@ def _fp8_gemm_fwd_rule( fuse_bias=fuse_bias, accumulate=accumulate, use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, ) - if out_dtype not in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + if not jax_dtype_is_fp8(out_dtype): updated_out_amax = None updated_out_scale = None + # Update returned and saved arrays based on comm+GEMM overlap config + final_out = out + saved_casted_x = casted_x + if comm_overlap_config is not None: + match comm_overlap_config.get("comm_type", None): + case tex.CommOverlapType.AG: + # AG overlap puts all-gathered global LHS (X) array into extra_out + saved_casted_x = extra_out + case tex.CommOverlapType.RS: + # RS overlap puts the reduce-scattered sharded output into extra_out + final_out = extra_out + ctx = ( + saved_casted_x, casted_x_t, casted_kernel, amax_list, @@ -332,7 +552,7 @@ def _fp8_gemm_fwd_rule( (x.ndim > 2), ) - return (out, updated_out_scale), ctx + return (final_out, updated_out_amax, updated_out_scale), ctx def _fp8_gemm_bwd_rule( @@ -340,6 +560,7 @@ def _fp8_gemm_bwd_rule( fuse_gelu, accumulate, use_split_accumulator, + comm_overlap_config, ctx, grad, ): @@ -407,28 +628,128 @@ def _fp8_gemm_bwd_rule( ) bgrad = None + # Recover dgrad comm+GEMM overlap config + dgrad_overlap_config = None + if comm_overlap_config is not None: + dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" + dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) + + # Set scale_inv for comm overlap buffer + dgrad_out_dtype = jnp.bfloat16 + dgrad_amax = None + dgrad_scale = None + if ( + dgrad_overlap_config is not None + and dgrad_overlap_config["method"] != "bulk" + and tex.overlap_buffer_is_fp8(dgrad_overlap_name) + ): + dgrad_out_dtype = bwd_dtype + dgrad_amax = grad_amax + dgrad_scale = grad_scale + tex.set_overlap_buffer_scale_inverse( + dgrad_overlap_name, + jax.dlpack.to_dlpack(grad_scale_inv), + ) + + # DGRAD: ([B], M, N) x (K, N)^T = ([B], M, K) kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] - dgrad, *_ = fp8_gemm_impl( + dgrad, *_, dgrad_extra_out = fp8_gemm_impl( casted_grad, grad_scale_inv, casted_kernel, kernel_scale_inv, + None, + None, + dgrad_amax, + dgrad_scale, + out_dtype=dgrad_out_dtype, batched_output=batched_input, accumulate=accumulate, use_split_accumulator=use_split_accumulator, + comm_overlap_config=dgrad_overlap_config, ) + # If dgrad overlapped reduce-scatter, set it to the RS output + if ( + dgrad_overlap_config is not None + and dgrad_overlap_config["method"] != "bulk" + and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): + dgrad = dgrad_extra_out + + if fuse_gelu and fuse_bias: + # Fuse bgrad with dGELU. + _, casted_dgelu_t, bgrad, updated_grad_amax = dact_lu_dbias_cast_transpose( + grad, + pre_gelu_out, + grad_amax, + grad_scale, + grad_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1, + activation_type=("gelu",), + ) + elif fuse_gelu: + # No bias grad to fuse so we just do dGELU. + _, casted_dgelu_t, updated_grad_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) + bgrad = None + + # Recover wgrad config + wgrad_overlap_config = None + if comm_overlap_config is not None: + wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" + wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + + # Set scale_inv for comm overlap buffer + wgrad_out_dtype = jnp.bfloat16 + wgrad_amax = None + wgrad_scale = None + if ( + wgrad_overlap_config is not None + and wgrad_overlap_config["method"] != "bulk" + and tex.overlap_buffer_is_fp8(wgrad_overlap_name) + ): + match wgrad_overlap_config["comm_type"]: + case tex.CommOverlapType.AG: + buffer_scale_inv = x_scale_inv + case tex.CommOverlapType.RS: + buffer_scale_inv = grad_scale_inv + wgrad_out_dtype = bwd_dtype + wgrad_amax = grad_amax + wgrad_scale = grad_scale + tex.set_overlap_buffer_scale_inverse( + dgrad_overlap_name, + jax.dlpack.to_dlpack(buffer_scale_inv), + ) + + # WGRAD: ([B], N, M) x ([B], K, M)^T = (N, K) + wgrad_rhs_t = casted_dgelu_t if fuse_gelu else casted_grad_t x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] - wgrad, *_ = fp8_gemm_impl( + wgrad, *_, wgrad_extra_out = fp8_gemm_impl( casted_x_t, x_scale_inv, - casted_grad_t, + wgrad_rhs_t, grad_scale_inv, - out_shape=False, + None, + None, + wgrad_amax, + wgrad_scale, + out_dtype=wgrad_out_dtype, + batched_output=False, accumulate=accumulate, use_split_accumulator=use_split_accumulator, + comm_overlap_config=wgrad_overlap_config, ) + # If wgrad overlapped reduce-scatter, set it to the RS output + if ( + wgrad_overlap_config is not None + and wgrad_overlap_config["method"] != "bulk" + and wgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): + wgrad = wgrad_extra_out + amax_list[FP8MetaPackage.INPUT_IDX] = ( amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0]) ) @@ -462,11 +783,9 @@ def type_safe_gemm( fuse_gelu: bool = False, accumulate: bool = False, use_split_accumulator: bool = False, + comm_overlap_name: Optional[str] = None, ) -> ArrayLike: - if x.dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] or kernel.dtype in [ - jnp.float8_e4m3fn, - jnp.float8_e5m2, - ]: + if jax_dtype_is_fp8(x.dtype) or jax_dtype_is_fp8(kernel.dtype): assert fp8_meta is not None, "GEMM operands have FP8 dtypes but FP8MetaPackage is None." if fp8_meta is not None: @@ -484,6 +803,212 @@ def type_safe_gemm( fuse_gelu, accumulate, use_split_accumulator, + comm_overlap_name, ) else: - return gemm(x, kernel, bias, contracting_dims, fuse_gelu, accumulate, use_split_accumulator) + return gemm( + x, + kernel, + bias, + contracting_dims, + fuse_gelu, + accumulate, + use_split_accumulator, + comm_overlap_name, + ) + + +def initialize_comm_gemm_overlaps( + buffer_shape: Sequence[int], + buffer_dtype: jnp.dtype, + mesh: Optional[jax.sharding.Mesh] = None, + tp_resource: Optional[str] = None, + use_fp8: bool = False, + overlap_configs: Optional[dict] = None, +) -> None: + assert tex.ubuf_built_with_mpi(), ( + "Comm+GEMM overlap in TE/JAX requires Transformer Engine to be compiled with " + + "`NVTE_UB_WITH_MPI=1` and `MPI_HOME=/path/to/mpi` options." + ) + if not tex.device_supports_multicast(): + assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( + "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " + + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." + ) + + # Get # of devices in the mesh axis for comm+GEMM overlap + tp_resource = global_mesh_resource().tp_resource if tp_resource is None else tp_resource + tp_size = get_mesh_axis_size(tp_resource, mesh=mesh) + + # Layers that support comm+GEMM overlap + layers_all_gather_overlap = [ + "generic_ag", + "qkv_fprop", + "qkv_dgrad", + "proj_dgrad", + "fc1_fprop", + "fc1_dgrad", + "fc2_dgrad", + ] + layers_reduce_scatter_overlap = [ + "generic_rs", + "proj_fprop", + "fc2_fprop", + "qkv_wgrad", + "fc1_wgrad", + ] + dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] + + # Default overlap methods for layers + methods = { + "ring_exchange": [ + "generic_ag", + "generic_rs", + "qkv_fprop", + "fc1_fprop", + "proj_dgrad", + "fc2_dgrad", + ], + "pipeline": ["proj_fprop", "fc2_fprop"], + "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], + } + + # AG-RS overlap pairs of layers forming a tensor-parallel block + ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} + rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} + global layers_atomic_ring_exchange + layers_atomic_ring_exchange = [] + + def get_method(name): + for method, names in methods.items(): + if name in names: + return method + raise KeyError(f"Given layer name {name} does not exist.") + + def get_default_config(name): + method = get_method(name) + default_cfg = { + "method": method, + "comm_type": ( + tex.CommOverlapType.AG if name in layers_all_gather_overlap else tex.CommOverlap.RS + ), + "num_sm": 1 if method == "ring_exchange" else 16, + "cga_size": 1 if method == "ring_exchange" else 2, + "set_sm_margin": False, + "num_splits": 4 if method == "pipeline" else tp_size, + "aggregate": False, + "atomic_gemm": False, + "pipeline_rs_overlap_first_gemm": False, + "use_ce": True, + "fp8_buf": name in layers_all_gather_overlap, + } + return default_cfg + + def add_new_comm_gemm_overlap( + name: str, + method: str, + shape: Sequence[int], + dtype: jnp.dtype, + comm_type: tex.CommOverlapType, + num_sm: int = 16, + cga_size: int = 2, + set_sm_margin: bool = False, + num_splits: int = 4, + aggregate: bool = False, + atomic_gemm: bool = False, + pipeline_rs_overlap_first_gemm: bool = False, + use_ce: bool = True, + fp8_buf: bool = False, + ) -> None: + assert ( + name not in _ACTIVE_COMM_GEMM_OVERLAPS + ), "Duplicate initialization for `{name}` overlap!" + + if atomic_gemm: + warnings.warn( + "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." + ) + assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." + if method == "bulk": + warnings.warn( + f"At {name}, atoimic GEMM not is supported for a bulk overlap." + "Defaulting to `atomic_gemm=False`." + ) + atomic_gemm = False + if method == "pipeline" and comm_type == tex.CommOverlapType.AG: + raise ValueError( + f"At {name}, `pipeline` overlap method is not supported for AllGather." + ) + # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`. + # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality. + global layers_atomic_ring_exchange + if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs: + layers_atomic_ring_exchange += [name, ag_rs_pairs[name]] + if name in rs_ag_pairs: + assert_message = ( + f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk " + "outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and " + "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses " + "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config " + "for functionality." + ) + if name in layers_atomic_ring_exchange: + assert atomic_gemm and method == "ring_exchange", assert_message + else: + if atomic_gemm and method == "ring_exchange": + assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message + + dtype = jnp.uint8 if (use_fp8 and fp8_buf) else dtype + tex.bootstrap_comm_gemm_overlap( + name, + method, + shape, + jax_dtype_to_te_dtype(dtype), + comm_type, + tp_size, + num_splits, + _NUM_MAX_UB_STREAMS, + cga_size, + num_sm, + set_sm_margin, + use_ce, + atomic_gemm, + aggregate, + pipeline_rs_overlap_first_gemm, + ) + + if overlap_configs is not None: + for name in dgrad_reduce_scatter_overlap: + if ( + name in overlap_configs + and "method" in overlap_configs[name] + and overlap_configs[name]["method"] != "bulk" + ): + wgrad_name = name.replace("dgrad", "wgrad") + assert wgrad_name not in overlap_configs + layers_reduce_scatter_overlap.remove(wgrad_name) + layers_all_gather_overlap.remove(name) + layers_reduce_scatter_overlap.append(name) + methods["bulk"].remove(name) + methods["bulk"].remove(wgrad_name) + new_method = overlap_configs[name]["method"] + methods[new_method].append(name) + + global _ACTIVE_COMM_GEMM_OVERLAPS + for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: + if overlap_configs is not None and name in overlap_configs: + fp8_buf = (name in layers_all_gather_overlap) or ( + overlap_configs[name].get("fp8_buf", False) and name in methods["pipeline"] + ) + default_config = get_default_config(name) + final_config = default_config.update(overlap_configs[name]) + final_config["fp8_buf"] = fp8_buf + add_new_comm_gemm_overlap(name, buffer_shape, buffer_dtype, **final_config) + _ACTIVE_COMM_GEMM_OVERLAPS.update({name: final_config}) + + +def destroy_comm_gemm_overlaps(): + for name in _ACTIVE_COMM_GEMM_OVERLAPS: + tex.destroy_comm_gemm_overlap(name) + _ACTIVE_COMM_GEMM_OVERLAPS.pop(name) + _ACTIVE_COMM_GEMM_OVERLAPS = dict() diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 3b49ece4a3..d906bba98f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -553,7 +553,8 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false, + bool overlap_first_gemm = false); void set_ubuf_scale_inv(torch::Tensor scale_inv) { assert(scale_inv.numel()); diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index d212d13516..587e3115b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -186,13 +186,13 @@ void CommOverlapHelper::ub_barrier(ExtComm group) { CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, - bool set_sm_margin, bool atomic_gemm) - : te::CommOverlapBase(buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, - helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, - helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), - std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, - num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { + bool set_sm_margin, bool atomic_gemm, bool overlap_first_gemm) + : te::CommOverlapBase( + buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, + helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, + std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, num_max_streams, + comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm, overlap_first_gemm) { // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to // for PyTorch to factor externally allocated memory into its memory pool and garbage collection // threshold calculation. diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 8856553c54..9841b5d640 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -263,12 +263,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::class_(m, "CommOverlap") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, - int, int, bool, bool>(), + int, int, bool, bool, bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16, - py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false) + py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, + py::arg("overlap_first_gemm") = false) .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard()) .def("split_overlap_rs", &CommOverlap::split_overlap_rs, py::call_guard()) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d115efedaa..164d371985 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -306,6 +306,7 @@ def get_default_config(name): "num_splits": 4 if method == "pipeline" else tp_size, "aggregate": False, "atomic_gemm": False, + "pipeline_rs_overlap_first_gemm": False, "use_ce": True, "fp8_buf": name in layers_all_gather_overlap, } @@ -314,13 +315,14 @@ def get_default_config(name): def add_ub( name: str, method: str, - is_reduce_scatter: int, + is_reduce_scatter: bool, num_sm: int = 16, cga_size: int = 2, - set_sm_margin: int = 0, - num_splits: int = 0, - aggregate: int = 0, - atomic_gemm: int = 0, + set_sm_margin: bool = False, + num_splits: int = 4, + aggregate: bool = False, + atomic_gemm: bool = False, + pipeline_rs_overlap_first_gemm: bool = False, use_ce: bool = True, fp8_buf: bool = False, ) -> None: @@ -386,6 +388,7 @@ def add_ub( num_comm_sm=num_sm, set_sm_margin=set_sm_margin, atomic_gemm=atomic_gemm, + overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) _ub_communicators[name] = ub_obj From b306608da620a95e53a2dc2dca8e7063bc950277 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 3 Dec 2024 14:08:27 +0000 Subject: [PATCH 21/39] AG+GEMM overlap working Signed-off-by: Alp Dener --- .../transformer_engine/comm_gemm_overlap.h | 4 +- .../common/util/pybind_helper.h | 1 + transformer_engine/jax/cpp_extensions/gemm.py | 762 +++++++++++------- transformer_engine/jax/csrc/extensions.h | 109 +-- .../jax/csrc/extensions/comm_gemm_overlap.cpp | 233 +++--- .../jax/csrc/extensions/gemm.cpp | 50 +- .../jax/csrc/extensions/packing.cpp | 19 - .../jax/csrc/extensions/pybind.cpp | 19 +- transformer_engine/jax/gemm.py | 314 ++++---- 9 files changed, 873 insertions(+), 638 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 16e4ccf16a..0605825c82 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -29,9 +29,9 @@ namespace transformer_engine { */ bool ubuf_built_with_mpi(); -enum class CommOverlapType : int32_t { RS = 0, AG = 1 }; +enum class CommOverlapType : int { RS = 0, AG = 1 }; -enum class CommOverlapAlgo : int32_t { +enum class CommOverlapAlgo : int { BULK_OVERLAP_AG = 0, BULK_OVERLAP_RS = 1, SPLIT_PIPELINED_AG_P2P = 2, diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 6fa9574f63..9091e7e364 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -19,6 +19,7 @@ pybind11::enum_(m, "DType") \ .value("kByte", transformer_engine::DType::kByte) \ .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kInt64", transformer_engine::DType::kInt64) \ .value("kFloat32", transformer_engine::DType::kFloat32) \ .value("kFloat16", transformer_engine::DType::kFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2ff98c20d9..59bf28434d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -27,15 +27,15 @@ ) from ..sharding import ( global_mesh_resource, - lax_paral_op, all_reduce_max_along_all_axes_except_PP, - get_mesh_axis_size, ) __all__ = [ "fp8_gemm_impl", "gemm_impl", + "copy_into_overlap_buffer", + "bootstrap_comm_gemm_overlap", ] _COMM_GEMM_OVERLAP_LAYERS = ["qkv", "proj", "fc1", "fc2"] @@ -43,7 +43,7 @@ [layer + "_fprop" for layer in _COMM_GEMM_OVERLAP_LAYERS] + [layer + "_dgrad" for layer in _COMM_GEMM_OVERLAP_LAYERS] + [layer + "_wgrad" for layer in _COMM_GEMM_OVERLAP_LAYERS if layer != "fc2"] - + ["generic_ag", "generic_rs"] + + ["ag_gemm", "gemm_rs"] ) @@ -68,7 +68,7 @@ class CollectiveGemmPrimitive(BasePrimitive): """ name = "te_gemm" - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16) + impl_static_args = (10, 11, 12, 13, 14, 15, 16, 17, 18) multiple_results = True inner_primitive = None outer_primitive = None @@ -81,9 +81,10 @@ def abstract( rhs_scale_inv_aval, bias_aval, gelu_input_aval, + out_aval, out_amax_aval, out_scale_aval, - out_dtype, + extra_out_aval, batched_output, contracting_dims, fuse_gelu, @@ -92,16 +93,19 @@ def abstract( accumulate, use_split_accumulator, comm_overlap_config, + sharded_abstract, ): """ cuBlasLt GEMM abstract """ - del grad, accumulate, use_split_accumulator + if comm_overlap_config is not None: + assert tex.ubuf_built_with_mpi(), ( + "Comm+GEMM overlap in TE/JAX requires Transformer Engine to be compiled with " + + "`NVTE_UB_WITH_MPI=1` and `MPI_HOME=/path/to/mpi` options." + ) + assert is_ffi_enabled(), "Comm+GEMM overlap is supported only via XLA FFI." - assert tex.ubuf_built_with_mpi(), ( - "Comm+GEMM overlap in TE/JAX requires Transformer Engine to be compiled with " - + "`NVTE_UB_WITH_MPI=1` and `MPI_HOME=/path/to/mpi` options." - ) + del grad, accumulate, use_split_accumulator # Validate operand dtypes lhs_dtype = dtypes.canonicalize_dtype(lhs_aval.dtype) @@ -120,13 +124,14 @@ def abstract( and dtypes.canonicalize_dtype(rhs_scale_inv_aval.dtype) == jnp.float32 ), "Missing RHS operand scale inverse in FP8 GEMM." - # Validate operand layouts, adjusted for comm-overlap if necessary + # Validate operand layouts lhs_inner_dim, rhs_inner_dim = map( sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) ) - assert ( - lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim] - ), f"Incompatible contracting dimensions: {lhs_aval.shape} x {rhs_aval.shape}." + assert lhs_aval.shape[lhs_inner_dim] == rhs_aval.shape[rhs_inner_dim], ( + "Incompatible operand sizes: " + + f"{lhs_aval.shape} @ idx {lhs_inner_dim} X {rhs_aval.shape} @ idx {rhs_inner_dim}." + ) lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 @@ -137,80 +142,125 @@ def abstract( assert not lhs_trans, "FP8 GEMM does not support transposed LHS." assert rhs_trans, "FP8 GEMM requires transposed RHS." - # Validate output dtype - if jax_dtype_is_fp8(out_dtype): - assert jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8( - rhs_dtype - ), "FP8 GEMM output requires FP8 inputs." - assert ( - out_amax_aval.size == out_scale_aval.size == 1 - ), "Invalid/missing output amax and scale." - out_amax_updated_dtype = dtypes.canonicalize_dtype(out_amax_aval.dtype) - out_scale_updated_dtype = dtypes.canonicalize_dtype(out_scale_aval.dtype) - assert ( - out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32 - ), "Invalid output amax or scale dtype." - else: - out_dtype = lhs_dtype - out_amax_updated_dtype = jnp.float32 - out_scale_updated_dtype = jnp.float32 - # Make sure leading dimensions of RHS is broadcast-compatible with LHS lhs_outer_dim, rhs_outer_dim = map( mirror_dim, (lhs_inner_dim, rhs_inner_dim), (lhs_aval.ndim, rhs_aval.ndim), ) - lhs_bdims = [ - dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] - ] - lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] - lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) - - if rhs_aval.ndim > 2: + if lhs_aval.ndim > 2 and rhs_aval.ndim > 2: + assert not batched_output, ( + "Batched output requires batched LHS and non-batched RHS operands." + ) + lhs_bdims = [ + dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] + ] + lhs_batch_shape = [lhs_aval.shape[dim] for dim in lhs_bdims] + lhs_batch_size = reduce(operator.mul, lhs_batch_shape, 1) rhs_bdims = [ dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] ] rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] - rhs_batch_size = reduce(operator.mul, rhs_bdims, 1) - if rhs_batch_size > 1: - assert lhs_batch_size == rhs_batch_size, ( - f"Leading dimensins of RHS ({rhs_batch_shape=}) is not broadcast-compatible " - + f"with the leading dimensions of LHS ({lhs_batch_shape=})." - ) + rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) + assert lhs_batch_size == rhs_batch_size, ( + "Leading dimensions of LHS and RHS are not broadcast-compatible: " + + f"{lhs_aval.shape} @ idx {lhs_inner_dim} X {rhs_aval.shape} @ idx {rhs_inner_dim}" + ) - # Infer output shape - if batched_output: - assert ( - lhs_aval.ndim > 2 and rhs_aval.ndim == 2 - ), "Batched output requires batched LHS and non-batched RHS operands." - out_shape = ( - *lhs_batch_shape, - lhs_aval.shape[lhs_outer_dim], - rhs_aval.shape[rhs_outer_dim], + # Validate output dtypes + out_dtype = dtypes.canonicalize_dtype(out_aval.dtype) + if jax_dtype_is_fp8(out_dtype): + assert jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8(rhs_dtype), ( + "FP8 GEMM output requires FP8 inputs." + ) + assert out_amax_aval.size == out_scale_aval.size == 1, ( + "Invalid/missing output amax and scale." + ) + out_amax_updated_dtype = dtypes.canonicalize_dtype(out_amax_aval.dtype) + out_scale_updated_dtype = dtypes.canonicalize_dtype(out_scale_aval.dtype) + assert out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32, ( + "Invalid output amax or scale dtype." ) else: - assert ( - lhs_aval.ndim == rhs_aval.ndim - ), "Non-batched output requires LHS and RHS operands with same number of dimensions." - if lhs_aval.ndim > 2: - rhs_bdims = [ - dim for dim in range(rhs_aval.ndim) if dim not in [rhs_outer_dim, rhs_inner_dim] - ] - rhs_batch_shape = [rhs_aval.shape[dim] for dim in rhs_bdims] - rhs_batch_size = reduce(operator.mul, rhs_batch_shape, 1) - assert lhs_batch_size == rhs_batch_size, ( - f"Leading dimensins of RHS ({rhs_aval.shape=}) is not broadcast-compatible " - + f"with the leading dimensions of LHS ({lhs_aval.shape=})." - ) - out_shape = (lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim]) + assert out_dtype == lhs_dtype, ( + "Output buffer has incorrect dtype: " + + f"expected {lhs_dtype} but found {out_dtype}" + ) + out_amax_updated_dtype = jnp.float32 + out_scale_updated_dtype = jnp.float32 + + # Validate output buffers + out_shape = out_aval.shape + expected_out_shape = [ + *lhs_aval.shape[:-2], + lhs_aval.shape[lhs_outer_dim], + rhs_aval.shape[rhs_outer_dim] + ] + extra_out_shape = extra_out_aval.shape + expected_extra_out_shape = [0] + extra_out_dtype = dtypes.canonicalize_dtype(extra_out_aval.dtype) + expected_extra_out_dtype = jnp.bfloat16 + if batched_output: + assert out_aval.ndim > 2, "Batched output buffer is missing batch dimensions." + else: + expected_out_shape = [reduce(operator.mul, expected_out_shape[:-1], 1), + expected_out_shape[-1]] - # Validate bias/bias_grad shape against inferred output + if (comm_overlap_config is not None + and comm_overlap_config["method"] != "bulk"): + comm_type = comm_overlap_config.get("comm_type", None) + assert comm_type is not None, "Missing comm type for comm+GEMM overlap." + + tp_size = comm_overlap_config.get("tp_size", 1) + assert tp_size > 1, ( + "Comm+GEMM overlap requires tensor-parallel mesh axis size greater than 1." + ) + + if comm_type == tex.CommOverlapType.AG: + expected_extra_out_shape = list(lhs_aval.shape).copy() + elif comm_type == tex.CommOverlapType.RS: + expected_extra_out_shape = list(expected_out_shape).copy() + expected_extra_out_dtype = lhs_dtype + + if sharded_abstract: + if comm_type == tex.CommOverlapType.AG: + expected_out_shape[-2] *= tp_size + expected_extra_out_shape[-2] *= tp_size + else: + expected_extra_out_shape[-2] = expected_extra_out_shape[-2] // tp_size + + assert out_aval.ndim == len(expected_out_shape), ( + "Output buffer has incorrect number of dimensions: " + + f"expected {len(expected_out_shape)} but found {out_aval.ndim}" + ) + assert all([out_aval.shape[i] == expected_out_shape[i] for i in range(out_aval.ndim)]), ( + "Output buffer has incorrect shape: " + + f"expected {expected_out_shape=} but found {out_aval.shape=}" + ) + + assert extra_out_dtype == expected_extra_out_dtype, ( + "Extra output has incorrect dtype: " + + f"expected {expected_extra_out_dtype} but found {extra_out_dtype}" + ) + assert extra_out_aval.ndim == len(expected_extra_out_shape), ( + "Extra output buffer has incorrect number of dimensions: " + + f"expected {len(expected_extra_out_shape)} but found {extra_out_aval.ndim}" + ) + assert all([extra_out_aval.shape[i] == expected_extra_out_shape[i] + for i in range(extra_out_aval.ndim)]), ( + "Extra output buffer has incorrect shape: " + + f"expected {expected_extra_out_shape=} but found {extra_out_aval.shape=}" + ) + + # Validate bias/bias_grad shape against output bufer bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype if fuse_bias: assert ( bias_aval.size > 0 and bias_aval.ndim == 1 and bias_aval.shape[0] == out_shape[-1] - ), "Incorrect bias shape." + ), ( + "Incorrect bias shape: " + + f"expected ({out_shape[-1]}, ) but found ({bias_aval.shape[0]}, )" + ) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) else: assert bias_aval.size == 0, "Internal TE error." @@ -230,45 +280,8 @@ def abstract( else: assert gelu_input_aval.size == 0, "Internal TE error." - # Adjust output sizes for comm-overlap - extra_out_shape = (0,) - extra_out_dtype = jnp.bfloat16 - if comm_overlap_config is not None: - comm_overlap_type = comm_overlap_config.get("comm_type", None) - assert comm_overlap_type is not None, "Missing comm type for comm+GEMM overlap." - comm_overlap_name = comm_overlap_config.get("name", None) - assert ( - comm_overlap_name in _COMM_GEMM_OVERLAP_NAMES - ), f"Unrecognized comm+GEMM overlap name: {comm_overlap_name=}" - - mesh = comm_overlap_config.get("mesh", None) - tp_resource = comm_overlap_config.get("tp_resource", global_mesh_resource().tp_resource) - tp_size = get_mesh_axis_size(tp_resource, mesh=mesh) - - match comm_overlap_type: - case tex.CommOverlapType.AG: - # Extra output is all-gathered LHS copy - extra_out_shape = list(lhs_aval.shape).copy() - extra_out_shape[lhs_outer_dim] *= tp_size - extra_out_dtype = lhs_dtype - - case tex.CommOverlapType.RS: - # FP8 GEMM output for RS overlap is always FP8 - if jax_dtype_is_fp8(lhs_dtype): - assert jax_dtype_is_fp8( - out_dtype - ), "FP8 GEMM with reduce-scatter overlap requires FP8 output." - # Extra output is reduce-scattered GEMM output - extra_out_shape = list(out_shape).copy() - extra_out_shape[-2] /= tp_size - - case _: - raise RuntimeError( - f"Unrecognized comm type for comm+GEMM overlap: {comm_overlap_type=}" - ) - # Create abstract arrays for all outputs - out_aval = lhs_aval.update(shape=out_shape, dtype=out_dtype) + out_updated_aval = out_aval.update(shape=out_shape, dtype=out_dtype) out_amax_updated_aval = out_amax_aval.update( shape=out_amax_aval.shape, dtype=out_amax_updated_dtype ) @@ -277,18 +290,18 @@ def abstract( ) pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_shape, dtype=bias_dtype) bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - extra_out_aval = jax.core.ShapedArray(shape=extra_out_shape, dtype=extra_out_dtype) + extra_out_updated_aval = extra_out_aval.update(shape=extra_out_shape, dtype=extra_out_dtype) workspace_aval = jax.core.ShapedArray( shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 ) return ( - out_aval, + out_updated_aval, out_amax_updated_aval, out_scale_updated_aval, pre_gelu_out_aval, bias_grad_aval, - extra_out_aval, # global LHS for AG overlap, or sharded output for RS overlap + extra_out_updated_aval, # global LHS for AG overlap, or sharded output for RS overlap workspace_aval, ) @@ -324,10 +337,11 @@ def lowering( rhs_scale_inv, bias, gelu_input, + out, out_amax, out_scale, + extra_out, *, - out_dtype, batched_output, contracting_dims, fuse_gelu, @@ -336,11 +350,12 @@ def lowering( accumulate, use_split_accumulator, comm_overlap_config, + sharded_abstract ): """ Fused attention fwd lowering rules """ - del batched_output + del batched_output, sharded_abstract lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in lhs_inner_dim, rhs_inner_dim = map( sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) @@ -348,26 +363,31 @@ def lowering( lhs_trans = lhs_inner_dim != lhs_aval.ndim - 1 rhs_trans = rhs_inner_dim == rhs_aval.ndim - 1 + operands = [ + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out, + out_amax, + out_scale, + extra_out, + ] + operand_output_aliases = { 4: 4, # bias <--> bias_grad 5: 3, # gelu_input <--> pre_gelu_out - 6: 1, # out_amax <--> out_amax_updated - 7: 2, # out_scale <--> out_scale_updated + 6: 0, # out <--> out_updated + 7: 1, # out_amax <--> out_amax_updated + 8: 2, # out_scale <--> out_scale_updated + 9: 5, # extra_out <--> extra_out_updated } if is_ffi_enabled(): name = "te_gemm_ffi" - ffi_args = ( - ctx, - lhs, - lhs_scale_inv, - rhs, - rhs_scale_inv, - bias, - gelu_input, - out_amax, - out_scale, - ) + ffi_args = (ctx, *operands) ffi_kwargs = dict( lhs_trans=lhs_trans, rhs_trans=rhs_trans, @@ -380,23 +400,14 @@ def lowering( if comm_overlap_config is not None: name = "te_comm_gemm_overlap_ffi" - ffi_kwargs["comm_type"] = int(comm_overlap_config["comm_type"]) + ffi_kwargs["comm_type_flag"] = int(comm_overlap_config["comm_type"]) ffi_kwargs["name"] = comm_overlap_config["name"] return ffi.ffi_lowering(name, operand_output_aliases=operand_output_aliases)( *ffi_args, **ffi_kwargs ) + else: - operands = [ - lhs, - lhs_scale_inv, - rhs, - rhs_scale_inv, - bias, - gelu_input, - out_amax, - out_scale, - ] operand_shapes = map(lambda x: ir.RankedTensorType(x.type).shape, operands) out_types = [ ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_dtype(output.dtype)) @@ -423,7 +434,7 @@ def lowering( k, workspace_size, operand_dtype, - jax_dtype_to_te_dtype(out_dtype), + jax_dtype_to_te_dtype(dtypes.canonicalize_dtype(ctx.avals_out[0].dtype)), bias_dtype, lhs_trans, rhs_trans, @@ -434,14 +445,6 @@ def lowering( use_split_accumulator, ) - comm_overlap_type = comm_overlap_config.get("comm_type", None) - if comm_overlap_type is not None: - name = "te_comm_gemm_overlap" - descriptor_packer_fn = tex.pack_overlap_descriptor - descriptor_args += ( - comm_overlap_type, - comm_overlap_config.get("name", None), - ) opaque = descriptor_packer_fn(*descriptor_args) return custom_caller( @@ -460,9 +463,10 @@ def impl( rhs_scale_inv, bias, gelu_input, + out, out_amax, out_scale, - out_dtype, + extra_out, batched_output, contracting_dims, fuse_gelu, @@ -471,6 +475,7 @@ def impl( accumulate, use_split_accumulator, comm_overlap_config, + sharded_abstract, ): assert CollectiveGemmPrimitive.inner_primitive is not None @@ -526,14 +531,33 @@ def impl( else: contracting_dims_2d[1] = contracting_dims[1] + # Reshape output and extra output buffers into 2D as well + if out.ndim > 2: + out = jax.lax.reshape(out, (reduce(operator.mul, out.shape[:-1], 1), out.shape[-1])) + if extra_out.size > 0 and extra_out.ndim > 2: + extra_out = jax.lax.reshape( + extra_out, (reduce(operator.mul, extra_out.shape[:-1], 1), extra_out.shape[-1]) + ) + + batched_extra_out = False + if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": + comm_type = comm_overlap_config["comm_type"] + if comm_type == tex.CommOverlapType.AG: + # Extra output is global LHS, we can collapse but need to recover batches later + batched_extra_out = len(lhs_batch_dims) > 0 + elif comm_type == tex.CommOverlapType.RS: + # Extra output is scattered GEMM output, so we recover batches only if the output is + # batched + batched_extra_out = batched_output + # Invoke GEMM with guaranteed 2D inputs, so batched_output=False ( - out, + out_updated, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad, - extra_out, + extra_out_updated, _, ) = CollectiveGemmPrimitive.inner_primitive.bind( lhs, @@ -542,9 +566,10 @@ def impl( rhs_scale_inv, bias, gelu_input, + out, out_amax, out_scale, - out_dtype=out_dtype, + extra_out, batched_output=False, contracting_dims=contracting_dims_2d, fuse_gelu=fuse_gelu, @@ -553,21 +578,40 @@ def impl( accumulate=accumulate, use_split_accumulator=use_split_accumulator, comm_overlap_config=comm_overlap_config, + sharded_abstract=sharded_abstract, ) # Recover batched dimensions in the output if batched_output: - out_shape = (*lhs_batch_shape, out.shape[-2] // lhs_batch_size, out.shape[-1]) - out = jax.lax.reshape(out, out_shape) + out_shape = ( + *lhs_batch_shape, + out_updated.shape[-2] // lhs_batch_size, + out_updated.shape[-1] + ) + out_updated = jax.lax.reshape(out_updated, out_shape) - return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad, extra_out + if batched_extra_out: + extra_out_shape = ( + *lhs_batch_shape, + extra_out_updated.shape[-2] // lhs_batch_size, + extra_out_updated.shape[-1] + ) + extra_out_updated = jax.lax.reshape(extra_out_updated, extra_out_shape) + + return ( + out_updated, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + extra_out_updated, + ) @staticmethod def batcher( batched_args, batch_dims, *, - out_dtype, batched_output, contracting_dims, fuse_gelu, @@ -576,15 +620,23 @@ def batcher( accumulate, use_split_accumulator, comm_overlap_config, + sharded_abstract, ): assert CollectiveGemmPrimitive.outer_primitive is not None check_valid_batch_dims(batch_dims) - lhs_bdims, *_, bias_bdims, gelu_input_bdims, out_amax_bdims, out_scale_bdims = batch_dims + ( + *_, + bias_bdims, + gelu_input_bdims, + out_bdims, + out_amax_bdims, + out_scale_bdims, + extra_out_bdims, + ) = batch_dims return ( CollectiveGemmPrimitive.outer_primitive.bind( *batched_args, - out_dtype=out_dtype, batched_output=batched_output, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, @@ -592,13 +644,21 @@ def batcher( grad=grad, accumulate=accumulate, use_split_accumulator=use_split_accumulator, + comm_overlap_config=comm_overlap_config, + sharded_abstract=sharded_abstract, + ), + ( + out_bdims, + out_amax_bdims, + out_scale_bdims, + gelu_input_bdims, + bias_bdims, + extra_out_bdims ), - (lhs_bdims, out_amax_bdims, out_scale_bdims, gelu_input_bdims, bias_bdims), ) @staticmethod def infer_sharding_from_operands( - out_dtype, batched_output, contracting_dims, fuse_gelu, @@ -607,11 +667,12 @@ def infer_sharding_from_operands( accumulate, use_split_accumulator, comm_overlap_config, + sharded_abstract, mesh, arg_infos, result_infos, ): - del out_dtype, accumulate, use_split_accumulator, result_infos + del accumulate, use_split_accumulator, sharded_abstract, result_infos lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) @@ -625,7 +686,6 @@ def infer_sharding_from_operands( # Modify operand specs lhs_spec_new = [spec for spec in lhs_spec] rhs_spec_new = [spec for spec in rhs_spec] - reduce_output = False if comm_overlap_config is None: # When comm overlap is not enabled: # - Always all-gather the outer dimension of LHS. @@ -651,7 +711,6 @@ def infer_sharding_from_operands( + "communication overhead." ) rhs_spec_new[rhs_outer_dim] = None - reduce_output = True else: if lhs_spec_new[lhs_inner_dim] is None and rhs_spec_new[rhs_inner_dim] is not None: warnings.warn( @@ -669,12 +728,6 @@ def infer_sharding_from_operands( ) lhs_spec_new[lhs_inner_dim] = None rhs_spec_new[rhs_inner_dim] = None - else: - # When comm overlap is enabled, make sure both contracting dims are unsharded if one - # of them is unsharded. - if lhs_spec_new[lhs_inner_dim] is None or rhs_spec_new[rhs_inner_dim] is None: - lhs_spec_new[lhs_inner_dim] = None - rhs_spec_new[rhs_inner_dim] = None out_col_spec = rhs_spec_new[rhs_outer_dim] # Output sharding is conditional on output shape @@ -698,37 +751,48 @@ def infer_sharding_from_operands( # Validate operand sharding for comm+GEMM overlap and adust extra output sharding extra_out_spec = [None] if comm_overlap_config is not None: - mesh = comm_overlap_config.get("mesh", None) + comm_type = comm_overlap_config.get("comm_type", None) tp_resource = comm_overlap_config.get("tp_resource", global_mesh_resource().tp_resource) - match comm_overlap_config.get("comm_type", None): - case tex.CommOverlapType.AG: - # AG overlap requires the outer dimension of LHS to be sharded - # over the TP resource - assert lhs_spec[lhs_outer_dim] == tp_resource, ( - "AG+GEMM overlap requires the outer (sequence) dimension of the LHS " - + f"operand to be sharded over the TP resource (mesh axis: {tp_resource=})." - ) - extra_out_spec = list(lhs_spec).copy() - extra_out_spec[lhs_outer_dim] = None - - case tex.CommOverlapType.RS: - # RS overlap requires the contracting dimensions of both LHS and RHS to be - # sharded over the TP resource, and the outer dimension of LHS to be unsharded - assert lhs_spec[lhs_outer_dim] is None, ( - "GEMM+RS overlap requires the outer (sequence) dimension of the LHS " - + "operand to be un-sharded." - ) - assert lhs_spec[lhs_inner_dim] == tp_resource, ( - "GEMM+RS overlap requires the contracting dimension of the LHS operand " - + f"to be sharded over the TP resource (mesh axis: {tp_resource=})." - ) - assert rhs_spec[rhs_inner_dim] == tp_resource, ( - "GEMM+RS overlap requires the contracting dimension of the RHS operand " - + f"to be sharded over the TP resource (mesh axis: {tp_resource=})." - ) - extra_out_spec = out_spec.copy() - extra_out_spec[-2] = tp_resource - + if comm_type == tex.CommOverlapType.AG: + # AG overlap requires the outer dimension of LHS to be sharded + # over the TP resource + assert lhs_spec[lhs_outer_dim] == tp_resource, ( + "AG+GEMM overlap requires the outer (sequence) dimension of the LHS " + + f"operand to be sharded over the TP resource '{tp_resource=}'." + ) + assert lhs_spec[lhs_inner_dim] is None, ( + "AG+GEMM overlap requires the contracting dimension of the LHS operand " + + "to be unsharded." + ) + assert rhs_spec[rhs_inner_dim] is None, ( + "AG+GEMM overlap requires the contracting dimension of the RHS operand " + + "to be unsharded." + ) + extra_out_spec = list(lhs_spec).copy() + extra_out_spec[lhs_outer_dim] = None + + elif comm_type == tex.CommOverlapType.RS: + # RS overlap requires the contracting dimensions of both LHS and RHS to be + # sharded over the TP resource, and the outer dimensions of LHS and RHS to be + # unsharded. + assert lhs_spec[lhs_outer_dim] is None, ( + "GEMM+RS overlap requires the outer (sequence) dimension of the LHS " + + "operand to be unsharded." + ) + assert lhs_spec[lhs_inner_dim] == tp_resource, ( + "GEMM+RS overlap requires the contracting dimension of the LHS operand " + + f"to be sharded over the TP resource '{tp_resource=}'." + ) + assert rhs_spec[rhs_inner_dim] == tp_resource, ( + "GEMM+RS overlap requires the contracting dimension of the RHS operand " + + f"to be sharded over the TP resource '{tp_resource=}'." + ) + assert rhs_spec[rhs_outer_dim] is None, ( + "GEMM+RS overlap requires the outer dimension of the RHS operand to be " + + "unsharded." + ) + extra_out_spec = list(out_spec).copy() + extra_out_spec[-2] = tp_resource extra_out_sharding = NamedSharding(mesh, PartitionSpec(*extra_out_spec)) return ( @@ -742,7 +806,6 @@ def infer_sharding_from_operands( @staticmethod def partition( - out_dtype, batched_output, contracting_dims, fuse_gelu, @@ -751,11 +814,12 @@ def partition( accumulate, use_split_accumulator, comm_overlap_config, + sharded_abstract, mesh, arg_infos, result_infos, ): - del result_infos + del sharded_abstract, result_infos lhs, _, rhs, *_ = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) @@ -783,12 +847,6 @@ def partition( else: lhs_spec_new[lhs_inner_dim] = None rhs_spec_new[rhs_inner_dim] = None - else: - # When comm overlap is enabled, make sure both contracting dims are unsharded if one - # of them is unsharded. - if lhs_spec_new[lhs_inner_dim] is None or rhs_spec_new[rhs_inner_dim] is None: - lhs_spec_new[lhs_inner_dim] = None - rhs_spec_new[rhs_inner_dim] = None out_col_spec = rhs_spec_new[rhs_outer_dim] lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_spec_new)) @@ -812,20 +870,17 @@ def partition( gelu_spec = [None, out_col_spec] if fuse_gelu else [None] gelu_sharding = NamedSharding(mesh, PartitionSpec(*gelu_spec)) - # Adjust extra output sharding for comm+GEMM overlap + # Extra output sharding for comm+GEMM overlap extra_out_spec = [None] if comm_overlap_config is not None: - mesh = comm_overlap_config.get("mesh", None) - tp_resource = comm_overlap_config.get("tp_resource", global_mesh_resource().tp_resource) - match comm_overlap_config.get("comm_type", None): - case tex.CommOverlapType.AG: - extra_out_spec = list(lhs_spec).copy() - extra_out_spec[lhs_outer_dim] = None - - case tex.CommOverlapType.RS: - extra_out_spec = out_spec.copy() - extra_out_spec[-2] = tp_resource - + comm_type = comm_overlap_config.get("comm_type", None) + if comm_type == tex.CommOverlapType.AG: + extra_out_spec = list(lhs_spec).copy() + extra_out_spec[lhs_outer_dim] = None + elif comm_type == tex.CommOverlapType.RS: + extra_out_spec = list(out_spec).copy() + extra_out_spec[-2] = comm_overlap_config.get("tp_resource", + global_mesh_resource().tp_resource) extra_out_sharding = NamedSharding(mesh, PartitionSpec(*extra_out_spec)) arg_shardings = ( @@ -835,8 +890,10 @@ def partition( fp8_meta_sharding, bias_sharding, gelu_sharding, + out_sharding, fp8_meta_sharding, fp8_meta_sharding, + extra_out_sharding, ) out_shardings = ( out_sharding, @@ -848,15 +905,16 @@ def partition( ) def sharded_impl( - lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out_amax, out_scale + lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out, out_amax, out_scale, + extra_out, ): ( - out, + out_updated, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad, - extra_out, + extra_out_updated, ) = CollectiveGemmPrimitive.impl( lhs, lhs_scale_inv, @@ -864,9 +922,10 @@ def sharded_impl( rhs_scale_inv, bias, gelu_input, + out, out_amax, out_scale, - out_dtype=out_dtype, + extra_out, batched_output=batched_output, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, @@ -875,6 +934,7 @@ def sharded_impl( accumulate=accumulate, use_split_accumulator=use_split_accumulator, comm_overlap_config=comm_overlap_config, + sharded_abstract=True, ) # FP8 amax reduction @@ -882,15 +942,19 @@ def sharded_impl( out_amax_updated = all_reduce_max_along_all_axes_except_PP(out_amax_updated, mesh) # All-reduce sum GEMM output when contracting dimensions are sharded - if comm_overlap_config is None: - if reduce_output: - out = jax.lax.psum(out, global_mesh_resource().tp_resource) - if fuse_gelu: - pre_gelu_out = jax.lax.psum( - pre_gelu_out, global_mesh_resource().tp_resource - ) + if comm_overlap_config is None and reduce_output: + out_updated = jax.lax.psum(out_updated, global_mesh_resource().tp_resource) + if fuse_gelu: + pre_gelu_out = jax.lax.psum(pre_gelu_out, global_mesh_resource().tp_resource) - return out, out_amax_updated, out_scale_updated, pre_gelu_out, bias_grad, extra_out + return ( + out_updated, + out_amax_updated, + out_scale_updated, + pre_gelu_out, + bias_grad, + extra_out_updated + ) return mesh, sharded_impl, out_shardings, arg_shardings @@ -903,6 +967,8 @@ def gemm_impl( rhs: ArrayLike, bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, + out: Optional[ArrayLike] = None, + extra_out: Optional[ArrayLike] = None, batched_output: bool = False, contracting_dims: Tuple[int, int] = (-1, -2), fuse_gelu: bool = False, @@ -917,7 +983,24 @@ def gemm_impl( lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) lhs_outer_dim = lhs.ndim - 1 if lhs_inner_dim != lhs.ndim - 1 else lhs.ndim - 2 rhs_outer_dim = rhs.ndim - 2 if rhs_inner_dim == rhs.ndim - 1 else rhs.ndim - 1 - out_shape = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + + out_shape_batched = (*lhs.shape[:-2], lhs.shape[lhs_outer_dim], rhs.shape[rhs_outer_dim]) + out_shape_2d = (reduce(operator.mul, out_shape_batched[:-1], 1), out_shape_batched[-1]) + out_shape = out_shape_batched if batched_output else out_shape_2d + + if out is None: + out = jnp.zeros(out_shape, dtype=lhs.dtype) + + if extra_out is None: + extra_out_shape = 0 + if (comm_overlap_config is not None + and comm_overlap_config["method"] != "bulk"): + comm_type = comm_overlap_config["comm_type"] + if comm_type == tex.CommOverlapType.AG: + extra_out_shape = list(lhs.shape).copy() + elif comm_type == tex.CommOverlapType.RS: + extra_out_shape = list(out_shape).copy() + extra_out = jnp.zeros(extra_out_shape, dtype=lhs.dtype) if not fuse_bias: bias = jnp.zeros(0, dtype=lhs.dtype) @@ -929,11 +1012,11 @@ def gemm_impl( if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=lhs.dtype) elif grad: - assert ( - gelu_input is not None - ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." + assert gelu_input is not None, ( + "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." + ) elif gelu_input is None: - gelu_input = jnp.zeros(out_shape, dtype=lhs.dtypes) + gelu_input = jnp.zeros(out_shape_2d, dtype=lhs.dtype) ( out, @@ -949,9 +1032,10 @@ def gemm_impl( dummy_fp8_meta, bias, gelu_input, + out, dummy_fp8_meta, dummy_fp8_meta, - out_dtype=lhs.dtype, + extra_out, batched_output=batched_output, contracting_dims=contracting_dims, fuse_gelu=fuse_gelu, @@ -960,6 +1044,7 @@ def gemm_impl( accumulate=accumulate, use_split_accumulator=use_split_accumulator, comm_overlap_config=comm_overlap_config, + sharded_abstract=False, ) if grad: @@ -975,6 +1060,7 @@ def fp8_gemm_impl( rhs_scale_inv: ArrayLike, bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, + out: Optional[ArrayLike] = None, out_amax: Optional[ArrayLike] = None, out_scale: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, @@ -986,7 +1072,27 @@ def fp8_gemm_impl( comm_overlap_config: Optional[dict] = None, ) -> Tuple[ArrayLike, ...]: """FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" - if out_dtype is not None and jax_dtype_is_fp8(out_dtype): + out_shape_batched = (*lhs.shape[:-2], lhs.shape[-1], rhs_t.shape[-1]) + out_shape_2d = (reduce(operator.mul, out_shape_batched[:-1], 1), out_shape_batched[-1]) + out_shape = out_shape_batched if batched_output else out_shape_2d + + if out is None: + out = jnp.zeros(out_shape, dtype=out_dtype) + else: + out_dtype = out.dtype + + if extra_out is None: + extra_out_shape = 0 + if (comm_overlap_config is not None + and comm_overlap_config["method"] != "bulk"): + comm_type = comm_overlap_config["comm_type"] + if comm_type == tex.CommOverlapType.AG: + extra_out_shape = list(lhs.shape).copy() + elif comm_type == tex.CommOverlapType.RS: + extra_out_shape = list(out_shape).copy() + extra_out = jnp.zeros(extra_out_shape, dtype=jnp.bfloat16) + + if jax_dtype_is_fp8(out_dtype): assert out_amax is not None and out_scale is not None, "Missing output amax and scale." else: out_amax = jnp.zeros(0, dtype=jnp.float32) @@ -1000,8 +1106,7 @@ def fp8_gemm_impl( if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=bias.dtype) elif gelu_input is None: - gelu_shape = (reduce(operator.mul, lhs.shape[:-1]), rhs_t.shape[-1]) - gelu_input = jnp.zeros(gelu_shape, dtype=bias.dtype) + gelu_input = jnp.zeros(out_shape_2d, dtype=bias.dtype) (out, out_amax, out_scale, pre_gelu_out, _, extra_out) = ( # bias_grad in non-FP8 GEMM CollectiveGemmPrimitive.outer_primitive.bind( @@ -1011,9 +1116,10 @@ def fp8_gemm_impl( lhs_scale_inv, bias, gelu_input, + out, out_amax, out_scale, - out_dtype=out_dtype, + extra_out, batched_output=batched_output, contracting_dims=(-1, -1), fuse_gelu=fuse_gelu, @@ -1022,86 +1128,190 @@ def fp8_gemm_impl( accumulate=accumulate, use_split_accumulator=use_split_accumulator, comm_overlap_config=comm_overlap_config, + sharded_abstract=False, ) ) return out, out_amax, out_scale, pre_gelu_out, extra_out +class BootstrapCommGemmOverlapPrimitive(BasePrimitive): + """ + Initialize Comm+GEMM overlap communicators and buffers + """ + + name = "te_bootstrap_comm_gemm_overlap_ffi" + impl_static_args = (1,) + multiple_results = False + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(buffer_aval, myrank, numranks, comm_overlap_config): + del myrank, numranks + assert is_ffi_enabled(), "Comm+GEMM overlap is supported only via XLA FFI." + overlap_name = comm_overlap_config.get("name", None) + assert overlap_name in _COMM_GEMM_OVERLAP_NAMES, ( + f"Unrecognized comm+GEMM overlap name: {overlap_name=}" + ) + assert buffer_aval.size > 0, "Cannot initialize a zero-size communication buffer." + return jax.core.ShapedArray(shape=(0,), dtype=dtypes.canonicalize_dtype(buffer_aval.dtype)) + + @staticmethod + def lowering(ctx, buffer, *, myrank, numranks, comm_overlap_config): + return ffi.ffi_lowering(BootstrapCommGemmOverlapPrimitive.name)( + ctx, + buffer, + name=comm_overlap_config["name"], + method=comm_overlap_config["method"], + myrank=myrank, + numranks=numranks, + tp_size=comm_overlap_config["tp_size"], + num_splits=comm_overlap_config["num_splits"], + num_max_streams=comm_overlap_config["num_max_streams"], + cga_size=comm_overlap_config["cga_size"], + num_comm_sm=comm_overlap_config["num_sm"], + set_sm_margin=comm_overlap_config["set_sm_margin"], + use_ce=comm_overlap_config["use_ce"], + atomic_gemm=comm_overlap_config["atomic_gemm"], + aggregate=comm_overlap_config["aggregate"], + pipeline_rs_overlap_first_gemm=comm_overlap_config["pipeline_rs_overlap_first_gemm"], + ) + + @staticmethod + def impl(buffer, myrank, numranks, comm_overlap_config): + assert BootstrapCommGemmOverlapPrimitive.inner_primitive is not None + buffer = jax.lax.reshape( + buffer, (reduce(operator.mul, buffer.shape[:-1], 1), buffer.shape[-1]) + ) + return BootstrapCommGemmOverlapPrimitive.inner_primitive.bind( + buffer, myrank=myrank, numranks=numranks, comm_overlap_config=comm_overlap_config, + ) + + @staticmethod + def batcher(batched_args, batch_dims, *, myrank, numranks, comm_overlap_config): + assert BootstrapCommGemmOverlapPrimitive.inner_primitive is not None + check_valid_batch_dims(batch_dims) + return ( + BootstrapCommGemmOverlapPrimitive.inner_primitive.bind( + *batched_args, myrank=myrank, numranks=numranks, comm_overlap_config=comm_overlap_config + ), + None, + ) + + @staticmethod + def infer_sharding_from_operands(myrank, numranks, comm_overlap_config, mesh, arg_infos, + result_infos): + del myrank, numranks, comm_overlap_config, result_infos + buffer_spec = get_padded_spec(arg_infos[0]) + assert all([spec is None for spec in buffer_spec]), "Sample buffer must be unsharded." + return NamedSharding(mesh, PartitionSpec(None)) + + @staticmethod + def partition(myrank, numranks, comm_overlap_config, mesh, arg_infos, result_infos): + del arg_infos, result_infos + arg_shardings = (NamedSharding(mesh, PartitionSpec(None)),) + out_sharding = NamedSharding(mesh, PartitionSpec(None)) + return ( + mesh, + partial(BootstrapCommGemmOverlapPrimitive.impl, + myrank=myrank, + numranks=numranks, + comm_overlap_config=comm_overlap_config), + out_sharding, + arg_shardings, + ) + + +register_primitive(BootstrapCommGemmOverlapPrimitive) + + +def bootstrap_comm_gemm_overlap( + buffer: ArrayLike, + myrank: int, + numranks: int, + comm_overlap_config: dict +): + _ = BootstrapCommGemmOverlapPrimitive.outer_primitive.bind( + buffer, + myrank=myrank, + numranks=numranks, + comm_overlap_config=comm_overlap_config + ) + + class CopyIntoOverlapBufferPrimitive(BasePrimitive): """ Copy JAX array data into comm+GEMM overlap buffer """ - name = "te_copy_into_overlap_buffer" + name = "te_copy_into_overlap_buffer_ffi" impl_static_args = (1, 2) multiple_results = False inner_primitive = None outer_primitive = None @staticmethod - def abstract(inp_aval, name, comm_type): + def abstract(inp_aval, name, sharded): + del sharded + assert is_ffi_enabled(), "Comm+GEMM overlap is supported only via XLA FFI." assert name in _COMM_GEMM_OVERLAP_NAMES, f"Unrecognized comm+GEMM overlap name: {name=}" - assert comm_type in [ - tex.CommOverlapType.AG, - tex.CommOverlapType.RS, - ], "Invalid comm+GEMM overlap type." assert inp_aval.size > 0, "Cannot copy a zero-size array into overlap buffer." - assert inp_aval.ndim == 2, "Cannot copy more than 2 dimensions!" return jax.core.ShapedArray(shape=(0,), dtype=dtypes.canonicalize_dtype(inp_aval.dtype)) @staticmethod - def lowering(ctx, inp, *, name, comm_type): - if is_ffi_enabled(): - name = "te_copy_into_overlap_buffer_ffi" - return ffi.ffi_lowering(name)( - ctx, - inp, - name=name, - comm_type=int(comm_type), - ) - else: - operands = [inp] - operand_shapes = [ir.RankedTensorType(inp.type).shape] - out_types = [] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - opaque = tex.pack_buffer_descriptor( - name, inp.shape, jax_dtype_to_te_dtype(inp.dtype), comm_type - ) - return custom_caller(CopyIntoOverlapBufferPrimitive.name, args, opaque, False) + def lowering(ctx, inp, *, name, sharded): + return ffi.ffi_lowering(name)( + ctx, + inp, + name=name, + sharded=sharded, + ) @staticmethod - def impl(inp, name, comm_type): + def impl(inp, name, sharded): assert CopyIntoOverlapBufferPrimitive.inner_primitive is not None + inp_2d = jax.lax.reshape(inp, (reduce(operator.mul, inp.shape[:-1], 1), inp.shape[-1])) return CopyIntoOverlapBufferPrimitive.inner_primitive.bind( - inp, name=name, comm_type=comm_type + inp_2d, name=name, sharded=sharded ) @staticmethod - def batcher(batched_args, batch_dims, *, name, comm_type): + def batcher(batched_args, batch_dims, *, name, sharded): assert CopyIntoOverlapBufferPrimitive.inner_primitive is not None check_valid_batch_dims(batch_dims) return ( CopyIntoOverlapBufferPrimitive.inner_primitive.bind( - *batched_args, name=name, comm_type=comm_type + *batched_args, name=name, sharded=sharded ), None, ) @staticmethod - def infer_sharding_from_operands(name, comm_type, mesh, arg_infos, result_infos): - del name, comm_type, arg_infos, result_infos + def infer_sharding_from_operands(name, sharded, mesh, arg_infos, result_infos): + del name, result_infos + inp_spec = get_padded_spec(arg_infos[0]) + if sharded: + assert inp_spec[-2] is not None, ( + "Leading dimension of input tensor must be sharded in order to copy into a " + + "sharded communication tensor (e.g. preparing for bulk all-gather overlap)." + ) + else: + assert inp_spec[-2] is None, ( + "Leading dimension of input tensor cannot be sharded when copying into an " + + "unsharded communication tensor (e.g. preparing for bulk reduce-scatter overlap)." + ) return NamedSharding(mesh, PartitionSpec(None)) @staticmethod - def partition(name, comm_type, mesh, arg_infos, result_infos): - del name, comm_type, result_infos - inp_spec = arg_infos[0] + def partition(name, sharded, mesh, arg_infos, result_infos): + del name, sharded, result_infos + inp_spec = get_padded_spec(arg_infos[0]) arg_shardings = (NamedSharding(mesh, PartitionSpec(*inp_spec)),) out_sharding = NamedSharding(mesh, PartitionSpec(None)) return ( mesh, - partial(CopyIntoOverlapBufferPrimitive.impl, name=name, comm_type=comm_type), + partial(CopyIntoOverlapBufferPrimitive.impl, name=name, sharded=sharded), out_sharding, arg_shardings, ) @@ -1110,5 +1320,5 @@ def partition(name, comm_type, mesh, arg_infos, result_infos): register_primitive(CopyIntoOverlapBufferPrimitive) -def copy_into_overlap_buffer(inp: ArrayLike, name: str, comm_type: tex.CommOverlapType) -> None: - _ = CollectiveGemmPrimitive.outer_primitive.bind(inp, name=name, comm_type=comm_type) +def copy_into_overlap_buffer(inp: ArrayLike, name: str, sharded: bool) -> None: + _ = CopyIntoOverlapBufferPrimitive.outer_primitive.bind(inp, name=name, sharded=sharded) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index d123d9b5b4..fd0786a040 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -171,44 +171,6 @@ pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, size_ bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator); -struct CustomCallBufferDescriptor { - const std::string name; - const size_t *shape; - const size_t ndim; - DType dtype; - CommOverlapType comm_type; -}; - -pybind11::bytes PackCustomCallBufferDescriptor(const std::string &name, - const std::vector &shape, DType dtype, - CommOverlapType comm_type); - -struct CustomCallOverlapDescriptor { - size_t m; - size_t k; - size_t n; - size_t workspace_size; - DType operand_dtype; - DType bias_dtype; - DType out_dtype; - bool lhs_trans; - bool rhs_trans; - bool fuse_gelu; - bool fuse_bias; - bool grad; - bool accumulate; - bool use_split_accumulator; - CommOverlapType comm_type; - const std::string name; -}; - -pybind11::bytes PackCustomCallOverlapDescriptor(size_t m, size_t k, size_t n, size_t workspace_size, - DType operand_dtype, DType bias_dtype, - DType out_dtype, bool lhs_trans, bool rhs_trans, - bool fuse_gelu, bool fuse_bias, bool grad, - bool accumulate, bool use_split_accumulator, - CommOverlapType comm_type, const std::string &name); - // Transpose void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); @@ -372,54 +334,63 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); // GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasltHandleInitHandler); + void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); -Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, - Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, - Result_Type out_amax_updated, Result_Type out_scale_updated, - Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type dummy_out, - Result_Type workspace, bool lhs_trans, bool rhs_trans, bool fuse_gelu, - bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator); +Error_Type GemmFFI( + cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out, + Buffer_Type out_amax, Buffer_Type out_scale, Buffer_Type dummy_in, Result_Type out_updated, + Result_Type out_amax_updated, Result_Type out_scale_updated, Result_Type pre_gelu_out, + Result_Type bias_grad, Result_Type dummy_out, Result_Type workspace, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, + bool use_split_accumulator); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); // Comm+GEMM Overlap -void BootstrapCommGemmOverlap(const std::string &name, const std::string &method, - const std::vector &buffer_shape, DType buffer_dtype, - CommOverlapType comm_type, int tp_size, int num_splits, - int num_max_streams, int comm_cga_size, int num_comm_sm, - int set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate, - bool pipeline_rs_overlap_first_gemm); +bool OverlapBufferIsFp8(const std::string &name); -void DestroyCommGemmOverlap(const std::string &name); +pybind11::object GetOverlapBuffer(const std::string &name, bool sharded); -void SetOverlapBufferScaleInverse(const std::string &name, pybind11::object scale_inv, - bool grad = false); +void SetOverlapBufferScaleInverse(const std::string &name, pybind11::object scale_inv, bool grad); -bool OverlapBufferIsFp8(const std::string &name); +void BootstrapCommGemmOverlap( + const std::vector &buffer_shape, DType buffer_dtype, const std::string &name, + const std::string &method, CommOverlapType comm_type, int64_t myrank, int64_t numranks, + int64_t tp_size, int64_t num_splits, int64_t num_max_streams, int64_t cga_size, + int64_t num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm); + +Error_Type BootstrapCommGemmOverlapFFI( + cudaStream_t, Buffer_Type sample_buffer, std::string_view name, std::string_view method, + int64_t comm_type_flag, int64_t myrank, int64_t numranks, int64_t tp_size, int64_t num_splits, + int64_t num_max_streams, int64_t cga_size, int64_t num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate, bool pipeline_rs_overlap_first_gemm); + +XLA_FFI_DECLARE_HANDLER_SYMBOL(BootstrapCommGemmOverlapHandler); + +void DestroyCommGemmOverlap(const std::string &name); -pybind11::object GetOverlapBuffer(const std::string &name, CommOverlapType comm_type); +Error_Type DestroyCommGemmOverlapFFI(cudaStream_t stream, std::string_view name); -void CopyIntoOverlapBuffer(cudaStream_t, void **buffers, const char *opaque, size_t opaque_len); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DestroyCommGemmOverlapHandler); Error_Type CopyIntoOverlapBufferFFI(cudaStream_t stream, Buffer_Type input, std::string_view name, - int32_t comm_type_flag); + bool sharded); XLA_FFI_DECLARE_HANDLER_SYMBOL(CopyIntoOverlapBufferHandler); -void CommGemmOverlap(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); - -Error_Type CommGemmOverlapFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, - Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, - Buffer_Type gelu_input, Buffer_Type out_amax, Buffer_Type out_scale, - Result_Type out, Result_Type out_amax_new, Result_Type out_scale_new, - Result_Type pre_gelu_out, Result_Type bias_grad, - Result_Type extra_out, Result_Type workspace, bool lhs_trans, - bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, - bool accumulate, bool use_split_accumulator, int32_t comm_type_flag, - std::string_view name); +Error_Type CommGemmOverlapFFI( + cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out, + Buffer_Type out_amax, Buffer_Type out_scale, Buffer_Type extra_out, Result_Type out_updated, + Result_Type out_amax_updated, Result_Type out_scale_updated, Result_Type pre_gelu_out, + Result_Type bias_grad, Result_Type extra_out_updated, Result_Type workspace, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, + bool use_split_accumulator, int64_t comm_type_flag, std::string_view name); XLA_FFI_DECLARE_HANDLER_SYMBOL(CommGemmOverlapHandler); diff --git a/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp index df1f4bdc23..d6f5daaa80 100644 --- a/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp @@ -16,14 +16,38 @@ namespace transformer_engine { namespace jax { +Error_Type CublasltHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets, + Dictionary attrs) { + cublasLtHandle_t handle; + NVTE_CHECK_CUBLAS(cublasLtCreate(&handle)); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CublasltHandleInitHandler, CublasltHandleInitFFI, + FFI::Bind().RemainingArgs().RemainingRets().Attrs()); + static std::unordered_map _overlaps; -void BootstrapCommGemmOverlap(const std::string &name, const std::string &method, - const std::vector &buffer_shape, DType buffer_dtype, - CommOverlapType comm_type, int tp_size, int num_splits, - int num_max_streams, int comm_cga_size, int num_comm_sm, - int set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate, - bool pipeline_rs_overlap_first_gemm) { +void SetOverlapBufferScaleInverse(const std::string &name, pybind11::object scale_inv, bool grad) { + auto scale_inv_tensor = DLPackWrapper(scale_inv, grad); + _overlaps[name]->set_ubuf_scale_inv(reinterpret_cast(scale_inv_tensor.dptr())); +} + +bool OverlapBufferIsFp8(const std::string &name) { return _overlaps[name]->is_fp8_ubuf(); } + +pybind11::object GetOverlapBuffer(const std::string &name, bool sharded) { + auto comm_type = (sharded) ? CommOverlapType::RS : CommOverlapType::AG; + DLPackWrapper output = std::move(_overlaps[name]->get_ubuf_output(comm_type)); + auto capsule = output.capsule(); + return capsule; +}; + +void BootstrapCommGemmOverlap( + const std::vector &buffer_shape, DType buffer_dtype, const std::string &name, + const std::string &method, CommOverlapType comm_type, int64_t myrank, int64_t numranks, + int64_t tp_size, int64_t num_splits, int64_t num_max_streams, int64_t comm_cga_size, + int64_t num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm) { #ifndef NVTE_UB_WITH_MPI NVTE_ERROR( std::string("Comm+GEMM overlap in TE/JAX requires bootstrapping Userbuffers with MPI. ") + @@ -32,19 +56,56 @@ void BootstrapCommGemmOverlap(const std::string &name, const std::string &method // Initialize overlap object -- this allocates the comm buffer NVTE_CHECK(_overlaps.find(name) == _overlaps.end(), name, " is already initialized!"); - if (method == "ring-exchange") { - _overlaps[name] = reinterpret_cast(new CommOverlapP2PBase( - buffer_shape, buffer_dtype, -1, -1, -1, -1, -1, -1, tp_size, &_dummy_allgather, + if (method == "ring_exchange") { + _overlaps[name] = new CommOverlapP2PBase( + buffer_shape, buffer_dtype, myrank, numranks, -1, -1, -1, -1, tp_size, &_dummy_allgather, &_dummy_barrier, comm_type, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, - use_ce, atomic_gemm, aggregate)); + use_ce, atomic_gemm, aggregate); } else { - _overlaps[name] = reinterpret_cast(new CommOverlapBase( - buffer_shape, buffer_dtype, -1, -1, -1, -1, -1, -1, tp_size, &_dummy_allgather, + _overlaps[name] = new CommOverlapBase( + buffer_shape, buffer_dtype, myrank, numranks, -1, -1, -1, -1, tp_size, &_dummy_allgather, &_dummy_barrier, num_splits, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, - atomic_gemm, pipeline_rs_overlap_first_gemm)); + atomic_gemm, pipeline_rs_overlap_first_gemm); } }; +Error_Type BootstrapCommGemmOverlapFFI( + cudaStream_t, Buffer_Type sample_buffer, std::string_view name, std::string_view method, + int64_t comm_type_flag, int64_t myrank, int64_t numranks, int64_t tp_size, int64_t num_splits, + int64_t num_max_streams, int64_t cga_size, int64_t num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate, bool pipeline_rs_overlap_first_gemm) { + auto buffer_shape = std::vector(sample_buffer.dimensions().begin(), + sample_buffer.dimensions().end()); + auto buffer_dtype = convert_ffi_datatype_to_te_dtype(sample_buffer.element_type()); + BootstrapCommGemmOverlap( + buffer_shape, buffer_dtype, static_cast(name), static_cast(method), + static_cast(comm_type_flag), myrank, numranks, tp_size, num_splits, + num_max_streams, cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate, + pipeline_rs_overlap_first_gemm); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(BootstrapCommGemmOverlapHandler, BootstrapCommGemmOverlapFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // sample_buffer + .Attr("name") + .Attr("method") + .Attr("comm_type_flag") + .Attr("myrank") + .Attr("numranks") + .Attr("tp_size") + .Attr("num_splits") + .Attr("num_max_streams") + .Attr("cga_size") + .Attr("num_comm_sm") + .Attr("set_sm_margin") + .Attr("use_ce") + .Attr("atomic_gemm") + .Attr("aggregate") + .Attr("pipeline_rs_overlap_first_gemm"), + FFI_CudaGraph_Traits); + void DestroyCommGemmOverlap(const std::string &name) { auto overlap = _overlaps.find(name); if (overlap != _overlaps.end()) { @@ -53,45 +114,33 @@ void DestroyCommGemmOverlap(const std::string &name) { } }; -void SetOverlapBufferScaleInverse(const std::string &name, pybind11::object scale_inv, bool grad) { - auto scale_inv_tensor = DLPackWrapper(scale_inv, grad); - _overlaps[name]->set_ubuf_scale_inv(reinterpret_cast(scale_inv_tensor.dptr())); +Error_Type DestroyCommGemmOverlapFFI(cudaStream_t stream, std::string_view name) { + DestroyCommGemmOverlap(static_cast(name)); + return ffi_with_cuda_error_check(); } -bool OverlapBufferIsFp8(const std::string &name) { return _overlaps[name]->is_fp8_ubuf(); } - -pybind11::object GetOverlapBuffer(const std::string &name, CommOverlapType comm_type) { - DLPackWrapper output = std::move(_overlaps[name]->get_ubuf_output(comm_type)); - auto capsule = output.capsule(); - return capsule; -}; +XLA_FFI_DEFINE_HANDLER_SYMBOL(DestroyComMGemmOverlapHandler, DestroyCommGemmOverlapFFI, + FFI::Bind() + .Ctx() + .Attr("name"), + FFI_CudaGraph_Traits); void CopyIntoOverlapBufferImpl(cudaStream_t stream, void *input_ptr, const std::vector &shape, DType dtype, - const std::string &name, CommOverlapType comm_type) { + const std::string &name, bool sharded) { auto input = TensorWrapper(input_ptr, shape, dtype); + auto comm_type = (sharded) ? CommOverlapType::RS : CommOverlapType::AG; _overlaps[name]->copy_into_ubuf(stream, input, comm_type); } -void CopyIntoOverlapBuffer(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - auto input_ptr = buffers[0]; - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - - CopyIntoOverlapBufferImpl(stream, input_ptr, - std::vector(desc.shape, desc.shape + desc.ndim), desc.dtype, - desc.name, desc.comm_type); -} - Error_Type CopyIntoOverlapBufferFFI(cudaStream_t stream, Buffer_Type input, std::string_view name, - int32_t comm_type_flag) { + bool sharded) { auto input_ptr = input.untyped_data(); auto shape = std::vector(input.dimensions().begin(), input.dimensions().end()); auto dtype = convert_ffi_datatype_to_te_dtype(input.element_type()); CopyIntoOverlapBufferImpl(stream, input_ptr, shape, dtype, static_cast(name), - static_cast(comm_type_flag)); + sharded); return ffi_with_cuda_error_check(); } @@ -101,7 +150,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CopyIntoOverlapBufferHandler, CopyIntoOverlapBuffe .Ctx() // stream .Arg() // input .Attr("name") - .Attr("comm_type_flag"), + .Attr("sharded"), FFI_CudaGraph_Traits); void CommGemmOverlapImpl(void *lhs, const std::vector &lhs_shape, DType lhs_dtype, @@ -156,59 +205,14 @@ void CommGemmOverlapImpl(void *lhs, const std::vector &lhs_shape, DType } } -void CommGemmOverlap(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { - // Inputs - auto lhs = buffers[0]; - auto lhs_scale_inv = reinterpret_cast(buffers[1]); - auto rhs = buffers[2]; - auto rhs_scale_inv = reinterpret_cast(buffers[3]); - auto bias = buffers[4]; - auto gelu_input = buffers[5]; - auto out_amax = reinterpret_cast(buffers[6]); - auto out_scale = reinterpret_cast(buffers[7]); - - // Outputs - auto out = buffers[8]; - auto out_amax_new = reinterpret_cast(buffers[9]); - auto out_scale_new = reinterpret_cast(buffers[10]); - auto pre_gelu_out = buffers[11]; - auto bias_grad = buffers[12]; - auto extra_out = buffers[13]; - auto workspace = buffers[14]; - - // Check operand-output aliases - NVTE_CHECK(bias == bias_grad, "bias not bound to bias_grad in AG+GEMM overlap."); - NVTE_CHECK(gelu_input == pre_gelu_out, - "gelu_input not bound to pre_gelu_out in AG+GEMM overlap."); - NVTE_CHECK(out_amax == out_amax_new, "out_amax not bound to out_amax_new in AG+GEMM overlap."); - NVTE_CHECK(out_scale == out_scale_new, - "out_scale not bound to out_scale_new in AG+GEMM overlap."); - - const auto &desc = *UnpackOpaque(opaque, opaque_len); - - auto lhs_shape = - (desc.lhs_trans) ? std::vector{desc.k, desc.m} : std::vector{desc.m, desc.k}; - auto rhs_shape = - (desc.rhs_trans) ? std::vector{desc.n, desc.k} : std::vector{desc.k, desc.n}; - auto out_shape = std::vector{desc.m, desc.n}; - - CommGemmOverlapImpl(lhs, lhs_shape, desc.operand_dtype, lhs_scale_inv, desc.lhs_trans, rhs, - rhs_shape, desc.operand_dtype, rhs_scale_inv, desc.rhs_trans, out, out_shape, - desc.out_dtype, out_amax, out_scale, bias, desc.bias_dtype, pre_gelu_out, - extra_out, lhs_shape, workspace, desc.workspace_size, desc.fuse_gelu, - desc.fuse_bias, desc.grad, desc.accumulate, desc.use_split_accumulator, - desc.comm_type, desc.name, stream); -} - -Error_Type CommGemmOverlapFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, - Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, - Buffer_Type gelu_input, Buffer_Type out_amax, Buffer_Type out_scale, - Result_Type out, Result_Type out_amax_new, Result_Type out_scale_new, - Result_Type pre_gelu_out, Result_Type bias_grad, - Result_Type extra_out, Result_Type workspace, bool lhs_trans, - bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, - bool accumulate, bool use_split_accumulator, int32_t comm_type_flag, - std::string_view name) { +Error_Type CommGemmOverlapFFI( + cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out, + Buffer_Type out_amax, Buffer_Type out_scale, Buffer_Type extra_out, Result_Type out_updated, + Result_Type out_amax_updated, Result_Type out_scale_updated, Result_Type pre_gelu_out, + Result_Type bias_grad, Result_Type extra_out_updated, Result_Type workspace, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, + bool use_split_accumulator, int64_t comm_type_flag, std::string_view name) { // Inputs auto lhs_ptr = lhs.untyped_data(); auto lhs_shape = std::vector(lhs.dimensions().begin(), lhs.dimensions().end()); @@ -221,31 +225,38 @@ Error_Type CommGemmOverlapFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type auto bias_ptr = bias.untyped_data(); auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); auto gelu_input_ptr = gelu_input.untyped_data(); + auto out_ptr = out.untyped_data(); auto out_amax_ptr = reinterpret_cast(out_amax.untyped_data()); auto out_scale_ptr = reinterpret_cast(out_scale.untyped_data()); + auto extra_out_ptr = extra_out.untyped_data(); // Outputs - auto out_ptr = out->untyped_data(); - auto out_shape = std::vector(out->dimensions().begin(), out->dimensions().end()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(out->element_type()); - auto out_amax_new_ptr = reinterpret_cast(out_amax_new->untyped_data()); - auto out_scale_new_ptr = reinterpret_cast(out_scale_new->untyped_data()); + auto out_updated_ptr = out_updated->untyped_data(); + auto out_shape = std::vector(out_updated->dimensions().begin(), + out_updated->dimensions().end()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(out_updated->element_type()); + auto out_amax_updated_ptr = reinterpret_cast(out_amax_updated->untyped_data()); + auto out_scale_updated_ptr = reinterpret_cast(out_scale_updated->untyped_data()); auto pre_gelu_ptr = pre_gelu_out->untyped_data(); auto bias_grad_ptr = bias_grad->untyped_data(); - auto extra_out_ptr = extra_out->untyped_data(); - auto extra_out_shape = - std::vector(extra_out->dimensions().begin(), extra_out->dimensions().end()); + auto extra_out_updated_ptr = extra_out_updated->untyped_data(); + auto extra_out_shape = std::vector(extra_out_updated->dimensions().begin(), + extra_out_updated->dimensions().end()); auto workspace_ptr = workspace->untyped_data(); auto workspace_size = workspace->element_count(); // Check operand-output aliases - NVTE_CHECK(bias_ptr == bias_grad_ptr, "bias not bound to bias_grad in AG+GEMM overlap."); + NVTE_CHECK(bias_ptr == bias_grad_ptr, "bias not bound to bias_grad in TE/JAX comm+GEMM overlap."); NVTE_CHECK(gelu_input_ptr == pre_gelu_ptr, - "gelu_input not bound to pre_gelu_out in AG+GEMM overlap."); - NVTE_CHECK(out_amax_ptr == out_amax_new_ptr, - "out_amax not bound to out_amax_new in AG+GEMM overlap."); - NVTE_CHECK(out_scale_ptr == out_scale_new_ptr, - "out_scale not bound to out_scale_new in AG+GEMM overlap."); + "gelu_input not bound to pre_gelu_out in TE/JAX comm+GEMM overlap."); + NVTE_CHECK(out_ptr == out_updated_ptr, + "out not bound to out_updated in TE/JAX comm+GEMM overlap."); + NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, + "out_amax not bound to out_amax_updated in TE/JAX comm+GEMM overlap."); + NVTE_CHECK(out_scale_ptr == out_scale_updated_ptr, + "out_scale not bound to out_scale_updated in TE/JAX comm+GEMM overlap."); + NVTE_CHECK(extra_out_ptr == extra_out_updated_ptr, + "extra_out not bound to extra_out_updated in TE/JAX comm+GEMM overlap."); CommGemmOverlapImpl( lhs_ptr, lhs_shape, lhs_dtype, lhs_scale_inv_ptr, lhs_trans, rhs_ptr, rhs_shape, rhs_dtype, @@ -266,14 +277,16 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CommGemmOverlapHandler, CommGemmOverlapFFI, .Arg() // rhs_scale_inv .Arg() // bias .Arg() // gelu_input + .Arg() // out .Arg() // out_amax .Arg() // out_scale - .Ret() // out - .Ret() // out_amax_new - .Ret() // out_scale_new + .Arg() // extra_out + .Ret() // out_updated + .Ret() // out_amax_updated + .Ret() // out_scale_updated .Ret() // pre_gelu_out .Ret() // bias_grad - .Ret() // extra_out + .Ret() // extra_out_updated .Ret() // workspace .Attr("lhs_trans") .Attr("rhs_trans") @@ -282,7 +295,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CommGemmOverlapHandler, CommGemmOverlapFFI, .Attr("grad") .Attr("accumulate") .Attr("use_split_accumulator") - .Attr("comm_type_flag") + .Attr("comm_type_flag") .Attr("name"), FFI_CudaGraph_Traits); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 14148ecbd0..8f6f907268 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -53,21 +53,24 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque auto *rhs_scale_inv = reinterpret_cast(buffers[3]); auto *bias = buffers[4]; auto *gelu_input = buffers[5]; - auto *out_amax = reinterpret_cast(buffers[6]); - auto *out_scale = reinterpret_cast(buffers[7]); + auto *out = buffers[6]; + auto *out_amax = reinterpret_cast(buffers[7]); + auto *out_scale = reinterpret_cast(buffers[8]); + // buffers[9] is the extra output bufer for comm+GEMM overlap, not used here // Outputs - auto *out = buffers[8]; - auto *out_amax_updated = reinterpret_cast(buffers[9]); - auto *out_scale_updated = reinterpret_cast(buffers[10]); - auto *pre_gelu_out = buffers[11]; - auto *bias_grad = buffers[12]; - // buffers[13] is the extra output for comm+GEMM overlap, not used here - auto *workspace = buffers[14]; + auto *out_updated = buffers[10]; + auto *out_amax_updated = reinterpret_cast(buffers[11]); + auto *out_scale_updated = reinterpret_cast(buffers[12]); + auto *pre_gelu_out = buffers[13]; + auto *bias_grad = buffers[14]; + // buffers[15] is the updated extra output for comm+GEMM overlap, not used here + auto *workspace = buffers[16]; // Operand aliasing NVTE_CHECK(bias == bias_grad, "bias not bound to bias_grad in TE/JAX GEMM"); NVTE_CHECK(gelu_input == pre_gelu_out, "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out == out_updated, "out not bound to out_updated in TE/JAX GEMM"); NVTE_CHECK(out_amax == out_amax_updated, "out_amax not bound to out_amax_updated in TE/JAX GEMM"); NVTE_CHECK(out_scale == out_scale_updated, "out_scale not bound to out_scale_updated in TE/JAX GEMM"); @@ -85,13 +88,15 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque desc.fuse_bias, desc.grad, desc.accumulate, desc.use_split_accumulator); } -Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, - Buffer_Type out_amax, Buffer_Type out_scale, Result_Type out, - Result_Type out_amax_updated, Result_Type out_scale_updated, - Result_Type pre_gelu_out, Result_Type bias_grad, Result_Type dummy_out, - Result_Type workspace, bool lhs_trans, bool rhs_trans, bool fuse_gelu, - bool fuse_bias, bool grad, bool accumulate, bool use_split_accumulator) { +Error_Type GemmFFI( + cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out, + Buffer_Type out_amax, Buffer_Type out_scale, Buffer_Type dummy_in, Result_Type out_updated, + Result_Type out_amax_updated, Result_Type out_scale_updated, Result_Type pre_gelu_out, + Result_Type bias_grad, Result_Type dummy_out, Result_Type workspace, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, + bool use_split_accumulator +) { // Inputs auto lhs_ptr = lhs.untyped_data(); auto lhs_scale_inv_ptr = reinterpret_cast(lhs_scale_inv.untyped_data()); @@ -101,17 +106,19 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto bias_ptr = bias.untyped_data(); auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); auto gelu_input_ptr = gelu_input.untyped_data(); + auto out_ptr = out.untyped_data(); auto out_amax_ptr = reinterpret_cast(out_amax.untyped_data()); auto out_scale_ptr = reinterpret_cast(out_scale.untyped_data()); + // dummy_in is the extra output buffer for comm+GEMM overlap, not used here // Outputs - auto out_ptr = out->untyped_data(); + auto out_updated_ptr = out_updated->untyped_data(); auto out_amax_updated_ptr = reinterpret_cast(out_amax_updated->untyped_data()); auto out_scale_updated_ptr = reinterpret_cast(out_scale_updated->untyped_data()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(out->element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(out_updated->element_type()); auto pre_gelu_out_ptr = pre_gelu_out->untyped_data(); auto bias_grad_ptr = bias_grad->untyped_data(); - // dummy_out is the extra output for comm+GEMM overlap, not used here + // dummy_out is the updated extra output for comm+GEMM overlap, not used here auto workspace_ptr = workspace->untyped_data(); auto workspace_size = workspace->dimensions().back(); @@ -119,6 +126,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i NVTE_CHECK(bias_ptr == bias_grad_ptr, "bias not bound to bias_grad in TE/JAX GEMM"); NVTE_CHECK(gelu_input_ptr == pre_gelu_out_ptr, "gelu_input not bound to pre_gelu_out in TE/JAX GEMM"); + NVTE_CHECK(out_ptr == out_updated_ptr, "out not bound to out_updated in TE/JAX GEMM"); NVTE_CHECK(out_amax_ptr == out_amax_updated_ptr, "out_amax not bound to out_amax_updated in TE/JAX GEMM"); NVTE_CHECK(out_scale_ptr == out_scale_updated_ptr, @@ -146,9 +154,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Arg() // rhs_scale_inv .Arg() // bias .Arg() // gelu_input + .Arg() // out .Arg() // out_amax .Arg() // out_scale - .Ret() // out + .Arg() // dummy_in + .Ret() // out_updated .Ret() // out_amax_updated .Ret() // out_scale_updated .Ret() // pre_gelu_out diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 31a53529e3..dd4070af41 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -90,24 +90,5 @@ pybind11::bytes PackCustomCallGemmDescriptor(size_t m, size_t n, size_t k, size_ grad, accumulate, use_split_accumulator}); } -pybind11::bytes PackCustomCallBufferDescriptor(const std::string &name, - const std::vector &shape, DType dtype, - CommOverlapType comm_type) { - return PackOpaque( - {name, shape.data(), shape.size(), dtype, comm_type}); -} - -pybind11::bytes PackCustomCallOverlapDescriptor(size_t m, size_t k, size_t n, size_t workspace_size, - DType operand_dtype, DType bias_dtype, - DType out_dtype, bool lhs_trans, bool rhs_trans, - bool fuse_gelu, bool fuse_bias, bool grad, - bool accumulate, bool use_split_accumulator, - CommOverlapType comm_type, - const std::string &name) { - return PackOpaque( - {m, n, k, workspace_size, operand_dtype, bias_dtype, out_dtype, lhs_trans, rhs_trans, - fuse_gelu, fuse_bias, grad, accumulate, use_split_accumulator, comm_type, name}); -} - } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 2bf13a600d..c61e9c8127 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -54,8 +54,6 @@ pybind11::dict Registrations() { dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); dict["te_gemm"] = EncapsulateFunction(Gemm); - dict["te_copy_into_overlap_buffer"] = EncapsulateFunction(CopyIntoOverlapBuffer); - dict["te_comm_gemm_overlap"] = EncapsulateFunction(CommGemmOverlap); // Transpose dict["te_transpose_ffi"] = EncapsulateFFI(TransposeHandler); @@ -106,9 +104,18 @@ pybind11::dict Registrations() { fused_attn_backward_ffi["execute"] = EncapsulateFFI(FusedAttnBackwardHandler); dict["te_fused_attn_backward_ffi"] = fused_attn_backward_ffi; - dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); + pybind11::dict gemm_ffi; + gemm_ffi["prepare"] = EncapsulateFFI(CublasltHandleInitHandler); + gemm_ffi["execute"] = EncapsulateFFI(GemmHandler); + dict["te_gemm_ffi"] = gemm_ffi; + + dict["te_bootstrap_comm_gemm_overlap_ffi"] = EncapsulateFFI(BootstrapCommGemmOverlapHandler); dict["te_copy_into_overlap_buffer_ffi"] = EncapsulateFFI(CopyIntoOverlapBufferHandler); - dict["te_comm_gemm_overlap_ffi"] = EncapsulateFFI(CommGemmOverlapHandler); + + pybind11::dict comm_gemm_overlap_ffi; + comm_gemm_overlap_ffi["prepare"] = EncapsulateFFI(CublasltHandleInitHandler); + comm_gemm_overlap_ffi["execute"] = EncapsulateFFI(CommGemmOverlapHandler); + dict["te_comm_gemm_overlap_ffi"] = comm_gemm_overlap_ffi; return dict; } @@ -125,8 +132,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("pack_softmax_descriptor", &PackCustomCallSoftmaxDescriptor); m.def("pack_fused_attn_descriptor", &PackCustomCallFusedAttnDescriptor); m.def("pack_gemm_descriptor", &PackCustomCallGemmDescriptor); - m.def("pack_buffer_descriptor", &PackCustomCallBufferDescriptor); - m.def("pack_overlap_descriptor", &PackCustomCallOverlapDescriptor); m.def("get_fused_attn_backend", &GetFusedAttnBackend); m.def("get_cuda_version", &GetCudaRuntimeVersion); m.def("get_cudnn_version", &GetCudnnRuntimeVersion); @@ -140,7 +145,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("bootstrap_comm_gemm_overlap", &BootstrapCommGemmOverlap); - m.def("destroy_comm_gemm_overlaps", &DestroyCommGemmOverlap); + m.def("destroy_comm_gemm_overlap", &DestroyCommGemmOverlap); m.def("set_buffer_scale_inv", &SetOverlapBufferScaleInverse, pybind11::arg(), pybind11::arg(), pybind11::arg("grad") = false); m.def("get_overlap_buffer", &GetOverlapBuffer); diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index e463f0ace2..9b9afd56ca 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -10,7 +10,6 @@ import jax import jax.numpy as jnp from jax.typing import ArrayLike -from jax.sharding import NamedSharding, PartitionSpec from transformer_engine import transformer_engine_jax as tex from .fp8 import FP8Helper, FP8MetaPackage @@ -34,6 +33,7 @@ "type_safe_gemm", "initialize_comm_gemm_overlaps", "destroy_comm_gemm_overlap", + "get_comm_gemm_overlap_config", ] _NUM_MAX_UB_STREAMS = 3 @@ -83,7 +83,13 @@ def gemm( """ comm_overlap_config = None if comm_overlap_name is not None: - comm_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(comm_overlap_name, None) + global _ACTIVE_COMM_GEMM_OVERLAPS + comm_overlap_layer = ( + comm_overlap_name + "_fprop" + if comm_overlap_name not in ["ag_gemm", "gemm_rs"] + else comm_overlap_name + ) + comm_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(comm_overlap_layer, None) if comm_overlap_config is None: warnings.warn( f"Comm+GEMM overlap for {comm_overlap_name} has not been initialized! " @@ -97,7 +103,7 @@ def gemm( ): if sanitize_dims(contracting_dims[0], x.ndim) != x.ndim - 1: x = jnp.matrix_transpose(x) - copy_into_overlap_buffer(x, comm_overlap_name, tex.CommOverlapType.RS) + copy_into_overlap_buffer(x, comm_overlap_name, True) return _gemm( x, @@ -151,11 +157,11 @@ def _gemm_fwd_rule( fuse_bias = bias is not None - # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) - # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) + # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) --> ([B], M, N/P) # # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) - # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) + # + # GEMM+RS: ([B], M, K/P) x (K/P, N) --(RS)--> ([B], M/P, N) out, pre_gelu_out, extra_out = gemm_impl( x, kernel, @@ -169,20 +175,15 @@ def _gemm_fwd_rule( comm_overlap_config=comm_overlap_config, ) - # Update returned and saved tensors based on comm+GEMM overlap - saved_x = x final_out = out - if comm_overlap_config is not None: - match comm_overlap_config.get("comm_type", None): - case tex.CommOverlapType.AG: - # AG overlap puts the all-gathered global LHS (X) into extra_out - saved_x = extra_out - case tex.CommOverlapType.RS: - # RS overlap puts the reduce-scattered sharded output into extra_out - final_out = extra_out + if (comm_overlap_config is not None + and comm_overlap_config["method"] != "bulk" + and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS): + # Non-bulk RS overlap output is in extra output, not usual output + final_out = extra_out ctx = ( - saved_x, + x, kernel, pre_gelu_out if fuse_gelu else None, fuse_bias, @@ -207,26 +208,47 @@ def _gemm_bwd_rule( ) dgrad_overlap_config = None + wgrad_overlap_config = None + dgrad_pre_rs = None if comm_overlap_config is not None: dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) + if (dgrad_overlap_config["method"] == "bulk" + and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.AG): + # If DGRAD is bulk overlap, copy input X into comm buffer to be all-gathered in + # preparation for WGRAD. + wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" + wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + assert wgrad_overlap_config is not None, "Internal TE error!" + copy_into_overlap_buffer(x, dgrad_overlap_name, True) + + # Set DGRAD output buffer to the comm buffer of WGRAD GEMM in order to do the + # bulk RS overlap without an extra memcpy + dgrad_pre_rs = tex.get_overlap_buffer(wgrad_overlap_name, False) # FWD MODE: # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) - # (DP, TP, None) --(AG)--> (DP, None, None) x (None, TP) --> (DP, None, TP) # # GEMM+AR: ([B], M, K/P) x (K/P, N) --(AR)--> ([B], M, N) - # (DP, None, TP) x (TP, None) --(AR)--> (DP, None, None) + # + # GEMM+RS: ([B], M, K/P) x (K/P, N) --(RS)--> ([B], M/P, N) - # DGRAD: - # AG+GEMM: ([B], M, N/P) x (K, N/P)^T ----(AR)----> ([B], M, K) - # (DP, None, TP) x (None, TP)^T --(AR)--> (DP, None, None) + # DGRAD w/o Overlap: + # AG+GEMM: ([B], M, N/P) x (K, N/P)^T ---(AR)---> ([B], M, K) + # + # GEMM+AR: ([B], M, N) x (K/P, N)^T ----> ([B], M, K/P) + # + # DGRAD w/ Overlap: + # AG+GEMM w/ DGRAD+RS Overlap: ([B], M, N/P) x (K, N/P)^T ---(RS)---> ([B], M/P, K) # - # GEMM+AR: ([B], M, N) x (K/P, N)^T ------> ([B], M, K/P) - # (DP, None, None) x (TP, None)^T --> (DP, None, TP) + # AG+GEMM w/ Bulk AG Overlap: ([B], M, N/P) x (K, N/P)^T -----> ([B], M, K) (deferred RS) + # ([B], M, K/P) --(Bulk AG)--> ([B], M, K) (needed in WGRAD) + # + # GEMM+RS: ([B], M/P, N) --(AG)--> ([B], M, N) x (K/P, N)^T ----> ([B], M, K/P) dgrad, dgelu, _, dgrad_extra_out = gemm_impl( grad, kernel, + out=dgrad_pre_rs, gelu_input=pre_gelu_out, batched_output=(x.ndim > 2), contracting_dims=(-1, kernel_outer_dim), @@ -238,38 +260,25 @@ def _gemm_bwd_rule( comm_overlap_config=dgrad_overlap_config, ) - # If dgrad overlapped reduce-scatter, set it to the RS output - if dgrad_overlap_config is not None: - if ( - dgrad_overlap_config["method"] != "bulk" - and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS - ): - dgrad = dgrad_extra_out - - # Collapse batch dimension for wgrad - wgrad_rhs = dgelu if fuse_gelu else grad - if x.ndim > 2: - # If x was originally transposed, we need to transpose it back in order to collapse - # the batch dims correctly. - if x_inner_dim == x.ndim - 2: - x = jnp.matrix_transpose(x) - batch_size = reduce(operator.mul, x.shape[:-2], 1) - x = jnp.reshape(x, (batch_size * x.shape[-2], x.shape[-1])) - wgrad_rhs = jnp.reshape(wgrad_rhs, (batch_size * wgrad_rhs.shape[-2], wgrad_rhs.shape[-1])) - - # Recover comm+GEMM overlap config for wgrad - wgrad_overlap_config = None - if comm_overlap_config is not None: - wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" - wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + if (dgrad_overlap_config is not None + and dgrad_overlap_config["method"] != "bulk" + and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS): + # Otherwise, if DGRAD overlap is RS overlap, DGRAD output is the extra output tensor + dgrad = dgrad_extra_out - # WGRAD: + # WGRAD w/o Overlap: # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) - # (DP, 'tp', None)^T --(AG)-->(DP, None, None)^T x (DP, None, 'tp') --> (None, 'tp') # - # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K/P, N) - # (DP, None, 'tp')^T --(AG)--> (DP, None, None)^T x (DP, None, None) ----> (None, None) - wgrad_rhs = dgelu if fuse_gelu else grad + # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K, N) + # + # WGRAD w/ Overlap: + # AG+GEMM w/ DGRAD+RS Overlap: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # + # AG+GEMM w/ Bulk Overlaps: ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # ([B], M, K) --(Bulk RS)--> ([B], M/P, K) (finalize DGRAD) + # + # GEMM+RS: ([B], M, K/P)^T x ([B], M, N) --> (K/P, N) (re-use all-gathered GRAD from DGRAD) + wgrad_rhs = dgelu if fuse_gelu else (grad if comm_overlap_config is None else dgrad_extra_out) wgrad, _, bgrad, wgrad_extra_out = gemm_impl( x, wgrad_rhs, @@ -284,13 +293,9 @@ def _gemm_bwd_rule( comm_overlap_config=wgrad_overlap_config, ) - # If wgrad overlapped reduce-scatter, set it to the RS output if wgrad_overlap_config is not None: - if ( - wgrad_overlap_config["method"] != "bulk" - and wgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS - ): - wgrad = wgrad_extra_out + # DGRAD was reduce-scattered during WGRAD GEMM, so set DGRAD to WGRAD extra output here + dgrad = wgrad_extra_out if not fuse_bias: bgrad = None @@ -362,7 +367,7 @@ def fp8_gemm( and comm_overlap_config["method"] != "bulk" and comm_overlap_config["comm_type"] == tex.CommOverlapType.AG ): - copy_into_overlap_buffer(x, comm_overlap_name, tex.CommOverlapType.RS) + copy_into_overlap_buffer(x, comm_overlap_name, True) return _fp8_gemm( x, @@ -526,18 +531,12 @@ def _fp8_gemm_fwd_rule( # Update returned and saved arrays based on comm+GEMM overlap config final_out = out - saved_casted_x = casted_x if comm_overlap_config is not None: - match comm_overlap_config.get("comm_type", None): - case tex.CommOverlapType.AG: - # AG overlap puts all-gathered global LHS (X) array into extra_out - saved_casted_x = extra_out - case tex.CommOverlapType.RS: - # RS overlap puts the reduce-scattered sharded output into extra_out - final_out = extra_out + if comm_overlap_config["comm_type"] == tex.CommOverlapType.RS: + # RS overlap puts the reduce-scattered sharded output into extra_out + final_out = extra_out ctx = ( - saved_casted_x, casted_x_t, casted_kernel, amax_list, @@ -820,29 +819,59 @@ def type_safe_gemm( def initialize_comm_gemm_overlaps( buffer_shape: Sequence[int], - buffer_dtype: jnp.dtype, - mesh: Optional[jax.sharding.Mesh] = None, - tp_resource: Optional[str] = None, - use_fp8: bool = False, - overlap_configs: Optional[dict] = None, + mesh: jax.sharding.Mesh, + myrank: int, + numranks: int, + **kwargs: Optional[dict], ) -> None: + """ + Initialize Comm+GEMM overlap communicators and buffers. + + .. warning:: + Communication buffer allocations for this functionality are outside the XLA memory pool + and can cause OOM errors if XLA's memory margin is not reduced. + + Parameters + ---------- + buffer_shape : Sequence[int] + Shape of the communication buffer. This should be sized to match the global shape of the + input/activation tensor. + mesh : jax.sharding.Mesh + JAX Mesh with a `tp_resource` axis. + myrank: int + Global rank of the calling process. + numranks: int + Global number of processes. + tp_resource : Optional[str] = None + Tensor-parallel mesh axis name. If not given, defaults to the TP resource in the global + te.sharding.MeshResource context. + tp_size : Optional[int] = None + Size of the tensor-parallel axis in the mesh. If not given, defaults to the size of the + tensor-parallel axis in `jax.interpreters.pxla.thread_resources`. + use_fp8 : bool = False + Flag for allocating an FP8 communication buffer. This is not supported for reduce-scatter + overlaps with the `pipeline` method. + overlap_configs: Optional[dict] = None, + Dictionary of configs for comm+GEMM overlaps by layer name. + """ assert tex.ubuf_built_with_mpi(), ( "Comm+GEMM overlap in TE/JAX requires Transformer Engine to be compiled with " - + "`NVTE_UB_WITH_MPI=1` and `MPI_HOME=/path/to/mpi` options." + + "`NVTE_UB_WITH_MPI=1` and `MPI_HOME=/path/to/mpi` variables." ) if not tex.device_supports_multicast(): assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " - + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." + + "CUDA Multicast. Launch with UB_SKIPMC=1 to try CUDA IPC instead." ) - - # Get # of devices in the mesh axis for comm+GEMM overlap - tp_resource = global_mesh_resource().tp_resource if tp_resource is None else tp_resource - tp_size = get_mesh_axis_size(tp_resource, mesh=mesh) + # Extract kwargs + tp_resource = kwargs.get("tp_resource", global_mesh_resource().tp_resource) + tp_size = kwargs.get("tp_size", get_mesh_axis_size(tp_resource, mesh=mesh)) + use_fp8 = kwargs.get("use_fp8", False) + overlap_configs = kwargs.get("overlap_configs", None) # Layers that support comm+GEMM overlap layers_all_gather_overlap = [ - "generic_ag", + "ag_gemm", "qkv_fprop", "qkv_dgrad", "proj_dgrad", @@ -851,7 +880,7 @@ def initialize_comm_gemm_overlaps( "fc2_dgrad", ] layers_reduce_scatter_overlap = [ - "generic_rs", + "gemm_rs", "proj_fprop", "fc2_fprop", "qkv_wgrad", @@ -862,8 +891,8 @@ def initialize_comm_gemm_overlaps( # Default overlap methods for layers methods = { "ring_exchange": [ - "generic_ag", - "generic_rs", + "ag_gemm", + "gemm_rs", "qkv_fprop", "fc1_fprop", "proj_dgrad", @@ -874,7 +903,10 @@ def initialize_comm_gemm_overlaps( } # AG-RS overlap pairs of layers forming a tensor-parallel block - ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} + ag_rs_pairs = { + "qkv_fprop": "proj_fprop", + "fc1_fprop": "fc2_fprop", + } rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} global layers_atomic_ring_exchange layers_atomic_ring_exchange = [] @@ -888,11 +920,16 @@ def get_method(name): def get_default_config(name): method = get_method(name) default_cfg = { + "mesh": mesh, + "tp_resource": tp_resource, + "tp_size": tp_size, + "name": name, "method": method, "comm_type": ( tex.CommOverlapType.AG if name in layers_all_gather_overlap else tex.CommOverlap.RS ), "num_sm": 1 if method == "ring_exchange" else 16, + "num_max_streams": _NUM_MAX_UB_STREAMS, "cga_size": 1 if method == "ring_exchange" else 2, "set_sm_margin": False, "num_splits": 4 if method == "pipeline" else tp_size, @@ -905,76 +942,75 @@ def get_default_config(name): return default_cfg def add_new_comm_gemm_overlap( - name: str, - method: str, shape: Sequence[int], - dtype: jnp.dtype, - comm_type: tex.CommOverlapType, - num_sm: int = 16, - cga_size: int = 2, - set_sm_margin: bool = False, - num_splits: int = 4, - aggregate: bool = False, - atomic_gemm: bool = False, - pipeline_rs_overlap_first_gemm: bool = False, - use_ce: bool = True, - fp8_buf: bool = False, + kwargs: dict, ) -> None: + overlap_name = kwargs["name"] assert ( - name not in _ACTIVE_COMM_GEMM_OVERLAPS - ), "Duplicate initialization for `{name}` overlap!" + overlap_name not in _ACTIVE_COMM_GEMM_OVERLAPS + ), f"Duplicate initialization for `{overlap_name}` overlap!" - if atomic_gemm: + overlap_method = kwargs["method"] + overlap_atomic_gemm = kwargs["atomic_gemm"] + if overlap_atomic_gemm: warnings.warn( "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." ) assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." - if method == "bulk": + if overlap_method == "bulk": warnings.warn( - f"At {name}, atoimic GEMM not is supported for a bulk overlap." + f"At {overlap_name}, atoimic GEMM not is supported for a bulk overlap." "Defaulting to `atomic_gemm=False`." ) - atomic_gemm = False - if method == "pipeline" and comm_type == tex.CommOverlapType.AG: + overlap_atomic_gemm = False + kwargs["atomic_gemm"] = overlap_atomic_gemm + if overlap_method == "pipeline" and kwargs["comm_type"] == tex.CommOverlapType.AG: raise ValueError( - f"At {name}, `pipeline` overlap method is not supported for AllGather." + f"At {overlap_name}, `pipeline` overlap method is not supported for AllGather." ) # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`. # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality. global layers_atomic_ring_exchange - if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs: - layers_atomic_ring_exchange += [name, ag_rs_pairs[name]] - if name in rs_ag_pairs: + if (overlap_atomic_gemm + and overlap_method == "ring_exchange" + and overlap_name in ag_rs_pairs): + layers_atomic_ring_exchange += [overlap_name, ag_rs_pairs[overlap_name]] + if overlap_name in rs_ag_pairs: assert_message = ( - f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk " - "outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and " + f"At {overlap_name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM " + "chunk outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and " "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses " "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config " "for functionality." ) - if name in layers_atomic_ring_exchange: - assert atomic_gemm and method == "ring_exchange", assert_message + if overlap_name in layers_atomic_ring_exchange: + assert overlap_atomic_gemm and overlap_method == "ring_exchange", assert_message else: - if atomic_gemm and method == "ring_exchange": - assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message + if overlap_atomic_gemm and overlap_method == "ring_exchange": + assert ( + rs_ag_pairs[overlap_name] in layers_atomic_ring_exchange + ), assert_message - dtype = jnp.uint8 if (use_fp8 and fp8_buf) else dtype + # Reduce buffer shape to 2D here in case the user initialized with batch dims + buffer_shape = (reduce(operator.mul, shape[:-1], 1), shape[-1]) tex.bootstrap_comm_gemm_overlap( - name, - method, - shape, - jax_dtype_to_te_dtype(dtype), - comm_type, + buffer_shape, + jax_dtype_to_te_dtype(jnp.uint8 if (use_fp8 and fp8_buf) else jnp.bfloat16), + overlap_name, + overlap_method, + kwargs["comm_type"], + myrank, + numranks, tp_size, - num_splits, + kwargs["num_splits"], _NUM_MAX_UB_STREAMS, - cga_size, - num_sm, - set_sm_margin, - use_ce, - atomic_gemm, - aggregate, - pipeline_rs_overlap_first_gemm, + kwargs["cga_size"], + kwargs["num_sm"], + kwargs["set_sm_margin"], + kwargs["use_ce"], + overlap_atomic_gemm, + kwargs["aggregate"], + kwargs["pipeline_rs_overlap_first_gemm"], ) if overlap_configs is not None: @@ -998,17 +1034,25 @@ def add_new_comm_gemm_overlap( for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: if overlap_configs is not None and name in overlap_configs: fp8_buf = (name in layers_all_gather_overlap) or ( - overlap_configs[name].get("fp8_buf", False) and name in methods["pipeline"] + overlap_configs[name].get("fp8_buf", False) and name not in methods["pipeline"] ) - default_config = get_default_config(name) - final_config = default_config.update(overlap_configs[name]) + final_config = get_default_config(name) + final_config.update(overlap_configs[name]) final_config["fp8_buf"] = fp8_buf - add_new_comm_gemm_overlap(name, buffer_shape, buffer_dtype, **final_config) - _ACTIVE_COMM_GEMM_OVERLAPS.update({name: final_config}) + add_new_comm_gemm_overlap(buffer_shape, final_config) + _ACTIVE_COMM_GEMM_OVERLAPS[name] = final_config def destroy_comm_gemm_overlaps(): + global _ACTIVE_COMM_GEMM_OVERLAPS for name in _ACTIVE_COMM_GEMM_OVERLAPS: tex.destroy_comm_gemm_overlap(name) - _ACTIVE_COMM_GEMM_OVERLAPS.pop(name) _ACTIVE_COMM_GEMM_OVERLAPS = dict() + + +def get_comm_overlap_config(name): + global _ACTIVE_COMM_GEMM_OVERLAPS + assert name in _ACTIVE_COMM_GEMM_OVERLAPS, ( + f"Comm+GEMM overlap for '{name}' has not been initialized!" + ) + return _ACTIVE_COMM_GEMM_OVERLAPS[name] From aa16726307aadbf64193796f77eb113778f5d0a3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:09:06 +0000 Subject: [PATCH 22/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 127 ++++++++++-------- transformer_engine/jax/csrc/extensions.h | 61 +++++---- .../jax/csrc/extensions/comm_gemm_overlap.cpp | 75 ++++++----- .../jax/csrc/extensions/gemm.cpp | 17 ++- transformer_engine/jax/gemm.py | 34 +++-- 5 files changed, 171 insertions(+), 143 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 59bf28434d..b43c644a51 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -149,9 +149,9 @@ def abstract( (lhs_aval.ndim, rhs_aval.ndim), ) if lhs_aval.ndim > 2 and rhs_aval.ndim > 2: - assert not batched_output, ( - "Batched output requires batched LHS and non-batched RHS operands." - ) + assert ( + not batched_output + ), "Batched output requires batched LHS and non-batched RHS operands." lhs_bdims = [ dim for dim in range(lhs_aval.ndim) if dim not in [lhs_outer_dim, lhs_inner_dim] ] @@ -170,17 +170,17 @@ def abstract( # Validate output dtypes out_dtype = dtypes.canonicalize_dtype(out_aval.dtype) if jax_dtype_is_fp8(out_dtype): - assert jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8(rhs_dtype), ( - "FP8 GEMM output requires FP8 inputs." - ) - assert out_amax_aval.size == out_scale_aval.size == 1, ( - "Invalid/missing output amax and scale." - ) + assert jax_dtype_is_fp8(lhs_dtype) and jax_dtype_is_fp8( + rhs_dtype + ), "FP8 GEMM output requires FP8 inputs." + assert ( + out_amax_aval.size == out_scale_aval.size == 1 + ), "Invalid/missing output amax and scale." out_amax_updated_dtype = dtypes.canonicalize_dtype(out_amax_aval.dtype) out_scale_updated_dtype = dtypes.canonicalize_dtype(out_scale_aval.dtype) - assert out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32, ( - "Invalid output amax or scale dtype." - ) + assert ( + out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32 + ), "Invalid output amax or scale dtype." else: assert out_dtype == lhs_dtype, ( "Output buffer has incorrect dtype: " @@ -194,7 +194,7 @@ def abstract( expected_out_shape = [ *lhs_aval.shape[:-2], lhs_aval.shape[lhs_outer_dim], - rhs_aval.shape[rhs_outer_dim] + rhs_aval.shape[rhs_outer_dim], ] extra_out_shape = extra_out_aval.shape expected_extra_out_shape = [0] @@ -203,18 +203,19 @@ def abstract( if batched_output: assert out_aval.ndim > 2, "Batched output buffer is missing batch dimensions." else: - expected_out_shape = [reduce(operator.mul, expected_out_shape[:-1], 1), - expected_out_shape[-1]] + expected_out_shape = [ + reduce(operator.mul, expected_out_shape[:-1], 1), + expected_out_shape[-1], + ] - if (comm_overlap_config is not None - and comm_overlap_config["method"] != "bulk"): + if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": comm_type = comm_overlap_config.get("comm_type", None) assert comm_type is not None, "Missing comm type for comm+GEMM overlap." tp_size = comm_overlap_config.get("tp_size", 1) - assert tp_size > 1, ( - "Comm+GEMM overlap requires tensor-parallel mesh axis size greater than 1." - ) + assert ( + tp_size > 1 + ), "Comm+GEMM overlap requires tensor-parallel mesh axis size greater than 1." if comm_type == tex.CommOverlapType.AG: expected_extra_out_shape = list(lhs_aval.shape).copy() @@ -246,8 +247,12 @@ def abstract( "Extra output buffer has incorrect number of dimensions: " + f"expected {len(expected_extra_out_shape)} but found {extra_out_aval.ndim}" ) - assert all([extra_out_aval.shape[i] == expected_extra_out_shape[i] - for i in range(extra_out_aval.ndim)]), ( + assert all( + [ + extra_out_aval.shape[i] == expected_extra_out_shape[i] + for i in range(extra_out_aval.ndim) + ] + ), ( "Extra output buffer has incorrect shape: " + f"expected {expected_extra_out_shape=} but found {extra_out_aval.shape=}" ) @@ -350,7 +355,7 @@ def lowering( accumulate, use_split_accumulator, comm_overlap_config, - sharded_abstract + sharded_abstract, ): """ Fused attention fwd lowering rules @@ -586,7 +591,7 @@ def impl( out_shape = ( *lhs_batch_shape, out_updated.shape[-2] // lhs_batch_size, - out_updated.shape[-1] + out_updated.shape[-1], ) out_updated = jax.lax.reshape(out_updated, out_shape) @@ -594,7 +599,7 @@ def impl( extra_out_shape = ( *lhs_batch_shape, extra_out_updated.shape[-2] // lhs_batch_size, - extra_out_updated.shape[-1] + extra_out_updated.shape[-1], ) extra_out_updated = jax.lax.reshape(extra_out_updated, extra_out_shape) @@ -653,7 +658,7 @@ def batcher( out_scale_bdims, gelu_input_bdims, bias_bdims, - extra_out_bdims + extra_out_bdims, ), ) @@ -879,8 +884,9 @@ def partition( extra_out_spec[lhs_outer_dim] = None elif comm_type == tex.CommOverlapType.RS: extra_out_spec = list(out_spec).copy() - extra_out_spec[-2] = comm_overlap_config.get("tp_resource", - global_mesh_resource().tp_resource) + extra_out_spec[-2] = comm_overlap_config.get( + "tp_resource", global_mesh_resource().tp_resource + ) extra_out_sharding = NamedSharding(mesh, PartitionSpec(*extra_out_spec)) arg_shardings = ( @@ -905,7 +911,15 @@ def partition( ) def sharded_impl( - lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, out, out_amax, out_scale, + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out, + out_amax, + out_scale, extra_out, ): ( @@ -953,7 +967,7 @@ def sharded_impl( out_scale_updated, pre_gelu_out, bias_grad, - extra_out_updated + extra_out_updated, ) return mesh, sharded_impl, out_shardings, arg_shardings @@ -993,8 +1007,7 @@ def gemm_impl( if extra_out is None: extra_out_shape = 0 - if (comm_overlap_config is not None - and comm_overlap_config["method"] != "bulk"): + if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": comm_type = comm_overlap_config["comm_type"] if comm_type == tex.CommOverlapType.AG: extra_out_shape = list(lhs.shape).copy() @@ -1012,9 +1025,9 @@ def gemm_impl( if not fuse_gelu: gelu_input = jnp.zeros(0, dtype=lhs.dtype) elif grad: - assert gelu_input is not None, ( - "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." - ) + assert ( + gelu_input is not None + ), "Backward GEMM with dGELU epilogue requires pre-GELU output from forward GEMM." elif gelu_input is None: gelu_input = jnp.zeros(out_shape_2d, dtype=lhs.dtype) @@ -1083,8 +1096,7 @@ def fp8_gemm_impl( if extra_out is None: extra_out_shape = 0 - if (comm_overlap_config is not None - and comm_overlap_config["method"] != "bulk"): + if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": comm_type = comm_overlap_config["comm_type"] if comm_type == tex.CommOverlapType.AG: extra_out_shape = list(lhs.shape).copy() @@ -1151,9 +1163,9 @@ def abstract(buffer_aval, myrank, numranks, comm_overlap_config): del myrank, numranks assert is_ffi_enabled(), "Comm+GEMM overlap is supported only via XLA FFI." overlap_name = comm_overlap_config.get("name", None) - assert overlap_name in _COMM_GEMM_OVERLAP_NAMES, ( - f"Unrecognized comm+GEMM overlap name: {overlap_name=}" - ) + assert ( + overlap_name in _COMM_GEMM_OVERLAP_NAMES + ), f"Unrecognized comm+GEMM overlap name: {overlap_name=}" assert buffer_aval.size > 0, "Cannot initialize a zero-size communication buffer." return jax.core.ShapedArray(shape=(0,), dtype=dtypes.canonicalize_dtype(buffer_aval.dtype)) @@ -1185,7 +1197,10 @@ def impl(buffer, myrank, numranks, comm_overlap_config): buffer, (reduce(operator.mul, buffer.shape[:-1], 1), buffer.shape[-1]) ) return BootstrapCommGemmOverlapPrimitive.inner_primitive.bind( - buffer, myrank=myrank, numranks=numranks, comm_overlap_config=comm_overlap_config, + buffer, + myrank=myrank, + numranks=numranks, + comm_overlap_config=comm_overlap_config, ) @staticmethod @@ -1194,14 +1209,18 @@ def batcher(batched_args, batch_dims, *, myrank, numranks, comm_overlap_config): check_valid_batch_dims(batch_dims) return ( BootstrapCommGemmOverlapPrimitive.inner_primitive.bind( - *batched_args, myrank=myrank, numranks=numranks, comm_overlap_config=comm_overlap_config + *batched_args, + myrank=myrank, + numranks=numranks, + comm_overlap_config=comm_overlap_config, ), None, ) @staticmethod - def infer_sharding_from_operands(myrank, numranks, comm_overlap_config, mesh, arg_infos, - result_infos): + def infer_sharding_from_operands( + myrank, numranks, comm_overlap_config, mesh, arg_infos, result_infos + ): del myrank, numranks, comm_overlap_config, result_infos buffer_spec = get_padded_spec(arg_infos[0]) assert all([spec is None for spec in buffer_spec]), "Sample buffer must be unsharded." @@ -1214,10 +1233,12 @@ def partition(myrank, numranks, comm_overlap_config, mesh, arg_infos, result_inf out_sharding = NamedSharding(mesh, PartitionSpec(None)) return ( mesh, - partial(BootstrapCommGemmOverlapPrimitive.impl, - myrank=myrank, - numranks=numranks, - comm_overlap_config=comm_overlap_config), + partial( + BootstrapCommGemmOverlapPrimitive.impl, + myrank=myrank, + numranks=numranks, + comm_overlap_config=comm_overlap_config, + ), out_sharding, arg_shardings, ) @@ -1227,16 +1248,10 @@ def partition(myrank, numranks, comm_overlap_config, mesh, arg_infos, result_inf def bootstrap_comm_gemm_overlap( - buffer: ArrayLike, - myrank: int, - numranks: int, - comm_overlap_config: dict + buffer: ArrayLike, myrank: int, numranks: int, comm_overlap_config: dict ): _ = BootstrapCommGemmOverlapPrimitive.outer_primitive.bind( - buffer, - myrank=myrank, - numranks=numranks, - comm_overlap_config=comm_overlap_config + buffer, myrank=myrank, numranks=numranks, comm_overlap_config=comm_overlap_config ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index fd0786a040..6bc6d02173 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -338,14 +338,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasltHandleInitHandler); void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); -Error_Type GemmFFI( - cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out, - Buffer_Type out_amax, Buffer_Type out_scale, Buffer_Type dummy_in, Result_Type out_updated, - Result_Type out_amax_updated, Result_Type out_scale_updated, Result_Type pre_gelu_out, - Result_Type bias_grad, Result_Type dummy_out, Result_Type workspace, bool lhs_trans, - bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, - bool use_split_accumulator); +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out, Buffer_Type out_amax, Buffer_Type out_scale, + Buffer_Type dummy_in, Result_Type out_updated, Result_Type out_amax_updated, + Result_Type out_scale_updated, Result_Type pre_gelu_out, Result_Type bias_grad, + Result_Type dummy_out, Result_Type workspace, bool lhs_trans, bool rhs_trans, + bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, + bool use_split_accumulator); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); @@ -357,18 +357,21 @@ pybind11::object GetOverlapBuffer(const std::string &name, bool sharded); void SetOverlapBufferScaleInverse(const std::string &name, pybind11::object scale_inv, bool grad); -void BootstrapCommGemmOverlap( - const std::vector &buffer_shape, DType buffer_dtype, const std::string &name, - const std::string &method, CommOverlapType comm_type, int64_t myrank, int64_t numranks, - int64_t tp_size, int64_t num_splits, int64_t num_max_streams, int64_t cga_size, - int64_t num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate, - bool pipeline_rs_overlap_first_gemm); - -Error_Type BootstrapCommGemmOverlapFFI( - cudaStream_t, Buffer_Type sample_buffer, std::string_view name, std::string_view method, - int64_t comm_type_flag, int64_t myrank, int64_t numranks, int64_t tp_size, int64_t num_splits, - int64_t num_max_streams, int64_t cga_size, int64_t num_comm_sm, bool set_sm_margin, - bool use_ce, bool atomic_gemm, bool aggregate, bool pipeline_rs_overlap_first_gemm); +void BootstrapCommGemmOverlap(const std::vector &buffer_shape, DType buffer_dtype, + const std::string &name, const std::string &method, + CommOverlapType comm_type, int64_t myrank, int64_t numranks, + int64_t tp_size, int64_t num_splits, int64_t num_max_streams, + int64_t cga_size, int64_t num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm); + +Error_Type BootstrapCommGemmOverlapFFI(cudaStream_t, Buffer_Type sample_buffer, + std::string_view name, std::string_view method, + int64_t comm_type_flag, int64_t myrank, int64_t numranks, + int64_t tp_size, int64_t num_splits, int64_t num_max_streams, + int64_t cga_size, int64_t num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm); XLA_FFI_DECLARE_HANDLER_SYMBOL(BootstrapCommGemmOverlapHandler); @@ -383,14 +386,16 @@ Error_Type CopyIntoOverlapBufferFFI(cudaStream_t stream, Buffer_Type input, std: XLA_FFI_DECLARE_HANDLER_SYMBOL(CopyIntoOverlapBufferHandler); -Error_Type CommGemmOverlapFFI( - cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out, - Buffer_Type out_amax, Buffer_Type out_scale, Buffer_Type extra_out, Result_Type out_updated, - Result_Type out_amax_updated, Result_Type out_scale_updated, Result_Type pre_gelu_out, - Result_Type bias_grad, Result_Type extra_out_updated, Result_Type workspace, bool lhs_trans, - bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, - bool use_split_accumulator, int64_t comm_type_flag, std::string_view name); +Error_Type CommGemmOverlapFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, + Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Buffer_Type out, Buffer_Type out_amax, + Buffer_Type out_scale, Buffer_Type extra_out, Result_Type out_updated, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, + Result_Type extra_out_updated, Result_Type workspace, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator, int64_t comm_type_flag, + std::string_view name); XLA_FFI_DECLARE_HANDLER_SYMBOL(CommGemmOverlapHandler); diff --git a/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp index d6f5daaa80..533fdc3e83 100644 --- a/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp @@ -17,7 +17,7 @@ namespace transformer_engine { namespace jax { Error_Type CublasltHandleInitFFI(Variadic_Buffer_Type args, Variadic_Result_Type rets, - Dictionary attrs) { + Dictionary attrs) { cublasLtHandle_t handle; NVTE_CHECK_CUBLAS(cublasLtCreate(&handle)); return ffi_with_cuda_error_check(); @@ -42,12 +42,13 @@ pybind11::object GetOverlapBuffer(const std::string &name, bool sharded) { return capsule; }; -void BootstrapCommGemmOverlap( - const std::vector &buffer_shape, DType buffer_dtype, const std::string &name, - const std::string &method, CommOverlapType comm_type, int64_t myrank, int64_t numranks, - int64_t tp_size, int64_t num_splits, int64_t num_max_streams, int64_t comm_cga_size, - int64_t num_comm_sm, bool set_sm_margin, bool use_ce, bool atomic_gemm, bool aggregate, - bool pipeline_rs_overlap_first_gemm) { +void BootstrapCommGemmOverlap(const std::vector &buffer_shape, DType buffer_dtype, + const std::string &name, const std::string &method, + CommOverlapType comm_type, int64_t myrank, int64_t numranks, + int64_t tp_size, int64_t num_splits, int64_t num_max_streams, + int64_t comm_cga_size, int64_t num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm) { #ifndef NVTE_UB_WITH_MPI NVTE_ERROR( std::string("Comm+GEMM overlap in TE/JAX requires bootstrapping Userbuffers with MPI. ") + @@ -57,10 +58,10 @@ void BootstrapCommGemmOverlap( // Initialize overlap object -- this allocates the comm buffer NVTE_CHECK(_overlaps.find(name) == _overlaps.end(), name, " is already initialized!"); if (method == "ring_exchange") { - _overlaps[name] = new CommOverlapP2PBase( - buffer_shape, buffer_dtype, myrank, numranks, -1, -1, -1, -1, tp_size, &_dummy_allgather, - &_dummy_barrier, comm_type, num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, - use_ce, atomic_gemm, aggregate); + _overlaps[name] = new CommOverlapP2PBase(buffer_shape, buffer_dtype, myrank, numranks, -1, -1, + -1, -1, tp_size, &_dummy_allgather, &_dummy_barrier, + comm_type, num_max_streams, comm_cga_size, num_comm_sm, + set_sm_margin, use_ce, atomic_gemm, aggregate); } else { _overlaps[name] = new CommOverlapBase( buffer_shape, buffer_dtype, myrank, numranks, -1, -1, -1, -1, tp_size, &_dummy_allgather, @@ -69,19 +70,21 @@ void BootstrapCommGemmOverlap( } }; -Error_Type BootstrapCommGemmOverlapFFI( - cudaStream_t, Buffer_Type sample_buffer, std::string_view name, std::string_view method, - int64_t comm_type_flag, int64_t myrank, int64_t numranks, int64_t tp_size, int64_t num_splits, - int64_t num_max_streams, int64_t cga_size, int64_t num_comm_sm, bool set_sm_margin, - bool use_ce, bool atomic_gemm, bool aggregate, bool pipeline_rs_overlap_first_gemm) { - auto buffer_shape = std::vector(sample_buffer.dimensions().begin(), - sample_buffer.dimensions().end()); +Error_Type BootstrapCommGemmOverlapFFI(cudaStream_t, Buffer_Type sample_buffer, + std::string_view name, std::string_view method, + int64_t comm_type_flag, int64_t myrank, int64_t numranks, + int64_t tp_size, int64_t num_splits, int64_t num_max_streams, + int64_t cga_size, int64_t num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm, bool aggregate, + bool pipeline_rs_overlap_first_gemm) { + auto buffer_shape = + std::vector(sample_buffer.dimensions().begin(), sample_buffer.dimensions().end()); auto buffer_dtype = convert_ffi_datatype_to_te_dtype(sample_buffer.element_type()); - BootstrapCommGemmOverlap( - buffer_shape, buffer_dtype, static_cast(name), static_cast(method), - static_cast(comm_type_flag), myrank, numranks, tp_size, num_splits, - num_max_streams, cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate, - pipeline_rs_overlap_first_gemm); + BootstrapCommGemmOverlap(buffer_shape, buffer_dtype, static_cast(name), + static_cast(method), + static_cast(comm_type_flag), myrank, numranks, tp_size, + num_splits, num_max_streams, cga_size, num_comm_sm, set_sm_margin, + use_ce, atomic_gemm, aggregate, pipeline_rs_overlap_first_gemm); return ffi_with_cuda_error_check(); } @@ -120,9 +123,7 @@ Error_Type DestroyCommGemmOverlapFFI(cudaStream_t stream, std::string_view name) } XLA_FFI_DEFINE_HANDLER_SYMBOL(DestroyComMGemmOverlapHandler, DestroyCommGemmOverlapFFI, - FFI::Bind() - .Ctx() - .Attr("name"), + FFI::Bind().Ctx().Attr("name"), FFI_CudaGraph_Traits); void CopyIntoOverlapBufferImpl(cudaStream_t stream, void *input_ptr, @@ -205,14 +206,16 @@ void CommGemmOverlapImpl(void *lhs, const std::vector &lhs_shape, DType } } -Error_Type CommGemmOverlapFFI( - cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out, - Buffer_Type out_amax, Buffer_Type out_scale, Buffer_Type extra_out, Result_Type out_updated, - Result_Type out_amax_updated, Result_Type out_scale_updated, Result_Type pre_gelu_out, - Result_Type bias_grad, Result_Type extra_out_updated, Result_Type workspace, bool lhs_trans, - bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, - bool use_split_accumulator, int64_t comm_type_flag, std::string_view name) { +Error_Type CommGemmOverlapFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, + Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Buffer_Type out, Buffer_Type out_amax, + Buffer_Type out_scale, Buffer_Type extra_out, Result_Type out_updated, + Result_Type out_amax_updated, Result_Type out_scale_updated, + Result_Type pre_gelu_out, Result_Type bias_grad, + Result_Type extra_out_updated, Result_Type workspace, bool lhs_trans, + bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, + bool accumulate, bool use_split_accumulator, int64_t comm_type_flag, + std::string_view name) { // Inputs auto lhs_ptr = lhs.untyped_data(); auto lhs_shape = std::vector(lhs.dimensions().begin(), lhs.dimensions().end()); @@ -232,8 +235,8 @@ Error_Type CommGemmOverlapFFI( // Outputs auto out_updated_ptr = out_updated->untyped_data(); - auto out_shape = std::vector(out_updated->dimensions().begin(), - out_updated->dimensions().end()); + auto out_shape = + std::vector(out_updated->dimensions().begin(), out_updated->dimensions().end()); auto out_dtype = convert_ffi_datatype_to_te_dtype(out_updated->element_type()); auto out_amax_updated_ptr = reinterpret_cast(out_amax_updated->untyped_data()); auto out_scale_updated_ptr = reinterpret_cast(out_scale_updated->untyped_data()); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 8f6f907268..44a2d55f8e 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -88,15 +88,14 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque desc.fuse_bias, desc.grad, desc.accumulate, desc.use_split_accumulator); } -Error_Type GemmFFI( - cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, - Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type out, - Buffer_Type out_amax, Buffer_Type out_scale, Buffer_Type dummy_in, Result_Type out_updated, - Result_Type out_amax_updated, Result_Type out_scale_updated, Result_Type pre_gelu_out, - Result_Type bias_grad, Result_Type dummy_out, Result_Type workspace, bool lhs_trans, - bool rhs_trans, bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, - bool use_split_accumulator -) { +Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, + Buffer_Type out, Buffer_Type out_amax, Buffer_Type out_scale, + Buffer_Type dummy_in, Result_Type out_updated, Result_Type out_amax_updated, + Result_Type out_scale_updated, Result_Type pre_gelu_out, Result_Type bias_grad, + Result_Type dummy_out, Result_Type workspace, bool lhs_trans, bool rhs_trans, + bool fuse_gelu, bool fuse_bias, bool grad, bool accumulate, + bool use_split_accumulator) { // Inputs auto lhs_ptr = lhs.untyped_data(); auto lhs_scale_inv_ptr = reinterpret_cast(lhs_scale_inv.untyped_data()); diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 9b9afd56ca..37d6e5328b 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -176,9 +176,11 @@ def _gemm_fwd_rule( ) final_out = out - if (comm_overlap_config is not None + if ( + comm_overlap_config is not None and comm_overlap_config["method"] != "bulk" - and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS): + and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): # Non-bulk RS overlap output is in extra output, not usual output final_out = extra_out @@ -213,8 +215,10 @@ def _gemm_bwd_rule( if comm_overlap_config is not None: dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) - if (dgrad_overlap_config["method"] == "bulk" - and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.AG): + if ( + dgrad_overlap_config["method"] == "bulk" + and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.AG + ): # If DGRAD is bulk overlap, copy input X into comm buffer to be all-gathered in # preparation for WGRAD. wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" @@ -260,9 +264,11 @@ def _gemm_bwd_rule( comm_overlap_config=dgrad_overlap_config, ) - if (dgrad_overlap_config is not None + if ( + dgrad_overlap_config is not None and dgrad_overlap_config["method"] != "bulk" - and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS): + and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): # Otherwise, if DGRAD overlap is RS overlap, DGRAD output is the extra output tensor dgrad = dgrad_extra_out @@ -971,9 +977,11 @@ def add_new_comm_gemm_overlap( # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`. # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality. global layers_atomic_ring_exchange - if (overlap_atomic_gemm + if ( + overlap_atomic_gemm and overlap_method == "ring_exchange" - and overlap_name in ag_rs_pairs): + and overlap_name in ag_rs_pairs + ): layers_atomic_ring_exchange += [overlap_name, ag_rs_pairs[overlap_name]] if overlap_name in rs_ag_pairs: assert_message = ( @@ -987,9 +995,7 @@ def add_new_comm_gemm_overlap( assert overlap_atomic_gemm and overlap_method == "ring_exchange", assert_message else: if overlap_atomic_gemm and overlap_method == "ring_exchange": - assert ( - rs_ag_pairs[overlap_name] in layers_atomic_ring_exchange - ), assert_message + assert rs_ag_pairs[overlap_name] in layers_atomic_ring_exchange, assert_message # Reduce buffer shape to 2D here in case the user initialized with batch dims buffer_shape = (reduce(operator.mul, shape[:-1], 1), shape[-1]) @@ -1052,7 +1058,7 @@ def destroy_comm_gemm_overlaps(): def get_comm_overlap_config(name): global _ACTIVE_COMM_GEMM_OVERLAPS - assert name in _ACTIVE_COMM_GEMM_OVERLAPS, ( - f"Comm+GEMM overlap for '{name}' has not been initialized!" - ) + assert ( + name in _ACTIVE_COMM_GEMM_OVERLAPS + ), f"Comm+GEMM overlap for '{name}' has not been initialized!" return _ACTIVE_COMM_GEMM_OVERLAPS[name] From a569e3b44b6e9f26823ceeb5608e0bd820c20841 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 3 Dec 2024 14:35:13 +0000 Subject: [PATCH 23/39] added comm+GEMM overlap example script Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.py | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 examples/jax/comm_gemm_overlap/comm_gemm_overlap.py diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py new file mode 100644 index 0000000000..eb8d09d7f6 --- /dev/null +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -0,0 +1,152 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Comm+GEMM Overlap with TE/JAX""" + +import argparse + +from mpi4py import MPI + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils +import numpy as np +import transformer_engine.jax as te +from transformer_engine.jax.cpp_extensions import gemm_impl +from transformer_engine.jax.gemm import ( + initialize_comm_gemm_overlaps, + destroy_comm_gemm_overlaps, + get_comm_overlap_config, +) + +jax.clear_caches() + +# This script needs to be launched via `mpirun` with 1 process per GPU +myrank = MPI.COMM_WORLD.Get_rank() +numranks = MPI.COMM_WORLD.Get_size() +jax.distributed.initialize(cluster_detection_method='mpi4py') + +parser = argparse.ArgumentParser() +parser.add_argument('-dp', '--dp-size', type=int, default=1) +parser.add_argument('-zp', '--fsdp-size', type=int, default=2) +parser.add_argument('-tp', '--tp-size', type=int, default=4) +parser.add_argument('-np', '--num-gpus', type=int, default=8) +parser.add_argument('--base-size', type=int, default=16) +parser.add_argument('--batch-size', type=int, default=4) +parser.add_argument('--no-batch', action="store_true") +parser.add_argument('--no-fsdp', action="store_true") +parser.add_argument('--comm-type', type=str.upper, default="AG", choices=["AG", "RS"]) +args = parser.parse_args() + +# GEMM problem sizing +dtype = jnp.bfloat16 +seq_length = args.base_size * 8 +hidden_size = args.base_size * 6 +ffn_hidden_size = args.base_size * 16 + +# Operand shapes +lhs_shape = ( + [seq_length, hidden_size] + if args.comm_type == "AG" + else [seq_length, ffn_hidden_size] +) +rhs_shape = ( + [hidden_size, ffn_hidden_size] + if args.comm_type == "AG" + else [ffn_hidden_size, hidden_size] +) + +# Operand partitioning +batched = not args.no_batch +fsdp = not args.no_fsdp +if batched: + lhs_shape = [args.batch_size] + lhs_shape + if fsdp: + mesh_shape = {'dp': args.dp_size, 'zp': args.fsdp_size, 'tp': args.tp_size} + mesh_resource = te.MeshResource(dp_resource='dp', tp_resource='tp', cp_resource='tp', + fsdp_resource='zp') + if args.comm_type == "AG": + input_specs = [('dp', 'zp'), 'tp', None] + weight_specs = ['zp', 'tp'] + weight_no_fsdp = [None, 'tp'] + elif args.comm_type == "RS": + input_specs = [('dp', 'zp'), None, 'tp'] + weight_specs = ['tp', 'zp'] + weight_no_fsdp = ['tp', None] + else: + mesh_shape = {'dp': args.dp_size, 'tp': args.tp_size} + mesh_resource = te.MeshResource(dp_resource='dp', tp_resource='tp', cp_resource='tp',) + if args.comm_type == "AG": + input_specs = ['dp', 'tp', None] + weight_specs = [None, 'tp'] + elif args.comm_type == "RS": + input_specs = ['dp', None, 'tp'] + weight_specs = ['tp', None] + weight_no_fsdp = weight_specs +else: + mesh_shape = {'tp': args.tp_size} + mesh_resource = te.MeshResource(tp_resource='tp', cp_resource='cp') + if args.comm_type == "AG": + input_specs = ['tp', None] + weight_specs = [None, 'tp'] + elif args.comm_type == "RS": + input_specs = [None, 'tp'] + weight_specs = ['tp', None] + weight_no_fsdp = weight_specs + +# Mesh setup and sharding definitions +devices = mesh_utils.create_device_mesh((args.num_gpus, ), devices=jax.devices()[:args.num_gpus]) +mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) +input_sharding = NamedSharding(mesh, PartitionSpec(*input_specs)) +weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_specs)) +weight_no_fsdp_sharding = NamedSharding(mesh, PartitionSpec(*weight_no_fsdp)) + +# Operand initialization +key = jax.random.PRNGKey(0) +key1, key2 = jax.random.split(key, 2) +lhs = jax.device_put(jax.random.normal(key1, lhs_shape, dtype=dtype), input_sharding) +rhs = jax.device_put(jax.random.normal(key2, rhs_shape, dtype=dtype), weight_sharding) + +# Name of comm+GEMM overlap layer +overlap_name = "ag_gemm" if args.comm_type == "AG" else "gemm_rs" + +# Bootstrap Userbuffers communicators and communication buffers +initialize_comm_gemm_overlaps( + lhs_shape, + mesh, + myrank, + numranks, + tp_resource='tp', + overlap_configs={overlap_name : dict()}, +) + +if myrank == 0: + print( + f"{myrank}: INPUTS {lhs.shape} x {rhs.shape}\n" + + f"{myrank}: LHS sharding: {lhs.sharding}\n" + + f"{myrank}: RHS sharding: {rhs.sharding}\n", + flush=True + ) + +@jax.jit +def te_gemm(A, B): + return gemm_impl(A, jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), + batched_output=True, + comm_overlap_config=get_comm_overlap_config(overlap_name)) + +with te.sharding.global_shard_guard(mesh_resource): + output, _, extra_out = te_gemm(lhs, rhs) + +if myrank == 0: + print( + f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUTS:\n" + + f"{myrank}: GEMM output: {output.shape} | {output.sharding}\n" + + f"{myrank}: {'Gathered LHS' if args.comm_type == 'AG' else 'Scattered output:'}: " + + f"{extra_out.shape} | {extra_out.sharding}\n", + flush=True + ) + +destroy_comm_gemm_overlaps() + + From 69db12ea42e0e3729b42ba874cae91aaf52f4b1d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:35:44 +0000 Subject: [PATCH 24/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../comm_gemm_overlap/comm_gemm_overlap.py | 100 +++++++++--------- 1 file changed, 51 insertions(+), 49 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index eb8d09d7f6..551fdaa0b4 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -25,18 +25,18 @@ # This script needs to be launched via `mpirun` with 1 process per GPU myrank = MPI.COMM_WORLD.Get_rank() numranks = MPI.COMM_WORLD.Get_size() -jax.distributed.initialize(cluster_detection_method='mpi4py') +jax.distributed.initialize(cluster_detection_method="mpi4py") parser = argparse.ArgumentParser() -parser.add_argument('-dp', '--dp-size', type=int, default=1) -parser.add_argument('-zp', '--fsdp-size', type=int, default=2) -parser.add_argument('-tp', '--tp-size', type=int, default=4) -parser.add_argument('-np', '--num-gpus', type=int, default=8) -parser.add_argument('--base-size', type=int, default=16) -parser.add_argument('--batch-size', type=int, default=4) -parser.add_argument('--no-batch', action="store_true") -parser.add_argument('--no-fsdp', action="store_true") -parser.add_argument('--comm-type', type=str.upper, default="AG", choices=["AG", "RS"]) +parser.add_argument("-dp", "--dp-size", type=int, default=1) +parser.add_argument("-zp", "--fsdp-size", type=int, default=2) +parser.add_argument("-tp", "--tp-size", type=int, default=4) +parser.add_argument("-np", "--num-gpus", type=int, default=8) +parser.add_argument("--base-size", type=int, default=16) +parser.add_argument("--batch-size", type=int, default=4) +parser.add_argument("--no-batch", action="store_true") +parser.add_argument("--no-fsdp", action="store_true") +parser.add_argument("--comm-type", type=str.upper, default="AG", choices=["AG", "RS"]) args = parser.parse_args() # GEMM problem sizing @@ -46,15 +46,9 @@ ffn_hidden_size = args.base_size * 16 # Operand shapes -lhs_shape = ( - [seq_length, hidden_size] - if args.comm_type == "AG" - else [seq_length, ffn_hidden_size] -) +lhs_shape = [seq_length, hidden_size] if args.comm_type == "AG" else [seq_length, ffn_hidden_size] rhs_shape = ( - [hidden_size, ffn_hidden_size] - if args.comm_type == "AG" - else [ffn_hidden_size, hidden_size] + [hidden_size, ffn_hidden_size] if args.comm_type == "AG" else [ffn_hidden_size, hidden_size] ) # Operand partitioning @@ -63,40 +57,45 @@ if batched: lhs_shape = [args.batch_size] + lhs_shape if fsdp: - mesh_shape = {'dp': args.dp_size, 'zp': args.fsdp_size, 'tp': args.tp_size} - mesh_resource = te.MeshResource(dp_resource='dp', tp_resource='tp', cp_resource='tp', - fsdp_resource='zp') + mesh_shape = {"dp": args.dp_size, "zp": args.fsdp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource( + dp_resource="dp", tp_resource="tp", cp_resource="tp", fsdp_resource="zp" + ) if args.comm_type == "AG": - input_specs = [('dp', 'zp'), 'tp', None] - weight_specs = ['zp', 'tp'] - weight_no_fsdp = [None, 'tp'] + input_specs = [("dp", "zp"), "tp", None] + weight_specs = ["zp", "tp"] + weight_no_fsdp = [None, "tp"] elif args.comm_type == "RS": - input_specs = [('dp', 'zp'), None, 'tp'] - weight_specs = ['tp', 'zp'] - weight_no_fsdp = ['tp', None] + input_specs = [("dp", "zp"), None, "tp"] + weight_specs = ["tp", "zp"] + weight_no_fsdp = ["tp", None] else: - mesh_shape = {'dp': args.dp_size, 'tp': args.tp_size} - mesh_resource = te.MeshResource(dp_resource='dp', tp_resource='tp', cp_resource='tp',) + mesh_shape = {"dp": args.dp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource( + dp_resource="dp", + tp_resource="tp", + cp_resource="tp", + ) if args.comm_type == "AG": - input_specs = ['dp', 'tp', None] - weight_specs = [None, 'tp'] + input_specs = ["dp", "tp", None] + weight_specs = [None, "tp"] elif args.comm_type == "RS": - input_specs = ['dp', None, 'tp'] - weight_specs = ['tp', None] + input_specs = ["dp", None, "tp"] + weight_specs = ["tp", None] weight_no_fsdp = weight_specs else: - mesh_shape = {'tp': args.tp_size} - mesh_resource = te.MeshResource(tp_resource='tp', cp_resource='cp') + mesh_shape = {"tp": args.tp_size} + mesh_resource = te.MeshResource(tp_resource="tp", cp_resource="cp") if args.comm_type == "AG": - input_specs = ['tp', None] - weight_specs = [None, 'tp'] + input_specs = ["tp", None] + weight_specs = [None, "tp"] elif args.comm_type == "RS": - input_specs = [None, 'tp'] - weight_specs = ['tp', None] + input_specs = [None, "tp"] + weight_specs = ["tp", None] weight_no_fsdp = weight_specs # Mesh setup and sharding definitions -devices = mesh_utils.create_device_mesh((args.num_gpus, ), devices=jax.devices()[:args.num_gpus]) +devices = mesh_utils.create_device_mesh((args.num_gpus,), devices=jax.devices()[: args.num_gpus]) mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) input_sharding = NamedSharding(mesh, PartitionSpec(*input_specs)) weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_specs)) @@ -117,8 +116,8 @@ mesh, myrank, numranks, - tp_resource='tp', - overlap_configs={overlap_name : dict()}, + tp_resource="tp", + overlap_configs={overlap_name: dict()}, ) if myrank == 0: @@ -126,14 +125,19 @@ f"{myrank}: INPUTS {lhs.shape} x {rhs.shape}\n" + f"{myrank}: LHS sharding: {lhs.sharding}\n" + f"{myrank}: RHS sharding: {rhs.sharding}\n", - flush=True + flush=True, ) + @jax.jit def te_gemm(A, B): - return gemm_impl(A, jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), - batched_output=True, - comm_overlap_config=get_comm_overlap_config(overlap_name)) + return gemm_impl( + A, + jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), + batched_output=True, + comm_overlap_config=get_comm_overlap_config(overlap_name), + ) + with te.sharding.global_shard_guard(mesh_resource): output, _, extra_out = te_gemm(lhs, rhs) @@ -144,9 +148,7 @@ def te_gemm(A, B): + f"{myrank}: GEMM output: {output.shape} | {output.sharding}\n" + f"{myrank}: {'Gathered LHS' if args.comm_type == 'AG' else 'Scattered output:'}: " + f"{extra_out.shape} | {extra_out.sharding}\n", - flush=True + flush=True, ) destroy_comm_gemm_overlaps() - - From ec2d5aecd45926959e90a5cd6a20bb2202beecdd Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 3 Dec 2024 14:38:57 +0000 Subject: [PATCH 25/39] RS overlap also works Signed-off-by: Alp Dener --- examples/jax/comm_gemm_overlap/comm_gemm_overlap.py | 3 ++- transformer_engine/jax/gemm.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index 551fdaa0b4..3637abbd50 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -4,6 +4,7 @@ """Comm+GEMM Overlap with TE/JAX""" import argparse +import numpy as np from mpi4py import MPI @@ -11,7 +12,7 @@ import jax.numpy as jnp from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.experimental import mesh_utils -import numpy as np + import transformer_engine.jax as te from transformer_engine.jax.cpp_extensions import gemm_impl from transformer_engine.jax.gemm import ( diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 37d6e5328b..59d1045080 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -932,7 +932,9 @@ def get_default_config(name): "name": name, "method": method, "comm_type": ( - tex.CommOverlapType.AG if name in layers_all_gather_overlap else tex.CommOverlap.RS + tex.CommOverlapType.AG + if name in layers_all_gather_overlap + else tex.CommOverlapType.RS ), "num_sm": 1 if method == "ring_exchange" else 16, "num_max_streams": _NUM_MAX_UB_STREAMS, From 8fe3942635c311b80de2910cebd906b250236b90 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 3 Dec 2024 14:55:18 +0000 Subject: [PATCH 26/39] added missing copy of AG+GEMM input into comm buffer Signed-off-by: Alp Dener --- .../jax/comm_gemm_overlap/comm_gemm_overlap.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index 3637abbd50..8920b1a37d 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -14,7 +14,7 @@ from jax.experimental import mesh_utils import transformer_engine.jax as te -from transformer_engine.jax.cpp_extensions import gemm_impl +from transformer_engine.jax.cpp_extensions import gemm_impl, copy_into_overlap_buffer from transformer_engine.jax.gemm import ( initialize_comm_gemm_overlaps, destroy_comm_gemm_overlaps, @@ -124,14 +124,15 @@ if myrank == 0: print( f"{myrank}: INPUTS {lhs.shape} x {rhs.shape}\n" - + f"{myrank}: LHS sharding: {lhs.sharding}\n" - + f"{myrank}: RHS sharding: {rhs.sharding}\n", + + f"{myrank}: LHS sharding: {lhs.sharding.spec}\n" + + f"{myrank}: RHS sharding: {rhs.sharding.spec}\n", flush=True, ) @jax.jit def te_gemm(A, B): + copy_into_overlap_buffer(A, overlap_name, True) return gemm_impl( A, jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), @@ -145,10 +146,9 @@ def te_gemm(A, B): if myrank == 0: print( - f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUTS:\n" - + f"{myrank}: GEMM output: {output.shape} | {output.sharding}\n" - + f"{myrank}: {'Gathered LHS' if args.comm_type == 'AG' else 'Scattered output:'}: " - + f"{extra_out.shape} | {extra_out.sharding}\n", + f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUT " + + f"{output.shape}\n" + + f"{myrank}: Sharding: {output.sharding.spec}\n", flush=True, ) From adf4046f7fd92e29ec2834778821acec177fe949 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 5 Dec 2024 19:54:23 +0000 Subject: [PATCH 27/39] updated FWD/BWD wrappers for non-FP8 and FP8 gemm Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 1 + transformer_engine/jax/gemm.py | 224 +++++++++--------- 2 files changed, 108 insertions(+), 117 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index b43c644a51..66eea09cb2 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1074,6 +1074,7 @@ def fp8_gemm_impl( bias: Optional[ArrayLike] = None, gelu_input: Optional[ArrayLike] = None, out: Optional[ArrayLike] = None, + extra_out: Optional[ArrayLike] = None, out_amax: Optional[ArrayLike] = None, out_scale: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 59d1045080..1a275ceed7 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -209,27 +209,30 @@ def _gemm_bwd_rule( mirror_dim, (x_inner_dim, kernel_inner_dim), (x.ndim, kernel.ndim) ) + # Recover DGRAD and WGRAD comm+GEMM overlap configs + dgrad_overlap_name = None dgrad_overlap_config = None + wgrad_overlap_name = None wgrad_overlap_config = None - dgrad_pre_rs = None if comm_overlap_config is not None: dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) - if ( - dgrad_overlap_config["method"] == "bulk" - and dgrad_overlap_config["comm_type"] == tex.CommOverlapType.AG - ): - # If DGRAD is bulk overlap, copy input X into comm buffer to be all-gathered in - # preparation for WGRAD. - wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" - wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) - assert wgrad_overlap_config is not None, "Internal TE error!" - copy_into_overlap_buffer(x, dgrad_overlap_name, True) + wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" + wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + dgrad_pre_rs = None + if dgrad_overlap_config is not None: + if dgrad_overlap_config["method"] == "bulk": # Set DGRAD output buffer to the comm buffer of WGRAD GEMM in order to do the - # bulk RS overlap without an extra memcpy + # bulk RS overlap without an extra memcpy. + assert wgrad_overlap_config is not None, ( + f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" + ) dgrad_pre_rs = tex.get_overlap_buffer(wgrad_overlap_name, False) + # Copy transposed input into the DGRAD overlap buffer for bulk AG. + copy_into_overlap_buffer(jnp.matrix_transpose(x), dgrad_overlap_name, True) + # FWD MODE: # AG+GEMM: ([B], M/P, K) --(AG)--> ([B], M, K) x (K, N/P) ------> ([B], M, N/P) # @@ -246,7 +249,7 @@ def _gemm_bwd_rule( # AG+GEMM w/ DGRAD+RS Overlap: ([B], M, N/P) x (K, N/P)^T ---(RS)---> ([B], M/P, K) # # AG+GEMM w/ Bulk AG Overlap: ([B], M, N/P) x (K, N/P)^T -----> ([B], M, K) (deferred RS) - # ([B], M, K/P) --(Bulk AG)--> ([B], M, K) (needed in WGRAD) + # ([B], M, K/P)^T --(Bulk AG)--> ([B], M, K)^T (needed in WGRAD) # # GEMM+RS: ([B], M/P, N) --(AG)--> ([B], M, N) x (K/P, N)^T ----> ([B], M, K/P) dgrad, dgelu, _, dgrad_extra_out = gemm_impl( @@ -272,13 +275,14 @@ def _gemm_bwd_rule( # Otherwise, if DGRAD overlap is RS overlap, DGRAD output is the extra output tensor dgrad = dgrad_extra_out + # WGRAD w/o Overlap: # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) # # GEMM+AR: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N) ---------> (K, N) # # WGRAD w/ Overlap: - # AG+GEMM w/ DGRAD+RS Overlap: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) + # AG+GEMM w/ DGRAD+RS Overlap: ([B], M, K/P)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) # # AG+GEMM w/ Bulk Overlaps: ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) # ([B], M, K) --(Bulk RS)--> ([B], M/P, K) (finalize DGRAD) @@ -299,7 +303,11 @@ def _gemm_bwd_rule( comm_overlap_config=wgrad_overlap_config, ) - if wgrad_overlap_config is not None: + if ( + wgrad_overlap_config is not None + and wgrad_overlap_config["method"] == "bulk" + and wgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): # DGRAD was reduce-scattered during WGRAD GEMM, so set DGRAD to WGRAD extra output here dgrad = wgrad_extra_out @@ -317,6 +325,7 @@ def fp8_gemm( kernel_t: ArrayLike, fp8_meta: FP8MetaPackage, bias: Optional[ArrayLike] = None, + out: Optional[ArrayLike] = None, out_dtype: jnp.dtype = jnp.bfloat16, fuse_gelu: bool = False, accumulate: bool = False, @@ -340,10 +349,12 @@ def fp8_gemm( FP8MetaPackage object carrying amax, scale and scale_inv information for the GEMM operands. bias : Optional[ArrayLike], default = `None` Optional bias term to add onto the (LHS x RHS) result. + out: Optional[ArrayLike], default = `None` + Optional empty buffer for FP8 GEMM output. out_dtype : jnp.dtype, default = `jnp.bfloat16` Data type of the FP8 GEMM output. If chosen as an FP8 dtype (i.e. `jnp.float8_e4m3fn` or `jnp.float8_e5m2`), the `fp8_meta` must also contain amax and scale information for the - GEMM output. + GEMM output. This option is overridden by the data type of the `out` buffer, if given. fuse_gelu : bool, default = `False` Enable the GELU epilogue for GEMM. This applies GELU after the bias-addition if the bias term is not `None`. @@ -389,13 +400,14 @@ def fp8_gemm( ) -@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10)) def _fp8_gemm( x: ArrayLike, kernel_t: ArrayLike, bias: ArrayLike, amax_list: ArrayLike, scale_list: ArrayLike, + out: ArrayLike, out_dtype: jnp.dtype, fuse_gelu: bool, accumulate: bool, @@ -501,14 +513,14 @@ def _fp8_gemm_fwd_rule( buffer_scale_inv = None if comm_overlap_config is not None: overlap_name = comm_overlap_config["name"] - if comm_overlap_config["method"] != "bulk" and tex.overlap_buffer_is_fp8(overlap_name): - match comm_overlap_config["comm_type"]: - case tex.CommOverlapType.AG: - buffer_scale_inv = x_scale_inv + if comm_overlap_config["comm_type"] == tex.CommOverlapType.AG: + buffer_scale_inv = x_scale_inv - case tex.CommOverlapType.RS: - buffer_scale_inv = jnp.reciprocal(out_scale) + elif comm_overlap_config["comm_type"] == tex.CommOverlapType.RS: + out_dtype = fwd_dtype + out_scale = scale_list[FP8MetaPackage.OUTPUT_IDX][0:1] + buffer_scale_inv = jnp.reciprocal(out_scale) tex.set_overlap_buffer_scale_inverse( overlap_name, @@ -531,9 +543,6 @@ def _fp8_gemm_fwd_rule( use_split_accumulator=use_split_accumulator, comm_overlap_config=comm_overlap_config, ) - if not jax_dtype_is_fp8(out_dtype): - updated_out_amax = None - updated_out_scale = None # Update returned and saved arrays based on comm+GEMM overlap config final_out = out @@ -542,6 +551,10 @@ def _fp8_gemm_fwd_rule( # RS overlap puts the reduce-scattered sharded output into extra_out final_out = extra_out + if not jax_dtype_is_fp8(final_out): + updated_out_amax = None + updated_out_scale = None + ctx = ( casted_x_t, casted_kernel, @@ -583,9 +596,21 @@ def _fp8_gemm_bwd_rule( maybe_fp32_to_fm32, batched_input, ) = ctx - + del out_dtype bwd_dtype = FP8Helper.BWD_DTYPE + # Recover DGRAD and WGRAD comm+GEMM overlap configs + dgrad_overlap_name = None + dgrad_overlap_config = None + wgrad_overlap_name = None + wgrad_overlap_config = None + if comm_overlap_config is not None: + dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" + dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) + wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" + wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + + # Cast-transpose grad with potential fusions grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1] grad_scale = scale_list[FP8MetaPackage.GRAD_IDX] grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_ID] @@ -633,28 +658,29 @@ def _fp8_gemm_bwd_rule( ) bgrad = None - # Recover dgrad comm+GEMM overlap config - dgrad_overlap_config = None - if comm_overlap_config is not None: - dgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_dgrad" - dgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(dgrad_overlap_name, None) - # Set scale_inv for comm overlap buffer - dgrad_out_dtype = jnp.bfloat16 dgrad_amax = None dgrad_scale = None - if ( - dgrad_overlap_config is not None - and dgrad_overlap_config["method"] != "bulk" - and tex.overlap_buffer_is_fp8(dgrad_overlap_name) - ): - dgrad_out_dtype = bwd_dtype - dgrad_amax = grad_amax - dgrad_scale = grad_scale - tex.set_overlap_buffer_scale_inverse( - dgrad_overlap_name, - jax.dlpack.to_dlpack(grad_scale_inv), - ) + if dgrad_overlap_config is not None: + if dgrad_overlap_config["method"] == "bulk": + assert wgrad_overlap_config is not None, ( + f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" + ) + # Set WGRAD buffer as output of DGRAD in order to avoid a memcpy for bulk RS overlap + dgrad_pre_rs = jax.dlpack.from_dlpack( + tex.get_overlap_buffer(wgrad_overlap_name, False) + ) + # Copy input into overlap buffer for all-gather + copy_into_overlap_buffer(casted_x_t, dgrad_overlap_name, True) + + elif tex.overlap_buffer_is_fp8(dgrad_overlap_name): + # Non-bulk RS DGRAD overlap needs output amax and scale if buffer type is FP8 + dgrad_amax = grad_amax + dgrad_scale = grad_scale + tex.set_overlap_buffer_scale_inverse( + dgrad_overlap_name, + jax.dlpack.to_dlpack(grad_scale_inv), + ) # DGRAD: ([B], M, N) x (K, N)^T = ([B], M, K) kernel_scale_inv = scale_inv_list[FP8MetaPackage.WEIGHT_IDX] @@ -663,11 +689,9 @@ def _fp8_gemm_bwd_rule( grad_scale_inv, casted_kernel, kernel_scale_inv, - None, - None, - dgrad_amax, - dgrad_scale, - out_dtype=dgrad_out_dtype, + out=dgrad_pre_rs, + out_amax=dgrad_amax, + out_scale=dgrad_scale, batched_output=batched_input, accumulate=accumulate, use_split_accumulator=use_split_accumulator, @@ -682,65 +706,29 @@ def _fp8_gemm_bwd_rule( ): dgrad = dgrad_extra_out - if fuse_gelu and fuse_bias: - # Fuse bgrad with dGELU. - _, casted_dgelu_t, bgrad, updated_grad_amax = dact_lu_dbias_cast_transpose( - grad, - pre_gelu_out, - grad_amax, - grad_scale, - grad_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1, - activation_type=("gelu",), - ) - elif fuse_gelu: - # No bias grad to fuse so we just do dGELU. - _, casted_dgelu_t, updated_grad_amax = dact_lu(grad, pre_gelu_out, ("gelu",)) - bgrad = None - - # Recover wgrad config - wgrad_overlap_config = None - if comm_overlap_config is not None: - wgrad_overlap_name = comm_overlap_config["name"].rstrip("_fprop") + "_wgrad" - wgrad_overlap_config = _ACTIVE_COMM_GEMM_OVERLAPS.get(wgrad_overlap_name, None) + # Prepare comm+GEMM overlap for WGRAD + if wgrad_overlap_config is not None: + if wgrad_overlap_config["method"] == "bulk": + # Get all-gathered input from DGRAD bulk overlap + casted_x_t = jax.dlpack.from_dlpack( + tex.get_overlap_buffer(dgrad_overlap_name, False) + ) - # Set scale_inv for comm overlap buffer - wgrad_out_dtype = jnp.bfloat16 - wgrad_amax = None - wgrad_scale = None - if ( - wgrad_overlap_config is not None - and wgrad_overlap_config["method"] != "bulk" - and tex.overlap_buffer_is_fp8(wgrad_overlap_name) - ): - match wgrad_overlap_config["comm_type"]: - case tex.CommOverlapType.AG: - buffer_scale_inv = x_scale_inv - case tex.CommOverlapType.RS: - buffer_scale_inv = grad_scale_inv - wgrad_out_dtype = bwd_dtype - wgrad_amax = grad_amax - wgrad_scale = grad_scale - tex.set_overlap_buffer_scale_inverse( - dgrad_overlap_name, - jax.dlpack.to_dlpack(buffer_scale_inv), - ) + elif tex.overlap_buffer_is_fp8(wgrad_overlap_name): + # Set FP8 scale inverse for non-bulk AG overlap + tex.set_overlap_buffer_scale_inverse( + wgrad_overlap_name, + jax.dlpack.to_dlpack(x_scale_inv) + ) # WGRAD: ([B], N, M) x ([B], K, M)^T = (N, K) - wgrad_rhs_t = casted_dgelu_t if fuse_gelu else casted_grad_t x_scale_inv = scale_inv_list[FP8MetaPackage.INPUT_IDX] wgrad, *_, wgrad_extra_out = fp8_gemm_impl( casted_x_t, x_scale_inv, - wgrad_rhs_t, + casted_grad_t, grad_scale_inv, - None, - None, - wgrad_amax, - wgrad_scale, - out_dtype=wgrad_out_dtype, + out_dtype=jnp.bfloat16, batched_output=False, accumulate=accumulate, use_split_accumulator=use_split_accumulator, @@ -753,7 +741,7 @@ def _fp8_gemm_bwd_rule( and wgrad_overlap_config["method"] != "bulk" and wgrad_overlap_config["comm_type"] == tex.CommOverlapType.RS ): - wgrad = wgrad_extra_out + dgrad = wgrad_extra_out amax_list[FP8MetaPackage.INPUT_IDX] = ( amax_list[FP8MetaPackage.INPUT_IDX].at[0].set(updated_x_amax[0]) @@ -764,7 +752,7 @@ def _fp8_gemm_bwd_rule( amax_list[FP8MetaPackage.GRAD_IDX] = ( amax_list[FP8MetaPackage.GRAD_IDX].at[0].set(updated_grad_amax[0]) ) - if out_dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2]: + if updated_out_amax is not None: amax_list[FP8MetaPackage.OUTPUT_IDX] = ( amax_list[FP8MetaPackage.OUTPUT_IDX].at[0].set(updated_out_amax[0]) ) @@ -782,8 +770,9 @@ def type_safe_gemm( x: ArrayLike, kernel: ArrayLike, bias: Optional[ArrayLike] = None, - fp8_meta: Optional[FP8MetaPackage] = None, + out: Optional[ArrayLike] = None, out_dtype: Optional[jnp.dtype] = None, + fp8_meta: Optional[FP8MetaPackage] = None, contracting_dims: Tuple[int, int] = (-1, -2), fuse_gelu: bool = False, accumulate: bool = False, @@ -802,24 +791,25 @@ def type_safe_gemm( return fp8_gemm( x, kernel, - bias, fp8_meta, - out_dtype, - fuse_gelu, - accumulate, - use_split_accumulator, - comm_overlap_name, + bias=bias, + out=out, + out_dtype=out_dtype, + fuse_gelu=fuse_gelu, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_name=comm_overlap_name, ) else: return gemm( x, kernel, - bias, - contracting_dims, - fuse_gelu, - accumulate, - use_split_accumulator, - comm_overlap_name, + bias=bias, + contracting_dims=contracting_dims, + fuse_gelu=fuse_gelu, + accumulate=accumulate, + use_split_accumulator=use_split_accumulator, + comm_overlap_name=comm_overlap_name, ) From c4c608b54540c7cfc48f867e3e01bab15c4593cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:54:48 +0000 Subject: [PATCH 28/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/gemm.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 1a275ceed7..7024dcb9fe 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -225,9 +225,9 @@ def _gemm_bwd_rule( if dgrad_overlap_config["method"] == "bulk": # Set DGRAD output buffer to the comm buffer of WGRAD GEMM in order to do the # bulk RS overlap without an extra memcpy. - assert wgrad_overlap_config is not None, ( - f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" - ) + assert ( + wgrad_overlap_config is not None + ), f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" dgrad_pre_rs = tex.get_overlap_buffer(wgrad_overlap_name, False) # Copy transposed input into the DGRAD overlap buffer for bulk AG. @@ -275,7 +275,6 @@ def _gemm_bwd_rule( # Otherwise, if DGRAD overlap is RS overlap, DGRAD output is the extra output tensor dgrad = dgrad_extra_out - # WGRAD w/o Overlap: # AG+GEMM: ([B], M/P, K)^T --(AG)--> ([B], M, K)^T x ([B], M, N/P) --> (K, N/P) # @@ -663,13 +662,11 @@ def _fp8_gemm_bwd_rule( dgrad_scale = None if dgrad_overlap_config is not None: if dgrad_overlap_config["method"] == "bulk": - assert wgrad_overlap_config is not None, ( - f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" - ) + assert ( + wgrad_overlap_config is not None + ), f"Missing comm+GEMM overlap config for {wgrad_overlap_name}!" # Set WGRAD buffer as output of DGRAD in order to avoid a memcpy for bulk RS overlap - dgrad_pre_rs = jax.dlpack.from_dlpack( - tex.get_overlap_buffer(wgrad_overlap_name, False) - ) + dgrad_pre_rs = jax.dlpack.from_dlpack(tex.get_overlap_buffer(wgrad_overlap_name, False)) # Copy input into overlap buffer for all-gather copy_into_overlap_buffer(casted_x_t, dgrad_overlap_name, True) @@ -710,15 +707,12 @@ def _fp8_gemm_bwd_rule( if wgrad_overlap_config is not None: if wgrad_overlap_config["method"] == "bulk": # Get all-gathered input from DGRAD bulk overlap - casted_x_t = jax.dlpack.from_dlpack( - tex.get_overlap_buffer(dgrad_overlap_name, False) - ) + casted_x_t = jax.dlpack.from_dlpack(tex.get_overlap_buffer(dgrad_overlap_name, False)) elif tex.overlap_buffer_is_fp8(wgrad_overlap_name): # Set FP8 scale inverse for non-bulk AG overlap tex.set_overlap_buffer_scale_inverse( - wgrad_overlap_name, - jax.dlpack.to_dlpack(x_scale_inv) + wgrad_overlap_name, jax.dlpack.to_dlpack(x_scale_inv) ) # WGRAD: ([B], N, M) x ([B], K, M)^T = (N, K) From 4707df3bfaa8bc1c6a1de1caec1fa4ee19dc0902 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 5 Dec 2024 22:41:27 +0000 Subject: [PATCH 29/39] added more documentation to the TE/JAX comm+GEMM overlap example Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.py | 26 ++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index 8920b1a37d..e3d72a9849 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -14,6 +14,7 @@ from jax.experimental import mesh_utils import transformer_engine.jax as te +from transformer_engine import transformer_engine_jax as tex from transformer_engine.jax.cpp_extensions import gemm_impl, copy_into_overlap_buffer from transformer_engine.jax.gemm import ( initialize_comm_gemm_overlaps, @@ -118,7 +119,18 @@ myrank, numranks, tp_resource="tp", - overlap_configs={overlap_name: dict()}, + overlap_configs={ + overlap_name: { + "method": "ring_exchange", # "pipeline" for collective kernels instead of send/recv + "comm_type": tex.CommOverlapType if args.comm_type == "AG" else tex.CommOverlapType.RS, + "num_splits": args.tp_size, # independent of TP size for "pipeline" + "cga_size": 1, # default is 2 for "pipeline" + "num_sm": 1, # ignored for "ring_exchange", must be tuned for "pipeline" + "set_sm_margin": False, # set to True for "pipeline" + "atomic_gemm": False, # more performant when not using CUDA Graphs + "use_ce": True, # ignored (always False) for "pipeline" method + } + }, ) if myrank == 0: @@ -132,11 +144,16 @@ @jax.jit def te_gemm(A, B): + # LHS needs to be copied into the comm. buffer before GEMM. This can usually be circumvented by + # extracting the comm. buffer as a JAX array via + # `buffer = jax.dlpack.from_dlpack(tex.get_overlap_buffer(overlap_name: str, sharded: bool))` + # and directly writing the result of a preceding operation into it (e.g.. LayerNorm output + # written directly into the communication buffer before AG+GEMM in a QKV projection) copy_into_overlap_buffer(A, overlap_name, True) return gemm_impl( A, - jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), - batched_output=True, + jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), # all-gather FSDP weights + batched_output=True, # internal option, will be hidden by the FWD/BWD wrapper comm_overlap_config=get_comm_overlap_config(overlap_name), ) @@ -144,6 +161,9 @@ def te_gemm(A, B): with te.sharding.global_shard_guard(mesh_resource): output, _, extra_out = te_gemm(lhs, rhs) +if args.comm_type == "RS": + output = extra_out + if myrank == 0: print( f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUT " From 18a62496be25d9d0a0039c5bf3e2c5acaff671f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Dec 2024 22:43:33 +0000 Subject: [PATCH 30/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/comm_gemm_overlap/comm_gemm_overlap.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index e3d72a9849..b968c4ef62 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -121,14 +121,14 @@ tp_resource="tp", overlap_configs={ overlap_name: { - "method": "ring_exchange", # "pipeline" for collective kernels instead of send/recv + "method": "ring_exchange", # "pipeline" for collective kernels instead of send/recv "comm_type": tex.CommOverlapType if args.comm_type == "AG" else tex.CommOverlapType.RS, - "num_splits": args.tp_size, # independent of TP size for "pipeline" - "cga_size": 1, # default is 2 for "pipeline" - "num_sm": 1, # ignored for "ring_exchange", must be tuned for "pipeline" - "set_sm_margin": False, # set to True for "pipeline" - "atomic_gemm": False, # more performant when not using CUDA Graphs - "use_ce": True, # ignored (always False) for "pipeline" method + "num_splits": args.tp_size, # independent of TP size for "pipeline" + "cga_size": 1, # default is 2 for "pipeline" + "num_sm": 1, # ignored for "ring_exchange", must be tuned for "pipeline" + "set_sm_margin": False, # set to True for "pipeline" + "atomic_gemm": False, # more performant when not using CUDA Graphs + "use_ce": True, # ignored (always False) for "pipeline" method } }, ) @@ -152,8 +152,8 @@ def te_gemm(A, B): copy_into_overlap_buffer(A, overlap_name, True) return gemm_impl( A, - jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), # all-gather FSDP weights - batched_output=True, # internal option, will be hidden by the FWD/BWD wrapper + jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), # all-gather FSDP weights + batched_output=True, # internal option, will be hidden by the FWD/BWD wrapper comm_overlap_config=get_comm_overlap_config(overlap_name), ) From b1449417044e132c7a80b13a6756581964f836df Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Mon, 9 Dec 2024 16:48:41 +0000 Subject: [PATCH 31/39] fixed RS overlap in the example Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.py | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index e3d72a9849..a283ca62a2 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -21,6 +21,7 @@ destroy_comm_gemm_overlaps, get_comm_overlap_config, ) +from transformer_engine.jax.sharding import get_padded_spec jax.clear_caches() @@ -122,7 +123,11 @@ overlap_configs={ overlap_name: { "method": "ring_exchange", # "pipeline" for collective kernels instead of send/recv - "comm_type": tex.CommOverlapType if args.comm_type == "AG" else tex.CommOverlapType.RS, + "comm_type": ( + tex.CommOverlapType.AG + if args.comm_type == "AG" + else tex.CommOverlapType.RS + ), "num_splits": args.tp_size, # independent of TP size for "pipeline" "cga_size": 1, # default is 2 for "pipeline" "num_sm": 1, # ignored for "ring_exchange", must be tuned for "pipeline" @@ -144,31 +149,34 @@ @jax.jit def te_gemm(A, B): - # LHS needs to be copied into the comm. buffer before GEMM. This can usually be circumvented by - # extracting the comm. buffer as a JAX array via + # For AG overlap, LHS needs to be copied into the comm. buffer before GEMM. This can usually + # be circumvented by extracting the comm. buffer as a JAX array via # `buffer = jax.dlpack.from_dlpack(tex.get_overlap_buffer(overlap_name: str, sharded: bool))` # and directly writing the result of a preceding operation into it (e.g.. LayerNorm output # written directly into the communication buffer before AG+GEMM in a QKV projection) - copy_into_overlap_buffer(A, overlap_name, True) + if args.comm_type == "AG": + copy_into_overlap_buffer(A, overlap_name, True) + return_idx = 0 + else: + # For RS overlap, the scattered output is in the `extra_out` array. + return_idx = -1 + return gemm_impl( A, jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), # all-gather FSDP weights batched_output=True, # internal option, will be hidden by the FWD/BWD wrapper comm_overlap_config=get_comm_overlap_config(overlap_name), - ) + )[return_idx] with te.sharding.global_shard_guard(mesh_resource): - output, _, extra_out = te_gemm(lhs, rhs) - -if args.comm_type == "RS": - output = extra_out + output = te_gemm(lhs, rhs) if myrank == 0: print( f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUT " + f"{output.shape}\n" - + f"{myrank}: Sharding: {output.sharding.spec}\n", + + f"{myrank}: Sharding: {get_padded_spec(output.sharding.spec, output.ndim)}\n", flush=True, ) From 6ad56517df0d19017c38efe9ad5374dfe000b948 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 16:50:20 +0000 Subject: [PATCH 32/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../comm_gemm_overlap/comm_gemm_overlap.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index 0bfcbb5830..77266539e1 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -122,20 +122,18 @@ tp_resource="tp", overlap_configs={ overlap_name: { - "method": "ring_exchange", # "pipeline" for collective kernels instead of send/recv + "method": "ring_exchange", # "pipeline" for collective kernels instead of send/recv "comm_type": ( - tex.CommOverlapType.AG - if args.comm_type == "AG" - else tex.CommOverlapType.RS + tex.CommOverlapType.AG if args.comm_type == "AG" else tex.CommOverlapType.RS ), - "num_splits": args.tp_size, # independent of TP size for "pipeline" - "cga_size": 1, # default is 2 for "pipeline" - "num_sm": 1, # ignored for "ring_exchange", must be tuned for "pipeline" - "set_sm_margin": False, # set to True for "pipeline" - "atomic_gemm": False, # more performant when not using CUDA Graphs - "use_ce": True, # ignored (always False) for "pipeline" method + "num_splits": args.tp_size, # independent of TP size for "pipeline" + "cga_size": 1, # default is 2 for "pipeline" + "num_sm": 1, # ignored for "ring_exchange", must be tuned for "pipeline" + "set_sm_margin": False, # set to True for "pipeline" + "atomic_gemm": False, # more performant when not using CUDA Graphs + "use_ce": True, # ignored (always False) for "pipeline" method }, - } + }, ) if myrank == 0: From 5aceb02cc764478ee8695bc79902c11dbd3b7ca5 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 14 Jan 2025 21:35:43 +0000 Subject: [PATCH 33/39] updated comm overlap JAX example with numerical correctness check Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.py | 90 +++++++++++++++---- transformer_engine/jax/cpp_extensions/gemm.py | 6 +- 2 files changed, 77 insertions(+), 19 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index 77266539e1..cde89b95f7 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -5,7 +5,7 @@ import argparse import numpy as np - +from functools import partial from mpi4py import MPI import jax @@ -40,13 +40,15 @@ parser.add_argument("--no-batch", action="store_true") parser.add_argument("--no-fsdp", action="store_true") parser.add_argument("--comm-type", type=str.upper, default="AG", choices=["AG", "RS"]) +parser.add_argument("--check-result", action="store_true") +parser.add_argument("--std", type=float, default=0.023) args = parser.parse_args() # GEMM problem sizing dtype = jnp.bfloat16 -seq_length = args.base_size * 8 -hidden_size = args.base_size * 6 -ffn_hidden_size = args.base_size * 16 +seq_length = 2 # args.base_size * 8 +hidden_size = 4 # args.base_size * 6 +ffn_hidden_size = 6 # args.base_size * 16 # Operand shapes lhs_shape = [seq_length, hidden_size] if args.comm_type == "AG" else [seq_length, ffn_hidden_size] @@ -87,19 +89,31 @@ weight_specs = ["tp", None] weight_no_fsdp = weight_specs else: - mesh_shape = {"tp": args.tp_size} - mesh_resource = te.MeshResource(tp_resource="tp", cp_resource="cp") - if args.comm_type == "AG": - input_specs = ["tp", None] - weight_specs = [None, "tp"] - elif args.comm_type == "RS": - input_specs = [None, "tp"] - weight_specs = ["tp", None] - weight_no_fsdp = weight_specs + if fsdp: + mesh_shape = {"zp": args.fsdp_size, "tp": args.tp_size} + mesh_resource = te.MeshResource(fsdp_resource="zp", tp_resource="tp", cp_resource="cp") + if args.comm_type == "AG": + input_specs = ["tp", None] + weight_specs = ["zp", "tp"] + elif args.comm_type == "RS": + input_specs = [None, "tp"] + weight_specs = ["tp", "zp"] + weight_no_fsdp = ["tp", None] + else: + mesh_shape = {"tp": args.tp_size} + mesh_resource = te.MeshResource(tp_resource="tp", cp_resource="cp") + if args.comm_type == "AG": + input_specs = ["tp", None] + weight_specs = [None, "tp"] + elif args.comm_type == "RS": + input_specs = [None, "tp"] + weight_specs = ["tp", None] + weight_no_fsdp = weight_specs # Mesh setup and sharding definitions devices = mesh_utils.create_device_mesh((args.num_gpus,), devices=jax.devices()[: args.num_gpus]) mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) +no_sharding = NamedSharding(mesh, PartitionSpec(None)) input_sharding = NamedSharding(mesh, PartitionSpec(*input_specs)) weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_specs)) weight_no_fsdp_sharding = NamedSharding(mesh, PartitionSpec(*weight_no_fsdp)) @@ -107,8 +121,10 @@ # Operand initialization key = jax.random.PRNGKey(0) key1, key2 = jax.random.split(key, 2) -lhs = jax.device_put(jax.random.normal(key1, lhs_shape, dtype=dtype), input_sharding) -rhs = jax.device_put(jax.random.normal(key2, rhs_shape, dtype=dtype), weight_sharding) +lhs_data = jax.random.normal(key1, lhs_shape, dtype=dtype) +rhs_data = jax.random.normal(key2, rhs_shape, dtype=dtype) +lhs = jax.device_put(lhs_data, input_sharding) +rhs = jax.device_put(rhs_data, weight_sharding) # Name of comm+GEMM overlap layer overlap_name = "ag_gemm" if args.comm_type == "AG" else "gemm_rs" @@ -162,7 +178,7 @@ def te_gemm(A, B): return gemm_impl( A, jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), # all-gather FSDP weights - batched_output=True, # internal option, will be hidden by the FWD/BWD wrapper + batched_output=not args.no_batch, # internal option, will be hidden by the FWD/BWD wrapper comm_overlap_config=get_comm_overlap_config(overlap_name), )[return_idx] @@ -178,4 +194,46 @@ def te_gemm(A, B): flush=True, ) +if args.check_result: + ref_global = jnp.matmul(jax.device_put(lhs_data, no_sharding), + jax.device_put(rhs_data, no_sharding)) + if myrank == 0: + print(f"{myrank}: Global reference: {ref_global}\n", flush=True) + + output_global = jax.lax.with_sharding_constraint(output, no_sharding) + if myrank == 0: + print(f"{myrank}: Global output: {output_global}\n", flush=True) + + diff = jnp.abs(ref_global - output_global).flatten() + if myrank == 0: + print(f"{myrank}: Global difference: {diff}\n", flush=True) + + m = jnp.argmax(diff).item() + abs_err = diff[m].item() + rel_err = abs_err / max(abs(ref_global.flatten()[m]), 1e-5) + + rtol = 0.02 + atol = 0.001 + numerics_failed = False + if rel_err > rtol and abs_err > atol: + numerics_failed = True + numerics_info = ( + "NUMERICAL CHECK FAILED: " + + f"Outputs not close enough at index {m} " + + f"with {output.flatten()[m].item()} vs {ref_global.flatten()[m].item()} | " + + f"rel. error = {rel_err} (tol = {rtol}) | " + + f"abs. error = {abs_err} (tol = {atol})" + ) + else: + numerics_info = "NUMERICAL CHECK PASSED: " + if rel_err <= rtol: + numerics_info += f"rel. error = {rel_err} (tol = {rtol})" + ( + " | " if abs_err < atol else "" + ) + if abs_err <= atol: + numerics_info += f"abs. error = {abs_err} (tol = {atol})" + + if myrank == 0: + print(numerics_info + "\n", end="", flush=True) + destroy_comm_gemm_overlaps() diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 66eea09cb2..6bab28e4b3 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1086,7 +1086,7 @@ def fp8_gemm_impl( comm_overlap_config: Optional[dict] = None, ) -> Tuple[ArrayLike, ...]: """FP8 mat-mul with `nvte_cublas_gemm()` custom op.""" - out_shape_batched = (*lhs.shape[:-2], lhs.shape[-1], rhs_t.shape[-1]) + out_shape_batched = (*lhs.shape[:-2], lhs.shape[-2], rhs_t.shape[-2]) out_shape_2d = (reduce(operator.mul, out_shape_batched[:-1], 1), out_shape_batched[-1]) out_shape = out_shape_batched if batched_output else out_shape_2d @@ -1123,10 +1123,10 @@ def fp8_gemm_impl( (out, out_amax, out_scale, pre_gelu_out, _, extra_out) = ( # bias_grad in non-FP8 GEMM CollectiveGemmPrimitive.outer_primitive.bind( - rhs_t, - rhs_scale_inv, lhs, lhs_scale_inv, + rhs_t, + rhs_scale_inv, bias, gelu_input, out, From f2b2a5bbe85852be286ea3f78540c464570a7f2a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Jan 2025 21:36:34 +0000 Subject: [PATCH 34/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/jax/comm_gemm_overlap/comm_gemm_overlap.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index cde89b95f7..009b69f2f2 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -195,8 +195,9 @@ def te_gemm(A, B): ) if args.check_result: - ref_global = jnp.matmul(jax.device_put(lhs_data, no_sharding), - jax.device_put(rhs_data, no_sharding)) + ref_global = jnp.matmul( + jax.device_put(lhs_data, no_sharding), jax.device_put(rhs_data, no_sharding) + ) if myrank == 0: print(f"{myrank}: Global reference: {ref_global}\n", flush=True) From 19482547a27479de15768bb6560fa0b0af857b78 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 15 Jan 2025 22:48:50 +0000 Subject: [PATCH 35/39] changed commandline size controls to directly modify sequence length, num heads, head dim and activation size Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index cde89b95f7..f1e07dd78f 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -35,25 +35,29 @@ parser.add_argument("-zp", "--fsdp-size", type=int, default=2) parser.add_argument("-tp", "--tp-size", type=int, default=4) parser.add_argument("-np", "--num-gpus", type=int, default=8) -parser.add_argument("--base-size", type=int, default=16) parser.add_argument("--batch-size", type=int, default=4) +parser.add_argument("--seq-length", type=int, default=8192) +parser.add_argument("--num-heads", type=int, default=128) +parser.add_argument("--head-dim", type=int, default=128) +parser.add_argument("--activation-size", type=int, default=53248) parser.add_argument("--no-batch", action="store_true") parser.add_argument("--no-fsdp", action="store_true") parser.add_argument("--comm-type", type=str.upper, default="AG", choices=["AG", "RS"]) parser.add_argument("--check-result", action="store_true") -parser.add_argument("--std", type=float, default=0.023) args = parser.parse_args() -# GEMM problem sizing -dtype = jnp.bfloat16 -seq_length = 2 # args.base_size * 8 -hidden_size = 4 # args.base_size * 6 -ffn_hidden_size = 6 # args.base_size * 16 - # Operand shapes -lhs_shape = [seq_length, hidden_size] if args.comm_type == "AG" else [seq_length, ffn_hidden_size] +dtype = jnp.bfloat16 +hidden_size = args.num_heads * args.head_dim +lhs_shape = ( + [args.seq_length, hidden_size] + if args.comm_type == "AG" + else [args.seq_length, args.activation_size] +) rhs_shape = ( - [hidden_size, ffn_hidden_size] if args.comm_type == "AG" else [ffn_hidden_size, hidden_size] + [hidden_size, args.activation_size] + if args.comm_type == "AG" + else [args.activation_size, hidden_size] ) # Operand partitioning From b2720cb445112da6d2ef9982cbdb9e51bb26883a Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 17 Jan 2025 03:42:37 +0000 Subject: [PATCH 36/39] fixed incorrect chunking of cuBLAS workspace Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.py | 12 ++--- transformer_engine/jax/cpp_extensions/gemm.py | 52 ++++++++++++------- 2 files changed, 38 insertions(+), 26 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index e9d280621e..2bc3651e51 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -35,10 +35,9 @@ parser.add_argument("-zp", "--fsdp-size", type=int, default=2) parser.add_argument("-tp", "--tp-size", type=int, default=4) parser.add_argument("-np", "--num-gpus", type=int, default=8) -parser.add_argument("--batch-size", type=int, default=4) +parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--seq-length", type=int, default=8192) -parser.add_argument("--num-heads", type=int, default=128) -parser.add_argument("--head-dim", type=int, default=128) +parser.add_argument("--hidden-size", type=int, default=16384) parser.add_argument("--activation-size", type=int, default=53248) parser.add_argument("--no-batch", action="store_true") parser.add_argument("--no-fsdp", action="store_true") @@ -48,16 +47,15 @@ # Operand shapes dtype = jnp.bfloat16 -hidden_size = args.num_heads * args.head_dim lhs_shape = ( - [args.seq_length, hidden_size] + [args.seq_length, args.hidden_size] if args.comm_type == "AG" else [args.seq_length, args.activation_size] ) rhs_shape = ( - [hidden_size, args.activation_size] + [args.hidden_size, args.activation_size] if args.comm_type == "AG" - else [args.activation_size, hidden_size] + else [args.activation_size, args.hidden_size] ) # Operand partitioning diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 6bab28e4b3..5aeef243c5 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -182,10 +182,11 @@ def abstract( out_amax_updated_dtype == out_scale_updated_dtype == jnp.float32 ), "Invalid output amax or scale dtype." else: - assert out_dtype == lhs_dtype, ( - "Output buffer has incorrect dtype: " - + f"expected {lhs_dtype} but found {out_dtype}" - ) + if not jax_dtype_is_fp8(lhs_dtype): + assert out_dtype == lhs_dtype, ( + "Output buffer has incorrect dtype: " + + f"expected {lhs_dtype} but found {out_dtype}" + ) out_amax_updated_dtype = jnp.float32 out_scale_updated_dtype = jnp.float32 @@ -208,7 +209,8 @@ def abstract( expected_out_shape[-1], ] - if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": + workspace_size = get_cublas_workspace_size_bytes() + if comm_overlap_config is not None: comm_type = comm_overlap_config.get("comm_type", None) assert comm_type is not None, "Missing comm type for comm+GEMM overlap." @@ -217,18 +219,26 @@ def abstract( tp_size > 1 ), "Comm+GEMM overlap requires tensor-parallel mesh axis size greater than 1." - if comm_type == tex.CommOverlapType.AG: - expected_extra_out_shape = list(lhs_aval.shape).copy() - elif comm_type == tex.CommOverlapType.RS: - expected_extra_out_shape = list(expected_out_shape).copy() - expected_extra_out_dtype = lhs_dtype + if comm_overlap_config["method"] != "bulk": + # Increase workspace size to ensure every GEMM chunk has an independent workspace + # of the appropriate size + if comm_overlap_config["method"] == "pipeline": + workspace_size *= comm_overlap_config.get("num_splits", 4) + elif comm_overlap_config["method"] == "ring_exchange": + workspace_size *= tp_size - if sharded_abstract: if comm_type == tex.CommOverlapType.AG: - expected_out_shape[-2] *= tp_size - expected_extra_out_shape[-2] *= tp_size - else: - expected_extra_out_shape[-2] = expected_extra_out_shape[-2] // tp_size + expected_extra_out_shape = list(lhs_aval.shape).copy() + elif comm_type == tex.CommOverlapType.RS: + expected_extra_out_shape = list(expected_out_shape).copy() + expected_extra_out_dtype = lhs_dtype + + if sharded_abstract: + if comm_type == tex.CommOverlapType.AG: + expected_out_shape[-2] *= tp_size + expected_extra_out_shape[-2] *= tp_size + else: + expected_extra_out_shape[-2] = expected_extra_out_shape[-2] // tp_size assert out_aval.ndim == len(expected_out_shape), ( "Output buffer has incorrect number of dimensions: " @@ -296,9 +306,7 @@ def abstract( pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_shape, dtype=bias_dtype) bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) extra_out_updated_aval = extra_out_aval.update(shape=extra_out_shape, dtype=extra_out_dtype) - workspace_aval = jax.core.ShapedArray( - shape=(get_cublas_workspace_size_bytes(),), dtype=jnp.uint8 - ) + workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) return ( out_updated_aval, @@ -428,10 +436,16 @@ def lowering( m = lhs_aval.shape[lhs_outer_dim] k = rhs_aval.shape[rhs_inner_dim] n = rhs_aval.shape[rhs_outer_dim] - workspace_size = get_cublas_workspace_size_bytes() operand_dtype = jax_dtype_to_te_dtype(lhs_aval.dtype) bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype) + workspace_size = get_cublas_workspace_size_bytes() + if comm_overlap_config is not None: + if comm_overlap_config["method"] == "pipeline": + workspace_size *= comm_overlap_config["num_splits"] + elif comm_overlap_config["method"] == "ring_exchange": + workspace_size *= comm_overlap_config["tp_size"] + descriptor_packer_fn = tex.pack_gemm_decriptor descriptor_args = ( m, From b7e034e12e445ab502258798ce798f8f7168abd6 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 17 Jan 2025 21:44:56 +0000 Subject: [PATCH 37/39] syntactic cleanup for workspace size correction in TP overlap Signed-off-by: Alp Dener --- transformer_engine/jax/cpp_extensions/gemm.py | 27 ++++++++++++------- transformer_engine/jax/gemm.py | 7 ++--- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 5aeef243c5..3816d52745 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -36,8 +36,12 @@ "gemm_impl", "copy_into_overlap_buffer", "bootstrap_comm_gemm_overlap", + "get_num_max_compute_streams", + "set_num_max_compute_streams", ] + +_NUM_MAX_COMPUTE_STREAMS = 3 _COMM_GEMM_OVERLAP_LAYERS = ["qkv", "proj", "fc1", "fc2"] _COMM_GEMM_OVERLAP_NAMES = ( [layer + "_fprop" for layer in _COMM_GEMM_OVERLAP_LAYERS] @@ -62,6 +66,17 @@ def get_cublas_workspace_size_bytes() -> None: return 4_194_304 +def get_num_max_compute_streams() -> int: + """Return the maximum number of compute streams that Comm+GEMM overlap can utilize.""" + return _NUM_MAX_COMPUTE_STREAMS + + +def set_num_max_compute_streams(new_max: int) -> None: + """Change the maximum number of compute streams that Comm+GEMM overlap can utilize.""" + global _NUM_MAX_COMPUTE_STREAMS + _NUM_MAX_COMPUTE_STREAMS = new_max + + class CollectiveGemmPrimitive(BasePrimitive): """ cuBlasLt GEMM Primitive w/ support for distributed inputs @@ -222,10 +237,7 @@ def abstract( if comm_overlap_config["method"] != "bulk": # Increase workspace size to ensure every GEMM chunk has an independent workspace # of the appropriate size - if comm_overlap_config["method"] == "pipeline": - workspace_size *= comm_overlap_config.get("num_splits", 4) - elif comm_overlap_config["method"] == "ring_exchange": - workspace_size *= tp_size + workspace_size *= _NUM_MAX_COMPUTE_STREAMS if comm_type == tex.CommOverlapType.AG: expected_extra_out_shape = list(lhs_aval.shape).copy() @@ -440,11 +452,8 @@ def lowering( bias_dtype = jax_dtype_to_te_dtype(bias_aval.dtype) workspace_size = get_cublas_workspace_size_bytes() - if comm_overlap_config is not None: - if comm_overlap_config["method"] == "pipeline": - workspace_size *= comm_overlap_config["num_splits"] - elif comm_overlap_config["method"] == "ring_exchange": - workspace_size *= comm_overlap_config["tp_size"] + if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": + workspace_size *= get_num_max_compute_streams() descriptor_packer_fn = tex.pack_gemm_decriptor descriptor_args = ( diff --git a/transformer_engine/jax/gemm.py b/transformer_engine/jax/gemm.py index 7024dcb9fe..06cd52e97f 100644 --- a/transformer_engine/jax/gemm.py +++ b/transformer_engine/jax/gemm.py @@ -20,6 +20,7 @@ dact_lu, dbias_cast_transpose, dact_lu_dbias_cast_transpose, + get_num_max_compute_streams, ) from .cpp_extensions.gemm import sanitize_dims, mirror_dim, copy_into_overlap_buffer @@ -36,7 +37,7 @@ "get_comm_gemm_overlap_config", ] -_NUM_MAX_UB_STREAMS = 3 + _ACTIVE_COMM_GEMM_OVERLAPS = dict() @@ -921,7 +922,7 @@ def get_default_config(name): else tex.CommOverlapType.RS ), "num_sm": 1 if method == "ring_exchange" else 16, - "num_max_streams": _NUM_MAX_UB_STREAMS, + "num_max_streams": get_num_max_compute_streams(), "cga_size": 1 if method == "ring_exchange" else 2, "set_sm_margin": False, "num_splits": 4 if method == "pipeline" else tp_size, @@ -995,7 +996,7 @@ def add_new_comm_gemm_overlap( numranks, tp_size, kwargs["num_splits"], - _NUM_MAX_UB_STREAMS, + get_num_max_compute_streams(), kwargs["cga_size"], kwargs["num_sm"], kwargs["set_sm_margin"], From 5a3f4f32ae89854d89d2c266ee7d0ecfc108663b Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Sat, 25 Jan 2025 05:56:44 +0000 Subject: [PATCH 38/39] converted extra output in Comm+GEMM overlap to optional for AG overlaps Signed-off-by: Alp Dener --- .../comm_gemm_overlap/comm_gemm_overlap.py | 1 - transformer_engine/jax/cpp_extensions/gemm.py | 106 +++++++++--------- .../jax/csrc/extensions/comm_gemm_overlap.cpp | 6 +- 3 files changed, 59 insertions(+), 54 deletions(-) diff --git a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py index 2bc3651e51..8dc3035fbf 100644 --- a/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py +++ b/examples/jax/comm_gemm_overlap/comm_gemm_overlap.py @@ -5,7 +5,6 @@ import argparse import numpy as np -from functools import partial from mpi4py import MPI import jax diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 3816d52745..0d8ac0e51f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -212,10 +212,6 @@ def abstract( lhs_aval.shape[lhs_outer_dim], rhs_aval.shape[rhs_outer_dim], ] - extra_out_shape = extra_out_aval.shape - expected_extra_out_shape = [0] - extra_out_dtype = dtypes.canonicalize_dtype(extra_out_aval.dtype) - expected_extra_out_dtype = jnp.bfloat16 if batched_output: assert out_aval.ndim > 2, "Batched output buffer is missing batch dimensions." else: @@ -224,6 +220,8 @@ def abstract( expected_out_shape[-1], ] + expected_extra_out_shape = [0] + expected_extra_out_dtype = jnp.bfloat16 workspace_size = get_cublas_workspace_size_bytes() if comm_overlap_config is not None: comm_type = comm_overlap_config.get("comm_type", None) @@ -239,17 +237,21 @@ def abstract( # of the appropriate size workspace_size *= _NUM_MAX_COMPUTE_STREAMS - if comm_type == tex.CommOverlapType.AG: + if comm_type == tex.CommOverlapType.AG and extra_out_aval.size > 0: expected_extra_out_shape = list(lhs_aval.shape).copy() + expected_extra_out_dtype = lhs_dtype elif comm_type == tex.CommOverlapType.RS: + assert extra_out_aval.size > 0, ( + "GEMM+RS overlap requires extra output buffer." + ) expected_extra_out_shape = list(expected_out_shape).copy() - expected_extra_out_dtype = lhs_dtype if sharded_abstract: if comm_type == tex.CommOverlapType.AG: expected_out_shape[-2] *= tp_size - expected_extra_out_shape[-2] *= tp_size - else: + if extra_out_aval.size > 0: + expected_extra_out_shape[-2] *= tp_size + elif comm_type == tex.CommOverlapType.RS: expected_extra_out_shape[-2] = expected_extra_out_shape[-2] // tp_size assert out_aval.ndim == len(expected_out_shape), ( @@ -261,23 +263,25 @@ def abstract( + f"expected {expected_out_shape=} but found {out_aval.shape=}" ) - assert extra_out_dtype == expected_extra_out_dtype, ( - "Extra output has incorrect dtype: " - + f"expected {expected_extra_out_dtype} but found {extra_out_dtype}" - ) - assert extra_out_aval.ndim == len(expected_extra_out_shape), ( - "Extra output buffer has incorrect number of dimensions: " - + f"expected {len(expected_extra_out_shape)} but found {extra_out_aval.ndim}" - ) - assert all( - [ - extra_out_aval.shape[i] == expected_extra_out_shape[i] - for i in range(extra_out_aval.ndim) - ] - ), ( - "Extra output buffer has incorrect shape: " - + f"expected {expected_extra_out_shape=} but found {extra_out_aval.shape=}" - ) + if extra_out_aval.size > 0: + extra_out_dtype = dtypes.canonicalize_dtype(extra_out_aval.dtype) + assert extra_out_dtype == expected_extra_out_dtype, ( + "Extra output has incorrect dtype: " + + f"expected {expected_extra_out_dtype} but found {extra_out_dtype}" + ) + assert extra_out_aval.ndim == len(expected_extra_out_shape), ( + "Extra output buffer has incorrect number of dimensions: " + + f"expected {len(expected_extra_out_shape)} but found {extra_out_aval.ndim}" + ) + assert all( + [ + extra_out_aval.shape[i] == expected_extra_out_shape[i] + for i in range(extra_out_aval.ndim) + ] + ), ( + "Extra output buffer has incorrect shape: " + + f"expected {expected_extra_out_shape=} but found {extra_out_aval.shape=}" + ) # Validate bias/bias_grad shape against output bufer bias_dtype = jnp.bfloat16 if jax_dtype_is_fp8(out_dtype) else out_dtype @@ -317,7 +321,8 @@ def abstract( ) pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_shape, dtype=bias_dtype) bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - extra_out_updated_aval = extra_out_aval.update(shape=extra_out_shape, dtype=extra_out_dtype) + extra_out_updated_aval = extra_out_aval.update(shape=expected_extra_out_shape, + dtype=expected_extra_out_dtype) workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) return ( @@ -381,7 +386,7 @@ def lowering( Fused attention fwd lowering rules """ del batched_output, sharded_abstract - lhs_aval, _, rhs_aval, _, bias_aval, *_ = ctx.avals_in + lhs_aval, _, rhs_aval, _, bias_aval, *_, extra_out_aval = ctx.avals_in lhs_inner_dim, rhs_inner_dim = map( sanitize_dims, contracting_dims, (lhs_aval.ndim, rhs_aval.ndim) ) @@ -407,8 +412,9 @@ def lowering( 6: 0, # out <--> out_updated 7: 1, # out_amax <--> out_amax_updated 8: 2, # out_scale <--> out_scale_updated - 9: 5, # extra_out <--> extra_out_updated } + if extra_out_aval.size > 0: + operand_output_aliases[9] = 5 # extra_out <--> extra_out_updated if is_ffi_enabled(): name = "te_gemm_ffi" @@ -570,7 +576,7 @@ def impl( batched_extra_out = False if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": comm_type = comm_overlap_config["comm_type"] - if comm_type == tex.CommOverlapType.AG: + if comm_type == tex.CommOverlapType.AG and extra_out.size > 0: # Extra output is global LHS, we can collapse but need to recover batches later batched_extra_out = len(lhs_batch_dims) > 0 elif comm_type == tex.CommOverlapType.RS: @@ -701,7 +707,7 @@ def infer_sharding_from_operands( result_infos, ): del accumulate, use_split_accumulator, sharded_abstract, result_infos - lhs, _, rhs, *_ = arg_infos + lhs, _, rhs, *_, extra_out = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) @@ -778,7 +784,7 @@ def infer_sharding_from_operands( # Validate operand sharding for comm+GEMM overlap and adust extra output sharding extra_out_spec = [None] - if comm_overlap_config is not None: + if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": comm_type = comm_overlap_config.get("comm_type", None) tp_resource = comm_overlap_config.get("tp_resource", global_mesh_resource().tp_resource) if comm_type == tex.CommOverlapType.AG: @@ -796,8 +802,9 @@ def infer_sharding_from_operands( "AG+GEMM overlap requires the contracting dimension of the RHS operand " + "to be unsharded." ) - extra_out_spec = list(lhs_spec).copy() - extra_out_spec[lhs_outer_dim] = None + if extra_out.size > 0: + extra_out_spec = list(lhs_spec).copy() + extra_out_spec[lhs_outer_dim] = None elif comm_type == tex.CommOverlapType.RS: # RS overlap requires the contracting dimensions of both LHS and RHS to be @@ -821,6 +828,7 @@ def infer_sharding_from_operands( ) extra_out_spec = list(out_spec).copy() extra_out_spec[-2] = tp_resource + extra_out_sharding = NamedSharding(mesh, PartitionSpec(*extra_out_spec)) return ( @@ -848,7 +856,7 @@ def partition( result_infos, ): del sharded_abstract, result_infos - lhs, _, rhs, *_ = arg_infos + lhs, _, rhs, *_, extra_out = arg_infos lhs_spec, rhs_spec = map(get_padded_spec, [lhs, rhs]) lhs_inner_dim, rhs_inner_dim = map(sanitize_dims, contracting_dims, (lhs.ndim, rhs.ndim)) @@ -900,9 +908,9 @@ def partition( # Extra output sharding for comm+GEMM overlap extra_out_spec = [None] - if comm_overlap_config is not None: + if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": comm_type = comm_overlap_config.get("comm_type", None) - if comm_type == tex.CommOverlapType.AG: + if comm_type == tex.CommOverlapType.AG and extra_out.size > 0: extra_out_spec = list(lhs_spec).copy() extra_out_spec[lhs_outer_dim] = None elif comm_type == tex.CommOverlapType.RS: @@ -1029,14 +1037,12 @@ def gemm_impl( out = jnp.zeros(out_shape, dtype=lhs.dtype) if extra_out is None: - extra_out_shape = 0 - if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": - comm_type = comm_overlap_config["comm_type"] - if comm_type == tex.CommOverlapType.AG: - extra_out_shape = list(lhs.shape).copy() - elif comm_type == tex.CommOverlapType.RS: - extra_out_shape = list(out_shape).copy() - extra_out = jnp.zeros(extra_out_shape, dtype=lhs.dtype) + extra_out_shape = (0,) + if (comm_overlap_config is not None + and comm_overlap_config["method"] != "bulk" + and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS): + extra_out_shape = list(out_shape).copy() + extra_out = jnp.zeros(extra_out_shape, dtype=jnp.bfloat16) if not fuse_bias: bias = jnp.zeros(0, dtype=lhs.dtype) @@ -1119,13 +1125,11 @@ def fp8_gemm_impl( out_dtype = out.dtype if extra_out is None: - extra_out_shape = 0 - if comm_overlap_config is not None and comm_overlap_config["method"] != "bulk": - comm_type = comm_overlap_config["comm_type"] - if comm_type == tex.CommOverlapType.AG: - extra_out_shape = list(lhs.shape).copy() - elif comm_type == tex.CommOverlapType.RS: - extra_out_shape = list(out_shape).copy() + extra_out_shape = (0,) + if (comm_overlap_config is not None + and comm_overlap_config["method"] != "bulk" + and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS): + extra_out_shape = list(out_shape).copy() extra_out = jnp.zeros(extra_out_shape, dtype=jnp.bfloat16) if jax_dtype_is_fp8(out_dtype): diff --git a/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp index 533fdc3e83..02b415b321 100644 --- a/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/jax/csrc/extensions/comm_gemm_overlap.cpp @@ -258,8 +258,10 @@ Error_Type CommGemmOverlapFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type "out_amax not bound to out_amax_updated in TE/JAX comm+GEMM overlap."); NVTE_CHECK(out_scale_ptr == out_scale_updated_ptr, "out_scale not bound to out_scale_updated in TE/JAX comm+GEMM overlap."); - NVTE_CHECK(extra_out_ptr == extra_out_updated_ptr, - "extra_out not bound to extra_out_updated in TE/JAX comm+GEMM overlap."); + if (extra_out.element_count() > 0) { + NVTE_CHECK(extra_out_ptr == extra_out_updated_ptr, + "extra_out not bound to extra_out_updated in TE/JAX comm+GEMM overlap."); + } CommGemmOverlapImpl( lhs_ptr, lhs_shape, lhs_dtype, lhs_scale_inv_ptr, lhs_trans, rhs_ptr, rhs_shape, rhs_dtype, From 43a38cff1c47a56627c4e9cd6139a6b973c2c04a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 25 Jan 2025 05:57:14 +0000 Subject: [PATCH 39/39] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 0d8ac0e51f..b60fd1c74f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -241,9 +241,7 @@ def abstract( expected_extra_out_shape = list(lhs_aval.shape).copy() expected_extra_out_dtype = lhs_dtype elif comm_type == tex.CommOverlapType.RS: - assert extra_out_aval.size > 0, ( - "GEMM+RS overlap requires extra output buffer." - ) + assert extra_out_aval.size > 0, "GEMM+RS overlap requires extra output buffer." expected_extra_out_shape = list(expected_out_shape).copy() if sharded_abstract: @@ -321,8 +319,9 @@ def abstract( ) pre_gelu_out_aval = gelu_input_aval.update(shape=gelu_shape, dtype=bias_dtype) bias_grad_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - extra_out_updated_aval = extra_out_aval.update(shape=expected_extra_out_shape, - dtype=expected_extra_out_dtype) + extra_out_updated_aval = extra_out_aval.update( + shape=expected_extra_out_shape, dtype=expected_extra_out_dtype + ) workspace_aval = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) return ( @@ -1038,9 +1037,11 @@ def gemm_impl( if extra_out is None: extra_out_shape = (0,) - if (comm_overlap_config is not None + if ( + comm_overlap_config is not None and comm_overlap_config["method"] != "bulk" - and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS): + and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): extra_out_shape = list(out_shape).copy() extra_out = jnp.zeros(extra_out_shape, dtype=jnp.bfloat16) @@ -1126,9 +1127,11 @@ def fp8_gemm_impl( if extra_out is None: extra_out_shape = (0,) - if (comm_overlap_config is not None + if ( + comm_overlap_config is not None and comm_overlap_config["method"] != "bulk" - and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS): + and comm_overlap_config["comm_type"] == tex.CommOverlapType.RS + ): extra_out_shape = list(out_shape).copy() extra_out = jnp.zeros(extra_out_shape, dtype=jnp.bfloat16)