Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster Call for CompiledSDFG #1467

Merged
merged 28 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
20c6295
It is now possible to wite `math.nan` in a (Python) tasklet.
philip-paul-mueller Nov 21, 2023
93e86e6
Fixed some bugs with the `typeless_nan`.
philip-paul-mueller Nov 21, 2023
a07eede
Removed the exception in the `int` convertion of the `typeless_nan`.
philip-paul-mueller Nov 22, 2023
8a9cb2c
Modified how the function hijacking for `typeless_nan` behaves.
philip-paul-mueller Nov 22, 2023
724a4d1
Merge remote-tracking branch 'spcl/master' into nan_in_math
philip-paul-mueller Nov 22, 2023
d4a656e
Merge branch 'master' into nan_in_math
philip-paul-mueller Nov 30, 2023
9b7254e
New way to call a `CompiledSDFG` object.
philip-paul-mueller Dec 4, 2023
d2a1c85
Split `CompiledSDFG._construct_args()`.
philip-paul-mueller Dec 4, 2023
f629887
Fixed some issue that only surfaced diring the tests.
philip-paul-mueller Dec 4, 2023
a14d6c0
Fixed a missing argument.
philip-paul-mueller Dec 4, 2023
206be9b
Merge remote-tracking branch 'spcl/master' into new_call_method
philip-paul-mueller Dec 6, 2023
f22c7ad
Merge remote-tracking branch 'spcl/master' into new_call_method
Dec 16, 2023
ac2bcc4
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
ebad5e6
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
b8d7743
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
d85fd2f
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
eb6f719
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
8f8b101
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
be10506
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
ddd18c1
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
f600304
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
c4a8881
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
8e5b657
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
57017f8
Update dace/codegen/compiled_sdfg.py
philip-paul-mueller Dec 21, 2023
eb0198a
Addressed Tal's suggestions.
philip-paul-mueller Dec 21, 2023
ff3ced3
Updated the computation of the arguments.
Dec 27, 2023
1a1fca4
Merged the `_construct_args()` back into one function.
Dec 27, 2023
310479b
Made a small correction.
Dec 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 147 additions & 68 deletions dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,21 +366,56 @@ def _get_error_text(self, result: Union[str, int]) -> str:
return result

def __call__(self, *args, **kwargs):
# Update arguments from ordered list
if len(args) > 0 and self.argnames is not None:
kwargs.update({aname: arg for aname, arg in zip(self.argnames, args)})
"""This function forwards the Python call to the complied `C` code.

try:
argtuple, initargtuple = self._construct_args(kwargs)
The positional arguments (`args`) are expected to be in the same order as
`argnames` (which comes from `arg_names` in the source `SDFG`).
This function will roughly do the following steps:
- bringing the arguments in the order dictated by the C callable.
- perfrom some basic checks on the arguments.
- transfroms `ndarray`s to their `C` equivalent.

If you are know what you are doing you can also use `_fast_call()` which
allows to bypass these operations and call the extension directly.
"""
if self.argnames is None and len(args) != 0:
raise KeyError(f"Passed positional arguments to an SDFG that does not accept them.")
elif len(args) > 0 and self.argnames is not None:
kwargs.update(
# `_construct_args` will handle all of its argument as kwargs.
{aname: arg for aname, arg in zip(self.argnames, args)}
)
argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here.
# Return values cached in `self._lastargs`.
return self._fast_call(argtuple, initargtuple)


def _fast_call(
self,
callargs: Tuple[Any, ...],
initargs: Tuple[Any, ...],
) -> Any:
"""This function allows to bypass the construction of arguments.

By default `self.__call__()` will reorder its argument, whose order is given by `argnames`,
to the one that is given by `_sig`, i.e. the one of the C callback,
which is done by `_construct_args()`, and transforms the arguments.
This function excepts that this reordering and transformation has already been done.

:param callargs: Tuples containing the arguments for the call to the C function.
:param initargs: Tupels containing the arguments for the initialization.

This is an internal function that is used for benchmarking.
"""
try:
# Call initializer function if necessary, then SDFG
if self._initialized is False:
self._lib.load()
self._initialize(initargtuple)
self._initialize(initargs)

