Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C/JAX] Comm+GEMM Overlap API for TE/JAX #1337

Draft
wants to merge 41 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ad6bf2a
added XLA custom op defs for TE GEMM
denera Oct 24, 2024
c9774d8
fixed batching rules to accommodated batched RHS operand for GEMM
denera Nov 14, 2024
e523018
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
2c3dbf1
re-applied bug fixes to working older version, updated backward pass,…
denera Nov 15, 2024
448eaa9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2024
cb6ae3c
batched operands for GEMM custom op seem to be working now
denera Nov 18, 2024
6f67355
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2024
4b2b2d4
fixed batch size 1 issue and enabled FSDP sharding for RHS operand
denera Nov 19, 2024
2b2753e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2024
969f597
fixed FSDP+TP w/ DP=1 and TP+DP, but FSDP+TP w/ DP>1 still crashes
denera Nov 21, 2024
ce86dcb
fixed logic to remove FSDP sharding
denera Nov 21, 2024
b215f20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2024
cbab16c
retained FSDP dims and pushed FSDP all-gather of weight array to outs…
denera Nov 21, 2024
0ea55c0
Added useful warning about DGRAD sharding not matching sequence/conte…
denera Nov 21, 2024
2acb92f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2024
b07bb2d
documentation fixes
denera Nov 21, 2024
765b844
added unit test, both AG+GEMM and GEMM+AR passing with FSDP+TP, DP+TP…
denera Nov 27, 2024
2ce4377
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 27, 2024
f68d71e
restored old test_custom_call_compute.py to remove erroneous changes
denera Dec 5, 2024
6b322bb
added XLA custom ops and C++ infrastructure for comm+GEMM overlap in …
denera Nov 14, 2024
b306608
AG+GEMM overlap working
denera Dec 3, 2024
aa16726
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2024
a569e3b
added comm+GEMM overlap example script
denera Dec 3, 2024
69db12e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 3, 2024
ec2d5ae
RS overlap also works
denera Dec 3, 2024
8fe3942
added missing copy of AG+GEMM input into comm buffer
denera Dec 3, 2024
adf4046
updated FWD/BWD wrappers for non-FP8 and FP8 gemm
denera Dec 5, 2024
c4c608b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2024
4707df3
added more documentation to the TE/JAX comm+GEMM overlap example
denera Dec 5, 2024
18a6249
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2024
b144941
fixed RS overlap in the example
denera Dec 9, 2024
76b26bc
Merge branch 'jax-collective-gemm-with-overlap' of github.com:denera/…
denera Dec 9, 2024
6ad5651
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 9, 2024
5aceb02
updated comm overlap JAX example with numerical correctness check
denera Jan 14, 2025
f2b2a5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2025
1948254
changed commandline size controls to directly modify sequence length,…
denera Jan 15, 2025
bf0c88d
Merge branch 'jax-collective-gemm-with-overlap' of github.com:denera/…
denera Jan 15, 2025
b2720cb
fixed incorrect chunking of cuBLAS workspace
denera Jan 17, 2025
b7e034e
syntactic cleanup for workspace size correction in TP overlap
denera Jan 17, 2025
5a3f4f3
converted extra output in Comm+GEMM overlap to optional for AG overlaps
denera Jan 25, 2025
43a38cf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "3rdparty/cudnn-frontend"]
path = 3rdparty/cudnn-frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
[submodule "3rdparty/dlpack"]
path = 3rdparty/dlpack
url = [email protected]:dmlc/dlpack.git
1 change: 1 addition & 0 deletions 3rdparty/dlpack
Submodule dlpack added at bbd2f4
20 changes: 20 additions & 0 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""JAX related extensions."""
import os
from pathlib import Path
from typing import Optional

import setuptools
from glob import glob
Expand Down Expand Up @@ -36,6 +37,7 @@ def setup_jax_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
third_party_packages,
) -> setuptools.Extension:
"""Setup PyBind11 extension for JAX support"""
# Source files
Expand All @@ -55,12 +57,28 @@ def setup_jax_extension(
common_header_files / "common" / "include",
csrc_header_files,
xla_home,
third_party_packages / "dlpack" / "include",
]

# Compile flags
cxx_flags = ["-O3"]
nvcc_flags = ["-O3"]

# Userbuffers MPI dependence
libraries = []
library_dirs = []
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
mpi_home = os.getenv("MPI_HOME")
assert mpi_home is not None, "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
mpi_home = Path(mpi_home)
libraries.append("mpi")
library_dirs.append(mpi_home / "lib")

include_dirs.append(mpi_home / "include")

cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")

# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension

Expand All @@ -79,5 +97,7 @@ def _add_cflags(self, flags: List[str]) -> None:
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
library_dirs=[str(path) for path in library_dirs],
libraries=libraries,
extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags},
)
181 changes: 181 additions & 0 deletions examples/jax/comm_gemm_overlap/comm_gemm_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Comm+GEMM Overlap with TE/JAX"""

