Skip to content

Commit

Permalink
Transformations: Remove recursive_expression_map_update
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Oct 30, 2024
1 parent 95e1ecc commit 8298e0c
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 166 deletions.
6 changes: 2 additions & 4 deletions loki/transformations/block_index_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
)
from loki.transformations.sanitise import resolve_associates
from loki.transformations.utilities import (
recursive_expression_map_update, get_integer_variable,
get_loop_bounds, check_routine_sequential
get_integer_variable, get_loop_bounds, check_routine_sequential
)
from loki.transformations.single_column.base import SCCBaseTransformation

Expand Down Expand Up @@ -227,7 +226,6 @@ def process_body(self, body, definitions, successors, targets, exclude_arrays):
if self.global_gfl_ptr:
vmap.update({v: self.build_ydvars_global_gfl_ptr(vmap.get(v, v))
for v in FindVariables(unique=False).visit(body) if 'ydvars%gfl_ptr' in v.name.lower()})
vmap = recursive_expression_map_update(vmap)

# filter out arrays marked for exclusion
vmap = {k: v for k, v in vmap.items() if not any(e in k for e in exclude_arrays)}
Expand All @@ -236,7 +234,7 @@ def process_body(self, body, definitions, successors, targets, exclude_arrays):
self.propagate_defs_to_children(self._key, definitions, successors)

# finally we perform the substitution
return SubstituteExpressions(vmap).visit(body)
return SubstituteExpressions(vmap, recursive=True).visit(body)


def process_kernel(self, routine, item, successors, targets, exclude_arrays):
Expand Down
18 changes: 7 additions & 11 deletions loki/transformations/inline/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@

from loki.transformations.inline.mapper import InlineSubstitutionMapper
from loki.transformations.inline.procedures import map_call_to_procedure_body
from loki.transformations.utilities import (
single_variable_declaration, recursive_expression_map_update
)
from loki.transformations.utilities import single_variable_declaration


