Skip to content

Commit

Permalink
Implement conversion to ttnn.trunc
Browse files Browse the repository at this point in the history
  • Loading branch information
jdh8 committed Oct 16, 2024
1 parent c0775ea commit c64f168
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
43 changes: 43 additions & 0 deletions tests/lowering/eltwise/unary/test_trunc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import torch_ttnn
import pytest
import ttnn
from tests.utils import assert_with_pcc


class TruncModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.trunc(x)


@pytest.mark.skip_platform("grayskull")
@pytest.mark.parametrize(
"input_shape",
(
(1, 1, 1, 42),
(1, 1, 32, 1),
(4, 4),
(4, 32),
(1066,),
),
)
def test_trunc(device, input_shape):
m = TruncModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) * 20 - 10
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.trunc) == 1

# Check inference result
assert_with_pcc(result_before, result_after)
1 change: 1 addition & 0 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def is_function_call(node) -> bool:
ttnn.sqrt,
ttnn.tan,
ttnn.tanh,
ttnn.trunc,
]


Expand Down
2 changes: 2 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
torch.ops.aten.ceil.default,
torch.ops.aten.floor.default,
torch.ops.aten.round.default,
torch.ops.aten.trunc.default,
}


Expand Down Expand Up @@ -98,6 +99,7 @@ def create_call_function(transformer, target, args, kwargs):
torch.ops.aten.tan.default: ttnn.tan,
torch.ops.aten.tanh.default: ttnn.tanh,
torch.ops.aten.tril.default: ttnn.tril,
torch.ops.aten.trunc.default: ttnn.trunc,
}


Expand Down

0 comments on commit c64f168

Please sign in to comment.