Skip to content

Commit

Permalink
#13310: Implement Fill op using eltwise unary sfpu LLK (#13680)
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW authored Oct 17, 2024
1 parent 912473a commit d1d8e9a
Show file tree
Hide file tree
Showing 22 changed files with 338 additions and 4 deletions.
73 changes: 73 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_fill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

import ttnn

from tests.ttnn.utils_for_testing import assert_equal
from models.utility_functions import skip_for_grayskull


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 20, 31])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
(torch.Size([2, 4, 320, 1024])),
),
)
@pytest.mark.parametrize("fill_value", [1, 0, 5.5, -2.235])
def test_fill(device, input_shapes, fill_value):
torch_input_tensor = torch.randn((input_shapes), dtype=torch.bfloat16)
torch_output_tensor = torch.full((input_shapes), fill_value, dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
output = ttnn.fill(input_tensor, fill_value)
output_tensor = ttnn.to_torch(output)
assert_equal(torch_output_tensor, output_tensor)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 20, 31])),
(torch.Size([1, 1, 320, 384])),
),
)
@pytest.mark.parametrize("fill_value", [7, -9])
def test_fill_int(device, input_shapes, fill_value):
torch_input_tensor = torch.randn((input_shapes), dtype=torch.bfloat16).to(torch.int32)
torch_output_tensor = torch.full((input_shapes), fill_value, dtype=torch.int32)

input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.int32, layout=ttnn.TILE_LAYOUT, device=device)
output = ttnn.fill(input_tensor, fill_value)
output_tensor = ttnn.to_torch(output)
assert_equal(torch_output_tensor, output_tensor)


@skip_for_grayskull("Unsupported dtype for Grayskull")
@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 20, 31])),
(torch.Size([1, 1, 320, 384])),
),
)
@pytest.mark.parametrize("fill_value", [5.88, -9.76])
def test_fill_fp32(device, input_shapes, fill_value):
torch_input_tensor = torch.randn((input_shapes), dtype=torch.float32)
torch_output_tensor = torch.full((input_shapes), fill_value, dtype=torch.float32)

input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
output = ttnn.fill(input_tensor, fill_value)
output_tensor = ttnn.to_torch(output)
assert_equal(torch_output_tensor, output_tensor)
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
#include "llk_math_eltwise_unary_sfpu_topk.h"
#include "llk_math_eltwise_unary_sfpu_trigonometry.h"
#include "llk_math_eltwise_unary_sfpu_unary_comp.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "ckernel_sfpu_converter.h"


using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_fill(const uint value) {

// SFPU microcode
Converter c_value;
c_value.u = value;
vFloat fill_val = c_value.f;

#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++)
{
dst_reg[0] = fill_val;
dst_reg++;
}
}
} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_fill_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::fill, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_fill(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_fill<APPROXIMATE>,
dst_index,
vector_mode,
param0);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,6 @@ enum SfpuType {
fmod,
ceil,
unused,
cumsum
cumsum,
fill
};
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "llk_math_eltwise_unary_sfpu_topk.h"
#include "llk_math_eltwise_unary_sfpu_trigonometry.h"
#include "llk_math_eltwise_unary_sfpu_unary_comp.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"

namespace ckernel {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "ckernel_sfpu_converter.h"


using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int ITERATIONS = 4>
inline void calculate_fill(const uint value) {

// SFPU microcode
Converter c_value;
c_value.u = value;
vFloat fill_val = c_value.f;

#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++)
{
dst_reg[0] = fill_val;
dst_reg++;
}
}
} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_fill_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::fill, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_fill(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_fill<APPROXIMATE>,
dst_index,
vector_mode,
param0);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,5 @@ enum SfpuType {
unary_lt,
tiled_prod,
unused,
fill,
};
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
#include "llk_math_eltwise_unary_sfpu_bitwise_or.h"
#include "llk_math_eltwise_unary_sfpu_right_shift.h"
#include "llk_math_eltwise_unary_sfpu_left_shift.h"
#include "llk_math_eltwise_unary_sfpu_fill.h"
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel.h"
#include "ckernel_defs.h"
#include "noc_nonblocking_api.h"
#include "ckernel_sfpu_converter.h"


using namespace sfpi;

