Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax restrictions when inserting ttnn.to_layout and ttnn.{to,from}_device ops #322

Open
wants to merge 65 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
650152e
Relax restrictions for inserting `ttnn.to_layout` and `ttnn.{to,from}…
kevinwuTT Oct 16, 2024
7a1e9d6
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Oct 16, 2024
db1c3c9
Restrict reshape for tensors with rank > 4
kevinwuTT Oct 17, 2024
fcc42ff
Fix restrictions to reshape again
kevinwuTT Oct 17, 2024
a6873b8
Consider reshape to 1-D
kevinwuTT Oct 17, 2024
b236038
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Oct 18, 2024
7567f38
Keep util helper function names consistent
kevinwuTT Oct 21, 2024
9c99eba
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Oct 22, 2024
61e3074
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Oct 22, 2024
3f722c1
Fix reporting of device related ttnn ops
kevinwuTT Oct 22, 2024
fb16e96
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Oct 22, 2024
723dc62
Remove a reshape input variant that has a workaround
kevinwuTT Oct 22, 2024
3f2bfa2
Remove check for torch.fx.Node since PsuedoNode is being used for tt_…
kevinwuTT Oct 23, 2024
86619a0
Revise some restrictions and interactions
kevinwuTT Oct 23, 2024
4d6f5cc
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Oct 24, 2024
0c3de8e
Revise reshape restrictions
kevinwuTT Oct 24, 2024
6bc8976
Fix slice
kevinwuTT Oct 31, 2024
e99cb00
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Oct 31, 2024
f7676e8
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Nov 4, 2024
b3f0a2b
Add some inputs to blacklist for view
kevinwuTT Nov 5, 2024
6455d2c
Fix data movement with different layouts and host/device requirements…
jerrysky3 Nov 1, 2024
0aa9b3c
Remove the blocklist related to issue #358
swimdi Nov 7, 2024
67ec52f
Move aten.add.Tensor restricted from to_tt_guard_autogen to to_tt_pass
swimdi Nov 7, 2024
2b6d028
Add tests/pattern/test_vilt.py
swimdi Nov 7, 2024
fe4e4bf
Add lowering to ttnn.div and other guards
kevinwuTT Nov 7, 2024
31aac9d
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Nov 7, 2024
4dac2c9
Merge branch 'stage1_swimdi_add' into kw/layout_device_ops
kevinwuTT Nov 7, 2024
6a6c7be
Fix data movement with different layouts and host/device requirements…
jerrysky3 Nov 1, 2024
0444cd0
Remove the blocklist related to issue #358
swimdi Nov 7, 2024
4910029
Move aten.add.Tensor restricted from to_tt_guard_autogen to to_tt_pass
swimdi Nov 7, 2024
c7c22dd
Add tests/pattern/test_vilt.py
swimdi Nov 7, 2024
39c812b
Not calculate_accuracy speecht5-tts in confest
swimdi Nov 8, 2024
3e6463c
Add tests/pattern/test_retinanet_pattern.py
swimdi Nov 8, 2024
1b68c5a
Relax more restrictions
kevinwuTT Nov 8, 2024
d21f1df
Merge branch 'stage1_swimdi_add' of github.com:tenstorrent/pytorch2.0…
kevinwuTT Nov 8, 2024
6f7878c
Merge branch 'stage1_swimdi_add' into kw/layout_device_ops
kevinwuTT Nov 8, 2024
5003605
Force TILE_LAYOUT and force fallback for lowerings that do not suppor…
kevinwuTT Nov 9, 2024
00c9b67
Unqueeze reshape does not support 5D
kevinwuTT Nov 11, 2024
6979151
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Nov 11, 2024
449f170
Fix unsqueeze for 4D inputs
kevinwuTT Nov 11, 2024
fe33687
Clean up
kevinwuTT Nov 12, 2024
e03cd9b
Remove remaining reshape restrictions
kevinwuTT Nov 12, 2024
f8bc7db
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Nov 12, 2024
6b3c2dc
Add blacklist for xglm and reshape to 1D
kevinwuTT Nov 12, 2024
5d30722
not rm autogen/all/conv
swimdi Nov 13, 2024
1a8786a
remove original_input_varations of to_torch and aten._to_copy because…
swimdi Nov 13, 2024
bca86b2
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Nov 13, 2024
2836c83
Merge remote-tracking branch 'origin/fix-docs' into kw/layout_device_ops
kevinwuTT Nov 13, 2024
e9cd2a9
Optimize to_copy
kevinwuTT Nov 13, 2024
8a8fe5e
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Nov 14, 2024
f2dfdd7
Rework after node layout change to mirror before node. Handle issues …
kevinwuTT Nov 15, 2024
b70b754
Include reshape when finding user of current target
kevinwuTT Nov 15, 2024
a455ffd
Use output_size for aten slice
kevinwuTT Nov 15, 2024
df6ce56
Handle input aliasing for get_attr nodes
kevinwuTT Nov 19, 2024
ab485dc
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Nov 19, 2024
7cac473
Lower to fallback if input > 4D for aten.select
kevinwuTT Nov 19, 2024
62db956
Skip testing avg pool for now
kevinwuTT Nov 19, 2024
de003cc
Insert clone nodes after each get_attr instead of end
kevinwuTT Nov 19, 2024
7f98bd6
get_attr can non FakeTensor types. Only add clone to these
kevinwuTT Nov 19, 2024
d2862f5
Convert fallback unsafe_view to reshape
kevinwuTT Nov 20, 2024
c7bcb51
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Nov 21, 2024
30e7f49
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Nov 21, 2024
b397236
Disable reshaping or unsqueezing for outputs > 4D for now
kevinwuTT Nov 21, 2024
f3303b6
Last cleanup
kevinwuTT Nov 23, 2024
10f381a
Merge branch 'main' into kw/layout_device_ops
kevinwuTT Nov 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/lowering/tensor_manipulation/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def forward(self, x, dim0, dim1):
# If not, this runtime error will be thrown:
# RuntimeError: TT_FATAL @ ../tt_metal/impl/buffers/buffer.cpp:41: page_size % sizeof(uint32_t) == 0
((5, 3, 2), 0, 2),
((1, 4150, 192), 1, 2),
],
)
def test_transpose(device, input_shape, dim0, dim1):
Expand Down
23 changes: 23 additions & 0 deletions tests/lowering/tensor_manipulation/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,29 @@ def forward(self, x, new_shape):
((256, 4096), (1, 256, 4096)),
((1, 32, 16, 96), (1, 32, 1536)),
((1, 192, 4150), (1, 192, 50, 83)),
((1, 32, 4608), (1, 32, 16, 3, 96)),
((1, 100, 192), (100, 192)),
((1, 1445, 192), (1, 1445, 3, 64)),
((1, 1445, 192), (1445, 192)),
((1, 1445, 3, 64), (1, 1445, 192)),
((1, 1445, 768), (1445, 768)),
((1, 192, 32, 42), (1, 192, 1344)),
((1, 192, 4150), (1, 192, 50, 83)),
((1, 3, 1445, 1445), (3, 1445, 1445)),
((1, 3, 1445, 64), (3, 1445, 64)),
((1, 3, 64, 1445), (3, 64, 1445)),
((100, 192), (1, 100, 192)),
((100, 4), (1, 100, 4)),
((100, 92), (1, 100, 92)),
((1445, 192), (1, 1445, 192)),
((1445, 768), (1, 1445, 768)),
((192), (1, 192, 1, 1)),
((1), (1, 1, 1, 1)),
((3, 1445, 1445), (1, 3, 1445, 1445)),
((3, 1445, 64), (1, 3, 1445, 64)),
((32), (1, 1, 32, 1)),
((42), (1, 1, 1, 42)),
((1, 10), (10,)),
],
)
def test_reshape(device, input_shape, new_shape, module):
Expand Down
72 changes: 47 additions & 25 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
TtnnBfloat16,
TtnnUint32,
HasValidPageSize,
CanBeTilized,
)


Expand Down Expand Up @@ -137,13 +138,31 @@ def is_function_call(node) -> bool:
[
ttnn.reshape,
ttnn.slice,
ttnn.full,
]
)


def can_be_tilized(node):
size = node.meta["val"].size()
return len(size) >= 2 and size[-1] % 32 == 0 and size[-2] % 32 == 0
# FIXME: Workaround function for unsupported features for ttnn.reshape
# BUG (https://github.com/tenstorrent/tt-metal/issues/13891)
# BUG (https://github.com/tenstorrent/tt-metal/issues/13889)
def can_reshape(node):
shape = node.meta["val"].size()
# Unsupported H dims
unsupported_H_dim = set([1, 1445, 100])
# Unsupported if output rank is > 4
return len(shape) >= 2 and shape[-2] not in unsupported_H_dim and len(shape) <= 4


# FIXME: Workaround functions for unsupported features for ttnn.reshape
def get_shape(node):
return node.meta["val"].size()


def have_supported_ranks(src_node, dst_node):
dst_node_shape = get_shape(dst_node)
src_node_shape = get_shape(src_node)
return len(dst_node_shape) > 4 or len(src_node_shape) > 4 or len(dst_node_shape) == 1


# For operations limitations
Expand Down Expand Up @@ -292,10 +311,11 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N
with g.inserting_before(dst_node):
kwargs = {}
if (
(dst_node.target in TTNN_LAYOUT_CHANGE_OPS and not can_be_tilized(dst_node))
dst_node.target == ttnn.slice
or dst_node.target == ttnn.embedding
or dst_node.target == ttnn.zeros_like
or dst_node.target == target_wrappers.repeat
or (dst_node.target == ttnn.reshape and have_supported_ranks(src_node, dst_node))
):
kwargs["layout"] = TtnnRowMajorLayout()
else:
Expand All @@ -306,10 +326,9 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N
else:
kwargs["dtype"] = TtnnBfloat16()

if (is_tt_compute(dst_node) and dst_node.target not in TTNN_LAYOUT_CHANGE_OPS) or (
dst_node.target in TTNN_LAYOUT_CHANGE_OPS and HasValidPageSize(src_node.meta["val"].size(), strict=True)
):
kwargs["device"] = device
if is_tt_compute(dst_node):
if not (dst_node.target == ttnn.reshape and have_supported_ranks(src_node, dst_node)):
kwargs["device"] = device

new_nodes.append(g.call_function(ttnn.from_torch, (src_node,), kwargs))

Expand All @@ -324,44 +343,47 @@ def try_add_layout_change_before_node(src_node, dst_idx, dst_node) -> torch.fx.n
if not is_function_call(dst_node):
return None
if (
dst_node.target not in TTNN_LAYOUT_CHANGE_OPS
or dst_idx != 0
or not is_tt(src_node)
or (dst_node.target in TTNN_LAYOUT_CHANGE_OPS and can_be_tilized(dst_node))
not is_tt(src_node)
or dst_node.target not in TTNN_LAYOUT_CHANGE_OPS
or (dst_node.target == ttnn.reshape and not have_supported_ranks(src_node, dst_node))
or (dst_node.target == ttnn.full and CanBeTilized(dst_node))
or (dst_node.target == ttnn.slice and not HasValidPageSize(dst_node, strict=True))
):
return None

g = dst_node.graph
new_nodes = []
with g.inserting_before(dst_node):
from_device = g.call_function(ttnn.from_device, (src_node,))
to_layout = g.call_function(ttnn.to_layout, (from_device, TtnnRowMajorLayout()))
new_nodes.append(g.call_function(ttnn.to_layout, (src_node, TtnnRowMajorLayout())))
if len(get_shape(dst_node)) > 4 or len(get_shape(dst_node)) == 1:
new_nodes.append(g.call_function(ttnn.from_device, (new_nodes[-1],)))

insert_node_between(src_node, dst_idx, dst_node, [from_device, to_layout])
insert_node_between(src_node, dst_idx, dst_node, new_nodes)

return to_layout
return new_nodes[-1]


def try_add_layout_change_after_node(src_node, dst_idx, dst_node, device) -> torch.fx.node.Node:
# Consider src_node is ttnn.repeat, and dst_node should be any tt_compute node that uses ttnn.repeat
if not is_function_call(src_node):
return None
if (
src_node.target not in TTNN_LAYOUT_CHANGE_OPS.union(set([target_wrappers.repeat]))
or not is_tt_compute(dst_node)
not is_tt_compute(dst_node)
or dst_node.target == ttnn.embedding
or dst_node.target == target_wrappers.repeat
or src_node.target not in TTNN_LAYOUT_CHANGE_OPS.union(set([target_wrappers.repeat]))
or (src_node.target == ttnn.reshape and can_reshape(src_node))
or (src_node.target == ttnn.full and CanBeTilized(src_node))
or (src_node.target == ttnn.slice and not HasValidPageSize(src_node, strict=True))
):
return None

g = dst_node.graph
new_nodes = []
with g.inserting_before(dst_node):
if dst_node.target != target_wrappers.repeat:
new_nodes.append(
g.call_function(ttnn.to_layout, (new_nodes[-1] if new_nodes else src_node, TtnnTileLayout()))
)
new_nodes.append(
g.call_function(ttnn.to_device, (new_nodes[-1] if new_nodes else src_node,), {"device": device})
)
new_nodes.append(g.call_function(ttnn.to_layout, (new_nodes[-1] if new_nodes else src_node, TtnnTileLayout())))
if len(get_shape(src_node)) > 4 or len(get_shape(src_node)) == 1:
new_nodes.append(g.call_function(ttnn.to_device, (new_nodes[-1], TtnnDevice())))

insert_node_between(src_node, dst_idx, dst_node, new_nodes)

Expand Down
6 changes: 4 additions & 2 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
TtnnDramMemoryConfig,
TtnnRowMajorLayout,
HasValidPageSize,
CanBeTilized,
)
import numpy as np
from typing import Tuple
Expand Down Expand Up @@ -526,17 +527,18 @@ def rewrite_node(node):
# Page size must be divisible by sizeof(uint32_t) because buffers hold uint32_t values
if shape[-1] != 1 and HasValidPageSize(shape):
if isinstance(args[1], float):
layout = TtnnTileLayout() if CanBeTilized(shape) else TtnnRowMajorLayout()
new_kwargs = {
"fill_value": args[1],
"device": TtnnDevice(),
"layout": layout,
}
full = g.call_function(
ttnn.full,
args=(tuple(shape),),
kwargs=new_kwargs,
)
to_layout = g.call_function(ttnn.to_layout, (full,), {"layout": TtnnTileLayout()})
recip = g.call_function(ttnn.reciprocal, (to_layout,), {})
recip = g.call_function(ttnn.reciprocal, (full,), {})
else:
recip = g.call_function(ttnn.reciprocal, (args[1],), {})
return g.call_function(ttnn.mul, (args[0], recip), {})
Expand Down
24 changes: 21 additions & 3 deletions torch_ttnn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,30 @@ def GraphCleanup(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
return gm


def GetShape(node_or_shape):
if isinstance(node_or_shape, torch.fx.node.Node):
if (val := node_or_shape.meta.get("val", None)) is not None:
return val.size()
elif isinstance(node_or_shape, torch.Size):
return node_or_shape

return None


# Certain ops don't support certain shapes and will emit a valid_page_size error
# RuntimeError: TT_FATAL @ ../tt_metal/impl/buffers/buffer.cpp:38: valid_page_size
# For valid non-interleaved buffers page size 2048 must equal buffer size X. For interleaved-buffers page size should be divisible by buffer size
def HasValidPageSize(shape, strict=False):
if len(shape) >= 2 and shape[-1] > 0:
return shape[-1] % 32 == 0 or (not strict and shape[-1] < 32)
def HasValidPageSize(node_or_shape, strict=False):
if (shape := GetShape(node_or_shape)) is not None:
if len(shape) >= 2 and len(shape) <= 4 and shape[-1] > 0:
return shape[-1] % 32 == 0 or (not strict and shape[-1] < 32)
return False


def CanBeTilized(node_or_shape):
if (shape := GetShape(node_or_shape)) is not None:
if len(shape) >= 2 and len(shape) <= 4 and shape[-1] > 0 and shape[-2] > 1:
return shape[-1] % 32 == 0 or shape[-2] % 32 == 0
return False


Expand Down
Loading