From 64e4badc6951d2bde32f4b69ebed4814af47e63d Mon Sep 17 00:00:00 2001 From: Pavle Milenkovic Date: Tue, 18 Feb 2025 13:48:43 +0100 Subject: [PATCH] #16174: Support for int32 subtraction for WHB0 and BH (#17359) ### Ticket #16174 ### Problem description Subtraction of int32 dtype was not supported on WHB0 and BH. ### What's changed Added necessary APIs, LLKs, and modified codepaths to include sub int32 operation. This operation was done through SFPU. ### Checklist - [x] Post commit CI passes - [x] Blackhole Post commit (if applicable) - [ ] Model regression CI testing passes (if applicable) - [ ] Device performance regression CI testing passes (if applicable) - [ ] **(For models and ops writers)** Full [new models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml) tests passes - [x] New/Existing tests provide coverage for changes --- .../operations/eltwise/test_binary_fp32.py | 20 ++++++++ .../llk_api/llk_sfpu/ckernel_sfpu_sub_int32.h | 22 +++++++++ .../llk_math_eltwise_binary_sfpu_sub_int32.h | 27 +++++++++++ .../llk_api/llk_sfpu/ckernel_sfpu_sub_int32.h | 22 +++++++++ .../llk_math_eltwise_binary_sfpu_sub_int32.h | 27 +++++++++++ .../compute_kernel_api/sub_int32_sfpu.h | 47 +++++++++++++++++++ tt_metal/third_party/tt_llk_blackhole | 2 +- tt_metal/third_party/tt_llk_wormhole_b0 | 2 +- .../eltwise/binary/common/binary_op_utils.cpp | 9 +++- .../binary/device/binary_device_operation.cpp | 3 +- .../compute/eltwise_binary_sfpu_kernel.cpp | 4 ++ 11 files changed, 180 insertions(+), 5 deletions(-) create mode 100644 tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_sub_int32.h create mode 100644 tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_sub_int32.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_sub_int32.h create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_sub_int32.h create mode 100644 tt_metal/include/compute_kernel_api/sub_int32_sfpu.h diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_binary_fp32.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_fp32.py index 6c3c37fc7d5..eb73010e54f 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_binary_fp32.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_fp32.py @@ -93,6 +93,26 @@ def test_add_int32(device, ttnn_function): assert status +@skip_for_grayskull("Unsupported dtype for Grayskull") +@pytest.mark.parametrize( + "ttnn_function", + [ + ttnn.sub, + ], +) +def test_sub_int32(device, ttnn_function): + x_torch = torch.tensor([[11, 23, 0, -23, -1, -100]], dtype=torch.int32) + y_torch = torch.tensor([[78, 99, 34, -33, -1, 100]], dtype=torch.int32) + golden_fn = ttnn.get_golden_function(ttnn_function) + z_torch = golden_fn(x_torch, y_torch) + x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device) + y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device) + z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device) + z_tt_sub = ttnn.sub(x_tt, y_tt) + tt_out = ttnn.to_torch(z_tt_sub) + assert torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False) + + @skip_for_grayskull("Unsupported dtype for Grayskull") @pytest.mark.parametrize( "ttnn_function", diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_sub_int32.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_sub_int32.h new file mode 100644 index 00000000000..154cf20122e --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_sub_int32.h @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "sfpi.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_sub_int32(const uint dst_offset) { + _sub_int32_(dst_offset); +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_sub_int32.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_sub_int32.h new file mode 100644 index 00000000000..4efe45a1c23 --- /dev/null +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_sub_int32.h @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_params.h" +#include "ckernel_sfpu_sub_int32.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_binary_sfpu_sub_int32_init() { + llk_math_eltwise_binary_sfpu_init(); +} + +template +inline void llk_math_eltwise_binary_sfpu_sub_int32( + uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_sub_int32, dst_index0, dst_index1, vector_mode); +} + +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_sub_int32.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_sub_int32.h new file mode 100644 index 00000000000..154cf20122e --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_sub_int32.h @@ -0,0 +1,22 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ckernel.h" +#include "ckernel_defs.h" +#include "sfpi.h" + +using namespace sfpi; + +namespace ckernel { +namespace sfpu { + +template +inline void calculate_sub_int32(const uint dst_offset) { + _sub_int32_(dst_offset); +} + +} // namespace sfpu +} // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_sub_int32.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_sub_int32.h new file mode 100644 index 00000000000..4efe45a1c23 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_sub_int32.h @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_binary_sfpu_init.h" +#include "llk_math_eltwise_binary_sfpu_params.h" +#include "ckernel_sfpu_sub_int32.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_binary_sfpu_sub_int32_init() { + llk_math_eltwise_binary_sfpu_init(); +} + +template +inline void llk_math_eltwise_binary_sfpu_sub_int32( + uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) { + llk_math_eltwise_binary_sfpu_params( + ckernel::sfpu::calculate_sub_int32, dst_index0, dst_index1, vector_mode); +} + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/sub_int32_sfpu.h b/tt_metal/include/compute_kernel_api/sub_int32_sfpu.h new file mode 100644 index 00000000000..ee3c9b998c7 --- /dev/null +++ b/tt_metal/include/compute_kernel_api/sub_int32_sfpu.h @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_binary_sfpu_sub_int32.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + +namespace ckernel { + +// clang-format off +/** + * Performs an elementwise sub operation with the two integer inputs: y = sub(x0,x1) + * Output overwrites first operand in DST. + * + * The DST register buffer must be in acquired state via *acquire_dst* call. This call is blocking and is only available + * on the compute engine. + * A maximum of 4 tiles from each operand can be loaded into DST at once, for a total of 8 tiles, + * when using 16 bit formats. This gets reduced to 2 tiles from each operand for 32 bit formats. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | Required | + * |-----------------------|-----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst0 | The index of the tile in DST register buffer to use as first operand | uint32_t | Must be less than the size of the DST register buffer | True | + * | idst1 | The index of the tile in DST register buffer to use as second operand | uint32_t | Must be less than the size of the DST register buffer | True | + * | sign_magnitude_format | Whether the Int32 values are in sign-magnitude format (not 2's complement) | bool | | False | + */ +// clang-format on +template +ALWI void sub_int32_tile(uint32_t idst0, uint32_t idst1) { + MATH((llk_math_eltwise_binary_sfpu_sub_int32(idst0, idst1))); +} + +/** + * Please refer to documentation for any_init. + */ +ALWI void sub_int32_tile_init() { MATH((llk_math_eltwise_binary_sfpu_sub_int32_init())); } + +} // namespace ckernel diff --git a/tt_metal/third_party/tt_llk_blackhole b/tt_metal/third_party/tt_llk_blackhole index 9fd3e2d93d1..76b5357a75b 160000 --- a/tt_metal/third_party/tt_llk_blackhole +++ b/tt_metal/third_party/tt_llk_blackhole @@ -1 +1 @@ -Subproject commit 9fd3e2d93d1532373f52e11e963de40c1cdf9a55 +Subproject commit 76b5357a75bfed7dac22a7b0417bb5589c2e0c5b diff --git a/tt_metal/third_party/tt_llk_wormhole_b0 b/tt_metal/third_party/tt_llk_wormhole_b0 index 0ec3177bfc2..a34e1966683 160000 --- a/tt_metal/third_party/tt_llk_wormhole_b0 +++ b/tt_metal/third_party/tt_llk_wormhole_b0 @@ -1 +1 @@ -Subproject commit 0ec3177bfc262f7edf6cfc19531ecb8f669895d2 +Subproject commit a34e1966683c478d575d5ea79413004955c8a57f diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp index 153c99488ba..1b2d48bf618 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/common/binary_op_utils.cpp @@ -191,8 +191,13 @@ std::map get_defines_fp32( } break; case BinaryOpType::SUB: - new_defines.insert({"BINOP_INIT", fmt::format("sub_binary_tile_init();")}); - op_name = "sub_binary_tile"; + if (input_a_dtype == DataType::INT32 && input_b_dtype == DataType::INT32) { + new_defines.insert({"SUB_INT32_INIT", "sub_int32_tile_init();"}); + op_name = "sub_int32_tile"; + } else { + new_defines.insert({"BINOP_INIT", "sub_binary_tile_init();"}); + op_name = "sub_binary_tile"; + } break; case BinaryOpType::MUL: new_defines.insert({"BINOP_INIT", fmt::format("mul_binary_tile_init();")}); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp index a3c7d86cc81..094d5d2a0cc 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp @@ -17,8 +17,9 @@ namespace ttnn::operations::binary { namespace utils { bool is_binary_sfpu_op(BinaryOpType val, DataType a, DataType b) { switch (val) { - case BinaryOpType::ADD: return ((a == DataType::FLOAT32 && b == DataType::FLOAT32) || (a == DataType::INT32 && b == DataType::INT32)); + case BinaryOpType::ADD: case BinaryOpType::SUB: + return ((a == DataType::FLOAT32 && b == DataType::FLOAT32) || (a == DataType::INT32 && b == DataType::INT32)); case BinaryOpType::MUL: case BinaryOpType::DIV_FAST: case BinaryOpType::RSUB: diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_sfpu_kernel.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_sfpu_kernel.cpp index c083a354fae..032118851f7 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_sfpu_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/kernels/compute/eltwise_binary_sfpu_kernel.cpp @@ -13,6 +13,7 @@ #include "compute_kernel_api/binary_bitwise_sfpu.h" #include "compute_kernel_api/binary_shift.h" #include "compute_kernel_api/add_int32_sfpu.h" +#include "compute_kernel_api/sub_int32_sfpu.h" #define PRE_SCALE defined SFPU_OP_INIT_PRE_IN0_0 || defined SFPU_OP_INIT_PRE_IN1_0 @@ -113,6 +114,9 @@ void MAIN { #ifdef ADD_INT32_INIT ADD_INT32_INIT #endif +#ifdef SUB_INT32_INIT + SUB_INT32_INIT +#endif #ifdef BITWISE_INIT BITWISE_INIT #endif