Skip to content

Commit

Permalink
#13804: revise callback test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
mrshaw01 committed Oct 17, 2024
1 parent 8686dc4 commit 1ed86ea
Show file tree
Hide file tree
Showing 9 changed files with 101 additions and 62 deletions.
18 changes: 11 additions & 7 deletions tests/ttnn/unit_tests/operations/test_moreh_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
to_ttnn,
)


Expand Down Expand Up @@ -153,12 +154,15 @@ def forward(self, x):
([2, 2, 2, 2, 2, 2, 64, 64], 0.0, (0.9, 0.999), 1e-06, 0.0, False, False),
),
)
def test_moreh_adam_enable_cache(params, device, use_program_cache):
for i in range(4):
def test_moreh_adam_callback(params, device, use_program_cache):
torch.manual_seed(2024)
num_program_cache_entries_list = []
for i in range(2):
shape, lr, betas, eps, weight_decay, amsgrad, fp32_dest_acc_en = params
if i % 2 == 1:
amsgrad = not amsgrad

test_moreh_adam(shape, lr, betas, eps, weight_decay, amsgrad, fp32_dest_acc_en, device)

assert device.num_program_cache_entries() == 2
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]
17 changes: 9 additions & 8 deletions tests/ttnn/unit_tests/operations/test_moreh_adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
to_ttnn,
)
from loguru import logger

Expand Down Expand Up @@ -200,16 +201,16 @@ def test_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device)
@pytest.mark.parametrize("amsgrad", [True, False])
@pytest.mark.parametrize("step", [8])
def test_moreh_adamw_callback(shape, lr, betas, eps, weight_decay, amsgrad, step, device, use_program_cache):
torch.manual_seed(0)
torch.manual_seed(2024)
num_program_cache_entries_list = []
for i in range(2):
run_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device)
# Add dummy tensor to make sure that created tensor in 2 iteration don't share the same addr
tt_dummy_tensor = ttnn.empty([1, 1, 32, 32], ttnn.bfloat16, ttnn.TILE_LAYOUT, device)
if i == 0:
num_program_cache_entries = device.num_program_cache_entries()
assert num_program_cache_entries > 0
else:
assert device.num_program_cache_entries() == num_program_cache_entries
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]


@pytest.mark.parametrize(
Expand Down
18 changes: 5 additions & 13 deletions tests/ttnn/unit_tests/operations/test_moreh_arange.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,7 @@ def test_arange(start_end_step, optional_output, dtype, tilized, device):
@pytest.mark.parametrize(
"start_end_step",
[
[0, 32, 1],
[2.3, 15.3, 0.5],
[10.9, -13, -0.3],
[-100, 32 * 10, 1],
[0, 32000, 1],
[2300.3, 15300.3, 0.5392],
[10900.9, -13000, -0.3111],
[-10000, 32 * 10000, 1],
],
)
@pytest.mark.parametrize(
Expand All @@ -120,14 +113,13 @@ def test_arange(start_end_step, optional_output, dtype, tilized, device):
)
def test_arange_callback(start_end_step, optional_output, dtype, device, use_program_cache):
"""Test arange functionality with callback and program cache validation."""
torch.manual_seed(2024)
num_program_cache_entries_list = []
for i in range(4):
if i % 2 == 0:
run_moreh_arange(start_end_step, optional_output, dtype, True, device)
else:
run_moreh_arange(start_end_step, optional_output, dtype, False, device)
for i in range(2):
run_moreh_arange(start_end_step, optional_output, dtype, True, device)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list == [1, 2, 2, 2]
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]
15 changes: 12 additions & 3 deletions tests/ttnn/unit_tests/operations/test_moreh_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,15 @@ def run_getitem_RAW_MAJOR(shape_index_dim, dtype, index_size, device):
)
def test_getitem_RAW_MAJOR_callback(shape_index_dim, dtype, index_size, device, use_program_cache):
torch.manual_seed(2024)

for _ in range(2):
num_program_cache_entries_list = []
for i in range(2):
run_getitem_RAW_MAJOR(shape_index_dim, dtype, index_size, device)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]


@skip_for_blackhole("Mismatching on Blackhole, see #12349")
Expand Down Expand Up @@ -816,7 +820,12 @@ def test_getitem_tilized_one_index_callback(
shape_index_dim, dtype, index_size, row_major_index, device, use_program_cache
):
torch.manual_seed(2024)
for _ in range(2):
num_program_cache_entries_list = []
for i in range(2):
run_moreh_geitem_tilized_one_index(shape_index_dim, dtype, index_size, row_major_index, device)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]
16 changes: 14 additions & 2 deletions tests/ttnn/unit_tests/operations/test_moreh_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,16 @@ def test_moreh_group_norm(N, C_num_groups, HW, eps, affine, compute_mean_rstd, d
],
)
def test_moreh_group_norm_callback(N, C_num_groups, HW, eps, affine, compute_mean_rstd, device, use_program_cache):
for _ in range(2):
torch.manual_seed(2024)
num_program_cache_entries_list = []
for i in range(2):
run_test_moreh_group_norm(N, C_num_groups, HW, eps, affine, compute_mean_rstd, device)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]