namespace ckernel {
namespace sfpu {

template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
inline void calculate_fill(const uint value) {

// SFPU microcode
Converter c_value;
c_value.u = value;
vFloat fill_val = c_value.f;

#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++)
{
dst_reg[0] = fill_val;
dst_reg++;
}
}
} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include "ckernel_sfpu_fill.h"
#include "llk_math_eltwise_unary_sfpu_params.h"
#include "llk_math_eltwise_unary_sfpu_init.h"

namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_fill_init() {
llk_math_eltwise_unary_sfpu_init<SfpuType::fill, APPROXIMATE>();
}

template <bool APPROXIMATE>
inline void llk_math_eltwise_unary_sfpu_fill(uint dst_index, uint param0, int vector_mode = (int)VectorMode::RC) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_fill<APPROXIMATE>,
dst_index,
vector_mode,
param0);
}

} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,6 @@ enum SfpuType {
ceil,
unused,
reshuffle_rows,
cumsum
cumsum,
fill
};
43 changes: 43 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/fill.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once


#include "compute_kernel_api/common_globals.h"
#ifdef TRISC_MATH
#include "llk_math_eltwise_unary_sfpu_fill.h"
#define MAIN math_main()
#define MATH(x) x
#else
#define MATH(x)
#endif



namespace ckernel {

/**
* Performs element-wise fill operation. The value to be filled in the tile is provided as const param0. The DST register buffer must be in
* acquired state via *acquire_dst* call. This call is blocking and is only
* available on the compute engine.
*
* Return value: None
*
* | Argument | Description | Type | Valid Range | Required |
* |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | idst | The index of the tile in DST register buffer to perform the computation on | uint32_t | Must be less than the size of the DST register buffer | True |
* | param0 | The value the output is if the input is greater than 0 | uint32_t | | True |
*/
ALWI void fill_tile(uint32_t idst, uint32_t param0) {
MATH((llk_math_eltwise_unary_sfpu_fill<APPROX>(idst, param0)));
}

/**
* Please refer to documentation for any_init.
*/
ALWI void fill_tile_init() { MATH((llk_math_eltwise_unary_sfpu_fill_init<APPROX>())); }


} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@
#include "compute_kernel_api/eltwise_unary/dropout.h"
#endif

#if SFPU_OP_FILL_INCLUDE
#include "compute_kernel_api/eltwise_unary/fill.h"
#endif

#if SFPU_OP_COMPUTE_KERNEL_API_INCLUDE
#include "compute_kernel_api.h"
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ enum class UnaryOpType {
LEFT_SHIFT,
REMAINDER,
FMOD,
DROPOUT
DROPOUT,
FILL
};

struct UnaryWithParam {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,21 @@ void update_macro_defines(UnaryOpType op_type, std::map<std::string, std::string
case UnaryOpType::REMAINDER: defines["SFPU_OP_REMAINDER_INCLUDE"] = "1"; break;
case UnaryOpType::FMOD: defines["SFPU_OP_FMOD_INCLUDE"] = "1"; break;
case UnaryOpType::DROPOUT: defines["SFPU_OP_DROPOUT_INCLUDE"] = "1"; break;
case UnaryOpType::FILL: defines["SFPU_OP_FILL_INCLUDE"] = "1"; break;
default: defines["SFPU_OP_COMPUTE_KERNEL_API_INCLUDE"] = "1"; break;
};
}

std::pair<std::string, std::string> get_op_init_and_func_parameterized(
UnaryOpType op_type, const std::vector<float>& params, const std::string& idst) {
std::pair<std::string, std::string> op_init_and_name;
TT_FATAL(is_parametrized_type(op_type) && "operator should support at least one parameter", "Error");
TT_FATAL(is_parametrized_type(op_type), "operator should support at least one parameter", "Error");
float param0 = params[0];
switch (op_type) {
case UnaryOpType::FILL:
op_init_and_name = {
"fill_tile_init();", fmt::format("fill_tile({}, {}u);", idst, Converter::to_hex(param0))};
break;
case UnaryOpType::RELU_MAX:
op_init_and_name = {
"relu_max_tile_init();", fmt::format("relu_max_tile({}, {}u);", idst, Converter::to_hex(param0))};
Expand Down
Loading

0 comments on commit d1d8e9a

Please sign in to comment.