import argparse
import numpy as np

from mpi4py import MPI

import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from jax.experimental import mesh_utils

import transformer_engine.jax as te
from transformer_engine import transformer_engine_jax as tex
from transformer_engine.jax.cpp_extensions import gemm_impl, copy_into_overlap_buffer
from transformer_engine.jax.gemm import (
initialize_comm_gemm_overlaps,
destroy_comm_gemm_overlaps,
get_comm_overlap_config,
)
from transformer_engine.jax.sharding import get_padded_spec

jax.clear_caches()

# This script needs to be launched via `mpirun` with 1 process per GPU
myrank = MPI.COMM_WORLD.Get_rank()
numranks = MPI.COMM_WORLD.Get_size()
jax.distributed.initialize(cluster_detection_method="mpi4py")

parser = argparse.ArgumentParser()
parser.add_argument("-dp", "--dp-size", type=int, default=1)
parser.add_argument("-zp", "--fsdp-size", type=int, default=2)
parser.add_argument("-tp", "--tp-size", type=int, default=4)
parser.add_argument("-np", "--num-gpus", type=int, default=8)
parser.add_argument("--base-size", type=int, default=16)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--no-batch", action="store_true")
parser.add_argument("--no-fsdp", action="store_true")
parser.add_argument("--comm-type", type=str.upper, default="AG", choices=["AG", "RS"])
args = parser.parse_args()

# GEMM problem sizing
dtype = jnp.bfloat16
seq_length = args.base_size * 8
hidden_size = args.base_size * 6
ffn_hidden_size = args.base_size * 16

# Operand shapes
lhs_shape = [seq_length, hidden_size] if args.comm_type == "AG" else [seq_length, ffn_hidden_size]
rhs_shape = (
[hidden_size, ffn_hidden_size] if args.comm_type == "AG" else [ffn_hidden_size, hidden_size]
)

# Operand partitioning
batched = not args.no_batch
fsdp = not args.no_fsdp
if batched:
lhs_shape = [args.batch_size] + lhs_shape
if fsdp:
mesh_shape = {"dp": args.dp_size, "zp": args.fsdp_size, "tp": args.tp_size}
mesh_resource = te.MeshResource(
dp_resource="dp", tp_resource="tp", cp_resource="tp", fsdp_resource="zp"
)
if args.comm_type == "AG":
input_specs = [("dp", "zp"), "tp", None]
weight_specs = ["zp", "tp"]
weight_no_fsdp = [None, "tp"]
elif args.comm_type == "RS":
input_specs = [("dp", "zp"), None, "tp"]
weight_specs = ["tp", "zp"]
weight_no_fsdp = ["tp", None]
else:
mesh_shape = {"dp": args.dp_size, "tp": args.tp_size}
mesh_resource = te.MeshResource(
dp_resource="dp",
tp_resource="tp",
cp_resource="tp",
)
if args.comm_type == "AG":
input_specs = ["dp", "tp", None]
weight_specs = [None, "tp"]
elif args.comm_type == "RS":
input_specs = ["dp", None, "tp"]
weight_specs = ["tp", None]
weight_no_fsdp = weight_specs
else:
mesh_shape = {"tp": args.tp_size}
mesh_resource = te.MeshResource(tp_resource="tp", cp_resource="cp")
if args.comm_type == "AG":
input_specs = ["tp", None]
weight_specs = [None, "tp"]
elif args.comm_type == "RS":
input_specs = [None, "tp"]
weight_specs = ["tp", None]
weight_no_fsdp = weight_specs