__all__ = [
Expand Down Expand Up @@ -217,10 +215,10 @@ def inline_statement_functions(routine):
if proc_type.is_function and isinstance(call.routine, StatementFunction):
exprmap[call] = InlineSubstitutionMapper()(call, scope=routine)
removed_functions.add(call.routine)
# Apply the map to itself to handle nested statement function calls
exprmap = recursive_expression_map_update(exprmap, max_iterations=10, mapper_cls=InlineSubstitutionMapper)
# Apply expression-level substitution to routine
routine.body = SubstituteExpressions(exprmap).visit(routine.body)

# Apply expression-level substitution to routine and handle nested
# statement function calls via recursion
routine.body = SubstituteExpressions(exprmap, recursive=True).visit(routine.body)

# remove statement function declarations as well as statement function argument(s) declarations
vars_to_remove = {stmt_func.variable.name.lower() for stmt_func in stmt_func_decls}
Expand Down Expand Up @@ -272,8 +270,7 @@ def rename_result_name(routine, rename):
callee_vars = [var for var in FindVariables().visit(callee.body)
if var.name.lower() == callee_result_var.name.lower()]
var_map.update({var: var.clone(name=rename) for var in callee_vars})
var_map = recursive_expression_map_update(var_map)
callee.body = SubstituteExpressions(var_map).visit(callee.body)
callee.body = SubstituteExpressions(var_map, recursive=True).visit(callee.body)
return callee, new_callee_result_var

allowed_aliases = as_tuple(allowed_aliases)
Expand Down Expand Up @@ -302,8 +299,7 @@ def rename_result_name(routine, rename):
if v.name.lower() in duplicate_names:
var_map[v] = v.clone(name=f'{callee.name}_{v.name}')

var_map = recursive_expression_map_update(var_map)
callee.body = SubstituteExpressions(var_map).visit(callee.body)
callee.body = SubstituteExpressions(var_map, recursive=True).visit(callee.body)

# Separate allowed aliases from other variables to ensure clean hoisting
if allowed_aliases:
Expand Down
14 changes: 4 additions & 10 deletions loki/transformations/inline/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
from loki.subroutine import Subroutine

from loki.transformations.sanitise import transform_sequence_association_append_map
from loki.transformations.utilities import (
single_variable_declaration, recursive_expression_map_update
)
from loki.transformations.utilities import single_variable_declaration


__all__ = [
Expand Down Expand Up @@ -146,11 +144,8 @@ def _map_unbound_dims(var, val):
}
argmap.update(present_map)

# Recursive update of the map in case of nested variables to map
argmap = recursive_expression_map_update(argmap, max_iterations=10)

# Substitute argument calls into a copy of the body
callee_body = SubstituteExpressions(argmap, rebuild_scopes=True).visit(
# Substitute argument calls into a copy of the body and capture nesting
callee_body = SubstituteExpressions(argmap, recursive=True, rebuild_scopes=True).visit(
callee.body.body, scope=caller
)

Expand Down Expand Up @@ -212,8 +207,7 @@ def inline_subroutine_calls(routine, calls, callee, allowed_aliases=None):
for v in FindVariables(unique=False).visit(callee.body):
if v.name.lower() in duplicate_names:
var_map[v] = v.clone(name=f'{callee.name}_{v.name}')
var_map = recursive_expression_map_update(var_map)
callee.body = SubstituteExpressions(var_map).visit(callee.body)
callee.body = SubstituteExpressions(var_map, recursive=True).visit(callee.body)

# Separate allowed aliases from other variables to ensure clean hoisting
if allowed_aliases:
Expand Down
5 changes: 1 addition & 4 deletions loki/transformations/pool_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
from loki.tools import as_tuple
from loki.types import SymbolAttributes, BasicType, DerivedType

from loki.transformations.utilities import recursive_expression_map_update


__all__ = ['TemporariesPoolAllocatorTransformation']

Expand Down Expand Up @@ -527,8 +525,7 @@ def _determine_stack_size(self, routine, successors, local_stack_size=None, item
if expr in arg_map
}
if expr_map:
expr_map = recursive_expression_map_update(expr_map)
successor_stack_size = SubstituteExpressions(expr_map).visit(successor_stack_size)
successor_stack_size = SubstituteExpressions(expr_map, recursive=True).visit(successor_stack_size)
stack_sizes += [successor_stack_size]

# Unwind "max" expressions from successors and inject the local stack size into the expressions
Expand Down
8 changes: 2 additions & 6 deletions loki/transformations/single_column/claw.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from loki.tools import as_tuple, CaseInsensitiveDict
from loki.types import SymbolAttributes, BasicType

from loki.transformations.utilities import recursive_expression_map_update


__all__ = ['ExtractSCATransformation', 'CLAWTransformation']

Expand Down Expand Up @@ -131,10 +129,8 @@ def remove_dimension(self, routine):
# Apply vmap to variable and argument list and subroutine body
routine.variables = [vmap.get(v, v) for v in routine.variables]

# Apply substitution map to replacements to capture nesting
vmap = recursive_expression_map_update(vmap)

routine.body = SubstituteExpressions(vmap).visit(routine.body)
# Apply substitution map and capture nesting via recursion
routine.body = SubstituteExpressions(vmap, recursive=True).visit(routine.body)
for m in as_tuple(routine.members):
m.body = SubstituteExpressions(vmap).visit(m.body)

Expand Down
119 changes: 56 additions & 63 deletions loki/transformations/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from loki.types import BasicType

from loki.transformations.utilities import (
single_variable_declaration, recursive_expression_map_update,
convert_to_lower_case, replace_intrinsics, rename_variables,
get_integer_variable, get_loop_bounds, is_driver_loop,
find_driver_loops, get_local_arrays, check_routine_sequential
single_variable_declaration, convert_to_lower_case,
replace_intrinsics, rename_variables, get_integer_variable,
get_loop_bounds, is_driver_loop, find_driver_loops,
get_local_arrays, check_routine_sequential
)


Expand Down Expand Up @@ -140,65 +140,58 @@ def test_transform_convert_to_lower_case(frontend):
assert all(var.name.islower() and str(var).islower() for var in FindVariables(unique=False).visit(routine.ir))


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_utilities_recursive_expression_map_update(frontend, tmp_path):
fcode = """
module some_mod
implicit none
type some_type
integer :: m, n
real, allocatable :: a(:, :)
contains
procedure, pass :: my_add
end type some_type
contains
function my_add(self, data, val)
class(some_type), intent(inout) :: self
real, intent(in) :: data(:,:)
real, value :: val
real :: my_add(:,:)
my_add(:,:) = self%a(:,:) + data(:,:) + val
end function my_add
subroutine do(my_obj)
type(some_type), intent(inout) :: my_obj
my_obj%a = my_obj%my_add(MY_OBJ%a(1:my_obj%m, 1:MY_OBJ%n), 1.)
end subroutine do
end module some_mod
""".strip()

module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = module['do']

expr_map = {}
expr_map[routine.variable_map['my_obj']] = routine.variable_map['my_obj'].clone(name='obj')
for var in FindVariables().visit(routine.body):
if var.parent == 'my_obj':
expr_map[var] = var.clone(name=f'obj%{var.basename}', parent=var.parent.clone(name='obj'))

# There are "my_obj" nodes still around...
assert any(
var == 'my_obj' or var.parent == 'my_obj' for var in FindVariables().visit(list(expr_map.values()))
)

# ...and application performs only a partial substitution
cloned = routine.clone()
cloned.body = SubstituteExpressions(expr_map).visit(cloned.body)
assert fgen(cloned.body.body[0]).lower() == 'obj%a = obj%my_add(obj%a(1:my_obj%m, 1:my_obj%n), 1.)'

# Apply recursive update
expr_map = recursive_expression_map_update(expr_map)

# No more "my_obj" nodes...
assert all(
var != 'my_obj' and var.parent != 'my_obj' for var in FindVariables().visit(list(expr_map.values()))
)

# ...and full substitution
assert fgen(routine.body.body[0]).lower() == 'my_obj%a = my_obj%my_add(my_obj%a(1:my_obj%m, 1:my_obj%n), 1.)'
routine.body = SubstituteExpressions(expr_map).visit(routine.body)
assert fgen(routine.body.body[0]) == 'obj%a = obj%my_add(obj%a(1:obj%m, 1:obj%n), 1.)'
# @pytest.mark.parametrize('frontend', available_frontends())
# def test_transform_utilities_recursive_expression_map_update(frontend, tmp_path):
# fcode = """
# module some_mod
# implicit none

# type some_type
# integer :: m, n
# real, allocatable :: a(:, :)
# contains
# procedure, pass :: my_add
# end type some_type
# contains
# function my_add(self, data, val)
# class(some_type), intent(inout) :: self
# real, intent(in) :: data(:,:)
# real, value :: val
# real :: my_add(:,:)
# my_add(:,:) = self%a(:,:) + data(:,:) + val
# end function my_add

# subroutine do(my_obj)
# type(some_type), intent(inout) :: my_obj
# my_obj%a = my_obj%my_add(MY_OBJ%a(1:my_obj%m, 1:MY_OBJ%n), 1.)
# end subroutine do
# end module some_mod
# """.strip()

# module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
# routine = module['do']

# expr_map = {}
# expr_map[routine.variable_map['my_obj']] = routine.variable_map['my_obj'].clone(name='obj')
# for var in FindVariables().visit(routine.body):
# if var.parent == 'my_obj':
# expr_map[var] = var.clone(name=f'obj%{var.basename}', parent=var.parent.clone(name='obj'))

# # There are "my_obj" nodes still around...
# assert any(
# var == 'my_obj' or var.parent == 'my_obj' for var in FindVariables().visit(list(expr_map.values()))
# )

# # ...and application performs only a partial substitution
# cloned = routine.clone()
# cloned.body = SubstituteExpressions(expr_map, recurse=False).visit(cloned.body)
# assert fgen(cloned.body.body[0]).lower() == 'obj%a = obj%my_add(obj%a(1:my_obj%m, 1:my_obj%n), 1.)'

# # ...and full substitution
# assert fgen(routine.body.body[0]).lower() == 'my_obj%a = my_obj%my_add(my_obj%a(1:my_obj%m, 1:my_obj%n), 1.)'
# routine.body = SubstituteExpressions(expr_map, recurse=True).visit(routine.body)

# assert fgen(routine.body.body[0]) == 'obj%a = obj%my_add(obj%a(1:obj%m, 1:obj%n), 1.)'

@pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'Argument mismatch for "min"')]))
def test_transform_utilites_replace_intrinsics(frontend):
Expand Down
12 changes: 4 additions & 8 deletions loki/transformations/transform_derived_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
from loki.tools import as_tuple, flatten, CaseInsensitiveDict
from loki.types import BasicType, DerivedType, ProcedureType

