Skip to content

Commit

Permalink
compiler: make visitor language parametric
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 23, 2025
1 parent 8ccb91c commit b9b60c3
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 52 deletions.
5 changes: 1 addition & 4 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _rebuild(self, *args, **kwargs):
handle.update(kwargs)
return type(self)(**handle)

@cached_property
@property
def ccode(self):
"""
Generate C code.
Expand Down Expand Up @@ -152,9 +152,6 @@ def writes(self):
"""All Basic objects modified by this node."""
return ()

def _signature_items(self):
return (str(self.ccode),)


class ExprStmt:

Expand Down
58 changes: 31 additions & 27 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ class CGen(Visitor):
Return a representation of the Iteration/Expression tree as a :module:`cgen` tree.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, language=None, **kwargs):
super().__init__(*args, **kwargs)
self.language = language

# The following mappers may be customized by subclasses (that is,
# backend-specific CGen-erators)
Expand All @@ -189,6 +190,9 @@ def __init__(self, *args, **kwargs):
}
_restrict_keyword = 'restrict'

def ccode(self, expr, **kwargs):
return ccode(expr, language=self.language, **kwargs)

def _gen_struct_decl(self, obj, masked=()):
"""
Convert ctypes.Struct -> cgen.Structure.
Expand Down Expand Up @@ -222,7 +226,7 @@ def _gen_struct_decl(self, obj, masked=()):
try:
entries.append(self._gen_value(i, 0, masked=('const',)))
except AttributeError:
cstr = ccode(ct)
cstr = self.ccode(ct)
if ct is c_restrict_void_p:
cstr = '%srestrict' % cstr
entries.append(c.Value(cstr, n))
Expand All @@ -244,10 +248,10 @@ def _gen_value(self, obj, mode=1, masked=()):
if getattr(obj.function, k, False) and v not in masked]

if (obj._mem_stack or obj._mem_constant) and mode == 1:
strtype = ccode(obj._C_typedata)
strshape = ''.join('[%s]' % ccode(i) for i in obj.symbolic_shape)
strtype = self.ccode(obj._C_typedata)
strshape = ''.join('[%s]' % self.ccode(i) for i in obj.symbolic_shape)
else:
strtype = ccode(obj._C_ctype)
strtype = self.ccode(obj._C_ctype)
strshape = ''
if isinstance(obj, (AbstractFunction, IndexedData)) and mode >= 1:
if not obj._mem_stack:
Expand All @@ -261,7 +265,7 @@ def _gen_value(self, obj, mode=1, masked=()):
strobj = '%s%s' % (strname, strshape)

if obj.is_LocalObject and obj.cargs and mode == 1:
arguments = [ccode(i) for i in obj.cargs]
arguments = [self.ccode(i) for i in obj.cargs]
strobj = MultilineCall(strobj, arguments, True)

value = c.Value(strtype, strobj)
Expand All @@ -275,9 +279,9 @@ def _gen_value(self, obj, mode=1, masked=()):
if obj.is_Array and obj.initvalue is not None and mode == 1:
init = ListInitializer(obj.initvalue)
if not obj._mem_constant or init.is_numeric:
value = c.Initializer(value, ccode(init))
value = c.Initializer(value, self.ccode(init))
elif obj.is_LocalObject and obj.initvalue is not None and mode == 1:
value = c.Initializer(value, ccode(obj.initvalue))
value = c.Initializer(value, self.ccode(obj.initvalue))

return value

Expand Down Expand Up @@ -311,7 +315,7 @@ def _args_call(self, args):
else:
ret.append(i._C_name)
except AttributeError:
ret.append(ccode(i))
ret.append(self.ccode(i))
return ret

def _gen_signature(self, o, is_declaration=False):
Expand Down Expand Up @@ -377,7 +381,7 @@ def visit_tuple(self, o):
def visit_PointerCast(self, o):
f = o.function
i = f.indexed
cstr = ccode(i._C_typedata)
cstr = self.ccode(i._C_typedata)

if f.is_PointerArray:
# lvalue
Expand All @@ -399,7 +403,7 @@ def visit_PointerCast(self, o):
else:
v = f.name
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in o.castshape)
shape = ''.join("[%s]" % self.ccode(i) for i in o.castshape)
rshape = '(*)%s' % shape
lvalue = c.Value(cstr, '(*restrict %s)%s' % (v, shape))
else:
Expand Down Expand Up @@ -432,9 +436,9 @@ def visit_Dereference(self, o):
a0, a1 = o.functions
if a1.is_PointerArray or a1.is_TempFunction:
i = a1.indexed
cstr = ccode(i._C_typedata)
cstr = self.ccode(i._C_typedata)
if o.flat is None:
shape = ''.join("[%s]" % ccode(i) for i in a0.symbolic_shape[1:])
shape = ''.join("[%s]" % self.ccode(i) for i in a0.symbolic_shape[1:])
rvalue = '(%s (*)%s) %s[%s]' % (cstr, shape, a1.name,
a1.dim.name)
lvalue = c.Value(cstr, '(*restrict %s)%s' % (a0.name, shape))
Expand Down Expand Up @@ -473,8 +477,8 @@ def visit_Definition(self, o):
return self._gen_value(o.function)