# Mesh setup and sharding definitions
devices = mesh_utils.create_device_mesh((args.num_gpus,), devices=jax.devices()[: args.num_gpus])
mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys()))
input_sharding = NamedSharding(mesh, PartitionSpec(*input_specs))
weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_specs))
weight_no_fsdp_sharding = NamedSharding(mesh, PartitionSpec(*weight_no_fsdp))

# Operand initialization
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key, 2)
lhs = jax.device_put(jax.random.normal(key1, lhs_shape, dtype=dtype), input_sharding)
rhs = jax.device_put(jax.random.normal(key2, rhs_shape, dtype=dtype), weight_sharding)

# Name of comm+GEMM overlap layer
overlap_name = "ag_gemm" if args.comm_type == "AG" else "gemm_rs"

# Bootstrap Userbuffers communicators and communication buffers
initialize_comm_gemm_overlaps(
lhs_shape,
mesh,
myrank,
numranks,
tp_resource="tp",
overlap_configs={
overlap_name: {
"method": "ring_exchange", # "pipeline" for collective kernels instead of send/recv
"comm_type": (
tex.CommOverlapType.AG if args.comm_type == "AG" else tex.CommOverlapType.RS
),
"num_splits": args.tp_size, # independent of TP size for "pipeline"
"cga_size": 1, # default is 2 for "pipeline"
"num_sm": 1, # ignored for "ring_exchange", must be tuned for "pipeline"
"set_sm_margin": False, # set to True for "pipeline"
"atomic_gemm": False, # more performant when not using CUDA Graphs
"use_ce": True, # ignored (always False) for "pipeline" method
},
},
)

if myrank == 0:
print(
f"{myrank}: INPUTS {lhs.shape} x {rhs.shape}\n"
+ f"{myrank}: LHS sharding: {lhs.sharding.spec}\n"
+ f"{myrank}: RHS sharding: {rhs.sharding.spec}\n",
flush=True,
)


@jax.jit
def te_gemm(A, B):
# For AG overlap, LHS needs to be copied into the comm. buffer before GEMM. This can usually
# be circumvented by extracting the comm. buffer as a JAX array via
# `buffer = jax.dlpack.from_dlpack(tex.get_overlap_buffer(overlap_name: str, sharded: bool))`
# and directly writing the result of a preceding operation into it (e.g.. LayerNorm output
# written directly into the communication buffer before AG+GEMM in a QKV projection)
if args.comm_type == "AG":
copy_into_overlap_buffer(A, overlap_name, True)
return_idx = 0
else:
# For RS overlap, the scattered output is in the `extra_out` array.
return_idx = -1

return gemm_impl(
A,
jax.lax.with_sharding_constraint(B, weight_no_fsdp_sharding), # all-gather FSDP weights
batched_output=True, # internal option, will be hidden by the FWD/BWD wrapper
comm_overlap_config=get_comm_overlap_config(overlap_name),
)[return_idx]


with te.sharding.global_shard_guard(mesh_resource):
output = te_gemm(lhs, rhs)

if myrank == 0:
print(
f"{myrank}: {'AG -> GEMM' if args.comm_type == 'AG' else 'GEMM -> RS'} OUTPUT "
+ f"{output.shape}\n"
+ f"{myrank}: Sharding: {get_padded_spec(output.sharding.spec, output.ndim)}\n",
flush=True,
)

destroy_comm_gemm_overlaps()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
"transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine",
current_file_path / "3rdparty",
)
)
if "paddle" in frameworks:
Expand Down
Loading