Skip to content

Commit

Permalink
ArgMax workaround implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
mmanzoorTT committed Feb 27, 2025
1 parent 340e05f commit a46fec3
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 75 deletions.
9 changes: 7 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,13 @@ def TTNN_ArgMaxOp : TTNN_NamedDPSOp<"argmax"> {
AnyRankedTensor:$output);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }

wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkaroundsFactory::
createArgMaxOpOperandsWorkarounds();
}
}];

let results = (outs AnyRankedTensor:$result);
}
Expand Down
5 changes: 4 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,12 @@ class TTNNOperandsWorkaroundsFactory {
// dialect.
static TTNNOperandsWorkarounds createConstantOpOperandsWorkarounds();

// Create workarounds for concat op operands.
// Create workarounds for where op operands.
static TTNNOperandsWorkarounds
createWhereOpOperandsWorkarounds(mlir::Operation::operand_range inputs);

// Create workarounds for ArgMax op operands.
static TTNNOperandsWorkarounds createArgMaxOpOperandsWorkarounds();
};

} // namespace mlir::tt::ttnn::wa
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_ARGMAXOPREWRITEPATTERN_H
#define TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_ARGMAXOPREWRITEPATTERN_H

#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"

namespace mlir::tt::ttnn::workarounds::decomposition {

// tt-metal supports ArgMax op for 4D tensors only.
// https://github.com/tenstorrent/tt-metal/issues/18241
// This workaround unsqueeze the input tensor to 4D tennsor (if required) and
// reshape it back to original shape after performing the ArgMax op.
class ArgMaxOpRewritePattern : public OpRewritePattern<ttnn::ArgMaxOp> {
public:
using OpRewritePattern<ttnn::ArgMaxOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ttnn::ArgMaxOp srcOp,
PatternRewriter &rewriter) const override;
};

} // namespace mlir::tt::ttnn::workarounds::decomposition

#endif // TTMLIR_DIALECT_TTNN_TRANSFORMS_WORKAROUNDS_DECOMPOSITION_ARGMAXOPREWRITEPATTERN_H
18 changes: 18 additions & 0 deletions lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,4 +357,22 @@ TTNNOperandsWorkaroundsFactory::createWhereOpOperandsWorkarounds(
.addInputOperandWorkaround(typeWorkaround)
.addOutputOperandWorkaround(typeWorkaround);
}

// Factory method to create a set of workarounds for ArgMax op operands.
// Input tensor must have BFLOAT16 data type and ROW_MAJOR layout.
// tt-metal specs:
// https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/api/ttnn.argmax.html
TTNNOperandsWorkarounds
TTNNOperandsWorkaroundsFactory::createArgMaxOpOperandsWorkarounds() {
wa::TTNNOperandWorkarounds rowMajorLayoutWorkaround;
rowMajorLayoutWorkaround.tensorLayoutWorkaround = Layout::RowMajor;
// rowMajorLayoutWorkaround.tensorDataTypeWorkaround = DataType::UInt32;
wa::TTNNOperandWorkarounds rowMajorLayoutBF16Workaround;
rowMajorLayoutBF16Workaround.tensorLayoutWorkaround = Layout::RowMajor;
rowMajorLayoutBF16Workaround.tensorDataTypeWorkaround = DataType::BFloat16;
return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds()
.addInputOperandWorkaround(rowMajorLayoutBF16Workaround)
.addInputOperandWorkaround(rowMajorLayoutWorkaround)
.addOutputOperandWorkaround(rowMajorLayoutWorkaround);
}
} // namespace mlir::tt::ttnn::wa
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTTNNTransforms
TTNNLayout.cpp
TTNNDecomposeLayouts.cpp
TTNNToCpp.cpp
Workarounds/Decomposition/ArgMaxOpRewritePattern.cpp
Workarounds/Decomposition/CumSumOpRewritePattern.cpp
Workarounds/Decomposition/ReduceOpsRewritePattern.cpp
Workarounds/Decomposition/RepeatOpRewritePattern.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ArgMaxOpRewritePattern.h"
#include "ttmlir/Conversion/TTIRToTTNN/Utils.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/Utils/TransformUtils.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir::tt::ttnn::workarounds::decomposition {

