From 388d56e645ef5297285387469b24818e98a05885 Mon Sep 17 00:00:00 2001
From: Virdhatchani Narayanamoorthy
<138196495+VirdhatchaniKN@users.noreply.github.com>
Date: Sun, 24 Nov 2024 19:14:08 +0530
Subject: [PATCH] #14982: Update threshold logic (#15362)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
### Ticket
#14982
### Problem description
Provide context for the problem.
### What's changed
- Updated threshold logic to handle cases when input_tensor=threshold
value
- Updated with supported data type and layout
Tests :
- `pytest
tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_eltwise_threshold.py`
:
- `pytest
tests/ttnn/unit_tests/operations/eltwise/test_activation.py::test_threshold`
:
- `pytest
tests/ttnn/unit_tests/operations/eltwise/test_composite.py::test_unary_composite_threshold_ttnn`
:
- `python tests/ttnn/sweep_tests/run_sweeps.py --include threshold.py` -
passed
- `pytest
tests/ttnn/unit_tests/operations/eltwise/test_composite.py::test_threshold_example`
:
### Checklist
- [ ] Post commit CI passes
### Doc screenshot
---
.../unary/device/unary_composite_op.cpp | 8 +++----
.../unary/device/unary_composite_op.hpp | 9 --------
.../eltwise/unary/unary_composite.hpp | 11 +++++++--
.../operations/eltwise/unary/unary_pybind.hpp | 23 +++++++++++++++----
4 files changed, 31 insertions(+), 20 deletions(-)
diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp
index 5395587302d..980a1f97bab 100644
--- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp
+++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp
@@ -543,11 +543,9 @@ Tensor _selu(const Tensor& x, const float scale, const float alpha, const std::o
}
// threshold(a,t,v) = (a <= t)*v + (a > t)*a
-Tensor _threshold(const Tensor& input_tensor, float threshold, float value, const std::optional& output_mem_config) {
- Tensor t0 = ttnn::subtract(input_tensor, threshold, std::nullopt, output_mem_config);
- Tensor t1 = ttnn::multiply(ttnn::lez(t0), value, std::nullopt, output_mem_config);
- Tensor t2 = ttnn::multiply(ttnn::gtz(t0, output_mem_config), input_tensor, std::nullopt, output_mem_config);
- return ttnn::add(t1, t2, std::nullopt, output_mem_config);
+Tensor ExecuteUnaryCompositeThreshold::invoke(const Tensor& input_tensor, float threshold, float value, const std::optional& output_mem_config) {
+ Tensor sub_result = ttnn::subtract(input_tensor, threshold, std::nullopt, output_mem_config);
+ return ttnn::where(ttnn::lez(sub_result), value, input_tensor, output_mem_config);
}
std::vector split_tensor_for_glu(const Tensor& input_a, int32_t dim, const std::optional& output_mem_config) {
diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp
index 8194a669e76..95d5eaa7614 100644
--- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp
+++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp
@@ -37,7 +37,6 @@ enum class UnaryCompositeOpType {
HARDSIGMOID,
HARDTANH,
SELU,
- THRESHOLD,
GLU,
REGLU,
GEGLU,
@@ -82,7 +81,6 @@ Tensor _hardswish(const Tensor&, float scale = 1.0f/6.0f, float shift = 0.5f, c
Tensor _hardsigmoid(const Tensor&, float scale = 1.0f/6.0f, float shift = 0.5f, const std::optional& output_mem_config = std::nullopt);
Tensor _hardtanh(const Tensor&, float min = -1, float max = 1, const std::optional& output_mem_config = std::nullopt);
Tensor _selu(const Tensor&, float scale = 1.0507, float alpha = 1.67326, const std::optional& output_mem_config = std::nullopt);
-Tensor _threshold(const Tensor&, float, float, const std::optional& );
Tensor _glu(const Tensor&, int32_t, const std::optional& );
Tensor _reglu(const Tensor&, int32_t, const std::optional& );
Tensor _geglu(const Tensor&, int32_t, const std::optional& );
@@ -267,13 +265,6 @@ struct OpHandler {
}
};
-template <>
-struct OpHandler {
- static Tensor handle(const Tensor& t1, float threshold, float value, const std::optional& mem_cfg ) {
- return _threshold(t1, threshold, value, mem_cfg);
- }
-};
-
//glu (geglu, reglu, swiglu, glu) varinats are supported only for last dimension.
template <>
struct OpHandler {
diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp
index 2532bfcf36b..3c4cb7ba7ca 100644
--- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp
+++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp
@@ -129,6 +129,14 @@ struct ExecuteUnaryCompositeClamp {
const std::optional &memory_config = std::nullopt);
};
+struct ExecuteUnaryCompositeThreshold {
+ static Tensor invoke(
+ const Tensor &input_tensor,
+ float threshold,
+ float value,
+ const std::optional &memory_config = std::nullopt);
+};
+
struct ExecuteUnaryCompositeClip {
static Tensor invoke(
const Tensor &input_tensor,
@@ -305,8 +313,7 @@ constexpr auto selu = ttnn::register_operation_with_auto_launch_op<
operations::unary::ExecuteUnaryCompositeOpWithFloats>();
constexpr auto threshold = ttnn::register_operation_with_auto_launch_op<
"ttnn::threshold",
- operations::unary::ExecuteUnaryCompositeOpWithFloats>();
-
+ operations::unary::ExecuteUnaryCompositeThreshold>();
constexpr auto glu = ttnn::register_operation_with_auto_launch_op<
"ttnn::glu",
operations::unary::ExecuteUnaryCompositeOpWithDim>();
diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
index 8d98ee52660..f22fb9008f3 100644
--- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
+++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
@@ -1319,7 +1319,7 @@ void bind_unary_composite_int(py::module& module, const unary_operation_t& opera
//OpHandler_two_float_with_default
template
-void bind_unary_composite_floats(
+void bind_unary_composite_threshold(
py::module& module,
const unary_operation_t& operation,
const std::string& parameter_name_a,
@@ -1342,8 +1342,23 @@ void bind_unary_composite_floats(
Returns:
ttnn.Tensor: the output tensor.
+ Note:
+ Supported dtypes, layouts, and ranks:
+
+ .. list-table::
+ :header-rows: 1
+
+ * - Dtypes
+ - Layouts
+ - Ranks
+ * - BFLOAT16
+ - TILE
+ - 2, 3, 4
+
Example:
- >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device)
+ >>> tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
+ >>> {2} = 1.0
+ >>> {4} = 10.0
>>> output = {1}(tensor, {2}, {4})
)doc",
operation.base_name(),
@@ -1975,11 +1990,11 @@ void py_module(py::module& module) {
ttnn::selu,
"scale", "Scale value", 1.0507,
"alpha", "Alpha value", 1.67326);
- detail::bind_unary_composite_floats(
+ detail::bind_unary_composite_threshold(
module,
ttnn::threshold,
"threshold", "Threshold value",
- "value", "Value value",
+ "value", "Replacing value",
R"doc(Performs threshold function on :attr:`input_tensor`, :attr:`threshold`, :attr:`value`.)doc");
detail::bind_unary_composite_int_with_default(
module,