Skip to content

Commit

Permalink
Determine indirectly defined constant for clamp op
Browse files Browse the repository at this point in the history
Constant can be converted/reshaped/broadcasted and then used in stablehlo.clamp
op. Determine the base constant value for min/max and use them in ttir.clamp op.
  • Loading branch information
mmanzoorTT committed Feb 27, 2025
1 parent 63bc18c commit ac19198
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 16 deletions.
77 changes: 66 additions & 11 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <cmath>
#include <limits>
#include <mlir/IR/Operation.h>
#include <vector>

#include "mlir/Dialect/Traits.h"
Expand Down Expand Up @@ -1905,8 +1906,8 @@ class StableHLOToTTIROpClampOpConversionPattern
RankedTensorType outputType = mlir::cast<RankedTensorType>(
this->getTypeConverter()->convertType(srcOp.getResult().getType()));

if (std::optional<float> minValue = getConstantValue(adaptor.getMin()),
maxValue = getConstantValue(adaptor.getMax());
if (std::optional<float> minValue = getConstantValue(srcOp.getMin()),
maxValue = getConstantValue(srcOp.getMax());
minValue && maxValue) {
ttmlir::utils::replaceOpWithNewDPSOp<ttir::ClampOp>(
rewriter, srcOp, outputType, adaptor.getOperand(),
Expand All @@ -1915,27 +1916,81 @@ class StableHLOToTTIROpClampOpConversionPattern
return success();
}

mlir::Value min =
broadcastAttr(adaptor.getMin(), outputType, srcOp, rewriter);
mlir::Value max =
broadcastAttr(adaptor.getMax(), outputType, srcOp, rewriter);

ttir::MaximumOp maximumOp = ttmlir::utils::createDPSOp<ttir::MaximumOp>(
rewriter, srcOp->getLoc(), outputType, adaptor.getMin(),
adaptor.getOperand());
rewriter, srcOp->getLoc(), outputType, min, adaptor.getOperand());
ttmlir::utils::replaceOpWithNewDPSOp<ttir::MinimumOp>(
rewriter, srcOp, outputType, maximumOp.getResult(0), adaptor.getMax());
rewriter, srcOp, outputType, maximumOp.getResult(0), max);

return success();
}

