Skip to content

Commit

Permalink
#10778: Update Argmin with ttnn support
Browse files Browse the repository at this point in the history
  • Loading branch information
bharane-ab committed Jul 27, 2024
1 parent 74c1595 commit 5bc2eae
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 53 deletions.
3 changes: 0 additions & 3 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -586,9 +586,6 @@ Other Operations

.. autofunction:: tt_lib.tensor.argmax

.. autofunction:: tt_lib.tensor.argmin


Loss Functions
==============

Expand Down
18 changes: 9 additions & 9 deletions tests/ttnn/profiling/ops_for_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,23 +1536,23 @@ def argmax_all(x):


def argmin_1(x):
tt_lib.tensor.argmin(x, dim=-1)
ttnn.argmin(x, dim=-1)


def argmin_2(x):
tt_lib.tensor.argmin(x, dim=-2)
ttnn.argmin(x, dim=-2)


def argmin_3(x):
tt_lib.tensor.argmin(x, dim=-3)
ttnn.argmin(x, dim=-3)


def argmin_4(x):
tt_lib.tensor.argmin(x, dim=-4)
ttnn.argmin(x, dim=-4)


def argmin_all(x):
tt_lib.tensor.argmin(x, dim=-1, all=True)
ttnn.argmin(x, dim=-1, all=True)


def primary_moreh_softmax_0(x):
Expand Down Expand Up @@ -2284,22 +2284,22 @@ def clone(x):
},
{
"op": argmin_1,
"name": "tt_lib.tensor.argmin_dim_3",
"name": "ttnn.argmin_dim_3",
"num_repeats": 2,
},
{
"op": argmin_2,
"name": "tt_lib.tensor.argmin_dim_2",
"name": "ttnn.argmin_dim_2",
"num_repeats": 2,
},
{
"op": argmin_3,
"name": "tt_lib.tensor.argmin_dim_1",
"name": "ttnn.argmin_dim_1",
"num_repeats": 2,
},
{
"op": argmin_all,
"name": "tt_lib.tensor.argmin_all",
"name": "ttnn.argmin_all",
"num_repeats": 2,
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1601,17 +1601,6 @@ Tensor argmax(
return operation::decorate_as_composite(__func__, _argmax)(input_a, dim, all, output_mem_config);
}

Tensor _argmin(const Tensor& input_a, int64_t _dim, bool all, const MemoryConfig& output_mem_config) {
Tensor neg_input = ttnn::neg(input_a, output_mem_config);
return (argmax(neg_input, _dim, all, output_mem_config));
}
Tensor argmin(
const Tensor& input_a,
int64_t dim,
bool all,
const MemoryConfig& output_mem_config /* = operation::DEFAULT_OUTPUT_MEMORY_CONFIG */) {
return operation::decorate_as_composite(__func__, _argmin)(input_a, dim, all, output_mem_config);
}
} // namespace tt_metal

} // namespace tt
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,6 @@ Tensor argmax(
bool all = false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

Tensor argmin(
const Tensor& input_a,
int64_t dim = 0,
bool all = false,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

} // namespace tt_metal

} // namespace tt
Original file line number Diff line number Diff line change
Expand Up @@ -375,30 +375,6 @@ void TensorModuleCompositeOPs(py::module& m_tensor) {
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def(
"argmin",
&argmin,
py::arg("input").noconvert(),
py::arg("dim"),
py::arg("all") = false,
py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
R"doc(
Returns the indices of the minimum value of elements in the ``input`` tensor
If ``all`` is set to ``true`` irrespective of given dimension it will return the indices of minimum value of all elements in given ``input``
Input tensor must have BFLOAT16 data type.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"input", "Tensor argmin is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"dim", "Dimension to perform argmin", "int", "", "Yes"
"all", "Consider all dimension (ignores ``dim`` param)", "bool", "default to false", "No"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def(
"hardtanh",
&hardtanh,
Expand Down

0 comments on commit 5bc2eae

Please sign in to comment.