diff --git a/loki/transformations/extract/outline.py b/loki/transformations/extract/outline.py index b6082b82a..bd97d2155 100644 --- a/loki/transformations/extract/outline.py +++ b/loki/transformations/extract/outline.py @@ -6,18 +6,155 @@ # nor does it submit to any jurisdiction. from loki.analyse import dataflow_analysis_attached -from loki.expression import Variable +from loki.expression import symbols as sym, Variable from loki.ir import ( - CallStatement, Import, PragmaRegion, Section, FindNodes, - FindVariables, MaskedTransformer, Transformer, is_loki_pragma, + CallStatement, PragmaRegion, Section, FindNodes, + FindVariables, Transformer, is_loki_pragma, get_pragma_parameters, pragma_regions_attached ) from loki.logging import info from loki.subroutine import Subroutine -from loki.tools import as_tuple, CaseInsensitiveDict +from loki.tools import as_tuple +from loki.types import BasicType -__all__ = ['outline_pragma_regions'] + +__all__ = ['outline_region', 'outline_pragma_regions'] + + +def order_variables_by_type(variables, imports=None): + """ + Apply a default ordering to variables based on their type, so that + their use in declaration lists is unified. + """ + variables = sorted(variables, key=str) # Lexicographical base order + + derived = tuple( + v for v in variables + if not isinstance(v, (sym.Scalar, sym.Array)) or not isinstance(v.type.dtype, BasicType) + ) + + if imports: + # Order derived types by the order of their type in imports + imported_symbols = tuple(s for i in imports for s in i.symbols if not i.c_import) + derived = tuple(sorted(derived, key=lambda x: imported_symbols.index(x.type.dtype.name))) + + # Order declarations by type and put arrays before scalars + non_derived = tuple(v for v in variables if v not in derived) + arrays = tuple(v for v in non_derived if isinstance(v, sym.Array)) + scalars = tuple(v for v in non_derived if isinstance(v, sym.Scalar)) + assert len(derived) + len(arrays) + len(scalars) == len(variables) + + return derived + arrays + scalars + + +def outline_region(region, name, imports, intent_map=None): + """ + Creates a new :any:`Subroutine` object from a given :any:`PragmaRegion`. + + Parameters + ---------- + region : :any:`PragmaRegion` + The region that holds the body for which to create a subroutine. + name : str + Name of the new subroutine + imports : tuple of :any:`Import`, optional + List of imports to replicate in the new subroutine + intent_map : dict, optional + Mapping of instent strings to list of variables to override intents + + Returns + ------- + tuple of :any:`CallStatement` and :any:`Subroutine` + The newly created call and respectice subroutine. + """ + intent_map = intent_map or {} + imports = as_tuple(imports) + imported_symbols = {var for imp in imports for var in imp.symbols} + # Special-case for IFS-style C-imports + imported_symbols |= { + str(imp.module).split('.', maxsplit=1)[0] for imp in imports if imp.c_import + } + + # Create the external subroutine containing the routine's imports and the region's body + spec = Section(body=imports) + body = Section(body=Transformer().visit(region.body)) + region_routine = Subroutine(name, spec=spec, body=body) + + # Filter derived-type component accesses and only use the root parent + region_uses_symbols = {s.parents[0] if s.parent else s for s in region.uses_symbols} + region_defines_symbols = {s.parents[0] if s.parent else s for s in region.defines_symbols} + + # Use dataflow analysis to find in, out and inout variables to that region + # (ignoring any symbols that are external imports) + region_in_args = region_uses_symbols - region_defines_symbols - imported_symbols + region_inout_args = region_uses_symbols & region_defines_symbols - imported_symbols + region_out_args = region_defines_symbols - region_uses_symbols - imported_symbols + + # Remove any parameters from in args + region_in_args = {arg for arg in region_in_args if not arg.type.parameter} + + # Extract arguments given in pragma annotations + pragma_in_args = {v.clone(scope=region_routine) for v in intent_map['in']} + pragma_inout_args = {v.clone(scope=region_routine) for v in intent_map['inout']} + pragma_out_args = {v.clone(scope=region_routine) for v in intent_map['out']} + + # Override arguments according to pragma annotations + region_in_args = (region_in_args - (pragma_inout_args | pragma_out_args)) | pragma_in_args + region_inout_args = (region_inout_args - (pragma_in_args | pragma_out_args)) | pragma_inout_args + region_out_args = (region_out_args - (pragma_in_args | pragma_inout_args)) | pragma_out_args + + # Now fix the order + region_inout_args = as_tuple(region_inout_args) + region_in_args = as_tuple(region_in_args) + region_out_args = as_tuple(region_out_args) + + # Set the list of variables used in region routine (to create declarations) + # and put all in the new scope + region_routine_variables = tuple( + v.clone(dimensions=v.type.shape or None, scope=region_routine) + for v in FindVariables().visit(region.body) + if v.clone(dimensions=None) not in imported_symbols + ) + # Filter out derived-type component variables from declarations + region_routine_variables = tuple( + v.parents[0] if v.parent else v for v in region_routine_variables + ) + + # Build the call signature + region_routine_var_map = {v.name: v for v in region_routine_variables} + region_routine_arguments = [] + for intent, args in zip(('in', 'inout', 'out'), (region_in_args, region_inout_args, region_out_args)): + for arg in args: + local_var = region_routine_var_map.get(arg.name, arg) + # Sanitise argument types + local_var = local_var.clone( + type=local_var.type.clone(intent=intent, allocatable=None, target=None), + scope=region_routine + ) + + region_routine_var_map[arg.name] = local_var + region_routine_arguments += [local_var] + + # Order the arguments and local declaration lists and put arguments first + region_routine_locals = tuple( + v for v in region_routine_variables if not v in region_routine_arguments + ) + region_routine_arguments = order_variables_by_type(region_routine_arguments, imports=imports) + region_routine_locals = order_variables_by_type(region_routine_locals, imports=imports) + + region_routine.variables = region_routine_arguments + region_routine_locals + region_routine.arguments = region_routine_arguments + + # Ensure everything has been rescoped + region_routine.rescope_symbols() + + # Create the call according to the wrapped code region + call_arg_map = {v.name: v for v in region_in_args + region_inout_args + region_out_args} + call_arguments = tuple(call_arg_map[a.name] for a in region_routine_arguments) + call = CallStatement(name=Variable(name=name), arguments=call_arguments, kwarguments=()) + + return call, region_routine def outline_pragma_regions(routine): @@ -46,12 +183,12 @@ def outline_pragma_regions(routine): ------- list of :any:`Subroutine` the list of newly created subroutines. - """ counter = 0 - routines, starts, stops = [], [], [] - imports = {var for imprt in FindNodes(Import).visit(routine.spec) for var in imprt.symbols} - mask_map = {} + routines = [] + imports = routine.imports + parent_vmap = routine.variable_map + mapper = {} with pragma_regions_attached(routine): with dataflow_analysis_attached(routine): for region in FindNodes(PragmaRegion).visit(routine.body): @@ -63,74 +200,21 @@ def outline_pragma_regions(routine): name = parameters.get('name', f'{routine.name}_outlined_{counter}') counter += 1 - # Create the external subroutine containing the routine's imports and the region's body - spec = Section(body=Transformer().visit(FindNodes(Import).visit(routine.spec))) - body = Section(body=Transformer().visit(region.body)) - region_routine = Subroutine(name, spec=spec, body=body) - - # Use dataflow analysis to find in, out and inout variables to that region - # (ignoring any symbols that are external imports) - region_in_args = region.uses_symbols - region.defines_symbols - imports - region_inout_args = region.uses_symbols & region.defines_symbols - imports - region_out_args = region.defines_symbols - region.uses_symbols - imports - - # Remove any parameters from in args - region_in_args = {arg for arg in region_in_args if not arg.type.parameter} - - # Extract arguments given in pragma annotations - region_var_map = CaseInsensitiveDict( - (v.name, v.clone(dimensions=None)) - for v in FindVariables().visit(region.body) if v.clone(dimensions=None) not in imports - ) - pragma_in_args = {region_var_map[v.lower()] for v in parameters.get('in', '').split(',') if v} - pragma_inout_args = {region_var_map[v.lower()] for v in parameters.get('inout', '').split(',') if v} - pragma_out_args = {region_var_map[v.lower()] for v in parameters.get('out', '').split(',') if v} - - # Override arguments according to pragma annotations - region_in_args = (region_in_args - (pragma_inout_args | pragma_out_args)) | pragma_in_args - region_inout_args = (region_inout_args - (pragma_in_args | pragma_out_args)) | pragma_inout_args - region_out_args = (region_out_args - (pragma_in_args | pragma_inout_args)) | pragma_out_args - - # Now fix the order - region_inout_args = as_tuple(region_inout_args) - region_in_args = as_tuple(region_in_args) - region_out_args = as_tuple(region_out_args) - - # Set the list of variables used in region routine (to create declarations) - # and put all in the new scope - region_routine_variables = {v.clone(dimensions=v.type.shape or None) - for v in FindVariables().visit(region_routine.body) - if v.name in region_var_map} - region_routine.variables = as_tuple(region_routine_variables) - region_routine.rescope_symbols() - - # Build the call signature - region_routine_var_map = region_routine.variable_map - region_routine_arguments = [] - for intent, args in zip(('in', 'inout', 'out'), (region_in_args, region_inout_args, region_out_args)): - for arg in args: - local_var = region_routine_var_map[arg.name] - local_var = local_var.clone(type=local_var.type.clone(intent=intent)) - region_routine_var_map[arg.name] = local_var - region_routine_arguments += [local_var] - - # We need to update the list of variables again to avoid duplicate declarations - region_routine.variables = as_tuple(region_routine_var_map.values()) - region_routine.arguments = as_tuple(region_routine_arguments) + # Extract explicitly requested symbols from context + intent_map = {} + intent_map['in'] = tuple(parent_vmap[v] for v in parameters.get('in', '').split(',') if v) + intent_map['inout'] = tuple(parent_vmap[v] for v in parameters.get('inout', '').split(',') if v) + intent_map['out'] = tuple(parent_vmap[v] for v in parameters.get('out', '').split(',') if v) + + call, region_routine = outline_region(region, name, imports, intent_map=intent_map) # insert into list of new routines routines.append(region_routine) - # Register start and end nodes in transformer mask for original routine - starts += [region.pragma_post] - stops += [region.pragma] - - # Replace end pragma by call in original routine - call_arguments = region_in_args + region_inout_args + region_out_args - call = CallStatement(name=Variable(name=name), arguments=call_arguments) - mask_map[region.pragma_post] = call + # Replace region by call in original routine + mapper[region] = call - routine.body = MaskedTransformer(active=True, start=starts, stop=stops, mapper=mask_map).visit(routine.body) + routine.body = Transformer(mapper=mapper).visit(routine.body) info('%s: converted %d region(s) to calls', routine.name, counter) return routines diff --git a/loki/transformations/extract/tests/test_outline.py b/loki/transformations/extract/tests/test_outline.py index a17ae7361..c6ba5ffa0 100644 --- a/loki/transformations/extract/tests/test_outline.py +++ b/loki/transformations/extract/tests/test_outline.py @@ -13,6 +13,7 @@ from loki.frontend import available_frontends from loki.ir import FindNodes, Section, Assignment, CallStatement, Intrinsic from loki.tools import as_tuple +from loki.types import BasicType from loki.transformations.extract.outline import outline_pragma_regions @@ -361,3 +362,172 @@ def test_outline_pragma_regions_imports(tmp_path, builder, frontend): mod_function(a, b) assert np.all(a == [1] * 10) assert np.all(b == range(1,11)) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_outline_pragma_regions_derived_args(tmp_path, builder, frontend): + """ + Test subroutine extraction with derived-type arguments. + """ + + fcode = """ +module test_outline_dertype_mod + implicit none + + type rick + integer :: a(10), b(10) + end type rick +contains + + subroutine test_outline_imps(a, b) + integer, intent(out) :: a(10), b(10) + type(rick) :: dave + integer :: j + + dave%a(:) = a(:) + dave%b(:) = b(:) + +!$loki outline + do j=1,10 + dave%a(j) = j + 1 + end do + + dave%b(:) = dave%b(:) + 42 +!$loki end outline + + a(:) = dave%a(:) + b(:) = dave%b(:) + end subroutine test_outline_imps +end module test_outline_dertype_mod +""" + module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) + refname = f'ref_{module.name}_{frontend}' + reference = jit_compile_lib([module], path=tmp_path, name=refname, builder=builder) + function = getattr(getattr(reference, module.name), module.subroutines[0].name) + + # Test the reference solution + a = np.zeros(shape=(10,), dtype=np.int32) + b = np.zeros(shape=(10,), dtype=np.int32) + function(a, b) + assert np.all(a == range(2,12)) + assert np.all(b == 42) + (tmp_path/f'{module.name}.f90').unlink() + + assert len(FindNodes(Assignment).visit(module.subroutines[0].body)) == 6 + assert len(FindNodes(CallStatement).visit(module.subroutines[0].body)) == 0 + + # Apply transformation + routines = outline_pragma_regions(module.subroutines[0]) + + assert len(FindNodes(Assignment).visit(module.subroutines[0].body)) == 4 + assert len(FindNodes(CallStatement).visit(module.subroutines[0].body)) == 1 + + # Check for a single derived-type argument + assert len(routines) == 1 + assert len(routines[0].arguments) == 1 + assert routines[0].arguments[0] == 'dave' + assert routines[0].arguments[0].type.dtype.name == 'rick' + assert routines[0].arguments[0].type.intent == 'inout' + + # Insert created routines into module + module.contains.append(routines) + + obj = jit_compile_lib([module], path=tmp_path, name=f'{module.name}_{frontend}', builder=builder) + mod_function = getattr(getattr(obj, module.name), module.subroutines[0].name) + + # Test the transformed module solution + a = np.zeros(shape=(10,), dtype=np.int32) + b = np.zeros(shape=(10,), dtype=np.int32) + mod_function(a, b) + assert np.all(a == range(2,12)) + assert np.all(b == 42) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_outline_pragma_regions_associates(tmp_path, builder, frontend): + """ + Test subroutine extraction with derived-type arguments. + """ + + fcode = """ +module test_outline_assoc_mod + implicit none + + type rick + integer :: a(10), b(10) + end type rick +contains + + subroutine test_outline_imps(a, b) + integer, intent(out) :: a(10), b(10) + type(rick) :: dave + integer :: j + + associate(c=>dave%a, d=>dave%b) + + c(:) = a(:) + d(:) = b(:) + +!$loki outline + do j=1,10 + c(j) = j + 1 + end do + + d(:) = d(:) + 42 +!$loki end outline + + a(:) = c(:) + b(:) = d(:) + end associate + end subroutine test_outline_imps +end module test_outline_assoc_mod +""" + module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) + routine = module.subroutines[0] + refname = f'ref_{module.name}_{frontend}' + reference = jit_compile_lib([module], path=tmp_path, name=refname, builder=builder) + function = getattr(getattr(reference, module.name), routine.name) + + # Test the reference solution + a = np.zeros(shape=(10,), dtype=np.int32) + b = np.zeros(shape=(10,), dtype=np.int32) + function(a, b) + assert np.all(a == range(2,12)) + assert np.all(b == 42) + (tmp_path/f'{module.name}.f90').unlink() + + assert len(FindNodes(Assignment).visit(routine.body)) == 6 + assert len(FindNodes(CallStatement).visit(routine.body)) == 0 + + # Apply transformation + outlined = outline_pragma_regions(routine) + + assert len(FindNodes(Assignment).visit(routine.body)) == 4 + calls = FindNodes(CallStatement).visit(routine.body) + assert len(calls) == 1 + assert calls[0].arguments == ('c', 'd') + + # Check for a single derived-type argument + assert len(outlined) == 1 + assert len(outlined[0].arguments) == 2 + assert outlined[0].arguments[0].name == 'c' + assert outlined[0].arguments[0].type.shape == (10,) + assert outlined[0].arguments[0].type.dtype == BasicType.INTEGER + assert outlined[0].arguments[0].type.intent == 'out' + assert outlined[0].arguments[1].name == 'd' + assert outlined[0].arguments[1].type.shape == (10,) + assert outlined[0].arguments[1].type.dtype == BasicType.INTEGER + assert outlined[0].arguments[1].type.intent == 'inout' + + # Insert created routines into module + module.contains.append(outlined) + + obj = jit_compile_lib( + [module], path=tmp_path, name=f'{module.name}_{frontend}', builder=builder + ) + mod_function = getattr(getattr(obj, module.name), routine.name) + a = np.zeros(shape=(10,), dtype=np.int32) + b = np.zeros(shape=(10,), dtype=np.int32) + mod_function(a, b) + assert np.all(a == range(2,12)) + assert np.all(b == 42)