-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Transformation utility to fix sequence association (#173)
* identified arrays passed as scalars * scalars sort of fixed * commit call * constructed new arguments * Might work now * Check that we have the called routine definition * some documentation * simplify TypedSymbol handling * some cleanup * Simplify (?) and add docstrings * fix style * more style * even more style * Ensure Loki versions of Sum and Product * fix another negative bug * add tests * going out in style * Add option for turning scalar fix on and off * moved option setting a bit * fix linter complaints * scalar_syntax -> sequence_association * fix -> resolve * inline and sequence association tests in single column transformation * cleanup * greatly simplify by using callers dimensions * cleanup * more cleanup
- Loading branch information
Showing
8 changed files
with
270 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# (C) Copyright 2018- ECMWF. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
from loki.expression import Array, RangeIndex | ||
from loki.ir import CallStatement | ||
from loki.visitors import FindNodes, Transformer | ||
from loki.tools import as_tuple | ||
from loki.types import BasicType | ||
|
||
|
||
__all__ = [ | ||
'transform_sequence_association' | ||
] | ||
|
||
def check_if_scalar_syntax(arg, dummy): | ||
""" | ||
Check if an array argument, arg, | ||
is passed to an array dummy argument, dummy, | ||
using scalar syntax. i.e. arg(1,1) -> d(m,n) | ||
Parameters | ||
---------- | ||
arg: variable | ||
dummy: variable | ||
""" | ||
if isinstance(arg, Array) and isinstance(dummy, Array): | ||
if arg.dimensions: | ||
if not any(isinstance(d, RangeIndex) for d in arg.dimensions): | ||
return True | ||
return False | ||
|
||
|
||
def transform_sequence_association(routine): | ||
""" | ||
Housekeeping routine to replace scalar syntax when passing arrays as arguments | ||
For example, a call like | ||
real :: a(m,n) | ||
call myroutine(a(i,j)) | ||
where myroutine looks like | ||
subroutine myroutine(a) | ||
real :: a(5) | ||
end subroutine myroutine | ||
should be changed to | ||
call myroutine(a(i:m,j) | ||
Parameters | ||
---------- | ||
routine : :any:`Subroutine` | ||
The subroutine where calls will be changed | ||
""" | ||
|
||
#List calls in routine, but make sure we have the called routine definition | ||
calls = (c for c in FindNodes(CallStatement).visit(routine.body) if not c.procedure_type is BasicType.DEFERRED) | ||
call_map = {} | ||
|
||
for call in calls: | ||
|
||
new_args = [] | ||
|
||
found_scalar = False | ||
for dummy, arg in call.arg_map.items(): | ||
if check_if_scalar_syntax(arg, dummy): | ||
found_scalar = True | ||
|
||
n_dims = len(dummy.shape) | ||
new_dims = [] | ||
for s, lower in zip(arg.shape[:n_dims], arg.dimensions[:n_dims]): | ||
|
||
if isinstance(s, RangeIndex): | ||
new_dims += [RangeIndex((lower, s.stop))] | ||
else: | ||
new_dims += [RangeIndex((lower, s))] | ||
|
||
if len(arg.dimensions) > n_dims: | ||
new_dims += arg.dimensions[len(dummy.shape):] | ||
new_args += [arg.clone(dimensions=as_tuple(new_dims)),] | ||
else: | ||
new_args += [arg,] | ||
|
||
if found_scalar: | ||
call_map[call] = call.clone(arguments = as_tuple(new_args)) | ||
|
||
if call_map: | ||
routine.body = Transformer(call_map).visit(routine.body) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# (C) Copyright 2018- ECMWF. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
import pytest | ||
|
||
from conftest import available_frontends | ||
|
||
from loki.backend import fgen | ||
from loki.transform import transform_sequence_association | ||
from loki.module import Module | ||
from loki.ir import CallStatement | ||
from loki.visitors import FindNodes | ||
|
||
@pytest.mark.parametrize('frontend', available_frontends()) | ||
def test_transform_scalar_notation(frontend): | ||
fcode = """ | ||
module mod_a | ||
implicit none | ||
type type_b | ||
integer :: c | ||
integer :: d | ||
end type type_b | ||
type type_a | ||
type(type_b) :: b | ||
end type type_a | ||
contains | ||
subroutine main() | ||
type(type_a) :: a | ||
integer :: k, m, n | ||
real :: array(10,10) | ||
call sub_x(array(1, 1), 1) | ||
call sub_x(array(2, 2), 2) | ||
call sub_x(array(m, 1), k) | ||
call sub_x(array(m-1, 1), k-1) | ||
call sub_x(array(a%b%c, 1), a%b%d) | ||
contains | ||
subroutine sub_x(array, k) | ||
integer, intent(in) :: k | ||
real, intent(in) :: array(k:n) | ||
end subroutine sub_x | ||
end subroutine main | ||
end module mod_a | ||
""".strip() | ||
|
||
module = Module.from_source(fcode, frontend=frontend) | ||
routine = module['main'] | ||
|
||
transform_sequence_association(routine) | ||
|
||
calls = FindNodes(CallStatement).visit(routine.body) | ||
|
||
assert fgen(calls[0]).lower() == 'call sub_x(array(1:10, 1), 1)' | ||
assert fgen(calls[1]).lower() == 'call sub_x(array(2:10, 2), 2)' | ||
assert fgen(calls[2]).lower() == 'call sub_x(array(m:10, 1), k)' | ||
assert fgen(calls[3]).lower() == 'call sub_x(array(m - 1:10, 1), k - 1)' | ||
assert fgen(calls[4]).lower() == 'call sub_x(array(a%b%c:10, 1), a%b%d)' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters