Skip to content

Commit

Permalink
#16679: K min values support for TopK (#16917)
Browse files Browse the repository at this point in the history
### Ticket
[Link to Github
Issue](#16679)

### Problem description
TopK currently supports max sorting, where K max values are returned. We
need to add necessary changes to LLKs to support returning the K min
values.

### What's changed
LLKs were updated to pass down a flag specifying which behavior (largest
or smallest k values) is expected. Ckernel updated to place min values
into register instead of max values when flag is set, returning k min
values as a result.

### Checklist
- [x] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12932508914)
- [x] [Blackhole Post
commit](https://github.com/tenstorrent/tt-metal/actions/runs/12932523648)
(if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
atatuzunerTT authored Jan 24, 2025
1 parent c645503 commit 080f063
Show file tree
Hide file tree
Showing 13 changed files with 54 additions and 36 deletions.
11 changes: 6 additions & 5 deletions tests/ttnn/unit_tests/operations/test_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
from models.utility_functions import skip_for_grayskull


def run_topk_test(N, C, H, W, k, dtype, device):
def run_topk_test(N, C, H, W, k, largest, dtype, device):
torch.manual_seed(2005)
shape = [N, C, H, W]
torch_dtype = torch.bfloat16

input = torch.randn(shape, dtype=torch_dtype)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=-1, largest=True, sorted=True)
pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=-1, largest=largest, sorted=True)

ttnn_input = ttnn.from_torch(input, dtype, layout=ttnn.Layout.TILE, device=device)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=-1, largest=True, sorted=True)
ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=-1, largest=largest, sorted=True)

assert list(ttnn_topk_values.shape.with_tile_padding()) == [N, C, H, k]
assert list(ttnn_topk_indices.shape.with_tile_padding()) == [N, C, H, k]
Expand Down Expand Up @@ -73,5 +73,6 @@ def run_topk_test(N, C, H, W, k, dtype, device):
(1, 1, 8192, 64, 32),
),
)
def test_topk(N, C, H, W, k, dtype, device):
run_topk_test(N, C, H, W, k, dtype, device)
@pytest.mark.parametrize("largest", (True, False))
def test_topk(N, C, H, W, k, largest, dtype, device):
run_topk_test(N, C, H, W, k, largest, dtype, device)
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ inline void calculate_bitonic_topk_phases_steps(
idir, i_end_phase, i_start_phase, i_end_step, i_start_step);
}

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
template <bool APPROXIMATION_MODE, bool idir = false, int ITERATIONS = 8>
inline void calculate_bitonic_topk_merge(uint m_iter, uint k) {
_bitonic_topk_merge<APPROXIMATION_MODE, ITERATIONS>(m_iter, k);
_bitonic_topk_merge<APPROXIMATION_MODE, idir, ITERATIONS>(m_iter, k);
}

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ inline void llk_math_eltwise_unary_sfpu_topk_local_sort(
i_start_step);
}

template <bool APPROXIMATE>
template <bool APPROXIMATE, bool idir = false>
inline void llk_math_eltwise_unary_sfpu_topk_merge(
uint dst_index, int m_iter, int k, int vector_mode = (int)VectorMode::RC_custom) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_bitonic_topk_merge<APPROXIMATE>, dst_index, vector_mode, m_iter, k);
ckernel::sfpu::calculate_bitonic_topk_merge<APPROXIMATE, idir>, dst_index, vector_mode, m_iter, k);
}

template <bool APPROXIMATE>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ inline void calculate_bitonic_topk_phases_steps(
idir, i_end_phase, i_start_phase, i_end_step, i_start_step);
}

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
template <bool APPROXIMATION_MODE, bool idir = false, int ITERATIONS = 8>
inline void calculate_bitonic_topk_merge(uint m_iter, uint k) {
_bitonic_topk_merge<APPROXIMATION_MODE, ITERATIONS>(m_iter, k);
_bitonic_topk_merge<APPROXIMATION_MODE, idir, ITERATIONS>(m_iter, k);
}

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ inline void llk_math_eltwise_unary_sfpu_topk_local_sort(
i_start_step);
}

template <bool APPROXIMATE>
template <bool APPROXIMATE, bool idir = false>
inline void llk_math_eltwise_unary_sfpu_topk_merge(
uint dst_index, int m_iter, int k, int vector_mode = (int)VectorMode::RC_custom) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_bitonic_topk_merge<APPROXIMATE>, dst_index, vector_mode, m_iter, k);
ckernel::sfpu::calculate_bitonic_topk_merge<APPROXIMATE, idir>, dst_index, vector_mode, m_iter, k);
}