private:
std::optional<float> getConstantValue(Value value) const {
if (auto constantOp = value.getDefiningOp<ttir::ConstantOp>()) {
Operation *op = value.getDefiningOp();
while (op &&
(isa<stablehlo::BroadcastInDimOp>(op) ||
isa<stablehlo::ReshapeOp>(op) || isa<stablehlo::ConvertOp>(op))) {
op = op->getOperand(0).getDefiningOp();
}
if (!op) {
return std::nullopt;
}

if (auto constantOp = mlir::dyn_cast<stablehlo::ConstantOp>(op)) {
auto attr = constantOp.getValueAttr();
if (!attr.isSplat()) {
return {};
return std::nullopt;
}
return attr.getElementType().isInteger()
? static_cast<float>(attr.getSplatValue<int>())
: attr.getSplatValue<float>();
mlir::Type elementType = attr.getElementType();
mlir::APFloat fillValue(mlir::APFloat::IEEEsingle());
if (isa<IntegerType>(elementType)) {
fillValue.convertFromAPInt(attr.getSplatValue<llvm::APInt>(),
attr.getElementType().isSignedInteger(),
llvm::RoundingMode::TowardZero);
return fillValue.convertToFloat();
}
if (isa<FloatType>(elementType)) {
return static_cast<float>(
attr.getSplatValue<mlir::APFloat>().convertToDouble());
}
assert(false && "Unsupported data type.");
}
return {};
return std::nullopt;
}

mlir::Value broadcastAttr(mlir::Value input, RankedTensorType desiredType,
mlir::stablehlo::ClampOp srcOp,
ConversionPatternRewriter &rewriter) const {
auto inputType = mlir::cast<RankedTensorType>(input.getType());
if (inputType.getShape() == desiredType.getShape()) {
return input;
}

SmallVector<int64_t> unsqueezeShape(desiredType.getRank(), 1);
for (int64_t i = 0; i < inputType.getRank(); i++) {
unsqueezeShape[i] = inputType.getDimSize(i);
}
SmallVector<int32_t> reshapeDim(unsqueezeShape.begin(),
unsqueezeShape.end());

auto reshapeDimAttr = rewriter.getI32ArrayAttr(reshapeDim);
ttir::ReshapeOp reshapeOp = ttmlir::utils::createDPSOp<ttir::ReshapeOp>(
rewriter, srcOp.getLoc(), unsqueezeShape, desiredType.getElementType(),
desiredType.getEncoding(), input, reshapeDimAttr);

::llvm::ArrayRef<int64_t> inputShape = unsqueezeShape;
::llvm::ArrayRef<int64_t> outputShape = desiredType.getShape();

SmallVector<int64_t> broadcastShape =
ttmlir::utils::getBroadcastDimensions<int64_t>(inputShape, outputShape);

return ttmlir::utils::createDPSOp<ttir::BroadcastOp>(
rewriter, srcOp->getLoc(), desiredType, reshapeOp, broadcastShape);
}
};
} // namespace
Expand Down
65 changes: 64 additions & 1 deletion test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_transpose attributes {} {
module @jit_clamp attributes {} {
func.func public @test_clamp_constant(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-LABEL: func.func public @test_clamp_constant
%cst = stablehlo.constant dense<2.000000e+00> : tensor<4xf32>
%cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<4xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]]
Expand All @@ -12,7 +13,38 @@ module @jit_transpose attributes {} {
return %0 : tensor<4xf32>
}

func.func public @test_clamp_indirect_constant_reshape(%arg0: tensor<1x16xbf16>) -> tensor<1x16xbf16> {
// CHECK-LABEL: func.func public @test_clamp_indirect_constant_reshape
%cst = arith.constant dense<3.0> : tensor<1xf64>
%cst_0 = arith.constant dense<6> : tensor<1xi64>
%0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16>
%1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor<bf16>
%2 = stablehlo.convert %cst_0 : (tensor<1xi64>) -> tensor<1xbf16>
%3 = stablehlo.reshape %2 : (tensor<1xbf16>) -> tensor<bf16>
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : [[TENSOR:tensor<1x16xbf16>]]
// CHECK: "ttir.clamp"(%arg0, %[[EMPTY]])
// CHECK-SAME: max = 6.000000e+00 : f32, min = 3.000000e+00 : f32
// CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
%4 = stablehlo.clamp %1, %arg0, %3 : (tensor<bf16>, tensor<1x16xbf16>, tensor<bf16>) -> tensor<1x16xbf16>
return %4 : tensor<1x16xbf16>
}

func.func public @test_clamp_indirect_constant_broadcast(%arg0: tensor<1x32xbf16>) -> (tensor<1x32xbf16>) {
// CHECK-LABEL: func.func public @test_clamp_indirect_constant_broadcast
%cst = stablehlo.constant dense<2.000000e+00> : tensor<bf16>
%cst_0 = stablehlo.constant dense<5.000000e+00> : tensor<bf16>
%0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<bf16>) -> tensor<1x32xbf16>
%1 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<bf16>) -> tensor<1x32xbf16>
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() : [[TENSOR:tensor<1x32xbf16>]]
// CHECK: "ttir.clamp"(%arg0, %[[EMPTY]])
// CHECK-SAME: max = 5.000000e+00 : f32, min = 2.000000e+00 : f32
// CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
%2 = stablehlo.clamp %0, %arg0, %1 : tensor<1x32xbf16>
return %2 : tensor<1x32xbf16>
}

