Skip to content

Commit

Permalink
#16720: and #14898 update output dims for argmax and move pad for gen…
Browse files Browse the repository at this point in the history
…eric reduce (#16989)

### Ticket
Link to Github Issues #16720 and #14898

### Problem description
- when specify a dim, output tensor of argmax is not one rank smaller
than input
- transpose seems to insert it's own padding, which occurs for the early
tensor dimensions

### What's changed
- for argmax, change shape of output tensor to have the right rank
- for generic reduce, move pad filling to right before the reduce op is
called
- also update tests and add deprecated to another specialized reduce
function

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12956428377
- [x] Blackhole Post commit (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12939182511
- [x] Model regression CI testing passes (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12939185477/job/36091291360
in line with main
https://github.com/tenstorrent/tt-metal/actions/runs/12937069874/job/36084729557
- [x] Device performance regression CI testing passes (if applicable)
https://github.com/tenstorrent/tt-metal/actions/runs/12939183976
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
bbradelTT authored Jan 25, 2025
1 parent 8b2c6cd commit 83145d2
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ttnn
from loguru import logger
from tests.tt_eager.python_api_testing.sweep_tests import comparison_funcs
from tests.ttnn.utils_for_testing import assert_with_pcc


@pytest.mark.parametrize(
Expand Down Expand Up @@ -42,23 +43,9 @@ def test_argmax(self, input_shapes, dim, memconfig, device):
tt_output_tensor_on_device = ttnn.argmax(input_tensor, dim=dim)
tt_out_tensor = tt_output_tensor_on_device.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
golden_tensor = torch.argmax(input_data, dim=dim)
if dim == 1 or dim == -3 or dim == 0 or dim == -4:
tt_out_tensor = tt_out_tensor[0, :, 0 : input_shapes[2], 0 : input_shapes[3]]
else:
if input_shapes[1] != 1 or input_shapes[0] != 1:
if dim == 2 or dim == -2:
tt_out_tensor = tt_out_tensor[0, :, :, 0 : input_shapes[3]]
else:
tt_out_tensor = tt_out_tensor[0, :, :, 0 : input_shapes[2]]
else:
if dim == 2 or dim == -2:
tt_out_tensor = tt_out_tensor[0, 0, 0, 0 : input_shapes[3]]
else:
tt_out_tensor = tt_out_tensor[0, 0, 0, 0 : input_shapes[2]]

pt_out_tensor = golden_tensor
tt_out_tensor = tt_output_tensor_on_device.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
comp_pass, comp_out = comparison_funcs.comp_pcc(pt_out_tensor, tt_out_tensor, pcc=0.99)
assert_with_pcc(pt_out_tensor, tt_out_tensor)
comp_all, _ = comparison_funcs.comp_allclose(pt_out_tensor, tt_out_tensor, atol=0, rtol=0)

# DEBUG
Expand All @@ -68,8 +55,5 @@ def test_argmax(self, input_shapes, dim, memconfig, device):
# print(flat)
# print(torch.topk(flat, 8))

logger.info(comp_pass)
logger.info(comp_all)
logger.info(comp_out)
status = comp_pass | comp_all
assert status
assert comp_all
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ def test_min_max_for_dim_hw(device, use_program_cache, shape_dim, kind, layout):
if kind == "max":
value = x.max()
elif kind == "min":
if N * C % 32 != 0:
pytest.skip("global min with Tensor dimension N*C not multiple of 32 is not supported at this time.")
value = x.min()
elif kind == "mean":
value = x.mean()
Expand Down
18 changes: 16 additions & 2 deletions tests/ttnn/unit_tests/operations/test_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc
from models.utility_functions import torch_random
from models.utility_functions import torch_random, is_grayskull


@pytest.mark.parametrize("batch_size", [1, 16, 1, 16])
Expand Down Expand Up @@ -99,11 +99,24 @@ def test_max_global(device, batch_size, h, w):
((2, 32, 32, 64), -3),
((32, 32, 64), -3),
((1, 2, 3, 4), -1),
((2, 22, 37, 41), -4),
((2, 32, 64, 64), -3),
((2, 22, 37, 41), -3),
((2, 32, 64, 64), -2),
((2, 22, 37, 41), -1),
((2, 32, 64, 64), -1),
((2, 22, 37), -3),
((2, 22, 37), -2),
((2, 22, 37), -1),
((1, 6, 7), -3),
((32, 6, 7), -3),
],
)
@pytest.mark.parametrize("keepdim", [True, False])
def test_max_dim(device, input_shape_and_dim, keepdim):
input_shape, max_dim = input_shape_and_dim
if is_grayskull() and (input_shape[-1] % 32 != 0 or input_shape[-2] % 32 != 0 or input_shape[max_dim] % 32 != 0):
pytest.skip("If not a tile size multiple, may fail on GS if run all the tests in this file. #17084")

torch_input_tensor = torch_random(input_shape, -100, 100, dtype=torch.bfloat16)
torch_output_tensor, _ = torch.max(torch_input_tensor, dim=max_dim, keepdim=keepdim)
Expand All @@ -116,4 +129,5 @@ def test_max_dim(device, input_shape_and_dim, keepdim):

output_tensor = ttnn.to_torch(output_tensor)

assert_with_pcc(torch_output_tensor, output_tensor)
pcc = 0.9999
assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc)
1 change: 1 addition & 0 deletions tests/ttnn/unit_tests/operations/test_reduction_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_mean(device, batch_size, h, w, dim):
torch_output_tensor = torch.mean(torch_input_tensor, dim=dim, keepdim=True, dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
ttnn.fill_implicit_tile_padding(input_tensor, 42) # garbage padding to test that mean removes it

output_tensor = ttnn.mean(input_tensor, dim=dim)
output_tensor = ttnn.to_torch(output_tensor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ std::vector<TensorSpec> ArgMax::compute_output_specs(
ttnn::SimpleShape output_shape({1, 1, 1, 1});
if (this->dim.has_value()) {
auto input_shape = input_tensors[0].get_logical_shape();
output_shape = ttnn::SimpleShape{input_shape[0], input_shape[1], 1, input_shape[2]};
output_shape = ttnn::SimpleShape{input_shape[0], input_shape[1], input_shape[2]};
}
return {
TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(input_tensor.get_layout()), output_mem_config))};
Expand Down
15 changes: 10 additions & 5 deletions ttnn/cpp/ttnn/operations/reduction/generic/generic_reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ ttnn::SmallVector<int> generate_reduce_dim(
return dim;
}

float get_pad_value(ReduceType reduce_type) {
return reduce_type == ReduceType::Max
? -std::numeric_limits<float>::infinity()
: (reduce_type == ReduceType::Min ? std::numeric_limits<float>::infinity() : 0);
}

template <ReduceType reduce_type>
static Tensor reduce_impl(
const Tensor& input_tensor_arg,
Expand Down Expand Up @@ -79,6 +85,7 @@ static Tensor reduce_impl(
auto input_tensor = ttnn::unsqueeze_to_4D(input_tensor_arg);

Tensor output_tensor;
float pad_value = get_pad_value(reduce_type);
bool single_reduce_op = (dim.size() == 1 && (dim[0] == rank - 1 || dim[0] == rank - 2)) ||
(dim.size() == 2 && dim[1] == rank - 1 && dim[0] == rank - 2);
if (!single_reduce_op) {
Expand All @@ -92,7 +99,7 @@ static Tensor reduce_impl(
int adjusted_dim = offset + i_dim;
int reduce_dim = adjusted_dim;
if (transpose) {
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, 2, memory_config);
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config, pad_value);
reduce_dim = 2;
}
if (use_reduce_type) {
Expand All @@ -115,7 +122,7 @@ static Tensor reduce_impl(
/*reshape=*/false);
}
if (transpose) {
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config);
output_tensor = ttnn::transpose(output_tensor, adjusted_dim, -2, memory_config, pad_value);
}
}
}
Expand Down Expand Up @@ -241,9 +248,7 @@ Tensor Reduce<reduce_type>::invoke(
const std::optional<DeviceComputeKernelConfig>& compute_kernel_config,
float scalar) {
ttnn::SmallVector<int> dim = generate_reduce_dim(input_tensor_arg, dim_arg);
float pad_value = reduce_type == ReduceType::Max
? -std::numeric_limits<float>::infinity()
: (reduce_type == ReduceType::Min ? std::numeric_limits<float>::infinity() : 0);
float pad_value = get_pad_value(reduce_type);
bool is_tiled = input_tensor_arg.get_layout() == TILE_LAYOUT;
auto input_tensor = is_tiled ? ttnn::fill_implicit_tile_padding(input_tensor_arg, pad_value) : input_tensor_arg;
if constexpr (reduce_type == ReduceType::Std || reduce_type == ReduceType::Var) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct Reduce {
};

// Entry point for pool op, which uses non-standard tensors that cannot be padded.
[[deprecated]]
Tensor pool_sum(
const Tensor& input_tensor_arg,
int dim_arg,
Expand Down

0 comments on commit 83145d2

Please sign in to comment.