From a73be8a1bc1c78915b95fc1d7297669fb1ebe72f Mon Sep 17 00:00:00 2001 From: Virdhatchani Narayanamoorthy <138196495+VirdhatchaniKN@users.noreply.github.com> Date: Wed, 16 Oct 2024 11:02:19 +0530 Subject: [PATCH] #13527: Update ttnn.clamp logic to match PyTorch API (#13530) * #13527: Cleanup clamp * #13527: Update clamp to match PyTorch API --- .../operations/eltwise/test_composite.py | 35 ++++++++--- .../unary/device/unary_composite_op.cpp | 20 ++++++- .../unary/device/unary_composite_op.hpp | 9 --- .../eltwise/unary/unary_composite.hpp | 10 +++- .../operations/eltwise/unary/unary_pybind.hpp | 58 +++++++++++++++++-- ttnn/ttnn/operations/unary.py | 2 +- 6 files changed, 108 insertions(+), 26 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_composite.py index 5f43cd2ee17..758a7122f2a 100644 --- a/tests/ttnn/unit_tests/operations/eltwise/test_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_composite.py @@ -93,16 +93,33 @@ def test_unary_composite_cbrt_ttnn(input_shapes, device): (torch.Size([1, 3, 320, 384])), ), ) -def test_unary_composite_clamp_ttnn(input_shapes, device): +@pytest.mark.parametrize( + "min, max", + [ + (-10, 10), + (1, -1), + (0, 0), + (-1.0, None), + (None, 1.0), + (None, None), + (-0.5, None), + (None, -0.5), + (1.0, 0.0), + (0.0, 1.0), + ], +) +def test_unary_composite_clamp_ttnn(input_shapes, min, max, device): in_data1, input_tensor1 = data_gen_with_range(input_shapes, -100, 100, device) - min = -10 - max = 10 - output_tensor = ttnn.clamp(input_tensor1, min, max) - golden_function = ttnn.get_golden_function(ttnn.clamp) - golden_tensor = golden_function(in_data1, min, max) - - comp_pass = compare_pcc([output_tensor], [golden_tensor]) - assert comp_pass + if min is None and max is None: + with pytest.raises(RuntimeError, match="Only one of 'min' or 'max' can be None. Please provide one value"): + ttnn.clamp(input_tensor1, min=min, max=max) + assert True + else: + output_tensor = ttnn.clamp(input_tensor1, min=min, max=max) + golden_function = ttnn.get_golden_function(ttnn.clamp) + golden_tensor = golden_function(in_data1, min=min, max=max) + comp_pass = compare_pcc([output_tensor], [golden_tensor]) + assert comp_pass @pytest.mark.parametrize( diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index 5c52b147acd..6489de03deb 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -477,8 +477,24 @@ Tensor _clip(const Tensor& a, float low, float high, const std::optional& output_mem_config) { - return _clip(a, low, high, output_mem_config); +Tensor ExecuteUnaryCompositeClamp::invoke(const Tensor& a, std::optional min, std::optional max, const std::optional& output_mem_config) { + auto output_memory_config = output_mem_config.value_or(a.memory_config()); + TT_FATAL((max.has_value() || min.has_value()), "Only one of 'min' or 'max' can be None. Please provide one value"); + if (!max.has_value()) { + return ttnn::where( ttnn::ge(a, min.value(), std::nullopt, output_memory_config), a, min.value(), output_memory_config); + }else if(!min.has_value()) { + return ttnn::where( ttnn::le(a, max.value(), std::nullopt, output_memory_config), a, max.value(), output_memory_config); + }else if(min.value() > max.value()){ + return full_like(a, max.value()); + } + const Tensor h_const = full_like(a, max.value()); + Tensor a_max = ttnn::minimum(a, h_const, output_memory_config); + if (min.value() == 0.0f) { + return ttnn::relu(a_max, output_memory_config); + } else { + const Tensor l_const = full_like(a, min.value()); + return ttnn::maximum(a_max, l_const, output_memory_config); + } } // hardtanh diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp index 3608be72e02..446b7b80326 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp @@ -38,7 +38,6 @@ enum class UnaryCompositeOpType { HARDSIGMOID, HARDTANH, CLIP, - CLAMP, SELU, THRESHOLD, GLU, @@ -86,7 +85,6 @@ Tensor _hardswish(const Tensor&, float scale = 1.0f/6.0f, float shift = 0.5f, c Tensor _hardsigmoid(const Tensor&, float scale = 1.0f/6.0f, float shift = 0.5f, const std::optional& output_mem_config = std::nullopt); Tensor _hardtanh(const Tensor&, float min = -1, float max = 1, const std::optional& output_mem_config = std::nullopt); Tensor _clip(const Tensor&, float, float, const std::optional& ); -Tensor _clamp(const Tensor&, float, float, const std::optional& ); Tensor _selu(const Tensor&, float scale = 1.0507, float alpha = 1.67326, const std::optional& output_mem_config = std::nullopt); Tensor _threshold(const Tensor&, float, float, const std::optional& ); Tensor _glu(const Tensor&, int32_t, const std::optional& ); @@ -280,13 +278,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static Tensor handle(const Tensor& t1, float low, float high, const std::optional& mem_cfg ) { - return _clamp(t1, low, high, mem_cfg); - } -}; - template <> struct OpHandler { static Tensor handle(const Tensor& t1, float scale, float alpha, const std::optional& mem_cfg ) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp index 92afe1d24fc..1bbe00a7e0e 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp @@ -115,6 +115,14 @@ struct ExecuteUnaryCompositeOpWithFloats { } }; +struct ExecuteUnaryCompositeClamp { + static Tensor invoke( + const Tensor &input_tensor, + std::optional min = std::nullopt, + std::optional max = std::nullopt, + const std::optional &memory_config = std::nullopt); +}; + template struct ExecuteUnaryCompositeOpWithInt { @@ -265,7 +273,7 @@ constexpr auto clip = ttnn::register_operation_with_auto_launch_op< operations::unary::ExecuteUnaryCompositeOpWithFloats>(); constexpr auto clamp = ttnn::register_operation_with_auto_launch_op< "ttnn::clamp", - operations::unary::ExecuteUnaryCompositeOpWithFloats>(); + operations::unary::ExecuteUnaryCompositeClamp>(); constexpr auto selu = ttnn::register_operation_with_auto_launch_op< "ttnn::selu", operations::unary::ExecuteUnaryCompositeOpWithFloats>(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index 7191b492d44..d102f1687fe 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -23,6 +23,56 @@ namespace unary { namespace detail { +template +void bind_unary_composite_optional_floats_with_default(py::module& module, const unary_operation_t& operation, const std::string& parameter_name_a, const std::string& parameter_a_doc, std::optional parameter_a_value, const std::string& parameter_name_b, const std::string& parameter_b_doc, std::optional parameter_b_value, const std::string& description) { + auto doc = fmt::format( + R"doc( + {8} + + Args: + input_tensor (ttnn.Tensor): the input tensor. + + Keyword args: + {2} (float): {3}. Defaults to `{4}`. + {5} (float): {6}. Defaults to `{7}`. + memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. + + Returns: + ttnn.Tensor: the output tensor. + + Example: + >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) + >>> output = {1}(tensor, {2} = {4}, {5} = {7}) + )doc", + operation.base_name(), + operation.python_fully_qualified_name(), + parameter_name_a, + parameter_a_doc, + parameter_a_value, + parameter_name_b, + parameter_b_doc, + parameter_b_value, + description); + + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const unary_operation_t& self, + const ttnn::Tensor& input_tensor, + std::optional parameter_a, + std::optional parameter_b, + const std::optional& memory_config) { + return self(input_tensor, parameter_a, parameter_b, memory_config); + }, + py::arg("input_tensor"), + py::kw_only(), + py::arg(parameter_name_a.c_str()) = parameter_a_value, + py::arg(parameter_name_b.c_str()) = parameter_b_value, + py::arg("memory_config") = std::nullopt}); +} + template void bind_unary_operation(py::module& module, const unary_operation_t& operation, const std::string& math, const std::string& info_doc = "" ) { auto doc = fmt::format( @@ -1583,12 +1633,12 @@ void py_module(py::module& module) { "low", "Low value", "high", "High value", R"doc(Performs clip function on :attr:`input_tensor`, :attr:`low`, :attr:`high`.)doc"); - detail::bind_unary_composite_floats( + detail::bind_unary_composite_optional_floats_with_default( module, ttnn::clamp, - "low", "Low value", - "high", "High value", - R"doc(Performs clamp function on :attr:`input_tensor`, :attr:`low`, :attr:`high`.)doc"); + "min", "Minimum value", std::nullopt, + "max", "Maximum value", std::nullopt, + R"doc(Performs clamp function on :attr:`input_tensor`, :attr:`min`, :attr:`max`. Only one of 'min' or 'max' value can be None.)doc"); detail::bind_unary_composite_floats_with_default( module, ttnn::selu, diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 5af235d576e..a1079860fac 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -286,7 +286,7 @@ def _golden_function_polygamma(input_tensor_a, k, *args, **kwargs): ttnn.attach_golden_function(ttnn.polygamma, golden_function=_golden_function_polygamma) -def _golden_function_clamp(input_tensor_a, min, max, *args, **kwargs): +def _golden_function_clamp(input_tensor_a, min=None, max=None, *args, **kwargs): import torch return torch.clamp(input=input_tensor_a, min=min, max=max)