LogicalResult
ArgMaxOpRewritePattern::matchAndRewrite(ttnn::ArgMaxOp srcOp,
PatternRewriter &rewriter) const {
mlir::RankedTensorType inputType =
mlir::cast<RankedTensorType>(srcOp.getInput().getType());
llvm::SmallVector<int64_t> inputTypeShape(inputType.getShape());
if (inputTypeShape.size() >= 4) {
return failure();
}

int64_t inputRank = inputType.getRank();
llvm::SmallVector<int64_t, 4> reshapeOutputShape(4 - inputRank, 1);
reshapeOutputShape.append(inputTypeShape.begin(), inputTypeShape.end());

llvm::ArrayRef<int64_t> reshapedShapeAttr(reshapeOutputShape);

ReshapeOp preReshapeOp = ttir_to_ttnn::utils::generateReshape(
srcOp.getInput(), reshapedShapeAttr, rewriter);

RankedTensorType outputType = srcOp.getResult().getType();
llvm::SmallVector<int64_t> outputTypeShape(outputType.getShape());
llvm::SmallVector<int64_t, 4> argMaxOutputShape(4 - inputRank, 1);
argMaxOutputShape.append(outputTypeShape.begin(), outputTypeShape.end());

ttnn::TTNNLayoutAttr newOutputLayoutAttr =
mlir::cast<ttnn::TTNNLayoutAttr>(outputType.getEncoding())
.withTensorShape(rewriter.getContext(), argMaxOutputShape);
RankedTensorType newOutputType = RankedTensorType::get(
argMaxOutputShape, outputType.getElementType(), newOutputLayoutAttr);

DataTypeAttr dTypeAttr = DataTypeAttr::get(rewriter.getContext(),
newOutputLayoutAttr.getDataType());
ttnn::LayoutAttr tensorLayoutAttr =
ttnn::LayoutAttr::get(getContext(), newOutputLayoutAttr.getLayout());

ttnn::ShapeAttr shapeAttr =
ttnn::ShapeAttr::get(rewriter.getContext(), newOutputType.getShape());

ttnn::BufferTypeAttr bufferTypeAttr = ttnn::BufferTypeAttr::get(
getContext(), newOutputLayoutAttr.getBufferType());
ttnn::ShardSpecAttr shardSpecAttr = ttnn::ShardSpecAttr::get(
getContext(),
ttnn::ShapeAttr::get(getContext(), newOutputLayoutAttr.getShardShape()));
ttnn::MemoryConfigAttr memoryConfigAttr =
ttnn::MemoryConfigAttr::get(getContext(), bufferTypeAttr, shardSpecAttr,
newOutputLayoutAttr.getMemLayout());

EmptyOp emptyOp = rewriter.create<ttnn::EmptyOp>(
srcOp->getLoc(), newOutputType, shapeAttr, dTypeAttr, tensorLayoutAttr,
ttnn::utils::getOrInsertDevice(rewriter, srcOp), memoryConfigAttr);

mlir::IntegerAttr dimAttr;
auto dimArg = srcOp.getDim();
if (dimArg) {
// Update the dimension according to reshaped input.
int32_t dim = *dimArg + (reshapeOutputShape.size() - inputTypeShape.size());
dimAttr =
mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 32), dim);
}
ArgMaxOp argMaxOp = rewriter.create<mlir::tt::ttnn::ArgMaxOp>(
srcOp->getLoc(), newOutputType, preReshapeOp->getResult(0),
dimArg ? dimAttr : nullptr, false, nullptr, emptyOp);

llvm::ArrayRef<int64_t> outputShapeAttr(outputType.getShape());
mlir::TypedValue<mlir::RankedTensorType> argMaxOutput =
mlir::cast<mlir::TypedValue<mlir::RankedTensorType>>(
argMaxOp->getResults().front());

ReshapeOp postReshapeOp = ttir_to_ttnn::utils::generateReshape(
argMaxOutput, outputShapeAttr, rewriter);

rewriter.replaceOp(srcOp, postReshapeOp);

