Skip to content

Commit

Permalink
Add stablehlo-complex-math-expander pass.
Browse files Browse the repository at this point in the history
Add StableHLO log_plus_one on complex inputs to stablehlo-complex-math-expander pass.

Fix inaccuracy issues in StableHLO log_plus_one on complex inputs.
  • Loading branch information
pearu committed Dec 17, 2024
1 parent 9e0c9e3 commit e4ec740
Show file tree
Hide file tree
Showing 17 changed files with 696 additions and 46 deletions.
17 changes: 17 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,21 @@ gentbl_cc_library(
],
)

gentbl_cc_library(
name = "stablehlo_create_complex_math_expander_inc_gen",
tbl_outs = [
(
["--gen-rewriters"],
"stablehlo/transforms/StablehloComplexMathExpanderPatterns.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "stablehlo/transforms/StablehloComplexMathExpanderPatterns.td",
deps = [
":stablehlo_ops_td_files",
],
)

cc_library(
name = "interpreter_ops",
srcs = [
Expand Down Expand Up @@ -1120,6 +1135,7 @@ cc_library(
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
"stablehlo/transforms/StablehloCompatibilityExpander.cpp",
"stablehlo/transforms/StablehloComplexMathExpander.cpp",
"stablehlo/transforms/StablehloConvertToSignless.cpp",
"stablehlo/transforms/StablehloLegalizeCompositeToCall.cpp",
"stablehlo/transforms/StablehloLegalizeDeprecatedOps.cpp",
Expand Down Expand Up @@ -1148,6 +1164,7 @@ cc_library(
":linalg_passes",
":stablehlo_aggressive_simplification_inc_gen",
":stablehlo_create_compatibility_expander_inc_gen",
":stablehlo_create_complex_math_expander_inc_gen",
":stablehlo_legalize_deprecated_ops_inc_gen",
":stablehlo_ops",
":stablehlo_ops_inc_gen",
Expand Down
13 changes: 12 additions & 1 deletion build_tools/math/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ following requirements:

- Python 3.11 or newer
- mpmath 1.3 or newer
- functional_algorithms 0.11.1 or newer
- functional_algorithms 0.12 or newer

that can be installed via pypi:

Expand Down Expand Up @@ -77,6 +77,13 @@ build/bin/stablehlo-opt --chlo-legalize-to-stablehlo --split-input-file --verify

and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`.

A similar procedure is applied for updating
`stablehlo/tests/stablehlo_complex_math_expander.mlir`:
```sh
build/bin/stablehlo-opt --stablehlo-complex-math-expander --split-input-file --verify-diagnostics \
stablehlo/tests/stablehlo_complex_math_expander.mlir | python llvm-project/mlir/utils/generate-test-checks.py | less
```

## A procedure for adding a new algorithm to an existing operation

1. Implement a new algorithm in
Expand All @@ -98,6 +105,10 @@ and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`.
7. Add a record of the operation to
`generate_ChloDecompositionPatternsMath.py`, see the for-loop in
`main` function.
- If the operation is a StableHLO operation on complex inputs, add
it to `stable-complex-math-expander` pass: update
`populateStablehloComplexMathExpanderPatterns` function in
`stablehlo/transforms/StablehloComplexMathExpander.cpp`.
8. Generate new implementations by running
`generate_ChloDecompositionPatternsMath.py` and remove existing
implementations in
Expand Down
57 changes: 45 additions & 12 deletions build_tools/math/generate_ChloDecompositionPatternsMath.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_functional_algorithms_required_version():
)


def main():
def main(kind="CHLO"):
try:
import functional_algorithms as fa
except ImportError as msg:
Expand All @@ -64,16 +64,15 @@ def main():
warnings.warn(msg)
return

output_filename = dict(
CHLO="ChloDecompositionPatternsMath.td",
StableHLO="StablehloComplexMathExpanderPatterns.td",
)[kind]

output_file = os.path.relpath(
os.path.normpath(
os.path.join(
os.path.dirname(__file__),
"..",
"..",
"stablehlo",
"transforms",
"ChloDecompositionPatternsMath.td",
)),
os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo",
"transforms", output_filename)),
os.getcwd(),
)

Expand All @@ -98,13 +97,15 @@ def main():
("CHLO_AtanhOp", "complex_atanh", ("z:complex",)),
("CHLO_SquareOp", "complex_square", ("z:complex",)),
("CHLO_SquareOp", "real_square", ("x:float",)),
("StableHLO_Log1pOp", "complex_log1p", ("z:complex",)),
]:
if not chloname.startswith(kind):
continue
print(f'Generating {chloname} from {fname}{args}')
func = getattr(fa.algorithms, fname, None)
if func is None:
warnings.warn(
f"{fa.algorithms.__name__} does not define {fname}. Skipping."
)
f"{fa.algorithms.__name__} does not define {fname}. Skipping.")
continue
ctx = fa.Context(paths=[fa.algorithms],
parameters=dict(rewrite_keep_integer_literals=True))
Expand All @@ -115,6 +116,16 @@ def main():
sources[-1] += src
source = "\n\n".join(sources) + "\n"

if chloname.startswith('StableHLO_'):
# an ugly hack to fix the definition of stablehlo complex math
# functions. TODO(pearu): add the corresponding feature to
# functional_algorithms stablehlo printer
NameOp = chloname.split('_', 1)[1]
source = source.replace(
f'def : Pat<({chloname}',
f'def {NameOp}_ComplexElementType_ComplexMathExpander : Pat<({chloname}'
)

if os.path.isfile(output_file):
f = open(output_file, "r")
content = f.read()
Expand Down Expand Up @@ -146,10 +157,32 @@ def main():
This file is generated using functional_algorithms tool ({fa.__version__}).
See build_tools/math/README.md for more information.""") + "\n")

if kind == "StableHLO":
f.write("""\
include "mlir/IR/OpBase.td"
include "stablehlo/dialect/StablehloOps.td"
class StableHLO_ComparisonDirectionValue<string enumStr> :
ConstantAttr<StableHLO_ComparisonDirectionAttr,
"::mlir::stablehlo::ComparisonDirection::" # enumStr>;
class StableHLO_ConstantLike<string value> : NativeCodeCall<
"::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;
def ComplexElementType : Type<
CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
"Complex element type">;
def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
"::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
""")
f.write(source)
f.close()
print(f"Created {output_file}")


if __name__ == "__main__":
main()
main(kind="CHLO")
main(kind="StableHLO")
66 changes: 36 additions & 30 deletions build_tools/math/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,31 @@
default_max_ulp_difference = 1

operations = [
# The following dictionaries may have additional keys like
#
# size - defines the number of samples: size ** 2
#
# max_ulp_difference - the maximal allowed ULP difference between
# function and reference values
#
# extra_prec_multiplier - the precison multiplier for mpmath.mp
# that defines the precision of computing reference values:
# mpmath.mp.prec * extra_prec_multiplier
#
# When unspecifed, these parameters are retrieved from
# functional_algorithms database of support functions.
#
dict(name="asin", mpmath_name="arcsin"),
dict(name="acos", mpmath_name="arccos"),
dict(name="atan", mpmath_name="arctan"),
dict(name="asinh", mpmath_name="arcsinh"),
dict(name="acosh", mpmath_name="arccosh"),
dict(name="atanh", mpmath_name="arctanh"),
dict(name="square", mpmath_name="square"),
# The following dictionaries may have additional keys like
#
# size - defines the number of samples: size ** 2
#
# max_ulp_difference - the maximal allowed ULP difference between
# function and reference values
#
# extra_prec_multiplier - the precison multiplier for mpmath.mp
# that defines the precision of computing reference values:
# mpmath.mp.prec * extra_prec_multiplier
#
# When unspecifed, these parameters are retrieved from
# functional_algorithms database of support functions.
#
dict(name="asin", mpmath_name="arcsin"),
dict(name="acos", mpmath_name="arccos"),
dict(name="atan", mpmath_name="arctan"),
dict(name="asinh", mpmath_name="arcsinh"),
dict(name="acosh", mpmath_name="arccosh"),
dict(name="atanh", mpmath_name="arctanh"),
dict(name="square", mpmath_name="square"),
dict(name="log_plus_one",
mpmath_name="log1p",
namespace="stablehlo",
passes="--stablehlo-complex-math-expander"),
]


Expand Down Expand Up @@ -127,19 +131,21 @@ def main():
for op in operations:
opname = op["name"]
mpmath_opname = op.get("mpmath_name", opname)
namespace = op.get("namespace", "chlo")
size_re = size_im = op.get("size", default_size)

passes = op.get("passes", "--chlo-legalize-to-stablehlo")
for dtype in [np.complex64, np.complex128, np.float32, np.float64]:
params = fa.utils.function_validation_parameters(opname, dtype)
max_ulp_difference = op.get(
"max_ulp_difference",
params.get("max_valid_ulp_count", default_max_ulp_difference))
"max_ulp_difference",
params.get("max_valid_ulp_count", default_max_ulp_difference))

nmp = fa.utils.numpy_with_mpmath(
extra_prec_multiplier = op.get(
"extra_prec_multiplier",
params.get("extra_prec_multiplier", default_extra_prec_multiplier)),
flush_subnormals=flush_subnormals,
extra_prec_multiplier=op.get(
"extra_prec_multiplier",
params.get("extra_prec_multiplier",
default_extra_prec_multiplier)),
flush_subnormals=flush_subnormals,
)

fi = np.finfo(dtype)
Expand Down Expand Up @@ -180,7 +186,7 @@ def main():
main_func = m.make_function("main", "", "", "public")

ref_samples = main_func.call("samples")
actual = main_func.composite(f"chlo.{opname}", ref_samples)
actual = main_func.composite(f"{namespace}.{opname}", ref_samples)
expected = main_func.call("expected")

main_func.void_call(
Expand All @@ -202,7 +208,7 @@ def main():
continue

f = open(fname, "w")
f.write("// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |"
f.write(f"// RUN: stablehlo-opt {passes} %s |"
" stablehlo-translate --interpret\n")
f.write(
"// This file is generated, see build_tools/math/README.md for more"
Expand Down
3 changes: 2 additions & 1 deletion docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -4066,7 +4066,8 @@ Performs element-wise logarithm plus one operation on `operand` tensor and
produces a `result` tensor. Depending on the element type, does the following:

* For floats: `logp1` from IEEE-754.
* For complex numbers: complex logarithm plus one.
* For complex numbers:
`complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1))`
* For quantized types:
`dequantize_op_quantize(log_plus_one, operand, type(result))`.

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3924,4 +3924,4 @@ func.func @square_complex_f32(%arg : tensor<complex<f32>>) -> tensor<complex<f32
func.func @square_f32(%arg : tensor<f32>) -> tensor<f32> {
%result = "chlo.square"(%arg) : (tensor<f32>) -> tensor<f32>
func.return %result : tensor<f32>
}
}
Loading

0 comments on commit e4ec740

Please sign in to comment.