From be1793d62eb5d534f9b8c52252d2bfded9738d2b Mon Sep 17 00:00:00 2001 From: umadevimcw Date: Tue, 25 Feb 2025 09:50:41 +0000 Subject: [PATCH] #14470: Fix atan2 reverse arg issue --- .../tt_eager/python_api_testing/sweep_tests/pytorch_ops.py | 2 +- ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp | 6 +++++- .../eltwise/binary/device/binary_composite_op.cpp | 5 +++-- .../eltwise/complex_unary/device/complex_unary_op.cpp | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index fcc41f186a6..8ab4706538b 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -934,7 +934,7 @@ def lerp_ternary(x, y, z, *args, **kwargs): ## Binary Ops -def atan2(x, y, *args, **kwargs): +def atan2(y, x, *args, **kwargs): return torch.atan2(y, x) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp index 2f70f722368..372415a45cc 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp @@ -1631,7 +1631,11 @@ void py_module(py::module& module) { R"doc(Computes atan2 :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc", R"doc(\mathrm{output\_tensor}_i = \arctan\left(\frac{\mathrm{input\_tensor\_a}_i}{\mathrm{input\_tensor\_b}_i}\right) )doc", - R"doc(BFLOAT16, BFLOAT8_B)doc"); + R"doc(BFLOAT16, BFLOAT8_B)doc", + R"doc(2, 3, 4)doc", + R"doc(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device))doc", + R"doc(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device))doc", + R"doc(Input arguments for the atan2 function are in the format (y, x))doc"); detail::bind_binary_operation( module, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index 7a9cbc4be60..2996845fa6c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -137,7 +137,8 @@ Tensor ExecuteMaximum::invoke( return result; } -Tensor _atan2(const Tensor& input_a, const Tensor& input_b, const std::optional& output_mem_config) { +Tensor _atan2(const Tensor& input_b, const Tensor& input_a, const std::optional& output_mem_config) { + tt::log_info(tt::LogOp, "Input arguments for the atan2 function are in the format (y, x)"); Tensor result(input_a); { Tensor atan_input = @@ -171,7 +172,7 @@ Tensor _atan2(const Tensor& input_a, const Tensor& input_b, const std::optional< altz_bltz, ttnn::subtract(result, M_PI, std::nullopt, output_mem_config), ttnn::where( - az_bltz, M_PI_2, ttnn::where(az_bgtz, -M_PI_2, 0.0, output_mem_config), output_mem_config), + az_bltz, -M_PI_2, ttnn::where(az_bgtz, M_PI_2, 0.0, output_mem_config), output_mem_config), output_mem_config), output_mem_config), output_mem_config); diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp index 11fac7d9ae6..8cee3616bcf 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.cpp @@ -17,7 +17,7 @@ Tensor _real(const ComplexTensor& input, const MemoryConfig& output_mem_config) Tensor _imag(const ComplexTensor& input, const MemoryConfig& output_mem_config) { return input[1]; } Tensor _angle(const ComplexTensor& input, const MemoryConfig& output_mem_config) { - return atan2(input[0], input[1], output_mem_config); + return atan2(input[1], input[0], output_mem_config); } Tensor _is_imag(const ComplexTensor& input, const MemoryConfig& output_mem_config) {