-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into swa_padding_brcm
- Loading branch information
Showing
13 changed files
with
392 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
#!/usr/bin/python3 | ||
|
||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
import os | ||
import sys | ||
import argparse | ||
|
||
import transformer_engine.pytorch as te | ||
from transformer_engine.common.recipe import Format, DelayedScaling | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn.functional as F | ||
from torch import nn, optim | ||
from torch.distributed import DeviceMesh | ||
from torch.distributed._composable.fsdp import fully_shard | ||
from torch.distributed.device_mesh import init_device_mesh | ||
from contextlib import nullcontext | ||
|
||
|
||
class SimpleNet(nn.Module): | ||
def __init__(self, input_size, hidden_size, output_size): | ||
super(SimpleNet, self).__init__() | ||
self.fc1 = te.Linear(input_size, hidden_size) | ||
self.fc2 = te.Linear(hidden_size, output_size) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.fc1(x)) | ||
x = self.fc2(x) | ||
return x | ||
|
||
|
||
def save_custom_attrs(module): | ||
custom_attrs = {} | ||
for name, param in module.named_parameters(): | ||
attrs = vars(param) | ||
custom_attrs[name] = {k: v for k, v in attrs.items()} | ||
return custom_attrs | ||
|
||
|
||
def restore_custom_attrs(module, custom_attrs): | ||
for name, param in module.named_parameters(): | ||
if name in custom_attrs: | ||
for attr_name, attr_value in custom_attrs[name].items(): | ||
setattr(param, attr_name, attr_value) | ||
|
||
|
||
def _parse_args(argv=None, namespace=None): | ||
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") | ||
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") | ||
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") | ||
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") | ||
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") | ||
parser.add_argument( | ||
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." | ||
) | ||
parser.add_argument( | ||
"--iter", type=int, default=10, help="Number of iterations for forward pass" | ||
) | ||
parser.add_argument("--seed", type=int, default=42, help="RNG seed.") | ||
# Adding hsdp_dim as a list argument, comma-separated | ||
parser.add_argument( | ||
"--sharding-dims", | ||
type=int, | ||
nargs="+", | ||
help='FSDP/HSDP sharding dimensions ("replicate", "shard")', | ||
) | ||
args = parser.parse_args(argv, namespace) | ||
if args.sharding_dims: | ||
assert len(args.sharding_dims) <= 2 | ||
return args | ||
|
||
|
||
sub_modules_to_wrap = [te.Linear] | ||
|
||
|
||
def _train(args): | ||
assert "TORCHELASTIC_RUN_ID" in os.environ | ||
WORLD_RANK = int(os.getenv("RANK", "0")) | ||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) | ||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) | ||
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) | ||
assert LOCAL_SIZE == WORLD_SIZE | ||
|
||
# Set device and initialize RNG states | ||
torch.cuda.set_device(WORLD_RANK) | ||
torch.manual_seed(args.seed) | ||
torch.cuda.manual_seed(args.seed) | ||
|
||
# Initialize torch.distributed global process group and get DP/TP groups | ||
dist_init_kwargs = { | ||
"backend": "nccl", | ||
"rank": WORLD_RANK, | ||
"world_size": WORLD_SIZE, | ||
} | ||
assert dist.is_nccl_available() | ||
dist.init_process_group(**dist_init_kwargs) | ||
nccl_world = dist.new_group(backend="nccl") | ||
device = torch.device(f"cuda:{LOCAL_RANK}") | ||
|
||
# FP8 Configuration | ||
fp8_format = Format.HYBRID | ||
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") | ||
|
||
if not args.fp8_init: | ||
# Build model context (FP8 init) | ||
build_model_context = nullcontext | ||
build_model_context_args = {} | ||
|
||
from transformer_engine.pytorch import fp8_model_init | ||
|
||
build_model_context = fp8_model_init | ||
build_model_context_args["enabled"] = True | ||
|
||
# Build the model with the specified context | ||
with build_model_context(**build_model_context_args): | ||
model = SimpleNet(args.input_size, args.hidden_size, args.output_size) | ||
else: | ||
model = SimpleNet(args.input_size, args.hidden_size, args.output_size) | ||
# Move the model to the correct device | ||
|
||
model.to(device) | ||
|
||
if LOCAL_RANK == 0: | ||
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...") | ||
# Creating a DeviceMesh for fully_shard | ||
world_size = int(WORLD_SIZE) | ||
device_ids = list(range(world_size)) | ||
if LOCAL_RANK == 0: | ||
print(f"sharding-dims:{args.sharding_dims}") | ||
# Setup the sharding mesh for FSDP/HSDP | ||
if args.sharding_dims == None: # FSDP | ||
mesh = DeviceMesh("cuda", device_ids) | ||
elif len(args.sharding_dims) == 1: | ||
assert args.sharding_dims[0] == device_ids[-1] + 1 | ||
mesh = DeviceMesh("cuda", device_ids) | ||
elif len(args.sharding_dims) == 2: # HSDP | ||
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 | ||
mesh = init_device_mesh( | ||
"cuda", | ||
(args.sharding_dims[0], args.sharding_dims[1]), | ||
mesh_dim_names=("replicate", "shard"), | ||
) | ||
else: | ||
assert False | ||
|
||
# Apply FSDP/HSDP | ||
custom_attrs = save_custom_attrs(model) | ||
for sub_module in model.modules(): | ||
if any( | ||
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap | ||
): | ||
fully_shard(sub_module, mesh=mesh) | ||
fully_shard(model, mesh=mesh) | ||
restore_custom_attrs(model, custom_attrs) | ||
|
||
optimizer = optim.Adam(model.parameters(), lr=1e-3) | ||
|
||
for iteration in range(args.iter): | ||
# Zero the parameter gradients | ||
optimizer.zero_grad() | ||
input_data = torch.randn(args.batch_size, args.input_size).to(device) | ||
output = model(input_data) | ||
target = torch.randn(args.batch_size, args.output_size).to(device) | ||
loss = F.mse_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
if LOCAL_RANK == 0: | ||
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") | ||
|
||
dist.destroy_process_group() | ||
if LOCAL_RANK == 0: | ||
print(f"Rank {LOCAL_RANK}: Done...") | ||
return 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(_train(_parse_args())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
import os | ||
import pytest | ||
import subprocess | ||
from pathlib import Path | ||
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager | ||
import torch | ||
from packaging.version import Version as PkgVersion | ||
|
||
|
||
def get_torch_version(): | ||
"""Get pytorch version from __version__""" | ||
|
||
def get_torch_version_str(): | ||
import torch | ||
|
||
return str(torch.__version__) | ||
|
||
return PkgVersion(get_torch_version_str()) | ||
|
||
|
||
if torch.cuda.device_count() < 4: | ||
pytest.skip("FSDP2 test requires at least 4 GPUs.") | ||
|
||
if torch.cuda.device_count() % 2 != 0: | ||
pytest.skip("Number of device should be divided by 2.") | ||
|
||
if not get_torch_version() >= PkgVersion("2.4"): | ||
pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.") | ||
|
||
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() | ||
|
||
TEST_ROOT = Path(__file__).parent.resolve() | ||
NUM_PROCS: int = torch.cuda.device_count() | ||
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] | ||
|
||
|
||
def _run_test(fp_init, sharding_dims): | ||
test_path = TEST_ROOT / "run_fsdp2_model.py" | ||
test_cmd = LAUNCH_CMD + [str(test_path)] | ||
|
||
if fp_init: | ||
test_cmd += ["--fp8-init"] | ||
if len(sharding_dims) == 1: | ||
test_cmd += ["--sharding-dims", str(sharding_dims[0])] | ||
elif len(sharding_dims) == 2: | ||
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] | ||
else: | ||
assert False | ||
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) | ||
if result.returncode != 0: | ||
raise AssertionError(result.stderr.decode()) | ||
|
||
|
||
all_boolean = [True, False] | ||
sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]] | ||
|
||
|
||
@pytest.mark.parametrize("sharding_dims", sharding_dims) | ||
@pytest.mark.parametrize("fp8_init", all_boolean) | ||
def test_distributed(fp8_init, sharding_dims): | ||
if fp8_init and not fp8_available: | ||
pytest.skip(reason_for_no_fp8) | ||
_run_test(fp8_init, sharding_dims) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.