Skip to content

Commit

Permalink
Merge branch 'main' into kw/layout_device_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinwuTT committed Nov 11, 2024
2 parents 00c9b67 + 7fc898a commit 1751c98
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 171 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ tabulate==0.9.0
networkx==3.1
graphviz
matplotlib==3.7.1
https://github.com/tenstorrent/tt-metal/releases/download/v0.53.0-rc36/metal_libs-0.53.0rc36+wormhole.b0-cp38-cp38-linux_x86_64.whl
https://github.com/tenstorrent/tt-metal/releases/download/v0.53.0-rc37/metal_libs-0.53.0rc37+wormhole.b0-cp38-cp38-linux_x86_64.whl
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def compile_and_run(device, reset_torch_dynamo, request):
# Compile model with ttnn backend
option = torch_ttnn.TorchTtnnOption(
device=device,
gen_graphviz=True,
gen_graphviz=False,
run_mem_analysis=False,
metrics_path=model_name,
verbose=True,
Expand Down
57 changes: 57 additions & 0 deletions tests/lowering/misc/test_cumsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import torch_ttnn
import pytest

from tests.utils import assert_with_pcc


class CumsumModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, dim):
return torch.ops.aten.cumsum.default(input, dim=dim)


@pytest.mark.parametrize(
"input_shapes, dim",
[
((1, 32), -1),
((1, 45), -1),
((1, 59), 1),
((1, 5), -1),
((1, 60), 1),
((1, 10), 1),
((4, 32, 32), 0),
((1, 4, 32, 32), 1),
((4, 4, 32, 32), 0),
((1, 23, 40), 1),
((4, 32), 0),
pytest.param(
(1, 1, 32, 32),
3,
marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"),
),
pytest.param(
(1, 23, 40),
2,
marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"),
),
],
)
def test_cumsum(device, input_shapes, dim):
m = CumsumModule()
inputs = torch.rand(input_shapes, dtype=torch.bfloat16)
result_before = m.forward(inputs, dim)

option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=False)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)

result_after = m.forward(inputs, dim)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert nodes.count(torch.ops.aten.cumsum.default) == 0
assert_with_pcc(result_before, result_after, pcc=0.99)
172 changes: 19 additions & 153 deletions tests/lowering/tensor_manipulation/test_expand.py
Original file line number Diff line number Diff line change
@@ -1,178 +1,44 @@
import torch
import torch_ttnn
import pytest
import ttnn
from torch_ttnn.utils import (
TtnnRowMajorLayout,
TtnnTileLayout,
)

from tests.utils import assert_with_pcc

class ExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, new_shape):
return x.expand(new_shape)


class ExpandAfterOpModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, new_shape):
a = torch.clone(x)
return a.expand(new_shape)


class ExpandBetweenOpsModule(torch.nn.Module):
class ExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, new_shape):
a = torch.clone(x)
ex = a.expand(new_shape)
return torch.add(ex, ex)
def forward(self, input_tensor, shape):
return torch.ops.aten.expand.default(input_tensor, shape)


@pytest.mark.xfail(reason="lowering issue (#67)")
@pytest.mark.parametrize(
"input_shape, new_shape",
"input_shape, output_shape",
[
((1, 4), (4, 4)),
((1, 2), (32, -1)),
((1, 4), (32, -1)),
((1, 6), (32, -1)),
pytest.param((1, 3), (32, -1), marks=pytest.mark.xfail()),
pytest.param((12, 1), (-1, 32), marks=pytest.mark.xfail()),
],
)
def test_expand(device, input_shape, new_shape):
def test_expand(device, input_shape, output_shape):
m = ExpandModule()
tensor = torch.rand(input_shape, dtype=torch.bfloat16)
inputs = [tensor, new_shape]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
target = [node.target for node in nodes]
assert target.count(ttnn.repeat) == 1
assert nodes[target.index(ttnn.repeat)].args[1].target == ttnn.Shape
# Check inference result
assert torch.allclose(result_before, result_after, rtol=0.2)

input_tensor = torch.rand(input_shape, dtype=torch.bfloat16)
result_before = m.forward(input_tensor, output_shape)

@pytest.mark.xfail(reason="lowering issue (#67)")
@pytest.mark.parametrize(
"input_shape, new_shape",
[
((1, 4), (4, 4)),
],
)
def test_expand_after_op(device, input_shape, new_shape):
m = ExpandAfterOpModule()
tensor = torch.rand(input_shape, dtype=torch.bfloat16)
inputs = [tensor, new_shape]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
target = [node.target for node in nodes]
assert target.count(ttnn.repeat) == 1
repeat_node = nodes[target.index(ttnn.repeat)]
assert repeat_node.args[0].target == ttnn.to_layout
assert repeat_node.args[0].args[0].target == ttnn.clone
assert type(repeat_node.args[0].args[1]) is type(TtnnRowMajorLayout())
assert repeat_node.args[1].target == ttnn.Shape
# Check inference result
assert torch.allclose(result_before, result_after, rtol=0.2)


@pytest.mark.xfail(reason="lowering issue (#67)")
@pytest.mark.parametrize(
"input_shape, new_shape",
[
((1, 4), (4, 4)),
],
)
def test_expand_before_op(device, input_shape, new_shape):
class ExpandBeforeOpModule(torch.nn.Module):
def __init__(self):
super().__init__()
option.gen_graphviz = False

