Skip to content

Commit

Permalink
#10641: remove where op
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Jul 29, 2024
1 parent bd31709 commit 6632038
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 405 deletions.
7 changes: 0 additions & 7 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,6 @@ Tensor elementwise operations

.. autofunction:: tt_lib.tensor.rfloor_div

Tensor relational operations
============================


Tensor ternary operations
=========================
.. autofunction:: tt_lib.tensor.where

Tensor manipulation operations
-=============================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

shapes = [
[[1, 1, 32, 32]], # Single core
# [[1, 1, 320, 384]], # Multi core
# [[1, 3, 320, 384]], # Multi core
[[1, 1, 320, 384]], # Multi core
[[1, 3, 320, 384]], # Multi core
]
input_mem_cfgs = copy.copy(generation_funcs.supported_mem_configs)
output_mem_cfgs = copy.copy(generation_funcs.supported_mem_configs)
Expand Down
6 changes: 3 additions & 3 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,7 @@ def where(x, y, z, device, dtype, layout, input_mem_config, output_mem_config, *
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1])
t2 = setup_tt_tensor(z, device, layout[2], input_mem_config[2], dtype[2])
t3 = ttl.tensor.where(t0, t1, t2, output_mem_config=output_mem_config)
t3 = ttnn.where(t0, t1, t2, memory_config=output_mem_config)

return tt2torch_tensor(t3)

Expand All @@ -1644,7 +1644,7 @@ def where_optional(x, y, z, out, device, dtype, layout, input_mem_config, output
t2 = setup_tt_tensor(z, device, layout[2], input_mem_config[2], dtype[2])
t3 = setup_tt_tensor(out, device, layout[3], input_mem_config[3], dtype[3])
cq_id = 0
ttl.tensor.where(t0, t1, t2, output_tensor=t3, queue_id=cq_id)
ttnn.where(t0, t1, t2, output_tensor=t3, queue_id=cq_id)

return tt2torch_tensor(t3)

Expand All @@ -1656,7 +1656,7 @@ def where_scalar_optional(
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t3 = setup_tt_tensor(out, device, layout[1], input_mem_config[1], dtype[1])
cq_id = 0
ttl.tensor.where(t0, scalar_true, scalar_false, output_tensor=t3, queue_id=cq_id)
ttnn.where(t0, scalar_true, scalar_false, output_tensor=t3, queue_id=cq_id)

return tt2torch_tensor(t3)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "ttnn/operations/data_movement/pad/pad.hpp"
#include "tt_numpy/functions.hpp"
#include "ttnn/deprecated/tt_dnn/op_library/copy/copy_op.hpp"
#include "ttnn/cpp/ttnn/operations/eltwise/ternary/where_op.hpp"

#include "ttnn/operations/eltwise/binary/binary.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
Expand Down Expand Up @@ -56,7 +57,7 @@ std::vector<Tensor> _complex_recip_bw(const Tensor& grad, const Tensor& input, c
input_i.deallocate();
Tensor nan_flag = mk_complex(condition_nan, condition_nan, output_mem_config);
condition_nan.deallocate();
Tensor grad_result = where(
Tensor grad_result = ttnn::where(
nan_flag,
full_like(input, std::nanf(""), output_mem_config),
complex_mul(
Expand Down Expand Up @@ -118,15 +119,15 @@ std::vector<Tensor> _angle_bw(
Tensor abs_squared = ttnn::reciprocal(
ttnn::add(ttnn::square(inp_r, output_mem_config), ttnn::square(inp_i, output_mem_config), std::nullopt, output_mem_config),
output_mem_config);
Tensor real = where(
Tensor real = ttnn::where(
condition_zero,
zeros_like(inp_r, output_mem_config),
ttnn::multiply(grad,
ttnn::multiply(ttnn::neg(inp_i, output_mem_config), abs_squared, std::nullopt, output_mem_config),
std::nullopt,
output_mem_config),
output_mem_config);
Tensor imag = where(
Tensor imag = ttnn::where(
condition_zero,
zeros_like(inp_i, output_mem_config),
ttnn::multiply(grad, ttnn::multiply(inp_r, abs_squared, std::nullopt, output_mem_config), std::nullopt, output_mem_config),
Expand Down Expand Up @@ -158,7 +159,7 @@ std::vector<Tensor> _complex_abs_bw(const Tensor& grad, const Tensor& input, con
Tensor result = complex_abs(input, output_mem_config);
result = mk_complex(result, result, output_mem_config);
Tensor grad_c = mk_complex(grad, grad, output_mem_config);
Tensor grad_result = where(
Tensor grad_result = ttnn::where(
ttnn::eqz(result, output_mem_config),
zeros_like(result, output_mem_config),
ttnn::multiply(grad_c,
Expand All @@ -184,7 +185,7 @@ std::vector<Tensor> _polar_bw(
Tensor result = polar(input_a, input_b, output_mem_config);
Tensor abs_result = complex_abs(result, output_mem_config);
abs_result = mk_complex(abs_result, abs_result, output_mem_config);
Tensor sgn_result = where(
Tensor sgn_result = ttnn::where(
ttnn::eqz(abs_result, output_mem_config),
zeros_like(result, output_mem_config),
ttnn::multiply(result, ttnn::reciprocal(abs_result, output_mem_config), std::nullopt, output_mem_config),
Expand Down Expand Up @@ -229,14 +230,14 @@ std::vector<Tensor> _complex_div_bw(
other_i.deallocate();
Tensor nan_flag = mk_complex(condition_nan, condition_nan, output_mem_config);
condition_nan.deallocate();
Tensor grad_a = where(
Tensor grad_a = ttnn::where(
nan_flag,
full_like(input, std::nanf(""), output_mem_config),
complex_div(grad, conj(other, output_mem_config), output_mem_config),
output_mem_config);
grad_tensor.emplace_back(grad_a);
Tensor result = complex_div(input, other, output_mem_config);
Tensor grad_b = where(
Tensor grad_b = ttnn::where(
nan_flag,
full_like(input, std::nanf(""), output_mem_config),
complex_mul(
Expand Down
Loading

0 comments on commit 6632038

Please sign in to comment.