diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax_int.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax_int.py index 6109e014c4d..b15a41f944d 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax_int.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_argmax_int.py @@ -7,6 +7,7 @@ import ttnn from loguru import logger from tests.tt_eager.python_api_testing.sweep_tests import comparison_funcs +from tests.ttnn.utils_for_testing import assert_with_pcc @pytest.mark.parametrize( @@ -42,23 +43,9 @@ def test_argmax(self, input_shapes, dim, memconfig, device): tt_output_tensor_on_device = ttnn.argmax(input_tensor, dim=dim) tt_out_tensor = tt_output_tensor_on_device.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() golden_tensor = torch.argmax(input_data, dim=dim) - if dim == 1 or dim == -3 or dim == 0 or dim == -4: - tt_out_tensor = tt_out_tensor[0, :, 0 : input_shapes[2], 0 : input_shapes[3]] - else: - if input_shapes[1] != 1 or input_shapes[0] != 1: - if dim == 2 or dim == -2: - tt_out_tensor = tt_out_tensor[0, :, :, 0 : input_shapes[3]] - else: - tt_out_tensor = tt_out_tensor[0, :, :, 0 : input_shapes[2]] - else: - if dim == 2 or dim == -2: - tt_out_tensor = tt_out_tensor[0, 0, 0, 0 : input_shapes[3]] - else: - tt_out_tensor = tt_out_tensor[0, 0, 0, 0 : input_shapes[2]] pt_out_tensor = golden_tensor - tt_out_tensor = tt_output_tensor_on_device.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch() - comp_pass, comp_out = comparison_funcs.comp_pcc(pt_out_tensor, tt_out_tensor, pcc=0.99) + assert_with_pcc(pt_out_tensor, tt_out_tensor) comp_all, _ = comparison_funcs.comp_allclose(pt_out_tensor, tt_out_tensor, atol=0, rtol=0) # DEBUG @@ -68,8 +55,5 @@ def test_argmax(self, input_shapes, dim, memconfig, device): # print(flat) # print(torch.topk(flat, 8)) - logger.info(comp_pass) logger.info(comp_all) - logger.info(comp_out) - status = comp_pass | comp_all - assert status + assert comp_all diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_min_max.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_min_max.py index 94cce42f17a..acae7124847 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_min_max.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_min_max.py @@ -53,8 +53,6 @@ def test_min_max_for_dim_hw(device, use_program_cache, shape_dim, kind, layout): if kind == "max": value = x.max() elif kind == "min": - if N * C % 32 != 0: - pytest.skip("global min with Tensor dimension N*C not multiple of 32 is not supported at this time.") value = x.min() elif kind == "mean": value = x.mean() diff --git a/tests/ttnn/unit_tests/operations/test_max.py b/tests/ttnn/unit_tests/operations/test_max.py index 411fbd0ab44..b2378d023e0 100644 --- a/tests/ttnn/unit_tests/operations/test_max.py +++ b/tests/ttnn/unit_tests/operations/test_max.py @@ -8,7 +8,7 @@ import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc -from models.utility_functions import torch_random +from models.utility_functions import torch_random, is_grayskull @pytest.mark.parametrize("batch_size", [1, 16, 1, 16]) @@ -99,11 +99,24 @@ def test_max_global(device, batch_size, h, w): ((2, 32, 32, 64), -3), ((32, 32, 64), -3), ((1, 2, 3, 4), -1), + ((2, 22, 37, 41), -4), + ((2, 32, 64, 64), -3), + ((2, 22, 37, 41), -3), + ((2, 32, 64, 64), -2), + ((2, 22, 37, 41), -1), + ((2, 32, 64, 64), -1), + ((2, 22, 37), -3), + ((2, 22, 37), -2), + ((2, 22, 37), -1), + ((1, 6, 7), -3), + ((32, 6, 7), -3), ], ) @pytest.mark.parametrize("keepdim", [True, False]) def test_max_dim(device, input_shape_and_dim, keepdim): input_shape, max_dim = input_shape_and_dim + if is_grayskull() and (input_shape[-1] % 32 != 0 or input_shape[-2] % 32 != 0 or input_shape[max_dim] % 32 != 0): + pytest.skip("If not a tile size multiple, may fail on GS if run all the tests in this file. #17084") torch_input_tensor = torch_random(input_shape, -100, 100, dtype=torch.bfloat16) torch_output_tensor, _ = torch.max(torch_input_tensor, dim=max_dim, keepdim=keepdim) @@ -116,4 +129,5 @@ def test_max_dim(device, input_shape_and_dim, keepdim): output_tensor = ttnn.to_torch(output_tensor) - assert_with_pcc(torch_output_tensor, output_tensor) + pcc = 0.9999 + assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) diff --git a/tests/ttnn/unit_tests/operations/test_reduction_mean.py b/tests/ttnn/unit_tests/operations/test_reduction_mean.py index b9e8786ca38..e9146dc8e61 100644 --- a/tests/ttnn/unit_tests/operations/test_reduction_mean.py +++ b/tests/ttnn/unit_tests/operations/test_reduction_mean.py @@ -22,6 +22,7 @@ def test_mean(device, batch_size, h, w, dim): torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=True, dtype=torch.bfloat16) input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) + ttnn.fill_implicit_tile_padding(input_tensor, 42) # garbage padding to test that mean removes it output_tensor = ttnn.mean(input_tensor, dim=dim) output_tensor = ttnn.to_torch(output_tensor) diff --git a/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp b/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp index f2e5b3321da..a1b2b24bc85 100644 --- a/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/argmax/device/argmax_op.cpp @@ -58,7 +58,7 @@ std::vector ArgMax::compute_output_specs( ttnn::SimpleShape output_shape({1, 1, 1, 1}); if (this->dim.has_value()) { auto input_shape = input_tensors[0].get_logical_shape(); - output_shape = ttnn::SimpleShape{input_shape[0], input_shape[1], 1, input_shape[2]}; + output_shape = ttnn::SimpleShape{input_shape[0], input_shape[1], input_shape[2]}; } return { TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(input_tensor.get_layout()), output_mem_config))}; diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp index 6daaa330829..aa83584adfc 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp @@ -50,6 +50,12 @@ ttnn::SmallVector generate_reduce_dim( return dim; } +float get_pad_value(ReduceType reduce_type) { + return reduce_type == ReduceType::Max + ? -std::numeric_limits::infinity() + : (reduce_type == ReduceType::Min ? std::numeric_limits::infinity() : 0); +} + template static Tensor reduce_impl( const Tensor& input_tensor_arg, @@ -79,6 +85,7 @@ static Tensor reduce_impl( auto input_tensor = ttnn::unsqueeze_to_4D(input_tensor_arg); Tensor output_tensor; + float pad_value = get_pad_value(reduce_type); bool single_reduce_op = (dim.size() == 1 && (dim[0] == rank - 1 || dim[0] == rank - 2)) || (dim.size() == 2 && dim[1] == rank - 1 && dim[0] == rank - 2); if (!single_reduce_op) { @@ -92,7 +99,7 @@ static Tensor reduce_impl( int adjusted_dim = offset + i_dim; int reduce_dim = adjusted_dim; if (transpose) { - output_tensor = ttnn::transpose(output_tensor, adjusted_dim, 2, memory_config); + output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config, pad_value); reduce_dim = 2; } if (use_reduce_type) { @@ -115,7 +122,7 @@ static Tensor reduce_impl( /*reshape=*/false); } if (transpose) { - output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config); + output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config, pad_value); } } } @@ -241,9 +248,7 @@ Tensor Reduce::invoke( const std::optional& compute_kernel_config, float scalar) { ttnn::SmallVector dim = generate_reduce_dim(input_tensor_arg, dim_arg); - float pad_value = reduce_type == ReduceType::Max - ? -std::numeric_limits::infinity() - : (reduce_type == ReduceType::Min ? std::numeric_limits::infinity() : 0); + float pad_value = get_pad_value(reduce_type); bool is_tiled = input_tensor_arg.get_layout() == TILE_LAYOUT; auto input_tensor = is_tiled ? ttnn::fill_implicit_tile_padding(input_tensor_arg, pad_value) : input_tensor_arg; if constexpr (reduce_type == ReduceType::Std || reduce_type == ReduceType::Var) { diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp index 137dba6f7ce..f592a557508 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.hpp @@ -34,6 +34,7 @@ struct Reduce { }; // Entry point for pool op, which uses non-standard tensors that cannot be padded. +[[deprecated]] Tensor pool_sum( const Tensor& input_tensor_arg, int dim_arg,