Skip to content

Commit

Permalink
#0: Enforce tile layout when using bf4/bf8 data types (#16199)
Browse files Browse the repository at this point in the history
### Ticket
N/A

### Problem description
User can specify row major layout when using block float formats, which
is not supported.

### What's changed
Instead of silently overriding to use the tilized layout, throw an
error.

### Checklist
- [X] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12451630816)
- [X] [All T3K
tests](https://github.com/tenstorrent/tt-metal/actions/runs/12439116600)
- failures unrelated

---------

Co-authored-by: Oleg Milyutin <[email protected]>
  • Loading branch information
omilyutin-tt and Oleg Milyutin authored Dec 23, 2024
1 parent 326f022 commit 87c5423
Show file tree
Hide file tree
Showing 15 changed files with 336 additions and 518 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,8 @@ def generate_attn_mask(N, C, W, dev, offs, dtype, mem_config):
nc_tiles = [((top_row if i % 2 else neg_top_row) + zero_rows) for i in range(NC)]
nc_tiles_pt = torch.Tensor(nc_tiles).reshape(N, C, 32, W)
valtorch = torch.Tensor([(top_row if i % 2 else neg_top_row) for i in range(NC)]).reshape(N, C, 1, W)
val = (
ttnn.Tensor(
nc_tiles_pt,
dtype,
)
.to(ttnn.TILE_LAYOUT)
.to(
dev,
mem_config,
)
)
val = ttnn.Tensor(nc_tiles_pt, dtype, dev, ttnn.TILE_LAYOUT, mem_config)

# print("Attn mask=", valtorch)
return valtorch, val

Expand All @@ -70,17 +61,7 @@ def run_softmax_tests(dev, test_id, batch, dtype, in0_mem_config):
for N, C, H, W in test_dims:
x = torch.randn((N, C, H, W)) * 2.0 - 1.0

t0 = (
ttnn.Tensor(
x,
dtype,
)
.to(ttnn.TILE_LAYOUT)
.to(
dev,
in0_mem_config,
)
)
t0 = ttnn.Tensor(x, dtype, dev, ttnn.TILE_LAYOUT, in0_mem_config)

if test_id == 0:
logger.info("Running scale_mask_softmax")
Expand All @@ -95,7 +76,7 @@ def run_softmax_tests(dev, test_id, batch, dtype, in0_mem_config):
else:
assert False

tt_unt = t1_fused.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
tt_unt = ttnn.to_torch(t1_fused)

passing = is_close(tt_unt, ref_sm, rtol=5e-2, atol=5e-2)
assert passing, "is_close check failed"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,12 @@ def run_bert_large_concatenate_heads_test(device, batch, dtype, in0_mem_config,

A = torch.randn(a_shape)

a_t = (
ttnn.Tensor(
A.flatten().tolist(),
a_shape,
dtype,
ttnn.ROW_MAJOR_LAYOUT,
)
.to(ttnn.TILE_LAYOUT)
.to(device, in0_mem_config)
)
a_t = ttnn.Tensor(
A.flatten().tolist(),
a_shape,
dtype,
ttnn.TILE_LAYOUT,
).to(device, in0_mem_config)

out = ttnn.experimental.concatenate_heads(a_t, ttnn.CoreCoord(12, 9), memory_config=out_mem_config)

Expand All @@ -44,8 +40,7 @@ def run_bert_large_concatenate_heads_test(device, batch, dtype, in0_mem_config,
logger.debug(f"out: {out.memory_config().buffer_type} and {out.get_dtype()}")

assert out.shape.with_tile_padding() == [batch, 1, 384, 1024]
tt_host_rm_out = out.cpu().to(ttnn.ROW_MAJOR_LAYOUT)
pyt_got_back_rm_out = tt_host_rm_out.to_torch()
pyt_got_back_rm_out = ttnn.to_torch(out)

ref_out = torch.transpose(A, -3, -2).reshape([batch, 1, 384, 1024])
passing_pcc, output_pcc = comp_pcc(pyt_got_back_rm_out, ref_out, 0.99)
Expand Down Expand Up @@ -98,13 +93,13 @@ def test_bert_large_concatenate_heads_with_program_cache(device, use_program_cac
run_bert_large_concatenate_heads_test(device, 9, dtype, mem_config, mem_config)
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

mem_config = ttnn.L1_MEMORY_CONFIG
for _ in range(2):
run_bert_large_concatenate_heads_test(device, 9, dtype, mem_config, mem_config)
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -45,41 +45,29 @@ def run_bert_large_ff1_matmul_test(
bias_pad_shape = [1, 1, 32, 4096]
A = torch.randn(a_shape)
B = torch.randn(b_shape) - 0.95
BIAS = torch.randint(-20, 20, bias_shape, dtype=torch.float)
bias = torch.randint(-20, 20, bias_shape, dtype=torch.float)
bias_padded = torch.nn.functional.pad(bias, (0, 0, 0, 32 - bias.size(2)))

a_t = (
ttnn.Tensor(
A.flatten().tolist(),
a_shape,
dtype,
ttnn.ROW_MAJOR_LAYOUT,
)
.to(ttnn.TILE_LAYOUT)
.to(device, in0_mem_config)
)
b_t = (
ttnn.Tensor(
B.flatten().tolist(),
b_shape,
dtype,
ttnn.ROW_MAJOR_LAYOUT,
)
.to(ttnn.TILE_LAYOUT)
.to(device, in1_mem_config)
)
a_t = ttnn.Tensor(
A.flatten().tolist(),
a_shape,
dtype,
ttnn.TILE_LAYOUT,
).to(device, in0_mem_config)
b_t = ttnn.Tensor(
B.flatten().tolist(),
b_shape,
dtype,
ttnn.TILE_LAYOUT,
).to(device, in1_mem_config)

if bias_mem_config is not None:
bias_t = (
ttnn.Tensor(
BIAS.flatten().tolist(),
bias_shape,
dtype,
ttnn.ROW_MAJOR_LAYOUT,
)
.pad(bias_pad_shape, [0, 0, 0, 0], 0)
.to(ttnn.TILE_LAYOUT)
.to(device, bias_mem_config)
)
bias_t = ttnn.Tensor(
bias_padded.flatten().tolist(),
bias_pad_shape,
dtype,
ttnn.TILE_LAYOUT,
).to(device, bias_mem_config)
else:
bias_t = None

Expand All @@ -103,12 +91,11 @@ def run_bert_large_ff1_matmul_test(
logger.debug(f"out is on: {t2.memory_config().buffer_type}")

assert t2.shape.with_tile_padding() == [9, 1, 384, 4096]
tt_host_rm = t2.cpu().to(ttnn.ROW_MAJOR_LAYOUT)
pyt_got_back_rm = tt_host_rm.to_torch()
pyt_got_back_rm = ttnn.to_torch(t2)

ref_bmm = torch.matmul(A, B)
if bias_mem_config is not None:
ref_bmm = ref_bmm + BIAS
ref_bmm = ref_bmm + bias
if fused_activation is not None:
if fused_activation[0] == ttnn.UnaryOpType.GELU:
ref_bmm = torch.nn.functional.gelu(ref_bmm)
Expand Down Expand Up @@ -200,7 +187,7 @@ def test_bert_large_ff1_matmul_with_program_cache(device, use_program_cache):
)
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

mem_config = ttnn.L1_MEMORY_CONFIG
for _ in range(2):
Expand All @@ -215,6 +202,6 @@ def test_bert_large_ff1_matmul_with_program_cache(device, use_program_cache):
)
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -31,41 +31,29 @@ def run_bert_large_ff2_matmul_test(device, dtype, in0_mem_config, in1_mem_config

A = torch.randn(a_shape)
B = torch.randn(b_shape) - 0.95
BIAS = torch.randint(-20, 20, bias_shape, dtype=torch.float)

a_t = (
ttnn.Tensor(
A.flatten().tolist(),
a_shape,
dtype,
ttnn.ROW_MAJOR_LAYOUT,
)
.to(ttnn.TILE_LAYOUT)
.to(device, in0_mem_config)
)
b_t = (
ttnn.Tensor(
B.flatten().tolist(),
b_shape,
dtype,
ttnn.ROW_MAJOR_LAYOUT,
)
.to(ttnn.TILE_LAYOUT)
.to(device, in1_mem_config)
)
bias = torch.randint(-20, 20, bias_shape, dtype=torch.float)
bias_padded = torch.nn.functional.pad(bias, (0, 0, 0, 32 - bias.size(2)))

a_t = ttnn.Tensor(
A.flatten().tolist(),
a_shape,
dtype,
ttnn.TILE_LAYOUT,
).to(device, in0_mem_config)
b_t = ttnn.Tensor(
B.flatten().tolist(),
b_shape,
dtype,
ttnn.TILE_LAYOUT,
).to(device, in1_mem_config)

if bias_mem_config is not None:
bias_t = (
ttnn.Tensor(
BIAS.flatten().tolist(),
bias_shape,
dtype,
ttnn.ROW_MAJOR_LAYOUT,
)
.pad(bias_pad_shape, [0, 0, 0, 0], 0)
.to(ttnn.TILE_LAYOUT)
.to(device, bias_mem_config)
)
bias_t = ttnn.Tensor(
bias_padded.flatten().tolist(),
bias_pad_shape,
dtype,
ttnn.TILE_LAYOUT,
).to(device, bias_mem_config)
else:
bias_t = None

Expand All @@ -84,12 +72,11 @@ def run_bert_large_ff2_matmul_test(device, dtype, in0_mem_config, in1_mem_config
logger.debug(f"out is on: {t2.memory_config().buffer_type}")

assert t2.shape.with_tile_padding() == [9, 1, 384, 1024]
tt_host_rm = t2.cpu().to(ttnn.ROW_MAJOR_LAYOUT)
pyt_got_back_rm = tt_host_rm.to_torch()
pyt_got_back_rm = ttnn.to_torch(t2)

ref_bmm = torch.matmul(A, B)
if bias_mem_config is not None:
ref_bmm = ref_bmm + BIAS
ref_bmm = ref_bmm + bias
passing_pcc, output_pcc = comp_pcc(ref_bmm, pyt_got_back_rm, 0.99)
logger.debug(f"Passing={passing_pcc}")
logger.debug(f"Output pcc={output_pcc}")
Expand Down Expand Up @@ -161,7 +148,7 @@ def test_bert_large_ff2_matmul_with_program_cache(device, use_program_cache):
)
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

mem_config = ttnn.L1_MEMORY_CONFIG
for _ in range(2):
Expand All @@ -175,6 +162,6 @@ def test_bert_large_ff2_matmul_with_program_cache(device, use_program_cache):
)
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -32,40 +32,28 @@ def run_bert_large_fused_qkv_matmul_test(

A = torch.randn(a_shape)
B = torch.randn(b_shape) - 0.95
BIAS = torch.randint(-20, 20, bias_shape, dtype=torch.float)

a_t = (
ttnn.Tensor(
A.flatten().tolist(),
a_shape,
dtype,
ttnn.ROW_MAJOR_LAYOUT,
)
.to(ttnn.TILE_LAYOUT)
.to(device, in0_mem_config)
)
b_t = (
ttnn.Tensor(
B.flatten().tolist(),
b_shape,
dtype,
ttnn.ROW_MAJOR_LAYOUT,
)
.to(ttnn.TILE_LAYOUT)
.to(device, in1_mem_config)
)
bias = torch.randint(-20, 20, bias_shape, dtype=torch.float)
bias_padded = torch.nn.functional.pad(bias, (0, 0, 0, 32 - bias.size(2)))

a_t = ttnn.Tensor(
A.flatten().tolist(),
a_shape,
dtype,
ttnn.TILE_LAYOUT,
).to(device, in0_mem_config)
b_t = ttnn.Tensor(
B.flatten().tolist(),
b_shape,
dtype,
ttnn.TILE_LAYOUT,
).to(device, in1_mem_config)
if bias_mem_config is not None:
bias_t = (
ttnn.Tensor(
BIAS.flatten().tolist(),
bias_shape,
dtype,
ttnn.ROW_MAJOR_LAYOUT,
)
.pad(bias_pad_shape, [0, 0, 0, 0], 0)
.to(ttnn.TILE_LAYOUT)
.to(device, bias_mem_config)
)
bias_t = ttnn.Tensor(
bias_padded.flatten().tolist(),
bias_pad_shape,
dtype,
ttnn.TILE_LAYOUT,
).to(device, bias_mem_config)
else:
bias_t = None

Expand All @@ -84,12 +72,11 @@ def run_bert_large_fused_qkv_matmul_test(
logger.debug(f"out is on: {t2.memory_config().buffer_type}")

assert t2.shape.with_tile_padding() == [9, 1, 384, 3072]
tt_host_rm = t2.cpu().to(ttnn.ROW_MAJOR_LAYOUT)
pyt_got_back_rm = tt_host_rm.to_torch()
pyt_got_back_rm = ttnn.to_torch(t2)

ref_bmm = torch.matmul(A, B)
if bias_mem_config is not None:
ref_bmm = ref_bmm + BIAS
ref_bmm = ref_bmm + bias
passing_pcc, output_pcc = comp_pcc(ref_bmm, pyt_got_back_rm, 0.99)
logger.debug(f"Passing={passing_pcc}")
logger.debug(f"Output pcc={output_pcc}")
Expand Down Expand Up @@ -161,7 +148,7 @@ def test_bert_large_fused_qkv_matmul_with_program_cache(device, use_program_cach
)
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

mem_config = ttnn.L1_MEMORY_CONFIG
for _ in range(2):
Expand All @@ -175,6 +162,6 @@ def test_bert_large_fused_qkv_matmul_with_program_cache(device, use_program_cach
)
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Loading

0 comments on commit 87c5423

Please sign in to comment.