From 0008018c96f25a657c6b2b705a4c78711d2cfce0 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 8 Jan 2025 20:51:39 +0200 Subject: [PATCH] Add StableHLO complex sqrt to stablehlo-complex-math-expander pass (#2679) As in the title. The [existing implementation of `sqrt`](https://github.com/openxla/xla/blob/30caa6782b2f49c9ecb8f4727d8628a12ee9a861/xla/service/elemental_ir_emitter.cc#L2111) on complex inputs uses polar form of complex sqrt which is inaccurate/incorrect on about 28 % of uniformly distributed samples over all complex plane. The JAX complex sqrt accuracy statistics is as follows: ``` test_unary[sqrt-jax-cuda-complex64-default] maximal ULP difference: 2792619172 ULP difference == 0: 297588 ULP difference == 1: 1106945 ULP difference == 2: 78921 ULP difference == 3: 5924 ULP difference == 4: 2938 ULP difference == 5: 2231 ULP difference == 6: 1855 ULP difference == 7: 1648 ULP difference == 8: 1449 ULP difference == 9: 1263 ULP difference == 10: 1089 ULP difference >= 11: 598949 test_unary[sqrt-jax-cpu-complex64-default] maximal ULP difference: 2760370645 ULP difference == 0: 699582 ULP difference == 1: 759750 ULP difference == 2: 58127 ULP difference == 3: 3237 ULP difference == 4: 1665 ULP difference == 5: 1185 ULP difference == 6: 1021 ULP difference == 7: 871 ULP difference == 8: 804 ULP difference == 9: 653 ULP difference == 10: 622 ULP difference >= 11: 573283 ``` This PR provides an algorithm for complex sqrt that is accurate up to 3/6 ULP difference error on complex samples. The corresponding JAX complex sqrt accuracy statistics is as follows: ``` test_unary[sqrt-jax-cuda-complex64-default] maximal ULP difference: 5 ULP difference == 0: 1060571 ULP difference == 1: 1008268 ULP difference == 2: 31136 ULP difference == 3: 686 ULP difference == 4: 129 ULP difference == 5: 10 test_unary[sqrt-jax-cpu-complex64-default] maximal ULP difference: 2 ULP difference == 0: 1348868 ULP difference == 1: 751504 ULP difference == 2: 428 ``` It is interesting to note that although the same algorithm is used for both CUDA and CPU platforms, then the expected maximal ULP difference is 4 (obtained from applying the algorithm to numpy arrays). Hence - the accuracy of complex sqrt on CPU is better than expected - the accuracy of complex sqrt on CUDA is worse than expected because CUDA sqrt produces slightly different results from std sqrt on float inputs. --- build_tools/math/README.md | 2 +- .../generate_ChloDecompositionPatternsMath.py | 22 +- build_tools/math/generate_tests.py | 8 + stablehlo/tests/math/sqrt_complex128.mlir | 19 ++ stablehlo/tests/math/sqrt_complex64.mlir | 19 ++ stablehlo/tests/math/sqrt_float32.mlir | 19 ++ stablehlo/tests/math/sqrt_float64.mlir | 19 ++ .../StablehloComplexMathExpander.cpp | 7 + .../StablehloComplexMathExpanderPatterns.td | 232 +++++++++++++++++- 9 files changed, 332 insertions(+), 15 deletions(-) create mode 100644 stablehlo/tests/math/sqrt_complex128.mlir create mode 100644 stablehlo/tests/math/sqrt_complex64.mlir create mode 100644 stablehlo/tests/math/sqrt_float32.mlir create mode 100644 stablehlo/tests/math/sqrt_float64.mlir diff --git a/build_tools/math/README.md b/build_tools/math/README.md index e2f10a8881..f2dca9c228 100644 --- a/build_tools/math/README.md +++ b/build_tools/math/README.md @@ -31,7 +31,7 @@ following requirements: - Python 3.11 or newer - mpmath 1.3 or newer -- functional_algorithms 0.12 or newer +- functional_algorithms 0.13.2 or newer that can be installed via pypi: diff --git a/build_tools/math/generate_ChloDecompositionPatternsMath.py b/build_tools/math/generate_ChloDecompositionPatternsMath.py index 62b99474dc..3d07361027 100644 --- a/build_tools/math/generate_ChloDecompositionPatternsMath.py +++ b/build_tools/math/generate_ChloDecompositionPatternsMath.py @@ -98,10 +98,16 @@ def main(kind="CHLO"): ("CHLO_SquareOp", "complex_square", ("z:complex",)), ("CHLO_SquareOp", "real_square", ("x:float",)), ("StableHLO_Log1pOp", "complex_log1p", ("z:complex",)), + ("StableHLO_SqrtOp", "complex_sqrt", ("z:complex",)), ]: if not chloname.startswith(kind): continue - print(f'Generating {chloname} from {fname}{args}') + if chloname.startswith("StableHLO_"): + NameOp = chloname.split("_", 1)[1] + expander_name = f"{NameOp}_ComplexElementType_ComplexMathExpander" + else: + expander_name = "" + print(f"Generating {chloname} from {fname}{args}") func = getattr(fa.algorithms, fname, None) if func is None: warnings.warn( @@ -110,22 +116,12 @@ def main(kind="CHLO"): ctx = fa.Context(paths=[fa.algorithms], parameters=dict(rewrite_keep_integer_literals=True)) graph = ctx.trace(func, *args).rewrite(target, fa.rewrite) - graph.props.update(name=chloname) + graph.props.update(name=chloname, expander_name=expander_name) src = graph.tostring(target) sources.append(target.make_comment(func.__doc__)) if func.__doc__ else None sources[-1] += src source = "\n\n".join(sources) + "\n" - if chloname.startswith('StableHLO_'): - # an ugly hack to fix the definition of stablehlo complex math - # functions. TODO(pearu): add the corresponding feature to - # functional_algorithms stablehlo printer - NameOp = chloname.split('_', 1)[1] - source = source.replace( - f'def : Pat<({chloname}', - f'def {NameOp}_ComplexElementType_ComplexMathExpander : Pat<({chloname}' - ) - if os.path.isfile(output_file): f = open(output_file, "r") content = f.read() @@ -177,6 +173,8 @@ def ComplexElementType : Type< def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall< "::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; +def StableHLO_ConstantLikePosInfValue : NativeCodeCall< + "::mlir::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">; """) f.write(source) f.close() diff --git a/build_tools/math/generate_tests.py b/build_tools/math/generate_tests.py index fe20a1ae0b..8583eeb997 100644 --- a/build_tools/math/generate_tests.py +++ b/build_tools/math/generate_tests.py @@ -54,6 +54,10 @@ # that defines the precision of computing reference values: # mpmath.mp.prec * extra_prec_multiplier # + # namespace - "chlo" or "stablehlo" + # + # passes - a string of pass arguments + # # When unspecifed, these parameters are retrieved from # functional_algorithms database of support functions. # @@ -68,6 +72,10 @@ mpmath_name="log1p", namespace="stablehlo", passes="--stablehlo-complex-math-expander"), + dict(name="sqrt", + mpmath_name="sqrt", + namespace="stablehlo", + passes="--stablehlo-complex-math-expander"), ] diff --git a/stablehlo/tests/math/sqrt_complex128.mlir b/stablehlo/tests/math/sqrt_complex128.mlir new file mode 100644 index 0000000000..adb5edbbab --- /dev/null +++ b/stablehlo/tests/math/sqrt_complex128.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @sqrt_complex128 { + func.func private @samples() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0xtensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func private @expected() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0xtensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xcomplex> + %1 = "stablehlo.sqrt"(%0) : (tensor<169xcomplex>) -> tensor<169xcomplex> + %2 = call @expected() : () -> tensor<169xcomplex> + check.expect_close %1, %2, max_ulp_difference = 4 : tensor<169xcomplex>, tensor<169xcomplex> + func.return + } +} diff --git a/stablehlo/tests/math/sqrt_complex64.mlir b/stablehlo/tests/math/sqrt_complex64.mlir new file mode 100644 index 0000000000..4787262192 --- /dev/null +++ b/stablehlo/tests/math/sqrt_complex64.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @sqrt_complex64 { + func.func private @samples() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0xtensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func private @expected() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0x0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF00000000000080FF8901E95EAEA18CDF8A01E95EAEA18CDFF304355FF30435DFF304355FF30435DFF304355FF30435DFF304355FF30435DFF304355FF30435DFF304355FF30435DFF304355FF30435DFAEA18C5F8A01E9DEAEA18C5F8901E9DE0000807F0000000000000000000080FF8901E95EAEA18CDF8901E95EAEA18CDFF204355FF20435DFF204355FF20435DFF204355FF20435DFF204355FF20435DFF204355FF20435DFF204355FF20435DFF204355FF20435DFAEA18C5F8901E9DEAEA18C5F8901E9DE0000807F0000000000000000000080FF0000401FFFFF7FDF0100401FFFFF7FDFBCAF0E3FE03CACBFD7B35D3FD7B35DBFD7B35D3FD7B35DBFD7B35D3FD7B35DBFD7B35D3FD7B35DBFD7B35D3FD7B35DBFE03CAC3FBCAF0EBFFFFF7F5F0100409FFFFF7F5F0000409F0000807F0000000000000000000080FF00001C00FFFF7FDF00001C00FFFF7FDF2EE5361F71C49CBF761E1A2FBF09BAAF51776F2F51776FAF51776F2F51776FAF51776F2F51776FAFBF09BA2F761E1AAF71C49C3F2EE5369FFFFF7F5F00001C80FFFF7F5F00001C800000807F0000000000000000000080FF00000000FFFF7FDF00000000FFFF7FDF0000000071C49CBF8F844104FD53A9AF98C2A41911E2469A0000001A0000009A11E2461A98C2A499FD53A92F8F84418471C49C3F00000080FFFF7F5F00000080FFFF7F5F000000800000807F00000000000000000000807F00000000FFFF7F5F00000000FFFF7F5F0000000071C49C3F00000000FD53A92F00000000F304351A0000000000000000F304351A00000000FD53A92F0000000071C49C3F00000000FFFF7F5F00000000FFFF7F5F000000000000807F00000000000000000000807F00000000FFFF7F5F00000000FFFF7F5F0000000071C49C3F8F844104FD53A92F98C2A41911E2461A0000001A0000001A11E2461A98C2A419FD53A92F8F84410471C49C3F00000000FFFF7F5F00000000FFFF7F5F000000000000807F00000000000000000000807F00001C00FFFF7F5F00001C00FFFF7F5F2EE5361F71C49C3F761E1A2FBF09BA2F51776F2F51776F2F51776F2F51776F2F51776F2F51776F2FBF09BA2F761E1A2F71C49C3F2EE5361FFFFF7F5F00001C00FFFF7F5F00001C000000807F00000000000000000000807F0000401FFFFF7F5F0100401FFFFF7F5FBCAF0E3FE03CAC3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FE03CAC3FBCAF0E3FFFFF7F5F0100401FFFFF7F5F0000401F0000807F00000000000000000000807F8901E95EAEA18C5F8901E95EAEA18C5FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FAEA18C5F8901E95EAEA18C5F8901E95E0000807F00000000000000000000807F8901E95EAEA18C5F8A01E95EAEA18C5FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FAEA18C5F8A01E95EAEA18C5F8901E95E0000807F000000000000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F"> : tensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xcomplex> + %1 = "stablehlo.sqrt"(%0) : (tensor<169xcomplex>) -> tensor<169xcomplex> + %2 = call @expected() : () -> tensor<169xcomplex> + check.expect_close %1, %2, max_ulp_difference = 4 : tensor<169xcomplex>, tensor<169xcomplex> + func.return + } +} diff --git a/stablehlo/tests/math/sqrt_float32.mlir b/stablehlo/tests/math/sqrt_float32.mlir new file mode 100644 index 0000000000..e4f96eef2d --- /dev/null +++ b/stablehlo/tests/math/sqrt_float32.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @sqrt_float32 { + func.func private @samples() -> tensor<169xf32> { + %0 = stablehlo.constant dense<"0x000080FFFFFF7FFFFEFF7FFF05E763FC88DAD5FA0BCE47F98EC1B9F711B52BF695A89DF4189C0FF39B8F81F11E83F3EFA17665EE246AD7ECA75D49EB2B51BBE9AE442DE831389FE6B42B11E5371F83E3BA12F5E13D0667E0C1F9D8DE44ED4ADDC7E0BCDB4AD42EDACDC7A0D850BB12D7D3AE84D557A2F6D3DA9568D25D89DAD0E07C4CCF6370BECDE66330CC6957A2CAED4A14C9703E86C7F331F8C576256AC4F918DCC27C0C4EC10000C0BF83F331BE06E7A3BC89DA15BB0CCE87B98FC1F9B712B56BB696A8DDB4199C4FB39C8FC1B11F8333B0A276A5AE256A17ADA85D89AB2C51FBA9AF446DA83238DFA6B52B51A5381FC3A3BB1235A23E06A7A0C2F9189F45ED8A9DC8E0FC9B4BD46E9ACEC7E09851BB5297D4AEC49558A23694DB95A8925E891A91E17C8C8F6470FE8DE763708C6A57E28AEE4A5489713EC687F43138867725AA84FA181C837D0C8E810100008000000000010000007D0C8E01FA181C037725AA04F4313806713EC607EE4A54096A57E20AE763700C6470FE0DE17C8C0F5E891A11DB95A81258A23614D4AEC41551BB5217CEC7E0184BD46E1AC8E0FC1B45ED8A1DC2F9181F3E06A720BB123522381FC323B52B51253238DF26AF446D282C51FB29A85D892B256A172DA276A52E1F8333309C8FC131199C4F3396A8DD3412B56B368FC1F9370CCE873989DA153B06E7A33C83F3313E0000C03F7C0C4E41F918DC4276256A44F331F845703E8647ED4A14496957A24AE663304C6370BE4DE07C4C4F5D89DA50DA95685257A2F653D3AE845550BB1257CDC7A0584AD42E5AC7E0BC5B44ED4A5DC1F9D85E3D066760BA12F561371F8363B42B116531389F66AE442D682B51BB69A75D496B246AD76CA176656E1E83F36F9B8F8171189C0F7395A89D7411B52B768EC1B9770BCE477988DAD57A05E7637CFEFF7F7FFFFF7F7F0000807F"> : tensor<169xf32> + return %0 : tensor<169xf32> + } + func.func private @expected() -> tensor<169xf32> { + %0 = stablehlo.constant dense<"0x0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F00000000F304351A70D7862005E74721819313224D26D922C84B9F23D11F6924F5352A258712F8257377B4263F19862780E64628CEE512293F3AD829E9AA9E2AFF43682B719F292C0044F72C84E9B32DFF59852EAFE4452F4C3712302E4DD73067099E315E67673268082933CD74F633255BB334AA9984358CE14436FA871137185FD63740679D38E9896639D670283AEAA4F53A54CCB23B3ED8833C12DD433DD3D7103EF86FD53E71C49C3F9EAB6540BCD8274157D4F441103DB242B41583433DD74244D4261045CA7FD445F7209C467CCC6447174027481203F4485AADB1490752824A05D0414BFC740F4C8C8ED34CD27C9B4D7FEC634EE7A6264F1831F34F2E1DB150338D815166C7405245C20E533A9CD253FFD79A54A50B6355290D2656695EF2568D8CB05733C780585ABD3F59AF0E0E5ACFA8D15A7C329A5BEC29625CDD72255D028BF15DFFFF7F5FFFFF7F5F0000807F"> : tensor<169xf32> + return %0 : tensor<169xf32> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xf32> + %1 = "stablehlo.sqrt"(%0) : (tensor<169xf32>) -> tensor<169xf32> + %2 = call @expected() : () -> tensor<169xf32> + check.expect_close %1, %2, max_ulp_difference = 4 : tensor<169xf32>, tensor<169xf32> + func.return + } +} diff --git a/stablehlo/tests/math/sqrt_float64.mlir b/stablehlo/tests/math/sqrt_float64.mlir new file mode 100644 index 0000000000..17e2bb62d4 --- /dev/null +++ b/stablehlo/tests/math/sqrt_float64.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @sqrt_float64 { + func.func private @samples() -> tensor<169xf64> { + %0 = stablehlo.constant dense<"0xtensor<169xf64> + return %0 : tensor<169xf64> + } + func.func private @expected() -> tensor<169xf64> { + %0 = stablehlo.constant dense<"0xtensor<169xf64> + return %0 : tensor<169xf64> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xf64> + %1 = "stablehlo.sqrt"(%0) : (tensor<169xf64>) -> tensor<169xf64> + %2 = call @expected() : () -> tensor<169xf64> + check.expect_close %1, %2, max_ulp_difference = 4 : tensor<169xf64>, tensor<169xf64> + func.return + } +} diff --git a/stablehlo/transforms/StablehloComplexMathExpander.cpp b/stablehlo/transforms/StablehloComplexMathExpander.cpp index b830db6196..de05542115 100644 --- a/stablehlo/transforms/StablehloComplexMathExpander.cpp +++ b/stablehlo/transforms/StablehloComplexMathExpander.cpp @@ -29,6 +29,13 @@ static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc, b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val); } +static Value getConstantLikeInfValue(OpBuilder &b, Location loc, Value val, + bool negative) { + auto ty = cast(getElementTypeOrSelf(val.getType())); + return getConstantLike( + b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val); +} + //===----------------------------------------------------------------------===// // Pass //===----------------------------------------------------------------------===// diff --git a/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td index 42cb7f72fc..c488f16595 100644 --- a/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td +++ b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // -// This file is generated using functional_algorithms tool (0.12.0). +// This file is generated using functional_algorithms tool (0.13.1). // See build_tools/math/README.md for more information. include "mlir/IR/OpBase.td" @@ -34,6 +34,8 @@ def ComplexElementType : Type< def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall< "::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; +def StableHLO_ConstantLikePosInfValue : NativeCodeCall< + "::mlir::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">; // Logarithm of 1 + z on complex input: // // log1p(x + I * y) = 0.5 * log((x + 1) ** 2 + y ** 2) + I * arctan2(y, x + 1) @@ -139,7 +141,7 @@ def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall< // the Case A method [verified numerically for float32 and float64]. // // -def Log1pOp_ComplexElementType_ComplexMathExpander : Pat<(StableHLO_Log1pOp ComplexElementType:$z), +def Log1pOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_Log1pOp ComplexElementType:$z), (StableHLO_ComplexOp (StableHLO_SelectOp (StableHLO_CompareOp @@ -273,3 +275,229 @@ def Log1pOp_ComplexElementType_ComplexMathExpander : Pat<(StableHLO_Log1pOp Comp (StableHLO_SubtractOp:$subtract_add_2sum_high__add_2sum_high_0_ $add_2sum_high, $_add_2sum_high_0_))), (StableHLO_SubtractOp $_square_dekker_low_0_, $subtract_add_2sum_high__add_2sum_high_0_)))))))), (StableHLO_Atan2Op $y, $xp1))>; + +// Square root on complex inputs: +// +// sqrt(z) = sqrt((hypot(x, y) + x)/2) + I * sgn(y) * sqrt((hypot(x, y) - x) / 2) +// +// where z = x + I * y, sgn(y) = 1 if y >= 0, and sgn(y) = -1 otherwise. +// +// Algorithm +// --------- +// +// In the above formula, catastrophic cancellation errors occur in +// the imaginary part when x is positive, and in the real part when x +// is negative. To avoid these, let us define +// +// u = sqrt((hypot(x, y) + abs(x))/2) +// v = sgn(y) * sqrt((hypot(x, y) - abs(x))/2) +// +// and find +// +// u * v = sgn(y) * sqrt(hypot(x, y) ** 2 - x ** 2) / 2 = y / 2 +// +// That is, if x > 0, then we have +// +// sqrt(z) = u + I * y / u / 2 +// +// and if x < 0, +// +// sqrt(z) = abs(y) / u / 2 + I * sgn(y) * u +// +// If abs(x) and abs(y) are smaller that smallest normal, then as a +// result of underflow, u will be zero and v will be undefined. On +// the other hand, if abs(x) and abs(y) are close to largest floating +// point number, then `hypot(x, y) + abs(x)` will overflow, and u +// will be `inf`. To address the issues from underflow and overflow, +// we'll use the following formula: +// +// 1. abs(x) == abs(y), or abs(x) == inf and abs(y) == inf, then +// +// u_eq = sqrt(abs(x)) * sqrt((1 + sqrt(2))/2) +// abs(y) / u = sqrt(abs(x)) / sqrt((1 + sqrt(2))/2) +// +// 2. If abs(x) > abs(y) and u == 0 (the underflow case) or u == inf +// (the overflow case), denote r = abs(y) / abs(x), then +// +// u_gt = sqrt(abs(x)) * sqrt((1 + hypot(1, r)) / 2) +// abs(y) / u = sqrt(abs(y)) * sqrt(r) / sqrt((1 + hypot(1, r)) / 2) +// +// 3. If abs(x) < abs(y) and u == 0 (the underflow case) or u == inf +// (the overflow case), denote r = abs(x) / abs(y), then +// +// u_lt = sqrt(abs(y)) * sqrt((r + sqrt(1, r)) / 2) +// abs(y) / u = sqrt(abs(y)) / sqrt((r + sqrt(1, r)) / 2) +// +def SqrtOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_SqrtOp ComplexElementType:$z), + (StableHLO_ComplexOp + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_RealOp:$x $z), + (StableHLO_ConstantLike<"0">:$constant_0 $x), + StableHLO_ComparisonDirectionValue<"GE">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_SelectOp:$u + (StableHLO_CompareOp:$eq_ax_ay + (StableHLO_AbsOp:$ax $x), + (StableHLO_AbsOp:$ay + (StableHLO_ImagOp:$y $z)), + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_DivOp + (StableHLO_MulOp + (StableHLO_SqrtOp:$sq_ax $ax), + (StableHLO_ConstantLike<"1.5537739740300374"> $x)), + (StableHLO_ConstantLike<"1.4142135623730951">:$sq_2 $x)), + (StableHLO_SelectOp + (StableHLO_OrOp:$logical_or_eq_u_general_constant_0_eq_u_general_constant_posinf + (StableHLO_CompareOp + (StableHLO_SqrtOp:$u_general + (StableHLO_AddOp + (StableHLO_DivOp + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_MaxOp:$mx $ax, $ay), + (StableHLO_MinOp:$mn $ax, $ay), + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_MulOp $sq_2, $mx), + (StableHLO_SelectOp + (StableHLO_AndOp + (StableHLO_CompareOp + (StableHLO_SqrtOp:$sqa + (StableHLO_AddOp + (StableHLO_ConstantLike<"1">:$one $x), + (StableHLO_MulOp:$r + (StableHLO_DivOp:$mn_over_mx $mn, $mx), + $mn_over_mx))), + $one, + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_CompareOp + $r, + $constant_0, + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE))), + (StableHLO_AddOp + $mx, + (StableHLO_DivOp + (StableHLO_MulOp $mx, $r), + (StableHLO_ConstantLike<"2">:$two $x))), + (StableHLO_MulOp $mx, $sqa))), + $two), + (StableHLO_DivOp $ax, $two))), + $constant_0, + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_CompareOp + $u_general, + (StableHLO_ConstantLikePosInfValue $x), + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE))), + (StableHLO_SelectOp + (StableHLO_CompareOp:$gt_ax_ay + $ax, + $ay, + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_MulOp + $sq_ax, + (StableHLO_DivOp + (StableHLO_SqrtOp:$sq_1h + (StableHLO_AddOp + $one, + (StableHLO_SelectOp:$h + (StableHLO_CompareOp + (StableHLO_MaxOp:$_mx_0_ + $one, + (StableHLO_AbsOp:$abs__r_0_ + (StableHLO_SelectOp:$_r_0_ + $eq_ax_ay, + $one, + (StableHLO_SelectOp + (StableHLO_CompareOp:$lt_ax_ay + $ax, + $ay, + StableHLO_ComparisonDirectionValue<"LT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_DivOp $ax, $ay), + (StableHLO_DivOp $ay, $ax))))), + (StableHLO_MinOp:$_mn_0_ $one, $abs__r_0_), + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_MulOp $sq_2, $_mx_0_), + (StableHLO_SelectOp + (StableHLO_AndOp + (StableHLO_CompareOp + (StableHLO_SqrtOp:$_sqa_0_ + (StableHLO_AddOp + $one, + (StableHLO_MulOp:$_r_1_ + (StableHLO_DivOp:$_mn_over_mx_0_ $_mn_0_, $_mx_0_), + $_mn_over_mx_0_))), + $one, + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_CompareOp + $_r_1_, + $constant_0, + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE))), + (StableHLO_AddOp + $_mx_0_, + (StableHLO_DivOp + (StableHLO_MulOp $_mx_0_, $_r_1_), + $two)), + (StableHLO_MulOp $_mx_0_, $_sqa_0_))))), + $sq_2)), + (StableHLO_MulOp + (StableHLO_SqrtOp:$sq_ay $ay), + (StableHLO_DivOp + (StableHLO_SqrtOp:$sq_rh + (StableHLO_AddOp $_r_0_, $h)), + $sq_2))), + $u_general)), + (StableHLO_SelectOp:$ay_div_u + $eq_ax_ay, + (StableHLO_DivOp + $sq_ay, + (StableHLO_ConstantLike<"2.19736822693562"> $x)), + (StableHLO_SelectOp + $logical_or_eq_u_general_constant_0_eq_u_general_constant_posinf, + (StableHLO_SelectOp + $gt_ax_ay, + (StableHLO_DivOp + (StableHLO_MulOp + $sq_ay, + (StableHLO_SelectOp + $eq_ax_ay, + $one, + (StableHLO_SelectOp + $lt_ax_ay, + (StableHLO_DivOp $sq_ax, $sq_ay), + (StableHLO_DivOp $sq_ay, $sq_ax)))), + (StableHLO_MulOp $sq_1h, $sq_2)), + (StableHLO_DivOp + $sq_ay, + (StableHLO_MulOp $sq_rh, $sq_2))), + (StableHLO_DivOp + $ay, + (StableHLO_MulOp $u_general, $two))))), + (StableHLO_SelectOp + (StableHLO_CompareOp + $x, + $constant_0, + StableHLO_ComparisonDirectionValue<"LT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_SelectOp + (StableHLO_CompareOp:$lt_y_constant_0 + $y, + $constant_0, + StableHLO_ComparisonDirectionValue<"LT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_NegOp $u), + $u), + (StableHLO_SelectOp + $lt_y_constant_0, + (StableHLO_NegOp $ay_div_u), + $ay_div_u)))>;