def forward(self, x, new_shape):
ex = x.expand(new_shape)
return torch.add(ex, ex)

m = ExpandBeforeOpModule()
tensor = torch.rand(input_shape, dtype=torch.bfloat16)
inputs = [tensor, new_shape]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
target = [node.target for node in nodes]
assert target.count(ttnn.repeat) == 1
assert nodes[target.index(ttnn.repeat)].args[1].target == ttnn.Shape
# to_layout that follows ttnn.repeat
to_layout_idx = target.index(ttnn.to_layout, target.index(ttnn.repeat))
to_layout_node = nodes[to_layout_idx]
assert to_layout_node.args[0].target == ttnn.repeat
assert type(to_layout_node.args[1]) is type(TtnnTileLayout())
assert target.count(ttnn.add) == 1
assert to_layout_idx < target.index(ttnn.add)

# Check inference result
assert torch.allclose(result_before, result_after, rtol=0.2)


@pytest.mark.xfail(reason="lowering issue (#67)")
@pytest.mark.parametrize(
"input_shape, new_shape",
[
((1, 4), (4, 4)),
],
)
def test_expand_between_ops(device, input_shape, new_shape):
m = ExpandBetweenOpsModule()
tensor = torch.rand(input_shape, dtype=torch.bfloat16)
inputs = [tensor, new_shape]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
result_after = m.forward(input_tensor, output_shape)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
target = [node.target for node in nodes]
assert target.count(ttnn.repeat) == 1
repeat_node = nodes[target.index(ttnn.repeat)]
assert repeat_node.args[0].target == ttnn.to_layout
assert repeat_node.args[0].args[0].target == ttnn.clone
assert type(repeat_node.args[0].args[1]) is type(TtnnRowMajorLayout())
assert repeat_node.args[1].target == ttnn.Shape
# to_layout that follows ttnn.repeat
to_layout_idx = target.index(ttnn.to_layout, target.index(ttnn.repeat))
to_layout_node = nodes[to_layout_idx]
assert to_layout_node.args[0].target == ttnn.repeat
assert type(to_layout_node.args[1]) is type(TtnnTileLayout())
assert target.count(ttnn.add) == 1
assert to_layout_idx < target.index(ttnn.add)
# Check inference result
assert torch.allclose(result_before, result_after, rtol=0.2)
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert nodes.count(torch.ops.aten.expand.default) == 0

assert_with_pcc(result_before, result_after, pcc=0.99)
14 changes: 10 additions & 4 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def is_function_call(node) -> bool:
)


<<<<<<< HEAD
# FIXME: Workaround function for unsupported features for ttnn.reshape
# BUG (https://github.com/tenstorrent/tt-metal/issues/13889)
def can_reshape(node):
Expand All @@ -163,6 +164,8 @@ def have_unsupported_ranks(src_node, dst_node):
return len(dst_node_shape) > 5 or len(dst_node_shape) == 1


=======
>>>>>>> main
# For operations limitations
# See https://github.com/tenstorrent-metal/tt-metal/blob/main/ttnn/README.md?plain=1#L19
def is_tt_compute(node) -> bool:
Expand Down Expand Up @@ -195,6 +198,8 @@ def is_tt_compute(node) -> bool:
ttnn.squeeze,
ttnn.full,
ttnn.as_tensor,
ttnn.expand,
ttnn.moreh_cumsum,
]
)

Expand Down Expand Up @@ -357,10 +362,6 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N
else:
kwargs["dtype"] = TtnnBfloat16()

# if is_tt_compute(dst_node):
# if not (dst_node.target == ttnn.reshape and have_unsupported_ranks(src_node, dst_node)):
# kwargs["device"] = device

new_nodes.append(g.call_function(ttnn.from_torch, (src_node,), kwargs))

insert_node_between(src_node, dst_idx, dst_node, new_nodes)
Expand Down Expand Up @@ -390,6 +391,11 @@ def try_add_layout_change_before_node(src_node, dst_idx, dst_node, device) -> to
if dst_node.target == ttnn.slice:
need_to_layout = True

# # TODO(#372): #322 will enable tile layout for more layout change ops
# if dst_node.target in TTNN_LAYOUT_CHANGE_OPS and dst_idx == 0 and is_tt(src_node):
# need_from_device = True
# need_to_layout = True

if dst_node.target in [ttnn.embedding, ttnn.zeros_like, target_wrappers.repeat]:
# TODO: Only uint32 needs to to_layout on host
need_from_device = True
Expand Down
2 changes: 0 additions & 2 deletions torch_ttnn/passes/lowering/to_tt_guard_autogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,6 @@
aten__log_softmax_default_blocklist = [["Tensor<[19, 256008]> self = ?", "int dim = 1", "bool half_to_float = False"]]
aten_expand_default_blocklist = [
["Tensor<[1, 1, 1, 19]> self = ?", "List[int] size = [1, 1, 19, 19]"],
["Tensor<[256, 1280]> self = ?", "List[int] size = [1, -1, -1]"],
["Tensor<[2048, 768]> self = ?", "List[int] size = [1, -1, -1]"],
["Tensor<[1, 5]> self = ?", "List[int] size = [5, 5]"],
["Tensor<[1, 3]> self = ?", "List[int] size = [3, 3]"],
["Tensor<[1, 17]> self = ?", "List[int] size = [13, 17]"],
Expand Down
Loading

0 comments on commit 1751c98

Please sign in to comment.