Skip to content

Commit

Permalink
LoopUnroll: retain pragmas when unrolling loops
Browse files Browse the repository at this point in the history
  • Loading branch information
awnawab committed Nov 21, 2024
1 parent 27ee3b7 commit bac536f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
46 changes: 46 additions & 0 deletions loki/transformations/tests/test_transform_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,6 +1876,9 @@ def test_transform_loop_unroll_nested_restricted_depth(tmp_path, frontend):
unrolled_function(s=s)
assert s == sum(a + b + 1 for (a, b) in itertools.product(range(1, 11), range(1, 6)))

# check unroll pragma has been removed
assert not FindNodes(ir.Pragma).visit(routine.body)

clean_test(filepath)
clean_test(unrolled_filepath)

Expand Down Expand Up @@ -2105,3 +2108,46 @@ def test_transform_loop_transformation(frontend, loop_interchange, loop_fusion,

assert len(loops) == num_loops
assert len(pragmas) == num_pragmas


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_before_fuse(frontend):
fcode = """
subroutine test_loop_unroll_before_fuse(n, map, a, b, c)
integer, intent(in) :: n
integer, intent(in) :: map(3)
real, intent(inout) :: a(n)
real, intent(in) :: b(:)
real, intent(in) :: c(:)
integer :: i,j
!$loki loop-unroll
do j=1,3
!$loki loop-fusion
do i=1,n
a(i) = a(i) + b(map(j))
enddo
enddo
!$loki loop-unroll
do j=1,2
!$loki loop-fusion
do i=1,n
a(i) = a(i) + c(map(j))
enddo
enddo
end subroutine test_loop_unroll_before_fuse
"""

routine = Subroutine.from_source(fcode, frontend=frontend)
assert len(FindNodes(ir.Loop).visit(routine.body)) == 4

do_loop_unroll(routine)
loops = FindNodes(ir.Loop).visit(routine.body)
assert len(loops) == 5
assert all(loop.variable == 'i' for loop in loops)

do_loop_fusion(routine)
assert len(FindNodes(ir.Loop).visit(routine.body)) == 1
9 changes: 5 additions & 4 deletions loki/transformations/transform_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,7 @@ def __init__(self, warn_iterations_length=True):
super().__init__()

# depth is treated as an option of some depth or none, i.e. unroll all
def visit_Loop(self, o, depth=None):
def visit_Loop(self, o, depth=None, preserve_pragmas=False):
"""
Apply this :class:`Transformer` to an IR tree.
Expand Down Expand Up @@ -675,14 +675,15 @@ def visit_Loop(self, o, depth=None):
())

if depth is None or depth >= 1:
acc = [self.visit(a, depth=depth) for a in acc]
acc = [self.visit(a, depth=depth, preserve_pragmas=True) for a in acc]

return as_tuple(flatten(acc))

return Loop(
variable=o.variable,
body=self.visit(o.body, depth=depth),
bounds=o.bounds
body=self.visit(o.body, depth=depth, preserve_pragmas=True),
bounds=o.bounds, pragma=o.pragma if preserve_pragmas else None,
pragma_post=o.pragma_post if preserve_pragmas else None
)


Expand Down

0 comments on commit bac536f

Please sign in to comment.