From b267a5f194f4fbb675245da7ff3e6f0b6127f775 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Mon, 3 Feb 2025 16:39:13 +0100 Subject: [PATCH] Adapted `_result_type` method for NumPy 2. --- dace/frontend/python/replacements.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index ac371abe2a..c8758496f5 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -26,6 +26,8 @@ import numpy as np import sympy as sp +numpy_version = int(np.version.version.split('.')[0]) + Size = Union[int, dace.symbolic.symbol] Shape = Sequence[Size] if TYPE_CHECKING: @@ -1680,22 +1682,28 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi datatypes = [] dtypes_for_result = [] + dtypes_for_result_np2 = [] for arg in arguments: if isinstance(arg, (data.Array, data.Stream)): datatypes.append(arg.dtype) dtypes_for_result.append(arg.dtype.type) + dtypes_for_result_np2.append(arg.dtype.type) elif isinstance(arg, data.Scalar): datatypes.append(arg.dtype) dtypes_for_result.append(_representative_num(arg.dtype)) + dtypes_for_result_np2.append(arg.dtype.type) elif isinstance(arg, (Number, np.bool_)): datatypes.append(dtypes.dtype_to_typeclass(type(arg))) dtypes_for_result.append(arg) + dtypes_for_result_np2.append(arg) elif symbolic.issymbolic(arg): datatypes.append(_sym_type(arg)) dtypes_for_result.append(_representative_num(_sym_type(arg))) + dtypes_for_result_np2.append(_sym_type(arg).type) elif isinstance(arg, dtypes.typeclass): datatypes.append(arg) dtypes_for_result.append(_representative_num(arg)) + dtypes_for_result_np2.append(arg.type) else: raise TypeError("Type {t} of argument {a} is not supported".format(t=type(arg), a=arg)) @@ -1728,8 +1736,11 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi elif (operator in ('Fabs', 'Cbrt', 'Angles', 'SignBit', 'Spacing', 'Modf', 'Floor', 'Ceil', 'Trunc') and coarse_types[0] == 3): raise TypeError("ufunc '{}' not supported for complex input".format(operator)) + elif operator in ('Ceil', 'Floor', 'Trunc') and coarse_types[0] < 2 and numpy_version < 2: + result_type = dace.float64 + casting[0] = _cast_str(result_type) elif (operator in ('Fabs', 'Rint', 'Exp', 'Log', 'Sqrt', 'Cbrt', 'Trigonometric', 'Angles', 'FpBoolean', - 'Spacing', 'Modf', 'Floor', 'Ceil', 'Trunc') and coarse_types[0] < 2): + 'Spacing', 'Modf') and coarse_types[0] < 2): result_type = dace.float64 casting[0] = _cast_str(result_type) elif operator in ('Frexp'): @@ -1809,7 +1820,10 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi result_type = dace.float64 # All other arithmetic operators and cases of the above operators else: - result_type = _np_result_type(dtypes_for_result) + if numpy_version >= 2: + result_type = _np_result_type(dtypes_for_result_np2) + else: + result_type = _np_result_type(dtypes_for_result) if dtype1 != result_type: left_cast = _cast_str(result_type)