Skip to content

Commit

Permalink
Merge pull request #470 from ecmwf-ifs/naml-fix-symbol-scopes-remove-…
Browse files Browse the repository at this point in the history
…assocs

Sanitise: Fix rescoping of symbols for nested associates
  • Loading branch information
reuterbal authored Jan 10, 2025
2 parents 7668989 + 9a3c705 commit ddd15e0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
7 changes: 5 additions & 2 deletions loki/transformations/sanitise/associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,11 @@ def map_scalar(self, expr, *args, **kwargs):
expr = scope.inverse_map[expr.basename]
return self.rec(expr, *args, **kwargs)

# Update the scope, as this one will be removed
return expr.clone(scope=scope.parent)
# Update the scope, as any inner associates will be removed.
# For this we count backwards the nested scopes, the tail of
# which will the (innermost) associates.
new_scope = scope.parents[::-1][depth-self.start_depth-1]
return expr.clone(scope=new_scope)

def map_array(self, expr, *args, **kwargs):
""" Partially resolve dimension indices and handle shape """
Expand Down
25 changes: 18 additions & 7 deletions loki/transformations/sanitise/tests/test_associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,8 @@ def test_associates_transformation(frontend, merge, resolve):
@pytest.mark.parametrize('frontend', available_frontends(
skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_resolve_associates_stmt_func(frontend):
@pytest.mark.parametrize('depth', [0, 1, 2])
def test_resolve_associates_stmt_func(frontend, depth):
"""
Test scope management for stmt funcs, either as
:any:`ProcedureSymbol` or :any:`DeferredTypeSymbol`.
Expand All @@ -518,32 +519,42 @@ def test_resolve_associates_stmt_func(frontend):
real(kind=8) :: not_an_array
not_an_array ( x, y ) = x * y
associate(d=>b)
associate(c=>a)
associate(RTT=>YDCST%RTT)
a = not_an_array(RTT, 1.0) + a
b = some_stmt_func(RTT, 1.0) + b
end associate
end associate
end associate
end subroutine test_associates_stmt_func
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

associate = FindNodes(ir.Associate).visit(routine.body)[0]
associates = FindNodes(ir.Associate).visit(routine.body)
assert len(associates) == 3
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 2
assert isinstance(assigns[0].rhs.children[0], sym.InlineCall)
assert assigns[0].rhs.children[0].function.scope == associate
assert assigns[0].rhs.children[0].function.scope == associates[2]
assert isinstance(assigns[1].rhs.children[0], sym.InlineCall)
assert assigns[1].rhs.children[0].function.scope == associate
assert assigns[1].rhs.children[0].function.scope == associates[2]

do_resolve_associates(routine)
do_resolve_associates(routine, start_depth=depth)

associates = FindNodes(ir.Associate).visit(routine.body)
assert len(associates) == depth

assigns = FindNodes(ir.Assignment).visit(routine.body)
# Determine the outer routine or last associate left
outer_scope = routine if depth == 0 else associates[depth-1]
assert len(assigns) == 2
assert assigns[0].rhs == 'not_an_array(YDCST%RTT, 1.0) + a'
assert assigns[1].rhs == 'some_stmt_func(YDCST%RTT, 1.0) + b'
assert isinstance(assigns[0].rhs.children[0], sym.InlineCall)
assert assigns[0].rhs.children[0].function.scope == routine
assert assigns[0].rhs.children[0].function.scope == outer_scope
assert isinstance(assigns[1].rhs.children[0], sym.InlineCall)
assert assigns[1].rhs.children[0].function.scope == routine
assert assigns[1].rhs.children[0].function.scope == outer_scope

# Trigger a full clone, which would fail if scopes are missing
routine.clone()

0 comments on commit ddd15e0

Please sign in to comment.