with hooks.invoke_compiled_sdfg_call_hooks(self, argtuple):
with hooks.invoke_compiled_sdfg_call_hooks(self, callargs):
if self.do_not_execute is False:
self._cfunc(self._libhandle, *argtuple)
self._cfunc(self._libhandle, *callargs)

if self.has_gpu_code:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section may also not belong in a fast call

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand this. According to my understand this calls the actual compiled function. Thus without it nothing would be done, or do I miss something here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at the line the comment points to: the GPU runtime check belongs in the normal call, not fast call. If you want to ensure execution is fast you can skip this check, which might be expensive.

Copy link
Collaborator Author

@philip-paul-mueller philip-paul-mueller Dec 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to give fast_call() the possibility of performing the check, however it is disabled by default. The main reason for that was removing of code duplication.

Thanks for the explanation, so the comment always points to the "last line that is shown"?

# Optionally get errors from call
Expand Down Expand Up @@ -438,33 +473,93 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
argtypes = []
argnames = []
sig = []

# Type checking
self._construct_args_type_checking(argnames, arglist, argtypes)

# Explicit casting
arg_ctypes = self._construct_args_explicit_casting(arglist, argtypes, kwargs)

# Creates the call parameters
callparams = self._construct_args_callparams(arglist, arg_ctypes, argtypes, argnames)

initargs = self._construct_args_initarg(callparams)

try:
# Replace arrays with their base host/device pointers
newargs = []
for arg, actype, atype, _ in callparams:
if dtypes.is_array(arg):
newargs.append( ctypes.c_void_p(_array_interface_ptr(arg, atype.storage)) ) # c_void_p` is subclass of `ctypes._SimpleCData`.
elif not isinstance(arg, (ctypes._SimpleCData)):
newargs.append( actype(arg) )
else:
newargs.append( arg )
#
except TypeError:
# Pinpoint bad argument
for i, (arg, actype, _) in enumerate(newargs):
try:
if not isinstance(arg, ctypes._SimpleCData):
actype(arg)
except TypeError as ex:
raise TypeError(f'Invalid type for scalar argument "{callparams[i][3]}": {ex}')

# Store the arguments.
self._lastargs = newargs, initargs
return self._lastargs


