From 1b2787f21d3f653e112216d3a998661b1f85a995 Mon Sep 17 00:00:00 2001 From: Jay Kruer Date: Thu, 28 Nov 2024 05:28:49 +0000 Subject: [PATCH] #0: Revert refactor to test_run_tilize_test --- .../unit_testing/misc/test_tilize_test.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_tilize_test.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_tilize_test.py index 3f28195ebbc1..c7924edaa324 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_tilize_test.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_tilize_test.py @@ -10,14 +10,14 @@ @pytest.mark.parametrize( - "shape", + "nb, nc, nh, nw", ( - [5, 2, 4, 8], - [5, 2, 4, 7], + (5, 2, 4, 8), + (5, 2, 4, 7), ## resnet shapes - [1, 1, 784, 2], - [8, 1, 2, 64], - [1, 1, 1, 64], + (1, 1, 784, 2), + (8, 1, 2, 64), + (1, 1, 1, 64), ), ) @pytest.mark.parametrize( @@ -27,9 +27,8 @@ True, ), ) -def test_run_tilize_test(shape, multicore, device): - shape[-1] *= 32 - shape[-2] *= 32 +def test_run_tilize_test(nb, nc, nh, nw, multicore, device): + shape = [nb, nc, nh * 32, nw * 32] inp = torch.rand(*shape).bfloat16()