def visit_Expression(self, o):
lhs = ccode(o.expr.lhs, dtype=o.dtype)
rhs = ccode(o.expr.rhs, dtype=o.dtype)
lhs = self.ccode(o.expr.lhs, dtype=o.dtype)
rhs = self.ccode(o.expr.rhs, dtype=o.dtype)

if o.init:
code = c.Initializer(self._gen_value(o.expr.lhs, 0), rhs)
Expand All @@ -487,8 +491,8 @@ def visit_Expression(self, o):
return code

def visit_AugmentedExpression(self, o):
c_lhs = ccode(o.expr.lhs, dtype=o.dtype)
c_rhs = ccode(o.expr.rhs, dtype=o.dtype)
c_lhs = self.ccode(o.expr.lhs, dtype=o.dtype)
c_rhs = self.ccode(o.expr.rhs, dtype=o.dtype)
code = c.Statement("%s %s= %s" % (c_lhs, o.op, c_rhs))
if o.pragmas:
code = c.Module(self._visit(o.pragmas) + (code,))
Expand All @@ -507,7 +511,7 @@ def visit_Call(self, o, nested_call=False):
o.templates)
if retobj.is_Indexed or \
isinstance(retobj, (FieldFromComposite, FieldFromPointer)):
return c.Assign(ccode(retobj), call)
return c.Assign(self.ccode(retobj), call)
else:
return c.Initializer(c.Value(rettype, retobj._C_name), call)

Expand All @@ -521,9 +525,9 @@ def visit_Conditional(self, o):
then_body = c.Block(self._visit(then_body))
if else_body:
else_body = c.Block(self._visit(else_body))
return c.If(ccode(o.condition), then_body, else_body)
return c.If(self.ccode(o.condition), then_body, else_body)
else:
return c.If(ccode(o.condition), then_body)
return c.If(self.ccode(o.condition), then_body)

def visit_Iteration(self, o):
body = flatten(self._visit(i) for i in self._blankline_logic(o.children))
Expand All @@ -533,23 +537,23 @@ def visit_Iteration(self, o):

# For backward direction flip loop bounds
if o.direction == Backward:
loop_init = 'int %s = %s' % (o.index, ccode(_max))
loop_cond = '%s >= %s' % (o.index, ccode(_min))
loop_init = 'int %s = %s' % (o.index, self.ccode(_max))
loop_cond = '%s >= %s' % (o.index, self.ccode(_min))
loop_inc = '%s -= %s' % (o.index, o.limits[2])
else:
loop_init = 'int %s = %s' % (o.index, ccode(_min))
loop_cond = '%s <= %s' % (o.index, ccode(_max))
loop_init = 'int %s = %s' % (o.index, self.ccode(_min))
loop_cond = '%s <= %s' % (o.index, self.ccode(_max))
loop_inc = '%s += %s' % (o.index, o.limits[2])

# Append unbounded indices, if any
if o.uindices:
uinit = ['%s = %s' % (i.name, ccode(i.symbolic_min)) for i in o.uindices]
uinit = ['%s = %s' % (i.name, self.ccode(i.symbolic_min)) for i in o.uindices]
loop_init = c.Line(', '.join([loop_init] + uinit))

ustep = []
for i in o.uindices:
op = '=' if i.is_Modulo else '+='
ustep.append('%s %s %s' % (i.name, op, ccode(i.symbolic_incr)))
ustep.append('%s %s %s' % (i.name, op, self.ccode(i.symbolic_incr)))
loop_inc = c.Line(', '.join([loop_inc] + ustep))

# Create For header+body
Expand All @@ -566,7 +570,7 @@ def visit_Pragma(self, o):
return c.Pragma(o._generate)

def visit_While(self, o):
condition = ccode(o.condition)
condition = self.ccode(o.condition)
if o.body:
body = flatten(self._visit(i) for i in o.children)
return c.While(condition, c.Block(body))
Expand Down
22 changes: 14 additions & 8 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from devito.operator.profiling import create_profile
from devito.operator.registry import operator_selector
from devito.mpi import MPI
from devito.parameters import configuration, switchconfig
from devito.parameters import configuration
from devito.passes import (Graph, lower_index_derivatives, generate_implicit,
generate_macros, minimize_symbols, unevaluate,
error_mapper, is_on_device)
Expand Down Expand Up @@ -761,12 +761,17 @@ def _soname(self):

@cached_property
def ccode(self):
with switchconfig(compiler=self._compiler, language=self._language):
try:
return self._ccode_handler().visit(self)
except (AttributeError, TypeError):
from devito.ir.iet.visitors import CGen
return CGen().visit(self)
try:
return self._ccode_handler(language=self._language).visit(self)
except (AttributeError, TypeError):
from devito.ir.iet.visitors import CGen
return CGen(language=self._language).visit(self)