template <bool APPROXIMATE>
Expand Down
13 changes: 8 additions & 5 deletions tt_metal/include/compute_kernel_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -731,13 +731,16 @@ ALWI void topk_local_sort(
* Range | Required |
* |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be
* less than the size of the DST register buffer | True | | m_iter | The index of the merge & rebuild
* iteration of the algorithm | int32 | 0 to 9 | True |
* | k | The number of sorted values to return | int32 | {4, 8,
* 16, 32, 64} | True |
* less than the size of the DST register buffer | True |
* | idir | The sorting direction of the local sort (0 == decreasing, 1 == increasing) | bool | 0 to 1
* | True |
* | m_iter | The index of the merge & rebuild iteration of the algorithm | int32 | 0 to 9 |
* True | | k | The number of sorted values to return | int32 |
* {4, 8, 16, 32, 64} | True |
*/
template <bool idir = false>
ALWI void topk_merge(uint32_t idst, int m_iter, int k) {
MATH((llk_math_eltwise_unary_sfpu_topk_merge<true>(idst, m_iter, k)));
MATH((llk_math_eltwise_unary_sfpu_topk_merge<true, idir>(idst, m_iter, k)));
}

// topK rebuild
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/third_party/tt_llk_blackhole
2 changes: 1 addition & 1 deletion tt_metal/third_party/tt_llk_wormhole_b0
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void MAIN {
constexpr uint32_t K = get_compile_time_arg_val(8);
constexpr uint32_t logk = get_compile_time_arg_val(9);
constexpr uint32_t logWt = get_compile_time_arg_val(10);
constexpr uint32_t largest = get_compile_time_arg_val(11);

// dest indices for where to unpack the tiles for the llk
// the input goes in index 0,1 and the index goes in index 2,3
Expand All @@ -40,7 +41,8 @@ void MAIN {
transpose_wh_init(input_cb_index, input_transposed_cb_index);

for (uint32_t ht = 0; ht < Ht; ++ht) {
bool ascending = false;
bool ascending = !largest;
// bool ascending = false;
cb_reserve_back(input_transposed_cb_index, Wt);
cb_reserve_back(index_transposed_cb_index, Wt);

Expand Down Expand Up @@ -87,7 +89,8 @@ void MAIN {
// pair. second iteration we compare 0th and 2nd tile, then 4th and 6th, etc. logWt iteration we compare 0th and
// Wt/2 tile single buffer as we can pack tiles back in-place
for (uint32_t m_iter = 0; m_iter < logWt; ++m_iter) {
bool a = false;
bool a = !largest;
// bool a = false;
cb_wait_front(input_transposed_cb_index, Wt);
cb_wait_front(index_transposed_cb_index, Wt);

Expand All @@ -105,7 +108,7 @@ void MAIN {
copy_tile(index_transposed_cb_index, right_ind, index_dest_end);

// merge values - move larger 32 values into 0th dest and lower 32 values into 1st dest
ckernel::topk_merge(0, m_iter, K);
ckernel::topk_merge<!largest>(0, m_iter, K);
// sort within the larger 32 values
ckernel::topk_rebuild(0, (uint32_t)a, m_iter, K, logk, true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void MAIN {
constexpr uint32_t Kt = get_compile_time_arg_val(9);
constexpr uint32_t logk = get_compile_time_arg_val(10);
constexpr uint32_t logWt = get_compile_time_arg_val(11);
constexpr uint32_t largest = get_compile_time_arg_val(12);

// dest indices for where to unpack the tiles for the llk
// the input goes in index 0,1 and the index goes in index 2,3
Expand Down Expand Up @@ -83,7 +84,7 @@ void MAIN {
// pair. second iteration we compare 0th and 2nd tile, then 4th and 6th, etc. logWt iteration we compare 0th and
// Wt/2 tile single buffer as we can pack tiles back in-place
for (uint32_t m_iter = 0; m_iter < logWt; ++m_iter) {
bool direction = false;
bool direction = !largest;
cb_wait_front(input_transposed_cb_index, Wt);
cb_wait_front(index_transposed_cb_index, Wt);
uint32_t stride = 1 << m_iter;
Expand All @@ -102,7 +103,7 @@ void MAIN {
copy_tile(index_transposed_cb_index, right_ind, index_dest_end);

// merge values - move larger 32 values into 0th dest and lower 32 values into 1st dest
ckernel::topk_merge(0, m_iter, K);
ckernel::topk_merge<!largest>(0, m_iter, K);
// sort within the larger 32 values
ckernel::topk_rebuild(0, (uint32_t)direction, m_iter, K, logk, true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ void MAIN {
constexpr uint32_t Kt = get_compile_time_arg_val(9);
constexpr uint32_t logk = get_compile_time_arg_val(10);
constexpr uint32_t logWt = get_compile_time_arg_val(11);
constexpr bool ascending = false;
constexpr uint32_t largest = get_compile_time_arg_val(12);

uint32_t direction_init = get_arg_val<uint32_t>(0);

Expand Down Expand Up @@ -65,7 +65,7 @@ void MAIN {
transpose_wh_tile(index_cb_index, 1, 3);

// llk_topk_sort -> inplace
ckernel::topk_local_sort(0, (int)ascending, logk - 1);
ckernel::topk_local_sort(0, (int)direction_init, logk - 1);

// pack value tiles into cb_intermed0
pack_reconfig_data_format(input_transposed_cb_index);
Expand Down Expand Up @@ -108,7 +108,7 @@ void MAIN {
copy_tile(index_transposed_cb_index, right_ind, index_dest_end);

// merge values - move larger 32 values into 0th dest and lower 32 values into 1st dest
ckernel::topk_merge(0, m_iter, K);
ckernel::topk_merge<!largest>(0, m_iter, K);
// sort within the larger 32 values
ckernel::topk_rebuild(0, (uint32_t)direction, m_iter, K, logk, true);

Expand Down
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ operation::ProgramWithCallbacks TopK::create_program(
const auto& input_tensor = input_tensors.at(0);
if (input_tensor.get_padded_shape()[dim] < topk_utils::multi_core_min_width) {
return detail::topk_single_core_interleaved(
input_tensor, this->k, this->dim, output_tensors.at(0), output_tensors.at(1));
input_tensor, this->k, this->dim, this->largest, output_tensors.at(0), output_tensors.at(1));
} else {
return detail::topk_multicore_interleaved(
input_tensor, this->k, this->dim, output_tensors.at(0), output_tensors.at(1));
input_tensor, this->k, this->dim, this->largest, output_tensors.at(0), output_tensors.at(1));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
namespace ttnn::operations::reduction::detail {

operation::ProgramWithCallbacks topk_single_core_interleaved(
const Tensor& input_tensor, const uint16_t k, const int8_t dim, Tensor& value_tensor, Tensor& index_tensor) {
const Tensor& input_tensor,
const uint16_t k,
const int8_t dim,
const bool largest,
Tensor& value_tensor,
Tensor& index_tensor) {
using namespace tt::constants;
tt::tt_metal::Program program{};
CoreRange core({0, 0}, {0, 0});
Expand Down Expand Up @@ -131,7 +136,7 @@ operation::ProgramWithCallbacks topk_single_core_interleaved(
k,
(std::uint32_t)std::log2(k),
(std::uint32_t)std::log2(Wt),
};
largest};
tt::tt_metal::KernelHandle topk_compute_kernel_id = tt::tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk.cpp",
Expand Down Expand Up @@ -203,7 +208,12 @@ static inline std::tuple<uint16_t, uint16_t, uint16_t, uint16_t> cores_utilized(
*
*/
operation::ProgramWithCallbacks topk_multicore_interleaved(
const Tensor& input_tensor, const uint16_t k, const int8_t dim, Tensor& value_tensor, Tensor& index_tensor) {
const Tensor& input_tensor,
const uint16_t k,
const int8_t dim,
const bool largest,
Tensor& value_tensor,
Tensor& index_tensor) {
using namespace tt::constants;
tt::tt_metal::Program program{};

Expand Down Expand Up @@ -401,7 +411,7 @@ operation::ProgramWithCallbacks topk_multicore_interleaved(
Kt,
(std::uint32_t)std::log2(k),
(std::uint32_t)std::log2(Wt_local),
};
largest};
tt::tt_metal::KernelHandle topk_compute_kernel_id = tt::tt_metal::CreateKernel(
program,
"ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_local.cpp",
Expand All @@ -421,7 +431,7 @@ operation::ProgramWithCallbacks topk_multicore_interleaved(
Kt,
(std::uint32_t)std::log2(k),
(std::uint32_t)std::log2(Wt_final),
};
largest};

tt::tt_metal::KernelHandle topk_final_compute_kernel_id = tt::tt_metal::CreateKernel(
program,
Expand All @@ -430,7 +440,7 @@ operation::ProgramWithCallbacks topk_multicore_interleaved(
tt::tt_metal::ComputeConfig{.compile_args = compute_args_final});

for (uint32_t core_h = 0; core_h < 1; core_h++) {
bool ascending = false;
bool ascending = !largest;
for (uint32_t core_w = 0; core_w < num_cores - 1; core_w++) {
CoreCoord core = {core_h, core_w};
SetRuntimeArgs(
Expand Down

0 comments on commit 080f063

Please sign in to comment.