From 8298e0cb4704325c7e91d9b95f07ea0ed390ef78 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 30 Oct 2024 07:46:15 +0000 Subject: [PATCH] Transformations: Remove `recursive_expression_map_update` --- .../block_index_transformations.py | 6 +- loki/transformations/inline/functions.py | 18 ++- loki/transformations/inline/procedures.py | 14 +-- loki/transformations/pool_allocator.py | 5 +- loki/transformations/single_column/claw.py | 8 +- loki/transformations/tests/test_utilities.py | 119 +++++++++--------- .../transform_derived_types.py | 12 +- loki/transformations/utilities.py | 65 +--------- 8 files changed, 81 insertions(+), 166 deletions(-) diff --git a/loki/transformations/block_index_transformations.py b/loki/transformations/block_index_transformations.py index 1c22cc4c7..a0e0cfdc2 100644 --- a/loki/transformations/block_index_transformations.py +++ b/loki/transformations/block_index_transformations.py @@ -20,8 +20,7 @@ ) from loki.transformations.sanitise import resolve_associates from loki.transformations.utilities import ( - recursive_expression_map_update, get_integer_variable, - get_loop_bounds, check_routine_sequential + get_integer_variable, get_loop_bounds, check_routine_sequential ) from loki.transformations.single_column.base import SCCBaseTransformation @@ -227,7 +226,6 @@ def process_body(self, body, definitions, successors, targets, exclude_arrays): if self.global_gfl_ptr: vmap.update({v: self.build_ydvars_global_gfl_ptr(vmap.get(v, v)) for v in FindVariables(unique=False).visit(body) if 'ydvars%gfl_ptr' in v.name.lower()}) - vmap = recursive_expression_map_update(vmap) # filter out arrays marked for exclusion vmap = {k: v for k, v in vmap.items() if not any(e in k for e in exclude_arrays)} @@ -236,7 +234,7 @@ def process_body(self, body, definitions, successors, targets, exclude_arrays): self.propagate_defs_to_children(self._key, definitions, successors) # finally we perform the substitution - return SubstituteExpressions(vmap).visit(body) + return SubstituteExpressions(vmap, recursive=True).visit(body) def process_kernel(self, routine, item, successors, targets, exclude_arrays): diff --git a/loki/transformations/inline/functions.py b/loki/transformations/inline/functions.py index a39b81e5f..9b31e7b11 100644 --- a/loki/transformations/inline/functions.py +++ b/loki/transformations/inline/functions.py @@ -20,9 +20,7 @@ from loki.transformations.inline.mapper import InlineSubstitutionMapper from loki.transformations.inline.procedures import map_call_to_procedure_body -from loki.transformations.utilities import ( - single_variable_declaration, recursive_expression_map_update -) +from loki.transformations.utilities import single_variable_declaration __all__ = [ @@ -217,10 +215,10 @@ def inline_statement_functions(routine): if proc_type.is_function and isinstance(call.routine, StatementFunction): exprmap[call] = InlineSubstitutionMapper()(call, scope=routine) removed_functions.add(call.routine) - # Apply the map to itself to handle nested statement function calls - exprmap = recursive_expression_map_update(exprmap, max_iterations=10, mapper_cls=InlineSubstitutionMapper) - # Apply expression-level substitution to routine - routine.body = SubstituteExpressions(exprmap).visit(routine.body) + + # Apply expression-level substitution to routine and handle nested + # statement function calls via recursion + routine.body = SubstituteExpressions(exprmap, recursive=True).visit(routine.body) # remove statement function declarations as well as statement function argument(s) declarations vars_to_remove = {stmt_func.variable.name.lower() for stmt_func in stmt_func_decls} @@ -272,8 +270,7 @@ def rename_result_name(routine, rename): callee_vars = [var for var in FindVariables().visit(callee.body) if var.name.lower() == callee_result_var.name.lower()] var_map.update({var: var.clone(name=rename) for var in callee_vars}) - var_map = recursive_expression_map_update(var_map) - callee.body = SubstituteExpressions(var_map).visit(callee.body) + callee.body = SubstituteExpressions(var_map, recursive=True).visit(callee.body) return callee, new_callee_result_var allowed_aliases = as_tuple(allowed_aliases) @@ -302,8 +299,7 @@ def rename_result_name(routine, rename): if v.name.lower() in duplicate_names: var_map[v] = v.clone(name=f'{callee.name}_{v.name}') - var_map = recursive_expression_map_update(var_map) - callee.body = SubstituteExpressions(var_map).visit(callee.body) + callee.body = SubstituteExpressions(var_map, recursive=True).visit(callee.body) # Separate allowed aliases from other variables to ensure clean hoisting if allowed_aliases: diff --git a/loki/transformations/inline/procedures.py b/loki/transformations/inline/procedures.py index f5a74764b..9b55a0222 100644 --- a/loki/transformations/inline/procedures.py +++ b/loki/transformations/inline/procedures.py @@ -19,9 +19,7 @@ from loki.subroutine import Subroutine from loki.transformations.sanitise import transform_sequence_association_append_map -from loki.transformations.utilities import ( - single_variable_declaration, recursive_expression_map_update -) +from loki.transformations.utilities import single_variable_declaration __all__ = [ @@ -146,11 +144,8 @@ def _map_unbound_dims(var, val): } argmap.update(present_map) - # Recursive update of the map in case of nested variables to map - argmap = recursive_expression_map_update(argmap, max_iterations=10) - - # Substitute argument calls into a copy of the body - callee_body = SubstituteExpressions(argmap, rebuild_scopes=True).visit( + # Substitute argument calls into a copy of the body and capture nesting + callee_body = SubstituteExpressions(argmap, recursive=True, rebuild_scopes=True).visit( callee.body.body, scope=caller ) @@ -212,8 +207,7 @@ def inline_subroutine_calls(routine, calls, callee, allowed_aliases=None): for v in FindVariables(unique=False).visit(callee.body): if v.name.lower() in duplicate_names: var_map[v] = v.clone(name=f'{callee.name}_{v.name}') - var_map = recursive_expression_map_update(var_map) - callee.body = SubstituteExpressions(var_map).visit(callee.body) + callee.body = SubstituteExpressions(var_map, recursive=True).visit(callee.body) # Separate allowed aliases from other variables to ensure clean hoisting if allowed_aliases: diff --git a/loki/transformations/pool_allocator.py b/loki/transformations/pool_allocator.py index 5904d1295..933a82fb1 100644 --- a/loki/transformations/pool_allocator.py +++ b/loki/transformations/pool_allocator.py @@ -26,8 +26,6 @@ from loki.tools import as_tuple from loki.types import SymbolAttributes, BasicType, DerivedType -from loki.transformations.utilities import recursive_expression_map_update - __all__ = ['TemporariesPoolAllocatorTransformation'] @@ -527,8 +525,7 @@ def _determine_stack_size(self, routine, successors, local_stack_size=None, item if expr in arg_map } if expr_map: - expr_map = recursive_expression_map_update(expr_map) - successor_stack_size = SubstituteExpressions(expr_map).visit(successor_stack_size) + successor_stack_size = SubstituteExpressions(expr_map, recursive=True).visit(successor_stack_size) stack_sizes += [successor_stack_size] # Unwind "max" expressions from successors and inject the local stack size into the expressions diff --git a/loki/transformations/single_column/claw.py b/loki/transformations/single_column/claw.py index 159d0b32d..0038009c0 100644 --- a/loki/transformations/single_column/claw.py +++ b/loki/transformations/single_column/claw.py @@ -22,8 +22,6 @@ from loki.tools import as_tuple, CaseInsensitiveDict from loki.types import SymbolAttributes, BasicType -from loki.transformations.utilities import recursive_expression_map_update - __all__ = ['ExtractSCATransformation', 'CLAWTransformation'] @@ -131,10 +129,8 @@ def remove_dimension(self, routine): # Apply vmap to variable and argument list and subroutine body routine.variables = [vmap.get(v, v) for v in routine.variables] - # Apply substitution map to replacements to capture nesting - vmap = recursive_expression_map_update(vmap) - - routine.body = SubstituteExpressions(vmap).visit(routine.body) + # Apply substitution map and capture nesting via recursion + routine.body = SubstituteExpressions(vmap, recursive=True).visit(routine.body) for m in as_tuple(routine.members): m.body = SubstituteExpressions(vmap).visit(m.body) diff --git a/loki/transformations/tests/test_utilities.py b/loki/transformations/tests/test_utilities.py index cbacd60d3..51a71c430 100644 --- a/loki/transformations/tests/test_utilities.py +++ b/loki/transformations/tests/test_utilities.py @@ -17,10 +17,10 @@ from loki.types import BasicType from loki.transformations.utilities import ( - single_variable_declaration, recursive_expression_map_update, - convert_to_lower_case, replace_intrinsics, rename_variables, - get_integer_variable, get_loop_bounds, is_driver_loop, - find_driver_loops, get_local_arrays, check_routine_sequential + single_variable_declaration, convert_to_lower_case, + replace_intrinsics, rename_variables, get_integer_variable, + get_loop_bounds, is_driver_loop, find_driver_loops, + get_local_arrays, check_routine_sequential ) @@ -140,65 +140,58 @@ def test_transform_convert_to_lower_case(frontend): assert all(var.name.islower() and str(var).islower() for var in FindVariables(unique=False).visit(routine.ir)) -@pytest.mark.parametrize('frontend', available_frontends()) -def test_transform_utilities_recursive_expression_map_update(frontend, tmp_path): - fcode = """ -module some_mod - implicit none - - type some_type - integer :: m, n - real, allocatable :: a(:, :) - contains - procedure, pass :: my_add - end type some_type -contains - function my_add(self, data, val) - class(some_type), intent(inout) :: self - real, intent(in) :: data(:,:) - real, value :: val - real :: my_add(:,:) - my_add(:,:) = self%a(:,:) + data(:,:) + val - end function my_add - - subroutine do(my_obj) - type(some_type), intent(inout) :: my_obj - my_obj%a = my_obj%my_add(MY_OBJ%a(1:my_obj%m, 1:MY_OBJ%n), 1.) - end subroutine do -end module some_mod - """.strip() - - module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) - routine = module['do'] - - expr_map = {} - expr_map[routine.variable_map['my_obj']] = routine.variable_map['my_obj'].clone(name='obj') - for var in FindVariables().visit(routine.body): - if var.parent == 'my_obj': - expr_map[var] = var.clone(name=f'obj%{var.basename}', parent=var.parent.clone(name='obj')) - - # There are "my_obj" nodes still around... - assert any( - var == 'my_obj' or var.parent == 'my_obj' for var in FindVariables().visit(list(expr_map.values())) - ) - - # ...and application performs only a partial substitution - cloned = routine.clone() - cloned.body = SubstituteExpressions(expr_map).visit(cloned.body) - assert fgen(cloned.body.body[0]).lower() == 'obj%a = obj%my_add(obj%a(1:my_obj%m, 1:my_obj%n), 1.)' - - # Apply recursive update - expr_map = recursive_expression_map_update(expr_map) - - # No more "my_obj" nodes... - assert all( - var != 'my_obj' and var.parent != 'my_obj' for var in FindVariables().visit(list(expr_map.values())) - ) - - # ...and full substitution - assert fgen(routine.body.body[0]).lower() == 'my_obj%a = my_obj%my_add(my_obj%a(1:my_obj%m, 1:my_obj%n), 1.)' - routine.body = SubstituteExpressions(expr_map).visit(routine.body) - assert fgen(routine.body.body[0]) == 'obj%a = obj%my_add(obj%a(1:obj%m, 1:obj%n), 1.)' +# @pytest.mark.parametrize('frontend', available_frontends()) +# def test_transform_utilities_recursive_expression_map_update(frontend, tmp_path): +# fcode = """ +# module some_mod +# implicit none + +# type some_type +# integer :: m, n +# real, allocatable :: a(:, :) +# contains +# procedure, pass :: my_add +# end type some_type +# contains +# function my_add(self, data, val) +# class(some_type), intent(inout) :: self +# real, intent(in) :: data(:,:) +# real, value :: val +# real :: my_add(:,:) +# my_add(:,:) = self%a(:,:) + data(:,:) + val +# end function my_add + +# subroutine do(my_obj) +# type(some_type), intent(inout) :: my_obj +# my_obj%a = my_obj%my_add(MY_OBJ%a(1:my_obj%m, 1:MY_OBJ%n), 1.) +# end subroutine do +# end module some_mod +# """.strip() + +# module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) +# routine = module['do'] + +# expr_map = {} +# expr_map[routine.variable_map['my_obj']] = routine.variable_map['my_obj'].clone(name='obj') +# for var in FindVariables().visit(routine.body): +# if var.parent == 'my_obj': +# expr_map[var] = var.clone(name=f'obj%{var.basename}', parent=var.parent.clone(name='obj')) + +# # There are "my_obj" nodes still around... +# assert any( +# var == 'my_obj' or var.parent == 'my_obj' for var in FindVariables().visit(list(expr_map.values())) +# ) + +# # ...and application performs only a partial substitution +# cloned = routine.clone() +# cloned.body = SubstituteExpressions(expr_map, recurse=False).visit(cloned.body) +# assert fgen(cloned.body.body[0]).lower() == 'obj%a = obj%my_add(obj%a(1:my_obj%m, 1:my_obj%n), 1.)' + +# # ...and full substitution +# assert fgen(routine.body.body[0]).lower() == 'my_obj%a = my_obj%my_add(my_obj%a(1:my_obj%m, 1:my_obj%n), 1.)' +# routine.body = SubstituteExpressions(expr_map, recurse=True).visit(routine.body) + +# assert fgen(routine.body.body[0]) == 'obj%a = obj%my_add(obj%a(1:obj%m, 1:obj%n), 1.)' @pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'Argument mismatch for "min"')])) def test_transform_utilites_replace_intrinsics(frontend): diff --git a/loki/transformations/transform_derived_types.py b/loki/transformations/transform_derived_types.py index b3b8bfd4c..f781eacfe 100644 --- a/loki/transformations/transform_derived_types.py +++ b/loki/transformations/transform_derived_types.py @@ -30,8 +30,6 @@ from loki.tools import as_tuple, flatten, CaseInsensitiveDict from loki.types import BasicType, DerivedType, ProcedureType -from loki.transformations.utilities import recursive_expression_map_update - __all__ = ['DerivedTypeArgumentsTransformation', 'TypeboundProcedureCallTransformation'] @@ -351,9 +349,8 @@ def assumed_dim_or_none(shape): routine.variables = [v for var in routine.variables for v in arguments_map.get(var, [var])] # Substitue derived type member use in the spec and body - vmap = recursive_expression_map_update(vmap) - routine.spec = SubstituteExpressions(vmap).visit(routine.spec) - routine.body = SubstituteExpressions(vmap).visit(routine.body) + routine.spec = SubstituteExpressions(vmap, recursive=True).visit(routine.spec) + routine.body = SubstituteExpressions(vmap, recursive=True).visit(routine.body) # Update procedure bindings by specifying NOPASS attribute for arg in arguments_map: @@ -527,7 +524,7 @@ def _update_call(call): var, type=cls._get_expanded_kernel_var_type(orig_arg, var), scope=routine, dimensions=None ) expansion_map[var] = expanded_var - expansion_mapper = SubstituteExpressionsMapper(recursive_expression_map_update(expansion_map)) + expansion_mapper = SubstituteExpressionsMapper(expansion_map, recursive=True) arguments = tuple(expansion_mapper(arg) for arg in arguments) kwarguments = tuple((k, expansion_mapper(v)) for k, v in kwarguments) return arguments, kwarguments @@ -713,8 +710,7 @@ def visit_Expression(self, o, **kwargs): if not expr_map: return o - expr_map = recursive_expression_map_update(expr_map) - return SubstituteExpressionsMapper(expr_map)(o) + return SubstituteExpressionsMapper(expr_map, recursive=True)(o) class TypeboundProcedureCallTransformation(Transformation): diff --git a/loki/transformations/utilities.py b/loki/transformations/utilities.py index aca976a5e..5f89f857d 100644 --- a/loki/transformations/utilities.py +++ b/loki/transformations/utilities.py @@ -31,9 +31,9 @@ __all__ = [ 'convert_to_lower_case', 'replace_intrinsics', 'rename_variables', 'sanitise_imports', 'replace_selected_kind', - 'single_variable_declaration', 'recursive_expression_map_update', - 'get_integer_variable', 'get_loop_bounds', 'find_driver_loops', - 'get_local_arrays', 'check_routine_sequential' + 'single_variable_declaration', 'get_integer_variable', + 'get_loop_bounds', 'find_driver_loops', 'get_local_arrays', + 'check_routine_sequential' ] @@ -98,9 +98,8 @@ def convert_to_lower_case(routine): } # Capture nesting by applying map to itself before applying to the routine - vmap = recursive_expression_map_update(vmap) - routine.body = SubstituteExpressions(vmap).visit(routine.body) - routine.spec = SubstituteExpressions(vmap).visit(routine.spec) + routine.body = SubstituteExpressions(vmap, recursive=True).visit(routine.body) + routine.spec = SubstituteExpressions(vmap, recursive=True).visit(routine.spec) # Downcase inline calls to, but only after the above has been propagated, # so that we capture the updates from the variable update in the arguments @@ -472,60 +471,6 @@ def replace_selected_kind(routine): routine.spec.prepend(imprt) -def recursive_expression_map_update(expr_map, max_iterations=10, mapper_cls=SubstituteExpressionsMapper): - """ - Utility function to apply a substitution map for expressions to itself - - The expression substitution mechanism :any:`SubstituteExpressions` and the - underlying mapper :any:`SubstituteExpressionsMapper` replace nodes that - are found in the substitution map by their corresponding replacement. - - However, expression nodes can be nested inside other expression nodes, - e.g. via the ``parent`` or ``dimensions`` properties of variables. - In situations, where such expression nodes as well as expression nodes - appearing inside such properties are marked for substitution, it may - be necessary to apply the substitution map to itself first. This utility - routine takes care of that. - - Parameters - ---------- - expr_map : dict - The substitution map that should be updated - max_iterations : int - Maximum number of iterations, corresponds to the maximum level of - nesting that can be replaced. - mapper_cls: :any:`SubstituteExpressionsMapper` - The underlying mapper to be used (default: :any:`SubstituteExpressionsMapper`). - """ - def apply_to_init_arg(name, arg, expr, mapper): - # Helper utility to apply the mapper only to expression arguments and - # retain the scope while rebuilding the node - if isinstance(arg, (tuple, Expression)): - return mapper(arg) - if name == 'scope': - return expr.scope - return arg - - for _ in range(max_iterations): - # We update the expression map by applying it to the children of each replacement - # node, thus making sure node replacements are also applied to nested attributes, - # e.g. call arguments or array subscripts etc. - mapper = mapper_cls(expr_map) - prev_map, expr_map = expr_map, { - expr: type(replacement)(**{ - name: apply_to_init_arg(name, arg, expr, mapper) - for name, arg in zip(replacement.init_arg_names, replacement.__getinitargs__()) - }) - for expr, replacement in expr_map.items() - } - - # Check for early termination opportunities - if prev_map == expr_map: - break - - return expr_map - - def get_integer_variable(routine, name): """ Find a local variable in the routine, or create an integer-typed one.