def _construct_args_type_checking(self, argnames, arglist, argtypes):
"""This function performs the type checking for `_construct_args()`

Will modify `arglist`.
It either succeed or throws.
"""
for i, (a, arg, atype) in enumerate(zip(argnames, arglist, argtypes)):
if not dtypes.is_array(arg) and isinstance(atype, dt.Array):
arg_is_array = dtypes.is_array(arg)
atype_is_dtArray = isinstance(atype, dt.Array)
expr1 = atype_is_dtArray and isinstance(arg, np.ndarray)
if not arg_is_array and atype_is_dtArray:
if isinstance(arg, list):
print('WARNING: Casting list argument "%s" to ndarray' % a)
elif arg is None:
if atype.optional is False: # If array cannot be None
raise TypeError(f'Passing a None value to a non-optional array in argument "{a}"')
# Otherwise, None values are passed as null pointers below
else:
raise TypeError('Passing an object (type %s) to an array in argument "%s"' %
(type(arg).__name__, a))
elif dtypes.is_array(arg) and not isinstance(atype, dt.Array):
raise TypeError('Passing an object (type %s) to an array in argument "%s"' % (type(arg).__name__, a))
elif arg_is_array and not atype_is_dtArray:
# GPU scalars and return values are pointers, so this is fine
if atype.storage != dtypes.StorageType.GPU_Global and not a.startswith('__return'):
raise TypeError('Passing an array to a scalar (type %s) in argument "%s"' % (atype.dtype.ctype, a))
elif (expr1 and not isinstance(atype, dt.StructArray)
and atype.dtype.as_numpy_dtype() != arg.dtype):
# Make exception for vector types
if (isinstance(atype.dtype, dtypes.vector) and atype.dtype.vtype.as_numpy_dtype() == arg.dtype):
pass
else:
print('WARNING: Passing %s array argument "%s" to a %s array' %
(arg.dtype, a, atype.dtype.type.__name__))
elif (expr1 and arg.base is not None and not '__return' in a and not Config.get_bool('compiler', 'allow_view_arguments')):
raise TypeError(f'Passing a numpy view (e.g., sub-array or "A.T") "{a}" to DaCe '
'programs is not allowed in order to retain analyzability. '
'Please make a copy with "numpy.copy(...)". If you know what '
'you are doing, you can override this error in the '
'configuration by setting compiler.allow_view_arguments '
'to True.')
elif (not isinstance(atype, (dt.Array, dt.Structure)) and
not isinstance(atype.dtype, dtypes.callback) and
not isinstance(arg, (atype.dtype.type, sp.Basic)) and
not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)):
if isinstance(arg, int) and atype.dtype.type == np.int64:
arg_is_int = isinstance(arg, int)
if arg_is_int and atype.dtype.type == np.int64:
pass
elif isinstance(arg, float) and atype.dtype.type == np.float64:
elif (arg_is_int and atype.dtype.type == np.int32 and abs(arg) <= (1 << 31) - 1):
pass
elif (isinstance(arg, int) and atype.dtype.type == np.int32 and abs(arg) <= (1 << 31) - 1):
elif (arg_is_int and atype.dtype.type == np.uint32 and arg >= 0 and arg <= (1 << 32) - 1):
pass
elif (isinstance(arg, int) and atype.dtype.type == np.uint32 and arg >= 0 and arg <= (1 << 32) - 1):
elif isinstance(arg, float) and atype.dtype.type == np.float64:
pass
elif (isinstance(arg, str) or arg is None) and atype.dtype == dtypes.string:
if arg is None:
Expand All @@ -475,24 +570,14 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
else:
warnings.warn(f'Casting scalar argument "{a}" from {type(arg).__name__} to {atype.dtype.type}')
arglist[i] = atype.dtype.type(arg)
elif (isinstance(atype, dt.Array) and isinstance(arg, np.ndarray) and not isinstance(atype, dt.StructArray)
and atype.dtype.as_numpy_dtype() != arg.dtype):
# Make exception for vector types
if (isinstance(atype.dtype, dtypes.vector) and atype.dtype.vtype.as_numpy_dtype() == arg.dtype):
pass
else:
print('WARNING: Passing %s array argument "%s" to a %s array' %
(arg.dtype, a, atype.dtype.type.__name__))
elif (isinstance(atype, dt.Array) and isinstance(arg, np.ndarray) and arg.base is not None
and not '__return' in a and not Config.get_bool('compiler', 'allow_view_arguments')):
raise TypeError(f'Passing a numpy view (e.g., sub-array or "A.T") "{a}" to DaCe '
'programs is not allowed in order to retain analyzability. '
'Please make a copy with "numpy.copy(...)". If you know what '
'you are doing, you can override this error in the '
'configuration by setting compiler.allow_view_arguments '
'to True.')

# Explicit casting
return True


def _construct_args_explicit_casting(self, arglist, argtypes, kwargs):
"""Explicitly performs the type casting for `_construct_args()`.
"""
arg_ctypes = []
for index, (arg, argtype) in enumerate(zip(arglist, argtypes)):
# Call a wrapper function to make NumPy arrays from pointers.
if isinstance(argtype.dtype, dtypes.callback):
Expand All @@ -503,50 +588,43 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]:
# Null pointer
elif arg is None and isinstance(argtype, dt.Array):
arglist[index] = ctypes.c_void_p(0)
#

