Skip to content

Commit

Permalink
Add log_plus_one tests and generate StablehloComplexMathExpanderPatte…
Browse files Browse the repository at this point in the history
…rns.td. Remove CHLO Log1p.
  • Loading branch information
pearu committed Dec 17, 2024
1 parent a56b4ed commit 4164522
Show file tree
Hide file tree
Showing 13 changed files with 404 additions and 403 deletions.
58 changes: 45 additions & 13 deletions build_tools/math/generate_ChloDecompositionPatternsMath.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_functional_algorithms_required_version():
)


def main():
def main(kind="CHLO"):
try:
import functional_algorithms as fa
except ImportError as msg:
Expand All @@ -64,16 +64,15 @@ def main():
warnings.warn(msg)
return

output_filename = dict(
CHLO="ChloDecompositionPatternsMath.td",
StableHLO="StablehloComplexMathExpanderPatterns.td",
)[kind]

output_file = os.path.relpath(
os.path.normpath(
os.path.join(
os.path.dirname(__file__),
"..",
"..",
"stablehlo",
"transforms",
"ChloDecompositionPatternsMath.td",
)),
os.path.join(os.path.dirname(__file__), "..", "..", "stablehlo",
"transforms", output_filename)),
os.getcwd(),
)

Expand All @@ -98,14 +97,15 @@ def main():
("CHLO_AtanhOp", "complex_atanh", ("z:complex",)),
("CHLO_SquareOp", "complex_square", ("z:complex",)),
("CHLO_SquareOp", "real_square", ("x:float",)),
("CHLO_Log1pOp", "complex_log1p", ("z:complex",)),
("StableHLO_Log1pOp", "complex_log1p", ("z:complex",)),
]:
if not chloname.startswith(kind):
continue
print(f'Generating {chloname} from {fname}{args}')
func = getattr(fa.algorithms, fname, None)
if func is None:
warnings.warn(
f"{fa.algorithms.__name__} does not define {fname}. Skipping."
)
f"{fa.algorithms.__name__} does not define {fname}. Skipping.")
continue
ctx = fa.Context(paths=[fa.algorithms],
parameters=dict(rewrite_keep_integer_literals=True))
Expand All @@ -116,6 +116,16 @@ def main():
sources[-1] += src
source = "\n\n".join(sources) + "\n"

if chloname.startswith('StableHLO_'):
# an ugly hack to fix the definition of stablehlo complex math
# functions. TODO(pearu): add the corresponding feature to
# functional_algorithms stablehlo printer
NameOp = chloname.split('_', 1)[1]
source = source.replace(
f'def : Pat<({chloname}',
f'def {NameOp}_ComplexElementType_ComplexMathExpander : Pat<({chloname}'
)

