Skip to content

Commit

Permalink
fixed backward pass for te.Linear column-parallel with TP overlap, up…
Browse files Browse the repository at this point in the history
…dated unit tests

Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Nov 20, 2024
1 parent 9555554 commit 7cb1954
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 133 deletions.
58 changes: 39 additions & 19 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,23 @@ def _get_layer_args(config, tp_group, tp_size, reference=False):
kwargs["ub_overlap_ag"] = not reference

if config.layer_type is te.Linear:
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["parallel_mode"] = config.parallel_mode
kwargs["ub_overlap_rs_fprop"] = not reference and config.parallel_mode == "row"
kwargs["ub_overlap_ag_dgrad"] = not reference and config.parallel_mode == "row"
kwargs["ub_overlap_ag_fprop"] = not reference and config.parallel_mode == "column"
kwargs["ub_overlap_rs_dgrad"] = not reference and config.parallel_mode == "column"
kwargs["ub_name"] = "proj" if config.parallel_mode == "row" else "qkv"
if config.linear_parallel_mode == "row":
input_shape[2] = hidden_size // tp_size
args.append(hidden_size)
kwargs["ub_overlap_rs"] = not reference
elif config.linear_parallel_mode == "column":
input_shape[0] = config.seq_length // tp_size
args.append(3 * hidden_size)
kwargs["ub_overlap_rs"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["parallel_mode"] = config.linear_parallel_mode
kwargs["ub_name"] = "proj" if config.linear_parallel_mode == "row" else "qkv"
else:
input_shape[0] = config.seq_length // tp_size
kwargs["ub_bulk_wgrad"] = not reference
kwargs["ub_bulk_dgrad"] = not reference
kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference
kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference
if config.layer_type is te.LayerNormLinear:
args.append(3 * hidden_size)
kwargs["parallel_mode"] = "column"
Expand Down Expand Up @@ -135,6 +140,12 @@ def _parse_args(argv=None, namespace=None):
choices=["row", "column"],
help="Parallel mode for te.Linear."
)
parser.add_argument(
"--overlap-rs-dgrad",
action="store_true",
default=False,
help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps."
)
parser.add_argument(
"--debug",
action="store_true",
Expand Down Expand Up @@ -240,12 +251,19 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")

# Intialize userbuffers
ub_cfgs = None
if opts.overlap_rs_dgrad:
ub_cfgs = {
"proj_dgrad" : {"method" : "ring_exchange"},
"qkv_dgrad" : {"method" : "ring_exchange"},
}
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
WORLD_SIZE,
use_fp8=opts.fp8,
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
ub_cfgs=ub_cfgs,
)

# Initialize the Transformer Engine layer with overlap
Expand Down Expand Up @@ -324,27 +342,29 @@ def run_fwd_bwd(model, x):
ref_grads.append(ref_param.grad)

# Make sure we have the same number of gradients
numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
num_grads_failed = torch.tensor([0], dtype=torch.uint8, device="cuda")
if len(test_grads) != len(ref_grads):
numerics_failed[0] = 1
num_grads_failed[0] = 1
numerics_info = (
"NUMERICAL CHECK FAILED: Incorrect number of gradients, "
+ f"expected {len(ref_grads)} but got {len(test_grads)}."
)
dist_print(numerics_info, src=WORLD_RANK, error=True)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
dist.all_reduce(num_grads_failed, dist.ReduceOp.MAX, nccl_world)

# Now validate accuracy
if not bool(numerics_failed.item()):
numerics_failed = torch.zeros(len(test_grads), dtype=torch.uint8, device="cuda")
if not bool(num_grads_failed.item()):
for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)):
rtol = 0.125 if opts.fp8 else 0.025
atol = 0.0625 if opts.fp8 else 0.00125
grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol)
dist_print(grad_info, src=WORLD_RANK, error=grad_failed)
numerics_failed[0] = int(grad_failed)
dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world)
if bool(numerics_failed.item()):
break
numerics_failed[i] = int(grad_failed)
return_code = torch.max(numerics_failed)
dist.all_reduce(return_code, dist.ReduceOp.MAX, nccl_world)
else:
return_code = num_grads_failed

te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True)
Expand All @@ -354,7 +374,7 @@ def run_fwd_bwd(model, x):
if opts.debug and WORLD_RANK == 0:
print("Exiting...\n", end="", flush=True)

return numerics_failed[0].item()
return return_code.item()


if __name__ == "__main__":
Expand Down
14 changes: 8 additions & 6 deletions tests/pytorch/distributed/test_comm_gemm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg
raise AssertionError(result.stderr.decode())


def _run_layer_with_overlap(layer_type, fp8, fp8_init):
def _run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
test_path = TEST_ROOT / "run_layer_with_overlap.py"
test_cmd = LAUNCH_CMD + [
str(test_path),
Expand All @@ -99,6 +99,8 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init):
f"--head-dim={HEAD_DIM}",
f"--layer-type={layer_type}",
]
if layer_type == te.Linear.__name__:
test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}")

if fp8:
if not fp8_available:
Expand Down Expand Up @@ -229,12 +231,12 @@ def test_bulk_overlaps(comm_type, fp8):
@pytest.mark.parametrize(
"layer_type,linear_parallel_mode",
(
list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))]))
+ [(te.Linear, "row"), (te.Linear, "column")]
[(te.Linear.__name__, "row"), (te.Linear.__name__, "column")]
+ list(zip([layer.__name__ for layer in TE_LAYERS], [None for _ in range(len(TE_LAYERS))]))
),
ids=(
[(" " + layer.__name__ + " ") for layer in TE_LAYERS]
+ [" te.Linear (row-parallel) ", " te.Linear (column-parallel) "]
[f" {te.Linear.__name__} (row-parallel) ", f" {te.Linear.__name__} (column-parallel) "]
+ [(" " + layer.__name__ + " ") for layer in TE_LAYERS]
),
)
@pytest.mark.parametrize(
Expand All @@ -254,4 +256,4 @@ def test_layers_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init):
"""
Test Transformer Engine layers with comm+GEMM overlap.
"""
_run_layer_with_overlap(layer_type, fp8, fp8_init)
_run_layer_with_overlap(layer_type, linear_parallel_mode, fp8, fp8_init)
Loading

0 comments on commit 7cb1954

Please sign in to comment.