From 0008018c96f25a657c6b2b705a4c78711d2cfce0 Mon Sep 17 00:00:00 2001 From: Pearu Peterson Date: Wed, 8 Jan 2025 20:51:39 +0200 Subject: [PATCH] Add StableHLO complex sqrt to stablehlo-complex-math-expander pass (#2679) As in the title. The [existing implementation of `sqrt`](https://github.com/openxla/xla/blob/30caa6782b2f49c9ecb8f4727d8628a12ee9a861/xla/service/elemental_ir_emitter.cc#L2111) on complex inputs uses polar form of complex sqrt which is inaccurate/incorrect on about 28 % of uniformly distributed samples over all complex plane. The JAX complex sqrt accuracy statistics is as follows: ``` test_unary[sqrt-jax-cuda-complex64-default] maximal ULP difference: 2792619172 ULP difference == 0: 297588 ULP difference == 1: 1106945 ULP difference == 2: 78921 ULP difference == 3: 5924 ULP difference == 4: 2938 ULP difference == 5: 2231 ULP difference == 6: 1855 ULP difference == 7: 1648 ULP difference == 8: 1449 ULP difference == 9: 1263 ULP difference == 10: 1089 ULP difference >= 11: 598949 test_unary[sqrt-jax-cpu-complex64-default] maximal ULP difference: 2760370645 ULP difference == 0: 699582 ULP difference == 1: 759750 ULP difference == 2: 58127 ULP difference == 3: 3237 ULP difference == 4: 1665 ULP difference == 5: 1185 ULP difference == 6: 1021 ULP difference == 7: 871 ULP difference == 8: 804 ULP difference == 9: 653 ULP difference == 10: 622 ULP difference >= 11: 573283 ``` This PR provides an algorithm for complex sqrt that is accurate up to 3/6 ULP difference error on complex samples. The corresponding JAX complex sqrt accuracy statistics is as follows: ``` test_unary[sqrt-jax-cuda-complex64-default] maximal ULP difference: 5 ULP difference == 0: 1060571 ULP difference == 1: 1008268 ULP difference == 2: 31136 ULP difference == 3: 686 ULP difference == 4: 129 ULP difference == 5: 10 test_unary[sqrt-jax-cpu-complex64-default] maximal ULP difference: 2 ULP difference == 0: 1348868 ULP difference == 1: 751504 ULP difference == 2: 428 ``` It is interesting to note that although the same algorithm is used for both CUDA and CPU platforms, then the expected maximal ULP difference is 4 (obtained from applying the algorithm to numpy arrays). Hence - the accuracy of complex sqrt on CPU is better than expected - the accuracy of complex sqrt on CUDA is worse than expected because CUDA sqrt produces slightly different results from std sqrt on float inputs. --- build_tools/math/README.md | 2 +- .../generate_ChloDecompositionPatternsMath.py | 22 +- build_tools/math/generate_tests.py | 8 + stablehlo/tests/math/sqrt_complex128.mlir | 19 ++ stablehlo/tests/math/sqrt_complex64.mlir | 19 ++ stablehlo/tests/math/sqrt_float32.mlir | 19 ++ stablehlo/tests/math/sqrt_float64.mlir | 19 ++ .../StablehloComplexMathExpander.cpp | 7 + .../StablehloComplexMathExpanderPatterns.td | 232 +++++++++++++++++- 9 files changed, 332 insertions(+), 15 deletions(-) create mode 100644 stablehlo/tests/math/sqrt_complex128.mlir create mode 100644 stablehlo/tests/math/sqrt_complex64.mlir create mode 100644 stablehlo/tests/math/sqrt_float32.mlir create mode 100644 stablehlo/tests/math/sqrt_float64.mlir diff --git a/build_tools/math/README.md b/build_tools/math/README.md index e2f10a8881..f2dca9c228 100644 --- a/build_tools/math/README.md +++ b/build_tools/math/README.md @@ -31,7 +31,7 @@ following requirements: - Python 3.11 or newer - mpmath 1.3 or newer -- functional_algorithms 0.12 or newer +- functional_algorithms 0.13.2 or newer that can be installed via pypi: diff --git a/build_tools/math/generate_ChloDecompositionPatternsMath.py b/build_tools/math/generate_ChloDecompositionPatternsMath.py index 62b99474dc..3d07361027 100644 --- a/build_tools/math/generate_ChloDecompositionPatternsMath.py +++ b/build_tools/math/generate_ChloDecompositionPatternsMath.py @@ -98,10 +98,16 @@ def main(kind="CHLO"): ("CHLO_SquareOp", "complex_square", ("z:complex",)), ("CHLO_SquareOp", "real_square", ("x:float",)), ("StableHLO_Log1pOp", "complex_log1p", ("z:complex",)), + ("StableHLO_SqrtOp", "complex_sqrt", ("z:complex",)), ]: if not chloname.startswith(kind): continue - print(f'Generating {chloname} from {fname}{args}') + if chloname.startswith("StableHLO_"): + NameOp = chloname.split("_", 1)[1] + expander_name = f"{NameOp}_ComplexElementType_ComplexMathExpander" + else: + expander_name = "" + print(f"Generating {chloname} from {fname}{args}") func = getattr(fa.algorithms, fname, None) if func is None: warnings.warn( @@ -110,22 +116,12 @@ def main(kind="CHLO"): ctx = fa.Context(paths=[fa.algorithms], parameters=dict(rewrite_keep_integer_literals=True)) graph = ctx.trace(func, *args).rewrite(target, fa.rewrite) - graph.props.update(name=chloname) + graph.props.update(name=chloname, expander_name=expander_name) src = graph.tostring(target) sources.append(target.make_comment(func.__doc__)) if func.__doc__ else None sources[-1] += src source = "\n\n".join(sources) + "\n" - if chloname.startswith('StableHLO_'): - # an ugly hack to fix the definition of stablehlo complex math - # functions. TODO(pearu): add the corresponding feature to - # functional_algorithms stablehlo printer - NameOp = chloname.split('_', 1)[1] - source = source.replace( - f'def : Pat<({chloname}', - f'def {NameOp}_ComplexElementType_ComplexMathExpander : Pat<({chloname}' - ) - if os.path.isfile(output_file): f = open(output_file, "r") content = f.read() @@ -177,6 +173,8 @@ def ComplexElementType : Type< def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall< "::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; +def StableHLO_ConstantLikePosInfValue : NativeCodeCall< + "::mlir::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">; """) f.write(source) f.close() diff --git a/build_tools/math/generate_tests.py b/build_tools/math/generate_tests.py index fe20a1ae0b..8583eeb997 100644 --- a/build_tools/math/generate_tests.py +++ b/build_tools/math/generate_tests.py @@ -54,6 +54,10 @@ # that defines the precision of computing reference values: # mpmath.mp.prec * extra_prec_multiplier # + # namespace - "chlo" or "stablehlo" + # + # passes - a string of pass arguments + # # When unspecifed, these parameters are retrieved from # functional_algorithms database of support functions. # @@ -68,6 +72,10 @@ mpmath_name="log1p", namespace="stablehlo", passes="--stablehlo-complex-math-expander"), + dict(name="sqrt", + mpmath_name="sqrt", + namespace="stablehlo", + passes="--stablehlo-complex-math-expander"), ] diff --git a/stablehlo/tests/math/sqrt_complex128.mlir b/stablehlo/tests/math/sqrt_complex128.mlir new file mode 100644 index 0000000000..adb5edbbab --- /dev/null +++ b/stablehlo/tests/math/sqrt_complex128.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @sqrt_complex128 { + func.func private @samples() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0x000000000000F0FF000000000000F0FFFFFFFFFFFFFFEFFF000000000000F0FFFEFFFFFFFFFFEFFF000000000000F0FF000000000000F8BF000000000000F0FF000000000000FC9F000000000000F0FF0100000000000080000000000000F0FF0000000000000000000000000000F0FF0100000000000000000000000000F0FF000000000000FC1F000000000000F0FF000000000000F83F000000000000F0FFFEFFFFFFFFFFEF7F000000000000F0FFFFFFFFFFFFFFEF7F000000000000F0FF000000000000F07F000000000000F0FF000000000000F0FFFFFFFFFFFFFFEFFFFFFFFFFFFFFFEFFFFFFFFFFFFFFFEFFFFEFFFFFFFFFFEFFFFFFFFFFFFFFFEFFF000000000000F8BFFFFFFFFFFFFFEFFF000000000000FC9FFFFFFFFFFFFFEFFF0100000000000080FFFFFFFFFFFFEFFF0000000000000000FFFFFFFFFFFFEFFF0100000000000000FFFFFFFFFFFFEFFF000000000000FC1FFFFFFFFFFFFFEFFF000000000000F83FFFFFFFFFFFFFEFFFFEFFFFFFFFFFEF7FFFFFFFFFFFFFEFFFFFFFFFFFFFFFEF7FFFFFFFFFFFFFEFFF000000000000F07FFFFFFFFFFFFFEFFF000000000000F0FFFEFFFFFFFFFFEFFFFFFFFFFFFFFFEFFFFEFFFFFFFFFFEFFFFEFFFFFFFFFFEFFFFEFFFFFFFFFFEFFF000000000000F8BFFEFFFFFFFFFFEFFF000000000000FC9FFEFFFFFFFFFFEFFF0100000000000080FEFFFFFFFFFFEFFF0000000000000000FEFFFFFFFFFFEFFF0100000000000000FEFFFFFFFFFFEFFF000000000000FC1FFEFFFFFFFFFFEFFF000000000000F83FFEFFFFFFFFFFEFFFFEFFFFFFFFFFEF7FFEFFFFFFFFFFEFFFFFFFFFFFFFFFEF7FFEFFFFFFFFFFEFFF000000000000F07FFEFFFFFFFFFFEFFF000000000000F0FF000000000000F8BFFFFFFFFFFFFFEFFF000000000000F8BFFEFFFFFFFFFFEFFF000000000000F8BF000000000000F8BF000000000000F8BF000000000000FC9F000000000000F8BF0100000000000080000000000000F8BF0000000000000000000000000000F8BF0100000000000000000000000000F8BF000000000000FC1F000000000000F8BF000000000000F83F000000000000F8BFFEFFFFFFFFFFEF7F000000000000F8BFFFFFFFFFFFFFEF7F000000000000F8BF000000000000F07F000000000000F8BF000000000000F0FF000000000000FC9FFFFFFFFFFFFFEFFF000000000000FC9FFEFFFFFFFFFFEFFF000000000000FC9F000000000000F8BF000000000000FC9F000000000000FC9F000000000000FC9F0100000000000080000000000000FC9F0000000000000000000000000000FC9F0100000000000000000000000000FC9F000000000000FC1F000000000000FC9F000000000000F83F000000000000FC9FFEFFFFFFFFFFEF7F000000000000FC9FFFFFFFFFFFFFEF7F000000000000FC9F000000000000F07F000000000000FC9F000000000000F0FF0100000000000080FFFFFFFFFFFFEFFF0100000000000080FEFFFFFFFFFFEFFF0100000000000080000000000000F8BF0100000000000080000000000000FC9F0100000000000080010000000000008001000000000000800000000000000000010000000000008001000000000000000100000000000080000000000000FC1F0100000000000080000000000000F83F0100000000000080FEFFFFFFFFFFEF7F0100000000000080FFFFFFFFFFFFEF7F0100000000000080000000000000F07F0100000000000080000000000000F0FF0000000000000000FFFFFFFFFFFFEFFF0000000000000000FEFFFFFFFFFFEFFF0000000000000000000000000000F8BF0000000000000000000000000000FC9F0000000000000000010000000000008000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000FC1F0000000000000000000000000000F83F0000000000000000FEFFFFFFFFFFEF7F0000000000000000FFFFFFFFFFFFEF7F0000000000000000000000000000F07F0000000000000000000000000000F0FF0100000000000000FFFFFFFFFFFFEFFF0100000000000000FEFFFFFFFFFFEFFF0100000000000000000000000000F8BF0100000000000000000000000000FC9F0100000000000000010000000000008001000000000000000000000000000000010000000000000001000000000000000100000000000000000000000000FC1F0100000000000000000000000000F83F0100000000000000FEFFFFFFFFFFEF7F0100000000000000FFFFFFFFFFFFEF7F0100000000000000000000000000F07F0100000000000000000000000000F0FF000000000000FC1FFFFFFFFFFFFFEFFF000000000000FC1FFEFFFFFFFFFFEFFF000000000000FC1F000000000000F8BF000000000000FC1F000000000000FC9F000000000000FC1F0100000000000080000000000000FC1F0000000000000000000000000000FC1F0100000000000000000000000000FC1F000000000000FC1F000000000000FC1F000000000000F83F000000000000FC1FFEFFFFFFFFFFEF7F000000000000FC1FFFFFFFFFFFFFEF7F000000000000FC1F000000000000F07F000000000000FC1F000000000000F0FF000000000000F83FFFFFFFFFFFFFEFFF000000000000F83FFEFFFFFFFFFFEFFF000000000000F83F000000000000F8BF000000000000F83F000000000000FC9F000000000000F83F0100000000000080000000000000F83F0000000000000000000000000000F83F0100000000000000000000000000F83F000000000000FC1F000000000000F83F000000000000F83F000000000000F83FFEFFFFFFFFFFEF7F000000000000F83FFFFFFFFFFFFFEF7F000000000000F83F000000000000F07F000000000000F83F000000000000F0FFFEFFFFFFFFFFEF7FFFFFFFFFFFFFEFFFFEFFFFFFFFFFEF7FFEFFFFFFFFFFEFFFFEFFFFFFFFFFEF7F000000000000F8BFFEFFFFFFFFFFEF7F000000000000FC9FFEFFFFFFFFFFEF7F0100000000000080FEFFFFFFFFFFEF7F0000000000000000FEFFFFFFFFFFEF7F0100000000000000FEFFFFFFFFFFEF7F000000000000FC1FFEFFFFFFFFFFEF7F000000000000F83FFEFFFFFFFFFFEF7FFEFFFFFFFFFFEF7FFEFFFFFFFFFFEF7FFFFFFFFFFFFFEF7FFEFFFFFFFFFFEF7F000000000000F07FFEFFFFFFFFFFEF7F000000000000F0FFFFFFFFFFFFFFEF7FFFFFFFFFFFFFEFFFFFFFFFFFFFFFEF7FFEFFFFFFFFFFEFFFFFFFFFFFFFFFEF7F000000000000F8BFFFFFFFFFFFFFEF7F000000000000FC9FFFFFFFFFFFFFEF7F0100000000000080FFFFFFFFFFFFEF7F0000000000000000FFFFFFFFFFFFEF7F0100000000000000FFFFFFFFFFFFEF7F000000000000FC1FFFFFFFFFFFFFEF7F000000000000F83FFFFFFFFFFFFFEF7FFEFFFFFFFFFFEF7FFFFFFFFFFFFFEF7FFFFFFFFFFFFFEF7FFFFFFFFFFFFFEF7F000000000000F07FFFFFFFFFFFFFEF7F000000000000F0FF000000000000F07FFFFFFFFFFFFFEFFF000000000000F07FFEFFFFFFFFFFEFFF000000000000F07F000000000000F8BF000000000000F07F000000000000FC9F000000000000F07F0100000000000080000000000000F07F0000000000000000000000000000F07F0100000000000000000000000000F07F000000000000FC1F000000000000F07F000000000000F83F000000000000F07FFEFFFFFFFFFFEF7F000000000000F07FFFFFFFFFFFFFEF7F000000000000F07F000000000000F07F000000000000F07F"> : tensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func private @expected() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0x000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF0000000000000000000000000000F0FF28C8F6383120DD5FF8A9FFCA3594F1DF28C8F6383120DD5FF8A9FFCA3594F1DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFF8A9FFCA3594F15F28C8F6383120DDDFF8A9FFCA3594F15F28C8F6383120DDDF000000000000F07F00000000000000000000000000000000000000000000F0FF27C8F6383120DD5FF8A9FFCA3594F1DF27C8F6383120DD5FF8A9FFCA3594F1DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFF8A9FFCA3594F15F27C8F6383120DDDFF8A9FFCA3594F15F27C8F6383120DDDF000000000000F07F00000000000000000000000000000000000000000000F0FF000000000000E81FFFFFFFFFFFFFEFDF010000000000E81FFFFFFFFFFFFFEFDFDBF5B774F7D5E13F54CB78F99B87F5BFAA4C58E87AB6EB3FAA4C58E87AB6EBBFAA4C58E87AB6EB3FAA4C58E87AB6EBBFAA4C58E87AB6EB3FAA4C58E87AB6EBBFAA4C58E87AB6EB3FAA4C58E87AB6EBBFAA4C58E87AB6EB3FAA4C58E87AB6EBBF54CB78F99B87F53FDBF5B774F7D5E1BFFFFFFFFFFFFFEF5F010000000000E89FFFFFFFFFFFFFEF5F000000000000E89F000000000000F07F00000000000000000000000000000000000000000000F0FF0000000000800300FFFFFFFFFFFFEFDF0000000000800300FFFFFFFFFFFFEFDFB6A60AC2A5DCE61F2E2109148E98F3BF1B1B52C0CE43E32F6EACA4EA3741F7AF493F6811EAEEED2F493F6811EAEEEDAF493F6811EAEEED2F493F6811EAEEEDAF493F6811EAEEED2F493F6811EAEEEDAF6EACA4EA3741F72F1B1B52C0CE43E3AF2E2109148E98F33FB6A60AC2A5DCE69FFFFFFFFFFFFFEF5F0000000000800380FFFFFFFFFFFFEF5F0000000000800380000000000000F07F00000000000000000000000000000000000000000000F0FF0000000000000000FFFFFFFFFFFFEFDF0000000000000000FFFFFFFFFFFFEFDF00000000000000002E2109148E98F3BFE7F7A7E69130B80CEAF8D2A97F2AF5AF28C8F63831204D1EF9A9FFCA3594619ECD3B7F669EA0561ECD3B7F669EA0569EF9A9FFCA3594611E28C8F63831204D9EEAF8D2A97F2AF52FE7F7A7E69130B88C2E2109148E98F33F0000000000000080FFFFFFFFFFFFEF5F0000000000000080FFFFFFFFFFFFEF5F0000000000000080000000000000F07F00000000000000000000000000000000000000000000F07F0000000000000000FFFFFFFFFFFFEF5F0000000000000000FFFFFFFFFFFFEF5F00000000000000002E2109148E98F33F0000000000000000EAF8D2A97F2AF52F0000000000000000000000000000601E00000000000000000000000000000000000000000000601E0000000000000000EAF8D2A97F2AF52F00000000000000002E2109148E98F33F0000000000000000FFFFFFFFFFFFEF5F0000000000000000FFFFFFFFFFFFEF5F0000000000000000000000000000F07F00000000000000000000000000000000000000000000F07F0000000000000000FFFFFFFFFFFFEF5F0000000000000000FFFFFFFFFFFFEF5F00000000000000002E2109148E98F33FE7F7A7E69130B80CEAF8D2A97F2AF52F28C8F63831204D1EF9A9FFCA3594611ECD3B7F669EA0561ECD3B7F669EA0561EF9A9FFCA3594611E28C8F63831204D1EEAF8D2A97F2AF52FE7F7A7E69130B80C2E2109148E98F33F0000000000000000FFFFFFFFFFFFEF5F0000000000000000FFFFFFFFFFFFEF5F0000000000000000000000000000F07F00000000000000000000000000000000000000000000F07F0000000000800300FFFFFFFFFFFFEF5F0000000000800300FFFFFFFFFFFFEF5FB6A60AC2A5DCE61F2E2109148E98F33F1B1B52C0CE43E32F6EACA4EA3741F72F493F6811EAEEED2F493F6811EAEEED2F493F6811EAEEED2F493F6811EAEEED2F493F6811EAEEED2F493F6811EAEEED2F6EACA4EA3741F72F1B1B52C0CE43E32F2E2109148E98F33FB6A60AC2A5DCE61FFFFFFFFFFFFFEF5F0000000000800300FFFFFFFFFFFFEF5F0000000000800300000000000000F07F00000000000000000000000000000000000000000000F07F000000000000E81FFFFFFFFFFFFFEF5F010000000000E81FFFFFFFFFFFFFEF5FDBF5B774F7D5E13F54CB78F99B87F53FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3F54CB78F99B87F53FDBF5B774F7D5E13FFFFFFFFFFFFFEF5F010000000000E81FFFFFFFFFFFFFEF5F000000000000E81F000000000000F07F00000000000000000000000000000000000000000000F07F27C8F6383120DD5FF8A9FFCA3594F15F27C8F6383120DD5FF8A9FFCA3594F15FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FF8A9FFCA3594F15F27C8F6383120DD5FF8A9FFCA3594F15F27C8F6383120DD5F000000000000F07F00000000000000000000000000000000000000000000F07F28C8F6383120DD5FF8A9FFCA3594F15F28C8F6383120DD5FF8A9FFCA3594F15FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FF8A9FFCA3594F15F28C8F6383120DD5FF8A9FFCA3594F15F28C8F6383120DD5F000000000000F07F0000000000000000000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F"> : tensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xcomplex> + %1 = "stablehlo.sqrt"(%0) : (tensor<169xcomplex>) -> tensor<169xcomplex> + %2 = call @expected() : () -> tensor<169xcomplex> + check.expect_close %1, %2, max_ulp_difference = 4 : tensor<169xcomplex>, tensor<169xcomplex> + func.return + } +} diff --git a/stablehlo/tests/math/sqrt_complex64.mlir b/stablehlo/tests/math/sqrt_complex64.mlir new file mode 100644 index 0000000000..4787262192 --- /dev/null +++ b/stablehlo/tests/math/sqrt_complex64.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @sqrt_complex64 { + func.func private @samples() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0x000080FF000080FFFFFF7FFF000080FFFEFF7FFF000080FF0000C0BF000080FF0000E09F000080FF01000080000080FF00000000000080FF01000000000080FF0000E01F000080FF0000C03F000080FFFEFF7F7F000080FFFFFF7F7F000080FF0000807F000080FF000080FFFFFF7FFFFFFF7FFFFFFF7FFFFEFF7FFFFFFF7FFF0000C0BFFFFF7FFF0000E09FFFFF7FFF01000080FFFF7FFF00000000FFFF7FFF01000000FFFF7FFF0000E01FFFFF7FFF0000C03FFFFF7FFFFEFF7F7FFFFF7FFFFFFF7F7FFFFF7FFF0000807FFFFF7FFF000080FFFEFF7FFFFFFF7FFFFEFF7FFFFEFF7FFFFEFF7FFF0000C0BFFEFF7FFF0000E09FFEFF7FFF01000080FEFF7FFF00000000FEFF7FFF01000000FEFF7FFF0000E01FFEFF7FFF0000C03FFEFF7FFFFEFF7F7FFEFF7FFFFFFF7F7FFEFF7FFF0000807FFEFF7FFF000080FF0000C0BFFFFF7FFF0000C0BFFEFF7FFF0000C0BF0000C0BF0000C0BF0000E09F0000C0BF010000800000C0BF000000000000C0BF010000000000C0BF0000E01F0000C0BF0000C03F0000C0BFFEFF7F7F0000C0BFFFFF7F7F0000C0BF0000807F0000C0BF000080FF0000E09FFFFF7FFF0000E09FFEFF7FFF0000E09F0000C0BF0000E09F0000E09F0000E09F010000800000E09F000000000000E09F010000000000E09F0000E01F0000E09F0000C03F0000E09FFEFF7F7F0000E09FFFFF7F7F0000E09F0000807F0000E09F000080FF01000080FFFF7FFF01000080FEFF7FFF010000800000C0BF010000800000E09F010000800100008001000080000000000100008001000000010000800000E01F010000800000C03F01000080FEFF7F7F01000080FFFF7F7F010000800000807F01000080000080FF00000000FFFF7FFF00000000FEFF7FFF000000000000C0BF000000000000E09F000000000100008000000000000000000000000001000000000000000000E01F000000000000C03F00000000FEFF7F7F00000000FFFF7F7F000000000000807F00000000000080FF01000000FFFF7FFF01000000FEFF7FFF010000000000C0BF010000000000E09F010000000100008001000000000000000100000001000000010000000000E01F010000000000C03F01000000FEFF7F7F01000000FFFF7F7F010000000000807F01000000000080FF0000E01FFFFF7FFF0000E01FFEFF7FFF0000E01F0000C0BF0000E01F0000E09F0000E01F010000800000E01F000000000000E01F010000000000E01F0000E01F0000E01F0000C03F0000E01FFEFF7F7F0000E01FFFFF7F7F0000E01F0000807F0000E01F000080FF0000C03FFFFF7FFF0000C03FFEFF7FFF0000C03F0000C0BF0000C03F0000E09F0000C03F010000800000C03F000000000000C03F010000000000C03F0000E01F0000C03F0000C03F0000C03FFEFF7F7F0000C03FFFFF7F7F0000C03F0000807F0000C03F000080FFFEFF7F7FFFFF7FFFFEFF7F7FFEFF7FFFFEFF7F7F0000C0BFFEFF7F7F0000E09FFEFF7F7F01000080FEFF7F7F00000000FEFF7F7F01000000FEFF7F7F0000E01FFEFF7F7F0000C03FFEFF7F7FFEFF7F7FFEFF7F7FFFFF7F7FFEFF7F7F0000807FFEFF7F7F000080FFFFFF7F7FFFFF7FFFFFFF7F7FFEFF7FFFFFFF7F7F0000C0BFFFFF7F7F0000E09FFFFF7F7F01000080FFFF7F7F00000000FFFF7F7F01000000FFFF7F7F0000E01FFFFF7F7F0000C03FFFFF7F7FFEFF7F7FFFFF7F7FFFFF7F7FFFFF7F7F0000807FFFFF7F7F000080FF0000807FFFFF7FFF0000807FFEFF7FFF0000807F0000C0BF0000807F0000E09F0000807F010000800000807F000000000000807F010000000000807F0000E01F0000807F0000C03F0000807FFEFF7F7F0000807FFFFF7F7F0000807F0000807F0000807F"> : tensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func private @expected() -> tensor<169xcomplex> { + %0 = stablehlo.constant dense<"0x0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF0000807F000080FF00000000000080FF8901E95EAEA18CDF8A01E95EAEA18CDFF304355FF30435DFF304355FF30435DFF304355FF30435DFF304355FF30435DFF304355FF30435DFF304355FF30435DFF304355FF30435DFAEA18C5F8A01E9DEAEA18C5F8901E9DE0000807F0000000000000000000080FF8901E95EAEA18CDF8901E95EAEA18CDFF204355FF20435DFF204355FF20435DFF204355FF20435DFF204355FF20435DFF204355FF20435DFF204355FF20435DFF204355FF20435DFAEA18C5F8901E9DEAEA18C5F8901E9DE0000807F0000000000000000000080FF0000401FFFFF7FDF0100401FFFFF7FDFBCAF0E3FE03CACBFD7B35D3FD7B35DBFD7B35D3FD7B35DBFD7B35D3FD7B35DBFD7B35D3FD7B35DBFD7B35D3FD7B35DBFE03CAC3FBCAF0EBFFFFF7F5F0100409FFFFF7F5F0000409F0000807F0000000000000000000080FF00001C00FFFF7FDF00001C00FFFF7FDF2EE5361F71C49CBF761E1A2FBF09BAAF51776F2F51776FAF51776F2F51776FAF51776F2F51776FAFBF09BA2F761E1AAF71C49C3F2EE5369FFFFF7F5F00001C80FFFF7F5F00001C800000807F0000000000000000000080FF00000000FFFF7FDF00000000FFFF7FDF0000000071C49CBF8F844104FD53A9AF98C2A41911E2469A0000001A0000009A11E2461A98C2A499FD53A92F8F84418471C49C3F00000080FFFF7F5F00000080FFFF7F5F000000800000807F00000000000000000000807F00000000FFFF7F5F00000000FFFF7F5F0000000071C49C3F00000000FD53A92F00000000F304351A0000000000000000F304351A00000000FD53A92F0000000071C49C3F00000000FFFF7F5F00000000FFFF7F5F000000000000807F00000000000000000000807F00000000FFFF7F5F00000000FFFF7F5F0000000071C49C3F8F844104FD53A92F98C2A41911E2461A0000001A0000001A11E2461A98C2A419FD53A92F8F84410471C49C3F00000000FFFF7F5F00000000FFFF7F5F000000000000807F00000000000000000000807F00001C00FFFF7F5F00001C00FFFF7F5F2EE5361F71C49C3F761E1A2FBF09BA2F51776F2F51776F2F51776F2F51776F2F51776F2F51776F2FBF09BA2F761E1A2F71C49C3F2EE5361FFFFF7F5F00001C00FFFF7F5F00001C000000807F00000000000000000000807F0000401FFFFF7F5F0100401FFFFF7F5FBCAF0E3FE03CAC3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FD7B35D3FE03CAC3FBCAF0E3FFFFF7F5F0100401FFFFF7F5F0000401F0000807F00000000000000000000807F8901E95EAEA18C5F8901E95EAEA18C5FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FF204355FAEA18C5F8901E95EAEA18C5F8901E95E0000807F00000000000000000000807F8901E95EAEA18C5F8A01E95EAEA18C5FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FF304355FAEA18C5F8A01E95EAEA18C5F8901E95E0000807F000000000000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F0000807F"> : tensor<169xcomplex> + return %0 : tensor<169xcomplex> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xcomplex> + %1 = "stablehlo.sqrt"(%0) : (tensor<169xcomplex>) -> tensor<169xcomplex> + %2 = call @expected() : () -> tensor<169xcomplex> + check.expect_close %1, %2, max_ulp_difference = 4 : tensor<169xcomplex>, tensor<169xcomplex> + func.return + } +} diff --git a/stablehlo/tests/math/sqrt_float32.mlir b/stablehlo/tests/math/sqrt_float32.mlir new file mode 100644 index 0000000000..e4f96eef2d --- /dev/null +++ b/stablehlo/tests/math/sqrt_float32.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @sqrt_float32 { + func.func private @samples() -> tensor<169xf32> { + %0 = stablehlo.constant dense<"0x000080FFFFFF7FFFFEFF7FFF05E763FC88DAD5FA0BCE47F98EC1B9F711B52BF695A89DF4189C0FF39B8F81F11E83F3EFA17665EE246AD7ECA75D49EB2B51BBE9AE442DE831389FE6B42B11E5371F83E3BA12F5E13D0667E0C1F9D8DE44ED4ADDC7E0BCDB4AD42EDACDC7A0D850BB12D7D3AE84D557A2F6D3DA9568D25D89DAD0E07C4CCF6370BECDE66330CC6957A2CAED4A14C9703E86C7F331F8C576256AC4F918DCC27C0C4EC10000C0BF83F331BE06E7A3BC89DA15BB0CCE87B98FC1F9B712B56BB696A8DDB4199C4FB39C8FC1B11F8333B0A276A5AE256A17ADA85D89AB2C51FBA9AF446DA83238DFA6B52B51A5381FC3A3BB1235A23E06A7A0C2F9189F45ED8A9DC8E0FC9B4BD46E9ACEC7E09851BB5297D4AEC49558A23694DB95A8925E891A91E17C8C8F6470FE8DE763708C6A57E28AEE4A5489713EC687F43138867725AA84FA181C837D0C8E810100008000000000010000007D0C8E01FA181C037725AA04F4313806713EC607EE4A54096A57E20AE763700C6470FE0DE17C8C0F5E891A11DB95A81258A23614D4AEC41551BB5217CEC7E0184BD46E1AC8E0FC1B45ED8A1DC2F9181F3E06A720BB123522381FC323B52B51253238DF26AF446D282C51FB29A85D892B256A172DA276A52E1F8333309C8FC131199C4F3396A8DD3412B56B368FC1F9370CCE873989DA153B06E7A33C83F3313E0000C03F7C0C4E41F918DC4276256A44F331F845703E8647ED4A14496957A24AE663304C6370BE4DE07C4C4F5D89DA50DA95685257A2F653D3AE845550BB1257CDC7A0584AD42E5AC7E0BC5B44ED4A5DC1F9D85E3D066760BA12F561371F8363B42B116531389F66AE442D682B51BB69A75D496B246AD76CA176656E1E83F36F9B8F8171189C0F7395A89D7411B52B768EC1B9770BCE477988DAD57A05E7637CFEFF7F7FFFFF7F7F0000807F"> : tensor<169xf32> + return %0 : tensor<169xf32> + } + func.func private @expected() -> tensor<169xf32> { + %0 = stablehlo.constant dense<"0x0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F0000C07F00000000F304351A70D7862005E74721819313224D26D922C84B9F23D11F6924F5352A258712F8257377B4263F19862780E64628CEE512293F3AD829E9AA9E2AFF43682B719F292C0044F72C84E9B32DFF59852EAFE4452F4C3712302E4DD73067099E315E67673268082933CD74F633255BB334AA9984358CE14436FA871137185FD63740679D38E9896639D670283AEAA4F53A54CCB23B3ED8833C12DD433DD3D7103EF86FD53E71C49C3F9EAB6540BCD8274157D4F441103DB242B41583433DD74244D4261045CA7FD445F7209C467CCC6447174027481203F4485AADB1490752824A05D0414BFC740F4C8C8ED34CD27C9B4D7FEC634EE7A6264F1831F34F2E1DB150338D815166C7405245C20E533A9CD253FFD79A54A50B6355290D2656695EF2568D8CB05733C780585ABD3F59AF0E0E5ACFA8D15A7C329A5BEC29625CDD72255D028BF15DFFFF7F5FFFFF7F5F0000807F"> : tensor<169xf32> + return %0 : tensor<169xf32> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xf32> + %1 = "stablehlo.sqrt"(%0) : (tensor<169xf32>) -> tensor<169xf32> + %2 = call @expected() : () -> tensor<169xf32> + check.expect_close %1, %2, max_ulp_difference = 4 : tensor<169xf32>, tensor<169xf32> + func.return + } +} diff --git a/stablehlo/tests/math/sqrt_float64.mlir b/stablehlo/tests/math/sqrt_float64.mlir new file mode 100644 index 0000000000..17e2bb62d4 --- /dev/null +++ b/stablehlo/tests/math/sqrt_float64.mlir @@ -0,0 +1,19 @@ +// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret +// This file is generated, see build_tools/math/README.md for more information. +module @sqrt_float64 { + func.func private @samples() -> tensor<169xf64> { + %0 = stablehlo.constant dense<"0x000000000000F0FFFFFFFFFFFFFFEFFFFEFFFFFFFFFFEFFF2A51BB12B52BD1FCC0F9189C8FC141FB56A276256A57B2F9EC4AD4AE44ED22F882F331381F8393F6189C8FC1F91804F5AE44ED4AD4AE74F343ED4AD4AE44E5F1D995A85D89DA55F06F3E06E76370C6EE05E763703E0637ED9B8FC1F9189CA7EB31381F83F33118EAC7E07C0CCEC788E85D89DA95A85DF9E6F231381F83F369E588DA95A85D89DAE31E83F331381F4BE2B42B51BB12B5BBE04AD4AE44ED4A2CDFE07C0CCEC7E09CDD76256A57A2760DDC0CCEC7E07C0C7EDAA176256A57A2EED8371F83F331385FD7CDC7E07C0CCECFD563703E06E76340D4F9189C8FC1F9B0D28FC1F9189C8F21D1256A57A2762592CFBB12B52B51BB02CE50BB12B52B5173CCE663703E06E7E3CA7C0CCEC7E07C54C912B52B51BB12C5C7A85D89DA95A835C63E06E763703EA6C4D4AE44ED4AD416C36A57A276256A87C1000000000000F8BF95A85D89DA9568BE2B51BB12B52BD9BCC1F9189C8FC149BB57A276256A57BAB9ED4AD4AE44ED2AB883F331381F839BB6199C8FC1F9180CB5AF44ED4AD4AE7CB344ED4AD4AE44EDB1DA95A85D89DA5DB0703E06E76370CEAE06E763703E063FAD9C8FC1F9189CAFAB32381F83F33120AAC8E07C0CCEC790A85E89DA95A85D01A7F331381F83F371A589DA95A85D89E2A31F83F331381F53A2B52B51BB12B5C3A04BD4AE44ED4A349FE17C0CCEC7E0A49D77256A57A276159C0DCEC7E07C0C869AA276256A57A2F698381F83F331386797CEC7E07C0CCED79564703E06E7634894FA189C8FC1F9B89290C1F9189C8F2991266A57A276259A8FBC12B52B51BB0A8E51BB12B52B517B8CE763703E06E7EB8A7D0CCEC7E07C5C8913B52B51BB12CD87A95D89DA95A83D863F06E763703EAE84D5AE44ED4AD41E836B57A276256A8F810100000000000080000000000000000001000000000000006B57A276256A8F01D5AE44ED4AD41E033F06E763703EAE04A95D89DA95A83D0613B52B51BB12CD077D0CCEC7E07C5C09E763703E06E7EB0A51BB12B52B517B0CBC12B52B51BB0A0E266A57A276259A0F90C1F9189C8F2911FA189C8FC1F9B81264703E06E7634814CEC7E07C0CCED715381F83F331386717A276256A57A2F6180DCEC7E07C0C861A77256A57A276151CE17C0CCEC7E0A41D4BD4AE44ED4A341FB52B51BB12B5C3201F83F331381F532289DA95A85D89E223F331381F83F371255E89DA95A85D0127C8E07C0CCEC7902832381F83F331202A9C8FC1F9189CAF2B06E763703E063F2D703E06E76370CE2EDA95A85D89DA5D3044ED4AD4AE44ED31AF44ED4AD4AE7C33199C8FC1F9180C3583F331381F839B36ED4AD4AE44ED2A3857A276256A57BA39C1F9189C8FC1493B2B51BB12B52BD93C95A85D89DA95683E000000000000F83F6A57A276256A8741D4AE44ED4AD416433E06E763703EA644A85D89DA95A8354612B52B51BB12C5477C0CCEC7E07C5449E663703E06E7E34A50BB12B52B51734CBB12B52B51BB024E256A57A27625924F8FC1F9189C8F2151F9189C8FC1F9B05263703E06E7634054CDC7E07C0CCECF55371F83F331385F57A176256A57A2EE580CCEC7E07C0C7E5A76256A57A2760D5CE07C0CCEC7E09C5D4AD4AE44ED4A2C5FB42B51BB12B5BB601E83F331381F4B6288DA95A85D89DA63F231381F83F369655D89DA95A85DF966C7E07C0CCEC7886831381F83F331186A9B8FC1F9189CA76B05E763703E06376D6F3E06E76370C66ED995A85D89DA557043ED4AD4AE44E571AE44ED4AD4AE7473189C8FC1F918047582F331381F839376EC4AD4AE44ED227856A276256A57B279C0F9189C8FC1417B2A51BB12B52BD17CFEFFFFFFFFFFEF7FFFFFFFFFFFFFEF7F000000000000F07F"> : tensor<169xf64> + return %0 : tensor<169xf64> + } + func.func private @expected() -> tensor<169xf64> { + %0 = stablehlo.constant dense<"0x000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F000000000000F87F0000000000000000000000000000601EE1946D33BAB4BF208F51F73CAB35862115C07C490C1C4F228B51A617ABC8152320B524436280DE23AC8AB58D7E59A524A3F3B97A8DE16D25FF4704F302E834266485B7575A3FFD26FC6461D91174C42789BD348A8F998C287705E77780FD532948A3521DEDEF1B2A9A2FB6F11E84E32ABA6BA5552B42AB2BA0E41E73B707732C0235F74CF98F3A2DAB1F33190D88022E699A9F37FBD8C92EE99E6F93DA04922F9F335838C81C5930C824EA68D07D2131899ACF9CE75AE831B5414FC192F2B032BF31714DCD92773371514384B6624034D2A0E422D5C3063596620E43E5CDCF35A4B9A015A24796366EE8C579B2355F379CC71789FBDA253874501EEB8A9AEE38B89EFB0A2E6CB5393288DAB340FC7D3A0CAC928217FB443BB87D461BA15A0D3C7EC3D72B9287D43C01ACA4D473B59C3DB80391037411643E6F8920197A0C2C3F2E2109148E98F33FFD59CA8F6D5FBB407361E299AB1C8341AAE2ABF5FEAD4A42DE4097F5909D1243228E2D73D4F7D94319C9285AFA1AA244B5401E85873C69451F6D9F239A943146F6D9B155A27BF8468FD071B7160AC14755A105539CB4874853AC18C4077B5049690B6BBAD5E6164AAFB1AD77FCE6DF4AD3B5DB7D8A59A64BF5E9FEA6434F6F4C4DB68BAE3CED354D6F08134A9DB4FE4DE2BA134ACD7EC54EE4BFB33BDC168E4F7375B9C71A0E5550CD40B79DCE751D51F7908B09009BE45127372D203DD1AC52CE2706CD53257453E13B8C22EA283C5440858DFDE7AC035583A7DDA4907CCB55F9759BE0883193561DD49FFBE2CB5A579F5BC010FCB22258D0F483348916EA5842B8C038FF30B259FC3FAE131F5C795AC536967A46AB415B4571BD87319C085C1B6713667A21D15C94540D653BD6975D4CCDF5563593605EFFFFFFFFFFFFEF5FFFFFFFFFFFFFEF5F000000000000F07F"> : tensor<169xf64> + return %0 : tensor<169xf64> + } + func.func public @main() { + %0 = call @samples() : () -> tensor<169xf64> + %1 = "stablehlo.sqrt"(%0) : (tensor<169xf64>) -> tensor<169xf64> + %2 = call @expected() : () -> tensor<169xf64> + check.expect_close %1, %2, max_ulp_difference = 4 : tensor<169xf64>, tensor<169xf64> + func.return + } +} diff --git a/stablehlo/transforms/StablehloComplexMathExpander.cpp b/stablehlo/transforms/StablehloComplexMathExpander.cpp index b830db6196..de05542115 100644 --- a/stablehlo/transforms/StablehloComplexMathExpander.cpp +++ b/stablehlo/transforms/StablehloComplexMathExpander.cpp @@ -29,6 +29,13 @@ static Value getConstantLikeMaxFiniteValue(OpBuilder &b, Location loc, b, loc, llvm::APFloat::getLargest(ty.getFloatSemantics()), val); } +static Value getConstantLikeInfValue(OpBuilder &b, Location loc, Value val, + bool negative) { + auto ty = cast(getElementTypeOrSelf(val.getType())); + return getConstantLike( + b, loc, llvm::APFloat::getInf(ty.getFloatSemantics(), negative), val); +} + //===----------------------------------------------------------------------===// // Pass //===----------------------------------------------------------------------===// diff --git a/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td index 42cb7f72fc..c488f16595 100644 --- a/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td +++ b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // -// This file is generated using functional_algorithms tool (0.12.0). +// This file is generated using functional_algorithms tool (0.13.1). // See build_tools/math/README.md for more information. include "mlir/IR/OpBase.td" @@ -34,6 +34,8 @@ def ComplexElementType : Type< def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall< "::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">; +def StableHLO_ConstantLikePosInfValue : NativeCodeCall< + "::mlir::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">; // Logarithm of 1 + z on complex input: // // log1p(x + I * y) = 0.5 * log((x + 1) ** 2 + y ** 2) + I * arctan2(y, x + 1) @@ -139,7 +141,7 @@ def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall< // the Case A method [verified numerically for float32 and float64]. // // -def Log1pOp_ComplexElementType_ComplexMathExpander : Pat<(StableHLO_Log1pOp ComplexElementType:$z), +def Log1pOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_Log1pOp ComplexElementType:$z), (StableHLO_ComplexOp (StableHLO_SelectOp (StableHLO_CompareOp @@ -273,3 +275,229 @@ def Log1pOp_ComplexElementType_ComplexMathExpander : Pat<(StableHLO_Log1pOp Comp (StableHLO_SubtractOp:$subtract_add_2sum_high__add_2sum_high_0_ $add_2sum_high, $_add_2sum_high_0_))), (StableHLO_SubtractOp $_square_dekker_low_0_, $subtract_add_2sum_high__add_2sum_high_0_)))))))), (StableHLO_Atan2Op $y, $xp1))>; + +// Square root on complex inputs: +// +// sqrt(z) = sqrt((hypot(x, y) + x)/2) + I * sgn(y) * sqrt((hypot(x, y) - x) / 2) +// +// where z = x + I * y, sgn(y) = 1 if y >= 0, and sgn(y) = -1 otherwise. +// +// Algorithm +// --------- +// +// In the above formula, catastrophic cancellation errors occur in +// the imaginary part when x is positive, and in the real part when x +// is negative. To avoid these, let us define +// +// u = sqrt((hypot(x, y) + abs(x))/2) +// v = sgn(y) * sqrt((hypot(x, y) - abs(x))/2) +// +// and find +// +// u * v = sgn(y) * sqrt(hypot(x, y) ** 2 - x ** 2) / 2 = y / 2 +// +// That is, if x > 0, then we have +// +// sqrt(z) = u + I * y / u / 2 +// +// and if x < 0, +// +// sqrt(z) = abs(y) / u / 2 + I * sgn(y) * u +// +// If abs(x) and abs(y) are smaller that smallest normal, then as a +// result of underflow, u will be zero and v will be undefined. On +// the other hand, if abs(x) and abs(y) are close to largest floating +// point number, then `hypot(x, y) + abs(x)` will overflow, and u +// will be `inf`. To address the issues from underflow and overflow, +// we'll use the following formula: +// +// 1. abs(x) == abs(y), or abs(x) == inf and abs(y) == inf, then +// +// u_eq = sqrt(abs(x)) * sqrt((1 + sqrt(2))/2) +// abs(y) / u = sqrt(abs(x)) / sqrt((1 + sqrt(2))/2) +// +// 2. If abs(x) > abs(y) and u == 0 (the underflow case) or u == inf +// (the overflow case), denote r = abs(y) / abs(x), then +// +// u_gt = sqrt(abs(x)) * sqrt((1 + hypot(1, r)) / 2) +// abs(y) / u = sqrt(abs(y)) * sqrt(r) / sqrt((1 + hypot(1, r)) / 2) +// +// 3. If abs(x) < abs(y) and u == 0 (the underflow case) or u == inf +// (the overflow case), denote r = abs(x) / abs(y), then +// +// u_lt = sqrt(abs(y)) * sqrt((r + sqrt(1, r)) / 2) +// abs(y) / u = sqrt(abs(y)) / sqrt((r + sqrt(1, r)) / 2) +// +def SqrtOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_SqrtOp ComplexElementType:$z), + (StableHLO_ComplexOp + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_RealOp:$x $z), + (StableHLO_ConstantLike<"0">:$constant_0 $x), + StableHLO_ComparisonDirectionValue<"GE">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_SelectOp:$u + (StableHLO_CompareOp:$eq_ax_ay + (StableHLO_AbsOp:$ax $x), + (StableHLO_AbsOp:$ay + (StableHLO_ImagOp:$y $z)), + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_DivOp + (StableHLO_MulOp + (StableHLO_SqrtOp:$sq_ax $ax), + (StableHLO_ConstantLike<"1.5537739740300374"> $x)), + (StableHLO_ConstantLike<"1.4142135623730951">:$sq_2 $x)), + (StableHLO_SelectOp + (StableHLO_OrOp:$logical_or_eq_u_general_constant_0_eq_u_general_constant_posinf + (StableHLO_CompareOp + (StableHLO_SqrtOp:$u_general + (StableHLO_AddOp + (StableHLO_DivOp + (StableHLO_SelectOp + (StableHLO_CompareOp + (StableHLO_MaxOp:$mx $ax, $ay), + (StableHLO_MinOp:$mn $ax, $ay), + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_MulOp $sq_2, $mx), + (StableHLO_SelectOp + (StableHLO_AndOp + (StableHLO_CompareOp + (StableHLO_SqrtOp:$sqa + (StableHLO_AddOp + (StableHLO_ConstantLike<"1">:$one $x), + (StableHLO_MulOp:$r + (StableHLO_DivOp:$mn_over_mx $mn, $mx), + $mn_over_mx))), + $one, + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_CompareOp + $r, + $constant_0, + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE))), + (StableHLO_AddOp + $mx, + (StableHLO_DivOp + (StableHLO_MulOp $mx, $r), + (StableHLO_ConstantLike<"2">:$two $x))), + (StableHLO_MulOp $mx, $sqa))), + $two), + (StableHLO_DivOp $ax, $two))), + $constant_0, + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_CompareOp + $u_general, + (StableHLO_ConstantLikePosInfValue $x), + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE))), + (StableHLO_SelectOp + (StableHLO_CompareOp:$gt_ax_ay + $ax, + $ay, + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_MulOp + $sq_ax, + (StableHLO_DivOp + (StableHLO_SqrtOp:$sq_1h + (StableHLO_AddOp + $one, + (StableHLO_SelectOp:$h + (StableHLO_CompareOp + (StableHLO_MaxOp:$_mx_0_ + $one, + (StableHLO_AbsOp:$abs__r_0_ + (StableHLO_SelectOp:$_r_0_ + $eq_ax_ay, + $one, + (StableHLO_SelectOp + (StableHLO_CompareOp:$lt_ax_ay + $ax, + $ay, + StableHLO_ComparisonDirectionValue<"LT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_DivOp $ax, $ay), + (StableHLO_DivOp $ay, $ax))))), + (StableHLO_MinOp:$_mn_0_ $one, $abs__r_0_), + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_MulOp $sq_2, $_mx_0_), + (StableHLO_SelectOp + (StableHLO_AndOp + (StableHLO_CompareOp + (StableHLO_SqrtOp:$_sqa_0_ + (StableHLO_AddOp + $one, + (StableHLO_MulOp:$_r_1_ + (StableHLO_DivOp:$_mn_over_mx_0_ $_mn_0_, $_mx_0_), + $_mn_over_mx_0_))), + $one, + StableHLO_ComparisonDirectionValue<"EQ">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_CompareOp + $_r_1_, + $constant_0, + StableHLO_ComparisonDirectionValue<"GT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE))), + (StableHLO_AddOp + $_mx_0_, + (StableHLO_DivOp + (StableHLO_MulOp $_mx_0_, $_r_1_), + $two)), + (StableHLO_MulOp $_mx_0_, $_sqa_0_))))), + $sq_2)), + (StableHLO_MulOp + (StableHLO_SqrtOp:$sq_ay $ay), + (StableHLO_DivOp + (StableHLO_SqrtOp:$sq_rh + (StableHLO_AddOp $_r_0_, $h)), + $sq_2))), + $u_general)), + (StableHLO_SelectOp:$ay_div_u + $eq_ax_ay, + (StableHLO_DivOp + $sq_ay, + (StableHLO_ConstantLike<"2.19736822693562"> $x)), + (StableHLO_SelectOp + $logical_or_eq_u_general_constant_0_eq_u_general_constant_posinf, + (StableHLO_SelectOp + $gt_ax_ay, + (StableHLO_DivOp + (StableHLO_MulOp + $sq_ay, + (StableHLO_SelectOp + $eq_ax_ay, + $one, + (StableHLO_SelectOp + $lt_ax_ay, + (StableHLO_DivOp $sq_ax, $sq_ay), + (StableHLO_DivOp $sq_ay, $sq_ax)))), + (StableHLO_MulOp $sq_1h, $sq_2)), + (StableHLO_DivOp + $sq_ay, + (StableHLO_MulOp $sq_rh, $sq_2))), + (StableHLO_DivOp + $ay, + (StableHLO_MulOp $u_general, $two))))), + (StableHLO_SelectOp + (StableHLO_CompareOp + $x, + $constant_0, + StableHLO_ComparisonDirectionValue<"LT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_SelectOp + (StableHLO_CompareOp:$lt_y_constant_0 + $y, + $constant_0, + StableHLO_ComparisonDirectionValue<"LT">, + (STABLEHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_NegOp $u), + $u), + (StableHLO_SelectOp + $lt_y_constant_0, + (StableHLO_NegOp $ay_div_u), + $ay_div_u)))>;