diff --git a/tests/lowering/eltwise/unary/test_trunc.py b/tests/lowering/eltwise/unary/test_trunc.py new file mode 100644 index 000000000..7df0c00b3 --- /dev/null +++ b/tests/lowering/eltwise/unary/test_trunc.py @@ -0,0 +1,43 @@ +import torch +import torch_ttnn +import pytest +import ttnn +from tests.utils import assert_with_pcc + + +class TruncModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.trunc(x) + + +@pytest.mark.skip_platform("grayskull") +@pytest.mark.parametrize( + "input_shape", + ( + (1, 1, 1, 42), + (1, 1, 32, 1), + (4, 4), + (4, 32), + (1066,), + ), +) +def test_trunc(device, input_shape): + m = TruncModule() + input = torch.rand(input_shape, dtype=torch.bfloat16) * 20 - 10 + result_before = m.forward(input) + option = torch_ttnn.TorchTtnnOption(device=device) + option.gen_graphviz = True + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + result_after = m.forward(input) + option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = list(option._out_fx_graphs[0].nodes) + assert [node.target for node in nodes].count(ttnn.trunc) == 1 + + # Check inference result + assert_with_pcc(result_before, result_after) diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 1f636c31e..f8e142da4 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -67,6 +67,7 @@ def is_function_call(node) -> bool: ttnn.sqrt, ttnn.tan, ttnn.tanh, + ttnn.trunc, ] diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 984c21cc1..a81b2b756 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -40,6 +40,7 @@ torch.ops.aten.ceil.default, torch.ops.aten.floor.default, torch.ops.aten.round.default, + torch.ops.aten.trunc.default, } @@ -98,6 +99,7 @@ def create_call_function(transformer, target, args, kwargs): torch.ops.aten.tan.default: ttnn.tan, torch.ops.aten.tanh.default: ttnn.tanh, torch.ops.aten.tril.default: ttnn.tril, + torch.ops.aten.trunc.default: ttnn.trunc, }