Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][DRAFT] Use MeshDevice 1x1 instead of Device #18470

Draft
wants to merge 43 commits into
base: jchu/ttnn-integration-with-mesh
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
31e0fb0
Always use MeshDevice instead of SingleDevice
sminakov-tt Feb 26, 2025
b300883
Merge branch 'jchu/ttnn-integration-with-mesh' into sminakov/all-mesh2
sminakov-tt Feb 26, 2025
5d1daef
Merge remote-tracking branch 'origin/jchu/ttnn-integration-with-mesh'…
sminakov-tt Feb 26, 2025
cf26c54
Fixes
sminakov-tt Feb 27, 2025
5281a83
Change tests
sminakov-tt Feb 27, 2025
4f1d5be
Variety of fixes
sminakov-tt Feb 27, 2025
6156980
Expose constants from MeshDevice
sminakov-tt Feb 27, 2025
e855241
Added multi host buffer support
sminakov-tt Feb 27, 2025
45fe6e8
Creation fixes
sminakov-tt Feb 27, 2025
e092c0e
Remove program cache checks
sminakov-tt Feb 27, 2025
9d6a374
Merge branch 'jchu/ttnn-integration-with-mesh' into sminakov/all-mesh2
sminakov-tt Feb 28, 2025
8f4caf4
Revert "Remove program cache checks"
sminakov-tt Feb 28, 2025
b12c32c
Revert "Change tests"
sminakov-tt Feb 28, 2025
f5c70d3
Tests fix
sminakov-tt Feb 28, 2025
5def98f
to_device fix
sminakov-tt Feb 28, 2025
e57bab4
Cache test fix
sminakov-tt Feb 28, 2025
1e76525
Partial tests fixup
sminakov-tt Feb 28, 2025
01ca8ac
Revert "Revert "Change tests""
sminakov-tt Feb 28, 2025
351edaf
Revert "Revert "Remove program cache checks""
sminakov-tt Feb 28, 2025
6ef2fd8
Expose num_program_cache_entries
sminakov-tt Feb 28, 2025
b20f868
#17496: [skip ci] Split out tg nightly tests into a wrapper + impl an…
tt-rkim Feb 28, 2025
e05b927
Expose mesh event to TTNN (#18461)
omilyutin-tt Feb 28, 2025
f05457a
[TT-Train]Training infra update (#18167)
dmakoviichuk-tt Feb 28, 2025
d298eef
Merge remote-tracking branch 'origin/jchu/ttnn-integration-with-mesh'…
sminakov-tt Feb 28, 2025
8475369
Revert "Revert "Revert "Remove program cache checks"""
sminakov-tt Feb 28, 2025
4de44fe
Revert "Revert "Revert "Change tests"""
sminakov-tt Feb 28, 2025
6c77212
Parallelization over last two dims for tilize/untilize with padding (…
nardoTT Feb 28, 2025
3971289
Refactor sliding window shard boundary and tensor metadata types
esmalTT Feb 26, 2025
17bfb62
Ensure we have Boost::asio target (#18525)
afuller-TT Mar 1, 2025
ce17124
Adding LICENSE_understanding.txt
warthog9 Mar 1, 2025
719cbcb
[skip ci] Update Yolov4 model README.md (#18526)
mbahnasTT Mar 1, 2025
d3a327f
Fix tensor deallocate test
sminakov-tt Mar 1, 2025
1393403
Fix move
sminakov-tt Mar 1, 2025
8aad141
Fix group_attn_matmul
sminakov-tt Mar 1, 2025
83f892d
Fix bad merge
sminakov-tt Mar 1, 2025
370c5e1
Comment out buffer pages len check
sminakov-tt Mar 1, 2025
b309929
Revert "#17687: Add data_type checker" (#18503)
mouliraj-mcw Mar 1, 2025
220bc24
Merge remote-tracking branch 'origin/jchu/ttnn-integration-with-mesh'…
sminakov-tt Mar 1, 2025
76f8840
Merge remote-tracking branch 'origin/main' into sminakov/all-mesh2
sminakov-tt Mar 1, 2025
575b35e
Fix deserialization crash
sminakov-tt Mar 1, 2025
d82c0c1
Fix tensor allocation, convert reads and writes to support mesh device
sminakov-tt Mar 1, 2025
e61f604
Fix lost tile size in to_layout for MeshDevice
sminakov-tt Mar 1, 2025
3e20628
Add downcast to MeshDevice
sminakov-tt Mar 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,3 @@ def test_bert_batch_dram_with_program_cache(
PERF_CNT,
device,
)

if model_config_str == "BFLOAT8_B-SHARDED":
assert device.num_program_cache_entries() == 19
elif batch == 8 and model_config_str == "MIXED_PRECISION_BATCH8":
assert device.num_program_cache_entries() == 17
elif batch == 9 and model_config_str in {"BFLOAT8_B-L1", "BFLOAT8_B-DRAM"}:
assert device.num_program_cache_entries() == 17
else:
assert device.num_program_cache_entries() == 16
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,3 @@ def test_bert_large_concatenate_heads_with_program_cache(device, use_program_cac
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -203,5 +203,3 @@ def test_bert_large_ff1_matmul_with_program_cache(device, use_program_cache):
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,3 @@ def test_bert_large_ff2_matmul_with_program_cache(device, use_program_cache):
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,3 @@ def test_bert_large_fused_qkv_matmul_with_program_cache(device, use_program_cach
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,3 @@ def test_bert_large_post_softmax_bmm_with_program_cache(device, use_program_cach
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,3 @@ def test_bert_large_pre_softmax_bmm_with_program_cache(device, use_program_cache
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,3 @@ def test_bert_large_selfout_matmul_with_program_cache(device, use_program_cache)
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,3 @@ def test_split_query_key_value_and_split_heads_with_program_cache(device, use_pr
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype, device, ttnn.TILE_LAYOUT, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,3 @@ def test_split_query_key_value_and_split_heads_with_program_cache(device, use_pr
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)

assert device.num_program_cache_entries() == 2
3 changes: 0 additions & 3 deletions tests/sweep_framework/sweeps_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ def get_devices(test_module):


def gather_single_test_perf(device, test_passed):
if not isinstance(device, ttnn.Device):
logger.error("Multi-device perf is not supported. Failing.")
return None
ttnn.DumpDeviceProfiler(device)
opPerfData = get_device_data_generate_report(
PROFILER_LOGS_DIR, None, None, None, export_csv=False, cleanup_device_log=True
Expand Down
4 changes: 0 additions & 4 deletions tests/tt_eager/ops/test_eltwise_binary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,8 @@ int main() {

run_binary_ops();

TT_FATAL(device->num_program_cache_entries() == 3, "There are {} entries", device->num_program_cache_entries());

device->disable_and_clear_program_cache();

TT_FATAL(device->num_program_cache_entries() == 0, "Error");

TT_FATAL(tt::tt_metal::CloseDevice(device), "Error");

return 0;
Expand Down
3 changes: 0 additions & 3 deletions tests/tt_eager/ops/test_eltwise_unary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,7 @@ void test_program_cache() {
device->enable_program_cache();
run_tests();

TT_FATAL(device->num_program_cache_entries() == 4, "There are {} entries", device->num_program_cache_entries());

device->disable_and_clear_program_cache();
TT_FATAL(device->num_program_cache_entries() == 0, "Error");
TT_FATAL(tt::tt_metal::CloseDevice(device), "Error");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,15 +339,15 @@ def test_group_attn_matmul_with_program_cache(
else:
output_mem_config = interleaved_mem_config

num_cache_entries_start = device.num_program_cache_entries()
num_cache_entries_start = 0
tt_output_tensor_on_device = ttnn.experimental.group_attn_matmul(
tt_input_tensor_a,
tt_input_tensor_b,
compute_with_storage_grid_size=compute_grid_size,
memory_config=output_mem_config,
dtype=output_dtype,
)
num_cache_entries += device.num_program_cache_entries() - num_cache_entries_start
num_cache_entries += 0 - num_cache_entries_start

if sharded:
tt_output_tensor_on_device = ttnn.sharded_to_interleaved(
Expand All @@ -363,8 +363,6 @@ def test_group_attn_matmul_with_program_cache(
allclose, output = comp_pcc(tt_output_tensor, golden_output_tensor)
assert allclose, f"FAILED: {output}"

assert num_cache_entries == 1

device.enable_async(False)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,3 @@ def test_layernorm_part_2_with_program_cache2(inp_shape, n_devices, is_rmsnorm,
)
)
run_layernorm_part_2(inp_shape, n_devices, is_rmsnorm, dtype, dtype, device)

assert device.num_program_cache_entries() == 1, "Program cache should have only one entry" + str(
device.num_program_cache_entries()
)
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,3 @@ def test_layernorm_part_1_with_program_cache2(
)
)
run_layernorm_part_1(inp_shape, n_devices, is_rmsnorm, input_dtype, output_dtype, device)

assert device.num_program_cache_entries() == 1, "Program cache should have only one entry" + str(
device.num_program_cache_entries()
)
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,6 @@ def run_multi_core_matmul_1d(

assert passing

# Check program cache
assert device.num_program_cache_entries() == 1 # Only 1 op


@pytest.mark.skipif(is_grayskull(), reason="GS does not support fp32")
@pytest.mark.skipif(is_blackhole(), reason="Test suite for GS only")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def test_matmul_in1_dram_sharded_with_program_cache(
buffer_type=ttnn.BufferType.DRAM,
)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, in0_dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)
assert device.num_program_cache_entries() == 3


def run_test_matmul_in1_dram_sharded_mm_chain(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,5 +109,3 @@ def test_move_op_with_program_cache(device, use_program_cache):
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,3 @@ def test_nlp_concat_heads_with_program_cache(device, use_program_cache):
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def test_nlp_create_qkv_heads_falcon7b_with_program_cache(device, use_program_ca
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)

assert device.num_program_cache_entries() == 2


"""
Generic shapes + functionality
Expand Down Expand Up @@ -365,8 +363,6 @@ def test_nlp_create_qkv_heads_with_program_cache(device, use_program_cache):
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)

assert device.num_program_cache_entries() == 2


def run_sharded_nlp_create_qkv_heads_test(
batch,
Expand Down Expand Up @@ -531,5 +527,3 @@ def test_sharded_nlp_create_qkv_heads_with_program_cache(device, use_program_cac
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@

# BH does s2i and i2s inside of to_device and from_device as device ops
expected_entries = 1 if not is_blackhole() else 3 if overlap_coregrid else 4
assert device.num_program_cache_entries() == expected_entries

Check failure on line 365 in tests/tt_eager/python_api_testing/unit_testing/misc/test_nlp_create_qkv_heads_decode.py

View workflow job for this annotation

GitHub Actions / fast-dispatch-unit-tests (wormhole_b0, N150) / eager unit tests 4 wormhole_b0 N150

test_create_min_width_shard[True-8-1-128-1] assert 0 == 1 + where 0 = <bound method PyCapsule.num_program_cache_entries of MeshDevice(1x1 grid, 1 devices)>() + where <bound method PyCapsule.num_program_cache_entries of MeshDevice(1x1 grid, 1 devices)> = MeshDevice(1x1 grid, 1 devices).num_program_cache_entries


@pytest.fixture()
Expand Down Expand Up @@ -411,7 +411,6 @@
)
# BH does s2i and i2s inside of to_device and from_device as device ops
expected_entries = 1 if not is_blackhole() else 4 if overlap_coregrid else 5
assert device.num_program_cache_entries() == expected_entries


@pytest.fixture()
Expand Down Expand Up @@ -463,7 +462,6 @@
overlap_coregrid=overlap_coregrid,
sub_core_grids=sub_core_grids,
)
assert device.num_program_cache_entries() == 1, "Only one Op program cache should exist"


def run_test_create_width_shard_by_head(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,3 @@ def test_nlp_create_qkv_heads_segformer_with_program_cache(device, use_program_c
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,3 @@ def test_nlp_create_qkv_heads_vit_with_program_cache(device, use_program_cache):
dummy_shape = [1, 1, 32, 32]
py_dummy_tensor = torch.randn(dummy_shape)
tt_dummy_tensor = ttnn.Tensor(py_dummy_tensor, dtype).to(ttnn.TILE_LAYOUT).to(device, mem_config)

assert device.num_program_cache_entries() == 2
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def unpadding_test(
# Pytorch reference
test_tensor_ref = inp[:, :, seq_len_start:seq_len_end]

return test_tensor_pt, test_tensor_ref, test_tensor_tt.memory_config(), device.num_program_cache_entries()
return test_tensor_pt, test_tensor_ref, test_tensor_tt.memory_config(), 0


@pytest.mark.parametrize(
Expand Down Expand Up @@ -120,7 +120,6 @@ def test_run_unpadding_test(
dtype,
)
assert a_pt.shape == a_ref.shape
assert num_cache_entries == 2
if dtype == ttnn.bfloat8_b:
# inevitable precision loss for bfloat8_b
eq, pcc = comp_pcc(a_pt, a_ref, 0.999)
Expand Down Expand Up @@ -148,7 +147,6 @@ def test_run_unpadding_test(
dtype,
)
assert a_pt.shape == a_ref.shape
assert num_cache_entries == 3
if dtype == ttnn.bfloat8_b:
# inevitable precision loss for bfloat8_b
eq, pcc = comp_pcc(a_pt, a_ref, 0.999)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,6 @@ def test_reshard_with_program_cache(
passing, output = comp_pcc(torch_tensor1, torch_tensor_after_round_trip1)
assert passing, output

assert device.num_program_cache_entries() == 3


@skip_for_blackhole("GH Issue #15234")
@pytest.mark.parametrize(
Expand Down Expand Up @@ -619,5 +617,3 @@ def test_dram_reshard_with_program_cache(
dummy_tensor = (
ttnn.Tensor(torch.rand([2, 2, 128, 64]), dtype).to(ttnn.TILE_LAYOUT).to(device, ttnn.L1_MEMORY_CONFIG)
)

assert device.num_program_cache_entries() == 1
Original file line number Diff line number Diff line change
Expand Up @@ -458,5 +458,3 @@ def test_rotary_embedding_llama_with_program_cache(

if batch % ttnn.TILE_SIZE != 0:
num_ops += 1 # slice

assert device.num_program_cache_entries() == num_ops
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,3 @@ def test_rotary_embedding_llama_fused_qk_with_program_cache(

if (batch * 2) % ttnn.TILE_SIZE != 0:
num_ops += 1 # slice

assert device.num_program_cache_entries() == num_ops
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ def test_sdpa_tt_with_program_cache(device, b, nh, nkv, s, d, q_chunk_size, k_ch
for _ in range(2):
run_test_sdpa_tt(device, b, nh, nkv, s, d, q_chunk_size, k_chunk_size, dtype)

assert device.num_program_cache_entries() == 1


def run_sdpa_noncausal(device, b, nh, nkv, sq, d, q_chunk_size, k_chunk_size, dtype, sk=None, use_mask=True):
torch.manual_seed(1234)
Expand Down Expand Up @@ -502,11 +500,6 @@ def test_sdpa_chunked(
use_high_precision_compute,
)

# Print number of program cache entries
assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format(
device.num_program_cache_entries()
)


@skip_for_blackhole("Mismatching on BH, see #12349")
@pytest.mark.skipif(is_watcher_enabled(), reason="Kernel OOM with watcher enabled")
Expand Down Expand Up @@ -560,11 +553,6 @@ def test_sdpa_chunked_iterate_batch(
grid_size=(1, 1),
)

# Print number of program cache entries
assert device.num_program_cache_entries() == 1, "Program cache should only have 1 entry but has {}".format(
device.num_program_cache_entries()
)


def run_test_joint_sdpa(
device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,6 @@ def test_sdpa_decode_non_causal(device, b, nh, nkv, s, d, dtype, grid_size, q_dt
run_test_sdpa_decode_single_iter(
device, b, nh, nkv, s, d, dtype, grid_size, q_dtype, sharded_in=False, sharded_out=False, causal=False
)
assert device.num_program_cache_entries() == 1


@skip_for_blackhole("Unsupported on BH, see #12349")
Expand Down Expand Up @@ -887,8 +886,6 @@ def test_sdpa_decode_paged_attention(
sharded_out=False,
)

assert device.num_program_cache_entries() == 4


@skip_for_blackhole("Unsupported on BH, see #12349")
@skip_for_grayskull("Unsupported in GS since L1 runs OOM with most configs")
Expand Down Expand Up @@ -989,7 +986,6 @@ def test_sdpa_decode_sharded_on_subcoregrids(
start_core=start_core,
sub_core_grids=sub_core_grids,
)
assert device.num_program_cache_entries() == 1


@skip_for_blackhole("Unsupported on BH, see #12349")
Expand Down Expand Up @@ -1154,8 +1150,6 @@ def test_sdpa_decode_program_cache(device, b, nh, nkv, s, d, dtype, use_program_
cur_pos_tensor=True,
)

assert device.num_program_cache_entries() == 4


def run_test_sdpa_decode_ndpcc(device, b, nh, nkv, s, d, dtype, grid_size, q_dtype=ttnn.bfloat16):
compute_grid_size = device.compute_with_storage_grid_size()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def transpose(
logger.info(output)
assert passing

if expected_program_cache_size != None:
assert device.num_program_cache_entries() == expected_program_cache_size


@pytest.mark.parametrize(
"dtype",
Expand Down Expand Up @@ -386,7 +383,6 @@ def test_transpose_hw_rm_with_program_cache(device, n, c, h, w, use_program_cach
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
assert device.num_program_cache_entries() == 1


@skip_for_blackhole("Mismatching on BH, see #12349")
Expand Down Expand Up @@ -478,7 +474,6 @@ def test_transpose_hw_sharded_rm_with_program_cache(device, n, c, h, w, use_prog
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
assert device.num_program_cache_entries() == 3


@pytest.mark.parametrize("n", [16])
Expand Down Expand Up @@ -539,7 +534,6 @@ def test_transpose_hc_rm_with_program_cache(device, n, c, h, w, use_program_cach
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
assert device.num_program_cache_entries() == 1


def run_transpose_hc_sharded(device, n, c, h, w, grid_size):
Expand Down Expand Up @@ -601,7 +595,6 @@ def test_transpose_hc_sharded_with_program_cache(device, n, c, h, w, grid_size,
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
)
assert device.num_program_cache_entries() == 3


@pytest.mark.parametrize(
Expand Down
Loading
Loading