diff --git a/README.md b/README.md index 6893d7d57..07a4147d5 100644 --- a/README.md +++ b/README.md @@ -51,3 +51,11 @@ The `*_total_*_size_dist/` statistics the `op_type`'s input/output_size distribu - Notice: the [aten ir interface is in there](https://pytorch.org/docs/stable/torch.compiler_ir.html) [The `profile/` is the tools provided by pytorch](https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html), you can open it by the url: chrome://tracing + +# Run transformer models +To run transformer model with ttnn backend, run: +``` +PYTHONPATH=${TT_METAL_HOME}:$(pwd) python3 tools/run_transformers.py --model "phiyodr/bert-large-finetuned-squad2" --backend torch_ttnn +``` + +You can also substitute the backend with `torch_stat` to run a reference comparison. diff --git a/tests/test_cse.py b/tests/test_cse.py index b6707c088..9737e34e9 100644 --- a/tests/test_cse.py +++ b/tests/test_cse.py @@ -19,11 +19,11 @@ def input_shapes(self): class TestModules(unittest.TestCase): def setUp(self): # Open device 0 - self.device: ttnn.Device = ttnn.open(0) + self.device: ttnn.Device = ttnn.open_device(device_id=0) def tearDown(self): # Close the device - ttnn.close(self.device) + ttnn.close_device(self.device) def test_add(self): m = AddModule() @@ -32,19 +32,21 @@ def test_add(self): result_before = m.forward(*inputs) option = torch_ttnn.TorchTtnnOption(device=self.device) # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) result_after = m.forward(*inputs) self.assertEqual(1, len(option._out_fx_graphs)) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) - self.assertEqual(nodes[3].target, ttnn.add) - self.assertEqual(nodes[3].args[0].target, ttnn.to_device) - self.assertEqual(nodes[3].args[0].args[0].target, ttnn.from_torch) - self.assertEqual(nodes[3].args[1].target, ttnn.to_device) - self.assertEqual(nodes[3].args[1].args[0].target, ttnn.from_torch) - self.assertEqual(nodes[4].target, ttnn.from_device) - self.assertEqual(nodes[5].target, ttnn.to_layout) - self.assertEqual(nodes[6].target, ttnn.to_torch) + self.assertEqual(nodes[4].target, ttnn.add) + self.assertEqual(nodes[4].args[0].target, ttnn.to_device) + self.assertEqual(nodes[4].args[0].args[0].target, ttnn.to_layout) + self.assertEqual(nodes[4].args[0].args[0].args[0].target, ttnn.from_torch) + self.assertEqual(nodes[4].args[1].target, ttnn.to_device) + self.assertEqual(nodes[4].args[1].args[0].target, ttnn.to_layout) + self.assertEqual(nodes[4].args[1].args[0].args[0].target, ttnn.from_torch) + self.assertEqual(nodes[5].target, ttnn.from_device) + self.assertEqual(nodes[6].target, ttnn.to_layout) + self.assertEqual(nodes[7].target, ttnn.to_torch) # Check inference result self.assertTrue(torch.allclose(result_before, result_after)) diff --git a/tests/test_fall_back.py b/tests/test_fall_back.py index f3b7897be..c1613b845 100644 --- a/tests/test_fall_back.py +++ b/tests/test_fall_back.py @@ -3,6 +3,8 @@ import unittest from torch_ttnn import ttnn +from torch_ttnn.utils import check_with_pcc + class MixModule(torch.nn.Module): def __init__(self): @@ -23,11 +25,11 @@ def input_shapes(self): class TestModules(unittest.TestCase): def setUp(self): # Open device 0 - self.device: ttnn.Device = ttnn.open(0) + self.device: ttnn.Device = ttnn.open_device(device_id=0) def tearDown(self): # Close the device - ttnn.close(self.device) + ttnn.close_device(self.device) def test_fall_back(self): m = MixModule() @@ -37,19 +39,29 @@ def test_fall_back(self): option = torch_ttnn.TorchTtnnOption(device=self.device) option.gen_graphviz = True # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) result_after = m.forward(*inputs) self.assertEqual(1, len(option._out_fx_graphs)) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) - self.assertEqual(nodes[3].target, ttnn.from_torch) + self.assertEqual(nodes[2].target, ttnn.from_torch) + self.assertEqual(nodes[3].target, ttnn.to_layout) self.assertEqual(nodes[4].target, ttnn.to_device) - self.assertEqual(nodes[5].target, ttnn.add) - self.assertEqual(nodes[6].target, ttnn.matmul) - self.assertEqual(nodes[7].target, ttnn.from_device) - self.assertEqual(nodes[8].target, ttnn.to_layout) - self.assertEqual(nodes[9].target, ttnn.to_torch) + self.assertEqual(nodes[5].target, ttnn.reciprocal) + self.assertEqual(nodes[6].target, ttnn.from_torch) + self.assertEqual(nodes[7].target, ttnn.to_layout) + self.assertEqual(nodes[8].target, ttnn.to_device) + self.assertEqual(nodes[9].target, ttnn.mul) + self.assertEqual(nodes[10].target, ttnn.add) + self.assertEqual(nodes[11].target, ttnn.matmul) + self.assertEqual(nodes[12].target, ttnn.reciprocal) + self.assertEqual(nodes[13].target, ttnn.mul) + self.assertEqual(nodes[14].target, ttnn.reciprocal) + self.assertEqual(nodes[15].target, ttnn.mul) + self.assertEqual(nodes[16].target, ttnn.from_device) + self.assertEqual(nodes[17].target, ttnn.to_layout) + self.assertEqual(nodes[18].target, ttnn.to_torch) # Check inference result - self.assertTrue(torch.allclose(result_before, result_after)) + self.assertTrue(check_with_pcc(result_before, result_after)) diff --git a/tests/test_if.py b/tests/test_if.py index 367c39f76..ea05dd9b7 100644 --- a/tests/test_if.py +++ b/tests/test_if.py @@ -22,11 +22,11 @@ def input_shapes(self): class TestModules(unittest.TestCase): def setUp(self): # Open device 0 - self.device: ttnn.Device = ttnn.open(0) + self.device: ttnn.Device = ttnn.open_device(device_id=0) def tearDown(self): # Close the device - ttnn.close(self.device) + ttnn.close_device(self.device) def test_if(self): m = IfModule() @@ -36,7 +36,7 @@ def test_if(self): result_before_else = m.forward(*inputs_else) option = torch_ttnn.TorchTtnnOption(device=self.device) # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) result_after_then = m.forward(*inputs_then) result_after_else = m.forward(*inputs_else) @@ -49,21 +49,23 @@ def test_if(self): self.assertEqual(nodes_0[1].target, torch.ops.aten.sum.default) self.assertEqual(nodes_0[2].target, torch.ops.aten.gt.Scalar) nodes_1 = list(option._out_fx_graphs[1].nodes) - self.assertEqual(len(nodes_1), 8) + self.assertEqual(len(nodes_1), 9) self.assertEqual(nodes_1[1].target, ttnn.from_torch) - self.assertEqual(nodes_1[2].target, ttnn.to_device) - self.assertEqual(nodes_1[3].target, ttnn.add) - self.assertEqual(nodes_1[4].target, ttnn.from_device) - self.assertEqual(nodes_1[5].target, ttnn.to_layout) - self.assertEqual(nodes_1[6].target, ttnn.to_torch) + self.assertEqual(nodes_1[2].target, ttnn.to_layout) + self.assertEqual(nodes_1[3].target, ttnn.to_device) + self.assertEqual(nodes_1[4].target, ttnn.add) + self.assertEqual(nodes_1[5].target, ttnn.from_device) + self.assertEqual(nodes_1[6].target, ttnn.to_layout) + self.assertEqual(nodes_1[7].target, ttnn.to_torch) nodes_2 = list(option._out_fx_graphs[2].nodes) - self.assertEqual(len(nodes_2), 8) + self.assertEqual(len(nodes_2), 9) self.assertEqual(nodes_2[1].target, ttnn.from_torch) - self.assertEqual(nodes_2[2].target, ttnn.to_device) - self.assertEqual(nodes_2[3].target, ttnn.matmul) - self.assertEqual(nodes_2[4].target, ttnn.from_device) - self.assertEqual(nodes_2[5].target, ttnn.to_layout) - self.assertEqual(nodes_2[6].target, ttnn.to_torch) + self.assertEqual(nodes_2[2].target, ttnn.to_layout) + self.assertEqual(nodes_2[3].target, ttnn.to_device) + self.assertEqual(nodes_2[4].target, ttnn.matmul) + self.assertEqual(nodes_2[5].target, ttnn.from_device) + self.assertEqual(nodes_2[6].target, ttnn.to_layout) + self.assertEqual(nodes_2[7].target, ttnn.to_torch) # Check inference result self.assertTrue(torch.allclose(result_before_then, result_after_then)) diff --git a/tests/test_more_ops.py b/tests/test_more_ops.py index cc9dc4275..90df3658c 100644 --- a/tests/test_more_ops.py +++ b/tests/test_more_ops.py @@ -3,6 +3,8 @@ import unittest from torch_ttnn import ttnn +from torch_ttnn.utils import check_with_pcc + class SubModule(torch.nn.Module): def __init__(self): @@ -26,204 +28,1790 @@ def input_shapes(self): return [(4, 4), (4, 4)] -class SoftmaxModule(torch.nn.Module): - def __init__(self): - super().__init__() +class SoftmaxModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, axis): + return torch.softmax(x, axis) + + def input_shapes(self): + return [(1, 1, 64, 32)] + + +class TanhModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.tanh(x) + + def input_shapes(self): + return [(4, 4)] + + +class ReshapeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, new_shape): + return torch.reshape(x, new_shape) + + def input_shapes(self): + return [(32, 2 * 32)] + + def output_shapes(self): + return [(2 * 32, 32)] + + +class ReshapeNegativeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, new_shape): + return torch.reshape(x, new_shape) + + def input_shapes(self): + return [(32, 2 * 32)] + + def output_shapes(self): + return [(-1,)] + + +class Reshape4DModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, new_shape): + return torch.reshape(x, new_shape) + + def input_shapes(self): + return [(64, 32, 16, 32)] + + def output_shapes(self): + return [(16, 32, 64, 32)] + + +class Reshape4DNegativeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, new_shape): + return torch.reshape(x, new_shape) + + def input_shapes(self): + return [(1, 4, 64, 32)] + + def output_shapes(self): + return [(1, -1, 2, 32)] + + +class PermuteModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, order): + return torch.permute(x, order) + + def input_shapes(self): + return [(4, 4)] + + +class ReluModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.relu(x) + + def input_shapes(self): + return [(4, 4)] + + +class AddMmModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, mat1, mat2): + return torch.addmm(input, mat1, mat2) + + def input_shapes(self): + return [(4, 4), (4, 4), (4, 4)] + + +class DivModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, numerator, denominator): + return torch.div(numerator, denominator) + + def input_shapes(self): + return [(4, 4), (4, 4)] + + +class DivScalarDenomModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, numerator, denominator): + return torch.div(numerator, denominator) + + def input_shapes(self): + return [(4, 4)] + + +class GeluModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.gelu(input) + + def input_shapes(self): + return [(4, 4)] + + +class RSubModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.rsub(x, y) + + def input_shapes(self): + return [(4, 4), (4, 4)] + + +class RSubScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, scalar): + return torch.rsub(x, scalar) + + def input_shapes(self): + return [(4, 4)] + + +class EmbeddingModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, weight): + embedding = torch.nn.Embedding.from_pretrained(weight) + return embedding(input) + + def input_shapes(self): + return [((1, 2, 4, 5), (4, 3, 2, 9)), (10, 4)] + + +class EmbeddingTileLayoutModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, weights): + return torch.nn.functional.embedding(input, weights) + + def input_shapes(self): + # from test_bloom_embedding at tt-metal/tests/ttnn/unit_tests/operations/test_embedding.py + batch_size = 8 + sentence_size = 384 + vocabulary_size = 250880 + hidden_embedding_dim = 1024 + return [ + (0, vocabulary_size - 1, (batch_size, sentence_size)), + ((vocabulary_size, hidden_embedding_dim)), + ] + + +class CloneFromNodeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + a = input + input + return torch.clone(a) + + def input_shapes(self): + return [(4, 4)] + + +class CloneFromArgModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.clone(input) + + def input_shapes(self): + return [(4, 4)] + + +class LayerNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, embedding, weight, bias): + return torch.nn.functional.layer_norm( + embedding, normalized_shape=[embedding.shape[-1]], weight=weight, bias=bias + ) + + def input_shapes(self): + batch, sentence_length, embedding_dim = 2, 32, 64 + # [embedding, weight, bias] + return [ + (batch, sentence_length, embedding_dim), + (embedding_dim), + (embedding_dim), + ] + + +class LayerNormWithOtherOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, embedding, weight, bias): + layer_norm = torch.nn.functional.layer_norm( + embedding, normalized_shape=[embedding.shape[-1]], weight=weight, bias=bias + ) + return layer_norm + layer_norm + + def input_shapes(self): + batch, sentence_length, embedding_dim = 2, 32, 64 + # [embedding, weight, bias] + return [ + (batch, sentence_length, embedding_dim), + (embedding_dim), + (embedding_dim), + ] + + +class NegModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.neg(input) + + def input_shapes(self): + return [(4)] + + +class OnesModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, shape): + return torch.ones(shape) + + def input_shapes(self): + return [(32, 32)] + + +class TrilModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.tril(input) + + def input_shapes(self): + return [(4, 4)] + + +class ArangeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, end): + # start = 0, step = 1 + return torch.arange(end) + + def input_shapes(self): + return [100] + + +class ArangeStartModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, start, end): + # step = 1 + return torch.arange(start, end) + + def input_shapes(self): + # ttnn.arange does not support star values less than 2? + return [2, 100] + + +class ArangeStartStepModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, start, end, step): + return torch.arange(start, end, step) + + def input_shapes(self): + # ttnn.arange does not support star values less than 2? + return [4, 100, 3] + + +class EqTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, tensor1, tensor2): + return torch.eq(tensor1, tensor2) + + def input_shapes(self): + return [(4, 4), (4, 4)] + + +class EqScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, tensor, scalar): + return torch.eq(tensor, scalar) + + def input_shapes(self): + return [(64, 128)] + + +class LogicalNotModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.logical_not(input) + + def input_shapes(self): + return [(4, 4)] + + +class ZerosLikeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.zeros_like(input) + + def input_shapes(self): + return [(4, 4)] + + +class MeanDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, dim, keepdim=False): + return torch.mean(input, dim, keepdim) + + def input_shapes(self): + return [(1, 32, 32), -1] + + +class PowTensorScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + # torch.fx lowers this into aten.pow.Tensor_Scalar + square = torch.square(input) + return square + + def input_shapes(self): + return [(4, 4)] + + +class RsqrtModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.rsqrt(input) + + def input_shapes(self): + return [(4, 4)] + + +class SiluModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.silu(input) + + def input_shapes(self): + return [(4, 4)] + + +class AdaptiveAvgPool2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch._adaptive_avg_pool2d(input, (1, 1)) + + def input_shapes(self): + return [(1, 2048, 7, 7)] + + +class ClampModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, min, max): + return torch.clamp(input, min=min, max=max) + + def input_shapes(self): + return [(4, 4)] + + +class SqueezeDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, dim): + return torch.squeeze(input, dim) + + def input_shapes(self): + return [(1, 32, 16)] + + +class FullModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, size, fill_value): + return torch.full(size, fill_value) + + def input_shapes(self): + return [(64, 128)] + + +class LtTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, tensor): + return torch.lt(input, tensor) + + def input_shapes(self): + return [(4, 4), (4, 4)] + + +class LtScalarModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, scalar): + return torch.lt(input, scalar) + + def input_shapes(self): + return [(64, 128)] + + +class BaddbmmModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, batch1, batch2, beta=1, alpha=1): + if beta == 1: + return torch.baddbmm(input, batch1, batch2, alpha=alpha) + elif alpha == 1: + return torch.baddbmm(input, batch1, batch2, beta=beta) + elif beta == 1 and alpha == 1: + return torch.baddbmm(input, batch1, batch2) + else: + return torch.baddbmm(input, batch1, batch2, beta=beta, alpha=alpha) + + def input_shapes(self): + # input, batch1, batch2 + return [(10, 64, 128), (10, 64, 32), (10, 32, 128)] + + +class CosModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.cos(x) + + def input_shapes(self): + return [(4, 4)] + + +class SigmoidModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sigmoid(x) + + def input_shapes(self): + return [(4, 4)] + + +class TestModules(unittest.TestCase): + def setUp(self): + # Open device 0 + self.device: ttnn.Device = ttnn.open_device(device_id=0) + + def tearDown(self): + # Close the device + ttnn.close_device(self.device) + + def test_sub(self): + m = SubModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[8].target == ttnn.sub) + self.assertTrue(nodes[8].args[0].target == ttnn.to_device) + self.assertTrue(nodes[8].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[8].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].target == ttnn.from_device) + self.assertTrue(nodes[10].target == ttnn.to_layout) + self.assertTrue(nodes[11].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_mul(self): + m = MulModule() + input_shapes = m.input_shapes() + inputs = [ + torch.randint(1, 5, shape).type(torch.bfloat16) for shape in input_shapes + ] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[8].target == ttnn.mul) + self.assertTrue(nodes[8].args[0].target == ttnn.to_device) + self.assertTrue(nodes[8].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[8].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].target == ttnn.from_device) + self.assertTrue(nodes[10].target == ttnn.to_layout) + self.assertTrue(nodes[11].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_softmax(self): + m = SoftmaxModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + axis = -1 + result_before = m.forward(*inputs, axis) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs, axis) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.softmax) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2)) + + def test_tanh(self): + m = TanhModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.tanh) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2)) + + @unittest.skip( + "NOTE(kevinwuTT) ttnn.reshape conversion needs to be reworked to support the many restrictions." + ) + def test_reshape(self): + m = ReshapeModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + new_shape = m.output_shapes() + result_before = m.forward(*inputs, *new_shape) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs, *new_shape) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[2].target == ttnn.reshape) + self.assertTrue(nodes[2].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[3].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + @unittest.skip( + "NOTE(kevinwuTT) ttnn.reshape conversion needs to be reworked to support the many restrictions." + ) + def test_reshape_negative(self): + m = ReshapeNegativeModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + new_shape = m.output_shapes() + result_before = m.forward(*inputs, *new_shape) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs, *new_shape) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[2].target == ttnn.reshape) + self.assertTrue(nodes[2].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[3].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + @unittest.skip( + "NOTE(kevinwuTT) ttnn.reshape conversion needs to be reworked to support the many restrictions." + ) + def test_reshape_4d(self): + m = Reshape4DModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + new_shape = m.output_shapes() + result_before = m.forward(*inputs, *new_shape) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs, *new_shape) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[3].target == ttnn.reshape) + self.assertTrue(nodes[3].args[0].target == ttnn.to_device) + self.assertTrue(nodes[3].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[4].target == ttnn.from_device) + self.assertTrue(nodes[5].target == ttnn.to_layout) + self.assertTrue(nodes[6].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + @unittest.skip( + "NOTE(kevinwuTT) ttnn.reshape conversion needs to be reworked to support the many restrictions." + ) + def test_reshape_4d_negative(self): + m = Reshape4DNegativeModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + new_shape = m.output_shapes() + result_before = m.forward(*inputs, *new_shape) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs, *new_shape) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[3].target == ttnn.reshape) + self.assertTrue(nodes[3].args[0].target == ttnn.to_device) + self.assertTrue(nodes[3].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[4].target == ttnn.from_device) + self.assertTrue(nodes[5].target == ttnn.to_layout) + self.assertTrue(nodes[6].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + @unittest.skip( + "NOTE(yoco) This test failed because currently the ttnn.permute does nothing. Seems like the ttnn.permute is not implemented yet." + ) + def test_permute(self): + m = PermuteModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + order = (1, 0) + result_before = m.forward(*inputs, order) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs, order) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.permute) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_relu(self): + m = ReluModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.relu) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_addmm(self): + m = AddMmModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[9].target == ttnn.matmul) + self.assertTrue(nodes[9].args[0].target == ttnn.to_device) + self.assertTrue(nodes[9].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].args[1].target == ttnn.to_device) + self.assertTrue(nodes[9].args[1].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[1].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[13].target == ttnn.add) + self.assertTrue(nodes[13].args[0].target == ttnn.to_device) + self.assertTrue(nodes[13].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[13].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[13].args[1].target == ttnn.matmul) + self.assertTrue(nodes[14].target == ttnn.from_device) + self.assertTrue(nodes[15].target == ttnn.to_layout) + self.assertTrue(nodes[16].target == ttnn.to_torch) + # Check inference result + self.assertTrue(check_with_pcc(result_before, result_after)) + + def test_div(self): + m = DivModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[5].target == ttnn.reciprocal) + self.assertTrue(nodes[5].args[0].target == ttnn.to_device) + self.assertTrue(nodes[5].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[5].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].target == ttnn.mul) + self.assertTrue(nodes[9].args[0].target == ttnn.to_device) + self.assertTrue(nodes[9].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].args[1].target == ttnn.reciprocal) + self.assertTrue(nodes[10].target == ttnn.from_device) + self.assertTrue(nodes[11].target == ttnn.to_layout) + self.assertTrue(nodes[12].target == ttnn.to_torch) + # Check inference result + self.assertTrue(check_with_pcc(result_before, result_after)) + + def test_div_scalar_denom(self): + m = DivScalarDenomModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs, 5.0) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs, 5.0) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[5].target == ttnn.reciprocal) + self.assertTrue(nodes[5].args[0].target == ttnn.to_device) + self.assertTrue(nodes[5].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[5].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].target == ttnn.mul) + self.assertTrue(nodes[9].args[0].target == ttnn.to_device) + self.assertTrue(nodes[9].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].args[1].target == ttnn.reciprocal) + self.assertTrue(nodes[10].target == ttnn.from_device) + self.assertTrue(nodes[11].target == ttnn.to_layout) + self.assertTrue(nodes[12].target == ttnn.to_torch) + # Check inference result + self.assertTrue(check_with_pcc(result_before, result_after)) + + def test_gelu(self): + m = GeluModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.gelu) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + print(result_before, "\n", result_after) + self.assertTrue(check_with_pcc(result_before, result_after)) + + def test_rsub(self): + m = RSubModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[8].target == ttnn.sub) + self.assertTrue(nodes[8].args[0].target == ttnn.to_device) + self.assertTrue(nodes[8].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[8].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].target == ttnn.from_device) + self.assertTrue(nodes[10].target == ttnn.to_layout) + self.assertTrue(nodes[11].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_rsub_scalar(self): + m = RSubScalarModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs, 5) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs, 5) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + # self.aseertTrue(nodes[1].target == ttnn.full) + self.assertTrue(nodes[8].target == ttnn.sub) + self.assertTrue(nodes[8].args[0].target == ttnn.to_device) + self.assertTrue(nodes[8].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[8].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].target == ttnn.from_device) + self.assertTrue(nodes[10].target == ttnn.to_layout) + self.assertTrue(nodes[11].target == ttnn.to_torch) + # Check inference result + self.assertTrue(check_with_pcc(result_before, result_after, 0.9998)) + + @unittest.skip( + "NOTE(kevinwuTT) Re-enable after conversion to ttnn.embedding with both ROW_MAJOR_LAYOUT and TILE_LAYOUT" + ) + def test_embedding(self): + m = EmbeddingModule() + input_shapes = m.input_shapes() + input = torch.tensor(input_shapes[0]) + weight = torch.rand(input_shapes[1], dtype=torch.bfloat16) + result_before = m.forward(input, weight) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input, weight) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[6].target == ttnn.embedding) + self.assertTrue(nodes[6].args[0].target == ttnn.to_device) + self.assertTrue(nodes[6].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[6].args[1].target == ttnn.to_device) + self.assertTrue(nodes[6].args[1].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[7].target == ttnn.from_device) + self.assertTrue(nodes[8].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_embedding_tile_layout(self): + m = EmbeddingTileLayoutModule() + input_shapes = m.input_shapes() + input = torch.randint(*input_shapes[0]) + weights = torch.zeros(*input_shapes[1], dtype=torch.bfloat16).uniform_( + -0.1, 0.1 + ) + result_before = m.forward(input, weights) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input, weights) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[6].target == ttnn.embedding) + self.assertTrue(nodes[6].args[0].target == ttnn.to_device) + self.assertTrue(nodes[6].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[6].args[1].target == ttnn.to_device) + self.assertTrue(nodes[6].args[1].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[7].target == ttnn.from_device) + self.assertTrue(nodes[8].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_clone_from_arg(self): + m = CloneFromArgModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = False + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[5].target == ttnn.clone) + self.assertTrue(nodes[5].args[0].target == ttnn.to_device) + self.assertTrue(nodes[5].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[5].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].args[1].target == ttnn.MemoryConfig) + self.assertTrue(nodes[6].target == ttnn.from_device) + self.assertTrue(nodes[7].target == ttnn.to_layout) + self.assertTrue(nodes[8].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_clone_from_node(self): + m = CloneFromNodeModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = False + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[6].target == ttnn.clone) + self.assertTrue(nodes[6].args[0].target == ttnn.add) + self.assertTrue(nodes[6].args[1].target == ttnn.MemoryConfig) + self.assertTrue(nodes[7].target == ttnn.from_device) + self.assertTrue(nodes[8].target == ttnn.to_layout) + self.assertTrue(nodes[9].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_layer_norm(self): + m = LayerNormModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[12].target == ttnn.layer_norm) + self.assertTrue(nodes[12].args[0].target == ttnn.to_device) + self.assertTrue(nodes[12].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[12].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[12].kwargs["weight"].target == ttnn.to_device) + self.assertTrue(nodes[12].kwargs["weight"].args[0].target == ttnn.to_layout) + self.assertTrue( + nodes[12].kwargs["weight"].args[0].args[0].target == ttnn.from_torch + ) + self.assertTrue(nodes[12].kwargs["bias"].target == ttnn.to_device) + self.assertTrue(nodes[12].kwargs["bias"].args[0].target == ttnn.to_layout) + self.assertTrue( + nodes[12].kwargs["bias"].args[0].args[0].target == ttnn.from_torch + ) + self.assertTrue(nodes[13].target == ttnn.from_device) + self.assertTrue(nodes[14].target == ttnn.to_layout) + self.assertTrue(nodes[15].target == ttnn.to_torch) + # Check inference result + self.assertTrue(check_with_pcc(result_before, result_after, 0.9998)) + + def test_layer_norm_with_other_op(self): + m = LayerNormWithOtherOpModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + # self.assertTrue(nodes[12].target == ttnn.layer_norm) + # self.assertTrue(nodes[12].args[0].target == ttnn.to_device) + # self.assertTrue(nodes[12].args[0].args[0].target == ttnn.to_layout) + # self.assertTrue(nodes[12].args[0].args[0].args[0].target == ttnn.from_torch) + # self.assertTrue(nodes[12].kwargs["weight"].target == ttnn.to_device) + # self.assertTrue(nodes[12].kwargs["weight"].args[0].target == ttnn.to_layout) + # self.assertTrue(nodes[12].kwargs["weight"].args[0].args[0].target == ttnn.from_torch) + # self.assertTrue(nodes[12].kwargs["bias"].target == ttnn.to_device) + # self.assertTrue(nodes[12].kwargs["bias"].args[0].target == ttnn.to_layout) + # self.assertTrue(nodes[12].kwargs["bias"].args[0].args[0].target == ttnn.from_torch) + # self.assertTrue(nodes[13].target == ttnn.to_layout) + # self.assertTrue(nodes[14].target == ttnn.from_device) + # self.assertTrue(nodes[15].target == ttnn.to_torch) + # Check inference result + self.assertTrue(check_with_pcc(result_before, result_after, 0.9998)) + + def test_neg(self): + m = NegModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.neg) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_ones(self): + m = OnesModule() + input_shapes = m.input_shapes()[0] + result_before = m.forward(input_shapes) + result_before = result_before.to(torch.bfloat16) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input_shapes) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[0].target == ttnn.ones) + self.assertTrue(nodes[1].target == ttnn.from_device) + self.assertTrue(nodes[2].target == ttnn.to_layout) + self.assertTrue(nodes[3].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_tril(self): + m = TrilModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.tril) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + @unittest.skip( + "NOTE(kevinwuTT) This test fails because ttnn.arange does not support start value of 0." + ) + def test_arange(self): + m = ArangeModule() + input_shapes = m.input_shapes() + # inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*input_shapes).to(torch.bfloat16) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*input_shapes) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[0].target == ttnn.arange) + self.assertTrue(nodes[1].target == ttnn.from_device) + self.assertTrue(nodes[2].target == ttnn.to_layout) + self.assertTrue(nodes[3].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) - def forward(self, x, axis): - return torch.softmax(x, axis) + def test_arange_start(self): + m = ArangeStartModule() + input_shapes = m.input_shapes() + # inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*input_shapes).to(torch.bfloat16) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*input_shapes) + option._out_fx_graphs[0].print_tabular() - def input_shapes(self): - return [(1, 1, 64, 32)] + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[0].target == ttnn.arange) + self.assertTrue(nodes[1].target == ttnn.from_device) + self.assertTrue(nodes[2].target == ttnn.to_layout) + self.assertTrue(nodes[3].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + def test_arange_start_step(self): + m = ArangeStartStepModule() + input_shapes = m.input_shapes() + result_before = m.forward(*input_shapes).to(torch.bfloat16) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*input_shapes) + option._out_fx_graphs[0].print_tabular() -class TanhModule(torch.nn.Module): - def __init__(self): - super().__init__() + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[0].target == ttnn.arange) + self.assertTrue(nodes[1].target == ttnn.from_device) + self.assertTrue(nodes[2].target == ttnn.to_layout) + self.assertTrue(nodes[3].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) - def forward(self, x): - return torch.tanh(x) + def test_eq_tensor(self): + m = EqTensorModule() + input_shapes = m.input_shapes() + inputs = [ + torch.randint(0, 2, shape, dtype=torch.bfloat16) for shape in input_shapes + ] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() - def input_shapes(self): - return [(4, 4)] + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[8].target == ttnn.eq) + self.assertTrue(nodes[8].args[0].target == ttnn.to_device) + self.assertTrue(nodes[8].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[8].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[8].args[1].target == ttnn.to_device) + self.assertTrue(nodes[8].args[1].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[8].args[1].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].target == ttnn.from_device) + self.assertTrue(nodes[10].target == ttnn.to_layout) + self.assertTrue(nodes[11].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after.to(torch.bool))) + def test_eq_scalar(self): + m = EqScalarModule() + input_shapes = m.input_shapes() + inputs = [ + torch.randint(0, 2, shape, dtype=torch.bfloat16) for shape in input_shapes + ] + scalar = torch.randint(0, 2, (1,)).item() + result_before = m.forward(inputs[0], scalar) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(inputs[0], scalar) + option._out_fx_graphs[0].print_tabular() -class ReshapeModule(torch.nn.Module): - def __init__(self): - super().__init__() + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[1].target == ttnn.full) + self.assertTrue(nodes[5].target == ttnn.eq) + self.assertTrue(nodes[5].args[0].target == ttnn.to_device) + self.assertTrue(nodes[5].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[5].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[6].target == ttnn.from_device) + self.assertTrue(nodes[7].target == ttnn.to_layout) + self.assertTrue(nodes[8].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after.to(torch.bool))) - def forward(self, x, new_shape): - return torch.reshape(x, new_shape) + def test_logical_not(self): + m = LogicalNotModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() - def input_shapes(self): - return [(4, 4)] + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.logical_not) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after.to(torch.bool))) + def test_zeros_like(self): + m = ZerosLikeModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() -class PermuteModule(torch.nn.Module): - def __init__(self): - super().__init__() + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[3].target == ttnn.zeros_like) + self.assertTrue(nodes[3].args[0].target == ttnn.to_device) + self.assertTrue(nodes[3].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[4].target == ttnn.from_device) + self.assertTrue(nodes[5].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) - def forward(self, x, order): - return torch.permute(x, order) + def test_mean_dim(self): + m = MeanDimModule() + input_shapes = m.input_shapes() + input = torch.zeros(input_shapes[0], dtype=torch.bfloat16).uniform_(-1, 1) + dim = input_shapes[1] + keepdim = True + result_before = m.forward(input, dim, keepdim) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input, dim, keepdim) + option._out_fx_graphs[0].print_tabular() - def input_shapes(self): - return [(4, 4)] + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.mean) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(check_with_pcc(result_before, result_after)) + def test_pow_tensor_scalar(self): + m = PowTensorScalarModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() -class TestModules(unittest.TestCase): - def setUp(self): - # Open device 0 - self.device: ttnn.Device = ttnn.open(0) + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.pow) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(check_with_pcc(result_before, result_after)) - def tearDown(self): - # Close the device - ttnn.close(self.device) + def test_rsqrt(self): + m = RsqrtModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() - def test_sub(self): - m = SubModule() + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.rsqrt) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(check_with_pcc(result_before, result_after)) + + def test_silu(self): + m = SiluModule() input_shapes = m.input_shapes() inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] result_before = m.forward(*inputs) option = torch_ttnn.TorchTtnnOption(device=self.device) option.gen_graphviz = True # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) result_after = m.forward(*inputs) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) - self.assertTrue(nodes[6].target == ttnn.sub) - self.assertTrue(nodes[6].args[0].target == ttnn.to_device) - self.assertTrue(nodes[6].args[0].args[0].target == ttnn.from_torch) - self.assertTrue(nodes[7].target == ttnn.from_device) - self.assertTrue(nodes[8].target == ttnn.to_layout) - self.assertTrue(nodes[9].target == ttnn.to_torch) + self.assertTrue(nodes[4].target == ttnn.silu) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) # Check inference result - self.assertTrue(torch.allclose(result_before, result_after)) + self.assertTrue(check_with_pcc(result_before, result_after)) - def test_mul(self): - m = MulModule() + def test_adaptive_avg_pool_2d(self): + m = AdaptiveAvgPool2dModule() input_shapes = m.input_shapes() - inputs = [ - torch.randint(1, 5, shape).type(torch.bfloat16) for shape in input_shapes - ] + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] result_before = m.forward(*inputs) option = torch_ttnn.TorchTtnnOption(device=self.device) option.gen_graphviz = True # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) result_after = m.forward(*inputs) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) - self.assertTrue(nodes[6].target == ttnn.mul) - self.assertTrue(nodes[6].args[0].target == ttnn.to_device) - self.assertTrue(nodes[6].args[0].args[0].target == ttnn.from_torch) - self.assertTrue(nodes[7].target == ttnn.from_device) - self.assertTrue(nodes[8].target == ttnn.to_layout) - self.assertTrue(nodes[9].target == ttnn.to_torch) + self.assertTrue(nodes[4].target == ttnn.global_avg_pool2d) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) # Check inference result - self.assertTrue(torch.allclose(result_before, result_after)) + self.assertTrue(check_with_pcc(result_before, result_after)) - def test_softmax(self): - m = SoftmaxModule() + def test_clamp(self): + m = ClampModule() input_shapes = m.input_shapes() inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] - axis = -1 - result_before = m.forward(*inputs, axis) + min, max = -0.5, 0.5 + result_before = m.forward(inputs[0], min, max) option = torch_ttnn.TorchTtnnOption(device=self.device) option.gen_graphviz = True # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) - result_after = m.forward(*inputs, axis) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(inputs[0], min, max) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) - self.assertTrue(nodes[3].target == ttnn.softmax) - self.assertTrue(nodes[3].args[0].target == ttnn.to_device) - self.assertTrue(nodes[3].args[0].args[0].target == ttnn.from_torch) - self.assertTrue(nodes[4].target == ttnn.from_device) - self.assertTrue(nodes[5].target == ttnn.to_layout) - self.assertTrue(nodes[6].target == ttnn.to_torch) + self.assertTrue(nodes[4].target == ttnn.clip) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) # Check inference result - self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2)) + self.assertTrue(check_with_pcc(result_before, result_after)) - def test_tanh(self): - m = TanhModule() + def test_squeeze_dim(self): + m = SqueezeDimModule() + input_shapes = m.input_shapes() + inputs = [torch.zeros(shape, dtype=torch.bfloat16) for shape in input_shapes] + dim = 0 + result_before = m.forward(inputs[0], dim) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(inputs[0], dim) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.squeeze) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_full(self): + m = FullModule() + input_shapes = m.input_shapes() + fill_value = 1.23 + result_before = m.forward(input_shapes[0], fill_value).to(torch.bfloat16) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input_shapes[0], fill_value) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[0].target == ttnn.full) + self.assertTrue(nodes[1].target == ttnn.from_device) + self.assertTrue(nodes[2].target == ttnn.to_layout) + self.assertTrue(nodes[3].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after)) + + def test_lt_tensor(self): + m = LtTensorModule() input_shapes = m.input_shapes() inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] result_before = m.forward(*inputs) option = torch_ttnn.TorchTtnnOption(device=self.device) option.gen_graphviz = True # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) result_after = m.forward(*inputs) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) - self.assertTrue(nodes[3].target == ttnn.tanh) - self.assertTrue(nodes[3].args[0].target == ttnn.to_device) - self.assertTrue(nodes[3].args[0].args[0].target == ttnn.from_torch) - self.assertTrue(nodes[4].target == ttnn.from_device) - self.assertTrue(nodes[5].target == ttnn.to_layout) - self.assertTrue(nodes[6].target == ttnn.to_torch) + self.assertTrue(nodes[8].target == ttnn.lt) + self.assertTrue(nodes[8].args[0].target == ttnn.to_device) + self.assertTrue(nodes[8].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[8].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[8].args[1].target == ttnn.to_device) + self.assertTrue(nodes[8].args[1].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[8].args[1].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].target == ttnn.from_device) + self.assertTrue(nodes[10].target == ttnn.to_layout) + self.assertTrue(nodes[11].target == ttnn.to_torch) # Check inference result - self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2)) + self.assertTrue(torch.allclose(result_before, result_after.to(torch.bool))) - def test_reshape(self): - m = ReshapeModule() + def test_lt_scalar(self): + m = LtScalarModule() input_shapes = m.input_shapes() inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] - new_shape = (2, 8) - result_before = m.forward(*inputs, new_shape) + scalar = inputs[0][0][0].item() + result_before = m.forward(inputs[0], scalar) option = torch_ttnn.TorchTtnnOption(device=self.device) option.gen_graphviz = True # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) - result_after = m.forward(*inputs, new_shape) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(inputs[0], scalar) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) - self.assertTrue(nodes[3].target == ttnn.reshape) - self.assertTrue(nodes[3].args[0].target == ttnn.to_device) - self.assertTrue(nodes[3].args[0].args[0].target == ttnn.from_torch) - self.assertTrue(nodes[4].target == ttnn.from_device) - self.assertTrue(nodes[5].target == ttnn.to_layout) - self.assertTrue(nodes[6].target == ttnn.to_torch) + self.assertTrue(nodes[1].target == ttnn.full) + self.assertTrue(nodes[5].target == ttnn.lt) + self.assertTrue(nodes[5].args[0].target == ttnn.to_device) + self.assertTrue(nodes[5].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[5].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[6].target == ttnn.from_device) + self.assertTrue(nodes[7].target == ttnn.to_layout) + self.assertTrue(nodes[8].target == ttnn.to_torch) # Check inference result - self.assertTrue(torch.allclose(result_before, result_after)) + self.assertTrue(check_with_pcc(result_before, result_after)) - # NOTE(yoco) This test failed because currently - # the ttnn.permute does nothing. Seems like the ttnn.permute - # is not implemented yet. - def test_permute(self): - m = PermuteModule() + def test_baddbmm(self): + m = BaddbmmModule() input_shapes = m.input_shapes() inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] - order = (1, 0) - result_before = m.forward(*inputs, order) option = torch_ttnn.TorchTtnnOption(device=self.device) option.gen_graphviz = True + + # (1) Test with default alpha and beta values + result_before = m.forward(*inputs) # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) - result_after = m.forward(*inputs, order) + m_ttnn = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m_ttnn.forward(*inputs) + option._out_fx_graphs[-1].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[-1].nodes) + self.assertTrue(nodes[9].target == ttnn.matmul) + self.assertTrue(nodes[9].args[0].target == ttnn.to_device) + self.assertTrue(nodes[9].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].args[1].target == ttnn.to_device) + self.assertTrue(nodes[9].args[1].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[1].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[13].target == ttnn.add) + self.assertTrue(nodes[13].args[0].target == ttnn.to_device) + self.assertTrue(nodes[13].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[13].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[13].args[1].target == ttnn.matmul) + self.assertTrue(nodes[14].target == ttnn.from_device) + self.assertTrue(nodes[15].target == ttnn.to_layout) + self.assertTrue(nodes[16].target == ttnn.to_torch) + # Check inference result + self.assertTrue(check_with_pcc(result_before, result_after)) + + # (2) Test with alpha and default beta value + result_before = m.forward(*inputs, alpha=2) + m_ttnn = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m_ttnn.forward(*inputs, alpha=2) + option._out_fx_graphs[-1].print_tabular() + + nodes = list(option._out_fx_graphs[-1].nodes) + self.assertTrue(nodes[9].target == ttnn.matmul) + self.assertTrue(nodes[9].args[0].target == ttnn.to_device) + self.assertTrue(nodes[9].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].args[1].target == ttnn.to_device) + self.assertTrue(nodes[9].args[1].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[1].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[10].target == ttnn.multiply) + self.assertTrue(nodes[10].args[0].target == ttnn.matmul) + self.assertTrue(nodes[14].target == ttnn.add) + self.assertTrue(nodes[14].args[0].target == ttnn.to_device) + self.assertTrue(nodes[14].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[14].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[14].args[1].target == ttnn.multiply) + self.assertTrue(nodes[15].target == ttnn.from_device) + self.assertTrue(nodes[16].target == ttnn.to_layout) + self.assertTrue(nodes[17].target == ttnn.to_torch) + self.assertTrue(check_with_pcc(result_before, result_after)) + + # (3) Test with beta and default alpha value + result_before = m.forward(*inputs, beta=2) + m_ttnn = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m_ttnn.forward(*inputs, beta=2) + option._out_fx_graphs[-1].print_tabular() + + nodes = list(option._out_fx_graphs[-1].nodes) + self.assertTrue(nodes[9].target == ttnn.matmul) + self.assertTrue(nodes[9].args[0].target == ttnn.to_device) + self.assertTrue(nodes[9].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].args[1].target == ttnn.to_device) + self.assertTrue(nodes[9].args[1].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[1].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[13].target == ttnn.multiply) + self.assertTrue(nodes[13].args[0].target == ttnn.to_device) + self.assertTrue(nodes[13].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[13].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[14].target == ttnn.add) + self.assertTrue(nodes[14].args[0].target == ttnn.multiply) + self.assertTrue(nodes[14].args[1].target == ttnn.matmul) + self.assertTrue(nodes[15].target == ttnn.from_device) + self.assertTrue(nodes[16].target == ttnn.to_layout) + self.assertTrue(nodes[17].target == ttnn.to_torch) + self.assertTrue(check_with_pcc(result_before, result_after)) + + # (4) Test with beta and alpha values + result_before = m.forward(*inputs, beta=2, alpha=2) + m_ttnn = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m_ttnn.forward(*inputs, beta=2, alpha=2) + option._out_fx_graphs[-1].print_tabular() + + nodes = list(option._out_fx_graphs[-1].nodes) + self.assertTrue(nodes[9].target == ttnn.matmul) + self.assertTrue(nodes[9].args[0].target == ttnn.to_device) + self.assertTrue(nodes[9].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].args[1].target == ttnn.to_device) + self.assertTrue(nodes[9].args[1].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[1].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[10].target == ttnn.multiply) + self.assertTrue(nodes[10].args[0].target == ttnn.matmul) + self.assertTrue(nodes[14].target == ttnn.multiply) + self.assertTrue(nodes[14].args[0].target == ttnn.to_device) + self.assertTrue(nodes[14].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[14].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[15].target == ttnn.add) + self.assertTrue(nodes[15].args[0].target == ttnn.multiply) + self.assertTrue(nodes[15].args[1].target == ttnn.multiply) + self.assertTrue(nodes[16].target == ttnn.from_device) + self.assertTrue(nodes[17].target == ttnn.to_layout) + self.assertTrue(nodes[18].target == ttnn.to_torch) + self.assertTrue(check_with_pcc(result_before, result_after)) + + # (5) Test special case when beta is 0 + result_before = m.forward(*inputs, beta=0, alpha=2) + m_ttnn = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m_ttnn.forward(*inputs, beta=0, alpha=2) + option._out_fx_graphs[-1].print_tabular() + + nodes = list(option._out_fx_graphs[-1].nodes) + self.assertTrue(nodes[9].target == ttnn.matmul) + self.assertTrue(nodes[9].args[0].target == ttnn.to_device) + self.assertTrue(nodes[9].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[9].args[1].target == ttnn.to_device) + self.assertTrue(nodes[9].args[1].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[9].args[1].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[10].target == ttnn.multiply) + self.assertTrue(nodes[10].args[0].target == ttnn.matmul) + self.assertTrue(nodes[11].target == ttnn.from_device) + self.assertTrue(nodes[12].target == ttnn.to_layout) + self.assertTrue(nodes[13].target == ttnn.to_torch) + self.assertTrue(check_with_pcc(result_before, result_after)) + + def test_cos(self): + m = CosModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) - self.assertTrue(nodes[3].target == ttnn.permute) - self.assertTrue(nodes[3].args[0].target == ttnn.to_device) - self.assertTrue(nodes[3].args[0].args[0].target == ttnn.from_torch) - self.assertTrue(nodes[4].target == ttnn.from_device) - self.assertTrue(nodes[5].target == ttnn.to_layout) - self.assertTrue(nodes[6].target == ttnn.to_torch) + self.assertTrue(nodes[4].target == ttnn.cos) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) # Check inference result - self.assertTrue(torch.allclose(result_before, result_after)) + self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2)) + + def test_sigmoid(self): + m = SigmoidModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[4].target == ttnn.sigmoid) + self.assertTrue(nodes[4].args[0].target == ttnn.to_device) + self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[5].target == ttnn.from_device) + self.assertTrue(nodes[6].target == ttnn.to_layout) + self.assertTrue(nodes[7].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2)) if __name__ == "__main__": diff --git a/tests/test_only_add_matmul.py b/tests/test_only_add_matmul.py index 4aa97e19e..9ffe58993 100644 --- a/tests/test_only_add_matmul.py +++ b/tests/test_only_add_matmul.py @@ -3,6 +3,8 @@ import unittest from torch_ttnn import ttnn +from torch_ttnn.utils import check_with_pcc + class AddModule(torch.nn.Module): def __init__(self): @@ -26,6 +28,17 @@ def input_shapes(self): return [(32, 32), (32, 32)] +class BatchMatmulModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.bmm(x, y) + + def input_shapes(self): + return [(10, 64, 32), (10, 32, 128)] + + # Nested module for demonstration, verify nested modules work class AddMatmulModule(torch.nn.Module): def __init__(self): @@ -43,11 +56,11 @@ def input_shapes(self): class TestModules(unittest.TestCase): def setUp(self): # Open device 0 - self.device: ttnn.Device = ttnn.open(0) + self.device: ttnn.Device = ttnn.open_device(device_id=0) def tearDown(self): # Close the device - ttnn.close(self.device) + ttnn.close_device(self.device) def test_add(self): m = AddModule() @@ -58,19 +71,20 @@ def test_add(self): result_before = m.forward(*inputs) option = torch_ttnn.TorchTtnnOption(device=self.device) # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) result_after = m.forward(*inputs) self.assertEqual(1, len(option._out_fx_graphs)) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) - self.assertEqual(nodes[6].target, ttnn.add) - self.assertEqual(nodes[6].args[0].target, ttnn.to_device) - self.assertEqual(nodes[6].args[0].args[0].target, ttnn.from_torch) - self.assertEqual(nodes[7].target, ttnn.from_device) - self.assertEqual(nodes[8].target, ttnn.to_layout) - self.assertEqual(nodes[9].target, ttnn.to_torch) + self.assertEqual(nodes[8].target, ttnn.add) + self.assertEqual(nodes[8].args[0].target, ttnn.to_device) + self.assertEqual(nodes[8].args[0].args[0].target, ttnn.to_layout) + self.assertEqual(nodes[8].args[0].args[0].args[0].target, ttnn.from_torch) + self.assertEqual(nodes[9].target, ttnn.from_device) + self.assertEqual(nodes[10].target, ttnn.to_layout) + self.assertEqual(nodes[11].target, ttnn.to_torch) # Check inference result self.assertTrue(torch.allclose(result_before, result_after)) @@ -84,22 +98,51 @@ def test_matmul(self): option = torch_ttnn.TorchTtnnOption(device=self.device) option.gen_graphviz = True # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) result_after = m.forward(*inputs) self.assertEqual(1, len(option._out_fx_graphs)) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) - self.assertEqual(nodes[6].target, ttnn.matmul) - self.assertEqual(nodes[6].args[0].target, ttnn.to_device) - self.assertEqual(nodes[6].args[0].args[0].target, ttnn.from_torch) - self.assertEqual(nodes[7].target, ttnn.from_device) - self.assertEqual(nodes[8].target, ttnn.to_layout) - self.assertEqual(nodes[9].target, ttnn.to_torch) + self.assertEqual(nodes[8].target, ttnn.matmul) + self.assertEqual(nodes[8].args[0].target, ttnn.to_device) + self.assertEqual(nodes[8].args[0].args[0].target, ttnn.to_layout) + self.assertEqual(nodes[8].args[0].args[0].args[0].target, ttnn.from_torch) + self.assertEqual(nodes[9].target, ttnn.from_device) + self.assertEqual(nodes[10].target, ttnn.to_layout) + self.assertEqual(nodes[11].target, ttnn.to_torch) # Check inference result self.assertTrue(torch.allclose(result_before, result_after)) + def test_batchmatmul(self): + m = BatchMatmulModule() + input_shapes = m.input_shapes() + inputs = [torch.rand(shape).type(torch.bfloat16) for shape in input_shapes] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + self.assertEqual(1, len(option._out_fx_graphs)) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertEqual(nodes[8].target, ttnn.matmul) + self.assertEqual(nodes[8].args[0].target, ttnn.to_device) + self.assertEqual(nodes[8].args[0].args[0].target, ttnn.to_layout) + self.assertEqual(nodes[8].args[0].args[0].args[0].target, ttnn.from_torch) + self.assertEqual(nodes[8].args[1].target, ttnn.to_device) + self.assertEqual(nodes[8].args[1].args[0].target, ttnn.to_layout) + self.assertEqual(nodes[8].args[1].args[0].args[0].target, ttnn.from_torch) + self.assertEqual(nodes[9].target, ttnn.from_device) + self.assertEqual(nodes[10].target, ttnn.to_layout) + self.assertEqual(nodes[11].target, ttnn.to_torch) + # Check inference result + self.assertTrue(check_with_pcc(result_before, result_after)) + def test_add_and_matmul(self): m = AddMatmulModule() input_shapes = m.input_shapes() @@ -110,20 +153,21 @@ def test_add_and_matmul(self): option = torch_ttnn.TorchTtnnOption(device=self.device) option.gen_graphviz = True # The compilation is lazy, so we need to run forward once to trigger the compilation - m = torch.compile(m, backend=torch_ttnn.backend(option)) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) result_after = m.forward(*inputs) self.assertEqual(1, len(option._out_fx_graphs)) option._out_fx_graphs[0].print_tabular() # Check the graph has be rewritten and contain ttnn ops nodes = list(option._out_fx_graphs[0].nodes) - self.assertEqual(nodes[6].target, ttnn.add) - self.assertEqual(nodes[6].args[0].target, ttnn.to_device) - self.assertEqual(nodes[6].args[0].args[0].target, ttnn.from_torch) - self.assertEqual(nodes[7].target, ttnn.matmul) - self.assertEqual(nodes[8].target, ttnn.from_device) - self.assertEqual(nodes[9].target, ttnn.to_layout) - self.assertEqual(nodes[10].target, ttnn.to_torch) + self.assertEqual(nodes[8].target, ttnn.add) + self.assertEqual(nodes[8].args[0].target, ttnn.to_device) + self.assertEqual(nodes[8].args[0].args[0].target, ttnn.to_layout) + self.assertEqual(nodes[8].args[0].args[0].args[0].target, ttnn.from_torch) + self.assertEqual(nodes[9].target, ttnn.matmul) + self.assertEqual(nodes[10].target, ttnn.from_device) + self.assertEqual(nodes[11].target, ttnn.to_layout) + self.assertEqual(nodes[12].target, ttnn.to_torch) # Check inference result self.assertTrue(torch.allclose(result_before, result_after)) diff --git a/tests/test_real_world.py b/tests/test_real_world.py index 6eed9bed1..f94aafc86 100644 --- a/tests/test_real_world.py +++ b/tests/test_real_world.py @@ -9,11 +9,11 @@ class TestRealWorld(unittest.TestCase): def setUp(self): # Open device 0 - self.device: ttnn.Device = ttnn.open(0) + self.device: ttnn.Device = ttnn.open_device(device_id=0) def tearDown(self): # Close the device - ttnn.close(self.device) + ttnn.close_device(self.device) @unittest.skip( "Skip this test because it take 135 MB to download the ResNet18 model. Un-skip it if you want to test it." diff --git a/tests/test_type_conversion.py b/tests/test_type_conversion.py new file mode 100644 index 000000000..fa0ac11cd --- /dev/null +++ b/tests/test_type_conversion.py @@ -0,0 +1,89 @@ +import torch +import torch_ttnn +import unittest +from torch_ttnn import ttnn +import tt_lib + + +class MulModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.mul(x, y) + + def input_shapes(self): + return [(4, 4), (4, 4)] + + +class TestModules(unittest.TestCase): + def setUp(self): + # Open device 0 + self.device: ttnn.Device = ttnn.open_device(device_id=0) + # For AutoFormat + tt_lib.device.SetDefaultDevice(self.device) + + def tearDown(self): + # Close the device + ttnn.close_device(self.device) + + def test_mul(self): + m = MulModule() + input_shapes = m.input_shapes() + inputs = [ + torch.randint(1, 5, shape).type(torch.float32) for shape in input_shapes + ] + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[10].target == ttnn.mul) + self.assertTrue(nodes[10].args[0].target == ttnn.to_device) + self.assertTrue(nodes[10].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[10].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[10].args[1].target == ttnn.to_device) + self.assertTrue(nodes[10].args[1].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[10].args[1].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[11].target == ttnn.from_device) + self.assertTrue(nodes[12].target == ttnn.to_layout) + self.assertTrue(nodes[13].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before.to(torch.bfloat16), result_after)) + + def test_mul_scalar(self): + m = MulModule() + input_shapes = m.input_shapes() + inputs = [ + torch.randint(1, 5, input_shapes[0]).type(torch.float32), + torch.randint(1, 5, (1,)).type(torch.float32).item(), + ] + print(inputs) + result_before = m.forward(*inputs) + option = torch_ttnn.TorchTtnnOption(device=self.device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(*inputs) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + self.assertTrue(nodes[5].target == ttnn.mul) + self.assertTrue(nodes[5].args[0].target == ttnn.to_device) + self.assertTrue(nodes[5].args[0].args[0].target == ttnn.to_layout) + self.assertTrue(nodes[5].args[0].args[0].args[0].target == ttnn.from_torch) + self.assertTrue(nodes[6].target == ttnn.from_device) + self.assertTrue(nodes[7].target == ttnn.to_layout) + self.assertTrue(nodes[8].target == ttnn.to_torch) + # Check inference result + self.assertTrue(torch.allclose(result_before.to(torch.bfloat16), result_after)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/run_torchvision.py b/tools/run_torchvision.py index e9f8f407a..3bc6d5e9d 100644 --- a/tools/run_torchvision.py +++ b/tools/run_torchvision.py @@ -91,7 +91,11 @@ def run_model( models = ["dinov2_vits14", "alexnet", "googlenet", "resnet18", "vgg11"] - device = torch_ttnn.ttnn.open(0) if args.backend == "torch_ttnn" else None + device = ( + torch_ttnn.ttnn.open_device(device_id=0) + if args.backend == "torch_ttnn" + else None + ) for m in models: try: run_model( @@ -106,4 +110,4 @@ def run_model( except: print(f"{m} FAIL") if args.backend == "torch_ttnn": - torch_ttnn.ttnn.close(device) + torch_ttnn.ttnn.close_device(device) diff --git a/tools/run_transformers.py b/tools/run_transformers.py new file mode 100644 index 000000000..aceed7772 --- /dev/null +++ b/tools/run_transformers.py @@ -0,0 +1,239 @@ +import os +import argparse +import torch +from PIL import Image +import requests + +# Load model directly +from transformers import ( + AutoTokenizer, + AutoModelForQuestionAnswering, + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoImageProcessor, + AutoModelForObjectDetection, +) + + +class TestModel: + def __init__(self, model_name, model_task, test_input): + self.model_name = model_name + self.model_task = model_task + self.test_input = test_input + + def __repr__(self): + return self.model_name + + +def run_model( + model: str, + backend: str, + backward: bool, + out_path: str, + graphviz: bool, + to_profile: bool, + device=None, +): + text_modules = [ + AutoModelForQuestionAnswering, + AutoModelForCausalLM, + AutoModelForSequenceClassification, + ] + vision_modules = [ + AutoModelForObjectDetection, + ] + + if model.model_task in text_modules: + tokenizer = AutoTokenizer.from_pretrained(model.model_name, padding_side="left") + elif model.model_task in vision_modules: + image_processor = AutoImageProcessor.from_pretrained(model.model_name) + else: + raise ValueError(f"model task: {model.model_task} not supported.") + + m = model.model_task.from_pretrained(model.model_name) + + if backward: + try: + m.train() + except: + print(f"{model.model_name} Cannot to the training mode, use just eval mode") + m.eval() + backward = False + else: + m.eval() + if backend == "torch_ttnn": + option = torch_ttnn.TorchTtnnOption(device=device) + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + elif backend == "torch_stat": + option = torch_stat.TorchStatOption( + model_name=model.model_name, + backward=backward, + out=out_path, + gen_graphviz=graphviz, + ) + m = torch.compile(m, backend=torch_stat.backend(option)) + else: + assert 0 and "Unsupport backend" + + if model.model_task == AutoModelForQuestionAnswering: + inputs = tokenizer.encode_plus( + model.test_input["question"], + model.test_input["context"], + add_special_tokens=True, + return_tensors="pt", + max_length=256, + padding="max_length", + truncation=True, + ) + elif ( + model.model_task == AutoModelForCausalLM + or model.model_task == AutoModelForSequenceClassification + ): + inputs = tokenizer(model.test_input, return_tensors="pt") + elif model.model_task == AutoModelForObjectDetection: + image = Image.open(requests.get(model.test_input, stream=True).raw) + inputs = image_processor(images=image, return_tensors="pt") + else: + raise ValueError(f"model task: {model.model_task} not supported.") + + if to_profile: + from torch.profiler import profile, record_function, ProfilerActivity + + trace_file = os.path.join(out_path, "profile", model.model_name) + os.makedirs(os.path.dirname(trace_file), exist_ok=True) + activities = [ProfilerActivity.CPU] + if torch.cuda.is_available(): + activities.append(ProfilerActivity.CUDA) + with profile(activities=activities, record_shapes=True) as prof: + with torch.no_grad(): + outputs = m(**inputs) + # if backward: + # result.backward(torch.ones_like(result)) + prof.export_chrome_trace(trace_file) + else: + with torch.no_grad(): + outputs = m(**inputs) + if backend == "torch_ttnn": + option._out_fx_graphs[0].print_tabular() + # if backward: + # result.backward(torch.ones_like(result)) + + if model.model_task == AutoModelForQuestionAnswering: + response_start = torch.argmax(outputs.start_logits) + response_end = torch.argmax(outputs.end_logits) + 1 + response_tokens = inputs.input_ids[0, response_start:response_end] + result = tokenizer.decode(response_tokens) + elif model.model_task == AutoModelForCausalLM: + next_token_logits = outputs.logits[:, -1] + next_token = next_token_logits.softmax(dim=-1).argmax() + result = tokenizer.decode([next_token]) + elif model.model_task == AutoModelForSequenceClassification: + normalized = outputs.logits.softmax(dim=-1) + print(normalized) + result = normalized.argmax().item() + elif model.model_task == AutoModelForObjectDetection: + target_sizes = torch.tensor([image.size[::-1]]) + results = image_processor.post_process_object_detection( + outputs, threshold=0.9, target_sizes=target_sizes + )[0] + else: + raise ValueError(f"model task: {model.model_task} not supported.") + + if model.model_task in text_modules: + print( + f"model_name: {model.model_name}\ninput: {model.test_input}\nresult: {result}" + ) + elif model.model_task in vision_modules: + for score, label, box in zip( + results["scores"], results["labels"], results["boxes"] + ): + box = [round(i, 2) for i in box.tolist()] + print( + f"Detected {m.config.id2label[label.item()]} with confidence " + f"{round(score.item(), 3)} at location {box}" + ) + else: + raise ValueError(f"model task: {model.model_task} not supported.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True) + parser.add_argument( + "--out_path", "-o", type=str, default=os.path.join(os.getcwd(), "stat") + ) + parser.add_argument("--backend", type=str) + parser.add_argument("--graphviz", action="store_true") + parser.add_argument("--backward", action="store_true") + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + assert args.backend in ["torch_ttnn", "torch_stat"] + if args.backend == "torch_ttnn" and args.backward: + assert 0 and "torch_ttnn not yet support backward" + + if args.backend == "torch_ttnn": + import torch_ttnn + elif args.backend == "torch_stat": + from torch_ttnn import torch_stat + + models = [ + TestModel( + "phiyodr/bert-large-finetuned-squad2", + AutoModelForQuestionAnswering, + { + "question": "What discipline did Winkelmann create?", + "context": 'Johann Joachim Winckelmann was a German art historian and archaeologist. He was a pioneering Hellenist who first articulated the difference between Greek, Greco-Roman and Roman art. "The prophet and founding hero of modern archaeology", Winckelmann was one of the founders of scientific archaeology and first applied the categories of style on a large, systematic basis to the history of art. ', + }, + ), + TestModel( + "tiiuae/falcon-7b-instruct", AutoModelForCausalLM, "Once upon a time" + ), + TestModel( + "mistralai/Mistral-7B-Instruct-v0.2", + AutoModelForCausalLM, + "My name is Johnny Appleseed, and today I", + ), + TestModel("bigscience/bloom-1b1", AutoModelForCausalLM, "My cat and my dog"), + # Need torch 2.3.0+ + TestModel( + "state-spaces/mamba-130m-hf", AutoModelForCausalLM, "Hey how are you doing?" + ), + TestModel( + "huggyllama/llama-7b", AutoModelForCausalLM, "Spring is a good time to" + ), + TestModel( + "mnoukhov/gpt2-imdb-sentiment-classifier", + AutoModelForSequenceClassification, + "This is the kind of movie you put in the background while working on other things.", + ), + TestModel( + "hustvl/yolos-tiny", + AutoModelForObjectDetection, + "http://images.cocodataset.org/val2017/000000039769.jpg", + ), + ] + + def get_model(model_name): + for m in models: + if model_name == m.model_name: + return m + raise ValueError( + f"model: {model_name} not supported. Supported models: {models}" + ) + + device = ( + torch_ttnn.ttnn.open_device(device_id=0) + if args.backend == "torch_ttnn" + else None + ) + run_model( + get_model(args.model), + args.backend, + args.backward, + args.out_path, + args.graphviz, + args.profile, + device, + ) + if args.backend == "torch_ttnn": + torch_ttnn.ttnn.close_device(device) diff --git a/torch_ttnn/backend.py b/torch_ttnn/backend.py index b77b448e7..fb12e00b1 100644 --- a/torch_ttnn/backend.py +++ b/torch_ttnn/backend.py @@ -2,6 +2,7 @@ import torch from typing import List import torch._dynamo +from functorch.compile import make_boxed_func torch._dynamo.config.suppress_errors = False torch._dynamo.config.verbose = True @@ -26,11 +27,43 @@ def aten_backend( trace into low level ATen ops not only high level torch ops. """ + # Clone ops used for input aliasing workaround are no longer needed at this point + from .handle_input_aliasing import remove_clones_for_input_aliasing + + gm = remove_clones_for_input_aliasing(gm) + + # Change float types in dtype kwargs to bfloat16 + from .convert_type import convert_dtype_to_bfloat16, convert_float_to_bfloat16 + + gm = convert_float_to_bfloat16(gm) + gm = convert_dtype_to_bfloat16(gm) + option: TorchTtnnOption = options["torch_ttnn_option"] torch.fx.graph._register_custom_builtin("ttnn_Specified_Device", "", option.device) torch.fx.graph._register_custom_builtin( "ttnn_ROW_MAJOR_LAYOUT", "", ttnn.ROW_MAJOR_LAYOUT ) + torch.fx.graph._register_custom_builtin("ttnn_TILE_LAYOUT", "", ttnn.TILE_LAYOUT) + torch.fx.graph._register_custom_builtin("ttnn_uint32", "", ttnn.uint32) + torch.fx.graph._register_custom_builtin("ttnn_bfloat16", "", ttnn.bfloat16) + + # Some ttnn objects are unhashable because they are function calls. + # However, arguments for these functions are usually hashable. + import tt_lib as ttl + + # ttnn.DRAM_MEMORY_CONFIG = ttnn.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM) + torch.fx.graph._register_custom_builtin( + "ttl_tensor_TensorMemoryLayout_INTERLEAVED", + "", + ttl.tensor.TensorMemoryLayout.INTERLEAVED, + ) + torch.fx.graph._register_custom_builtin( + "ttl_tensor_BufferType_DRAM", "", ttl.tensor.BufferType.DRAM + ) + # ttnn.L1_MEMORY_CONFIG = ttnn.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1) + torch.fx.graph._register_custom_builtin( + "ttl_tensor_BufferType_L1", "", ttl.tensor.BufferType.L1 + ) # Rewrite with ttnn ops, will insert redundant data movement from torch.fx.passes.infra.pass_manager import PassManager @@ -66,11 +99,12 @@ def aten_backend( pm = PassManager(passes=passes) gm, modified = pm(gm) + gm.graph.lint() gm.recompile() gm.graph.print_tabular() print(gm.code) option._out_fx_graphs.append(gm.graph) - return gm + return make_boxed_func(gm) from torch._dynamo.backends.common import aot_autograd @@ -85,7 +119,19 @@ def __init__(self, device: ttnn.Device): self._out_fx_graphs = list() +from .handle_input_aliasing import insert_clones_for_input_aliasing + + # The wrapper of aot_autograd that takes a TorchTtnnOption as options. -def backend(torch_ttnn_option: TorchTtnnOption): - options = {"torch_ttnn_option": torch_ttnn_option} - return aot_autograd(fw_compiler=partial(aten_backend, options=options)) +def backend( + gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], **kwargs +) -> torch.fx.GraphModule: + if options := kwargs.get("options"): + options = {"torch_ttnn_option": options} + else: + raise RuntimeError("TorchTtnnOption missing") + + gm = insert_clones_for_input_aliasing(gm) + return aot_autograd(fw_compiler=partial(aten_backend, options=options))( + gm, example_inputs + ) diff --git a/torch_ttnn/convert_type.py b/torch_ttnn/convert_type.py new file mode 100644 index 000000000..5f9f246f1 --- /dev/null +++ b/torch_ttnn/convert_type.py @@ -0,0 +1,69 @@ +import math +import torch +from typing import List +from .utils import GraphCleanup + + +# torch.fx defines a placeholder node as a function input +def get_input_nodes(gm: torch.fx.GraphModule) -> List[torch.fx.Node]: + input_nodes = [node for node in gm.graph.nodes if (node.op == "placeholder")] + return input_nodes + + +pytorch_float_types = [ + torch.float32, + torch.float64, + torch.float16, +] + + +def convert_float_to_bfloat16(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + input_nodes = get_input_nodes(gm) + modified = False + + for node in input_nodes: + arg_metadata = node.meta["val"] + if arg_metadata.dtype in pytorch_float_types: + with gm.graph.inserting_after(input_nodes[-1]): + to = gm.graph.call_method("to", args=(node, torch.bfloat16)) + node.replace_all_uses_with( + to, + delete_user_cb=lambda node: node != to, + ) + modified = True + + if modified: + gm = GraphCleanup(gm) + return gm + + +# Use on aten level +def convert_dtype_to_bfloat16(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + modified = False + for node in gm.graph.nodes: + # Convert min, max, of float32 to float16 + # TODO(kevinwuTT): Optimize this without needing to process every arg? + new_args = [] + for arg in node.args: + if type(arg) == int or type(arg) == float: + if arg == torch.finfo(torch.float32).min: + new_args.append(torch.finfo(torch.bfloat16).min) + elif arg == torch.finfo(torch.float32).max: + new_args.append(torch.finfo(torch.bfloat16).max) + else: + new_args.append(arg) + else: + new_args.append(arg) + node.args = tuple(new_args) + + if node.target == torch.ops.aten._to_copy.default: + new_kwargs = {"dtype": torch.bfloat16} + node.kwargs = new_kwargs + if node.target == torch.ops.aten.full.default: + new_kwargs = node.kwargs.copy() + new_kwargs["dtype"] = torch.bfloat16 + node.kwargs = new_kwargs + + gm = GraphCleanup(gm) + + return gm diff --git a/torch_ttnn/handle_input_aliasing.py b/torch_ttnn/handle_input_aliasing.py new file mode 100644 index 000000000..c448617ba --- /dev/null +++ b/torch_ttnn/handle_input_aliasing.py @@ -0,0 +1,83 @@ +import torch +from typing import List +from .utils import GraphCleanup + +""" +AOT Autograd has an optimization where if it determines that the storage of the +output is the same as the input, it will return data from the input. The output +is checked if it's an alias of the input. This becomes problematic when we +have data transfer between host and device. The issue has been raised before. +See: https://github.com/pytorch/pytorch/issues/108079 + +One workaround is to insert a clone op after every input node before the graph +is completely lowered to aten ops. This can be accomplished by pointing the +entry function with a `torch.fx.GraphModule` parameter to the `backend` +positional argument for `torch.compile`. At this point, the graph will have +higher-level torch ops where clone ops can be inserted. Then define another +function that will be passed to the `fw_compiler/bw_compiler` parameter for +`aot_autograd`. This is when the input aliasing metadata will be determined +and the graph will be lowered to aten ops. After this, the clone ops can +be removed with another pass. + +The method is inspired by TensorRT: +https://github.com/pytorch/TensorRT/commit/7daa1120dc1bc72d6f92f1e7aa2b357a65b6ea08 +""" + + +# torch.fx defines a placeholder node as a function input +def get_input_nodes(gm: torch.fx.GraphModule) -> List[torch.fx.Node]: + input_nodes = [node for node in gm.graph.nodes if (node.op == "placeholder")] + return input_nodes + + +# Insert aten.clone nodes after every input to prevent input aliasing +def insert_clones_for_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + input_nodes = get_input_nodes(gm) + modified = False + for node in input_nodes: + """TODO(kevinwuTT): This does not work if inserting right after the node itself. + Only works if inserting after all of the input_nodes. + TypeError: forward() missing `n` required positional arguments + Somehow the argument list will get truncated. + """ + with gm.graph.inserting_after(input_nodes[-1]): + clone_node = gm.graph.call_function( + torch.ops.aten.clone.default, args=(node,) + ) + node.replace_all_uses_with( + clone_node, + delete_user_cb=lambda node: node != clone_node, + ) + modified = True + + if modified: + gm = GraphCleanup(gm) + + return gm + + +def remove_clones_for_input_aliasing(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Remove the clone ops inserted to handle input aliasing + opcode name target args + ------------- --------------- ------------------- ------------ + placeholder l_input_ L_input_ () + call_function clone_default aten.clone.default (l_input_,) + call_function op (clone_default) + """ + modified = False + for node in gm.graph.nodes: + if ( + node.op == "placeholder" + and len(node.users) == 1 + and list(node.users)[0].target == torch.ops.aten.clone.default + ): + clone_node = list(node.users)[0] + clone_node.replace_all_uses_with(node) + gm.graph.erase_node(clone_node) + + modified = True + + if modified: + gm = GraphCleanup(gm) + + return gm diff --git a/torch_ttnn/mock_ttnn.py b/torch_ttnn/mock_ttnn.py index 7a8a5979a..ea8a4bdd0 100644 --- a/torch_ttnn/mock_ttnn.py +++ b/torch_ttnn/mock_ttnn.py @@ -22,12 +22,12 @@ def __init__(self, device_id): # return f"Device({self.device_id})" -def open(device_id): +def open_device(device_id): print(f"Device {device_id} is opened") return Device(device_id) -def close(device): +def close_device(device): print(f"Device {device.device_id} is closed") pass @@ -109,6 +109,7 @@ def permute(x, order): ROW_MAJOR_LAYOUT = 0 +TILE_LAYOUT = 1 # Wrap the functions so that they can be used in torch.fx # and block the further recusive tracing. See: diff --git a/torch_ttnn/passes/add_data_move_pass.py b/torch_ttnn/passes/add_data_move_pass.py index 7605b0e5d..75e3b6b15 100644 --- a/torch_ttnn/passes/add_data_move_pass.py +++ b/torch_ttnn/passes/add_data_move_pass.py @@ -1,4 +1,10 @@ import torch +from ..utils import ( + DummyTtnnUint32, + DummyTtnnRowMajorLayout, + DummyTtnnTileLayout, + DummyDevice, +) try: import ttnn @@ -9,6 +15,12 @@ from torch.fx.passes.infra.pass_base import PassBase, PassResult +class _Kwarg: + def __init__(self, key, value): + self.key = key + self.value = value + + def is_function_call(node) -> bool: if not isinstance(node, torch.fx.node.Node): return False @@ -30,6 +42,30 @@ def is_tt_compute(node) -> bool: ttnn.tanh, ttnn.reshape, ttnn.permute, + ttnn.relu, + ttnn.reciprocal, + ttnn.gelu, + ttnn.embedding, + ttnn.clone, + ttnn.layer_norm, + ttnn.neg, + ttnn.ones, + ttnn.tril, + ttnn.arange, + ttnn.eq, + ttnn.logical_not, + ttnn.zeros_like, + ttnn.mean, + ttnn.pow, + ttnn.rsqrt, + ttnn.silu, + ttnn.global_avg_pool2d, + ttnn.clip, + ttnn.squeeze, + ttnn.full, + ttnn.lt, + ttnn.cos, + ttnn.sigmoid, ] ) @@ -42,6 +78,7 @@ def is_tt_data_move(node) -> bool: ttnn.to_device, ttnn.from_torch, ttnn.to_torch, + ttnn.MemoryConfig, ] @@ -49,8 +86,17 @@ def is_tt(node): return is_tt_compute(node) or is_tt_data_move(node) +def is_reshape_rank_4(node): + if node.target == ttnn.reshape: + return len(node.args[1]) == 4 + else: + return False + + def should_add_data_move_in(src_node, dst_node) -> bool: - if isinstance(src_node, (int, float, list, tuple)): + if isinstance(src_node, (int, float, list, tuple)) or not isinstance( + src_node, torch.fx.node.Node + ): return False return is_tt_compute(dst_node) and not is_tt(src_node) @@ -76,44 +122,126 @@ def insert_node_between(src_node, dst_idx, dst_node, new_nodes): dst_node.update_arg(0, tuple(new_arg)) -def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> bool: +def insert_node_between_kwarg(src_node, key, dst_node, new_nodes): + """ + Insert new_node between src_node and dest_node's keyword arg. + + Output does not have keyword args. + """ + assert dst_node.op != "output" + new_nodes[0].update_arg(0, src_node) + dst_node.update_kwarg(key, new_nodes[-1]) + + +def try_add_data_move_in_kwargs(src_node_kwarg, dst_node, device) -> torch.fx.node.Node: + if not isinstance(src_node_kwarg, _Kwarg): + return None + key = src_node_kwarg.key + src_node = src_node_kwarg.value if not should_add_data_move_in(src_node, dst_node): - return False + return None g = dst_node.graph + new_nodes = list() with g.inserting_before(dst_node): - from_torch = g.call_function(ttnn.from_torch, (src_node,)) - to_device = g.call_function(ttnn.to_device, (from_torch, device)) + new_nodes.append(g.call_function(ttnn.from_torch, (src_node,))) + new_nodes.append( + g.call_function(ttnn.to_layout, (new_nodes[-1], DummyTtnnTileLayout())) + ) + if is_tt_compute(dst_node): + new_nodes.append(g.call_function(ttnn.to_device, (new_nodes[-1], device))) - insert_node_between(src_node, dst_idx, dst_node, [from_torch, to_device]) - return True + insert_node_between_kwarg(src_node, key, dst_node, new_nodes) + return new_nodes[-1] -def try_add_data_move_out(src_node, dst_idx, dst_node) -> bool: +def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.Node: + if not should_add_data_move_in(src_node, dst_node): + return None + + g = dst_node.graph + new_nodes = list() + with g.inserting_before(dst_node): + new_nodes.append(g.call_function(ttnn.from_torch, (src_node,))) + if ( + dst_node.target != ttnn.reshape + and dst_node.target != ttnn.embedding + and dst_node.target != ttnn.zeros_like + ): + new_nodes.append( + g.call_function(ttnn.to_layout, (new_nodes[-1], DummyTtnnTileLayout())) + ) + # For reshape only put tensor on device if rank is 4 + if (is_tt_compute(dst_node) and dst_node.target != ttnn.reshape) or ( + dst_node.target == ttnn.reshape and len(dst_node.args[1]) == 4 + ): + new_nodes.append(g.call_function(ttnn.to_device, (new_nodes[-1], device))) + + insert_node_between(src_node, dst_idx, dst_node, new_nodes) + return new_nodes[-1] + + +def try_add_data_move_out(src_node, dst_idx, dst_node) -> torch.fx.node.Node: if not should_add_data_move_out(src_node, dst_node): - return False + return None g = dst_node.graph + new_nodes = list() with g.inserting_before(dst_node): - from_device = g.call_function(ttnn.from_device, (src_node,)) - row_major_layout = g.call_function( - ttnn.to_layout, (from_device, DummyTtnnRowMajorLayout()) + if (is_tt_compute(src_node) and src_node.target != ttnn.reshape) or ( + src_node.target == ttnn.reshape and len(src_node.args[1]) == 4 + ): + new_nodes.append(g.call_function(ttnn.from_device, (src_node,))) + if src_node.target != ttnn.embedding and src_node.target != ttnn.zeros_like: + new_nodes.append( + g.call_function( + ttnn.to_layout, (new_nodes[-1], DummyTtnnRowMajorLayout()) + ) + ) + new_nodes.append( + g.call_function(ttnn.to_torch, (new_nodes[-1] if new_nodes else src_node,)) ) - to_torch = g.call_function(ttnn.to_torch, (row_major_layout,)) - - insert_node_between(src_node, dst_idx, dst_node, [from_device, to_torch]) - return True + insert_node_between(src_node, dst_idx, dst_node, new_nodes) + return new_nodes[-1] -# See https://docs.google.com/document/d/1r2D4AagoeTRjEmXFnWzzafaWQkf-8hlIbX2ze-JAUFo/edit#heading=h.zad9rwqjv6cr -class DummyDevice: - def __repr__(self): - return f"ttnn_Specified_Device" +def try_add_data_move_out_for_layer_norm( + src_node, dst_idx, dst_node +) -> torch.fx.node.Node: + if not should_add_data_move_out(src_node, dst_node): + return None -class DummyTtnnRowMajorLayout: - def __repr__(self): - return f"ttnn_ROW_MAJOR_LAYOUT" + g = dst_node.graph + new_nodes = list() + with g.inserting_before(dst_node): + if is_tt_compute(src_node) and src_node.target == ttnn.layer_norm: + new_nodes.append( + g.call_function(ttnn.to_layout, (src_node, DummyTtnnRowMajorLayout())) + ) + new_nodes.append(g.call_function(ttnn.from_device, (new_nodes[-1],))) + new_nodes.append(g.call_function(ttnn.to_torch, (new_nodes[-1],))) + + # Workaround to output the same layer_norm output + # Before: layer_norm = aten.layer_norm + # getitem = getitem(layer_norm, 0) + # return ((getitem,),) + # After: layer_norm = ttnn.layer_norm + # return (layer_norm,) + # Need to match the tuple in the original return statement + if new_nodes: + old_args = dst_node.args[0] + if isinstance(old_args, tuple): + new_args = list(old_args) + for idx, old_arg in enumerate(old_args): + if old_arg == src_node: + new_args[idx] = new_nodes[-1] + dst_node.update_arg(0, tuple(new_args)) + else: + dst_node.update_arg(dst_idx, new_nodes[-1]) + return new_nodes[-1] + else: + return None class AddDataMovePass(PassBase): @@ -122,12 +250,42 @@ def call(self, gm: torch.fx.GraphModule): device = DummyDevice() i = 0 nodes = list(gm.graph.nodes) + # Track argument reuse + data_move_in_hash = {} + # This might not be needed if workaround is not needed + data_move_out_hash = {} for node in nodes: args = node.args[0] if node.op == "output" else node.args + kwargs = tuple( + _Kwarg(k, v) + for k, v in node.kwargs.items() + if isinstance(v, torch.fx.node.Node) + ) + if isinstance(args, tuple): + args += kwargs + for idx, arg in enumerate(args): - if try_add_data_move_in(arg, idx, node, device): + if isinstance(arg, _Kwarg): + try_add_data_move_in_kwargs(arg, node, device) + elif arg in data_move_in_hash and node.op != "output": + node.update_arg(idx, data_move_in_hash[arg]) + elif to_device := try_add_data_move_in(arg, idx, node, device): + data_move_in_hash[arg] = to_device + i += 1 + + if arg in data_move_out_hash and node.op == "output": + old_arg = node.args[0] + new_arg = list(old_arg) + new_arg[idx] = data_move_out_hash[arg] + node.update_arg(0, tuple(new_arg)) + i += 1 + elif (node.target != ttnn.layer_norm) and ( + to_torch := try_add_data_move_out(arg, idx, node) + ): + data_move_out_hash[arg] = to_torch i += 1 - if try_add_data_move_out(arg, idx, node): + elif to_torch := try_add_data_move_out_for_layer_norm(arg, idx, node): + data_move_out_hash[arg] = to_torch i += 1 modified = i > 0 diff --git a/torch_ttnn/passes/to_tt_pass.py b/torch_ttnn/passes/to_tt_pass.py index b1f2c8b3a..200cfba68 100644 --- a/torch_ttnn/passes/to_tt_pass.py +++ b/torch_ttnn/passes/to_tt_pass.py @@ -1,4 +1,12 @@ import torch +from ..utils import ( + GraphCleanup, + DummyTtlTensorTensorMemoryLayoutInterleaved, + DummyTtlTensorBufferTypeDram, + DummyTtnnBfloat16, + DummyDevice, + DummyTtnnTileLayout, +) try: import ttnn @@ -7,48 +15,335 @@ from .. import mock_ttnn as ttnn from torch.fx.passes.infra.pass_base import PassBase, PassResult +import torch.fx.traceback as fx_traceback + +relational_scalar_ops = { + torch.ops.aten.eq.Scalar: ttnn.eq, + torch.ops.aten.lt.Scalar: ttnn.lt, +} + +# Workaround: If an arg of the model output is argmax then skip conversion +# TODO(kevinwuTT): Handle this case with ttnn ops +int_output_ops = [ + torch.ops.aten.argmax.default, + torch.ops.aten.argmin.default, +] + + +def are_args_from_int_output_ops(args): + for arg in args: + if isinstance(arg, torch.fx.proxy.Proxy): + if arg.node.target in int_output_ops: + return True class ReplaceMoreTt(torch.fx.Transformer): def call_function(self, target, args, kwargs): - if target == torch.ops.aten.sub.Tensor: - return super().call_function(ttnn.sub, args, kwargs) + if are_args_from_int_output_ops(args): + call_func = super().call_function(target, args, kwargs) + elif target == torch.ops.aten.sub.Tensor: + call_func = super().call_function(ttnn.sub, args, kwargs) + elif target == torch.ops.aten.rsub.Tensor: + # TODO(kevinwuMCW): handle alpha parameter if exists + call_func = super().call_function(ttnn.sub, (args[1], args[0]), kwargs) elif target == torch.ops.aten.mul.Tensor: - return super().call_function(ttnn.mul, args, kwargs) + call_func = super().call_function(ttnn.mul, args, kwargs) elif target == torch.ops.aten._softmax.default: - return super().call_function(ttnn.softmax, args[:2], kwargs) + call_func = super().call_function(ttnn.softmax, args[:2], kwargs) elif target == torch.ops.aten.tanh.default: - return super().call_function(ttnn.tanh, args, kwargs) + call_func = super().call_function(ttnn.tanh, args, kwargs) elif target == torch.ops.aten.view.default: - return super().call_function(ttnn.reshape, args, kwargs) + # TODO(kevinwuTT): Handle restrictions from ttnn.reshape + call_func = super().call_function(target, args, kwargs) elif target == torch.ops.aten.permute.default: - return super().call_function(ttnn.permute, args, kwargs) - return super().call_function(target, args, kwargs) + call_func = super().call_function(ttnn.permute, args, kwargs) + elif target == torch.ops.aten.relu.default: + call_func = super().call_function(ttnn.relu, args, kwargs) + elif target == torch.ops.aten.addmm.default: + # TODO(kevinwuMCW): include beta, alpha, and optional args + mm = super().call_function(ttnn.matmul, (args[1], args[2]), kwargs) + call_func = super().call_function(ttnn.add, (args[0], mm), kwargs) + elif target == torch.ops.aten.bmm.default: + call_func = super().call_function(ttnn.matmul, args, kwargs) + elif target == torch.ops.aten.gelu.default: + call_func = super().call_function(ttnn.gelu, args, kwargs) + elif target == torch.ops.aten.neg.default: + call_func = super().call_function(ttnn.neg, args, kwargs) + elif target == torch.ops.aten.tril.default: + call_func = super().call_function(ttnn.tril, args, kwargs) + elif target == torch.ops.aten.eq.Tensor: + call_func = super().call_function(ttnn.eq, args, kwargs) + elif target == torch.ops.aten.logical_not.default: + call_func = super().call_function(ttnn.logical_not, args, kwargs) + elif target == torch.ops.aten.zeros_like.default: + call_func = super().call_function(ttnn.zeros_like, args, {}) + elif target == torch.ops.aten.mean.dim: + # change dim parameter to tuple + new_args = list(args) + new_args[1] = tuple(args[1]) if len(args[1]) > 1 else args[1][0] + call_func = super().call_function(ttnn.mean, tuple(new_args), kwargs) + elif target == torch.ops.aten.add.Tensor: + call_func = super().call_function(ttnn.add, args, kwargs) + elif target == torch.ops.aten.mm.default: + call_func = super().call_function(ttnn.matmul, args, kwargs) + elif target == torch.ops.aten.pow.Tensor_Scalar: + call_func = super().call_function(ttnn.pow, args, kwargs) + elif target == torch.ops.aten.rsqrt.default: + call_func = super().call_function(ttnn.rsqrt, args, kwargs) + elif target == torch.ops.aten.silu.default: + call_func = super().call_function(ttnn.silu, args, kwargs) + elif target == torch.ops.aten._adaptive_avg_pool2d.default: + # assumes output size is (1, 1) + call_func = super().call_function( + ttnn.global_avg_pool2d, (args[0],), kwargs + ) + elif target == torch.ops.aten.clamp.default: + call_func = super().call_function(ttnn.clip, args, kwargs) + elif target == torch.ops.aten.squeeze.dim: + # NOTE(kevinwuTT): ttnn.squeeze only supports dim 0 currently + if args[1] != 0: + call_func = super().call_function(target, args, kwargs) + else: + call_func = super().call_function(ttnn.squeeze, args, kwargs) + elif target == torch.ops.aten.lt.Tensor: + call_func = super().call_function(ttnn.lt, args, kwargs) + elif target == torch.ops.aten.cos.default: + call_func = super().call_function(ttnn.cos, args, kwargs) + elif target == torch.ops.aten.sigmoid.default: + call_func = super().call_function(ttnn.sigmoid, args, kwargs) + else: + call_func = super().call_function(target, args, kwargs) + + # Copy metadata of old node to replacement + meta = fx_traceback.get_current_meta() + if meta is not None and "val" in meta: + call_func.node.meta["val"] = meta["val"] + return call_func + + +def torch_dtype_to_dummy_ttnn_dtype(dtype: torch.dtype): + # Add newly supported dtypes here: + dtype_map = { + torch.float32: DummyTtnnBfloat16(), + torch.bfloat16: DummyTtnnBfloat16(), + } + if dtype in dtype_map: + return dtype_map.get(dtype) + else: + raise RuntimeError( + f"Missing conversion from torch.dtype: {dtype} to DummyTtnn dtype." + ) + + +# Certain ops don't support certain shapes and will emit a valid_page_size error +# RuntimeError: TT_FATAL @ ../tt_metal/impl/buffers/buffer.cpp:38: valid_page_size +# For valid non-interleaved buffers page size 2048 must equal buffer size X. For interleaved-buffers page size should be divisible by buffer size +def has_valid_page_size(shape): + for dim in shape: + if dim < 32: + return False + return True + + +def ReplaceMoreTtManually(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + nodes = list(gm.graph.nodes) + for node in nodes: + g = node.graph + args = node.args + + with g.inserting_before(node): + # TODO (kevinwuTT): consolidate and simplify these statements? + if node.target == torch.ops.aten.clone.default: + arg_metadata = node.meta["val"] + dummy_dtype = torch_dtype_to_dummy_ttnn_dtype(arg_metadata.dtype) + # Add additional logic to choose the appropriate memory_config type: DRAM or L1 + memory_config = g.call_function( + ttnn.MemoryConfig, + ( + DummyTtlTensorTensorMemoryLayoutInterleaved(), + DummyTtlTensorBufferTypeDram(), + ), + ) + new_node = g.call_function( + ttnn.clone, args=(args[0], memory_config, dummy_dtype) + ) + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + if node.target == torch.ops.aten.native_layer_norm.default: + new_node = g.call_function( + ttnn.layer_norm, + args=(args[0],), + kwargs={"epsilon": args[4], "weight": args[2], "bias": args[3]}, + ) + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + node_users = list(new_node.users.keys()) + for node_user in node_users: + node_user.replace_all_uses_with(new_node) + if node.target == torch.ops.aten.ones.default: + new_node = g.call_function( + ttnn.ones, args=args, kwargs={"device": DummyDevice()} + ) + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + """ + # NOTE(kevinwuTT): aten.arange.default starts with 0 which is unsupported by ttnn.arange at the moment + if node.target == torch.ops.aten.arange.default: + # start = 0, step = 1 + new_args = (0,) + new_kwargs = {"end": args[0], "step": 1, "device": DummyDevice()} + new_node = g.call_function(ttnn.arange, args=new_args, kwargs=new_kwargs) + node.replace_all_uses_with(new_node, delete_user_cb=lambda node: node != new_node,) + """ + if node.target == torch.ops.aten.arange.start: + # NOTE(kevinwuTT): ttnn.arange does not support starting values smaller than 2 currently + if args[0] >= 2: + # step = 1 + new_args = (args[0],) + new_kwargs = {"end": args[1], "step": 1, "device": DummyDevice()} + new_node = g.call_function( + ttnn.arange, args=new_args, kwargs=new_kwargs + ) + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + if node.target == torch.ops.aten.arange.start_step: + # NOTE(kevinwuTT): ttnn.arange does not support starting values smaller than 2 currently + if args[0] >= 2: + new_args = (args[0],) + new_kwargs = { + "end": args[1], + "step": args[2], + "device": DummyDevice(), + } + new_node = g.call_function( + ttnn.arange, args=new_args, kwargs=new_kwargs + ) + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + if node.target in relational_scalar_ops: + # NOTE(kevinwuTT): ttnn.eq shows error if passing a literal scalar as an argument. + # Instead, fill a tensor with the same size as args[0] with the scalar value using ttnn.full + arg_metadata = node.meta["val"] + if has_valid_page_size(arg_metadata.size()): + new_kwargs = { + "fill_value": args[1], + "device": DummyDevice(), + "layout": DummyTtnnTileLayout(), + } + full_node = g.call_function( + ttnn.full, args=(arg_metadata.size(),), kwargs=new_kwargs + ) + new_node = g.call_function( + relational_scalar_ops[node.target], + args=(args[0], full_node), + kwargs={}, + ) + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + if node.target == torch.ops.aten.full.default: + # args[0] can be empty for aten.full which simply creates a scalar. Ignore conversion in this case. + if args[0]: + new_kwargs = { + "fill_value": args[1], + "device": DummyDevice(), + "layout": DummyTtnnTileLayout(), + } + new_node = g.call_function( + ttnn.full, args=(tuple(args[0]),), kwargs=new_kwargs + ) + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + if node.target == torch.ops.aten.baddbmm.default: + kwargs = node.kwargs + # out = beta * input + alpha * (batch1 @ batch2) + # if beta is 0, input is ignored, and nan and inf in it will not be propogated + new_node = g.call_function(ttnn.matmul, args=(args[1], args[2])) + if "alpha" in kwargs: + new_node = g.call_function( + ttnn.multiply, args=(new_node, kwargs["alpha"]) + ) + if "beta" in kwargs: + if kwargs["beta"] != 0: + beta_node = g.call_function( + ttnn.multiply, args=(args[0], kwargs["beta"]) + ) + new_node = g.call_function(ttnn.add, args=(beta_node, new_node)) + else: + new_node = g.call_function(ttnn.add, args=(args[0], new_node)) + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + if node.target == torch.ops.aten.embedding.default: + # TODO(kevinwuTT): Add support for ROW_MAJOR_LAYOUT + new_kwargs = {"layout": DummyTtnnTileLayout()} + new_node = g.call_function( + ttnn.embedding, args=(args[1], args[0]), kwargs=new_kwargs + ) + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + if node.target == torch.ops.aten.rsub.Scalar: + # NOTE(kevinwuTT): ttnn.sub shows error if passing a literal scalar as the first argument. + # Instead, fill a tensor with the same size as args[0] with the scalar value using aten.full + arg_metadata = node.meta["val"] + # NOTE(kevinwuTT): Only bfloat16 seems to work for now + # TODO(kevinwuTT): Use ttnn.full instead of aten + new_kwargs = {"dtype": torch.bfloat16} + full_node = g.call_function( + torch.ops.aten.full.default, + args=(arg_metadata.size(), args[1]), + kwargs=new_kwargs, + ) + new_node = g.call_function( + ttnn.sub, args=(full_node, args[0]), kwargs={} + ) + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + if node.target == torch.ops.aten.div.Tensor: + # ttnn.recip does not support scalars. Call an aten.full and pass that to ttnn.recip + # TODO(kevinwuTT): Use a ttnn equivalent + node_metadata = node.meta["val"] + if isinstance(args[1], float): + full = g.call_function( + torch.ops.aten.full.default, (node_metadata.size(), args[1]), {} + ) + recip = g.call_function(ttnn.reciprocal, (full,), {}) + else: + recip = g.call_function(ttnn.reciprocal, (args[1],), {}) + new_node = g.call_function(ttnn.mul, (args[0], recip), {}) + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + + gm = GraphCleanup(gm) + return gm class ToTtPass(PassBase): def call(self, gm: torch.fx.GraphModule): - # NOTE(yoco): In our case, subgraph_rewriter actually - # is not a best choice. We should use torch.fx.Transformer. - # However, we use subgraph_rewriter for demonstration - # and as a code stub. Because Transformer only support - # 1-to-N replacement. For N-to-M replacement, we need - # to use subgraph_rewriter. - from ..patterns import add - from ..patterns import mm - - # Patterns and replacements - pat_rep_list = list() - pat_rep_list += add.pat_rep_list - pat_rep_list += mm.pat_rep_list - - # Replace patterns - modified = False - for pat, rep in pat_rep_list: - replaced_pats = torch.fx.subgraph_rewriter.replace_pattern(gm, pat, rep) - modified = modified or len(replaced_pats) > 0 - # Replace more patterns with torch.fx.Transformer gm = ReplaceMoreTt(gm).transform() + # Replace patterns manually + gm = ReplaceMoreTtManually(gm) + return PassResult(gm, True) diff --git a/torch_ttnn/patterns/add.py b/torch_ttnn/patterns/add.py index 14c62a9a0..065fdb6ed 100644 --- a/torch_ttnn/patterns/add.py +++ b/torch_ttnn/patterns/add.py @@ -1,3 +1,6 @@ +# TODO(kevinwuTT): I have a patch for tt-metal that assigns the __name__ attribute of each Operation. +# This file may not be needed anymore. + import torch try: diff --git a/torch_ttnn/patterns/mm.py b/torch_ttnn/patterns/mm.py index 133d7c2c7..4d257be74 100644 --- a/torch_ttnn/patterns/mm.py +++ b/torch_ttnn/patterns/mm.py @@ -1,3 +1,6 @@ +# TODO(kevinwuTT): I have a patch for tt-metal that assigns the __name__ attribute of each Operation. +# This file may not be needed anymore. + import torch try: diff --git a/torch_ttnn/torch_stat.py b/torch_ttnn/torch_stat.py index 8792e62c8..b5f38efa0 100644 --- a/torch_ttnn/torch_stat.py +++ b/torch_ttnn/torch_stat.py @@ -4,6 +4,7 @@ import torch._dynamo import os from collections import Counter +from functorch.compile import make_boxed_func torch._dynamo.config.suppress_errors = False torch._dynamo.config.verbose = True @@ -37,6 +38,7 @@ def aten_backend( "raw", f"{direction}_{option.model_name}_{option.counter['val']}.json", ) + print(stat_filename) os.makedirs(os.path.dirname(stat_filename), exist_ok=True) passes = [StatPass(filename=stat_filename, example_inputs=example_inputs)] if option.gen_graphviz: @@ -56,7 +58,7 @@ def aten_backend( gm.recompile() option.out_fx_graphs.append(gm.graph) option.counter["val"] += 1 - return gm + return make_boxed_func(gm) from torch._dynamo.backends.common import aot_autograd @@ -72,7 +74,7 @@ def __init__( out=os.path.join(os.getcwd(), "stat"), gen_graphviz=False, ): - self.model_name = model_name + self.model_name = model_name.replace("/", "_") self.backward = backward self.out = out self.gen_graphviz = gen_graphviz @@ -90,5 +92,6 @@ def backend(torch_stat_option: TorchStatOption): ) else: return aot_autograd( - fw_compiler=partial(aten_backend, options=options, direction="fw") + fw_compiler=partial(aten_backend, options=options, direction="fw"), + # inference_compiler=partial(aten_backend, options=options, direction="fw"), ) diff --git a/torch_ttnn/utils.py b/torch_ttnn/utils.py new file mode 100644 index 000000000..ec3206315 --- /dev/null +++ b/torch_ttnn/utils.py @@ -0,0 +1,129 @@ +import torch +import numpy as np + + +# Testing utils copied from tt-metal/tests/ttnn/utils_for_testing.py +def comp_pcc(golden, calculated, pcc=0.99): + golden = torch.Tensor(golden) + calculated = torch.Tensor(calculated) + + if golden.dtype != calculated.dtype: + calculated = calculated.type(golden.dtype) + + if torch.all(torch.isnan(golden)) and torch.all(torch.isnan(calculated)): + # logger.warning("Both tensors are 'nan'") + return True, 1.0 + + if torch.all(torch.isnan(golden)) or torch.all(torch.isnan(calculated)): + # logger.error("One tensor is all nan, the other is not.") + return False, 0.0 + + # Test if either is completely zero + if torch.any(golden.bool()) != torch.any(calculated.bool()): + # logger.error("One tensor is all zero") + return False, 0.0 + + # For now, mask all infs and nans so that we check the rest... TODO + golden = golden.clone() + golden[ + torch.logical_or( + torch.isnan(golden), + torch.logical_or(torch.isinf(golden), torch.isneginf(golden)), + ) + ] = 0 + calculated = calculated.clone() + calculated[ + torch.logical_or( + torch.isnan(calculated), + torch.logical_or(torch.isinf(calculated), torch.isneginf(calculated)), + ) + ] = 0 + + if torch.equal(golden, calculated): + return True, 1.0 + + if golden.dtype == torch.bfloat16: + golden = golden.type(torch.float32) + calculated = calculated.type(torch.float32) + cal_pcc = np.min( + np.ma.corrcoef( + np.ma.masked_invalid(torch.squeeze(golden).detach().numpy()).flatten(), + np.ma.masked_invalid(torch.squeeze(calculated).detach().numpy()).flatten(), + ) + ) + + if isinstance(cal_pcc, np.ma.core.MaskedConstant): + return True, 1.0 + + return cal_pcc >= pcc, cal_pcc + + +def construct_pcc_assert_message( + message, expected_pytorch_result, actual_pytorch_result +): + messages = [] + messages.append(message) + # messages.append("Expected") + # messages.append(str(expected_pytorch_result)) + # messages.append("Actual") + # messages.append(str(actual_pytorch_result)) + messages = [str(m) for m in messages] + return "\n".join(messages) + + +def check_with_pcc(expected_pytorch_result, actual_pytorch_result, pcc=0.9999): + if expected_pytorch_result.shape != actual_pytorch_result.shape: + return ( + False, + f"list(expected_pytorch_result.shape)={list(expected_pytorch_result.shape)} vs list(actual_pytorch_result.shape)={list(actual_pytorch_result.shape)}", + ) + pcc_passed, pcc_message = comp_pcc( + expected_pytorch_result, actual_pytorch_result, pcc + ) + return pcc_passed, construct_pcc_assert_message( + pcc_message, expected_pytorch_result, actual_pytorch_result + ) + + +def GraphCleanup(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + return gm + + +# See https://docs.google.com/document/d/1r2D4AagoeTRjEmXFnWzzafaWQkf-8hlIbX2ze-JAUFo/edit#heading=h.zad9rwqjv6cr +class DummyDevice: + def __repr__(self): + return f"ttnn_Specified_Device" + + +class DummyTtnnRowMajorLayout: + def __repr__(self): + return f"ttnn_ROW_MAJOR_LAYOUT" + + +class DummyTtnnTileLayout: + def __repr__(self): + return f"ttnn_TILE_LAYOUT" + + +class DummyTtnnUint32: + def __repr__(self): + return f"ttnn_uint32" + + +class DummyTtnnBfloat16: + def __repr__(self): + return f"ttnn_bfloat16" + + +class DummyTtlTensorTensorMemoryLayoutInterleaved: + def __repr__(self): + return f"ttl_tensor_TensorMemoryLayout_INTERLEAVED" + + +class DummyTtlTensorBufferTypeDram: + def __repr__(self): + return f"ttl_tensor_BufferType_DRAM"