Skip to content

Commit

Permalink
Merge pull request #455 from ecmwf-ifs/naml-resolve-assoc-array-indices
Browse files Browse the repository at this point in the history
Sanitise: Resolve free range indices when resolving associates
  • Loading branch information
reuterbal authored Dec 13, 2024
2 parents a26cb44 + e350f50 commit 64b0bbf
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 7 deletions.
44 changes: 40 additions & 4 deletions loki/transformations/sanitise/associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
"""

from loki.batch import Transformation
from loki.expression import LokiIdentityMapper
from loki.expression import symbols as sym, LokiIdentityMapper
from loki.ir import nodes as ir, Transformer, NestedTransformer
from loki.logging import warning
from loki.scope import SymbolTable
from loki.tools import dict_override

Expand Down Expand Up @@ -111,6 +112,27 @@ def __init__(self, *args, start_depth=0, **kwargs):
self.start_depth = start_depth
super().__init__(*args, **kwargs)

@staticmethod
def _match_range_indices(expressions, indices):
""" Map :data:`indices` to free ranges in :data:`expressions` """
assert isinstance(expressions, tuple)
assert isinstance(indices, tuple)

free_symbols = tuple(e for e in expressions if isinstance(e, sym.RangeIndex))
if any(s.lower not in (None, 1) for s in free_symbols):
warning('WARNING: Bounds shifts through association is currently not supported')

if len(free_symbols) == len(indices):
# If the provided indices are enough to bind free symbols,
# we match them in sequence.
it = iter(indices)
return tuple(
next(it) if isinstance(e, sym.RangeIndex) else e
for e in expressions
)

return expressions

def map_scalar(self, expr, *args, **kwargs):
# Skip unscoped expressions
if not hasattr(expr, 'scope'):
Expand Down Expand Up @@ -143,17 +165,31 @@ def map_scalar(self, expr, *args, **kwargs):
return expr.clone(scope=scope.parent)

def map_array(self, expr, *args, **kwargs):
""" Special case for arrys: we need to preserve the dimensions """
new = self.map_variable_symbol(expr, *args, **kwargs)
""" Partially resolve dimension indices and handle shape """

# Recurse over existing array dimensions
expr_dims = self.rec(expr.dimensions, *args, **kwargs)

# Recurse over the type's shape
_type = expr.type
if expr.type.shape:
new_shape = self.rec(expr.type.shape, *args, **kwargs)
_type = expr.type.clone(shape=new_shape)

# Stop if scope is not an associate
if not isinstance(expr.scope, ir.Associate):
return expr.clone(dimensions=expr_dims, type=_type)

new = self.map_scalar(expr, *args, **kwargs)

# Recurse over array dimensions
new_dims = self.rec(expr.dimensions, *args, **kwargs)
if isinstance(new, sym.Array) and new.dimensions:
# Resolve unbound range symbols form existing indices
new_dims = self.rec(new.dimensions, *args, **kwargs)
new_dims = self._match_range_indices(new_dims, expr_dims)
else:
new_dims = expr_dims

return new.clone(dimensions=new_dims, type=_type)

map_variable_symbol = map_scalar
Expand Down
67 changes: 64 additions & 3 deletions loki/transformations/sanitise/tests/test_associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_transform_associates_simple(frontend):
real :: local_var
associate (a => some_obj%a)
local_var = a
local_var = a(:)
end associate
end subroutine transform_associates_simple
"""
Expand All @@ -42,7 +42,7 @@ def test_transform_associates_simple(frontend):
assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
assign = FindNodes(ir.Assignment).visit(routine.body)[0]
assert assign.rhs == 'a' and 'some_obj' not in assign.rhs
assert assign.rhs == 'a(:)' and 'some_obj' not in assign.rhs
assert assign.rhs.type.dtype == BasicType.DEFERRED

# Now apply the association resolver
Expand All @@ -51,7 +51,7 @@ def test_transform_associates_simple(frontend):
assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
assign = FindNodes(ir.Assignment).visit(routine.body)[0]
assert assign.rhs == 'some_obj%a'
assert assign.rhs == 'some_obj%a(:)'
assert assign.rhs.parent == 'some_obj'
assert assign.rhs.type.dtype == BasicType.DEFERRED
assert assign.rhs.scope == routine
Expand Down Expand Up @@ -148,6 +148,67 @@ def test_transform_associates_array_call(frontend):
assert routine.variable_map['local_arr'].type.shape == ('some_obj%a%n',)


@pytest.mark.parametrize('frontend', available_frontends(
skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_associates_array_slices(frontend):
"""
Test the resolution of associated array slices.
"""
fcode = """
subroutine transform_associates_slices(arr2d, arr3d)
use some_module, only: some_obj, another_routine
implicit none
real, intent(inout) :: arr2d(:,:), arr3d(:,:,:)
integer :: i, j
integer, parameter :: idx_a = 2
integer, parameter :: idx_c = 3
associate (a => arr2d(:, 1), b=>arr2d(:, idx_a), &
& c => arr3d(:,:,idx_c), idx => some_obj%idx)
b(:) = 42.0
do i=1, 5
a(i) = b(i+2)
call another_routine(i, a(2:4), b)
do j=1, 7
c(i, j) = c(i, j) + b(j)
c(i, idx) = c(i, idx) + 42.0
end do
end do
end associate
end subroutine transform_associates_slices
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 4
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].arguments[1] == 'a(2:4)'
assert calls[0].arguments[2] == 'b'

# Now apply the association resolver
do_resolve_associates(routine)

assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 4
assert assigns[0].lhs == 'arr2d(:, idx_a)'
assert assigns[1].lhs == 'arr2d(i, 1)'
assert assigns[1].rhs == 'arr2d(i+2, idx_a)'
assert assigns[2].lhs == 'arr3d(i, j, idx_c)'
assert assigns[2].rhs == 'arr3d(i, j, idx_c) + arr2d(j, idx_a)'
assert assigns[3].lhs == 'arr3d(i, some_obj%idx, idx_c)'
assert assigns[3].rhs == 'arr3d(i, some_obj%idx, idx_c) + 42.0'

calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].arguments[1] == 'arr2d(2:4, 1)'
assert calls[0].arguments[2] == 'arr2d(:, idx_a)'


@pytest.mark.parametrize('frontend', available_frontends(
skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
Expand Down

0 comments on commit 64b0bbf

Please sign in to comment.