diff --git a/loki/transformations/sanitise/associates.py b/loki/transformations/sanitise/associates.py index 9e50c9e1b..5335ff3c1 100644 --- a/loki/transformations/sanitise/associates.py +++ b/loki/transformations/sanitise/associates.py @@ -161,8 +161,11 @@ def map_scalar(self, expr, *args, **kwargs): expr = scope.inverse_map[expr.basename] return self.rec(expr, *args, **kwargs) - # Update the scope, as this one will be removed - return expr.clone(scope=scope.parent) + # Update the scope, as any inner associates will be removed. + # For this we count backwards the nested scopes, the tail of + # which will the (innermost) associates. + new_scope = scope.parents[::-1][depth-self.start_depth-1] + return expr.clone(scope=new_scope) def map_array(self, expr, *args, **kwargs): """ Partially resolve dimension indices and handle shape """ diff --git a/loki/transformations/sanitise/tests/test_associates.py b/loki/transformations/sanitise/tests/test_associates.py index f15ce9ae4..bf24da0c5 100644 --- a/loki/transformations/sanitise/tests/test_associates.py +++ b/loki/transformations/sanitise/tests/test_associates.py @@ -503,7 +503,8 @@ def test_associates_transformation(frontend, merge, resolve): @pytest.mark.parametrize('frontend', available_frontends( skip=[(OMNI, 'OMNI does not handle missing type definitions')] )) -def test_resolve_associates_stmt_func(frontend): +@pytest.mark.parametrize('depth', [0, 1, 2]) +def test_resolve_associates_stmt_func(frontend, depth): """ Test scope management for stmt funcs, either as :any:`ProcedureSymbol` or :any:`DeferredTypeSymbol`. @@ -518,32 +519,42 @@ def test_resolve_associates_stmt_func(frontend): real(kind=8) :: not_an_array not_an_array ( x, y ) = x * y +associate(d=>b) +associate(c=>a) associate(RTT=>YDCST%RTT) a = not_an_array(RTT, 1.0) + a b = some_stmt_func(RTT, 1.0) + b end associate +end associate +end associate end subroutine test_associates_stmt_func """ routine = Subroutine.from_source(fcode, frontend=frontend) - associate = FindNodes(ir.Associate).visit(routine.body)[0] + associates = FindNodes(ir.Associate).visit(routine.body) + assert len(associates) == 3 assigns = FindNodes(ir.Assignment).visit(routine.body) assert len(assigns) == 2 assert isinstance(assigns[0].rhs.children[0], sym.InlineCall) - assert assigns[0].rhs.children[0].function.scope == associate + assert assigns[0].rhs.children[0].function.scope == associates[2] assert isinstance(assigns[1].rhs.children[0], sym.InlineCall) - assert assigns[1].rhs.children[0].function.scope == associate + assert assigns[1].rhs.children[0].function.scope == associates[2] - do_resolve_associates(routine) + do_resolve_associates(routine, start_depth=depth) + + associates = FindNodes(ir.Associate).visit(routine.body) + assert len(associates) == depth assigns = FindNodes(ir.Assignment).visit(routine.body) + # Determine the outer routine or last associate left + outer_scope = routine if depth == 0 else associates[depth-1] assert len(assigns) == 2 assert assigns[0].rhs == 'not_an_array(YDCST%RTT, 1.0) + a' assert assigns[1].rhs == 'some_stmt_func(YDCST%RTT, 1.0) + b' assert isinstance(assigns[0].rhs.children[0], sym.InlineCall) - assert assigns[0].rhs.children[0].function.scope == routine + assert assigns[0].rhs.children[0].function.scope == outer_scope assert isinstance(assigns[1].rhs.children[0], sym.InlineCall) - assert assigns[1].rhs.children[0].function.scope == routine + assert assigns[1].rhs.children[0].function.scope == outer_scope # Trigger a full clone, which would fail if scopes are missing routine.clone()