Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ArgMax workaround implementation #2304

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,13 @@ def TTNN_ArgMaxOp : TTNN_NamedDPSOp<"argmax"> {
AnyRankedTensor:$output);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }

wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkaroundsFactory::
createArgMaxOpOperandsWorkarounds();
}
}];

let results = (outs AnyRankedTensor:$result);
}
Expand Down
5 changes: 4 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,12 @@ class TTNNOperandsWorkaroundsFactory {
// dialect.
static TTNNOperandsWorkarounds createConstantOpOperandsWorkarounds();

// Create workarounds for concat op operands.
// Create workarounds for where op operands.
static TTNNOperandsWorkarounds
createWhereOpOperandsWorkarounds(mlir::Operation::operand_range inputs);

// Create workarounds for ArgMax op operands.
static TTNNOperandsWorkarounds createArgMaxOpOperandsWorkarounds();
};

} // namespace mlir::tt::ttnn::wa
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_ARGMAXOPREWRITEPATTERN_H
#define TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_ARGMAXOPREWRITEPATTERN_H

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"

namespace mlir::tt::ttnn::workarounds::decomposition {

// tt-metal supports ArgMax op for 4D tensors only.
// https://github.com/tenstorrent/tt-metal/issues/18241
// This workaround unsqueeze the input tensor to 4D tennsor (if required) and
// reshape it back to original shape after performing the ArgMax op.
class ArgMaxOpRewritePattern : public OpRewritePattern<ttnn::ArgMaxOp> {
public:
using OpRewritePattern<ttnn::ArgMaxOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ttnn::ArgMaxOp srcOp,
PatternRewriter &rewriter) const override;
};

} // namespace mlir::tt::ttnn::workarounds::decomposition

#endif // TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_ARGMAXOPREWRITEPATTERN_H
18 changes: 18 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,22 @@ TTNNOperandsWorkaroundsFactory::createWhereOpOperandsWorkarounds(
.addInputOperandWorkaround(typeWorkaround)
.addOutputOperandWorkaround(typeWorkaround);
}

// Factory method to create a set of workarounds for ArgMax op operands.
// Input tensor must have BFLOAT16 data type and ROW_MAJOR layout.
// tt-metal specs:
// https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/api/ttnn.argmax.html
TTNNOperandsWorkarounds
TTNNOperandsWorkaroundsFactory::createArgMaxOpOperandsWorkarounds() {
wa::TTNNOperandWorkarounds rowMajorLayoutWorkaround;
rowMajorLayoutWorkaround.tensorLayoutWorkaround = Layout::RowMajor;
// rowMajorLayoutWorkaround.tensorDataTypeWorkaround = DataType::UInt32;
wa::TTNNOperandWorkarounds rowMajorLayoutBF16Workaround;
rowMajorLayoutBF16Workaround.tensorLayoutWorkaround = Layout::RowMajor;
rowMajorLayoutBF16Workaround.tensorDataTypeWorkaround = DataType::BFloat16;
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds()
.addInputOperandWorkaround(rowMajorLayoutBF16Workaround)
.addInputOperandWorkaround(rowMajorLayoutWorkaround)
.addOutputOperandWorkaround(rowMajorLayoutWorkaround);
}
} // namespace mlir::tt::ttnn::wa
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTTNNTransforms
TTNNLayout.cpp
TTNNDecomposeLayouts.cpp
TTNNToCpp.cpp
Workarounds/Decomposition/ArgMaxOpRewritePattern.cpp
Workarounds/Decomposition/CumSumOpRewritePattern.cpp
Workarounds/Decomposition/ReduceOpsRewritePattern.cpp
Workarounds/Decomposition/RepeatOpRewritePattern.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ArgMaxOpRewritePattern.h"
#include "ttmlir/Conversion/TTIRToTTNN/Utils.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir::tt::ttnn::workarounds::decomposition {