def _signature_items(self):
return (str(self),)

def __str__(self):
return str(self.ccode)

def _jit_compile(self):
"""
Expand Down Expand Up @@ -904,7 +909,8 @@ def apply(self, **kwargs):
"""
# Compile the operator before building the arguments list
# to avoid out of memory with greedy compilers
cfunction = self.cfunction
with self._profiler.timer_on('jit-compile'):
cfunction = self.cfunction

# Build the arguments list to invoke the kernel function
with self._profiler.timer_on('arguments'):
Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Graph:
The `visit` method collects info about the nodes in the Graph.
"""

def __init__(self, iet, options=None, sregistry=None, **kwargs):
def __init__(self, iet, options=None, sregistry=None, language=None, **kwargs):
self.efuncs = OrderedDict([(iet.name, iet)])

self.sregistry = sregistry
Expand Down
9 changes: 3 additions & 6 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,14 +765,11 @@ class SizeOf(DefFunction):

def __new__(cls, intype, stars=None, **kwargs):
stars = stars or ''

if not isinstance(intype, (str, ReservedWord)):
intype = dtype_to_ctype(intype)
if intype in ctypes_vector_mapper.values():
idx = list(ctypes_vector_mapper.values()).index(intype)
ctype = dtype_to_ctype(intype)
if ctype in ctypes_vector_mapper.values():
idx = list(ctypes_vector_mapper.values()).index(ctype)
intype = list(ctypes_vector_mapper.keys())[idx]
else:
intype = ctypes_to_cstr(intype)

newobj = super().__new__(cls, 'sizeof', arguments=f'{intype}{stars}', **kwargs)
newobj.stars = stars
Expand Down
11 changes: 8 additions & 3 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from devito.symbolics.inspection import has_integer_args, sympy_dtype
from devito.symbolics.extended_dtypes import c_complex, c_double_complex
from devito.types.basic import AbstractFunction
from devito.tools import ctypes_to_cstr
from devito.tools import ctypes_to_cstr, dtype_to_ctype

__all__ = ['ccode']

Expand Down Expand Up @@ -95,6 +95,10 @@ def parenthesize(self, item, level, strict=False):
return super().parenthesize(item, level, strict=strict)

def _print_type(self, expr):
try:
expr = dtype_to_ctype(expr)
except TypeError:
pass
try:
return self.type_mappings[expr]
except KeyError:
Expand Down Expand Up @@ -422,7 +426,7 @@ class AccDevitoPrinter(CXXDevitoPrinter):
'openacc': AccDevitoPrinter}


def ccode(expr, **settings):
def ccode(expr, language=None, **settings):
"""Generate C++ code from an expression.
Parameters
Expand All @@ -438,5 +442,6 @@ def ccode(expr, **settings):
The resulting code as a C++ string. If something went south, returns
the input ``expr`` itself.
"""
printer = printer_registry.get(configuration['language'], CDevitoPrinter)
lang = language or configuration['language']
printer = printer_registry.get(lang, CDevitoPrinter)
return printer(settings=settings).doprint(expr, None)
7 changes: 6 additions & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import inspect
from collections import namedtuple
from ctypes import POINTER, _Pointer, c_char_p, c_char
from ctypes import POINTER, _Pointer, c_char_p, c_char, Structure
from functools import reduce, cached_property
from operator import mul

Expand Down Expand Up @@ -87,13 +87,18 @@ def _C_typedata(self):
if isinstance(_type, CustomDtype):
return _type

_pointer = False
while issubclass(_type, _Pointer):
_pointer = True
_type = _type._type_

# `ctypes` treats C strings specially
if _type is c_char_p:
_type = c_char

if issubclass(_type, Structure) and _pointer:
_type = f'struct {_type.__name__}'

return _type

@abc.abstractproperty
Expand Down
5 changes: 3 additions & 2 deletions tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def _config_kwargs(platform: str, language: str) -> dict[str, str]:

@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
@pytest.mark.parametrize('kwargs', _configs)
def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> None:
def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str],
expected=None) -> None:
"""
Tests that half and complex floats' dtypes result in the correct type
strings in generated code.
Expand All @@ -78,7 +79,7 @@ def test_dtype_mapping(dtype: np.dtype[np.inexact], kwargs: dict[str, str]) -> N
params: dict[str, Basic] = {p.name: p for p in op.parameters}
_u, _c = params['u'], params['c']
assert type(_u.indexed._C_ctype._type_()) == ctypes_vector_mapper[dtype]
assert _c._C_ctype == ctypes_vector_mapper[dtype]
assert _c._C_ctype == expected or ctypes_vector_mapper[dtype]


@pytest.mark.parametrize('dtype', [np.complex64, np.complex128])
Expand Down

0 comments on commit b9b60c3

Please sign in to comment.