Skip to content

Commit

Permalink
Extract: Order declarations when extracting routines
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Oct 19, 2024
1 parent 2d3680c commit f28bf69
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 28 deletions.
70 changes: 47 additions & 23 deletions loki/transformations/extract/outline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# nor does it submit to any jurisdiction.

from loki.analyse import dataflow_analysis_attached
from loki.expression import Variable
from loki.expression import symbols as sym, Variable
from loki.ir import (
CallStatement, Import, PragmaRegion, Section, FindNodes,
FindVariables, Transformer, is_loki_pragma,
Expand All @@ -15,12 +15,39 @@
from loki.logging import info
from loki.subroutine import Subroutine
from loki.tools import as_tuple, CaseInsensitiveDict
from loki.types import BasicType, DerivedType



__all__ = ['outline_region', 'outline_pragma_regions']


def order_variables_by_type(variables, imports=None):
"""
Apply a default ordering to variables based on their type, so that
their use in declaration lists is unified.
"""
variables = sorted(variables, key=str) # Lexicographical base order

derived = tuple(
v for v in variables
if isinstance(v.type.dtype, DerivedType) or v.type.dtype == BasicType.DEFERRED
)

if imports:
# Order derived types by the order of their type in imports
imported_symbols = tuple(s for i in imports for s in i.symbols if not i.c_import)
derived = tuple(sorted(derived, key=lambda x: imported_symbols.index(x.type.dtype.name)))

# Order declarations by type and put arrays before scalars
non_derived = tuple(v for v in variables if v not in derived)
arrays = tuple(v for v in non_derived if isinstance(v, sym.Array))
scalars = tuple(v for v in non_derived if isinstance(v, sym.Scalar))
assert len(derived) + len(arrays) + len(scalars) == len(variables)

return derived + arrays + scalars


def outline_region(region, name, imports, intent_map=None):
"""
Creates a new :any:`Subroutine` object from a given :any:`PragmaRegion`.
Expand Down Expand Up @@ -66,11 +93,6 @@ def outline_region(region, name, imports, intent_map=None):
region_in_args = {arg for arg in region_in_args if not arg.type.parameter}

# Extract arguments given in pragma annotations
region_var_map = CaseInsensitiveDict(
(v.name, v.clone(dimensions=None))
for v in FindVariables().visit(region.body)
if v.clone(dimensions=None) not in imported_symbols
)
pragma_in_args = {v.clone(scope=region_routine) for v in intent_map['in']}
pragma_inout_args = {v.clone(scope=region_routine) for v in intent_map['inout']}
pragma_out_args = {v.clone(scope=region_routine) for v in intent_map['out']}
Expand All @@ -87,23 +109,18 @@ def outline_region(region, name, imports, intent_map=None):

# Set the list of variables used in region routine (to create declarations)
# and put all in the new scope
region_routine_variables = {
v.clone(dimensions=v.type.shape or None)
for v in FindVariables().visit(region_routine.body)
if v.name in region_var_map
}
region_routine_variables = tuple(
v.clone(dimensions=v.type.shape or None, scope=region_routine)
for v in FindVariables().visit(region.body)
if v.clone(dimensions=None) not in imported_symbols
)
# Filter out derived-type component variables from declarations
region_routine_variables = {
region_routine_variables = tuple(
v.parents[0] if v.parent else v for v in region_routine_variables
}
# Order the local devlaration list to put arguments first
region_routine_args = tuple(v for v in region_routine_variables if v.type.intent)
region_routine_locals = tuple(v for v in region_routine_variables if not v.type.intent)
region_routine.variables = region_routine_args + region_routine_locals
region_routine.rescope_symbols()
)

# Build the call signature
region_routine_var_map = region_routine.variable_map
region_routine_var_map = {v.name: v for v in region_routine_variables}
region_routine_arguments = []
for intent, args in zip(('in', 'inout', 'out'), (region_in_args, region_inout_args, region_out_args)):
for arg in args:
Expand All @@ -116,12 +133,19 @@ def outline_region(region, name, imports, intent_map=None):
region_routine_var_map[arg.name] = local_var
region_routine_arguments += [local_var]

# We need to update the list of variables again to avoid duplicate declarations
region_routine.variables = as_tuple(region_routine_var_map.values())
region_routine.arguments = as_tuple(region_routine_arguments)
# Order the arguments and local declaration lists and put arguments first
region_routine_locals = tuple(
v for v in region_routine_variables if not v in region_routine_arguments
)
region_routine_arguments = order_variables_by_type(region_routine_arguments, imports=imports)
region_routine_locals = order_variables_by_type(region_routine_locals, imports=imports)

region_routine.variables = region_routine_arguments + region_routine_locals
region_routine.arguments = region_routine_arguments

# Create the call according to the wrapped code region
call_arguments = region_in_args + region_inout_args + region_out_args
call_arg_map = {v.name: v for v in region_in_args + region_inout_args + region_out_args}
call_arguments = tuple(call_arg_map[a.name] for a in region_routine_arguments)
call = CallStatement(name=Variable(name=name), arguments=call_arguments, kwarguments=())

return call, region_routine
Expand Down
10 changes: 5 additions & 5 deletions loki/transformations/extract/tests/test_outline.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,19 +505,19 @@ def test_outline_pragma_regions_associates(tmp_path, builder, frontend):
assert len(FindNodes(Assignment).visit(routine.body)) == 4
calls = FindNodes(CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].arguments == ('d', 'c')
assert calls[0].arguments == ('c', 'd')

# Check for a single derived-type argument
assert len(outlined) == 1
assert len(outlined[0].arguments) == 2
assert outlined[0].arguments[0].name == 'd'
assert outlined[0].arguments[0].name == 'c'
assert outlined[0].arguments[0].type.shape == (10,)
assert outlined[0].arguments[0].type.dtype == BasicType.INTEGER
assert outlined[0].arguments[0].type.intent == 'inout'
assert outlined[0].arguments[1].name == 'c'
assert outlined[0].arguments[0].type.intent == 'out'
assert outlined[0].arguments[1].name == 'd'
assert outlined[0].arguments[1].type.shape == (10,)
assert outlined[0].arguments[1].type.dtype == BasicType.INTEGER
assert outlined[0].arguments[1].type.intent == 'out'
assert outlined[0].arguments[1].type.intent == 'inout'

# Insert created routines into module
module.contains.append(outlined)
Expand Down

0 comments on commit f28bf69

Please sign in to comment.