diff --git a/loki/transformations/extract/outline.py b/loki/transformations/extract/outline.py index a49e02f16..2299c04d9 100644 --- a/loki/transformations/extract/outline.py +++ b/loki/transformations/extract/outline.py @@ -6,7 +6,7 @@ # nor does it submit to any jurisdiction. from loki.analyse import dataflow_analysis_attached -from loki.expression import Variable +from loki.expression import symbols as sym, Variable from loki.ir import ( CallStatement, Import, PragmaRegion, Section, FindNodes, FindVariables, Transformer, is_loki_pragma, @@ -15,12 +15,39 @@ from loki.logging import info from loki.subroutine import Subroutine from loki.tools import as_tuple, CaseInsensitiveDict +from loki.types import BasicType, DerivedType __all__ = ['outline_region', 'outline_pragma_regions'] +def order_variables_by_type(variables, imports=None): + """ + Apply a default ordering to variables based on their type, so that + their use in declaration lists is unified. + """ + variables = sorted(variables, key=str) # Lexicographical base order + + derived = tuple( + v for v in variables + if isinstance(v.type.dtype, DerivedType) or v.type.dtype == BasicType.DEFERRED + ) + + if imports: + # Order derived types by the order of their type in imports + imported_symbols = tuple(s for i in imports for s in i.symbols if not i.c_import) + derived = tuple(sorted(derived, key=lambda x: imported_symbols.index(x.type.dtype.name))) + + # Order declarations by type and put arrays before scalars + non_derived = tuple(v for v in variables if v not in derived) + arrays = tuple(v for v in non_derived if isinstance(v, sym.Array)) + scalars = tuple(v for v in non_derived if isinstance(v, sym.Scalar)) + assert len(derived) + len(arrays) + len(scalars) == len(variables) + + return derived + arrays + scalars + + def outline_region(region, name, imports, intent_map=None): """ Creates a new :any:`Subroutine` object from a given :any:`PragmaRegion`. @@ -66,11 +93,6 @@ def outline_region(region, name, imports, intent_map=None): region_in_args = {arg for arg in region_in_args if not arg.type.parameter} # Extract arguments given in pragma annotations - region_var_map = CaseInsensitiveDict( - (v.name, v.clone(dimensions=None)) - for v in FindVariables().visit(region.body) - if v.clone(dimensions=None) not in imported_symbols - ) pragma_in_args = {v.clone(scope=region_routine) for v in intent_map['in']} pragma_inout_args = {v.clone(scope=region_routine) for v in intent_map['inout']} pragma_out_args = {v.clone(scope=region_routine) for v in intent_map['out']} @@ -87,23 +109,18 @@ def outline_region(region, name, imports, intent_map=None): # Set the list of variables used in region routine (to create declarations) # and put all in the new scope - region_routine_variables = { - v.clone(dimensions=v.type.shape or None) - for v in FindVariables().visit(region_routine.body) - if v.name in region_var_map - } + region_routine_variables = tuple( + v.clone(dimensions=v.type.shape or None, scope=region_routine) + for v in FindVariables().visit(region.body) + if v.clone(dimensions=None) not in imported_symbols + ) # Filter out derived-type component variables from declarations - region_routine_variables = { + region_routine_variables = tuple( v.parents[0] if v.parent else v for v in region_routine_variables - } - # Order the local devlaration list to put arguments first - region_routine_args = tuple(v for v in region_routine_variables if v.type.intent) - region_routine_locals = tuple(v for v in region_routine_variables if not v.type.intent) - region_routine.variables = region_routine_args + region_routine_locals - region_routine.rescope_symbols() + ) # Build the call signature - region_routine_var_map = region_routine.variable_map + region_routine_var_map = {v.name: v for v in region_routine_variables} region_routine_arguments = [] for intent, args in zip(('in', 'inout', 'out'), (region_in_args, region_inout_args, region_out_args)): for arg in args: @@ -116,12 +133,19 @@ def outline_region(region, name, imports, intent_map=None): region_routine_var_map[arg.name] = local_var region_routine_arguments += [local_var] - # We need to update the list of variables again to avoid duplicate declarations - region_routine.variables = as_tuple(region_routine_var_map.values()) - region_routine.arguments = as_tuple(region_routine_arguments) + # Order the arguments and local declaration lists and put arguments first + region_routine_locals = tuple( + v for v in region_routine_variables if not v in region_routine_arguments + ) + region_routine_arguments = order_variables_by_type(region_routine_arguments, imports=imports) + region_routine_locals = order_variables_by_type(region_routine_locals, imports=imports) + + region_routine.variables = region_routine_arguments + region_routine_locals + region_routine.arguments = region_routine_arguments # Create the call according to the wrapped code region - call_arguments = region_in_args + region_inout_args + region_out_args + call_arg_map = {v.name: v for v in region_in_args + region_inout_args + region_out_args} + call_arguments = tuple(call_arg_map[a.name] for a in region_routine_arguments) call = CallStatement(name=Variable(name=name), arguments=call_arguments, kwarguments=()) return call, region_routine diff --git a/loki/transformations/extract/tests/test_outline.py b/loki/transformations/extract/tests/test_outline.py index 2a35c32b1..c6ba5ffa0 100644 --- a/loki/transformations/extract/tests/test_outline.py +++ b/loki/transformations/extract/tests/test_outline.py @@ -505,19 +505,19 @@ def test_outline_pragma_regions_associates(tmp_path, builder, frontend): assert len(FindNodes(Assignment).visit(routine.body)) == 4 calls = FindNodes(CallStatement).visit(routine.body) assert len(calls) == 1 - assert calls[0].arguments == ('d', 'c') + assert calls[0].arguments == ('c', 'd') # Check for a single derived-type argument assert len(outlined) == 1 assert len(outlined[0].arguments) == 2 - assert outlined[0].arguments[0].name == 'd' + assert outlined[0].arguments[0].name == 'c' assert outlined[0].arguments[0].type.shape == (10,) assert outlined[0].arguments[0].type.dtype == BasicType.INTEGER - assert outlined[0].arguments[0].type.intent == 'inout' - assert outlined[0].arguments[1].name == 'c' + assert outlined[0].arguments[0].type.intent == 'out' + assert outlined[0].arguments[1].name == 'd' assert outlined[0].arguments[1].type.shape == (10,) assert outlined[0].arguments[1].type.dtype == BasicType.INTEGER - assert outlined[0].arguments[1].type.intent == 'out' + assert outlined[0].arguments[1].type.intent == 'inout' # Insert created routines into module module.contains.append(outlined)