From f4b9036f442bbaaa3cbb4a0a72553ae9fe934b34 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Fri, 29 Nov 2024 09:53:03 +0000 Subject: [PATCH 1/7] Sanitise: Add test for matching range indices in array expressions --- .../sanitise/tests/test_associates.py | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/loki/transformations/sanitise/tests/test_associates.py b/loki/transformations/sanitise/tests/test_associates.py index d1e3ff7be..f3d96a82f 100644 --- a/loki/transformations/sanitise/tests/test_associates.py +++ b/loki/transformations/sanitise/tests/test_associates.py @@ -148,6 +148,57 @@ def test_transform_associates_array_call(frontend): assert routine.variable_map['local_arr'].type.shape == ('some_obj%a%n',) +@pytest.mark.parametrize('frontend', available_frontends( + skip=[(OMNI, 'OMNI does not handle missing type definitions')] +)) +def test_transform_associates_array_slices(frontend): + """ + Test the resolution of associated array slices. + """ + fcode = """ +subroutine transform_associates_slices(arr2d) + use some_module, only: some_obj, another_routine + implicit none + real, intent(inout) :: arr2d(:,:) + integer :: i + integer, parameter :: idx_a = 2 + + associate (a => arr2d(:, 1), b=>arr2d(:, idx_a) ) + b(:) = 42.0 + do i=1, 5 + a(i) = b(i+2) + call another_routine(i, a(2:4), b) + end do + end associate +end subroutine transform_associates_slices +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + + assert len(FindNodes(ir.Associate).visit(routine.body)) == 1 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1 + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert len(assigns) == 2 + calls = FindNodes(ir.CallStatement).visit(routine.body) + assert len(calls) == 1 + assert calls[0].arguments[1] == 'a(2:4)' + assert calls[0].arguments[2] == 'b' + + # Now apply the association resolver + do_resolve_associates(routine) + + assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert len(assigns) == 2 + assert assigns[0].lhs == 'arr2d(:, idx_a)' + assert assigns[1].lhs == 'arr2d(i, 1)' + assert assigns[1].rhs == 'arr2d(i+2, idx_a)' + + calls = FindNodes(ir.CallStatement).visit(routine.body) + assert len(calls) == 1 + assert calls[0].arguments[1] == 'arr2d(2:4, 1)' + assert calls[0].arguments[2] == 'arr2d(:, idx_a)' + + @pytest.mark.parametrize('frontend', available_frontends( skip=[(OMNI, 'OMNI does not handle missing type definitions')] )) From 6618203454a7cc256707f4ecda634c472ef38177 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Fri, 29 Nov 2024 09:55:31 +0000 Subject: [PATCH 2/7] Sanitise: Add utility method to match indices to free range symbols --- loki/transformations/sanitise/associates.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/loki/transformations/sanitise/associates.py b/loki/transformations/sanitise/associates.py index bee89243d..a29a44799 100644 --- a/loki/transformations/sanitise/associates.py +++ b/loki/transformations/sanitise/associates.py @@ -13,8 +13,9 @@ """ from loki.batch import Transformation -from loki.expression import LokiIdentityMapper +from loki.expression import symbols as sym, LokiIdentityMapper from loki.ir import nodes as ir, Transformer, NestedTransformer +from loki.logging import warning from loki.scope import SymbolTable from loki.tools import dict_override @@ -111,6 +112,20 @@ def __init__(self, *args, start_depth=0, **kwargs): self.start_depth = start_depth super().__init__(*args, **kwargs) + @staticmethod + def _match_range_indices(expressions, indices): + """ Map :data:`indices` to free ranges in :data:`expressions` """ + assert isinstance(expressions, tuple) + assert isinstance(indices, tuple) + + free_symbols = tuple(e for e in expressions if isinstance(e, sym.RangeIndex)) + assert len(free_symbols) == len(indices) + if any(s.lower not in (None, 1) for s in free_symbols): + warning('WARNING: Bounds shifts through association is currently not supported') + symbol_map = dict(zip(free_symbols, indices)) + + return tuple(symbol_map.get(e, e) for e in expressions) + def map_scalar(self, expr, *args, **kwargs): # Skip unscoped expressions if not hasattr(expr, 'scope'): From 2f711429c3f2ff36f9bb4cfaa79129fa06ce7136 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Mon, 2 Dec 2024 08:27:37 +0000 Subject: [PATCH 3/7] Sanitise: Apply partial range resolution when resolving associates --- loki/transformations/sanitise/associates.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/loki/transformations/sanitise/associates.py b/loki/transformations/sanitise/associates.py index a29a44799..291340cd2 100644 --- a/loki/transformations/sanitise/associates.py +++ b/loki/transformations/sanitise/associates.py @@ -119,7 +119,6 @@ def _match_range_indices(expressions, indices): assert isinstance(indices, tuple) free_symbols = tuple(e for e in expressions if isinstance(e, sym.RangeIndex)) - assert len(free_symbols) == len(indices) if any(s.lower not in (None, 1) for s in free_symbols): warning('WARNING: Bounds shifts through association is currently not supported') symbol_map = dict(zip(free_symbols, indices)) @@ -158,8 +157,8 @@ def map_scalar(self, expr, *args, **kwargs): return expr.clone(scope=scope.parent) def map_array(self, expr, *args, **kwargs): - """ Special case for arrys: we need to preserve the dimensions """ - new = self.map_variable_symbol(expr, *args, **kwargs) + """ Partially resolve dimension indices and handle shape """ + new = self.map_scalar(expr, *args, **kwargs) # Recurse over the type's shape _type = expr.type @@ -168,7 +167,14 @@ def map_array(self, expr, *args, **kwargs): _type = expr.type.clone(shape=new_shape) # Recurse over array dimensions - new_dims = self.rec(expr.dimensions, *args, **kwargs) + if isinstance(new, sym.Array): + # Resolve unbound range symbols form existing indices + new_dims = self.rec(new.dimensions, *args, **kwargs) + new_dims = self._match_range_indices(new_dims, expr.dimensions) + else: + # Recurse over existing array dimensions + new_dims = self.rec(expr.dimensions, *args, **kwargs) + return new.clone(dimensions=new_dims, type=_type) map_variable_symbol = map_scalar From 344d0986fb8344cf08fbd30a41a2b471249fcba9 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Mon, 2 Dec 2024 13:06:21 +0000 Subject: [PATCH 4/7] Sanitise: Terminate after dims/type recursion when resolving assocs This is subtle, but to avoid false positives on the index-range matching, we need to terminate early if the symbol scope is not an association. However, to get this right, we still need to recurse over prior expression dimensions and type symbols. --- loki/transformations/sanitise/associates.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/loki/transformations/sanitise/associates.py b/loki/transformations/sanitise/associates.py index 291340cd2..cba02a2c0 100644 --- a/loki/transformations/sanitise/associates.py +++ b/loki/transformations/sanitise/associates.py @@ -158,7 +158,9 @@ def map_scalar(self, expr, *args, **kwargs): def map_array(self, expr, *args, **kwargs): """ Partially resolve dimension indices and handle shape """ - new = self.map_scalar(expr, *args, **kwargs) + + # Recurse over existing array dimensions + expr_dims = self.rec(expr.dimensions, *args, **kwargs) # Recurse over the type's shape _type = expr.type @@ -166,14 +168,19 @@ def map_array(self, expr, *args, **kwargs): new_shape = self.rec(expr.type.shape, *args, **kwargs) _type = expr.type.clone(shape=new_shape) + # Stop if scope is not an associate + if not isinstance(expr.scope, ir.Associate): + return expr.clone(dimensions=expr_dims, type=_type) + + new = self.map_scalar(expr, *args, **kwargs) + # Recurse over array dimensions if isinstance(new, sym.Array): # Resolve unbound range symbols form existing indices new_dims = self.rec(new.dimensions, *args, **kwargs) - new_dims = self._match_range_indices(new_dims, expr.dimensions) + new_dims = self._match_range_indices(new_dims, expr_dims) else: - # Recurse over existing array dimensions - new_dims = self.rec(expr.dimensions, *args, **kwargs) + new_dims = expr_dims return new.clone(dimensions=new_dims, type=_type) From c055d93c1e3f6b171322052f9ee11bf0dd364e1b Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Fri, 6 Dec 2024 04:16:02 +0000 Subject: [PATCH 5/7] Sanitise: Only apply partial range resolution if new symbols has dims Also slightly adjust the simply associate resolver test case to cover this basic use case. --- loki/transformations/sanitise/associates.py | 2 +- loki/transformations/sanitise/tests/test_associates.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/loki/transformations/sanitise/associates.py b/loki/transformations/sanitise/associates.py index cba02a2c0..5ccfc0ecb 100644 --- a/loki/transformations/sanitise/associates.py +++ b/loki/transformations/sanitise/associates.py @@ -175,7 +175,7 @@ def map_array(self, expr, *args, **kwargs): new = self.map_scalar(expr, *args, **kwargs) # Recurse over array dimensions - if isinstance(new, sym.Array): + if isinstance(new, sym.Array) and new.dimensions: # Resolve unbound range symbols form existing indices new_dims = self.rec(new.dimensions, *args, **kwargs) new_dims = self._match_range_indices(new_dims, expr_dims) diff --git a/loki/transformations/sanitise/tests/test_associates.py b/loki/transformations/sanitise/tests/test_associates.py index f3d96a82f..085c7016c 100644 --- a/loki/transformations/sanitise/tests/test_associates.py +++ b/loki/transformations/sanitise/tests/test_associates.py @@ -33,7 +33,7 @@ def test_transform_associates_simple(frontend): real :: local_var associate (a => some_obj%a) - local_var = a + local_var = a(:) end associate end subroutine transform_associates_simple """ @@ -42,7 +42,7 @@ def test_transform_associates_simple(frontend): assert len(FindNodes(ir.Associate).visit(routine.body)) == 1 assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1 assign = FindNodes(ir.Assignment).visit(routine.body)[0] - assert assign.rhs == 'a' and 'some_obj' not in assign.rhs + assert assign.rhs == 'a(:)' and 'some_obj' not in assign.rhs assert assign.rhs.type.dtype == BasicType.DEFERRED # Now apply the association resolver @@ -51,7 +51,7 @@ def test_transform_associates_simple(frontend): assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1 assign = FindNodes(ir.Assignment).visit(routine.body)[0] - assert assign.rhs == 'some_obj%a' + assert assign.rhs == 'some_obj%a(:)' assert assign.rhs.parent == 'some_obj' assert assign.rhs.type.dtype == BasicType.DEFERRED assert assign.rhs.scope == routine From bed1ae8faf1cff9f4b00fcd5af266782160e2446 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 11 Dec 2024 10:52:57 +0000 Subject: [PATCH 6/7] Sanitise: Use iterator, when matching free symbols in assoc resolve When using dict-mapping to match symbols, the range keys might be `:`, which alias and mean we'd miss susequent `:` matches. The test has been updated accordingly. --- loki/transformations/sanitise/associates.py | 12 ++++++++++-- .../sanitise/tests/test_associates.py | 19 +++++++++++++------ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/loki/transformations/sanitise/associates.py b/loki/transformations/sanitise/associates.py index 5ccfc0ecb..9e50c9e1b 100644 --- a/loki/transformations/sanitise/associates.py +++ b/loki/transformations/sanitise/associates.py @@ -121,9 +121,17 @@ def _match_range_indices(expressions, indices): free_symbols = tuple(e for e in expressions if isinstance(e, sym.RangeIndex)) if any(s.lower not in (None, 1) for s in free_symbols): warning('WARNING: Bounds shifts through association is currently not supported') - symbol_map = dict(zip(free_symbols, indices)) - return tuple(symbol_map.get(e, e) for e in expressions) + if len(free_symbols) == len(indices): + # If the provided indices are enough to bind free symbols, + # we match them in sequence. + it = iter(indices) + return tuple( + next(it) if isinstance(e, sym.RangeIndex) else e + for e in expressions + ) + + return expressions def map_scalar(self, expr, *args, **kwargs): # Skip unscoped expressions diff --git a/loki/transformations/sanitise/tests/test_associates.py b/loki/transformations/sanitise/tests/test_associates.py index 085c7016c..33fb22cbc 100644 --- a/loki/transformations/sanitise/tests/test_associates.py +++ b/loki/transformations/sanitise/tests/test_associates.py @@ -156,18 +156,23 @@ def test_transform_associates_array_slices(frontend): Test the resolution of associated array slices. """ fcode = """ -subroutine transform_associates_slices(arr2d) +subroutine transform_associates_slices(arr2d, arr3d) use some_module, only: some_obj, another_routine implicit none - real, intent(inout) :: arr2d(:,:) - integer :: i + real, intent(inout) :: arr2d(:,:), arr3d(:,:,:) + integer :: i, j integer, parameter :: idx_a = 2 + integer, parameter :: idx_c = 3 - associate (a => arr2d(:, 1), b=>arr2d(:, idx_a) ) + associate (a => arr2d(:, 1), b=>arr2d(:, idx_a), & + & c => arr3d(:,:,idx_c) ) b(:) = 42.0 do i=1, 5 a(i) = b(i+2) call another_routine(i, a(2:4), b) + do j=1, 7 + c(i, j) = c(i, j) + b(j) + end do end do end associate end subroutine transform_associates_slices @@ -177,7 +182,7 @@ def test_transform_associates_array_slices(frontend): assert len(FindNodes(ir.Associate).visit(routine.body)) == 1 assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1 assigns = FindNodes(ir.Assignment).visit(routine.body) - assert len(assigns) == 2 + assert len(assigns) == 3 calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 1 assert calls[0].arguments[1] == 'a(2:4)' @@ -188,10 +193,12 @@ def test_transform_associates_array_slices(frontend): assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 assigns = FindNodes(ir.Assignment).visit(routine.body) - assert len(assigns) == 2 + assert len(assigns) == 3 assert assigns[0].lhs == 'arr2d(:, idx_a)' assert assigns[1].lhs == 'arr2d(i, 1)' assert assigns[1].rhs == 'arr2d(i+2, idx_a)' + assert assigns[2].lhs == 'arr3d(i, j, idx_c)' + assert assigns[2].rhs == 'arr3d(i, j, idx_c) + arr2d(j, idx_a)' calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 1 From e350f50e08ec10c440e7f4474b1dd6f7b694c96a Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 11 Dec 2024 12:40:46 +0000 Subject: [PATCH 7/7] Sanitise: Add recursion + symbol-matching to array slice test --- loki/transformations/sanitise/tests/test_associates.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/loki/transformations/sanitise/tests/test_associates.py b/loki/transformations/sanitise/tests/test_associates.py index 33fb22cbc..f15ce9ae4 100644 --- a/loki/transformations/sanitise/tests/test_associates.py +++ b/loki/transformations/sanitise/tests/test_associates.py @@ -165,13 +165,14 @@ def test_transform_associates_array_slices(frontend): integer, parameter :: idx_c = 3 associate (a => arr2d(:, 1), b=>arr2d(:, idx_a), & - & c => arr3d(:,:,idx_c) ) + & c => arr3d(:,:,idx_c), idx => some_obj%idx) b(:) = 42.0 do i=1, 5 a(i) = b(i+2) call another_routine(i, a(2:4), b) do j=1, 7 c(i, j) = c(i, j) + b(j) + c(i, idx) = c(i, idx) + 42.0 end do end do end associate @@ -182,7 +183,7 @@ def test_transform_associates_array_slices(frontend): assert len(FindNodes(ir.Associate).visit(routine.body)) == 1 assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1 assigns = FindNodes(ir.Assignment).visit(routine.body) - assert len(assigns) == 3 + assert len(assigns) == 4 calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 1 assert calls[0].arguments[1] == 'a(2:4)' @@ -193,12 +194,14 @@ def test_transform_associates_array_slices(frontend): assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 assigns = FindNodes(ir.Assignment).visit(routine.body) - assert len(assigns) == 3 + assert len(assigns) == 4 assert assigns[0].lhs == 'arr2d(:, idx_a)' assert assigns[1].lhs == 'arr2d(i, 1)' assert assigns[1].rhs == 'arr2d(i+2, idx_a)' assert assigns[2].lhs == 'arr3d(i, j, idx_c)' assert assigns[2].rhs == 'arr3d(i, j, idx_c) + arr2d(j, idx_a)' + assert assigns[3].lhs == 'arr3d(i, some_obj%idx, idx_c)' + assert assigns[3].rhs == 'arr3d(i, some_obj%idx, idx_c) + 42.0' calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 1