func.func public @test_clamp_tensor(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-LABEL: func.func public @test_clamp_tensor
// CHECK: %[[EMPTY0:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]]
// CHECK: %[[MAX:.*]] = "ttir.maximum"(%arg1, %arg0, %[[EMPTY0]])
// CHECK-SAME: ([[TENSOR]], [[TENSOR]], [[TENSOR]]) -> [[TENSOR]]
Expand All @@ -23,4 +55,35 @@ module @jit_transpose attributes {} {
// CHECK: return %[[MIN]] : [[TENSOR]]
return %0 : tensor<4xf32>
}

func.func public @test_clamp_tensor_constant(%arg0: tensor<1x16xbf16>, %arg1: tensor<bf16>) -> tensor<1x16xbf16> {
// CHECK-LABEL: func.func public @test_clamp_tensor_constant(
// CHECK: %[[CONSTANT:[0-9]+]] = "ttir.constant"() <{value = dense<3.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%cst = arith.constant dense<3.0> : tensor<1xf64>
// CHECK: %[[CAST:[0-9]+]] = "ttir.typecast"(%[[CONSTANT]],
// CHECK-SAME: (tensor<1xf32>, tensor<1xbf16>) -> tensor<1xbf16>
%0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16>
// CHECK: %[[RESHAPE0:[0-9]+]] = "ttir.reshape"(%[[CAST]],
// CHECK-SAME: shape = [1 : i32]
// CHECK-SAME: (tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16>
%1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor<bf16>
// CHECK: %[[RESHAPE1:[0-9]+]] = "ttir.reshape"(%[[RESHAPE0]],
// CHECK-SAME: shape = [1 : i32, 1 : i32]
// CHECK-SAME: (tensor<1xbf16>, tensor<1x1xbf16>) -> tensor<1x1xbf16>
// CHECK: %[[MIN:[0-9]+]] = "ttir.broadcast"(%[[RESHAPE1]]
// CHECK-SAME: {broadcast_dimensions = array<i64: 1, 16>}
// CHECK-SAME: (tensor<1x1xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16>
// CHECK: %[[RESHAPE2:[0-9]+]] = "ttir.reshape"(%arg1,
// CHECK-SAME: <{shape = [1 : i32, 1 : i32]}>
// CHECK-SAME: (tensor<1xbf16>, tensor<1x1xbf16>) -> tensor<1x1xbf16>
// CHECK: %[[MAX:[0-9]+]] = "ttir.broadcast"(%[[RESHAPE2]],
// CHECK-SAME: <{broadcast_dimensions = array<i64: 1, 16>}>
// CHECK-SAME: (tensor<1x1xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16>
// CHECK: %[[ARG:[0-9]+]] = "ttir.maximum"(%[[MIN]], %arg0,
// CHECK-SAME: (tensor<1x16xbf16>, tensor<1x16xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16>
// CHECK: "ttir.minimum"(%[[ARG]], %[[MAX]],
// CHECK-SAME: (tensor<1x16xbf16>, tensor<1x16xbf16>, tensor<1x16xbf16>) -> tensor<1x16xbf16>
%2 = stablehlo.clamp %1, %arg0, %arg1 : (tensor<bf16>, tensor<1x16xbf16>, tensor<bf16>) -> tensor<1x16xbf16>
return %2 : tensor<1x16xbf16>
}
}
55 changes: 51 additions & 4 deletions test/ttmlir/Silicon/StableHLO/n150/Unary/clamp_op.mlir
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
// REQUIRES: stablehlo
// RUN: rm -rf %t.ttnn
// RUN: rm -rf %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | \
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s \
// RUN: --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" > %t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn
// RUN: FileCheck --input-file=%t.mlir %s