from loki.transformations.utilities import recursive_expression_map_update


__all__ = ['DerivedTypeArgumentsTransformation', 'TypeboundProcedureCallTransformation']

Expand Down Expand Up @@ -351,9 +349,8 @@ def assumed_dim_or_none(shape):
routine.variables = [v for var in routine.variables for v in arguments_map.get(var, [var])]

# Substitue derived type member use in the spec and body
vmap = recursive_expression_map_update(vmap)
routine.spec = SubstituteExpressions(vmap).visit(routine.spec)
routine.body = SubstituteExpressions(vmap).visit(routine.body)
routine.spec = SubstituteExpressions(vmap, recursive=True).visit(routine.spec)
routine.body = SubstituteExpressions(vmap, recursive=True).visit(routine.body)

# Update procedure bindings by specifying NOPASS attribute
for arg in arguments_map:
Expand Down Expand Up @@ -527,7 +524,7 @@ def _update_call(call):
var, type=cls._get_expanded_kernel_var_type(orig_arg, var), scope=routine, dimensions=None
)
expansion_map[var] = expanded_var
expansion_mapper = SubstituteExpressionsMapper(recursive_expression_map_update(expansion_map))
expansion_mapper = SubstituteExpressionsMapper(expansion_map, recursive=True)
arguments = tuple(expansion_mapper(arg) for arg in arguments)
kwarguments = tuple((k, expansion_mapper(v)) for k, v in kwarguments)
return arguments, kwarguments
Expand Down Expand Up @@ -713,8 +710,7 @@ def visit_Expression(self, o, **kwargs):
if not expr_map:
return o

expr_map = recursive_expression_map_update(expr_map)
return SubstituteExpressionsMapper(expr_map)(o)
return SubstituteExpressionsMapper(expr_map, recursive=True)(o)


class TypeboundProcedureCallTransformation(Transformation):
Expand Down
Loading

0 comments on commit 8298e0c

Please sign in to comment.