return success();
}

} // namespace mlir::tt::ttnn::workarounds::decomposition
4 changes: 3 additions & 1 deletion lib/Dialect/TTNN/Transforms/Workarounds/TTNNWorkarounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h"
#include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ArgMaxOpRewritePattern.h"
#include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/CumSumOpRewritePattern.h"
#include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/ReduceOpsRewritePattern.h"
#include "ttmlir/Dialect/TTNN/Transforms/Workarounds/Decomposition/RepeatOpRewritePattern.h"
Expand Down Expand Up @@ -436,7 +437,8 @@ class TTNNWorkarounds : public impl::TTNNWorkaroundsBase<TTNNWorkarounds> {
ttnn::MeanOp, /*keepDimUnsupported*/ false>,
workarounds::decomposition::ReduceOpsKeepDimRewritePattern<
ttnn::MinOp, /*keepDimUnsupported*/ false>,
workarounds::decomposition::CumSumOpRewritePattern>(
workarounds::decomposition::CumSumOpRewritePattern,
workarounds::decomposition::ArgMaxOpRewritePattern>(
&getContext());

runRewritePatterns(std::move(patterns),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: ttmlir-opt --ttnn-workaround --canonicalize %s | FileCheck %s

#device = #tt.device<workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, l1Map = (d0, d1)[s0, s1] -> (0, d0 floordiv s0, d1 floordiv s1, (d0 mod s0) * s1 + d1 mod s1), dramMap = (d0, d1)[s0, s1] -> (0, 0, ((((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 8192) mod 12, (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) floordiv 98304 + (((d0 floordiv s0) * 8 + d1 floordiv s1) * (s1 * s0) + (d0 mod s0) * s1 + d1 mod s1) mod 8192), meshShape = , chipIds = [0]>
#dram = #ttnn.buffer_type<dram>
#system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [<f32>, <f16>, <bf16>, <bfp_f8>, <bfp_bf8>, <bfp_f4>, <bfp_bf4>, <bfp_f2>, <bfp_bf2>, <u32>, <u16>, <u8>], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]>
#ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x2x!tt.tile<32x32, f32>, #dram>, <interleaved>>
#ttnn_layout1 = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x1x!tt.tile<32x32, u32>, #dram>, <interleaved>>
module attributes {tt.device = #device, tt.system_desc = #system_desc} {
func.func public @argmax_2d(%arg0: tensor<64x64xf32, #ttnn_layout>) -> tensor<64x1xui32, #ttnn_layout1> {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
// CHECK: %[[PRE_RESHAPE:[0-9]+]] = "ttnn.reshape"(%arg0)
// CHECK-SAME: {shape = [1 : i32, 1 : i32, 64 : i32, 64 : i32]}
// CHECK-SAME: tensor<64x64xf32,
// CHECK-SAME: -> tensor<1x1x64x64xf32
// CHECK: %[[ARG0:[0-9]+]] = "ttnn.to_layout"(%[[PRE_RESHAPE]],
// CHECK-SAME: dtype = #tt.supportedDataTypes<bf16>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: tensor<1x1x64x64xf32,
// CHECK-SAME: -> tensor<1x1x64x64xbf16,
// CHECK: %[[ARG1:[0-9]]] = "ttnn.to_layout"
// CHECK-SAME: dtype = #tt.supportedDataTypes<u32>
// CHECK-SAME: layout = #ttnn.layout<row_major>
// CHECK-SAME: tensor<1x1x64x1xui32
// CHECK-SAME: -> tensor<1x1x64x1xui32
%1 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<u32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<2x1>>, <interleaved>>, shape = #ttnn.shape<64x1>}> : (!tt.device<#device>) -> tensor<64x1xui32, #ttnn_layout1>
// CHECK: [[ARG_MAX:[0-9]+]] = "ttnn.argmax"(%[[ARG0]], %[[ARG1]])
// CHECK-SAME: {dim = 3 : i32, use_multicore = false}
// CHECK-SAME: tensor<1x1x64x64xbf16
// CHECK-SAME: tensor<1x1x64x1xui32
// CHECK-SAME: -> tensor<1x1x64x1xui32
%2 = "ttnn.argmax"(%arg0, %1) <{dim = 1 : i32, use_multicore = false}> : (tensor<64x64xf32, #ttnn_layout>, tensor<64x1xui32, #ttnn_layout1>) -> tensor<64x1xui32, #ttnn_layout1>
// CHECK: %[[TO_LAYOUT:[0-9]+]] = "ttnn.to_layout"(%[[ARG_MAX]],
// CHECK-SAME: dtype = #tt.supportedDataTypes<u32>
// CHECK-SAME: layout = #ttnn.layout<tile>
// CHECK-SAME: (tensor<1x1x64x1xui32
// CHECK-SAME: -> tensor<1x1x64x1xui32
// CHECK: = "ttnn.reshape"(%[[TO_LAYOUT]])
// CHECK-SAME: {shape = [64 : i32, 1 : i32]}
// CHECK-SAME: tensor<1x1x64x1xui32
// CHECK-SAME: -> tensor<64x1xui32
return %2 : tensor<64x1xui32, #ttnn_layout1>
}
}
48 changes: 20 additions & 28 deletions test/ttmlir/Dialect/TTNN/reduction/simple_argmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,35 @@ module attributes {} {
// CHECK-LABEL: func.func public @argmax_2d(
%0 = tensor.empty() : tensor<64x1xi32>
// CHECK: "ttnn.argmax"
// CHECK-SAME: {dim = 1 : i32, use_multicore = false}>
// CHECK-SAME: tensor<64x64xf32
// CHECK-SAME: tensor<64x1xui32
// CHECK-SAME: -> tensor<64x1xui32
// CHECK-SAME: {dim = 3 : i32, use_multicore = false}>
// CHECK-SAME: tensor<1x1x64x64xbf16
// CHECK-SAME: tensor<1x1x64x1xui32
// CHECK-SAME: -> tensor<1x1x64x1xui32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [1 : i32], keep_dim = true}> : (tensor<64x64xf32>, tensor<64x1xi32>) -> tensor<64x1xi32>
return %1 : tensor<64x1xi32>
}

func.func public @argmax_3d(%arg0: tensor<128x28x28xf32>) -> tensor<128x28xi32> {
func.func public @argmax_3d(%arg0: tensor<1x28x28xf32>) -> tensor<1x28xi32> {
// CHECK-LABEL: func.func public @argmax_3d(
%0 = tensor.empty() : tensor<128x28xi32>
// CHECK: %[[ARGMAX:[0-9]+]] = "ttnn.argmax"
// CHECK-SAME: {dim = 2 : i32, use_multicore = false}>
// CHECK-SAME: tensor<128x28x28xf32
// CHECK-SAME: tensor<128x28x1xui32
// CHECK-SAME: -> tensor<128x28x1xui32
// CHECK: "ttnn.reshape"(%[[ARGMAX]])
// CHECK-SAME: <{shape = [128 : i32, 28 : i32]}>
// CHECK-SAME: tensor<128x28x1xui32
// CHECK-SAME: -> tensor<128x28xui32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = false}> : (tensor<128x28x28xf32>, tensor<128x28xi32>) -> tensor<128x28xi32>
return %1 : tensor<128x28xi32>
%0 = tensor.empty() : tensor<1x28xi32>
// CHECK: "ttnn.argmax"
// CHECK-SAME: {dim = 3 : i32, use_multicore = false}>
// CHECK-SAME: tensor<1x1x28x28xbf16
// CHECK-SAME: tensor<1x1x28x1xui32
// CHECK-SAME: -> tensor<1x1x28x1xui32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [2 : i32], keep_dim = false}> : (tensor<1x28x28xf32>, tensor<1x28xi32>) -> tensor<1x28xi32>
return %1 : tensor<1x28xi32>
}

func.func public @argmax_4d(%arg0: tensor<4x8x128x64xf32>) -> tensor<4x8x128xi32> {
func.func public @argmax_4d(%arg0: tensor<1x1x128x64xf32>) -> tensor<1x1x128xi32> {
// CHECK-LABEL: func.func public @argmax_4d(
%0 = tensor.empty() : tensor<4x8x128xi32>
%0 = tensor.empty() : tensor<1x1x128xi32>
// CHECK: %[[ARGMAX:[0-9]+]] = "ttnn.argmax"
// CHECK-SAME: {dim = 3 : i32, use_multicore = false}>
// CHECK-SAME: tensor<4x8x128x64xf32
// CHECK-SAME: tensor<4x8x128x1xui32
// CHECK-SAME: -> tensor<4x8x128x1xui32
// CHECK: "ttnn.reshape"(%[[ARGMAX]])
// CHECK-SAME: <{shape = [4 : i32, 8 : i32, 128 : i32]}>
// CHECK-SAME: tensor<4x8x128x1xui32
// CHECK-SAME: -> tensor<4x8x128xui32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [3 : i32], keep_dim = false}> : (tensor<4x8x128x64xf32>, tensor<4x8x128xi32>) -> tensor<4x8x128xi32>
return %1 : tensor<4x8x128xi32>
// CHECK-SAME: tensor<1x1x128x64xbf16
// CHECK-SAME: tensor<1x1x128x1xui32
// CHECK-SAME: -> tensor<1x1x128x1xui32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [3 : i32], keep_dim = false}> : (tensor<1x1x128x64xf32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi32>
return %1 : tensor<1x1x128xi32>
}
}
15 changes: 0 additions & 15 deletions test/ttmlir/EmitC/TTNN/reduction/argmax.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,9 @@
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %basename_t.ttnn
// RUN: ttmlir-opt --ttnn-modify-signatures-for-dylib --convert-ttnn-to-emitc %t.mlir > %t2.mlir
// RUN: ttmlir-translate --mlir-to-cpp %t2.mlir > %basename_t.cpp
//
// UNSUPPORTED: true
// These tests are currently failing due to tt-metal restrictions for argmax op.
// tt-metal specs:
// https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/api/ttnn.argmax.html

// TODO(mmanzoor): Enable these tests after adding workarounds to overcome these
// limitations.
// https://github.com/tenstorrent/tt-mlir/issues/2057

func.func public @argmax_2d(%arg0: tensor<64x64xf32>) -> tensor<64xi32> {
// CHECK-LABEL: func.func public @argmax_2d(
%0 = tensor.empty() : tensor<64xi32>
// CHECK: "ttnn.argmax"
// CHECK-SAME: {dim = 1 : i32, use_multicore = false}>
// CHECK-SAME: tensor<64x64xf32
// CHECK-SAME: tensor<64xi32
// CHECK-SAME: -> tensor<64xi32
%1 = "ttir.argmax"(%arg0, %0) <{dim_arg = [1 : i32], keep_dim = false}> : (tensor<64x64xf32>, tensor<64xi32>) -> tensor<64xi32>
return %1 : tensor<64xi32>
}
Loading

0 comments on commit a46fec3

Please sign in to comment.