Skip to content

Commit

Permalink
Merge pull request #208 from AleHD/mem_fix_async
Browse files Browse the repository at this point in the history
Memory optimization in async tp-linear
  • Loading branch information
3outeille authored Aug 5, 2024
2 parents 4eb520f + 0adb368 commit 03d67f2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 41 deletions.
80 changes: 42 additions & 38 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,12 @@ class _ColumnLinearAsyncCommunication(torch.autograd.Function):

@staticmethod
@assert_cuda_max_connections_set_to_1
def forward(ctx, tensor, weight, bias, group, tp_mode):
def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather):
ctx.use_bias = bias is not None
ctx.tp_mode = tp_mode
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.tensor_shape = tensor.size()

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
gathered_tensor = tensor
Expand All @@ -140,22 +142,27 @@ def forward(ctx, tensor, weight, bias, group, tp_mode):
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
tensor = tensor.contiguous()
ctx.save_for_backward(tensor, weight)
# ctx.save_for_backward(tensor, weight)

# TODO @thomasw21: gather along another dimension
sharded_batch_size, *intermediate_size, hidden_size = tensor.shape
if group is None:
group = dist.distributed_c10d._get_default_group()
gathered_batch_size = sharded_batch_size * group.size()

gathered_tensor = torch.empty(
gathered_batch_size,
*intermediate_size,
hidden_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
)
if tp_recompute_allgather:
gathered_tensor = MemoryBuffer().get(
"allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype
)
else:
gathered_tensor = torch.empty(
gathered_batch_size,
*intermediate_size,
hidden_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=False,
)

handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True)

Expand Down Expand Up @@ -203,6 +210,10 @@ def forward(ctx, tensor, weight, bias, group, tp_mode):

# Wait communication
handle.wait()
if tp_recompute_allgather:
ctx.save_for_backward(tensor, weight)
else:
ctx.save_for_backward(gathered_tensor, weight)

# Compute all the other shards that are obtained from AllGather
# weights: w0 w1 w2 w3
Expand Down Expand Up @@ -260,8 +271,8 @@ def backward(ctx, grad_output):
use_bias = ctx.use_bias
tp_mode = ctx.tp_mode

handle: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
handle1: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather:
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = tensor.shape
if group is None:
Expand All @@ -272,14 +283,10 @@ def backward(ctx, grad_output):
else:
unsharded_batch_size = sharded_batch_size * group.size()

unsharded_tensor = torch.empty(
unsharded_batch_size,
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=False,
unsharded_tensor = MemoryBuffer().get(
"allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)
handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the tensor gradient computation
total_tensor = unsharded_tensor
Expand All @@ -288,9 +295,6 @@ def backward(ctx, grad_output):

grad_tensor = grad_output.matmul(weight)

if handle is not None:
handle.wait()

# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
Expand All @@ -302,37 +306,41 @@ def backward(ctx, grad_output):
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim)

handle: Optional[dist.Work] = None
handle2: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
if group.size() == 1:
sub_grad_tensor = grad_tensor
else:
sub_grad_tensor = torch.empty(
tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False
ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False
)
# reduce_scatter
handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)
handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_tensor, group=group, async_op=True)
handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
else:
raise ValueError()

grad_bias = grad_output.sum(dim=0) if use_bias else None

if handle1 is not None:
handle1.wait()

# TODO @thomasw21: This sounds like we don't have the optimal physical layout
grad_weight = grad_output.t().matmul(total_tensor)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if handle is not None:
handle.wait()
if handle2 is not None:
handle2.wait()

if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return sub_grad_tensor, grad_weight, grad_bias, None, None
return sub_grad_tensor, grad_weight, grad_bias, None, None, None
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
return grad_tensor, grad_weight, grad_bias, None, None
return grad_tensor, grad_weight, grad_bias, None, None, None
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")

Expand Down Expand Up @@ -430,7 +438,7 @@ def column_linear(
tp_recompute_allgather: bool = True,
):
if async_communication:
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
input = differentiable_identity(input, group=group)
Expand Down Expand Up @@ -480,12 +488,8 @@ def backward(ctx, grad_output):
else:
unsharded_batch_size = sharded_batch_size * group.size()

total_grad_output = torch.empty(
unsharded_batch_size,
*rest_size,
device=grad_output.device,
dtype=grad_output.dtype,
requires_grad=False,
total_grad_output = MemoryBuffer().get(
"allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)

# Doing gather + slicing during the NeMo forward pass can make this tensor
Expand Down
23 changes: 20 additions & 3 deletions tests/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,32 @@ def _test_column_linear(
@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@pytest.mark.parametrize("async_communication", [False, True])
@pytest.mark.parametrize("tp_recompute_allgather", [False, True])
@rerun_if_address_is_in_use()
def test_row_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool):
def test_row_linear(
tp: int,
dp: int,
pp: int,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication:
pytest.skip("ALL_REDUCE mode does not support async communication")
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather:
pytest.skip("ALL_REDUCE mode is not affected by tp_recompute_allgather")

init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(tp_mode=tp_mode, async_communication=async_communication)
init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(
tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather
)


def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool):
def _test_row_linear(
parallel_context: ParallelContext,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if async_communication:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
out_features = 3
Expand Down

0 comments on commit 03d67f2

Please sign in to comment.