def run_test_moreh_group_norm_backward(
Expand Down Expand Up @@ -525,9 +531,15 @@ def test_moreh_group_norm_backward_callback(
device,
use_program_cache,
):
for _ in range(2):
torch.manual_seed(2024)
num_program_cache_entries_list = []
for i in range(2):
run_test_moreh_group_norm_backward(
N, C_num_groups, HW, eps, affine, input_requires_grad, gamma_requires_grad, beta_requires_grad, device
)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]
24 changes: 18 additions & 6 deletions tests/ttnn/unit_tests/operations/test_moreh_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,10 +547,16 @@ def test_moreh_layer_norm_backward_compute_kernel_options(
],
)
def test_moreh_layer_norm_callback(input_shape_normalized_dims, elementwise_affine, eps, device, use_program_cache):
torch.manual_seed(2023)
for _ in range(2):
torch.manual_seed(2024)
num_program_cache_entries_list = []
for i in range(2):
run_moreh_layer_norm(input_shape_normalized_dims, elementwise_affine, eps, device)
assert device.num_program_cache_entries() == 1
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]


@skip_for_grayskull("Using the transpose function in copy_tile causes a hang.")
Expand All @@ -569,10 +575,16 @@ def test_moreh_layer_norm_callback(input_shape_normalized_dims, elementwise_affi
def test_moreh_layer_norm_backward_callback(
input_shape_normalized_dims, elementwise_affine, eps, device, use_program_cache
):
torch.manual_seed(2023)
for _ in range(2):
torch.manual_seed(2024)
num_program_cache_entries_list = []
for i in range(2):
run_moreh_layer_norm_backward(input_shape_normalized_dims, elementwise_affine, eps, device)
assert device.num_program_cache_entries() == (2 if elementwise_affine else 1)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]


@skip_for_grayskull("Using the transpose function in copy_tile causes a hang.")
Expand Down
27 changes: 18 additions & 9 deletions tests/ttnn/unit_tests/operations/test_moreh_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_compute_kernel_options,
compute_kernel_options,
compute_kernel_ids,
to_ttnn,
)


Expand Down Expand Up @@ -105,13 +106,17 @@ def test_moreh_linear_wo_output(shapes, has_bias, device):
),
)
def test_moreh_linear_enable_cache(shapes, device, use_program_cache):
torch.manual_seed(3072)
compute_kernel_config = get_compute_kernel_options(False)
torch.manual_seed(2024)
num_program_cache_entries_list = []
for i in range(2):
passing = moreh_linear(shapes, True, True, compute_kernel_config, device)
passing = moreh_linear(shapes, True, True, get_compute_kernel_options(False), device)
assert passing

assert device.num_program_cache_entries() == 1
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]


def moreh_linear_backward(
Expand Down Expand Up @@ -237,18 +242,22 @@ def test_moreh_linear_backward(shapes, requires_grads, requires_bias_grad, compu
),
)
def test_moreh_linear_backward_enable_cache(shapes, device, use_program_cache):
torch.manual_seed(3072)
requires_input_grad, requires_weight_grad, requires_bias_grad = (True, True, True)
compute_kernel_config = get_compute_kernel_options(False)
num_program_cache_entries_list = []

torch.manual_seed(2024)
num_program_cache_entries_list = []
for i in range(2):
passing = moreh_linear_backward(
shapes, requires_input_grad, requires_weight_grad, requires_bias_grad, compute_kernel_config, device
)
num_program_cache_entries_list.append(device.num_program_cache_entries())
assert passing
assert len(set(num_program_cache_entries_list)) == 1
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]


@skip_for_grayskull("GS does not support fp32")
Expand Down
26 changes: 12 additions & 14 deletions tests/ttnn/unit_tests/operations/test_moreh_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,16 @@ def test_moreh_mean_compute_kernel_options(input_shape_dim, compute_kernel_optio
],
)
def test_moreh_mean_callback(input_shape_dim, device, use_program_cache):
torch.manual_seed(2023)

torch.manual_seed(2024)
num_program_cache_entries_list = []
for i in range(2):
run_moreh_mean(input_shape_dim, device, keepdim=True)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
if i == 0:
num_program_cache_entries = device.num_program_cache_entries()
assert num_program_cache_entries > 0
else:
assert device.num_program_cache_entries() == num_program_cache_entries
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -275,17 +274,16 @@ def test_moreh_mean_backward_compute_kernel_options(input_shape_dim, compute_ker
],
)
def test_moreh_mean_backward_callback(input_shape_dim, device, use_program_cache):
torch.manual_seed(2023)

torch.manual_seed(2024)
num_program_cache_entries_list = []
for i in range(2):
run_moreh_mean_backward(input_shape_dim, device, keepdim=True)
torch_dummy = torch.randn([32, 32])
tt_dummy = to_ttnn(torch_dummy, device=device)
if i == 0:
num_program_cache_entries = device.num_program_cache_entries()
assert num_program_cache_entries > 0
else:
assert device.num_program_cache_entries() == num_program_cache_entries
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]


@pytest.mark.parametrize(
Expand Down
2 changes: 2 additions & 0 deletions tests/ttnn/unit_tests/operations/test_moreh_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def test_moreh_norm_callback(input_shape, p, dim_rtol_atol, keepdim, device, use
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]


Expand Down Expand Up @@ -500,4 +501,5 @@ def test_moreh_norm_backward_callback(input_shape, p, dim_rtol_atol, keepdim, de
tt_dummy = to_ttnn(torch_dummy, device=device)
num_program_cache_entries_list.append(device.num_program_cache_entries())
logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}")
assert num_program_cache_entries_list[0] > 0
assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1]

0 comments on commit 1ed86ea

Please sign in to comment.