From 3004928b65ef9d27b74913113b348b6cf068dcc8 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Tue, 17 Dec 2024 18:44:57 +0200 Subject: [PATCH] Add stablehlo-complex-math-expander pass. --- BUILD.bazel | 17 ++++ stablehlo/transforms/CMakeLists.txt | 6 ++ stablehlo/transforms/Passes.h | 6 ++ stablehlo/transforms/Passes.td | 34 ++++++++ .../StablehloComplexMathExpander.cpp | 81 +++++++++++++++++++ .../StablehloComplexMathExpanderPatterns.td | 18 +++++ 6 files changed, 162 insertions(+) create mode 100644 stablehlo/transforms/StablehloComplexMathExpander.cpp create mode 100644 stablehlo/transforms/StablehloComplexMathExpanderPatterns.td diff --git a/BUILD.bazel b/BUILD.bazel index b755374dc7..e6aed3b385 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -370,6 +370,21 @@ gentbl_cc_library( ], ) +gentbl_cc_library( + name = "stablehlo_create_complex_math_expander_inc_gen", + tbl_outs = [ + ( + ["--gen-rewriters"], + "stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/StablehloComplexMathExpanderPatterns.td", + deps = [ + ":stablehlo_ops_td_files", + ], +) + cc_library( name = "interpreter_ops", srcs = [ @@ -1120,6 +1135,7 @@ cc_library( "stablehlo/transforms/StablehloAggressiveSimplification.cpp", "stablehlo/transforms/StablehloCanonicalizeDynamism.cpp", "stablehlo/transforms/StablehloCompatibilityExpander.cpp", + "stablehlo/transforms/StablehloComplexMathExpander.cpp", "stablehlo/transforms/StablehloConvertToSignless.cpp", "stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp", "stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp", @@ -1148,6 +1164,7 @@ cc_library( ":linalg_passes", ":stablehlo_aggressive_simplification_inc_gen", ":stablehlo_create_compatibility_expander_inc_gen", + ":stablehlo_create_complex_math_expander_inc_gen", ":stablehlo_legalize_deprecated_ops_inc_gen", ":stablehlo_ops", ":stablehlo_ops_inc_gen", diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index c69f7ff58c..f5193cb61e 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -28,6 +28,10 @@ set(LLVM_TARGET_DEFINITIONS StablehloCompatibilityExpanderPatterns.td) mlir_tablegen(StablehloCompatibilityExpanderPatterns.h.inc --gen-rewriters) add_public_tablegen_target(StablehloCompatibilityExpanderPatternsIncGen) +set(LLVM_TARGET_DEFINITIONS StablehloComplexMathExpanderPatterns.td) +mlir_tablegen(StablehloComplexMathExpanderPatterns.h.inc --gen-rewriters) +add_public_tablegen_target(StablehloComplexMathExpanderPatternsIncGen) + set(LLVM_TARGET_DEFINITIONS StablehloLegalizeDeprecatedOpsPatterns.td) mlir_tablegen(StablehloLegalizeDeprecatedOpsPatterns.h.inc --gen-rewriters) add_public_tablegen_target(StablehloLegalizeDeprecatedOpsPatternsIncGen) @@ -47,6 +51,7 @@ add_mlir_dialect_library(StablehloPasses StablehloCanonicalizeDynamism.cpp StablehloConvertToSignless.cpp StablehloCompatibilityExpander.cpp + StablehloComplexMathExpander.cpp StablehloLegalizeCompositeToCall.cpp StablehloLegalizeDeprecatedOps.cpp StablehloLegalizeQuantToMath.cpp @@ -64,6 +69,7 @@ add_mlir_dialect_library(StablehloPasses PassesIncGen StablehloAggressiveSimplificationPatternsIncGen StablehloCompatibilityExpanderPatternsIncGen + StablehloComplexMathExpanderPatternsIncGen StablehloLegalizeDeprecatedOpsPatternsIncGen VhloToVersionPatterns diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h index 43ee22f355..7b12646bba 100644 --- a/stablehlo/transforms/Passes.h +++ b/stablehlo/transforms/Passes.h @@ -144,6 +144,12 @@ void createStablehloRemoveDynamismPipeline(OpPassManager &pm, // operations into a primitive math operations. void createStablehloLowerQuantPipeline(OpPassManager &pm); +/// Collection of patterns to create expander for StableHLO complex +/// math operations. +void populateStablehloComplexMathExpanderPatterns( + RewritePatternSet *patterns, MLIRContext *context/*, + vhlo::Version targetVersion*/); + // Adds `stablehlo-deserialize` pipeline as a registered pass pipeline // for opt tools. void registerPassPipelines(); diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td index 186560d012..aa9d696664 100644 --- a/stablehlo/transforms/Passes.td +++ b/stablehlo/transforms/Passes.td @@ -123,6 +123,40 @@ def StablehloCompatibilityExpanderPass : Pass<"stablehlo-compatibility-expander" ]; } +def StablehloComplexMathExpanderPass : Pass<"stablehlo-complex-math-expander", "mlir::func::FuncOp"> { + let summary = "Expander for StableHLO complex math operations."; + + let description = [{ + StableHLO complex math operations are decompositions using + StableHLO real math operations. + + This statement is based on the assumption that no hardware exists + that supports complex numbers nor complex math operations + natively. This means that the fallback mechanisms on complex math + operations that compilers may implement, are redundant. With + enabling this pass, all StableHLO complex math operations will be + expanded. + + ```mlir + func.func @sqrt_op_complex(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + %1 = stablehlo.sqrt %arg0 : tensor<4xcomplex> + func.return %1 : tensor<4xcomplex> + } + + ==> + + func.func @sqrt_op_complex(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + TBD + return %2 : tensor<4xcomplex> + } + ``` + }]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::chlo::ChloDialect", + ]; +} + def StablehloConvertToSignlessPass : Pass<"stablehlo-convert-to-signless", "ModuleOp"> { let summary = "Pass to transform the IR to be on signless integers."; } diff --git a/stablehlo/transforms/StablehloComplexMathExpander.cpp b/stablehlo/transforms/StablehloComplexMathExpander.cpp new file mode 100644 index 0000000000..fb4751389b --- /dev/null +++ b/stablehlo/transforms/StablehloComplexMathExpander.cpp @@ -0,0 +1,81 @@ +/* Copyright 2024 The StableHLO Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/Version.h" +#include "stablehlo/transforms/PassUtils.h" +#include "stablehlo/transforms/Passes.h" + +namespace mlir { +namespace stablehlo { +#define GEN_PASS_DEF_STABLEHLOCOMPLEXMATHEXPANDERPASS +#include "stablehlo/transforms/Passes.h.inc" + +namespace { + +static Value getConstantLikeInfValue(OpBuilder &b, Location loc, Value val, + bool negative) { + auto ty = cast(getElementTypeOrSelf(val.getType())); + return getConstantLike( + b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val); +} + +static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc, + Value val) { + auto ty = cast(getElementTypeOrSelf(val.getType())); + return getConstantLike( + b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val); +} + +static Value getConstantLikeSmallestNormalizedValue(OpBuilder &b, Location loc, + Value val) { + auto ty = cast(getElementTypeOrSelf(val.getType())); + return getConstantLike( + b, loc, llvm::APFloat::getSmallestNormalized(ty.getFloatSemantics()), + val); +} + +#include "stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc" + +} // namespace + +void populateStablehloComplexMathExpanderPatterns( + RewritePatternSet *patterns, MLIRContext *context/*, + vhlo::Version targetVersion*/) { + // StableHLO Log1pOp is introduced in v0.9.0. + patterns->add(context); +} + +} // namespace stablehlo +} // namespace mlir diff --git a/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td new file mode 100644 index 0000000000..f1d7451325 --- /dev/null +++ b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td @@ -0,0 +1,18 @@ +/* Copyright 2024 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "ChloDecompositionPatterns.td" + +def Log1pOp_ComplexElementType_ComplexMathExpander : Pat<(StableHLO_Log1pOp ComplexElementType:$input), (CHLO_Log1pOp $input)>;