Skip to content

Commit

Permalink
Merge branch 'main' into swa_padding_brcm
Browse files Browse the repository at this point in the history
  • Loading branch information
cyanguwa authored Dec 17, 2024
2 parents 6f677da + f4f35c2 commit 6a8e073
Show file tree
Hide file tree
Showing 13 changed files with 392 additions and 43 deletions.
1 change: 1 addition & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ jobs:
|| github.actor == 'kocchop'
|| github.actor == 'youngeunkwon0405'
|| github.actor == 'KshitijLakhani'
|| github.actor == 'jberchtold-nvidia'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/comm_gemm_overlap/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Forward and backward passes with layer weights distributed over all GPUs in a single node.

```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py

# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7]
Expand Down Expand Up @@ -70,7 +70,7 @@ Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across
groups in a single node.

```bash
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2
$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py --num-replicas 2

# Sample output on 8x H100s:
# [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3]
Expand Down
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
181 changes: 181 additions & 0 deletions tests/pytorch/distributed/run_fsdp2_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#!/usr/bin/python3

# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import sys
import argparse

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn, optim
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from contextlib import nullcontext


class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNet, self).__init__()
self.fc1 = te.Linear(input_size, hidden_size)
self.fc2 = te.Linear(hidden_size, output_size)

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x


def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs


def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)


def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass"
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
# Adding hsdp_dim as a list argument, comma-separated
parser.add_argument(
"--sharding-dims",
type=int,
nargs="+",
help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
)
args = parser.parse_args(argv, namespace)
if args.sharding_dims:
assert len(args.sharding_dims) <= 2
return args


sub_modules_to_wrap = [te.Linear]


def _train(args):
assert "TORCHELASTIC_RUN_ID" in os.environ
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert LOCAL_SIZE == WORLD_SIZE

# Set device and initialize RNG states
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

# Initialize torch.distributed global process group and get DP/TP groups
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
device = torch.device(f"cuda:{LOCAL_RANK}")

# FP8 Configuration
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")

if not args.fp8_init:
# Build model context (FP8 init)
build_model_context = nullcontext
build_model_context_args = {}

from transformer_engine.pytorch import fp8_model_init

build_model_context = fp8_model_init
build_model_context_args["enabled"] = True

# Build the model with the specified context
with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
else:
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
# Move the model to the correct device

model.to(device)

if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
# Creating a DeviceMesh for fully_shard
world_size = int(WORLD_SIZE)
device_ids = list(range(world_size))
if LOCAL_RANK == 0:
print(f"sharding-dims:{args.sharding_dims}")
# Setup the sharding mesh for FSDP/HSDP
if args.sharding_dims == None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 1:
assert args.sharding_dims[0] == device_ids[-1] + 1
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
mesh = init_device_mesh(
"cuda",
(args.sharding_dims[0], args.sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False

# Apply FSDP/HSDP
custom_attrs = save_custom_attrs(model)
for sub_module in model.modules():
if any(
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
restore_custom_attrs(model, custom_attrs)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

for iteration in range(args.iter):
# Zero the parameter gradients
optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device)
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")

dist.destroy_process_group()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Done...")
return 0


if __name__ == "__main__":
sys.exit(_train(_parse_args()))
67 changes: 67 additions & 0 deletions tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import pytest
import subprocess
from pathlib import Path
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import torch
from packaging.version import Version as PkgVersion


def get_torch_version():
"""Get pytorch version from __version__"""

def get_torch_version_str():
import torch

return str(torch.__version__)

return PkgVersion(get_torch_version_str())


if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs.")

if torch.cuda.device_count() % 2 != 0:
pytest.skip("Number of device should be divided by 2.")

if not get_torch_version() >= PkgVersion("2.4"):
pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.")

fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()

TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = torch.cuda.device_count()
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]


def _run_test(fp_init, sharding_dims):
test_path = TEST_ROOT / "run_fsdp2_model.py"
test_cmd = LAUNCH_CMD + [str(test_path)]

if fp_init:
test_cmd += ["--fp8-init"]
if len(sharding_dims) == 1:
test_cmd += ["--sharding-dims", str(sharding_dims[0])]
elif len(sharding_dims) == 2:
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
else:
assert False
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
if result.returncode != 0:
raise AssertionError(result.stderr.decode())


all_boolean = [True, False]
sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]]


@pytest.mark.parametrize("sharding_dims", sharding_dims)
@pytest.mark.parametrize("fp8_init", all_boolean)
def test_distributed(fp8_init, sharding_dims):
if fp8_init and not fp8_available:
pytest.skip(reason_for_no_fp8)
_run_test(fp8_init, sharding_dims)
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(

if (is_ragged && cudnn_runtime_version >= 90600) {
sdpa_backward_options.set_max_total_seq_len_q(s_q);
sdpa_backward_options.set_max_total_seq_len_kv(s_kv);
}

fe::DiagonalAlignment_t const &diagonal_alignment =
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/common/normalization/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase {

class NormalizationPlanRegistry {
public:
// TODO thread-safe
static NormalizationPlanRegistry& getInstance() {
static NormalizationPlanRegistry instance;
static thread_local NormalizationPlanRegistry instance;
return instance;
}

Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/jax/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import jax
import jax.numpy as jnp
from jax import core, dtypes
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi
Expand Down Expand Up @@ -98,7 +98,7 @@ def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument
assert x_shape[-2] == 2 or x_shape[-2] == 1
hidden_size = x_shape[-1]
batch_shapes = x_shape[:-2]
out_aval = core.raise_to_shaped(x_aval)
out_aval = x_aval
out_shape = (batch_shapes) + (hidden_size,)
out_aval = out_aval.update(shape=out_shape, dtype=dtype)

Expand Down Expand Up @@ -225,7 +225,7 @@ def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument
i_hidden_size = dz_aval.shape[-1]
g_hidden_size = x_aval.shape[-1]
assert i_hidden_size == g_hidden_size
out_aval = core.raise_to_shaped(x_aval)
out_aval = x_aval

return out_aval

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from abc import ABCMeta, abstractmethod
from functools import partial

from jax import core
from jax.extend import core
from jax.interpreters import xla, mlir
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
Expand Down
Loading

0 comments on commit 6a8e073

Please sign in to comment.