From 56c2719499ca48b0a56d8f9b8a77378dcd2bbc9a Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Wed, 29 Jan 2025 07:52:16 +0000 Subject: [PATCH 1/6] Initial commit --- tests/python/test_dtensor.py | 118 +++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) diff --git a/tests/python/test_dtensor.py b/tests/python/test_dtensor.py index de52067b500..2a3168dc7d0 100644 --- a/tests/python/test_dtensor.py +++ b/tests/python/test_dtensor.py @@ -116,3 +116,121 @@ def define_fusion(fd: FusionDefinition): torch.testing.assert_close(out_dtensor.to_local(), in_dtensor.to_local() + 1) assert out_dtensor.device_mesh == in_dtensor.device_mesh assert out_dtensor.placements == in_dtensor.placements + + +@pytest.mark.mpi +def test_linear(setup_process_group): + class FusionDefintionArguments: + def __init__(self, num_devices: int, batch: int, sequence: int, hidden: int): + self.d = num_devices + self.b = batch + self.s = sequence + self.e = hidden + + class LinearForwardDefinition(FusionDefintionArguments): + def __call__(self, fd: FusionDefinition): + inp = fd.define_tensor([self.b, self.s, self.e]) + weight = fd.define_tensor( + [self.d, self.e, self.e], contiguity=[True, True, True] + ) + bias = fd.define_tensor([self.d, self.e], contiguity=[True, True]) + out = fd.ops.linear(inp, weight, bias) + fd.add_output(out) + + class LinearBackwardDefinition(FusionDefintionArguments): + def __call__(self, fd: FusionDefinition): + x = fd.define_tensor([self.b, self.s, self.e]) + x = fd.ops.reshape(x, [self.b * self.s, self.e]) + w = fd.define_tensor([self.d, self.e, self.e], contiguity=True) + grad = fd.define_tensor([self.d, self.b, self.s, self.e], contiguity=True) + grad = fd.ops.reshape(grad, [self.d, self.b * self.s, self.e]) + + grad_x_partials = fd.ops.matmul(grad, w) + grad_x = fd.ops.sum(grad_x_partials, [0]) # all reduce + grad_t = fd.ops.permute(grad, [0, 2, 1]) + grad_w = fd.ops.matmul(grad_t, x) + grad_b = fd.ops.sum(grad, [1]) + + grad_x = fd.ops.reshape(grad_x, [self.b, self.s, self.e]) + fd.add_output(grad_x) + fd.add_output(grad_w) + fd.add_output(grad_b) + + class LinearFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: DTensor, + weight: DTensor, + bias: DTensor, + ): + b, s, e = input._local_tensor.shape + d = weight.device_mesh.size() + op = FusionDefinitionWrapper(LinearForwardDefinition(d, b, s, e)) + outputs = op([input, weight, bias]) + ctx.save_for_backward(input, weight) + return outputs[0] + + @staticmethod + def backward(ctx, grad_output: DTensor): + d, b, s, e = grad_output.shape + op = FusionDefinitionWrapper(LinearBackwardDefinition(d, b, s, e)) + input, weight = ctx.saved_tensors + outputs = op([input, weight, grad_output]) + return outputs[0], outputs[1], outputs[2] + + world_size = dist.get_world_size() + rank = dist.get_rank() + torch.cuda.set_device(rank) + + mesh = dist.device_mesh.init_device_mesh("cuda", [world_size]) + + d = world_size + b, s, e = 2, 1024, 768 + inp_tensor = torch.randn(b, s, e, device="cuda", requires_grad=True) + weight_tensor = torch.randn(world_size, e, e, device="cuda", requires_grad=True) + bias_tensor = torch.randn(world_size, e, device="cuda", requires_grad=True) + + inp_dtensor = dist.tensor.distribute_tensor(inp_tensor, mesh, [Replicate()]) + weight_dtensor = dist.tensor.distribute_tensor(weight_tensor, mesh, [Shard(0)]) + bias_dtensor = dist.tensor.distribute_tensor(bias_tensor, mesh, [Shard(0)]) + + # expected forward + unsharded_out_tensor = torch.nn.functional.linear( + inp_tensor, weight_tensor.view([d * e, e]), bias_tensor.view([d * e]) + ) + expected_out_tensor = unsharded_out_tensor.view([b, s, d, e]).permute(2, 0, 1, 3)[ + rank : rank + 1 + ] + + # multidevice forward + out_dtensor = LinearFunction.apply(inp_dtensor, weight_dtensor, bias_dtensor) + + # expected backward + (expected_grad_x, expected_grad_w, expected_grad_b) + = torch.autograd.grad( + unsharded_out_tensor, + (inp_tensor, weight_tensor, bias_tensor), + torch.ones_like(unsharded_out_tensor) + ) + + # multidevice backward + (grad_x, grad_w, grad_b) + = torch.autograd.grad( + out_dtensor, + (inp_dtensor, weight_dtensor, bias_dtensor), + torch.ones_like(out_dtensor) + ) + + torch.testing.assert_close( + out_dtensor.to_local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3 + ) + torch.testing.assert_close( + expected_grad_x, grad_x.to_local(), rtol=1.3e-6, atol=1e-3 + ) + torch.testing.assert_close( + expected_grad_w[rank : rank + 1], grad_w.to_local(), rtol=1.3e-6, atol=1e-3 + ) + torch.testing.assert_close( + expected_grad_b[rank : rank + 1], grad_b.to_local(), rtol=1.3e-6, atol=1e-3 + ) From 893506362581e926c314d69576fbf17283f4b3c5 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Wed, 29 Jan 2025 00:13:48 -0800 Subject: [PATCH 2/6] lint --- tests/python/test_dtensor.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/tests/python/test_dtensor.py b/tests/python/test_dtensor.py index 2a3168dc7d0..6d976e0ee74 100644 --- a/tests/python/test_dtensor.py +++ b/tests/python/test_dtensor.py @@ -126,7 +126,7 @@ def __init__(self, num_devices: int, batch: int, sequence: int, hidden: int): self.b = batch self.s = sequence self.e = hidden - + class LinearForwardDefinition(FusionDefintionArguments): def __call__(self, fd: FusionDefinition): inp = fd.define_tensor([self.b, self.s, self.e]) @@ -146,7 +146,7 @@ def __call__(self, fd: FusionDefinition): grad = fd.ops.reshape(grad, [self.d, self.b * self.s, self.e]) grad_x_partials = fd.ops.matmul(grad, w) - grad_x = fd.ops.sum(grad_x_partials, [0]) # all reduce + grad_x = fd.ops.sum(grad_x_partials, [0]) # all reduce grad_t = fd.ops.permute(grad, [0, 2, 1]) grad_w = fd.ops.matmul(grad_t, x) grad_b = fd.ops.sum(grad, [1]) @@ -178,7 +178,7 @@ def backward(ctx, grad_output: DTensor): input, weight = ctx.saved_tensors outputs = op([input, weight, grad_output]) return outputs[0], outputs[1], outputs[2] - + world_size = dist.get_world_size() rank = dist.get_rank() torch.cuda.set_device(rank) @@ -207,20 +207,18 @@ def backward(ctx, grad_output: DTensor): out_dtensor = LinearFunction.apply(inp_dtensor, weight_dtensor, bias_dtensor) # expected backward - (expected_grad_x, expected_grad_w, expected_grad_b) - = torch.autograd.grad( - unsharded_out_tensor, - (inp_tensor, weight_tensor, bias_tensor), - torch.ones_like(unsharded_out_tensor) - ) + (expected_grad_x, expected_grad_w, expected_grad_b) = torch.autograd.grad( + unsharded_out_tensor, + (inp_tensor, weight_tensor, bias_tensor), + torch.ones_like(unsharded_out_tensor), + ) # multidevice backward - (grad_x, grad_w, grad_b) - = torch.autograd.grad( - out_dtensor, - (inp_dtensor, weight_dtensor, bias_dtensor), - torch.ones_like(out_dtensor) - ) + (grad_x, grad_w, grad_b) = torch.autograd.grad( + out_dtensor, + (inp_dtensor, weight_dtensor, bias_dtensor), + torch.ones_like(out_dtensor), + ) torch.testing.assert_close( out_dtensor.to_local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3 From 8d11f75aedabd1bfedee3057a996e8369c768c43 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Mon, 10 Feb 2025 17:10:59 -0800 Subject: [PATCH 3/6] Addresses review --- tests/python/test_dtensor.py | 102 ++++++++++++++++++----------------- 1 file changed, 52 insertions(+), 50 deletions(-) diff --git a/tests/python/test_dtensor.py b/tests/python/test_dtensor.py index 6d976e0ee74..a8fc9e64492 100644 --- a/tests/python/test_dtensor.py +++ b/tests/python/test_dtensor.py @@ -8,6 +8,7 @@ import torch import torch.distributed as dist from collections.abc import Iterable +from functools import partial from nvfuser import DataType, FusionDefinition from torch.distributed.tensor import DTensor from torch.distributed.tensor.placement_types import Placement, Shard, Replicate @@ -120,41 +121,42 @@ def define_fusion(fd: FusionDefinition): @pytest.mark.mpi def test_linear(setup_process_group): - class FusionDefintionArguments: + from dataclasses import dataclass + + @dataclass + class LinearConfig: def __init__(self, num_devices: int, batch: int, sequence: int, hidden: int): self.d = num_devices self.b = batch self.s = sequence self.e = hidden - class LinearForwardDefinition(FusionDefintionArguments): - def __call__(self, fd: FusionDefinition): - inp = fd.define_tensor([self.b, self.s, self.e]) - weight = fd.define_tensor( - [self.d, self.e, self.e], contiguity=[True, True, True] - ) - bias = fd.define_tensor([self.d, self.e], contiguity=[True, True]) - out = fd.ops.linear(inp, weight, bias) - fd.add_output(out) - - class LinearBackwardDefinition(FusionDefintionArguments): - def __call__(self, fd: FusionDefinition): - x = fd.define_tensor([self.b, self.s, self.e]) - x = fd.ops.reshape(x, [self.b * self.s, self.e]) - w = fd.define_tensor([self.d, self.e, self.e], contiguity=True) - grad = fd.define_tensor([self.d, self.b, self.s, self.e], contiguity=True) - grad = fd.ops.reshape(grad, [self.d, self.b * self.s, self.e]) - - grad_x_partials = fd.ops.matmul(grad, w) - grad_x = fd.ops.sum(grad_x_partials, [0]) # all reduce - grad_t = fd.ops.permute(grad, [0, 2, 1]) - grad_w = fd.ops.matmul(grad_t, x) - grad_b = fd.ops.sum(grad, [1]) - - grad_x = fd.ops.reshape(grad_x, [self.b, self.s, self.e]) - fd.add_output(grad_x) - fd.add_output(grad_w) - fd.add_output(grad_b) + def define_linear_forward(config: LinearConfig, fd: FusionDefinition) -> None: + d, b, s, e = config.d, config.b, config.s, config.e + inp = fd.define_tensor([b, s, e]) + weight = fd.define_tensor([d, e, e], contiguity=[True, True, True]) + bias = fd.define_tensor([d, e], contiguity=[True, True]) + out = fd.ops.linear(inp, weight, bias) + fd.add_output(out) + + def define_linear_backward(config: LinearConfig, fd: FusionDefinition) -> None: + d, b, s, e = config.d, config.b, config.s, config.e + x = fd.define_tensor([b, s, e]) + x = fd.ops.reshape(x, [b * s, e]) + w = fd.define_tensor([d, e, e], contiguity=True) + grad = fd.define_tensor([d, b, s, e], contiguity=True) + grad = fd.ops.reshape(grad, [d, b * s, e]) + + grad_x_partials = fd.ops.matmul(grad, w) + grad_x = fd.ops.sum(grad_x_partials, [0]) # all reduce + grad_t = fd.ops.permute(grad, [0, 2, 1]) + grad_w = fd.ops.matmul(grad_t, x) + grad_b = fd.ops.sum(grad, [1]) + + grad_x = fd.ops.reshape(grad_x, [b, s, e]) + fd.add_output(grad_x) + fd.add_output(grad_w) + fd.add_output(grad_b) class LinearFunction(torch.autograd.Function): @staticmethod @@ -166,7 +168,9 @@ def forward( ): b, s, e = input._local_tensor.shape d = weight.device_mesh.size() - op = FusionDefinitionWrapper(LinearForwardDefinition(d, b, s, e)) + op = FusionDefinitionWrapper( + partial(define_linear_forward, LinearConfig(d, b, s, e)) + ) outputs = op([input, weight, bias]) ctx.save_for_backward(input, weight) return outputs[0] @@ -174,22 +178,23 @@ def forward( @staticmethod def backward(ctx, grad_output: DTensor): d, b, s, e = grad_output.shape - op = FusionDefinitionWrapper(LinearBackwardDefinition(d, b, s, e)) + op = FusionDefinitionWrapper( + partial(define_linear_backward, LinearConfig(d, b, s, e)) + ) input, weight = ctx.saved_tensors outputs = op([input, weight, grad_output]) - return outputs[0], outputs[1], outputs[2] + assert len(outputs) == 3 + return (*outputs,) - world_size = dist.get_world_size() + d, b, s, e = dist.get_world_size(), 2, 1024, 768 rank = dist.get_rank() torch.cuda.set_device(rank) - mesh = dist.device_mesh.init_device_mesh("cuda", [world_size]) + mesh = dist.device_mesh.init_device_mesh("cuda", [d]) - d = world_size - b, s, e = 2, 1024, 768 inp_tensor = torch.randn(b, s, e, device="cuda", requires_grad=True) - weight_tensor = torch.randn(world_size, e, e, device="cuda", requires_grad=True) - bias_tensor = torch.randn(world_size, e, device="cuda", requires_grad=True) + weight_tensor = torch.randn(d, e, e, device="cuda", requires_grad=True) + bias_tensor = torch.randn(d, e, device="cuda", requires_grad=True) inp_dtensor = dist.tensor.distribute_tensor(inp_tensor, mesh, [Replicate()]) weight_dtensor = dist.tensor.distribute_tensor(weight_tensor, mesh, [Shard(0)]) @@ -220,15 +225,12 @@ def backward(ctx, grad_output: DTensor): torch.ones_like(out_dtensor), ) - torch.testing.assert_close( - out_dtensor.to_local(), expected_out_tensor, rtol=1.3e-6, atol=1e-3 - ) - torch.testing.assert_close( - expected_grad_x, grad_x.to_local(), rtol=1.3e-6, atol=1e-3 - ) - torch.testing.assert_close( - expected_grad_w[rank : rank + 1], grad_w.to_local(), rtol=1.3e-6, atol=1e-3 - ) - torch.testing.assert_close( - expected_grad_b[rank : rank + 1], grad_b.to_local(), rtol=1.3e-6, atol=1e-3 - ) + def assert_close(expected_tensor, dtensor): + torch.testing.assert_close( + expected_tensor, dtensor.to_local(), rtol=1.3e-6, atol=1e-3 + ) + + assert_close(expected_out_tensor, out_dtensor.to_local()) + assert_close(expected_grad_x, grad_x.to_local()) + assert_close(expected_grad_w[rank : rank + 1], grad_w.to_local()) + assert_close(expected_grad_b[rank : rank + 1], grad_b.to_local()) From 4c8b6fb608546d2539af221d32f27854f5338dbd Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Mon, 10 Feb 2025 17:15:30 -0800 Subject: [PATCH 4/6] Addresses review --- tests/python/test_dtensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_dtensor.py b/tests/python/test_dtensor.py index a8fc9e64492..8eaaa60df37 100644 --- a/tests/python/test_dtensor.py +++ b/tests/python/test_dtensor.py @@ -8,6 +8,7 @@ import torch import torch.distributed as dist from collections.abc import Iterable +from dataclasses import dataclass from functools import partial from nvfuser import DataType, FusionDefinition from torch.distributed.tensor import DTensor @@ -121,7 +122,6 @@ def define_fusion(fd: FusionDefinition): @pytest.mark.mpi def test_linear(setup_process_group): - from dataclasses import dataclass @dataclass class LinearConfig: From 9c42d15e85a9b663b3722bf549104e550f9da027 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Mon, 10 Feb 2025 17:17:57 -0800 Subject: [PATCH 5/6] Lint --- tests/python/test_dtensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/test_dtensor.py b/tests/python/test_dtensor.py index 8eaaa60df37..3df6aa05595 100644 --- a/tests/python/test_dtensor.py +++ b/tests/python/test_dtensor.py @@ -122,7 +122,6 @@ def define_fusion(fd: FusionDefinition): @pytest.mark.mpi def test_linear(setup_process_group): - @dataclass class LinearConfig: def __init__(self, num_devices: int, batch: int, sequence: int, hidden: int): From 07fcc380c630d358028cbc5167513582d9aaaef0 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Tue, 11 Feb 2025 15:46:09 -0800 Subject: [PATCH 6/6] Fixes typo --- tests/python/test_dtensor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/test_dtensor.py b/tests/python/test_dtensor.py index 3df6aa05595..8131cc15230 100644 --- a/tests/python/test_dtensor.py +++ b/tests/python/test_dtensor.py @@ -229,7 +229,7 @@ def assert_close(expected_tensor, dtensor): expected_tensor, dtensor.to_local(), rtol=1.3e-6, atol=1e-3 ) - assert_close(expected_out_tensor, out_dtensor.to_local()) - assert_close(expected_grad_x, grad_x.to_local()) - assert_close(expected_grad_w[rank : rank + 1], grad_w.to_local()) - assert_close(expected_grad_b[rank : rank + 1], grad_b.to_local()) + assert_close(expected_out_tensor, out_dtensor) + assert_close(expected_grad_x, grad_x) + assert_close(expected_grad_w[rank : rank + 1], grad_w) + assert_close(expected_grad_b[rank : rank + 1], grad_b)