From 9a3c705397f09c524dfd0f7de12978dd610f308a Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Fri, 10 Jan 2025 08:28:50 +0000 Subject: [PATCH] Sanitise: Fix rescoping of symbols for nested associates When nested associates are about to be removed, we need to account for that when updating the symbol-scope (a priori). This new logic does this by counting out the nested associates and picking the appropriate one, according to the septh of the symbol in the nest. I've extended the stat-func test, where this most often triggers. --- loki/transformations/sanitise/associates.py | 7 ++++-- .../sanitise/tests/test_associates.py | 25 +++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/loki/transformations/sanitise/associates.py b/loki/transformations/sanitise/associates.py index 9e50c9e1b..5335ff3c1 100644 --- a/loki/transformations/sanitise/associates.py +++ b/loki/transformations/sanitise/associates.py @@ -161,8 +161,11 @@ def map_scalar(self, expr, *args, **kwargs): expr = scope.inverse_map[expr.basename] return self.rec(expr, *args, **kwargs) - # Update the scope, as this one will be removed - return expr.clone(scope=scope.parent) + # Update the scope, as any inner associates will be removed. + # For this we count backwards the nested scopes, the tail of + # which will the (innermost) associates. + new_scope = scope.parents[::-1][depth-self.start_depth-1] + return expr.clone(scope=new_scope) def map_array(self, expr, *args, **kwargs): """ Partially resolve dimension indices and handle shape """ diff --git a/loki/transformations/sanitise/tests/test_associates.py b/loki/transformations/sanitise/tests/test_associates.py index f15ce9ae4..bf24da0c5 100644 --- a/loki/transformations/sanitise/tests/test_associates.py +++ b/loki/transformations/sanitise/tests/test_associates.py @@ -503,7 +503,8 @@ def test_associates_transformation(frontend, merge, resolve): @pytest.mark.parametrize('frontend', available_frontends( skip=[(OMNI, 'OMNI does not handle missing type definitions')] )) -def test_resolve_associates_stmt_func(frontend): +@pytest.mark.parametrize('depth', [0, 1, 2]) +def test_resolve_associates_stmt_func(frontend, depth): """ Test scope management for stmt funcs, either as :any:`ProcedureSymbol` or :any:`DeferredTypeSymbol`. @@ -518,32 +519,42 @@ def test_resolve_associates_stmt_func(frontend): real(kind=8) :: not_an_array not_an_array ( x, y ) = x * y +associate(d=>b) +associate(c=>a) associate(RTT=>YDCST%RTT) a = not_an_array(RTT, 1.0) + a b = some_stmt_func(RTT, 1.0) + b end associate +end associate +end associate end subroutine test_associates_stmt_func """ routine = Subroutine.from_source(fcode, frontend=frontend) - associate = FindNodes(ir.Associate).visit(routine.body)[0] + associates = FindNodes(ir.Associate).visit(routine.body) + assert len(associates) == 3 assigns = FindNodes(ir.Assignment).visit(routine.body) assert len(assigns) == 2 assert isinstance(assigns[0].rhs.children[0], sym.InlineCall) - assert assigns[0].rhs.children[0].function.scope == associate + assert assigns[0].rhs.children[0].function.scope == associates[2] assert isinstance(assigns[1].rhs.children[0], sym.InlineCall) - assert assigns[1].rhs.children[0].function.scope == associate + assert assigns[1].rhs.children[0].function.scope == associates[2] - do_resolve_associates(routine) + do_resolve_associates(routine, start_depth=depth) + + associates = FindNodes(ir.Associate).visit(routine.body) + assert len(associates) == depth assigns = FindNodes(ir.Assignment).visit(routine.body) + # Determine the outer routine or last associate left + outer_scope = routine if depth == 0 else associates[depth-1] assert len(assigns) == 2 assert assigns[0].rhs == 'not_an_array(YDCST%RTT, 1.0) + a' assert assigns[1].rhs == 'some_stmt_func(YDCST%RTT, 1.0) + b' assert isinstance(assigns[0].rhs.children[0], sym.InlineCall) - assert assigns[0].rhs.children[0].function.scope == routine + assert assigns[0].rhs.children[0].function.scope == outer_scope assert isinstance(assigns[1].rhs.children[0], sym.InlineCall) - assert assigns[1].rhs.children[0].function.scope == routine + assert assigns[1].rhs.children[0].function.scope == outer_scope # Trigger a full clone, which would fail if scopes are missing routine.clone()