Skip to content

Commit

Permalink
Add support for stablehlo.reduce op for logical or operator (#2160)
Browse files Browse the repository at this point in the history
TTNN does not support reduction for logical or operator. So
stablehlo.reduce for stablehlo.or operator is decomposed into reduction
sum op along give dimension. If ttnn.sum output is zero then reduce_or
output is false; otherwise the output is true.

### Ticket
#1143

### Problem description
Add support for reduction operation for logical or operator

### What's changed
- `ttir.reduce_or` op is added in TTIR dialect
- `ttir.reduce_or` op is decomposed/converted to `ttir.sum` op as
tt-metal does not support reduction or operation.
- Stablehlo conversion for reduce or op.

### Checklist
- [X] New tests provide coverage for changes
  • Loading branch information
mmanzoorTT authored Feb 24, 2025
1 parent 3af5027 commit b020a93
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 3 deletions.
22 changes: 22 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,28 @@ def TTIR_ReduceAndOp : TTIR_ReductionOp<"reduce_and"> {
}];
}

def TTIR_ReduceOrOp : TTIR_ReductionOp<"reduce_or"> {
let summary = "Or reduction op.";
let description = [{
Reduces a given tensor using logical or operator along the given dimension(s).

Example:
input: [[True, False, False, False],
[True, True, False, True],
[False, False, False, True],
[False, False, False, False]]

// Reduction along dim 0
output: [True, True, False, True]

// Reduction along dim 1
output: [True, True, True, False]

// Reduction for both dimensions (entire tensor)
output: [True]
}];
}

def TTIR_ProdOp : TTIR_ReductionOp<"prod"> {
let summary = "Product reduction op.";
let description = [{
Expand Down
17 changes: 14 additions & 3 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ class StableHLOToTTIRReduceOpConversionPattern
return matchAndRewriteInternal<mlir::tt::ttir::ReduceAndOp>(
srcOp, adaptor, rewriter);
}
if (mlir::isa<mlir::stablehlo::OrOp>(innerOp)) {
return matchAndRewriteInternal<mlir::tt::ttir::ReduceOrOp>(srcOp, adaptor,
rewriter);
}
if (isArgMax(srcOp, adaptor, rewriter)) {
return matchAndRewriteInternalArgMax(srcOp, adaptor, rewriter);
}
Expand Down Expand Up @@ -129,16 +133,23 @@ class StableHLOToTTIRReduceOpConversionPattern
}

mlir::Operation &innerOp = srcOp.getBody().front().front();
if (mlir::isa<mlir::stablehlo::AndOp>(innerOp)) {
if (mlir::isa<mlir::stablehlo::AndOp>(innerOp) ||
mlir::isa<mlir::stablehlo::OrOp>(innerOp)) {
bool allOperandsAreBoolean = std::all_of(
srcOp->operand_begin(), srcOp->operand_end(), [](auto operand) {
return mlir::cast<RankedTensorType>(operand.getType())
.getElementTypeBitWidth() == 1;
});
// Stablehlo (unlike other dialects) has single op for both logical and
// bitwise operation. Data type is used to distinguish between logical and
// bitwise operation. If the datatype is boolean then it is a logical
// operation; otherwise it is bitwise operation. This check ensure that
// the inputs are boolean as tt-metal only supports logical operations.
if (!allOperandsAreBoolean) {
return rewriter.notifyMatchFailure(
srcOp, "stablehlo.reduce for stablehlo.and operator is only "
"supported for logical and.");
srcOp,
"stablehlo.reduce for stablehlo.and/stablehlo.or operator is only "
"supported for logical operator.");
}
}

Expand Down
25 changes: 25 additions & 0 deletions lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,30 @@ struct ArgMaxOpKeepDimConversionPattern
};
} // namespace

// TTNN does not support reduction operation for logical or. So this reduction
// is performed by decomposing/converting into reduction sum (ttnn.sum op).
// If ttnn.sum output is zero then reduce_or output is false; otherwise the
// output is true.
namespace {
struct ReductionOrPattern : public OpConversionPattern<ttir::ReduceOrOp> {
public:
using OpConversionPattern<ttir::ReduceOrOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::ReduceOrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType reduceOutputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType()));

