diff --git a/tests/ttnn/unit_tests/operations/test_topk.py b/tests/ttnn/unit_tests/operations/test_topk.py index 7f7ca2ba4a7..19a35425965 100644 --- a/tests/ttnn/unit_tests/operations/test_topk.py +++ b/tests/ttnn/unit_tests/operations/test_topk.py @@ -12,16 +12,16 @@ from models.utility_functions import skip_for_grayskull -def run_topk_test(N, C, H, W, k, dtype, device): +def run_topk_test(N, C, H, W, k, largest, dtype, device): torch.manual_seed(2005) shape = [N, C, H, W] torch_dtype = torch.bfloat16 input = torch.randn(shape, dtype=torch_dtype) - pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=-1, largest=True, sorted=True) + pyt_topk_values, pyt_topk_indices = torch.topk(input, k, dim=-1, largest=largest, sorted=True) ttnn_input = ttnn.from_torch(input, dtype, layout=ttnn.Layout.TILE, device=device) - ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=-1, largest=True, sorted=True) + ttnn_topk_values, ttnn_topk_indices = ttnn.topk(ttnn_input, k, dim=-1, largest=largest, sorted=True) assert list(ttnn_topk_values.shape.with_tile_padding()) == [N, C, H, k] assert list(ttnn_topk_indices.shape.with_tile_padding()) == [N, C, H, k] @@ -73,5 +73,6 @@ def run_topk_test(N, C, H, W, k, dtype, device): (1, 1, 8192, 64, 32), ), ) -def test_topk(N, C, H, W, k, dtype, device): - run_topk_test(N, C, H, W, k, dtype, device) +@pytest.mark.parametrize("largest", (True, False)) +def test_topk(N, C, H, W, k, largest, dtype, device): + run_topk_test(N, C, H, W, k, largest, dtype, device) diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_topk.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_topk.h index 41aaaa243b2..fef913c489e 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_topk.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_topk.h @@ -20,9 +20,9 @@ inline void calculate_bitonic_topk_phases_steps( idir, i_end_phase, i_start_phase, i_end_step, i_start_step); } -template +template inline void calculate_bitonic_topk_merge(uint m_iter, uint k) { - _bitonic_topk_merge(m_iter, k); + _bitonic_topk_merge(m_iter, k); } template diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_topk.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_topk.h index 8525bc7d466..dbda76bf480 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_topk.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_topk.h @@ -38,11 +38,11 @@ inline void llk_math_eltwise_unary_sfpu_topk_local_sort( i_start_step); } -template +template inline void llk_math_eltwise_unary_sfpu_topk_merge( uint dst_index, int m_iter, int k, int vector_mode = (int)VectorMode::RC_custom) { llk_math_eltwise_unary_sfpu_params( - ckernel::sfpu::calculate_bitonic_topk_merge, dst_index, vector_mode, m_iter, k); + ckernel::sfpu::calculate_bitonic_topk_merge, dst_index, vector_mode, m_iter, k); } template diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_topk.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_topk.h index 41aaaa243b2..fef913c489e 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_topk.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_topk.h @@ -20,9 +20,9 @@ inline void calculate_bitonic_topk_phases_steps( idir, i_end_phase, i_start_phase, i_end_step, i_start_step); } -template +template inline void calculate_bitonic_topk_merge(uint m_iter, uint k) { - _bitonic_topk_merge(m_iter, k); + _bitonic_topk_merge(m_iter, k); } template diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_topk.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_topk.h index 8525bc7d466..dbda76bf480 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_topk.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_topk.h @@ -38,11 +38,11 @@ inline void llk_math_eltwise_unary_sfpu_topk_local_sort( i_start_step); } -template +template inline void llk_math_eltwise_unary_sfpu_topk_merge( uint dst_index, int m_iter, int k, int vector_mode = (int)VectorMode::RC_custom) { llk_math_eltwise_unary_sfpu_params( - ckernel::sfpu::calculate_bitonic_topk_merge, dst_index, vector_mode, m_iter, k); + ckernel::sfpu::calculate_bitonic_topk_merge, dst_index, vector_mode, m_iter, k); } template diff --git a/tt_metal/include/compute_kernel_api.h b/tt_metal/include/compute_kernel_api.h index d642fed4b99..b81c24ab90d 100644 --- a/tt_metal/include/compute_kernel_api.h +++ b/tt_metal/include/compute_kernel_api.h @@ -731,13 +731,16 @@ ALWI void topk_local_sort( * Range | Required | * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| * | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be - * less than the size of the DST register buffer | True | | m_iter | The index of the merge & rebuild - * iteration of the algorithm | int32 | 0 to 9 | True | - * | k | The number of sorted values to return | int32 | {4, 8, - * 16, 32, 64} | True | + * less than the size of the DST register buffer | True | + * | idir | The sorting direction of the local sort (0 == decreasing, 1 == increasing) | bool | 0 to 1 + * | True | + * | m_iter | The index of the merge & rebuild iteration of the algorithm | int32 | 0 to 9 | + * True | | k | The number of sorted values to return | int32 | + * {4, 8, 16, 32, 64} | True | */ +template ALWI void topk_merge(uint32_t idst, int m_iter, int k) { - MATH((llk_math_eltwise_unary_sfpu_topk_merge(idst, m_iter, k))); + MATH((llk_math_eltwise_unary_sfpu_topk_merge(idst, m_iter, k))); } // topK rebuild diff --git a/tt_metal/third_party/tt_llk_blackhole b/tt_metal/third_party/tt_llk_blackhole index 80befe6a571..c6b6fb75d30 160000 --- a/tt_metal/third_party/tt_llk_blackhole +++ b/tt_metal/third_party/tt_llk_blackhole @@ -1 +1 @@ -Subproject commit 80befe6a57116d80d41794c7f823e83d204249a0 +Subproject commit c6b6fb75d3028c49ea8a7f4de5a24532737607c1 diff --git a/tt_metal/third_party/tt_llk_wormhole_b0 b/tt_metal/third_party/tt_llk_wormhole_b0 index a2645b5d639..1a637922ff0 160000 --- a/tt_metal/third_party/tt_llk_wormhole_b0 +++ b/tt_metal/third_party/tt_llk_wormhole_b0 @@ -1 +1 @@ -Subproject commit a2645b5d639c03f54c2e84f4972462dd23606dca +Subproject commit 1a637922ff0779bb692aeb922c033716ea7b9401 diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk.cpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk.cpp index 8fdc7e8e99e..89d0a7797cb 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk.cpp @@ -27,6 +27,7 @@ void MAIN { constexpr uint32_t K = get_compile_time_arg_val(8); constexpr uint32_t logk = get_compile_time_arg_val(9); constexpr uint32_t logWt = get_compile_time_arg_val(10); + constexpr uint32_t largest = get_compile_time_arg_val(11); // dest indices for where to unpack the tiles for the llk // the input goes in index 0,1 and the index goes in index 2,3 @@ -40,7 +41,8 @@ void MAIN { transpose_wh_init(input_cb_index, input_transposed_cb_index); for (uint32_t ht = 0; ht < Ht; ++ht) { - bool ascending = false; + bool ascending = !largest; + // bool ascending = false; cb_reserve_back(input_transposed_cb_index, Wt); cb_reserve_back(index_transposed_cb_index, Wt); @@ -87,7 +89,8 @@ void MAIN { // pair. second iteration we compare 0th and 2nd tile, then 4th and 6th, etc. logWt iteration we compare 0th and // Wt/2 tile single buffer as we can pack tiles back in-place for (uint32_t m_iter = 0; m_iter < logWt; ++m_iter) { - bool a = false; + bool a = !largest; + // bool a = false; cb_wait_front(input_transposed_cb_index, Wt); cb_wait_front(index_transposed_cb_index, Wt); @@ -105,7 +108,7 @@ void MAIN { copy_tile(index_transposed_cb_index, right_ind, index_dest_end); // merge values - move larger 32 values into 0th dest and lower 32 values into 1st dest - ckernel::topk_merge(0, m_iter, K); + ckernel::topk_merge(0, m_iter, K); // sort within the larger 32 values ckernel::topk_rebuild(0, (uint32_t)a, m_iter, K, logk, true); diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_final.cpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_final.cpp index 101662ddc1b..a272c9098d0 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_final.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_final.cpp @@ -29,6 +29,7 @@ void MAIN { constexpr uint32_t Kt = get_compile_time_arg_val(9); constexpr uint32_t logk = get_compile_time_arg_val(10); constexpr uint32_t logWt = get_compile_time_arg_val(11); + constexpr uint32_t largest = get_compile_time_arg_val(12); // dest indices for where to unpack the tiles for the llk // the input goes in index 0,1 and the index goes in index 2,3 @@ -83,7 +84,7 @@ void MAIN { // pair. second iteration we compare 0th and 2nd tile, then 4th and 6th, etc. logWt iteration we compare 0th and // Wt/2 tile single buffer as we can pack tiles back in-place for (uint32_t m_iter = 0; m_iter < logWt; ++m_iter) { - bool direction = false; + bool direction = !largest; cb_wait_front(input_transposed_cb_index, Wt); cb_wait_front(index_transposed_cb_index, Wt); uint32_t stride = 1 << m_iter; @@ -102,7 +103,7 @@ void MAIN { copy_tile(index_transposed_cb_index, right_ind, index_dest_end); // merge values - move larger 32 values into 0th dest and lower 32 values into 1st dest - ckernel::topk_merge(0, m_iter, K); + ckernel::topk_merge(0, m_iter, K); // sort within the larger 32 values ckernel::topk_rebuild(0, (uint32_t)direction, m_iter, K, logk, true); diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_local.cpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_local.cpp index ab77072416d..9b11151c6ee 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_local.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_local.cpp @@ -28,7 +28,7 @@ void MAIN { constexpr uint32_t Kt = get_compile_time_arg_val(9); constexpr uint32_t logk = get_compile_time_arg_val(10); constexpr uint32_t logWt = get_compile_time_arg_val(11); - constexpr bool ascending = false; + constexpr uint32_t largest = get_compile_time_arg_val(12); uint32_t direction_init = get_arg_val(0); @@ -65,7 +65,7 @@ void MAIN { transpose_wh_tile(index_cb_index, 1, 3); // llk_topk_sort -> inplace - ckernel::topk_local_sort(0, (int)ascending, logk - 1); + ckernel::topk_local_sort(0, (int)direction_init, logk - 1); // pack value tiles into cb_intermed0 pack_reconfig_data_format(input_transposed_cb_index); @@ -108,7 +108,7 @@ void MAIN { copy_tile(index_transposed_cb_index, right_ind, index_dest_end); // merge values - move larger 32 values into 0th dest and lower 32 values into 1st dest - ckernel::topk_merge(0, m_iter, K); + ckernel::topk_merge(0, m_iter, K); // sort within the larger 32 values ckernel::topk_rebuild(0, (uint32_t)direction, m_iter, K, logk, true); diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp index 1c1f4cfdc85..a3bc59e8124 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_op.cpp @@ -122,10 +122,10 @@ operation::ProgramWithCallbacks TopK::create_program( const auto& input_tensor = input_tensors.at(0); if (input_tensor.get_padded_shape()[dim] < topk_utils::multi_core_min_width) { return detail::topk_single_core_interleaved( - input_tensor, this->k, this->dim, output_tensors.at(0), output_tensors.at(1)); + input_tensor, this->k, this->dim, this->largest, output_tensors.at(0), output_tensors.at(1)); } else { return detail::topk_multicore_interleaved( - input_tensor, this->k, this->dim, output_tensors.at(0), output_tensors.at(1)); + input_tensor, this->k, this->dim, this->largest, output_tensors.at(0), output_tensors.at(1)); } } diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_program_factory.hpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_program_factory.hpp index fa739cd6e30..d70adcb00fe 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_program_factory.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/topk_program_factory.hpp @@ -11,7 +11,12 @@ namespace ttnn::operations::reduction::detail { operation::ProgramWithCallbacks topk_single_core_interleaved( - const Tensor& input_tensor, const uint16_t k, const int8_t dim, Tensor& value_tensor, Tensor& index_tensor) { + const Tensor& input_tensor, + const uint16_t k, + const int8_t dim, + const bool largest, + Tensor& value_tensor, + Tensor& index_tensor) { using namespace tt::constants; tt::tt_metal::Program program{}; CoreRange core({0, 0}, {0, 0}); @@ -131,7 +136,7 @@ operation::ProgramWithCallbacks topk_single_core_interleaved( k, (std::uint32_t)std::log2(k), (std::uint32_t)std::log2(Wt), - }; + largest}; tt::tt_metal::KernelHandle topk_compute_kernel_id = tt::tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk.cpp", @@ -203,7 +208,12 @@ static inline std::tuple cores_utilized( * */ operation::ProgramWithCallbacks topk_multicore_interleaved( - const Tensor& input_tensor, const uint16_t k, const int8_t dim, Tensor& value_tensor, Tensor& index_tensor) { + const Tensor& input_tensor, + const uint16_t k, + const int8_t dim, + const bool largest, + Tensor& value_tensor, + Tensor& index_tensor) { using namespace tt::constants; tt::tt_metal::Program program{}; @@ -401,7 +411,7 @@ operation::ProgramWithCallbacks topk_multicore_interleaved( Kt, (std::uint32_t)std::log2(k), (std::uint32_t)std::log2(Wt_local), - }; + largest}; tt::tt_metal::KernelHandle topk_compute_kernel_id = tt::tt_metal::CreateKernel( program, "ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_local.cpp", @@ -421,7 +431,7 @@ operation::ProgramWithCallbacks topk_multicore_interleaved( Kt, (std::uint32_t)std::log2(k), (std::uint32_t)std::log2(Wt_final), - }; + largest}; tt::tt_metal::KernelHandle topk_final_compute_kernel_id = tt::tt_metal::CreateKernel( program, @@ -430,7 +440,7 @@ operation::ProgramWithCallbacks topk_multicore_interleaved( tt::tt_metal::ComputeConfig{.compile_args = compute_args_final}); for (uint32_t core_h = 0; core_h < 1; core_h++) { - bool ascending = false; + bool ascending = !largest; for (uint32_t core_w = 0; core_w < num_cores - 1; core_w++) { CoreCoord core = {core_h, core_w}; SetRuntimeArgs(