Skip to content

Commit

Permalink
Add StableHLO complex sqrt to stablehlo-complex-math-expander pass (#…
Browse files Browse the repository at this point in the history
…2679)

As in the title.

The [existing implementation of
`sqrt`](https://github.com/openxla/xla/blob/30caa6782b2f49c9ecb8f4727d8628a12ee9a861/xla/service/elemental_ir_emitter.cc#L2111)
on complex inputs uses polar form of complex sqrt which is
inaccurate/incorrect on about 28 % of uniformly distributed samples over
all complex plane. The JAX complex sqrt accuracy statistics is as
follows:

```
test_unary[sqrt-jax-cuda-complex64-default] maximal ULP difference: 2792619172
ULP difference == 0: 297588
ULP difference == 1: 1106945
ULP difference == 2: 78921
ULP difference == 3: 5924
ULP difference == 4: 2938
ULP difference == 5: 2231
ULP difference == 6: 1855
ULP difference == 7: 1648
ULP difference == 8: 1449
ULP difference == 9: 1263
ULP difference == 10: 1089
ULP difference >= 11: 598949

test_unary[sqrt-jax-cpu-complex64-default] maximal ULP difference: 2760370645
ULP difference == 0: 699582
ULP difference == 1: 759750
ULP difference == 2: 58127
ULP difference == 3: 3237
ULP difference == 4: 1665
ULP difference == 5: 1185
ULP difference == 6: 1021
ULP difference == 7: 871
ULP difference == 8: 804
ULP difference == 9: 653
ULP difference == 10: 622
ULP difference >= 11: 573283
```

This PR provides an algorithm for complex sqrt that is accurate up to
3/6 ULP difference error on complex samples. The corresponding JAX
complex sqrt accuracy statistics is as follows:
```
test_unary[sqrt-jax-cuda-complex64-default] maximal ULP difference: 5
ULP difference == 0: 1060571
ULP difference == 1: 1008268
ULP difference == 2: 31136
ULP difference == 3: 686
ULP difference == 4: 129
ULP difference == 5: 10

test_unary[sqrt-jax-cpu-complex64-default] maximal ULP difference: 2
ULP difference == 0: 1348868
ULP difference == 1: 751504
ULP difference == 2: 428
```

It is interesting to note that although the same algorithm is used for
both CUDA and CPU platforms, then the expected maximal ULP difference is
4 (obtained from applying the algorithm to numpy arrays). Hence
- the accuracy of complex sqrt on CPU is better than expected
- the accuracy of complex sqrt on CUDA is worse than expected because
CUDA sqrt produces slightly different results from std sqrt on float
inputs.
  • Loading branch information
pearu authored Jan 8, 2025
1 parent c06b1a1 commit 0008018
Show file tree
Hide file tree
Showing 9 changed files with 332 additions and 15 deletions.
2 changes: 1 addition & 1 deletion build_tools/math/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ following requirements:

- Python 3.11 or newer
- mpmath 1.3 or newer
- functional_algorithms 0.12 or newer
- functional_algorithms 0.13.2 or newer

that can be installed via pypi:

Expand Down
22 changes: 10 additions & 12 deletions build_tools/math/generate_ChloDecompositionPatternsMath.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,16 @@ def main(kind="CHLO"):
("CHLO_SquareOp", "complex_square", ("z:complex",)),
("CHLO_SquareOp", "real_square", ("x:float",)),
("StableHLO_Log1pOp", "complex_log1p", ("z:complex",)),
("StableHLO_SqrtOp", "complex_sqrt", ("z:complex",)),
]:
if not chloname.startswith(kind):
continue
print(f'Generating {chloname} from {fname}{args}')
if chloname.startswith("StableHLO_"):
NameOp = chloname.split("_", 1)[1]
expander_name = f"{NameOp}_ComplexElementType_ComplexMathExpander"
else:
expander_name = ""
print(f"Generating {chloname} from {fname}{args}")
func = getattr(fa.algorithms, fname, None)
if func is None:
warnings.warn(
Expand All @@ -110,22 +116,12 @@ def main(kind="CHLO"):
ctx = fa.Context(paths=[fa.algorithms],
parameters=dict(rewrite_keep_integer_literals=True))
graph = ctx.trace(func, *args).rewrite(target, fa.rewrite)
graph.props.update(name=chloname)
graph.props.update(name=chloname, expander_name=expander_name)
src = graph.tostring(target)
sources.append(target.make_comment(func.__doc__)) if func.__doc__ else None
sources[-1] += src
source = "\n\n".join(sources) + "\n"

if chloname.startswith('StableHLO_'):
# an ugly hack to fix the definition of stablehlo complex math
# functions. TODO(pearu): add the corresponding feature to
# functional_algorithms stablehlo printer
NameOp = chloname.split('_', 1)[1]
source = source.replace(
f'def : Pat<({chloname}',
f'def {NameOp}_ComplexElementType_ComplexMathExpander : Pat<({chloname}'
)

if os.path.isfile(output_file):
f = open(output_file, "r")
content = f.read()
Expand Down Expand Up @@ -177,6 +173,8 @@ def ComplexElementType : Type<
def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
"::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
def StableHLO_ConstantLikePosInfValue : NativeCodeCall<
"::mlir::stablehlo::getConstantLikeInfValue($_builder, $_loc, $0, /*negative=*/false)">;
""")
f.write(source)
f.close()
Expand Down
8 changes: 8 additions & 0 deletions build_tools/math/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@
# that defines the precision of computing reference values:
# mpmath.mp.prec * extra_prec_multiplier
#
# namespace - "chlo" or "stablehlo"
#
# passes - a string of pass arguments
#
# When unspecifed, these parameters are retrieved from
# functional_algorithms database of support functions.
#
Expand All @@ -68,6 +72,10 @@
mpmath_name="log1p",
namespace="stablehlo",
passes="--stablehlo-complex-math-expander"),
dict(name="sqrt",
mpmath_name="sqrt",
namespace="stablehlo",
passes="--stablehlo-complex-math-expander"),
]


Expand Down
19 changes: 19 additions & 0 deletions stablehlo/tests/math/sqrt_complex128.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: stablehlo-opt --stablehlo-complex-math-expander %s | stablehlo-translate --interpret
// This file is generated, see build_tools/math/README.md for more information.
module @sqrt_complex128 {
func.func private @samples() -> tensor<169xcomplex<f64>> {
%0 = stablehlo.constant dense<"0xtensor<169xcomplex<f64>>
return %0 : tensor<169xcomplex<f64>>
}
func.func private @expected() -> tensor<169xcomplex<f64>> {
%0 = stablehlo.constant dense<"0xtensor<169xcomplex<f64>>
return %0 : tensor<169xcomplex<f64>>
}
func.func public @main() {
%0 = call @samples() : () -> tensor<169xcomplex<f64>>
%1 = "stablehlo.sqrt"(%0) : (tensor<169xcomplex<f64>>) -> tensor<169xcomplex<f64>>
%2 = call @expected() : () -> tensor<169xcomplex<f64>>
check.expect_close %1, %2, max_ulp_difference = 4 : tensor<169xcomplex<f64>>, tensor<169xcomplex<f64>>
func.return
}
}
Loading

0 comments on commit 0008018

Please sign in to comment.