# Retain only the element datatype for upcoming checks and casts
arg_ctypes = [t.dtype.as_ctypes() for t in argtypes]
# Retain only the element datatype for upcoming checks and casts
arg_ctypes.append( argtypes[index].dtype.as_ctypes() )
#
return arg_ctypes

sdfg = self._sdfg

# Obtain SDFG constants
constants = sdfg.constants
def _construct_args_callparams(self, arglist, arg_ctypes, argtypes, argnames):
"""Construct the call parameters.
"""
constants = self.sdfg.constants
callparams = []

# Remove symbolic constants from arguments
callparams = tuple((arg, actype, atype, aname)
for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames)
if not symbolic.issymbolic(arg) or (hasattr(arg, 'name') and arg.name not in constants))
for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames):
if symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants):
continue # Ignore symbolic constants with a compiletime value.
callparams.append( # If the argument is a symbol replace it with its actual value.
(actype(arg.get()) if isinstance(arg, symbolic.symbol) else arg, actype, atype, aname)
)
return tuple(callparams)

# Replace symbols with their values
callparams = tuple((actype(arg.get()) if isinstance(arg, symbolic.symbol) else arg, actype, atype, aname)
for arg, actype, atype, aname in callparams)

def _construct_args_initarg(self, callparams):
"""Constructs the initialarguments.
"""
# Construct init args, which only consist of the symbols
symbols = self._free_symbols
initargs = tuple(
actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg
for arg, actype, atype, aname in callparams if aname in symbols)

# Replace arrays with their base host/device pointers
newargs = tuple((ctypes.c_void_p(_array_interface_ptr(arg, atype.storage)), actype,
atype) if dtypes.is_array(arg) else (arg, actype, atype)
for arg, actype, atype, _ in callparams)
initargs = []

try:
newargs = tuple(
actype(arg) if not isinstance(arg, (ctypes._SimpleCData)) else arg
for arg, actype, atype in newargs)
except TypeError:
# Pinpoint bad argument
for i, (arg, actype, _) in enumerate(newargs):
try:
if not isinstance(arg, ctypes._SimpleCData):
actype(arg)
except TypeError as ex:
raise TypeError(f'Invalid type for scalar argument "{callparams[i][3]}": {ex}')
# Construct init args, which only consist of the symbols
for arg, actype, atype, aname in filter(lambda X: X[-1] in symbols, callparams):
initargs.append(
actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg
)
return tuple(initargs)

self._lastargs = newargs, initargs
return self._lastargs

def clear_return_values(self):
self._create_new_arrays = True
Expand Down Expand Up @@ -578,7 +656,8 @@ def ndarray(*args, buffer=None, **kwargs):
def _initialize_return_values(self, kwargs):
# Obtain symbol values from arguments and constants
syms = dict()
syms.update({k: v for k, v in kwargs.items() if k not in self.sdfg.arrays})
sdfg_arrays = self.sdfg.arrays
syms.update({k: v for k, v in kwargs.items() if k not in sdfg_arrays})
syms.update(self.sdfg.constants)

# Clear references from last call (allow garbage collection)
Expand Down
8 changes: 5 additions & 3 deletions dace/runtime/include/dace/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
#ifndef __DACE_MATH_H
#define __DACE_MATH_H

#include "pi.h"
#include "types.h"

#include <complex>
#include <numeric>
#include <cmath>
#include <cfloat>
#include <type_traits>

#include "pi.h"
#include "nan.h"
#include "types.h"

#ifdef __CUDACC__
#include <thrust/complex.h>
#endif
Expand Down Expand Up @@ -457,6 +458,7 @@ namespace dace
namespace math
{
static DACE_CONSTEXPR typeless_pi pi{};
static DACE_CONSTEXPR typeless_nan nan{};
//////////////////////////////////////////////////////
template<typename T>
DACE_CONSTEXPR DACE_HDFI T exp(const T& a)
Expand Down
Loading