Skip to content

Commit

Permalink
#16174: Support for int32 subtraction for WHB0 and BH (#17359)
Browse files Browse the repository at this point in the history
### Ticket
#16174 

### Problem description
Subtraction of int32 dtype was not supported on WHB0 and BH.

### What's changed
Added necessary APIs, LLKs, and modified codepaths to include sub int32
operation.
This operation was done through SFPU. 

### Checklist
- [x] Post commit CI passes
- [x] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] **(For models and ops writers)** Full [new
models](https://github.com/tenstorrent/tt-metal/actions/workflows/full-new-models-suite.yaml)
tests passes
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
pmilenkovicTT authored Feb 18, 2025
1 parent 0ef76c0 commit 64e4bad
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 5 deletions.
20 changes: 20 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ def test_add_int32(device, ttnn_function):
assert status


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
[
ttnn.sub,
],
)
def test_sub_int32(device, ttnn_function):
x_torch = torch.tensor([[11, 23, 0, -23, -1, -100]], dtype=torch.int32)
y_torch = torch.tensor([[78, 99, 34, -33, -1, 100]], dtype=torch.int32)
golden_fn = ttnn.get_golden_function(ttnn_function)
z_torch = golden_fn(x_torch, y_torch)
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt = ttnn.from_torch(z_torch, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
z_tt_sub = ttnn.sub(x_tt, y_tt)
tt_out = ttnn.to_torch(z_tt_sub)
assert torch.allclose(z_torch, tt_out, atol=1e-10, rtol=1e-5, equal_nan=False)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"ttnn_function",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "sfpi.h"

using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, bool SIGN_MAGNITUDE_FORMAT, int ITERATIONS = 8>
inline void calculate_sub_int32(const uint dst_offset) {
_sub_int32_<APPROXIMATION_MODE, SIGN_MAGNITUDE_FORMAT, ITERATIONS>(dst_offset);
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "llk_math_eltwise_binary_sfpu_init.h"
#include "llk_math_eltwise_binary_sfpu_params.h"
#include "ckernel_sfpu_sub_int32.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_sub_int32_init() {
llk_math_eltwise_binary_sfpu_init<SfpuType::unused, APPROXIMATE>();
}

template <bool APPROXIMATE, bool SIGN_MAGNITUDE_FORMAT>
inline void llk_math_eltwise_binary_sfpu_sub_int32(
uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) {
llk_math_eltwise_binary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_sub_int32<APPROXIMATE, SIGN_MAGNITUDE_FORMAT>, dst_index0, dst_index1, vector_mode);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "sfpi.h"

using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, bool SIGN_MAGNITUDE_FORMAT, int ITERATIONS = 8>
inline void calculate_sub_int32(const uint dst_offset) {
_sub_int32_<APPROXIMATION_MODE, SIGN_MAGNITUDE_FORMAT, ITERATIONS>(dst_offset);
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "llk_math_eltwise_binary_sfpu_init.h"
#include "llk_math_eltwise_binary_sfpu_params.h"
#include "ckernel_sfpu_sub_int32.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_binary_sfpu_sub_int32_init() {
llk_math_eltwise_binary_sfpu_init<SfpuType::unused, APPROXIMATE>();
}

template <bool APPROXIMATE, bool SIGN_MAGNITUDE_FORMAT>
inline void llk_math_eltwise_binary_sfpu_sub_int32(
uint dst_index0, uint32_t dst_index1, int vector_mode = VectorMode::RC) {
llk_math_eltwise_binary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_sub_int32<APPROXIMATE, SIGN_MAGNITUDE_FORMAT>, dst_index0, dst_index1, vector_mode);
}

} // namespace ckernel
47 changes: 47 additions & 0 deletions tt_metal/include/compute_kernel_api/sub_int32_sfpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: © 2025 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "compute_kernel_api/common_globals.h"
#ifdef TRISC_MATH
#include "llk_math_eltwise_binary_sfpu_sub_int32.h"
#define MAIN math_main()
#define MATH(x) x
#else
#define MATH(x)
#endif

namespace ckernel {

// clang-format off
/**
* Performs an elementwise sub operation with the two integer inputs: y = sub(x0,x1)
* Output overwrites first operand in DST.
*
* The DST register buffer must be in acquired state via *acquire_dst* call. This call is blocking and is only available
* on the compute engine.
* A maximum of 4 tiles from each operand can be loaded into DST at once, for a total of 8 tiles,
* when using 16 bit formats. This gets reduced to 2 tiles from each operand for 32 bit formats.
*
* Return value: None
*
* | Argument | Description | Type | Valid Range | Required |
* |-----------------------|-----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | idst0 | The index of the tile in DST register buffer to use as first operand | uint32_t | Must be less than the size of the DST register buffer | True |
* | idst1 | The index of the tile in DST register buffer to use as second operand | uint32_t | Must be less than the size of the DST register buffer | True |
* | sign_magnitude_format | Whether the Int32 values are in sign-magnitude format (not 2's complement) | bool | | False |
*/
// clang-format on
template <bool sign_magnitude_format = false>
ALWI void sub_int32_tile(uint32_t idst0, uint32_t idst1) {
MATH((llk_math_eltwise_binary_sfpu_sub_int32<APPROX, sign_magnitude_format>(idst0, idst1)));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void sub_int32_tile_init() { MATH((llk_math_eltwise_binary_sfpu_sub_int32_init<APPROX>())); }

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,13 @@ std::map<std::string, std::string> get_defines_fp32(
}
break;
case BinaryOpType::SUB:
new_defines.insert({"BINOP_INIT", fmt::format("sub_binary_tile_init();")});
op_name = "sub_binary_tile";
if (input_a_dtype == DataType::INT32 && input_b_dtype == DataType::INT32) {
new_defines.insert({"SUB_INT32_INIT", "sub_int32_tile_init();"});
op_name = "sub_int32_tile";
} else {
new_defines.insert({"BINOP_INIT", "sub_binary_tile_init();"});
op_name = "sub_binary_tile";
}
break;
case BinaryOpType::MUL:
new_defines.insert({"BINOP_INIT", fmt::format("mul_binary_tile_init();")});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ namespace ttnn::operations::binary {
namespace utils {
bool is_binary_sfpu_op(BinaryOpType val, DataType a, DataType b) {
switch (val) {
case BinaryOpType::ADD: return ((a == DataType::FLOAT32 && b == DataType::FLOAT32) || (a == DataType::INT32 && b == DataType::INT32));
case BinaryOpType::ADD:
case BinaryOpType::SUB:
return ((a == DataType::FLOAT32 && b == DataType::FLOAT32) || (a == DataType::INT32 && b == DataType::INT32));
case BinaryOpType::MUL:
case BinaryOpType::DIV_FAST:
case BinaryOpType::RSUB:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "compute_kernel_api/binary_bitwise_sfpu.h"
#include "compute_kernel_api/binary_shift.h"
#include "compute_kernel_api/add_int32_sfpu.h"
#include "compute_kernel_api/sub_int32_sfpu.h"

#define PRE_SCALE defined SFPU_OP_INIT_PRE_IN0_0 || defined SFPU_OP_INIT_PRE_IN1_0

Expand Down Expand Up @@ -113,6 +114,9 @@ void MAIN {
#ifdef ADD_INT32_INIT
ADD_INT32_INIT
#endif
#ifdef SUB_INT32_INIT
SUB_INT32_INIT
#endif
#ifdef BITWISE_INIT
BITWISE_INIT
#endif
Expand Down

0 comments on commit 64e4bad

Please sign in to comment.