From 388d56e645ef5297285387469b24818e98a05885 Mon Sep 17 00:00:00 2001 From: Virdhatchani Narayanamoorthy <138196495+VirdhatchaniKN@users.noreply.github.com> Date: Sun, 24 Nov 2024 19:14:08 +0530 Subject: [PATCH] #14982: Update threshold logic (#15362) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Ticket #14982 ### Problem description Provide context for the problem. ### What's changed - Updated threshold logic to handle cases when input_tensor=threshold value - Updated with supported data type and layout Tests : - `pytest tests/ttnn/python_api_testing/non_working_unit_tests/grayskull/test_eltwise_threshold.py` : Screenshot 2024-11-22 at 6 09 03 PM - `pytest tests/ttnn/unit_tests/operations/eltwise/test_activation.py::test_threshold` : Screenshot 2024-11-22 at 6 09 39 PM - `pytest tests/ttnn/unit_tests/operations/eltwise/test_composite.py::test_unary_composite_threshold_ttnn` : Screenshot 2024-11-22 at 6 10 21 PM - `python tests/ttnn/sweep_tests/run_sweeps.py --include threshold.py` - passed - `pytest tests/ttnn/unit_tests/operations/eltwise/test_composite.py::test_threshold_example` : Screenshot 2024-11-22 at 6 11 40 PM ### Checklist - [ ] Post commit CI passes ### Doc screenshot Screenshot 2024-11-22 at 11 06 48 PM --- .../unary/device/unary_composite_op.cpp | 8 +++---- .../unary/device/unary_composite_op.hpp | 9 -------- .../eltwise/unary/unary_composite.hpp | 11 +++++++-- .../operations/eltwise/unary/unary_pybind.hpp | 23 +++++++++++++++---- 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index 5395587302d..980a1f97bab 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -543,11 +543,9 @@ Tensor _selu(const Tensor& x, const float scale, const float alpha, const std::o } // threshold(a,t,v) = (a <= t)*v + (a > t)*a -Tensor _threshold(const Tensor& input_tensor, float threshold, float value, const std::optional& output_mem_config) { - Tensor t0 = ttnn::subtract(input_tensor, threshold, std::nullopt, output_mem_config); - Tensor t1 = ttnn::multiply(ttnn::lez(t0), value, std::nullopt, output_mem_config); - Tensor t2 = ttnn::multiply(ttnn::gtz(t0, output_mem_config), input_tensor, std::nullopt, output_mem_config); - return ttnn::add(t1, t2, std::nullopt, output_mem_config); +Tensor ExecuteUnaryCompositeThreshold::invoke(const Tensor& input_tensor, float threshold, float value, const std::optional& output_mem_config) { + Tensor sub_result = ttnn::subtract(input_tensor, threshold, std::nullopt, output_mem_config); + return ttnn::where(ttnn::lez(sub_result), value, input_tensor, output_mem_config); } std::vector split_tensor_for_glu(const Tensor& input_a, int32_t dim, const std::optional& output_mem_config) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp index 8194a669e76..95d5eaa7614 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp @@ -37,7 +37,6 @@ enum class UnaryCompositeOpType { HARDSIGMOID, HARDTANH, SELU, - THRESHOLD, GLU, REGLU, GEGLU, @@ -82,7 +81,6 @@ Tensor _hardswish(const Tensor&, float scale = 1.0f/6.0f, float shift = 0.5f, c Tensor _hardsigmoid(const Tensor&, float scale = 1.0f/6.0f, float shift = 0.5f, const std::optional& output_mem_config = std::nullopt); Tensor _hardtanh(const Tensor&, float min = -1, float max = 1, const std::optional& output_mem_config = std::nullopt); Tensor _selu(const Tensor&, float scale = 1.0507, float alpha = 1.67326, const std::optional& output_mem_config = std::nullopt); -Tensor _threshold(const Tensor&, float, float, const std::optional& ); Tensor _glu(const Tensor&, int32_t, const std::optional& ); Tensor _reglu(const Tensor&, int32_t, const std::optional& ); Tensor _geglu(const Tensor&, int32_t, const std::optional& ); @@ -267,13 +265,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static Tensor handle(const Tensor& t1, float threshold, float value, const std::optional& mem_cfg ) { - return _threshold(t1, threshold, value, mem_cfg); - } -}; - //glu (geglu, reglu, swiglu, glu) varinats are supported only for last dimension. template <> struct OpHandler { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp index 2532bfcf36b..3c4cb7ba7ca 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp @@ -129,6 +129,14 @@ struct ExecuteUnaryCompositeClamp { const std::optional &memory_config = std::nullopt); }; +struct ExecuteUnaryCompositeThreshold { + static Tensor invoke( + const Tensor &input_tensor, + float threshold, + float value, + const std::optional &memory_config = std::nullopt); +}; + struct ExecuteUnaryCompositeClip { static Tensor invoke( const Tensor &input_tensor, @@ -305,8 +313,7 @@ constexpr auto selu = ttnn::register_operation_with_auto_launch_op< operations::unary::ExecuteUnaryCompositeOpWithFloats>(); constexpr auto threshold = ttnn::register_operation_with_auto_launch_op< "ttnn::threshold", - operations::unary::ExecuteUnaryCompositeOpWithFloats>(); - + operations::unary::ExecuteUnaryCompositeThreshold>(); constexpr auto glu = ttnn::register_operation_with_auto_launch_op< "ttnn::glu", operations::unary::ExecuteUnaryCompositeOpWithDim>(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index 8d98ee52660..f22fb9008f3 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -1319,7 +1319,7 @@ void bind_unary_composite_int(py::module& module, const unary_operation_t& opera //OpHandler_two_float_with_default template -void bind_unary_composite_floats( +void bind_unary_composite_threshold( py::module& module, const unary_operation_t& operation, const std::string& parameter_name_a, @@ -1342,8 +1342,23 @@ void bind_unary_composite_floats( Returns: ttnn.Tensor: the output tensor. + Note: + Supported dtypes, layouts, and ranks: + + .. list-table:: + :header-rows: 1 + + * - Dtypes + - Layouts + - Ranks + * - BFLOAT16 + - TILE + - 2, 3, 4 + Example: - >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> tensor = ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device) + >>> {2} = 1.0 + >>> {4} = 10.0 >>> output = {1}(tensor, {2}, {4}) )doc", operation.base_name(), @@ -1975,11 +1990,11 @@ void py_module(py::module& module) { ttnn::selu, "scale", "Scale value", 1.0507, "alpha", "Alpha value", 1.67326); - detail::bind_unary_composite_floats( + detail::bind_unary_composite_threshold( module, ttnn::threshold, "threshold", "Threshold value", - "value", "Value value", + "value", "Replacing value", R"doc(Performs threshold function on :attr:`input_tensor`, :attr:`threshold`, :attr:`value`.)doc"); detail::bind_unary_composite_int_with_default( module,