Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert aten.ceil to ttnn.ceil #198

Merged
merged 20 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions tests/lowering/eltwise/unary/test_ceil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch_ttnn
import pytest
import ttnn
from tests.utils import assert_with_pcc


class CeilModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ceil(x)


@pytest.mark.skip_platform("grayskull")
@pytest.mark.parametrize(
"input_shape",
(
(1066,),
(120,),
(128,),
(160,),
(1, 1066),
(4, 4),
(4, 32),
),
)
def test_ceil(device, input_shape):
m = CeilModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) * 20 - 10
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()

# Check the graph has been rewritten and contains ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.ceil) == 1

# Check inference result
assert_with_pcc(result_before, result_after)
8 changes: 4 additions & 4 deletions tests/lowering/eltwise/unary/test_erf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ def forward(self, x):


@pytest.mark.parametrize(
("input_shape", "init_offset"),
[((4, 4), 0)],
"input_shape",
((4, 4), (1066,)),
)
def test_erf(device, input_shape, init_offset):
def test_erf(device, input_shape):
m = ErfModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) + init_offset
input = torch.rand(input_shape, dtype=torch.bfloat16)
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
Expand Down
8 changes: 4 additions & 4 deletions tests/lowering/eltwise/unary/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ def forward(self, x):


@pytest.mark.parametrize(
("input_shape", "init_offset"),
[((4, 4), 0)],
"input_shape",
((4, 4), (1066,)),
)
def test_exp(device, input_shape, init_offset):
def test_exp(device, input_shape):
m = ExpModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) + init_offset
input = torch.rand(input_shape, dtype=torch.bfloat16)
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
Expand Down
11 changes: 8 additions & 3 deletions tests/lowering/eltwise/unary/test_floor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch_ttnn
import pytest
import ttnn
from tests.utils import assert_with_pcc


class FloorModule(torch.nn.Module):
Expand All @@ -15,7 +16,13 @@ def forward(self, x):
@pytest.mark.skip_platform("grayskull")
@pytest.mark.parametrize(
"input_shape",
[(4, 4)],
(
(1, 1, 1, 42),
(1, 1, 32, 1),
(4, 4),
(4, 32),
(1066,),
),
)
def test_floor(device, input_shape):
m = FloorModule()
Expand All @@ -33,6 +40,4 @@ def test_floor(device, input_shape):
assert [node.target for node in nodes].count(ttnn.floor) == 1

# Check inference result
from tests.utils import assert_with_pcc

assert_with_pcc(result_before, result_after)
12 changes: 6 additions & 6 deletions tests/lowering/eltwise/unary/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ def forward(self, input):


@pytest.mark.parametrize(
"input_shapes",
[[(4, 4)]],
"input_shape",
((4, 4), (1066,)),
)
def test_gelu(device, input_shapes):
def test_gelu(device, input_shape):
m = GeluModule()
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
result_before = m.forward(*inputs)
input = torch.rand(input_shape, dtype=torch.bfloat16)
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
Expand Down
2 changes: 1 addition & 1 deletion tests/lowering/eltwise/unary/test_hardsigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def forward(self, x):
(2, 3, 4, 5),
(7, 7, 7),
(420, 69),
pytest.param((1337,), marks=pytest.mark.xfail(reason="1D cases solved in #198, waiting for review")),
(1337,),
),
)
def test_hardsigmoid(device, input_shape):
Expand Down
2 changes: 1 addition & 1 deletion tests/lowering/eltwise/unary/test_hardswish.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def forward(self, x):
(2, 3, 4, 5),
(7, 7, 7),
(420, 69),
pytest.param((1337,), marks=pytest.mark.xfail(reason="1D cases solved in #198, waiting for review")),
(1337,),
),
)
def test_hardswish(device, input_shape):
Expand Down
71 changes: 71 additions & 0 deletions tests/lowering/eltwise/unary/test_round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
import torch_ttnn
import pytest
import ttnn
from tests.utils import assert_with_pcc


class RoundModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, *args, **kwargs):
return torch.round(*args, **kwargs)


@pytest.mark.skip_platform("grayskull")
@pytest.mark.parametrize(
"input_shape",
(
(1, 1, 1, 42),
(1, 1, 32, 1),
(4, 4),
(4, 32),
(1066,),
),
)
def test_round_default(device, input_shape):
m = RoundModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) * 20 - 10
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.round) == 1
assert_with_pcc(result_before, result_after, 0.99)


@pytest.mark.skip_platform("grayskull")
@pytest.mark.parametrize(
"input_shape, decimals",
(
((1, 1, 1, 42), 1),
((1, 1, 32, 1), 2),
((4, 4), 3),
((4, 32), 0),
((1066,), 2),
((1066,), 1),
((1066,), 0),
pytest.param(
(1066,),
-1,
# NOTE(jdh8): skip instead of xfail because it takes a long time
marks=pytest.mark.skip(reason="decimals < 0 not supported (until tt-metal#13851?)"),
),
),
)
def test_round_decimals(device, input_shape, decimals):
m = RoundModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) * 20 - 10
result_before = m.forward(input, decimals=decimals)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input, decimals=decimals)
option._out_fx_graphs[0].print_tabular()
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.round) == 1
assert_with_pcc(result_before, result_after, 0.99)
12 changes: 6 additions & 6 deletions tests/lowering/eltwise/unary/test_rsqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ def forward(self, input):


@pytest.mark.parametrize(
"input_shapes",
[[(4, 4)]],
"input_shape",
((4, 4), (1066,)),
)
def test_rsqrt(device, input_shapes):
def test_rsqrt(device, input_shape):
m = RsqrtModule()
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
result_before = m.forward(*inputs)
input = torch.rand(input_shape, dtype=torch.bfloat16)
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
Expand Down
8 changes: 4 additions & 4 deletions tests/lowering/eltwise/unary/test_sqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ def forward(self, x):


@pytest.mark.parametrize(
("input_shape", "init_offset"),
[((4, 4), 0)],
"input_shape",
((4, 4), (1066,)),
)
def test_sqrt(device, input_shape, init_offset):
def test_sqrt(device, input_shape):
m = SqrtModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) + init_offset
input = torch.rand(input_shape, dtype=torch.bfloat16)
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
Expand Down
43 changes: 43 additions & 0 deletions tests/lowering/eltwise/unary/test_trunc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
import torch_ttnn
import pytest
import ttnn
from tests.utils import assert_with_pcc


class TruncModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.trunc(x)


@pytest.mark.skip_platform("grayskull")
@pytest.mark.parametrize(
"input_shape",
(
(1, 1, 1, 42),
(1, 1, 32, 1),
(4, 4),
(4, 32),
(1066,),
),
)
def test_trunc(device, input_shape):
m = TruncModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) * 20 - 10
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.trunc) == 1

# Check inference result
assert_with_pcc(result_before, result_after)
6 changes: 4 additions & 2 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ def is_function_call(node) -> bool:
ttnn.asin,
ttnn.asinh,
ttnn.atan,
ttnn.atan2, # binary
ttnn.atanh,
# ttnn.clone, in target_wrappers
ttnn.ceil,
ttnn.cos,
ttnn.cosh,
ttnn.elu,
Expand All @@ -58,6 +57,7 @@ def is_function_call(node) -> bool:
ttnn.reciprocal,
ttnn.relu,
ttnn.remainder,
ttnn.round,
ttnn.rsqrt,
ttnn.sigmoid,
ttnn.softmax,
Expand All @@ -68,11 +68,13 @@ def is_function_call(node) -> bool:
ttnn.sqrt,
ttnn.tan,
ttnn.tanh,
ttnn.trunc,
]


TTNN_POINTWISE_BINARY_OPS = [
ttnn.add,
ttnn.atan2,
ttnn.div,
ttnn.eqz,
ttnn.gez,
Expand Down
Loading
Loading