diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 586abd0541..86d22b7944 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -42,6 +42,7 @@ jobs: || github.actor == 'kocchop' || github.actor == 'youngeunkwon0405' || github.actor == 'KshitijLakhani' + || github.actor == 'jberchtold-nvidia' ) steps: - name: Check if comment is issued by authorized person diff --git a/examples/pytorch/comm_gemm_overlap/README.md b/examples/pytorch/comm_gemm_overlap/README.md index bb3ba209ed..fc8458844b 100644 --- a/examples/pytorch/comm_gemm_overlap/README.md +++ b/examples/pytorch/comm_gemm_overlap/README.md @@ -16,7 +16,7 @@ Forward and backward passes with layer weights distributed over all GPUs in a single node. ```bash -$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py # Sample output on 8x H100s: # [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7] @@ -70,7 +70,7 @@ Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across groups in a single node. ```bash -$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2 +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py --num-replicas 2 # Sample output on 8x H100s: # [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3] diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 9a11ccc008..4e52153db9 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -11,4 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py new file mode 100644 index 0000000000..0f00a6717b --- /dev/null +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -0,0 +1,181 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import argparse + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn, optim +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import init_device_mesh +from contextlib import nullcontext + + +class SimpleNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(SimpleNet, self).__init__() + self.fc1 = te.Linear(input_size, hidden_size) + self.fc2 = te.Linear(hidden_size, output_size) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + + +def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + attrs = vars(param) + custom_attrs[name] = {k: v for k, v in attrs.items()} + return custom_attrs + + +def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") + parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") + parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") + parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") + parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") + parser.add_argument( + "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + ) + parser.add_argument( + "--iter", type=int, default=10, help="Number of iterations for forward pass" + ) + parser.add_argument("--seed", type=int, default=42, help="RNG seed.") + # Adding hsdp_dim as a list argument, comma-separated + parser.add_argument( + "--sharding-dims", + type=int, + nargs="+", + help='FSDP/HSDP sharding dimensions ("replicate", "shard")', + ) + args = parser.parse_args(argv, namespace) + if args.sharding_dims: + assert len(args.sharding_dims) <= 2 + return args + + +sub_modules_to_wrap = [te.Linear] + + +def _train(args): + assert "TORCHELASTIC_RUN_ID" in os.environ + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + assert LOCAL_SIZE == WORLD_SIZE + + # Set device and initialize RNG states + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + # Initialize torch.distributed global process group and get DP/TP groups + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + device = torch.device(f"cuda:{LOCAL_RANK}") + + # FP8 Configuration + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + + if not args.fp8_init: + # Build model context (FP8 init) + build_model_context = nullcontext + build_model_context_args = {} + + from transformer_engine.pytorch import fp8_model_init + + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + + # Build the model with the specified context + with build_model_context(**build_model_context_args): + model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + else: + model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + # Move the model to the correct device + + model.to(device) + + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...") + # Creating a DeviceMesh for fully_shard + world_size = int(WORLD_SIZE) + device_ids = list(range(world_size)) + if LOCAL_RANK == 0: + print(f"sharding-dims:{args.sharding_dims}") + # Setup the sharding mesh for FSDP/HSDP + if args.sharding_dims == None: # FSDP + mesh = DeviceMesh("cuda", device_ids) + elif len(args.sharding_dims) == 1: + assert args.sharding_dims[0] == device_ids[-1] + 1 + mesh = DeviceMesh("cuda", device_ids) + elif len(args.sharding_dims) == 2: # HSDP + assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 + mesh = init_device_mesh( + "cuda", + (args.sharding_dims[0], args.sharding_dims[1]), + mesh_dim_names=("replicate", "shard"), + ) + else: + assert False + + # Apply FSDP/HSDP + custom_attrs = save_custom_attrs(model) + for sub_module in model.modules(): + if any( + isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap + ): + fully_shard(sub_module, mesh=mesh) + fully_shard(model, mesh=mesh) + restore_custom_attrs(model, custom_attrs) + + optimizer = optim.Adam(model.parameters(), lr=1e-3) + + for iteration in range(args.iter): + # Zero the parameter gradients + optimizer.zero_grad() + input_data = torch.randn(args.batch_size, args.input_size).to(device) + output = model(input_data) + target = torch.randn(args.batch_size, args.output_size).to(device) + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") + + dist.destroy_process_group() + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Done...") + return 0 + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py new file mode 100644 index 0000000000..3c9197c322 --- /dev/null +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import pytest +import subprocess +from pathlib import Path +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import torch +from packaging.version import Version as PkgVersion + + +def get_torch_version(): + """Get pytorch version from __version__""" + + def get_torch_version_str(): + import torch + + return str(torch.__version__) + + return PkgVersion(get_torch_version_str()) + + +if torch.cuda.device_count() < 4: + pytest.skip("FSDP2 test requires at least 4 GPUs.") + +if torch.cuda.device_count() % 2 != 0: + pytest.skip("Number of device should be divided by 2.") + +if not get_torch_version() >= PkgVersion("2.4"): + pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.") + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = torch.cuda.device_count() +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(fp_init, sharding_dims): + test_path = TEST_ROOT / "run_fsdp2_model.py" + test_cmd = LAUNCH_CMD + [str(test_path)] + + if fp_init: + test_cmd += ["--fp8-init"] + if len(sharding_dims) == 1: + test_cmd += ["--sharding-dims", str(sharding_dims[0])] + elif len(sharding_dims) == 2: + test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] + else: + assert False + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) + if result.returncode != 0: + raise AssertionError(result.stderr.decode()) + + +all_boolean = [True, False] +sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]] + + +@pytest.mark.parametrize("sharding_dims", sharding_dims) +@pytest.mark.parametrize("fp8_init", all_boolean) +def test_distributed(fp8_init, sharding_dims): + if fp8_init and not fp8_available: + pytest.skip(reason_for_no_fp8) + _run_test(fp8_init, sharding_dims) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 57cc9358e2..4279eae635 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -679,6 +679,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( if (is_ragged && cudnn_runtime_version >= 90600) { sdpa_backward_options.set_max_total_seq_len_q(s_q); + sdpa_backward_options.set_max_total_seq_len_kv(s_kv); } fe::DiagonalAlignment_t const &diagonal_alignment = diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 8a8df63ba4..d1d56d5cc9 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { class NormalizationPlanRegistry { public: - // TODO thread-safe static NormalizationPlanRegistry& getInstance() { - static NormalizationPlanRegistry instance; + static thread_local NormalizationPlanRegistry instance; return instance; } diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 44b396ad55..7f09e6f900 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -8,7 +8,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi @@ -98,7 +98,7 @@ def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument assert x_shape[-2] == 2 or x_shape[-2] == 1 hidden_size = x_shape[-1] batch_shapes = x_shape[:-2] - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval out_shape = (batch_shapes) + (hidden_size,) out_aval = out_aval.update(shape=out_shape, dtype=dtype) @@ -225,7 +225,7 @@ def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument i_hidden_size = dz_aval.shape[-1] g_hidden_size = x_aval.shape[-1] assert i_hidden_size == g_hidden_size - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval return out_aval diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 3d88c1f078..3715e6f20c 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -7,7 +7,7 @@ from abc import ABCMeta, abstractmethod from functools import partial -from jax import core +from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 0b7df0b5a8..8ad7ee4fcb 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding @@ -74,7 +74,7 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): mu_rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) assert gamma_aval.size == beta_aval.size @@ -147,7 +147,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, output_type), @@ -361,8 +361,8 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] assert mu_dtype == rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) + dx_aval = dz_aval + dgamma_aval = dbeta_aval = gamma_aval (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size @@ -441,7 +441,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): sm_margin = get_backward_sm_margin() - wkspace_aval = ctx.avals_out[-4:] + wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, @@ -589,7 +589,7 @@ def abstract(x_aval, gamma_aval, **kwargs): rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) hidden_size = gamma_aval.size @@ -650,7 +650,7 @@ def lowering(ctx, x, gamma, *, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, x_type.element_type), @@ -783,8 +783,8 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs): assert rsigma_aval.shape == x_aval.shape[:-1] assert rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = core.raise_to_shaped(gamma_aval) + dx_aval = dz_aval + dgamma_aval = gamma_aval (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size @@ -841,7 +841,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): hidden_size = reduce(operator.mul, g_shape) batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-3:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(x_shape, x_type.element_type), @@ -1088,7 +1088,7 @@ def lowering( batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), @@ -1394,7 +1394,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index a12943f4c2..67053ecd8e 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -9,7 +9,7 @@ import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi @@ -126,7 +126,7 @@ def forward_abstract(logits_aval, scale_factor): assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported assert q_seqlen > 1 - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod @@ -237,7 +237,7 @@ def backward_abstract( assert dz_aval.shape == softmax_out_aval.shape - dx_aval = core.raise_to_shaped(dz_aval) + dx_aval = dz_aval return dx_aval @staticmethod @@ -578,7 +578,7 @@ def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-ar assert mask_shape[-2] == q_seqlen assert mask_shape[-1] == k_seqlen - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 9b5c156e5d..a986b91b30 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -61,34 +61,43 @@ pybind11::dict Registrations() { dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); - dict["te_dact_lu_dbias_cast_transpose_ffi"] = - EncapsulateFunction(DActLuDBiasCastTransposeHandler); - dict["te_dgated_act_lu_cast_transpose_ffi"] = - EncapsulateFunction(DGatedActLuCastTransposeHandler); + dict["te_dact_lu_dbias_cast_transpose_ffi"] = EncapsulateFFI(DActLuDBiasCastTransposeHandler); + dict["te_dgated_act_lu_cast_transpose_ffi"] = EncapsulateFFI(DGatedActLuCastTransposeHandler); // Quantization dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax - dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler); - dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler); - dict["te_scaled_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler); + dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler); + dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler); + dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler); dict["te_scaled_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler); dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler); dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization - dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); - dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); - dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); - dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler); - dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler); - dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler); + dict["te_layernorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler)); + dict["te_layernorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler)); + dict["te_layernorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler)); + dict["te_rmsnorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler)); + dict["te_rmsnorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler)); + dict["te_rmsnorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler)); // Attention pybind11::dict fused_attn_forward_ffi; diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 7ace68a222..414e819f53 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -24,6 +24,19 @@ aten = torch.ops.aten updated_fp8_params = {} +_ops_to_preserve_subclass_in_fsdp2 = { + torch.ops.aten.empty_like.default, + torch.ops.aten.new_zeros.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.copy_.default, + torch.ops.aten.view.default, + torch.ops.aten.as_strided.default, + torch.ops.aten._to_copy.default, + torch.ops.aten._pin_memory.default, + torch.ops.aten.split.Tensor, + torch.ops.aten.clone.default, +} + def _make_fp8_attr_property_funcs(name: str) -> Any: """Make accessors for an FP8 attribute @@ -430,6 +443,37 @@ def __new__( return self + def fsdp_pre_all_gather(self, mesh): # pylint: disable=unused-argument + """ + A hook function used in torch fsdp2, called before all-gather + return (all-gather input), (metadata) + Ref: https://github.com/pytorch/pytorch/pull/122908 + + """ + + return (self._data,), (self,) + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, # pylint: disable=unused-argument + *, + out: Optional[torch.Tensor] = None, + ): + """ + A hook function used in torch fsdp2, called after all-gather + return (Float8Tensor class instance of all-gathered input), (Things to free after forward) + Ref: https://github.com/pytorch/pytorch/pull/122908 + + """ + (data,) = all_gather_outputs + (sample,) = metadata + if out is not None: + assert isinstance(out, Float8Tensor), f"{type(out)}" + return None + return Float8Tensor.make_like(sample, data=data), all_gather_outputs + @classmethod def make_like( cls, @@ -902,7 +946,53 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) return Float8Tensor.make_like(tensor, data=data_view) - # Default case + # Related to FSDP2 + if func == aten.split.Tensor: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out] + if func == aten.new_zeros.default: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=func_out) + if func == torch.ops.aten.as_strided.default: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=func_out) + if func == torch.ops.aten.detach.default: + return cls.detach(args[0]) + if func == torch.ops.aten.clone.default: + return cls.clone(args[0]) + if func == torch.ops.aten.copy_.default: + # Implementation in the superclass (QuantizedTensor) returns a proper output + pass + elif func in _ops_to_preserve_subclass_in_fsdp2: + # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 + warnings.warn( + f"A function call({func}) in {cls} may not return {cls} tensor as an output. It" + " might cause an error in torch FSDP2!" + ) + else: + pass + return super().__torch_dispatch__(func, types, args, kwargs) @classmethod