diff --git a/loki/transformations/tests/test_transform_loop.py b/loki/transformations/tests/test_transform_loop.py index 9f73e5039..0dc393c84 100644 --- a/loki/transformations/tests/test_transform_loop.py +++ b/loki/transformations/tests/test_transform_loop.py @@ -1876,6 +1876,9 @@ def test_transform_loop_unroll_nested_restricted_depth(tmp_path, frontend): unrolled_function(s=s) assert s == sum(a + b + 1 for (a, b) in itertools.product(range(1, 11), range(1, 6))) + # check unroll pragma has been removed + assert not FindNodes(ir.Pragma).visit(routine.body) + clean_test(filepath) clean_test(unrolled_filepath) @@ -2105,3 +2108,46 @@ def test_transform_loop_transformation(frontend, loop_interchange, loop_fusion, assert len(loops) == num_loops assert len(pragmas) == num_pragmas + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_loop_unroll_before_fuse(frontend): + fcode = """ + subroutine test_loop_unroll_before_fuse(n, map, a, b, c) + integer, intent(in) :: n + integer, intent(in) :: map(3) + real, intent(inout) :: a(n) + real, intent(in) :: b(:) + real, intent(in) :: c(:) + + integer :: i,j + + !$loki loop-unroll + do j=1,3 + !$loki loop-fusion + do i=1,n + a(i) = a(i) + b(map(j)) + enddo + enddo + + !$loki loop-unroll + do j=1,2 + !$loki loop-fusion + do i=1,n + a(i) = a(i) + c(map(j)) + enddo + enddo + + end subroutine test_loop_unroll_before_fuse +""" + + routine = Subroutine.from_source(fcode, frontend=frontend) + assert len(FindNodes(ir.Loop).visit(routine.body)) == 4 + + do_loop_unroll(routine) + loops = FindNodes(ir.Loop).visit(routine.body) + assert len(loops) == 5 + assert all(loop.variable == 'i' for loop in loops) + + do_loop_fusion(routine) + assert len(FindNodes(ir.Loop).visit(routine.body)) == 1 diff --git a/loki/transformations/transform_loop.py b/loki/transformations/transform_loop.py index 5700676e9..99e1c6f63 100644 --- a/loki/transformations/transform_loop.py +++ b/loki/transformations/transform_loop.py @@ -639,7 +639,7 @@ def __init__(self, warn_iterations_length=True): super().__init__() # depth is treated as an option of some depth or none, i.e. unroll all - def visit_Loop(self, o, depth=None): + def visit_Loop(self, o, depth=None, preserve_pragmas=False): """ Apply this :class:`Transformer` to an IR tree. @@ -675,14 +675,15 @@ def visit_Loop(self, o, depth=None): ()) if depth is None or depth >= 1: - acc = [self.visit(a, depth=depth) for a in acc] + acc = [self.visit(a, depth=depth, preserve_pragmas=True) for a in acc] return as_tuple(flatten(acc)) return Loop( variable=o.variable, - body=self.visit(o.body, depth=depth), - bounds=o.bounds + body=self.visit(o.body, depth=depth, preserve_pragmas=True), + bounds=o.bounds, pragma=o.pragma if preserve_pragmas else None, + pragma_post=o.pragma_post if preserve_pragmas else None )