Skip to content

Commit

Permalink
#14470: Fix atan2 reverse arg issue
Browse files Browse the repository at this point in the history
  • Loading branch information
umadevimcw committed Feb 27, 2025
1 parent 69a36b8 commit be1793d
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ def lerp_ternary(x, y, z, *args, **kwargs):


## Binary Ops
def atan2(x, y, *args, **kwargs):
def atan2(y, x, *args, **kwargs):
return torch.atan2(y, x)


Expand Down
6 changes: 5 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/binary/binary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1631,7 +1631,11 @@ void py_module(py::module& module) {
R"doc(Computes atan2 :attr:`input_tensor_a` and :attr:`input_tensor_b` and returns the tensor with the same layout as :attr:`input_tensor_a`)doc",
R"doc(\mathrm{output\_tensor}_i = \arctan\left(\frac{\mathrm{input\_tensor\_a}_i}{\mathrm{input\_tensor\_b}_i}\right)
)doc",
R"doc(BFLOAT16, BFLOAT8_B)doc");
R"doc(BFLOAT16, BFLOAT8_B)doc",
R"doc(2, 3, 4)doc",
R"doc(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device))doc",
R"doc(ttnn.from_torch(torch.tensor([[1, 2], [3, 4]], dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT, device=device))doc",
R"doc(Input arguments for the atan2 function are in the format (y, x))doc");

detail::bind_binary_operation(
module,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ Tensor ExecuteMaximum::invoke(
return result;
}

Tensor _atan2(const Tensor& input_a, const Tensor& input_b, const std::optional<MemoryConfig>& output_mem_config) {
Tensor _atan2(const Tensor& input_b, const Tensor& input_a, const std::optional<MemoryConfig>& output_mem_config) {
tt::log_info(tt::LogOp, "Input arguments for the atan2 function are in the format (y, x)");
Tensor result(input_a);
{
Tensor atan_input =
Expand Down Expand Up @@ -171,7 +172,7 @@ Tensor _atan2(const Tensor& input_a, const Tensor& input_b, const std::optional<
altz_bltz,
ttnn::subtract(result, M_PI, std::nullopt, output_mem_config),
ttnn::where(
az_bltz, M_PI_2, ttnn::where(az_bgtz, -M_PI_2, 0.0, output_mem_config), output_mem_config),
az_bltz, -M_PI_2, ttnn::where(az_bgtz, M_PI_2, 0.0, output_mem_config), output_mem_config),
output_mem_config),
output_mem_config),
output_mem_config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Tensor _real(const ComplexTensor& input, const MemoryConfig& output_mem_config)
Tensor _imag(const ComplexTensor& input, const MemoryConfig& output_mem_config) { return input[1]; }

Tensor _angle(const ComplexTensor& input, const MemoryConfig& output_mem_config) {
return atan2(input[0], input[1], output_mem_config);
return atan2(input[1], input[0], output_mem_config);
}

Tensor _is_imag(const ComplexTensor& input, const MemoryConfig& output_mem_config) {
Expand Down

0 comments on commit be1793d

Please sign in to comment.