LogicalResult
ArgMaxOpRewritePattern::matchAndRewrite(ttnn::ArgMaxOp srcOp,
PatternRewriter &rewriter) const {
mlir::RankedTensorType inputType =
mlir::cast<RankedTensorType>(srcOp.getInput().getType());
llvm::SmallVector<int64_t> inputTypeShape(inputType.getShape());
if (inputTypeShape.size() >= 4) {
return failure();
}

int64_t inputRank = inputType.getRank();
llvm::SmallVector<int64_t, 4> reshapeOutputShape(4 - inputRank, 1);
reshapeOutputShape.append(inputTypeShape.begin(), inputTypeShape.end());

llvm::ArrayRef<int64_t> reshapedShapeAttr(reshapeOutputShape);

ReshapeOp preReshapeOp = ttir_to_ttnn::utils::generateReshape(
srcOp.getInput(), reshapedShapeAttr, rewriter);

RankedTensorType outputType = srcOp.getResult().getType();
llvm::SmallVector<int64_t> outputTypeShape(outputType.getShape());
llvm::SmallVector<int64_t, 4> argMaxOutputShape(4 - inputRank, 1);
argMaxOutputShape.append(outputTypeShape.begin(), outputTypeShape.end());

ttnn::TTNNLayoutAttr newOutputLayoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(outputType.getEncoding())
.withTensorShape(rewriter.getContext(), argMaxOutputShape);
RankedTensorType newOutputType = RankedTensorType::get(
argMaxOutputShape, outputType.getElementType(), newOutputLayoutAttr);

DataTypeAttr dTypeAttr = DataTypeAttr::get(rewriter.getContext(),
newOutputLayoutAttr.getDataType());
ttnn::LayoutAttr tensorLayoutAttr =
ttnn::LayoutAttr::get(getContext(), newOutputLayoutAttr.getLayout());

ttnn::ShapeAttr shapeAttr =
ttnn::ShapeAttr::get(rewriter.getContext(), newOutputType.getShape());

ttnn::BufferTypeAttr bufferTypeAttr = ttnn::BufferTypeAttr::get(
getContext(), newOutputLayoutAttr.getBufferType());
ttnn::ShardSpecAttr shardSpecAttr = ttnn::ShardSpecAttr::get(
getContext(),
ttnn::ShapeAttr::get(getContext(), newOutputLayoutAttr.getShardShape()));
ttnn::MemoryConfigAttr memoryConfigAttr =
ttnn::MemoryConfigAttr::get(getContext(), bufferTypeAttr, shardSpecAttr,
newOutputLayoutAttr.getMemLayout());

EmptyOp emptyOp = rewriter.create<ttnn::EmptyOp>(
srcOp->getLoc(), newOutputType, shapeAttr, dTypeAttr, tensorLayoutAttr,
ttnn::utils::getOrInsertDevice(rewriter, srcOp), memoryConfigAttr);

mlir::IntegerAttr dimAttr;
auto dimArg = srcOp.getDim();
if (dimArg) {
// Update the dimension according to reshaped input.
int32_t dim = *dimArg + (reshapeOutputShape.size() - inputTypeShape.size());
dimAttr =
mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 32), dim);
}
ArgMaxOp argMaxOp = rewriter.create<mlir::tt::ttnn::ArgMaxOp>(
srcOp->getLoc(), newOutputType, preReshapeOp->getResult(0),
dimArg ? dimAttr : nullptr, false, nullptr, emptyOp);

llvm::ArrayRef<int64_t> outputShapeAttr(outputType.getShape());
mlir::TypedValue<mlir::RankedTensorType> argMaxOutput =
mlir::cast<mlir::TypedValue<mlir::RankedTensorType>>(
argMaxOp->getResults().front());

ReshapeOp postReshapeOp = ttir_to_ttnn::utils::generateReshape(
argMaxOutput, outputShapeAttr, rewriter);

rewriter.replaceOp(srcOp, postReshapeOp);

