Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unroll negative loop bounds and retain pragmas inside unrolled loop body #443

Merged
merged 2 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions loki/transformations/tests/test_transform_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,7 +1714,7 @@ def test_transform_loop_unroll_step(tmp_path, frontend):

!Loop A
!$loki loop-unroll
do a=1, 10, 2
do a=-2, 7, 2
s = s + a + 1
end do

Expand All @@ -1727,7 +1727,7 @@ def test_transform_loop_unroll_step(tmp_path, frontend):
# Test the reference solution
s = np.zeros(1)
function(s=s)
assert s == sum(x + 1 for x in range(1, 11, 2))
assert s == sum(x + 1 for x in range(-2, 8, 2))

# Apply transformation
assert len(FindNodes(Loop).visit(routine.body)) == 1
Expand All @@ -1740,7 +1740,7 @@ def test_transform_loop_unroll_step(tmp_path, frontend):
# Test transformation
s = np.zeros(1)
unrolled_function(s=s)
assert s == sum(x + 1 for x in range(1, 11, 2))
assert s == sum(x + 1 for x in range(-2, 8, 2))

clean_test(filepath)
clean_test(unrolled_filepath)
Expand Down Expand Up @@ -1876,6 +1876,9 @@ def test_transform_loop_unroll_nested_restricted_depth(tmp_path, frontend):
unrolled_function(s=s)
assert s == sum(a + b + 1 for (a, b) in itertools.product(range(1, 11), range(1, 6)))

# check unroll pragma has been removed
assert not FindNodes(ir.Pragma).visit(routine.body)

clean_test(filepath)
clean_test(unrolled_filepath)

Expand Down Expand Up @@ -2105,3 +2108,44 @@ def test_transform_loop_transformation(frontend, loop_interchange, loop_fusion,

assert len(loops) == num_loops
assert len(pragmas) == num_pragmas


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_before_fuse(frontend):
fcode = """
subroutine test_loop_unroll_before_fuse(n, map, a, b)
integer, intent(in) :: n
integer, intent(in) :: map(3,3)
real, intent(inout) :: a(n)
real, intent(in) :: b(:)

integer :: i,j,k

!$loki loop-unroll
do k=1,3
!$loki loop-unroll
do j=1,3
!$loki loop-fusion
do i=1,n
a(i) = a(i) + b(map(j,k))
enddo
enddo
enddo

end subroutine test_loop_unroll_before_fuse
"""

routine = Subroutine.from_source(fcode, frontend=frontend)
assert len(FindNodes(ir.Loop).visit(routine.body)) == 3

do_loop_unroll(routine)
loops = FindNodes(ir.Loop).visit(routine.body)
assert len(loops) == 9
assert all(loop.variable == 'i' for loop in loops)

pragmas = FindNodes(ir.Pragma).visit(routine.body)
assert len(pragmas) == 9
assert all(p.content == 'loop-fusion' for p in pragmas)

do_loop_fusion(routine)
assert len(FindNodes(ir.Loop).visit(routine.body)) == 1
19 changes: 12 additions & 7 deletions loki/transformations/transform_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from loki.expression import (
symbols as sym, simplify, is_constant, symbolic_op, parse_expr,
IntLiteral, FloatLiteral
IntLiteral, get_pyrange, LoopRange
)
from loki.ir import (
Loop, Conditional, Comment, Pragma, FindNodes, Transformer,
Expand Down Expand Up @@ -658,12 +658,10 @@ def visit_Loop(self, o, depth=None):
depth = depth - 1 if depth is not None else None

# Only unroll if we have all literal bounds and step
if isinstance(start, (IntLiteral, FloatLiteral)) and\
isinstance(stop, (IntLiteral, FloatLiteral)) and\
isinstance(step, (IntLiteral, FloatLiteral)):
if is_constant(start) and is_constant(stop) and is_constant(step):

# int() to truncate any floats - which are not invalid in all specs!
unroll_range = range(int(start), int(stop) + 1, int(step))
unroll_range = get_pyrange(LoopRange((start, stop, step)))
if self.warn_iterations_length and len(unroll_range) > 32:
warning(f"Unrolling loop over 32 iterations ({len(unroll_range)}), this may take a long time & "
f"provide few performance benefits.")
Expand All @@ -681,11 +679,18 @@ def visit_Loop(self, o, depth=None):

return as_tuple(flatten(acc))

_pragma = tuple(
p for p in o.pragma if not is_loki_pragma(p, starts_with='loop-unroll')
) if o.pragma else None
_pragma_post = tuple(
p for p in o.pragma_post if not is_loki_pragma(p, starts_with='loop-unroll')
) if o.pragma_post else None

return Loop(
variable=o.variable,
body=self.visit(o.body, depth=depth),
bounds=o.bounds
)
bounds=o.bounds, pragma=_pragma, pragma_post=_pragma_post
)


def do_loop_unroll(routine, warn_iterations_length=True):
Expand Down
Loading