Skip to content

Commit

Permalink
Add StableHLO complex sqrt to stablehlo-complex-math-expander pass (#…
Browse files Browse the repository at this point in the history
…2679)

As in the title.

The [existing implementation of
`sqrt`](https://github.com/openxla/xla/blob/30caa6782b2f49c9ecb8f4727d8628a12ee9a861/xla/service/elemental_ir_emitter.cc#L2111)
on complex inputs uses polar form of complex sqrt which is
inaccurate/incorrect on about 28 % of uniformly distributed samples over
all complex plane. The JAX complex sqrt accuracy statistics is as
follows:

```
test_unary[sqrt-jax-cuda-complex64-default] maximal ULP difference: 2792619172
ULP difference == 0: 297588
ULP difference == 1: 1106945
ULP difference == 2: 78921
ULP difference == 3: 5924
ULP difference == 4: 2938
ULP difference == 5: 2231
ULP difference == 6: 1855
ULP difference == 7: 1648
ULP difference == 8: 1449
ULP difference == 9: 1263
ULP difference == 10: 1089
ULP difference >= 11: 598949

test_unary[sqrt-jax-cpu-complex64-default] maximal ULP difference: 2760370645
ULP difference == 0: 699582
ULP difference == 1: 759750
ULP difference == 2: 58127
ULP difference == 3: 3237
ULP difference == 4: 1665
ULP difference == 5: 1185
ULP difference == 6: 1021
ULP difference == 7: 871
ULP difference == 8: 804
ULP difference == 9: 653
ULP difference == 10: 622
ULP difference >= 11: 573283
```

This PR provides an algorithm for complex sqrt that is accurate up to
3/6 ULP difference error on complex samples. The corresponding JAX
complex sqrt accuracy statistics is as follows:
```
test_unary[sqrt-jax-cuda-complex64-default] maximal ULP difference: 5
ULP difference == 0: 1060571
ULP difference == 1: 1008268
ULP difference == 2: 31136
ULP difference == 3: 686
ULP difference == 4: 129
ULP difference == 5: 10

test_unary[sqrt-jax-cpu-complex64-default] maximal ULP difference: 2
ULP difference == 0: 1348868
ULP difference == 1: 751504
ULP difference == 2: 428
```

It is interesting to note that although the same algorithm is used for
both CUDA and CPU platforms, then the expected maximal ULP difference is
4 (obtained from applying the algorithm to numpy arrays). Hence
- the accuracy of complex sqrt on CPU is better than expected
- the accuracy of complex sqrt on CUDA is worse than expected because
CUDA sqrt produces slightly different results from std sqrt on float
inputs.
  • Loading branch information
pearu authored Jan 8, 2025
1 parent c06b1a1 commit 0008018
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 15 deletions.
2 changes: 1 addition & 1 deletion build_tools/math/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ following requirements:

- Python 3.11 or newer
- mpmath 1.3 or newer
- functional_algorithms 0.12 or newer
- functional_algorithms 0.13.2 or newer

that can be installed via pypi:

Expand Down
22 changes: 10 additions & 12 deletions build_tools/math/generate_ChloDecompositionPatternsMath.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,16 @@ def main(kind="CHLO"):
("CHLO_SquareOp", "complex_square", ("z:complex",)),
("CHLO_SquareOp", "real_square", ("x:float",)),
("StableHLO_Log1pOp", "complex_log1p", ("z:complex",)),
("StableHLO_SqrtOp", "complex_sqrt", ("z:complex",)),
]:
if not chloname.startswith(kind):
continue
print(f'Generating {chloname} from {fname}{args}')
if chloname.startswith("StableHLO_"):
NameOp = chloname.split("_", 1)[1]
expander_name = f"{NameOp}_ComplexElementType_ComplexMathExpander"
else:
expander_name = ""
print(f"Generating {chloname} from {fname}{args}")
func = getattr(fa.algorithms, fname, None)
if func is None:
warnings.warn(
Expand All @@ -110,22 +116,12 @@ def main(kind="CHLO"):
ctx = fa.Context(paths=[fa.algorithms],
parameters=dict(rewrite_keep_integer_literals=True))
graph = ctx.trace(func, *args).rewrite(target, fa.rewrite)
graph.props.update(name=chloname)
graph.props.update(name=chloname, expander_name=expander_name)
src = graph.tostring(target)
sources.append(target.make_comment(func.__doc__)) if func.__doc__ else None
sources[-1] += src
source = "\n\n".join(sources) + "\n"

if chloname.startswith('StableHLO_'):
# an ugly hack to fix the definition of stablehlo complex math
# functions. TODO(pearu): add the corresponding feature to
# functional_algorithms stablehlo printer
NameOp = chloname.split('_', 1)[1]
source = source.replace(
f'def : Pat<({chloname}',
f'def {NameOp}_ComplexElementType_ComplexMathExpander : Pat<({chloname}'
)

if os.path.isfile(output_file):
f = open(output_file, "r")
content = f.read()
Expand Down Expand Up @@ -177,6 +173,8 @@ def ComplexElementType : Type<
def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
"::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
def StableHLO_ConstantLikePosInfValue : NativeCodeCall<
"::mlir::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">;
""")
f.write(source)
f.close()
Expand Down
8 changes: 8 additions & 0 deletions build_tools/math/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
# that defines the precision of computing reference values:
# mpmath.mp.prec * extra_prec_multiplier
#
# namespace - "chlo" or "stablehlo"
#
# passes - a string of pass arguments
#
# When unspecifed, these parameters are retrieved from
# functional_algorithms database of support functions.
#
Expand All @@ -68,6 +72,10 @@
mpmath_name="log1p",
namespace="stablehlo",
passes="--stablehlo-complex-math-expander"),
dict(name="sqrt",
mpmath_name="sqrt",
namespace="stablehlo",
passes="--stablehlo-complex-math-expander"),
]


Expand Down
19 changes: 19 additions & 0 deletions stablehlo/tests/math/sqrt_complex128.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret
// This file is generated, see build_tools/math/README.md for more information.
module @sqrt_complex128 {
func.func private @samples() -> tensor<169xcomplex<f64>> {
%0 = stablehlo.constant dense<"0x000000000000F0FF000000000000F0FFFFFFFFFFFFFFEFFF000000000000F0FFFEFFFFFFFFFFEFFF000000000000F0FF000000000000F8BF000000000000F0FF000000000000FC9F000000000000F0FF0100000000000080000000000000F0FF0000000000000000000000000000F0FF0100000000000000000000000000F0FF000000000000FC1F000000000000F0FF000000000000F83F000000000000F0FFFEFFFFFFFFFFEF7F000000000000F0FFFFFFFFFFFFFFEF7F000000000000F0FF000000000000F07F000000000000F0FF000000000000F0FFFFFFFFFFFFFFEFFFFFFFFFFFFFFFEFFFFFFFFFFFFFFFEFFFFEFFFFFFFFFFEFFFFFFFFFFFFFFFEFFF000000000000F8BFFFFFFFFFFFFFEFFF000000000000FC9FFFFFFFFFFFFFEFFF0100000000000080FFFFFFFFFFFFEFFF0000000000000000FFFFFFFFFFFFEFFF0100000000000000FFFFFFFFFFFFEFFF000000000000FC1FFFFFFFFFFFFFEFFF000000000000F83FFFFFFFFFFFFFEFFFFEFFFFFFFFFFEF7FFFFFFFFFFFFFEFFFFFFFFFFFFFFFEF7FFFFFFFFFFFFFEFFF000000000000F07FFFFFFFFFFFFFEFFF000000000000F0FFFEFFFFFFFFFFEFFFFFFFFFFFFFFFEFFFFEFFFFFFFFFFEFFFFEFFFFFFFFFFEFFFFEFFFFFFFFFFEFFF000000000000F8BFFEFFFFFFFFFFEFFF000000000000FC9FFEFFFFFFFFFFEFFF0100000000000080FEFFFFFFFFFFEFFF0000000000000000FEFFFFFFFFFFEFFF0100000000000000FEFFFFFFFFFFEFFF000000000000FC1FFEFFFFFFFFFFEFFF000000000000F83FFEFFFFFFFFFFEFFFFEFFFFFFFFFFEF7FFEFFFFFFFFFFEFFFFFFFFFFFFFFFEF7FFEFFFFFFFFFFEFFF000000000000F07FFEFFFFFFFFFFEFFF000000000000F0FF000000000000F8BFFFFFFFFFFFFFEFFF000000000000F8BFFEFFFFFFFFFFEFFF000000000000F8BF000000000000F8BF000000000000F8BF000000000000FC9F000000000000F8BF0100000000000080000000000000F8BF0000000000000000000000000000F8BF0100000000000000000000000000F8BF000000000000FC1F000000000000F8BF000000000000F83F000000000000F8BFFEFFFFFFFFFFEF7F000000000000F8BFFFFFFFFFFFFFEF7F000000000000F8BF000000000000F07F000000000000F8BF000000000000F0FF000000000000FC9FFFFFFFFFFFFFEFFF000000000000FC9FFEFFFFFFFFFFEFFF000000000000FC9F000000000000F8BF000000000000FC9F000000000000FC9F000000000000FC9F0100000000000080000000000000FC9F0000000000000000000000000000FC9F0100000000000000000000000000FC9F000000000000FC1F000000000000FC9F000000000000F83F000000000000FC9FFEFFFFFFFFFFEF7F000000000000FC9FFFFFFFFFFFFFEF7F000000000000FC9F000000000000F07F000000000000FC9F000000000000F0FF0100000000000080FFFFFFFFFFFFEFFF0100000000000080FEFFFFFFFFFFEFFF0100000000000080000000000000F8BF0100000000000080000000000000FC9F0100000000000080010000000000008001000000000000800000000000000000010000000000008001000000000000000100000000000080000000000000FC1F0100000000000080000000000000F83F0100000000000080FEFFFFFFFFFFEF7F0100000000000080FFFFFFFFFFFFEF7F0100000000000080000000000000F07F0100000000000080000000000000F0FF0000000000000000FFFFFFFFFFFFEFFF0000000000000000FEFFFFFFFFFFEFFF0000000000000000000000000000F8BF0000000000000000000000000000FC9F0000000000000000010000000000008000000000000000000000000000000000000000000000000001000000000000000000000000000000000000000000FC1F0000000000000000000000000000F83F0000000000000000FEFFFFFFFFFFEF7F0000000000000000FFFFFFFFFFFFEF7F0000000000000000000000000000F07F0000000000000000000000000000F0FF0100000000000000FFFFFFFFFFFFEFFF0100000000000000FEFFFFFFFFFFEFFF0100000000000000000000000000F8BF0100000000000000000000000000FC9F0100000000000000010000000000008001000000000000000000000000000000010000000000000001000000000000000100000000000000000000000000FC1F0100000000000000000000000000F83F0100000000000000FEFFFFFFFFFFEF7F0100000000000000FFFFFFFFFFFFEF7F0100000000000000000000000000F07F0100000000000000000000000000F0FF000000000000FC1FFFFFFFFFFFFFEFFF000000000000FC1FFEFFFFFFFFFFEFFF000000000000FC1F000000000000F8BF000000000000FC1F000000000000FC9F000000000000FC1F0100000000000080000000000000FC1F0000000000000000000000000000FC1F0100000000000000000000000000FC1F000000000000FC1F000000000000FC1F000000000000F83F000000000000FC1FFEFFFFFFFFFFEF7F000000000000FC1FFFFFFFFFFFFFEF7F000000000000FC1F000000000000F07F000000000000FC1F000000000000F0FF000000000000F83FFFFFFFFFFFFFEFFF000000000000F83FFEFFFFFFFFFFEFFF000000000000F83F000000000000F8BF000000000000F83F000000000000FC9F000000000000F83F0100000000000080000000000000F83F0000000000000000000000000000F83F0100000000000000000000000000F83F000000000000FC1F000000000000F83F000000000000F83F000000000000F83FFEFFFFFFFFFFEF7F000000000000F83FFFFFFFFFFFFFEF7F000000000000F83F000000000000F07F000000000000F83F000000000000F0FFFEFFFFFFFFFFEF7FFFFFFFFFFFFFEFFFFEFFFFFFFFFFEF7FFEFFFFFFFFFFEFFFFEFFFFFFFFFFEF7F000000000000F8BFFEFFFFFFFFFFEF7F000000000000FC9FFEFFFFFFFFFFEF7F0100000000000080FEFFFFFFFFFFEF7F0000000000000000FEFFFFFFFFFFEF7F0100000000000000FEFFFFFFFFFFEF7F000000000000FC1FFEFFFFFFFFFFEF7F000000000000F83FFEFFFFFFFFFFEF7FFEFFFFFFFFFFEF7FFEFFFFFFFFFFEF7FFFFFFFFFFFFFEF7FFEFFFFFFFFFFEF7F000000000000F07FFEFFFFFFFFFFEF7F000000000000F0FFFFFFFFFFFFFFEF7FFFFFFFFFFFFFEFFFFFFFFFFFFFFFEF7FFEFFFFFFFFFFEFFFFFFFFFFFFFFFEF7F000000000000F8BFFFFFFFFFFFFFEF7F000000000000FC9FFFFFFFFFFFFFEF7F0100000000000080FFFFFFFFFFFFEF7F0000000000000000FFFFFFFFFFFFEF7F0100000000000000FFFFFFFFFFFFEF7F000000000000FC1FFFFFFFFFFFFFEF7F000000000000F83FFFFFFFFFFFFFEF7FFEFFFFFFFFFFEF7FFFFFFFFFFFFFEF7FFFFFFFFFFFFFEF7FFFFFFFFFFFFFEF7F000000000000F07FFFFFFFFFFFFFEF7F000000000000F0FF000000000000F07FFFFFFFFFFFFFEFFF000000000000F07FFEFFFFFFFFFFEFFF000000000000F07F000000000000F8BF000000000000F07F000000000000FC9F000000000000F07F0100000000000080000000000000F07F0000000000000000000000000000F07F0100000000000000000000000000F07F000000000000FC1F000000000000F07F000000000000F83F000000000000F07FFEFFFFFFFFFFEF7F000000000000F07FFFFFFFFFFFFFEF7F000000000000F07F000000000000F07F000000000000F07F"> : tensor<169xcomplex<f64>>
return %0 : tensor<169xcomplex<f64>>
}
func.func private @expected() -> tensor<169xcomplex<f64>> {
%0 = stablehlo.constant dense<"0x000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF000000000000F07F000000000000F0FF0000000000000000000000000000F0FF28C8F6383120DD5FF8A9FFCA3594F1DF28C8F6383120DD5FF8A9FFCA3594F1DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFF8A9FFCA3594F15F28C8F6383120DDDFF8A9FFCA3594F15F28C8F6383120DDDF000000000000F07F00000000000000000000000000000000000000000000F0FF27C8F6383120DD5FF8A9FFCA3594F1DF27C8F6383120DD5FF8A9FFCA3594F1DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFCC3B7F669EA0E65FCC3B7F669EA0E6DFF8A9FFCA3594F15F27C8F6383120DDDFF8A9FFCA3594F15F27C8F6383120DDDF000000000000F07F00000000000000000000000000000000000000000000F0FF000000000000E81FFFFFFFFFFFFFEFDF010000000000E81FFFFFFFFFFFFFEFDFDBF5B774F7D5E13F54CB78F99B87F5BFAA4C58E87AB6EB3FAA4C58E87AB6EBBFAA4C58E87AB6EB3FAA4C58E87AB6EBBFAA4C58E87AB6EB3FAA4C58E87AB6EBBFAA4C58E87AB6EB3FAA4C58E87AB6EBBFAA4C58E87AB6EB3FAA4C58E87AB6EBBF54CB78F99B87F53FDBF5B774F7D5E1BFFFFFFFFFFFFFEF5F010000000000E89FFFFFFFFFFFFFEF5F000000000000E89F000000000000F07F00000000000000000000000000000000000000000000F0FF0000000000800300FFFFFFFFFFFFEFDF0000000000800300FFFFFFFFFFFFEFDFB6A60AC2A5DCE61F2E2109148E98F3BF1B1B52C0CE43E32F6EACA4EA3741F7AF493F6811EAEEED2F493F6811EAEEEDAF493F6811EAEEED2F493F6811EAEEEDAF493F6811EAEEED2F493F6811EAEEEDAF6EACA4EA3741F72F1B1B52C0CE43E3AF2E2109148E98F33FB6A60AC2A5DCE69FFFFFFFFFFFFFEF5F0000000000800380FFFFFFFFFFFFEF5F0000000000800380000000000000F07F00000000000000000000000000000000000000000000F0FF0000000000000000FFFFFFFFFFFFEFDF0000000000000000FFFFFFFFFFFFEFDF00000000000000002E2109148E98F3BFE7F7A7E69130B80CEAF8D2A97F2AF5AF28C8F63831204D1EF9A9FFCA3594619ECD3B7F669EA0561ECD3B7F669EA0569EF9A9FFCA3594611E28C8F63831204D9EEAF8D2A97F2AF52FE7F7A7E69130B88C2E2109148E98F33F0000000000000080FFFFFFFFFFFFEF5F0000000000000080FFFFFFFFFFFFEF5F0000000000000080000000000000F07F00000000000000000000000000000000000000000000F07F0000000000000000FFFFFFFFFFFFEF5F0000000000000000FFFFFFFFFFFFEF5F00000000000000002E2109148E98F33F0000000000000000EAF8D2A97F2AF52F0000000000000000000000000000601E00000000000000000000000000000000000000000000601E0000000000000000EAF8D2A97F2AF52F00000000000000002E2109148E98F33F0000000000000000FFFFFFFFFFFFEF5F0000000000000000FFFFFFFFFFFFEF5F0000000000000000000000000000F07F00000000000000000000000000000000000000000000F07F0000000000000000FFFFFFFFFFFFEF5F0000000000000000FFFFFFFFFFFFEF5F00000000000000002E2109148E98F33FE7F7A7E69130B80CEAF8D2A97F2AF52F28C8F63831204D1EF9A9FFCA3594611ECD3B7F669EA0561ECD3B7F669EA0561EF9A9FFCA3594611E28C8F63831204D1EEAF8D2A97F2AF52FE7F7A7E69130B80C2E2109148E98F33F0000000000000000FFFFFFFFFFFFEF5F0000000000000000FFFFFFFFFFFFEF5F0000000000000000000000000000F07F00000000000000000000000000000000000000000000F07F0000000000800300FFFFFFFFFFFFEF5F0000000000800300FFFFFFFFFFFFEF5FB6A60AC2A5DCE61F2E2109148E98F33F1B1B52C0CE43E32F6EACA4EA3741F72F493F6811EAEEED2F493F6811EAEEED2F493F6811EAEEED2F493F6811EAEEED2F493F6811EAEEED2F493F6811EAEEED2F6EACA4EA3741F72F1B1B52C0CE43E32F2E2109148E98F33FB6A60AC2A5DCE61FFFFFFFFFFFFFEF5F0000000000800300FFFFFFFFFFFFEF5F0000000000800300000000000000F07F00000000000000000000000000000000000000000000F07F000000000000E81FFFFFFFFFFFFFEF5F010000000000E81FFFFFFFFFFFFFEF5FDBF5B774F7D5E13F54CB78F99B87F53FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3FAA4C58E87AB6EB3F54CB78F99B87F53FDBF5B774F7D5E13FFFFFFFFFFFFFEF5F010000000000E81FFFFFFFFFFFFFEF5F000000000000E81F000000000000F07F00000000000000000000000000000000000000000000F07F27C8F6383120DD5FF8A9FFCA3594F15F27C8F6383120DD5FF8A9FFCA3594F15FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FF8A9FFCA3594F15F27C8F6383120DD5FF8A9FFCA3594F15F27C8F6383120DD5F000000000000F07F00000000000000000000000000000000000000000000F07F28C8F6383120DD5FF8A9FFCA3594F15F28C8F6383120DD5FF8A9FFCA3594F15FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FCC3B7F669EA0E65FF8A9FFCA3594F15F28C8F6383120DD5FF8A9FFCA3594F15F28C8F6383120DD5F000000000000F07F0000000000000000000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F000000000000F07F"> : tensor<169xcomplex<f64>>
return %0 : tensor<169xcomplex<f64>>
}
func.func public @main() {
%0 = call @samples() : () -> tensor<169xcomplex<f64>>
%1 = "stablehlo.sqrt"(%0) : (tensor<169xcomplex<f64>>) -> tensor<169xcomplex<f64>>
%2 = call @expected() : () -> tensor<169xcomplex<f64>>
check.expect_close %1, %2, max_ulp_difference = 4 : tensor<169xcomplex<f64>>, tensor<169xcomplex<f64>>
func.return
}
}
Loading

0 comments on commit 0008018

Please sign in to comment.