Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 11, 2024
1 parent 7ecfe04 commit e4cf960
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 23 deletions.
55 changes: 36 additions & 19 deletions tests/pytorch/distributed/run_fsdp2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.distributed.device_mesh import init_device_mesh
from contextlib import nullcontext


class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNet, self).__init__()
Expand All @@ -30,39 +31,51 @@ def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x



def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs


def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)


def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
parser.add_argument('--input-size', type=int, default=2048, help='Input size for the model')
parser.add_argument('--hidden-size', type=int, default=2048, help='Hidden layer size')
parser.add_argument('--output-size', type=int, default=2048, help='Output size for the model')
parser.add_argument('--batch-size', type=int, default=2048, help='Output size for the model')
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument('--iter', type=int, default=10, help='Number of iterations for forward pass')
parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass"
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
# Adding hsdp_dim as a list argument, comma-separated
parser.add_argument('--sharding-dims', type=int, nargs='+', help='FSDP/HSDP sharding dimensions ("replicate", "shard")')
parser.add_argument(
"--sharding-dims",
type=int,
nargs="+",
help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
)
args = parser.parse_args(argv, namespace)
if args.sharding_dims:
assert len(args.sharding_dims) <= 2
return args

sub_modules_to_wrap=[te.Linear]

sub_modules_to_wrap = [te.Linear]


def _train(args):
assert "TORCHELASTIC_RUN_ID" in os.environ
Expand All @@ -76,7 +89,7 @@ def _train(args):
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

# Initialize torch.distributed global process group and get DP/TP groups
dist_init_kwargs = {
"backend": "nccl",
Expand All @@ -86,7 +99,7 @@ def _train(args):
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
device = torch.device(f'cuda:{LOCAL_RANK}')
device = torch.device(f"cuda:{LOCAL_RANK}")

# FP8 Configuration
fp8_format = Format.HYBRID
Expand All @@ -98,6 +111,7 @@ def _train(args):
build_model_context_args = {}

from transformer_engine.pytorch import fp8_model_init

build_model_context = fp8_model_init
build_model_context_args["enabled"] = True

Expand All @@ -112,29 +126,32 @@ def _train(args):

if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
# Creating a DeviceMesh for fully_shard
# Creating a DeviceMesh for fully_shard
world_size = int(WORLD_SIZE)
device_ids = list(range(world_size))
if LOCAL_RANK == 0:
print(f"sharding-dims:{args.sharding_dims}")
# Setup the sharding mesh for FSDP/HSDP
if args.sharding_dims == None: # FSDP
if args.sharding_dims == None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 1:
assert args.sharding_dims[0] == device_ids[-1] +1
assert args.sharding_dims[0] == device_ids[-1] + 1
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] +1
mesh = init_device_mesh("cuda", (args.sharding_dims[0], args.sharding_dims[1]), mesh_dim_names=("replicate", "shard"))
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
mesh = init_device_mesh(
"cuda",
(args.sharding_dims[0], args.sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False

# Apply FSDP/HSDP
custom_attrs = save_custom_attrs(model)
for sub_module in model.modules():
if any(
isinstance(sub_module, sub_module_to_wrap)
for sub_module_to_wrap in sub_modules_to_wrap
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
Expand All @@ -153,7 +170,7 @@ def _train(args):
optimizer.step()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")

dist.destroy_process_group()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Done...")
Expand Down
13 changes: 9 additions & 4 deletions tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,18 @@
import torch
from packaging.version import Version as PkgVersion


def get_torch_version():
"""Get pytorch version from __version__"""

def get_torch_version_str():
import torch

return str(torch.__version__)

return PkgVersion(get_torch_version_str())


if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs.")

Expand Down Expand Up @@ -51,12 +56,12 @@ def _run_test(fp_init, sharding_dims):


all_boolean = [True, False]
sharding_dims =[[NUM_PROCS], [2, NUM_PROCS//2]]
sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]]


@pytest.mark.parametrize("sharding_dims", sharding_dims)
@pytest.mark.parametrize("fp8_init", all_boolean)
def test_distributed(fp8_init,sharding_dims):
def test_distributed(fp8_init, sharding_dims):
if fp8_init and not fp8_available:
pytest.skip(reason_for_no_fp8)
_run_test(fp8_init,sharding_dims)

_run_test(fp8_init, sharding_dims)

0 comments on commit e4cf960

Please sign in to comment.