return success();
}

} // namespace mlir::tt::ttnn::workarounds::decomposition
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h"
#include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ArgMaxOpRewritePattern.h"
#include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpRewritePattern.h"
#include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h"
#include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/RepeatOpRewritePattern.h"
Expand Down Expand Up @@ -436,7 +437,8 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
ttnn::MeanOp, /*keepDimUnsupported*/ false>,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::MinOp, /*keepDimUnsupported*/ false>,
workarounds::decomposition::CumSumOpRewritePattern>(
workarounds::decomposition::CumSumOpRewritePattern,
workarounds::decomposition::ArgMaxOpRewritePattern>(
&getContext());

runRewritePatterns(std::move(patterns),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: ttmlir-opt --ttnn-workaround --canonicalize %s | FileCheck %s

#device = #tt.device<workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]>
#dram = #ttnn.buffer_type<dram>
#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [<f32>, <f16>, <bf16>, <bfp_f8>, <bfp_bf8>, <bfp_f4>, <bfp_bf4>, <bfp_f2>, <bfp_bf2>, <u32>, <u16>, <u8>], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]>
#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x2x!tt.tile<32x32, f32>, #dram>, <interleaved>>
#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x1x!tt.tile<32x32, u32>, #dram>, <interleaved>>
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
func.func public @argmax_2d(%arg0: tensor<64x64xf32, #ttnn_layout>) -> tensor<64x1xui32, #ttnn_layout1> {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
// CHECK: %[[PRE_RESHAPE:[0-9]+]] = "ttnn.reshape"(%arg0)
// CHECK-SAME: {shape = [1 : i32, 1 : i32, 64 : i32, 64 : i32]}
// CHECK-SAME: tensor<64x64xf32,
// CHECK-SAME: -> tensor<1x1x64x64xf32
// CHECK: %[[ARG0:[0-9]+]] = "ttnn.to_layout"(%[[PRE_RESHAPE]],
// CHECK-SAME: dtype = #tt.supportedDataTypes<bf16>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: tensor<1x1x64x64xf32,
// CHECK-SAME: -> tensor<1x1x64x64xbf16,
// CHECK: %[[ARG1:[0-9]]] = "ttnn.to_layout"
// CHECK-SAME: dtype = #tt.supportedDataTypes<u32>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: tensor<1x1x64x1xui32
// CHECK-SAME: -> tensor<1x1x64x1xui32
%1 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<u32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<2x1>>, <interleaved>>, shape = #ttnn.shape<64x1>}> : (!tt.device<#device>) -> tensor<64x1xui32, #ttnn_layout1>
// CHECK: [[ARG_MAX:[0-9]+]] = "ttnn.argmax"(%[[ARG0]], %[[ARG1]])
// CHECK-SAME: {dim = 3 : i32, use_multicore = false}
// CHECK-SAME: tensor<1x1x64x64xbf16
// CHECK-SAME: tensor<1x1x64x1xui32
// CHECK-SAME: -> tensor<1x1x64x1xui32
%2 = "ttnn.argmax"(%arg0, %1) <{dim = 1 : i32, use_multicore = false}> : (tensor<64x64xf32, #ttnn_layout>, tensor<64x1xui32, #ttnn_layout1>) -> tensor<64x1xui32, #ttnn_layout1>
// CHECK: %[[TO_LAYOUT:[0-9]+]] = "ttnn.to_layout"(%[[ARG_MAX]],
// CHECK-SAME: dtype = #tt.supportedDataTypes<u32>
// CHECK-SAME: layout = #ttnn.layout<tile>
// CHECK-SAME: (tensor<1x1x64x1xui32
// CHECK-SAME: -> tensor<1x1x64x1xui32
// CHECK: = "ttnn.reshape"(%[[TO_LAYOUT]])
// CHECK-SAME: {shape = [64 : i32, 1 : i32]}
// CHECK-SAME: tensor<1x1x64x1xui32
// CHECK-SAME: -> tensor<64x1xui32
return %2 : tensor<64x1xui32, #ttnn_layout1>
}
}
48 changes: 20 additions & 28 deletions test/ttmlir/Dialect/TTNN/reduction/simple_argmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,35 @@ module attributes {} {
// CHECK-LABEL: func.func public @argmax_2d(
%0 = tensor.empty() : tensor<64x1xi32>
// CHECK: "ttnn.argmax"
// CHECK-SAME: {dim = 1 : i32, use_multicore = false}>
// CHECK-SAME: tensor<64x64xf32
// CHECK-SAME: tensor<64x1xui32
// CHECK-SAME: -> tensor<64x1xui32
// CHECK-SAME: {dim = 3 : i32, use_multicore = false}>
// CHECK-SAME: tensor<1x1x64x64xbf16
// CHECK-SAME: tensor<1x1x64x1xui32
// CHECK-SAME: -> tensor<1x1x64x1xui32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [1 : i32], keep_dim = true}> : (tensor<64x64xf32>, tensor<64x1xi32>) -> tensor<64x1xi32>
return %1 : tensor<64x1xi32>
}

func.func public @argmax_3d(%arg0: tensor<128x28x28xf32>) -> tensor<128x28xi32> {
func.func public @argmax_3d(%arg0: tensor<1x28x28xf32>) -> tensor<1x28xi32> {
// CHECK-LABEL: func.func public @argmax_3d(
%0 = tensor.empty() : tensor<128x28xi32>
// CHECK: %[[ARGMAX:[0-9]+]] = "ttnn.argmax"
// CHECK-SAME: {dim = 2 : i32, use_multicore = false}>
// CHECK-SAME: tensor<128x28x28xf32
// CHECK-SAME: tensor<128x28x1xui32
// CHECK-SAME: -> tensor<128x28x1xui32
// CHECK: "ttnn.reshape"(%[[ARGMAX]])
// CHECK-SAME: <{shape = [128 : i32, 28 : i32]}>
// CHECK-SAME: tensor<128x28x1xui32
// CHECK-SAME: -> tensor<128x28xui32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = false}> : (tensor<128x28x28xf32>, tensor<128x28xi32>) -> tensor<128x28xi32>
return %1 : tensor<128x28xi32>
%0 = tensor.empty() : tensor<1x28xi32>
// CHECK: "ttnn.argmax"
// CHECK-SAME: {dim = 3 : i32, use_multicore = false}>
// CHECK-SAME: tensor<1x1x28x28xbf16
// CHECK-SAME: tensor<1x1x28x1xui32
// CHECK-SAME: -> tensor<1x1x28x1xui32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = false}> : (tensor<1x28x28xf32>, tensor<1x28xi32>) -> tensor<1x28xi32>
return %1 : tensor<1x28xi32>
}

func.func public @argmax_4d(%arg0: tensor<4x8x128x64xf32>) -> tensor<4x8x128xi32> {
func.func public @argmax_4d(%arg0: tensor<1x1x128x64xf32>) -> tensor<1x1x128xi32> {
// CHECK-LABEL: func.func public @argmax_4d(
%0 = tensor.empty() : tensor<4x8x128xi32>
%0 = tensor.empty() : tensor<1x1x128xi32>
// CHECK: %[[ARGMAX:[0-9]+]] = "ttnn.argmax"
// CHECK-SAME: {dim = 3 : i32, use_multicore = false}>
// CHECK-SAME: tensor<4x8x128x64xf32
// CHECK-SAME: tensor<4x8x128x1xui32
// CHECK-SAME: -> tensor<4x8x128x1xui32
// CHECK: "ttnn.reshape"(%[[ARGMAX]])
// CHECK-SAME: <{shape = [4 : i32, 8 : i32, 128 : i32]}>
// CHECK-SAME: tensor<4x8x128x1xui32
// CHECK-SAME: -> tensor<4x8x128xui32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [3 : i32], keep_dim = false}> : (tensor<4x8x128x64xf32>, tensor<4x8x128xi32>) -> tensor<4x8x128xi32>
return %1 : tensor<4x8x128xi32>
// CHECK-SAME: tensor<1x1x128x64xbf16
// CHECK-SAME: tensor<1x1x128x1xui32
// CHECK-SAME: -> tensor<1x1x128x1xui32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [3 : i32], keep_dim = false}> : (tensor<1x1x128x64xf32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi32>
return %1 : tensor<1x1x128xi32>
}
}
15 changes: 0 additions & 15 deletions test/ttmlir/EmitC/TTNN/reduction/argmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,9 @@
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %basename_t.ttnn
// RUN: ttmlir-opt --ttnn-modify-signatures-for-dylib --convert-ttnn-to-emitc %t.mlir > %t2.mlir
// RUN: ttmlir-translate --mlir-to-cpp %t2.mlir > %basename_t.cpp
//
// UNSUPPORTED: true
// These tests are currently failing due to tt-metal restrictions for argmax op.
// tt-metal specs:
// https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/api/ttnn.argmax.html

// TODO(mmanzoor): Enable these tests after adding workarounds to overcome these
// limitations.
// https://github.com/tenstorrent/tt-mlir/issues/2057

func.func public @argmax_2d(%arg0: tensor<64x64xf32>) -> tensor<64xi32> {
// CHECK-LABEL: func.func public @argmax_2d(
%0 = tensor.empty() : tensor<64xi32>
// CHECK: "ttnn.argmax"
// CHECK-SAME: {dim = 1 : i32, use_multicore = false}>
// CHECK-SAME: tensor<64x64xf32
// CHECK-SAME: tensor<64xi32
// CHECK-SAME: -> tensor<64xi32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [1 : i32], keep_dim = false}> : (tensor<64x64xf32>, tensor<64xi32>) -> tensor<64xi32>
return %1 : tensor<64xi32>
}
Loading
Loading