diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 4ffdb39773..cc4555ea03 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -90,7 +90,7 @@ def _rebuild(self, *args, **kwargs): handle.update(kwargs) return type(self)(**handle) - @cached_property + @property def ccode(self): """ Generate C code. @@ -152,9 +152,6 @@ def writes(self): """All Basic objects modified by this node.""" return () - def _signature_items(self): - return (str(self.ccode),) - class ExprStmt: diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index ac6fa8806c..8db173e47d 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -176,8 +176,9 @@ class CGen(Visitor): Return a representation of the Iteration/Expression tree as a :module:`cgen` tree. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, language=None, **kwargs): super().__init__(*args, **kwargs) + self.language = language # The following mappers may be customized by subclasses (that is, # backend-specific CGen-erators) @@ -189,6 +190,9 @@ def __init__(self, *args, **kwargs): } _restrict_keyword = 'restrict' + def ccode(self, expr, **kwargs): + return ccode(expr, language=self.language, **kwargs) + def _gen_struct_decl(self, obj, masked=()): """ Convert ctypes.Struct -> cgen.Structure. @@ -222,7 +226,7 @@ def _gen_struct_decl(self, obj, masked=()): try: entries.append(self._gen_value(i, 0, masked=('const',))) except AttributeError: - cstr = ccode(ct) + cstr = self.ccode(ct) if ct is c_restrict_void_p: cstr = '%srestrict' % cstr entries.append(c.Value(cstr, n)) @@ -244,10 +248,10 @@ def _gen_value(self, obj, mode=1, masked=()): if getattr(obj.function, k, False) and v not in masked] if (obj._mem_stack or obj._mem_constant) and mode == 1: - strtype = ccode(obj._C_typedata) - strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape) + strtype = self.ccode(obj._C_typedata) + strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape) else: - strtype = ccode(obj._C_ctype) + strtype = self.ccode(obj._C_ctype) strshape = '' if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1: if not obj._mem_stack: @@ -261,7 +265,7 @@ def _gen_value(self, obj, mode=1, masked=()): strobj = '%s%s' % (strname, strshape) if obj.is_LocalObject and obj.cargs and mode == 1: - arguments = [ccode(i) for i in obj.cargs] + arguments = [self.ccode(i) for i in obj.cargs] strobj = MultilineCall(strobj, arguments, True) value = c.Value(strtype, strobj) @@ -275,9 +279,9 @@ def _gen_value(self, obj, mode=1, masked=()): if obj.is_Array and obj.initvalue is not None and mode == 1: init = ListInitializer(obj.initvalue) if not obj._mem_constant or init.is_numeric: - value = c.Initializer(value, ccode(init)) + value = c.Initializer(value, self.ccode(init)) elif obj.is_LocalObject and obj.initvalue is not None and mode == 1: - value = c.Initializer(value, ccode(obj.initvalue)) + value = c.Initializer(value, self.ccode(obj.initvalue)) return value @@ -311,7 +315,7 @@ def _args_call(self, args): else: ret.append(i._C_name) except AttributeError: - ret.append(ccode(i)) + ret.append(self.ccode(i)) return ret def _gen_signature(self, o, is_declaration=False): @@ -377,7 +381,7 @@ def visit_tuple(self, o): def visit_PointerCast(self, o): f = o.function i = f.indexed - cstr = ccode(i._C_typedata) + cstr = self.ccode(i._C_typedata) if f.is_PointerArray: # lvalue @@ -399,7 +403,7 @@ def visit_PointerCast(self, o): else: v = f.name if o.flat is None: - shape = ''.join("[%s]" % ccode(i) for i in o.castshape) + shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape) rshape = '(*)%s' % shape lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape)) else: @@ -432,9 +436,9 @@ def visit_Dereference(self, o): a0, a1 = o.functions if a1.is_PointerArray or a1.is_TempFunction: i = a1.indexed - cstr = ccode(i._C_typedata) + cstr = self.ccode(i._C_typedata) if o.flat is None: - shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:]) + shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:]) rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name, a1.dim.name) lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape)) @@ -473,8 +477,8 @@ def visit_Definition(self, o): return self._gen_value(o.function) def visit_Expression(self, o): - lhs = ccode(o.expr.lhs, dtype=o.dtype) - rhs = ccode(o.expr.rhs, dtype=o.dtype) + lhs = self.ccode(o.expr.lhs, dtype=o.dtype) + rhs = self.ccode(o.expr.rhs, dtype=o.dtype) if o.init: code = c.Initializer(self._gen_value(o.expr.lhs, 0), rhs) @@ -487,8 +491,8 @@ def visit_Expression(self, o): return code def visit_AugmentedExpression(self, o): - c_lhs = ccode(o.expr.lhs, dtype=o.dtype) - c_rhs = ccode(o.expr.rhs, dtype=o.dtype) + c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype) + c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype) code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs)) if o.pragmas: code = c.Module(self._visit(o.pragmas) + (code,)) @@ -507,7 +511,7 @@ def visit_Call(self, o, nested_call=False): o.templates) if retobj.is_Indexed or \ isinstance(retobj, (FieldFromComposite, FieldFromPointer)): - return c.Assign(ccode(retobj), call) + return c.Assign(self.ccode(retobj), call) else: return c.Initializer(c.Value(rettype, retobj._C_name), call) @@ -521,9 +525,9 @@ def visit_Conditional(self, o): then_body = c.Block(self._visit(then_body)) if else_body: else_body = c.Block(self._visit(else_body)) - return c.If(ccode(o.condition), then_body, else_body) + return c.If(self.ccode(o.condition), then_body, else_body) else: - return c.If(ccode(o.condition), then_body) + return c.If(self.ccode(o.condition), then_body) def visit_Iteration(self, o): body = flatten(self._visit(i) for i in self._blankline_logic(o.children)) @@ -533,23 +537,23 @@ def visit_Iteration(self, o): # For backward direction flip loop bounds if o.direction == Backward: - loop_init = 'int %s = %s' % (o.index, ccode(_max)) - loop_cond = '%s >= %s' % (o.index, ccode(_min)) + loop_init = 'int %s = %s' % (o.index, self.ccode(_max)) + loop_cond = '%s >= %s' % (o.index, self.ccode(_min)) loop_inc = '%s -= %s' % (o.index, o.limits[2]) else: - loop_init = 'int %s = %s' % (o.index, ccode(_min)) - loop_cond = '%s <= %s' % (o.index, ccode(_max)) + loop_init = 'int %s = %s' % (o.index, self.ccode(_min)) + loop_cond = '%s <= %s' % (o.index, self.ccode(_max)) loop_inc = '%s += %s' % (o.index, o.limits[2]) # Append unbounded indices, if any if o.uindices: - uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices] + uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices] loop_init = c.Line(', '.join([loop_init] + uinit)) ustep = [] for i in o.uindices: op = '=' if i.is_Modulo else '+=' - ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr))) + ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr))) loop_inc = c.Line(', '.join([loop_inc] + ustep)) # Create For header+body @@ -566,7 +570,7 @@ def visit_Pragma(self, o): return c.Pragma(o._generate) def visit_While(self, o): - condition = ccode(o.condition) + condition = self.ccode(o.condition) if o.body: body = flatten(self._visit(i) for i in o.children) return c.While(condition, c.Block(body)) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index a5fef9b84f..f210fd9065 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -23,7 +23,7 @@ from devito.operator.profiling import create_profile from devito.operator.registry import operator_selector from devito.mpi import MPI -from devito.parameters import configuration, switchconfig +from devito.parameters import configuration from devito.passes import (Graph, lower_index_derivatives, generate_implicit, generate_macros, minimize_symbols, unevaluate, error_mapper, is_on_device) @@ -761,12 +761,17 @@ def _soname(self): @cached_property def ccode(self): - with switchconfig(compiler=self._compiler, language=self._language): - try: - return self._ccode_handler().visit(self) - except (AttributeError, TypeError): - from devito.ir.iet.visitors import CGen - return CGen().visit(self) + try: + return self._ccode_handler(language=self._language).visit(self) + except (AttributeError, TypeError): + from devito.ir.iet.visitors import CGen + return CGen(language=self._language).visit(self) + + def _signature_items(self): + return (str(self),) + + def __str__(self): + return str(self.ccode) def _jit_compile(self): """ @@ -904,7 +909,8 @@ def apply(self, **kwargs): """ # Compile the operator before building the arguments list # to avoid out of memory with greedy compilers - cfunction = self.cfunction + with self._profiler.timer_on('jit-compile'): + cfunction = self.cfunction # Build the arguments list to invoke the kernel function with self._profiler.timer_on('arguments'): diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 7221e985d5..e4f0945d9c 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -40,7 +40,7 @@ class Graph: The `visit` method collects info about the nodes in the Graph. """ - def __init__(self, iet, options=None, sregistry=None, **kwargs): + def __init__(self, iet, options=None, sregistry=None, language=None, **kwargs): self.efuncs = OrderedDict([(iet.name, iet)]) self.sregistry = sregistry diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index f803117343..1e7a41c9b9 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -765,14 +765,11 @@ class SizeOf(DefFunction): def __new__(cls, intype, stars=None, **kwargs): stars = stars or '' - if not isinstance(intype, (str, ReservedWord)): - intype = dtype_to_ctype(intype) - if intype in ctypes_vector_mapper.values(): - idx = list(ctypes_vector_mapper.values()).index(intype) + ctype = dtype_to_ctype(intype) + if ctype in ctypes_vector_mapper.values(): + idx = list(ctypes_vector_mapper.values()).index(ctype) intype = list(ctypes_vector_mapper.keys())[idx] - else: - intype = ctypes_to_cstr(intype) newobj = super().__new__(cls, 'sizeof', arguments=f'{intype}{stars}', **kwargs) newobj.stars = stars diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 76980ae557..d167b89c73 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -20,7 +20,7 @@ from devito.symbolics.inspection import has_integer_args, sympy_dtype from devito.symbolics.extended_dtypes import c_complex, c_double_complex from devito.types.basic import AbstractFunction -from devito.tools import ctypes_to_cstr +from devito.tools import ctypes_to_cstr, dtype_to_ctype __all__ = ['ccode'] @@ -95,6 +95,10 @@ def parenthesize(self, item, level, strict=False): return super().parenthesize(item, level, strict=strict) def _print_type(self, expr): + try: + expr = dtype_to_ctype(expr) + except TypeError: + pass try: return self.type_mappings[expr] except KeyError: @@ -422,7 +426,7 @@ class AccDevitoPrinter(CXXDevitoPrinter): 'openacc': AccDevitoPrinter} -def ccode(expr, **settings): +def ccode(expr, language=None, **settings): """Generate C++ code from an expression. Parameters @@ -438,5 +442,6 @@ def ccode(expr, **settings): The resulting code as a C++ string. If something went south, returns the input ``expr`` itself. """ - printer = printer_registry.get(configuration['language'], CDevitoPrinter) + lang = language or configuration['language'] + printer = printer_registry.get(lang, CDevitoPrinter) return printer(settings=settings).doprint(expr, None) diff --git a/devito/types/basic.py b/devito/types/basic.py index 75e2ca936e..e8d367d149 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1,7 +1,7 @@ import abc import inspect from collections import namedtuple -from ctypes import POINTER, _Pointer, c_char_p, c_char +from ctypes import POINTER, _Pointer, c_char_p, c_char, Structure from functools import reduce, cached_property from operator import mul @@ -87,13 +87,18 @@ def _C_typedata(self): if isinstance(_type, CustomDtype): return _type + _pointer = False while issubclass(_type, _Pointer): + _pointer = True _type = _type._type_ # `ctypes` treats C strings specially if _type is c_char_p: _type = c_char + if issubclass(_type, Structure) and _pointer: + _type = f'struct {_type.__name__}' + return _type @abc.abstractproperty diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 44d94dbe6a..a0dddec575 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -59,7 +59,8 @@ def _config_kwargs(platform: str, language: str) -> dict[str, str]: @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) @pytest.mark.parametrize('kwargs', _configs) -def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None: +def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str], + expected=None) -> None: """ Tests that half and complex floats' dtypes result in the correct type strings in generated code. @@ -78,7 +79,7 @@ def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> N params: dict[str, Basic] = {p.name: p for p in op.parameters} _u, _c = params['u'], params['c'] assert type(_u.indexed._C_ctype._type_()) == ctypes_vector_mapper[dtype] - assert _c._C_ctype == ctypes_vector_mapper[dtype] + assert _c._C_ctype == expected or ctypes_vector_mapper[dtype] @pytest.mark.parametrize('dtype', [np.complex64, np.complex128])