Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to convert aten.amin/amax to ttnn.min/max #241

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/lowering/reduction/test_amax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
import torch_ttnn
import pytest
import ttnn


class AmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, dim, keepdim):
return torch.amax(input, dim=dim, keepdim=keepdim)


@pytest.mark.parametrize("sign", [1, -1])
@pytest.mark.parametrize(
"input_shape, dim, keepdim",
[
((32, 32), [], True),
((16, 32, 32), [], True),
((16, 32, 32), [1], True),
((16, 32, 32), 1, True),
((16, 32, 32), [2], True),
((16, 32, 32), [1, 2], True),
# TODO(#240): keepdim = false is not supported
pytest.param((32, 32), [1], False, marks=pytest.mark.xfail(reason="keepdim = false is not supported (#240)")),
# TODO(#240): Not support reduction on < rank - 2 dims
pytest.param(
(16, 32, 32), [0], True, marks=pytest.mark.xfail(reason="Not support reduction on < rank - 2 dims (#240)")
),
pytest.param(
(32, 32, 32),
[0, 1, 2],
True,
marks=pytest.mark.xfail(reason="Not support reduction on < rank - 2 dims (#240)"),
),
# TODO(#240): Unexpected output shape (1, 1) instead of (1)
pytest.param((32,), [], True, marks=pytest.mark.xfail(reason="Need -inf padding value (#240)")),
# TODO(#240): Need -inf padding value
pytest.param((1, 32), [], True, marks=pytest.mark.xfail(reason="Need -inf padding value (#240)")),
pytest.param((32, 1), [], True, marks=pytest.mark.xfail(reason="Need -inf padding value (#240)")),
# TODO(#240): Output reshape inside generic reduction can't handle non-tile-aligned size
pytest.param(
(1, 32),
[1],
True,
marks=pytest.mark.xfail(
reason="Output reshape inside generic reduction can't handle non-tile-aligned size (#240)"
),
),
],
)
def test_amax(device, sign, input_shape, dim, keepdim):
m = AmaxModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) * sign
result_before = m.forward(input, dim, keepdim)

option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=True)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input, dim, keepdim)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.max) == 1
# Check inference result
assert result_before.shape == result_after.shape
assert torch.allclose(result_before, result_after)
71 changes: 71 additions & 0 deletions tests/lowering/reduction/test_amin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import torch
import torch_ttnn
import pytest
import ttnn


class AminModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, dim, keepdim):
return torch.amin(input, dim=dim, keepdim=keepdim)


@pytest.mark.parametrize("sign", [1, -1])
@pytest.mark.parametrize(
"input_shape, dim, keepdim",
[
((32, 32), [], True),
((16, 32, 32), [], True),
((16, 32, 32), 1, True),
((16, 32, 32), [1], True),
((16, 32, 32), [2], True),
((16, 32, 32), [1, 2], True),
# TODO(#240): keepdim = false is not supported
pytest.param((32, 32), [1], False, marks=pytest.mark.xfail(reason="keepdim = false is not supported (#240)")),
# TODO(#240): Not support reduction on < rank - 2 dims
pytest.param(
(16, 32, 32), [0], True, marks=pytest.mark.xfail(reason="Not support reduction on < rank - 2 dims (#240)")
),
pytest.param(
(32, 32, 32),
[0, 1, 2],
True,
marks=pytest.mark.xfail(reason="Not support reduction on < rank - 2 dims (#240)"),
),
# TODO(#240): Unexpected output shape (1, 1) instead of (1)
pytest.param(
(32,), [], True, marks=pytest.mark.xfail(reason="Unexpected output shape (1, 1) instead of (1) (#240)")
),
# TODO(#240): Need inf padding value
pytest.param((1, 32), [], True, marks=pytest.mark.xfail(reason="Need inf padding value (#240)")),
pytest.param((32, 1), [], True, marks=pytest.mark.xfail(reason="Need inf padding value (#240)")),
# TODO(#240): Output reshape inside generic reduction can't handle non-tile-aligned size
pytest.param(
(1, 32),
[1],
True,
marks=pytest.mark.xfail(
reason="Output reshape inside generic reduction can't handle non-tile-aligned size (#240)"
),
),
],
)
def test_amin(device, sign, input_shape, dim, keepdim):
m = AminModule()
input = torch.rand(input_shape, dtype=torch.bfloat16) * sign
result_before = m.forward(input, dim, keepdim)

option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=True)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input, dim, keepdim)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.min) == 1
# Check inference result
assert result_before.shape == result_after.shape
assert torch.allclose(result_before, result_after)
9 changes: 7 additions & 2 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def is_function_call(node) -> bool:
ttnn.log1p,
ttnn.log2,
ttnn.logical_not,
ttnn.min,
ttnn.neg,
ttnn.reciprocal,
ttnn.relu,
Expand Down Expand Up @@ -109,6 +108,12 @@ def is_function_call(node) -> bool:
ttnn.where,
]

TTNN_REDUCTION_OPS = [
ttnn.max,
ttnn.mean,
ttnn.min,
]

TTNN_MATRIX_MULPIPLICATION_OPS = [
ttnn.matmul,
ttnn.linear,
Expand Down Expand Up @@ -148,6 +153,7 @@ def is_tt_compute(node) -> bool:
+ TTNN_POINTWISE_BINARY_OPS
+ TTNN_POINTWISE_TRINARY_OPS
+ TTNN_MATRIX_MULPIPLICATION_OPS
+ TTNN_REDUCTION_OPS
+ TTNN_TARGET_WRAPPERS
+ TTNN_DATAMOVE_OPS
+ TTNN_NORM_OPS
Expand All @@ -157,7 +163,6 @@ def is_tt_compute(node) -> bool:
ttnn.tril,
ttnn.arange,
ttnn.zeros_like,
ttnn.mean,
ttnn.global_avg_pool2d,
ttnn.clip,
ttnn.squeeze,
Expand Down
23 changes: 23 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,29 @@ def rewrite_node(node):
input = g.call_function(ttnn.to_layout, args=(input, TtnnRowMajorLayout()))
return g.call_function(ttnn.pad, args=(input, full_pad, value))

if node.target in [torch.ops.aten.amin.default, torch.ops.aten.amax.default]:
input_shape = args[0].meta["val"].size()
# TODO(#240): Not support keepdim = false (default value)
if len(args) < 3 or args[2] == False:
return None
# TODO(#240): Not support rank < 2 or non-tile-size-aligned tensor
if len(input_shape) < 2 or any(size % ttnn.TILE_SIZE != 0 for size in input_shape[-2:]):
return None
new_args = list(args)
# Convert dim int/list to tuple
if len(args) >= 2:
dim = args[1]
dim = (dim,) if isinstance(dim, int) else tuple(dim)
# TODO(#240): Not support reduction on < rank - 2 dims
if any(idx < len(input_shape) - 2 for idx in dim):
return None
new_args[1] = dim if len(dim) > 0 else None
return g.call_function(
ttnn.min if node.target == torch.ops.aten.amin.default else ttnn.max,
tuple(new_args),
kwargs,
)

with g.inserting_before(node):
new_node = rewrite_node(node)
if new_node is not None:
Expand Down