Skip to content

Commit

Permalink
Adapted _result_type method for NumPy 2.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnick83 committed Feb 3, 2025
1 parent ae5338e commit b267a5f
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import numpy as np
import sympy as sp

numpy_version = int(np.version.version.split('.')[0])

Size = Union[int, dace.symbolic.symbol]
Shape = Sequence[Size]
if TYPE_CHECKING:
Expand Down Expand Up @@ -1680,22 +1682,28 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi

datatypes = []
dtypes_for_result = []
dtypes_for_result_np2 = []
for arg in arguments:
if isinstance(arg, (data.Array, data.Stream)):
datatypes.append(arg.dtype)
dtypes_for_result.append(arg.dtype.type)
dtypes_for_result_np2.append(arg.dtype.type)
elif isinstance(arg, data.Scalar):
datatypes.append(arg.dtype)
dtypes_for_result.append(_representative_num(arg.dtype))
dtypes_for_result_np2.append(arg.dtype.type)
elif isinstance(arg, (Number, np.bool_)):
datatypes.append(dtypes.dtype_to_typeclass(type(arg)))
dtypes_for_result.append(arg)
dtypes_for_result_np2.append(arg)
elif symbolic.issymbolic(arg):
datatypes.append(_sym_type(arg))
dtypes_for_result.append(_representative_num(_sym_type(arg)))
dtypes_for_result_np2.append(_sym_type(arg).type)
elif isinstance(arg, dtypes.typeclass):
datatypes.append(arg)
dtypes_for_result.append(_representative_num(arg))
dtypes_for_result_np2.append(arg.type)
else:
raise TypeError("Type {t} of argument {a} is not supported".format(t=type(arg), a=arg))

Expand Down Expand Up @@ -1728,8 +1736,11 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi
elif (operator in ('Fabs', 'Cbrt', 'Angles', 'SignBit', 'Spacing', 'Modf', 'Floor', 'Ceil', 'Trunc')
and coarse_types[0] == 3):
raise TypeError("ufunc '{}' not supported for complex input".format(operator))
elif operator in ('Ceil', 'Floor', 'Trunc') and coarse_types[0] < 2 and numpy_version < 2:
result_type = dace.float64
casting[0] = _cast_str(result_type)
elif (operator in ('Fabs', 'Rint', 'Exp', 'Log', 'Sqrt', 'Cbrt', 'Trigonometric', 'Angles', 'FpBoolean',
'Spacing', 'Modf', 'Floor', 'Ceil', 'Trunc') and coarse_types[0] < 2):
'Spacing', 'Modf') and coarse_types[0] < 2):
result_type = dace.float64
casting[0] = _cast_str(result_type)
elif operator in ('Frexp'):
Expand Down Expand Up @@ -1809,7 +1820,10 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi
result_type = dace.float64
# All other arithmetic operators and cases of the above operators
else:
result_type = _np_result_type(dtypes_for_result)
if numpy_version >= 2:
result_type = _np_result_type(dtypes_for_result_np2)
else:
result_type = _np_result_type(dtypes_for_result)

if dtype1 != result_type:
left_cast = _cast_str(result_type)
Expand Down

0 comments on commit b267a5f

Please sign in to comment.