module @jit_transpose attributes {} {
func.func public @test_clamp_constant(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK-LABEL: func.func public @test_clamp_constant
// CHECK-LABEL: func.func public @test_clamp_constant(
// CHECK: ttnn.clamp
// CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}
// CHECK-SAME: tensor<64x128xf32,
Expand All @@ -19,8 +19,38 @@ module @jit_transpose attributes {} {
return %0 : tensor<64x128xf32>
}

func.func public @test_clamp_indirect_constant_reshape(%arg0: tensor<1x16xbf16>) -> tensor<1x16xbf16> {
// CHECK-LABEL: func.func public @test_clamp_indirect_constant_reshape
%cst = arith.constant dense<3.0> : tensor<1xf64>
%cst_0 = arith.constant dense<6> : tensor<1xi64>
%0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16>
%1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor<bf16>
%2 = stablehlo.convert %cst_0 : (tensor<1xi64>) -> tensor<1xbf16>
%3 = stablehlo.reshape %2 : (tensor<1xbf16>) -> tensor<bf16>
// CHECK: ttnn.clamp
// CHECK-SAME: {max = 6.000000e+00 : f32, min = 3.000000e+00 : f32}
// CHECK-SAME: tensor<1x16xbf16,
// CHECK-SAME: -> tensor<1x16xbf16,
%4 = stablehlo.clamp %1, %arg0, %3 : (tensor<bf16>, tensor<1x16xbf16>, tensor<bf16>) -> tensor<1x16xbf16>
return %4 : tensor<1x16xbf16>
}

func.func public @test_clamp_indirect_constant_broadcast(%arg0: tensor<1x32xbf16>) -> (tensor<1x32xbf16>) {
// CHECK-LABEL: func.func public @test_clamp_indirect_constant_broadcast
%cst = stablehlo.constant dense<2.000000e+00> : tensor<bf16>
%cst_0 = stablehlo.constant dense<5.000000e+00> : tensor<bf16>
%0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<bf16>) -> tensor<1x32xbf16>
%1 = stablehlo.broadcast_in_dim %cst_0, dims = [] : (tensor<bf16>) -> tensor<1x32xbf16>
// CHECK: ttnn.clamp
// CHECK-SAME: {max = 5.000000e+00 : f32, min = 2.000000e+00 : f32}
// CHECK-SAME: tensor<1x32xbf16,
// CHECK-SAME: -> tensor<1x32xbf16,
%2 = stablehlo.clamp %0, %arg0, %1 : tensor<1x32xbf16>
return %2 : tensor<1x32xbf16>
}

func.func public @test_clamp_tensor(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>, %arg2: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK-LABEL: func.func public @test_clamp_tensor
// CHECK-LABEL: func.func public @test_clamp_tensor(
// CHECK: %[[MAX:.*]] = "ttnn.maximum"
// CHECK-SAME: tensor<64x128xf32,
// CHECK-SAME: tensor<64x128xf32,
Expand All @@ -32,4 +62,21 @@ module @jit_transpose attributes {} {
%0 = stablehlo.clamp %arg1, %arg0, %arg2 : tensor<64x128xf32>
return %0 : tensor<64x128xf32>
}

func.func public @test_clamp_tensor_constant(%arg0: tensor<32x32xbf16>, %arg1: tensor<bf16>) -> tensor<32x32xbf16> {
// CHECK-LABEL: func.func public @test_clamp_tensor_constant(
%cst = arith.constant dense<3.0> : tensor<1xf64>
%0 = stablehlo.convert %cst : (tensor<1xf64>) -> tensor<1xbf16>
%1 = stablehlo.reshape %0 : (tensor<1xbf16>) -> tensor<bf16>
// CHECK: %[[MAX:.*]] = "ttnn.maximum"
// CHECK-SAME: tensor<32x32xbf16,
// CHECK-SAME: tensor<32x32xbf16,
// CHECK-SAME: -> tensor<32x32xbf16,
// CHECK: "ttnn.minimum"(%[[MAX]]
// CHECK-SAME: tensor<32x32xbf16,
// CHECK-SAME: tensor<32x32xbf16,
// CHECK-SAME: -> tensor<32x32xbf16,
%2 = stablehlo.clamp %1, %arg0, %arg1 : (tensor<bf16>, tensor<32x32xbf16>, tensor<bf16>) -> tensor<32x32xbf16>
return %2 : tensor<32x32xbf16>
}
}

0 comments on commit ac19198

Please sign in to comment.