if os.path.isfile(output_file):
f = open(output_file, "r")
content = f.read()
Expand Down Expand Up @@ -147,10 +157,32 @@ def main():
This file is generated using functional_algorithms tool ({fa.__version__}).
See build_tools/math/README.md for more information.""") + "\n")

if kind == "StableHLO":
f.write("""\
include "mlir/IR/OpBase.td"
include "stablehlo/dialect/StablehloOps.td"
class StableHLO_ComparisonDirectionValue<string enumStr> :
ConstantAttr<StableHLO_ComparisonDirectionAttr,
"::mlir::stablehlo::ComparisonDirection::" # enumStr>;
class StableHLO_ConstantLike<string value> : NativeCodeCall<
"::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">;
def ComplexElementType : Type<
CPred<"isa<ComplexType>(cast<ShapedType>($_self).getElementType())">,
"Complex element type">;
def StableHLO_ConstantLikeMaxFiniteValue : NativeCodeCall<
"::mlir::stablehlo::getConstantLikeMaxFiniteValue($_builder, $_loc, $0)">;
""")
f.write(source)
f.close()
print(f"Created {output_file}")


if __name__ == "__main__":
main()
main(kind="CHLO")
main(kind="StableHLO")
67 changes: 36 additions & 31 deletions build_tools/math/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,31 @@
default_max_ulp_difference = 1

operations = [
# The following dictionaries may have additional keys like
#
# size - defines the number of samples: size ** 2
#
# max_ulp_difference - the maximal allowed ULP difference between
# function and reference values
#
# extra_prec_multiplier - the precison multiplier for mpmath.mp
# that defines the precision of computing reference values:
# mpmath.mp.prec * extra_prec_multiplier
#
# When unspecifed, these parameters are retrieved from
# functional_algorithms database of support functions.
#
dict(name="asin", mpmath_name="arcsin"),
dict(name="acos", mpmath_name="arccos"),
dict(name="atan", mpmath_name="arctan"),
dict(name="asinh", mpmath_name="arcsinh"),
dict(name="acosh", mpmath_name="arccosh"),
dict(name="atanh", mpmath_name="arctanh"),
dict(name="square", mpmath_name="square"),
dict(name="log1p", mpmath_name="log1p"),
# The following dictionaries may have additional keys like
#
# size - defines the number of samples: size ** 2
#
# max_ulp_difference - the maximal allowed ULP difference between
# function and reference values
#
# extra_prec_multiplier - the precison multiplier for mpmath.mp
# that defines the precision of computing reference values:
# mpmath.mp.prec * extra_prec_multiplier
#
# When unspecifed, these parameters are retrieved from
# functional_algorithms database of support functions.
#
dict(name="asin", mpmath_name="arcsin"),
dict(name="acos", mpmath_name="arccos"),
dict(name="atan", mpmath_name="arctan"),
dict(name="asinh", mpmath_name="arcsinh"),
dict(name="acosh", mpmath_name="arccosh"),
dict(name="atanh", mpmath_name="arctanh"),
dict(name="square", mpmath_name="square"),
dict(name="log_plus_one",
mpmath_name="log1p",
namespace="stablehlo",
passes="--stablehlo-complex-math-expander"),
]


Expand Down Expand Up @@ -128,19 +131,21 @@ def main():
for op in operations:
opname = op["name"]
mpmath_opname = op.get("mpmath_name", opname)
namespace = op.get("namespace", "chlo")
size_re = size_im = op.get("size", default_size)

passes = op.get("passes", "--chlo-legalize-to-stablehlo")
for dtype in [np.complex64, np.complex128, np.float32, np.float64]:
params = fa.utils.function_validation_parameters(opname, dtype)
max_ulp_difference = op.get(
"max_ulp_difference",
params.get("max_valid_ulp_count", default_max_ulp_difference))
"max_ulp_difference",
params.get("max_valid_ulp_count", default_max_ulp_difference))

nmp = fa.utils.numpy_with_mpmath(
extra_prec_multiplier = op.get(
"extra_prec_multiplier",
params.get("extra_prec_multiplier", default_extra_prec_multiplier)),
flush_subnormals=flush_subnormals,
extra_prec_multiplier=op.get(
"extra_prec_multiplier",
params.get("extra_prec_multiplier",
default_extra_prec_multiplier)),
flush_subnormals=flush_subnormals,
)

fi = np.finfo(dtype)
Expand Down Expand Up @@ -181,7 +186,7 @@ def main():
main_func = m.make_function("main", "", "", "public")

ref_samples = main_func.call("samples")
actual = main_func.composite(f"chlo.{opname}", ref_samples)
actual = main_func.composite(f"{namespace}.{opname}", ref_samples)
expected = main_func.call("expected")

main_func.void_call(
Expand All @@ -203,7 +208,7 @@ def main():
continue

f = open(fname, "w")
f.write("// RUN: stablehlo-opt --chlo-legalize-to-stablehlo %s |"
f.write(f"// RUN: stablehlo-opt {passes} %s |"
" stablehlo-translate --interpret\n")
f.write(
"// This file is generated, see build_tools/math/README.md for more"
Expand Down
3 changes: 2 additions & 1 deletion docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -4066,7 +4066,8 @@ Performs element-wise logarithm plus one operation on `operand` tensor and
produces a `result` tensor. Depending on the element type, does the following:

* For floats: `logp1` from IEEE-754.
* For complex numbers: complex logarithm plus one.
* For complex numbers:
`complex(log(hypot(real(x) + 1, imag(x))), atan2(imag(x), real(x) + 1))`
* For quantized types:
`dequantize_op_quantize(log_plus_one, operand, type(result))`.

Expand Down
1 change: 0 additions & 1 deletion stablehlo/dialect/ChloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfcOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(ErfInvOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LgammaOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Log1pOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NextAfterOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(PolygammaOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SinhOp)
Expand Down
14 changes: 0 additions & 14 deletions stablehlo/dialect/ChloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -747,20 +747,6 @@ def CHLO_LgammaOp : CHLO_UnaryElementwiseOp<"lgamma",
}];
}

def CHLO_Log1pOp : CHLO_UnaryElementwiseOp<"log1p",
[HLO_CompatibleOperandsAndResultType], HLO_AnyFpOrComplexTensor> {
let summary = "Log1p function";

let description = [{
Returns `Log1p(operand)` element-wise.

$$
\log1p(x) = complex(log(hypot(x.real + 1, x.imag)), arctan2(x.imag, x.real + 1)) if x is a complex number
= log(x + 1) otherwise
$$
}];
}

def CHLO_SquareOp : CHLO_UnaryElementwiseOp<"square",
[HLO_CompatibleOperandsAndResultType], HLO_AnyFpOrComplexTensor> {
let summary = "Square operation";
Expand Down
Loading

0 comments on commit 4164522

Please sign in to comment.