Skip to content

Commit

Permalink
Support non-aligned shape
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrysky3 committed Sep 13, 2024
1 parent d270a13 commit affe6b0
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 28 deletions.
29 changes: 14 additions & 15 deletions tests/lowering/eltwise/unary/test_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,24 @@ def forward(self, input, value):


@pytest.mark.parametrize(
"input_shape",
"input_shape, value, converted",
[
(64,),
(32, 32),
(1, 64, 32),
(2, 32, 64),
(32, 1),
(1, 32),
(16, 64),
((32, 32), 0.5, True),
((32, 1), 114, True),
((1, 32), 120, True),
((4, 4), 32, True),
((1, 2, 100, 100), 16, True),
((1, 1, 1, 1), 61, True),
((2, 1, 1, 1), 64, True),
((1, 1, 2, 1), 8, True),
# Not supported: dims > 4 or < 2
((1, 1, 1, 2, 2), 1.0, False),
((64,), 1.0, False),
],
)
def test_fill_scalar(device, input_shape):
def test_fill_scalar(device, input_shape, value, converted):
m = FillScalarModule()
input = torch.rand(input_shape, dtype=torch.bfloat16)
value = 0.125
result_before = m.forward(input, value)

option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=True)
Expand All @@ -38,10 +41,6 @@ def test_fill_scalar(device, input_shape):

# Check the graph has be rewritten and contains ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
# Currently only tensor sizes divisible by 32x32 are converted.
if len(input_shape) > 1 and input_shape[-1] % ttnn.TILE_SIZE == 0 and input_shape[-2] % ttnn.TILE_SIZE == 0:
assert [node.target for node in nodes].count(ttnn.full) == 1
else:
assert [node.target for node in nodes].count(ttnn.full) == 0
assert [node.target for node in nodes].count(ttnn.full) == (1 if converted else 0)
# Check inference result
assert torch.allclose(result_before, result_after)
2 changes: 1 addition & 1 deletion torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def is_function_call(node) -> bool:
ttnn.to_layout,
]

TTNN_TARGET_WRAPPERS = [target_wrappers.clone, target_wrappers.repeat]
TTNN_TARGET_WRAPPERS = [target_wrappers.clone, target_wrappers.repeat, target_wrappers.getitem]

TTNN_NORM_OPS = [
ttnn.group_norm,
Expand Down
5 changes: 5 additions & 0 deletions torch_ttnn/passes/lowering/target_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ def clone(t):
@torch.fx.wrap
def repeat(t, sizes):
return ttnn.repeat(t, ttnn.Shape(sizes))


@torch.fx.wrap
def getitem(t, bounds):
return t[[slice(lower, upper) for lower, upper in bounds]]
25 changes: 13 additions & 12 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,18 +443,19 @@ def rewrite_node(node):
return None

if node.target == torch.ops.aten.fill.Scalar:
shape = tuple(node.meta["val"].size())
if has_valid_page_size(shape, strict=True):
new_kwargs = {
"fill_value": args[1],
"device": TtnnDevice(),
"layout": TtnnTileLayout(),
}
return g.call_function(
ttnn.full,
args=(shape,),
kwargs=new_kwargs,
)
shape = list(node.meta["val"].size())
if len(shape) < 2 or len(shape) > 4:
return None
aligned_shape = shape[:-2] + [
(dim + ttnn.TILE_SIZE - 1) // ttnn.TILE_SIZE * ttnn.TILE_SIZE for dim in shape[-2:]
]
new_node = g.call_function(
ttnn.full,
args=(aligned_shape,),
kwargs={"fill_value": args[1], "device": TtnnDevice(), "layout": TtnnTileLayout()},
)
bounds = [(0, dim) for dim in shape]
return g.call_function(target_wrappers.getitem, args=(new_node, bounds))

if node.target == torch.ops.aten.baddbmm.default:
# out = beta * input + alpha * (batch1 @ batch2)
Expand Down

0 comments on commit affe6b0

Please sign in to comment.