diff --git a/loki/transformations/tests/test_transform_loop.py b/loki/transformations/tests/test_transform_loop.py index d0cd9a6cd..d7b7b78b8 100644 --- a/loki/transformations/tests/test_transform_loop.py +++ b/loki/transformations/tests/test_transform_loop.py @@ -1714,7 +1714,7 @@ def test_transform_loop_unroll_step(tmp_path, frontend): !Loop A !$loki loop-unroll - do a=1, 10, 2 + do a=-2, 7, 2 s = s + a + 1 end do @@ -1727,7 +1727,7 @@ def test_transform_loop_unroll_step(tmp_path, frontend): # Test the reference solution s = np.zeros(1) function(s=s) - assert s == sum(x + 1 for x in range(1, 11, 2)) + assert s == sum(x + 1 for x in range(-2, 8, 2)) # Apply transformation assert len(FindNodes(Loop).visit(routine.body)) == 1 @@ -1740,7 +1740,7 @@ def test_transform_loop_unroll_step(tmp_path, frontend): # Test transformation s = np.zeros(1) unrolled_function(s=s) - assert s == sum(x + 1 for x in range(1, 11, 2)) + assert s == sum(x + 1 for x in range(-2, 8, 2)) clean_test(filepath) clean_test(unrolled_filepath) @@ -1876,6 +1876,9 @@ def test_transform_loop_unroll_nested_restricted_depth(tmp_path, frontend): unrolled_function(s=s) assert s == sum(a + b + 1 for (a, b) in itertools.product(range(1, 11), range(1, 6))) + # check unroll pragma has been removed + assert not FindNodes(ir.Pragma).visit(routine.body) + clean_test(filepath) clean_test(unrolled_filepath) @@ -2105,3 +2108,44 @@ def test_transform_loop_transformation(frontend, loop_interchange, loop_fusion, assert len(loops) == num_loops assert len(pragmas) == num_pragmas + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_loop_unroll_before_fuse(frontend): + fcode = """ + subroutine test_loop_unroll_before_fuse(n, map, a, b) + integer, intent(in) :: n + integer, intent(in) :: map(3,3) + real, intent(inout) :: a(n) + real, intent(in) :: b(:) + + integer :: i,j,k + + !$loki loop-unroll + do k=1,3 + !$loki loop-unroll + do j=1,3 + !$loki loop-fusion + do i=1,n + a(i) = a(i) + b(map(j,k)) + enddo + enddo + enddo + + end subroutine test_loop_unroll_before_fuse +""" + + routine = Subroutine.from_source(fcode, frontend=frontend) + assert len(FindNodes(ir.Loop).visit(routine.body)) == 3 + + do_loop_unroll(routine) + loops = FindNodes(ir.Loop).visit(routine.body) + assert len(loops) == 9 + assert all(loop.variable == 'i' for loop in loops) + + pragmas = FindNodes(ir.Pragma).visit(routine.body) + assert len(pragmas) == 9 + assert all(p.content == 'loop-fusion' for p in pragmas) + + do_loop_fusion(routine) + assert len(FindNodes(ir.Loop).visit(routine.body)) == 1 diff --git a/loki/transformations/transform_loop.py b/loki/transformations/transform_loop.py index 6a8b83a73..d94b8dfe9 100644 --- a/loki/transformations/transform_loop.py +++ b/loki/transformations/transform_loop.py @@ -20,7 +20,7 @@ ) from loki.expression import ( symbols as sym, simplify, is_constant, symbolic_op, parse_expr, - IntLiteral, FloatLiteral + IntLiteral, get_pyrange, LoopRange ) from loki.ir import ( Loop, Conditional, Comment, Pragma, FindNodes, Transformer, @@ -658,12 +658,10 @@ def visit_Loop(self, o, depth=None): depth = depth - 1 if depth is not None else None # Only unroll if we have all literal bounds and step - if isinstance(start, (IntLiteral, FloatLiteral)) and\ - isinstance(stop, (IntLiteral, FloatLiteral)) and\ - isinstance(step, (IntLiteral, FloatLiteral)): + if is_constant(start) and is_constant(stop) and is_constant(step): # int() to truncate any floats - which are not invalid in all specs! - unroll_range = range(int(start), int(stop) + 1, int(step)) + unroll_range = get_pyrange(LoopRange((start, stop, step))) if self.warn_iterations_length and len(unroll_range) > 32: warning(f"Unrolling loop over 32 iterations ({len(unroll_range)}), this may take a long time & " f"provide few performance benefits.") @@ -681,11 +679,18 @@ def visit_Loop(self, o, depth=None): return as_tuple(flatten(acc)) + _pragma = tuple( + p for p in o.pragma if not is_loki_pragma(p, starts_with='loop-unroll') + ) if o.pragma else None + _pragma_post = tuple( + p for p in o.pragma_post if not is_loki_pragma(p, starts_with='loop-unroll') + ) if o.pragma_post else None + return Loop( variable=o.variable, body=self.visit(o.body, depth=depth), - bounds=o.bounds - ) + bounds=o.bounds, pragma=_pragma, pragma_post=_pragma_post + ) def do_loop_unroll(routine, warn_iterations_length=True):