ttmlir::utils::replaceOpWithNewDPSOp<ttir::SumOp>(
rewriter, op, reduceOutputType, adaptor.getInput(), op.getKeepDim(),
op.getDimArgAttr());

return success();
}
};
} // namespace

void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
RewritePatternSet &patterns,
TypeConverter &typeConverter) {
Expand All @@ -1356,6 +1380,7 @@ void populateTTIRToTTIRDecompositionPatterns(MLIRContext *ctx,
patterns.add<ArangeForceLastDimensionPattern>(typeConverter, ctx);
patterns.add<DotGeneralToMatmulConversionPattern>(typeConverter, ctx);
patterns.add<ReductionAndPattern>(typeConverter, ctx);
patterns.add<ReductionOrPattern>(typeConverter, ctx);
patterns.add<ArgMaxOpKeepDimConversionPattern>(typeConverter, ctx);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct TTIRToTTIRDecompositionPass
target.addIllegalOp<ttir::SelectOp>();
target.addIllegalOp<ttir::DotGeneralOp>();
target.addIllegalOp<ttir::ReduceAndOp>();
target.addIllegalOp<ttir::ReduceOrOp>();

// These are the ops that must satisfy some conditions after this pass
target.addDynamicallyLegalOp<ttir::ArangeOp>([&](ttir::ArangeOp op) {
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2563,6 +2563,22 @@ ::mlir::LogicalResult mlir::tt::ttir::ReduceAndOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// ReduceOrOp
//===----------------------------------------------------------------------===//

// ReduceOrOp kernel builder.
void mlir::tt::ttir::ReduceOrOp::buildGenericRegion(
::mlir::OpBuilder &opBuilder, ::mlir::Block *block) {
// NOLINTNEXTLINE
createReduceOp(opBuilder, block, getLoc(), "or");
}

// ReduceOrOp verification.
::mlir::LogicalResult mlir::tt::ttir::ReduceOrOp::verify() {
return verifyReduceOp(getOperation(), getInput().getType(), getDimArg());
}

//===----------------------------------------------------------------------===//
// Reduce ArgMaxOp
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 39 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/reduction/reduce_or_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_reduce_or attributes {} {
func.func public @test_reduce_or_4to3dim(%arg0: tensor<128x10x32x4xi1>, %cst_0: tensor<i1>) -> tensor<128x10x32xi1> {
// CHECK-LABEL: func.func public @test_reduce_or_4to3dim
// CHECK: tensor.empty
// CHECK: "ttir.reduce_or"
// CHECK-SAME: dim_arg = [3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xbf16>
// CHECK-SAME: -> tensor<128x10x32xbf16>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.or across dimensions = [3] : (tensor<128x10x32x4xi1>, tensor<i1>) -> tensor<128x10x32xi1>
return %0 : tensor<128x10x32xi1>
}

func.func public @test_reduce_or_3to2dim(%arg0: tensor<128x10x4xi1>, %cst_0: tensor<i1>) -> tensor<128x4xi1> {
// CHECK-LABEL: func.func public @test_reduce_or_3to2dim
// CHECK: tensor.empty
// CHECK: "ttir.reduce_or"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xbf16>
// CHECK-SAME: -> tensor<128x4xbf16>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.or across dimensions = [1] : (tensor<128x10x4xi1>, tensor<i1>) -> tensor<128x4xi1>
return %0 : tensor<128x4xi1>
}

func.func public @test_reduce_or_2to1dim(%arg0: tensor<128x10xi1>, %cst_0: tensor<i1>) -> tensor<10xi1> {
// CHECK-LABEL: func.func public @test_reduce_or_2to1dim
// CHECK: tensor.empty
// CHECK: "ttir.reduce_or"
// CHECK-SAME: dim_arg = [0 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xbf16>
// CHECK-SAME: -> tensor<10xbf16>
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.or across dimensions = [0] : (tensor<128x10xi1>, tensor<i1>) -> tensor<10xi1>
return %0 : tensor<10xi1>
}
}
41 changes: 41 additions & 0 deletions test/ttmlir/Decomposition/TTIR/reduction/reduce_or.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s
module attributes {} {
func.func public @test_reduce_or_4to3dim(%arg0: tensor<128x10x32x4xbf16>, %arg1: tensor<1xbf16>) -> tensor<128x10x32xbf16> {
// CHECK-LABEL: func.func public @test_reduce_or_4to3dim
// CHECK: %[[SUM:[0-9]+]] = "ttir.sum"
// CHECK-SAME: dim_arg = [3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xbf16>
// CHECK-SAME: -> tensor<128x10x32xbf16>
// CHECK: return %[[SUM]]
%0 = tensor.empty() : tensor<128x10x32xbf16>
%1 = "ttir.reduce_or"(%arg0, %0) <{dim_arg = [3 : i32], keep_dim = false}> : (tensor<128x10x32x4xbf16>, tensor<128x10x32xbf16>) -> tensor<128x10x32xbf16>
return %1 : tensor<128x10x32xbf16>
}

func.func public @test_reduce_or_3to2dim(%arg0: tensor<128x10x4xbf16>, %arg1: tensor<1xbf16>) -> tensor<128x4xbf16> {
// CHECK-LABEL: func.func public @test_reduce_or_3to2dim
// CHECK: %[[SUM:[0-9]+]] = "ttir.sum"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xbf16>
// CHECK-SAME: -> tensor<128x4xbf16>
// CHECK: return %[[SUM]]
%0 = tensor.empty() : tensor<128x4xbf16>
%1 = "ttir.reduce_or"(%arg0, %0) <{dim_arg = [1 : i32], keep_dim = false}> : (tensor<128x10x4xbf16>, tensor<128x4xbf16>) -> tensor<128x4xbf16>
return %1 : tensor<128x4xbf16>
}

func.func public @test_reduce_or_2to1dim(%arg0: tensor<128x10xbf16>, %arg1: tensor<1xbf16>) -> tensor<10xbf16> {
// CHECK-LABEL: func.func public @test_reduce_or_2to1dim
// CHECK: %[[SUM:[0-9]+]] = "ttir.sum"
// CHECK-SAME: dim_arg = [0 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xbf16>
// CHECK-SAME: -> tensor<10xbf16>
// CHECK: return %[[SUM]]
%0 = tensor.empty() : tensor<10xbf16>
%1 = "ttir.reduce_or"(%arg0, %0) <{dim_arg = [0 : i32], keep_dim = false}> : (tensor<128x10xbf16>, tensor<10xbf16>) -> tensor<10xbf16>
return %1 : tensor<10xbf16>
}
}
39 changes: 39 additions & 0 deletions test/ttmlir/Dialect/TTNN/reduction/simple_reduce_or.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s

module attributes {} {
func.func public @test_reduce_or_4to3dim(%arg0: tensor<128x10x32x4xbf16>, %arg1: tensor<1xbf16>) -> tensor<128x10x32xbf16> {
// CHECK-LABEL: func.func public @test_reduce_or_4to3dim
// CHECK: %[[SUM:[0-9]+]] = "ttnn.sum"
// CHECK-SAME: dim_arg = [3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xbf16,
// CHECK-SAME: -> tensor<128x10x32xbf16,
%0 = tensor.empty() : tensor<128x10x32xbf16>
%1 = "ttir.reduce_or"(%arg0, %0) <{dim_arg = [3 : i32], keep_dim = false}> : (tensor<128x10x32x4xbf16>, tensor<128x10x32xbf16>) -> tensor<128x10x32xbf16>
return %1 : tensor<128x10x32xbf16>
}

func.func public @test_reduce_or_3to2dim(%arg0: tensor<128x10x4xbf16>, %arg1: tensor<1xbf16>) -> tensor<128x4xbf16> {
// CHECK-LABEL: func.func public @test_reduce_or_3to2dim
// CHECK: %[[SUM:[0-9]+]] = "ttnn.sum"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xbf16,
// CHECK-SAME: -> tensor<128x4xbf16,
%0 = tensor.empty() : tensor<128x4xbf16>
%1 = "ttir.reduce_or"(%arg0, %0) <{dim_arg = [1 : i32], keep_dim = false}> : (tensor<128x10x4xbf16>, tensor<128x4xbf16>) -> tensor<128x4xbf16>
return %1 : tensor<128x4xbf16>
}

func.func public @test_reduce_or_2to1dim(%arg0: tensor<128x10xbf16>, %arg1: tensor<1xbf16>) -> tensor<10xbf16> {
// CHECK-LABEL: func.func public @test_reduce_or_2to1dim
// CHECK: %[[SUM:[0-9]+]] = "ttnn.sum"
// CHECK-SAME: dim_arg = [0 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xbf16,
// CHECK-SAME: -> tensor<10xbf16,
%0 = tensor.empty() : tensor<10xbf16>
%1 = "ttir.reduce_or"(%arg0, %0) <{dim_arg = [0 : i32], keep_dim = false}> : (tensor<128x10xbf16>, tensor<10xbf16>) -> tensor<10xbf16>
return %1 : tensor<10xbf16>
}
}
19 changes: 19 additions & 0 deletions test/ttmlir/Silicon/StableHLO/n150/reduction/reduce_or_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline \
// RUN: --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

module @jit_reduce_add attributes {} {
func.func public @test_reduce_or_4to3dim(%arg0: tensor<128x10x32x4xi1>, %cst_0: tensor<i1>) -> tensor<128x10x32xi1> {
// CHECK-LABEL: func.func public @test_reduce_or_4to3dim
// CHECK: "ttnn.sum"
// CHECK-SAME: dim_arg = [3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: -> tensor<128x10x32xbf16,
%0 = stablehlo.reduce(%arg0 init: %cst_0) applies stablehlo.or across dimensions = [3] : (tensor<128x10x32x4xi1>, tensor<i1>) -> tensor<128x10x32xi1>
return %0 : tensor<128x10x32xi1>
}
}
43 changes: 43 additions & 0 deletions test/ttmlir/Silicon/TTNN/n150/simple_reduce_or.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn

module attributes {} {
func.func public @test_reduce_or_4to2dim(%arg0: tensor<128x10x32x4xbf16>, %arg1: tensor<1xbf16>) -> tensor<128x32xbf16> {
// CHECK-LABEL: func.func public @test_reduce_or_4to2dim
// CHECK: %[[SUM:[0-9]+]] = "ttnn.sum"
// CHECK-SAME: dim_arg = [1 : i32, 3 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x32x4xbf16,
// CHECK-SAME: -> tensor<128x32xbf16,
%0 = tensor.empty() : tensor<128x32xbf16>
%1 = "ttir.reduce_or"(%arg0, %0) <{dim_arg = [1: i32, 3 : i32], keep_dim = false}> : (tensor<128x10x32x4xbf16>, tensor<128x32xbf16>) -> tensor<128x32xbf16>
return %1 : tensor<128x32xbf16>
}

func.func public @test_reduce_or_3to2dim(%arg0: tensor<128x10x4xbf16>, %arg1: tensor<1xbf16>) -> tensor<128x4xbf16> {
// CHECK-LABEL: func.func public @test_reduce_or_3to2dim
// CHECK: %[[SUM:[0-9]+]] = "ttnn.sum"
// CHECK-SAME: dim_arg = [1 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10x4xbf16,
// CHECK-SAME: -> tensor<128x4xbf16,
%0 = tensor.empty() : tensor<128x4xbf16>
%1 = "ttir.reduce_or"(%arg0, %0) <{dim_arg = [1 : i32], keep_dim = false}> : (tensor<128x10x4xbf16>, tensor<128x4xbf16>) -> tensor<128x4xbf16>
return %1 : tensor<128x4xbf16>
}

func.func public @test_reduce_or_2to1dim(%arg0: tensor<128x10xbf16>, %arg1: tensor<1xbf16>) -> tensor<10xbf16> {
// CHECK-LABEL: func.func public @test_reduce_or_2to1dim
// CHECK: %[[SUM:[0-9]+]] = "ttnn.sum"
// CHECK-SAME: dim_arg = [0 : i32]
// CHECK-SAME: keep_dim = false
// CHECK-SAME: tensor<128x10xbf16,
// CHECK-SAME: -> tensor<10xbf16,
%0 = tensor.empty() : tensor<10xbf16>
%1 = "ttir.reduce_or"(%arg0, %0) <{dim_arg = [0 : i32], keep_dim = false}> : (tensor<128x10xbf16>, tensor<10xbf16>) -> tensor<10xbf16>
return %1 : tensor<10xbf16>
}
}

0 comments on commit b020a93

Please sign in to comment.