-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Parallel: Add utility to add/remove FIELD-API view updates
- Loading branch information
Showing
3 changed files
with
276 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# (C) Copyright 2018- ECMWF. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
""" | ||
Transformation utilities to manage and inject FIELD-API boilerplate code. | ||
""" | ||
|
||
from loki.expression import symbols as sym | ||
from loki.ir import ( | ||
nodes as ir, FindNodes, FindVariables, Transformer | ||
) | ||
from loki.logging import warning | ||
from loki.tools import as_tuple | ||
|
||
|
||
__all__ = [ | ||
'remove_field_api_view_updates', 'add_field_api_view_updates' | ||
] | ||
|
||
|
||
def remove_field_api_view_updates(routine, field_group_types, dim_object=None): | ||
""" | ||
Remove FIELD API boilerplate calls for view updates of derived types. | ||
This utility is intended to remove the IFS-specific group type | ||
objects that provide block-scope view pointers to deep kernel | ||
trees. It will remove all calls to ``UPDATE_VIEW`` on derive-type | ||
objects with the respective types. | ||
Parameters | ||
---------- | ||
routine : :any:`Subroutine` | ||
The routine from which to remove FIELD API update calls | ||
field_group_types : tuple of str | ||
List of names of the derived types of "field group" objects to remove | ||
dim_object : str, optional | ||
Optional name of the "dimension" object; if provided it will remove the | ||
call to ``<dim>%UPDATE(...)`` accordingly. | ||
""" | ||
field_group_types = as_tuple(field_group_types) | ||
|
||
class RemoveFieldAPITransformer(Transformer): | ||
|
||
def visit_CallStatement(self, call, **kwargs): # pylint: disable=unused-argument | ||
|
||
if '%update_view' in str(call.name).lower(): | ||
if not call.name.parent: | ||
warning(f'[Loki::ControlFlow] Removing {call.name} call without parent!') | ||
if not str(call.name.parent.type.dtype) in field_group_types: | ||
warning(f'[Loki::ControlFlow] Removing {call.name} call, but not in field group types!') | ||
|
||
return None | ||
|
||
if dim_object and f'{dim_object}%update'.lower() in str(call.name).lower(): | ||
return None | ||
|
||
return call | ||
|
||
def visit_Assignment(self, assign, **kwargs): # pylint: disable=unused-argument | ||
if assign.lhs.type.dtype in field_group_types: | ||
warning(f'[Loki::ControlFlow] Found LHS field group assign: {assign}') | ||
return assign | ||
|
||
def visit_Loop(self, loop, **kwargs): | ||
loop = self.visit_Node(loop, **kwargs) | ||
return loop if loop.body else None | ||
|
||
def visit_Conditional(self, cond, **kwargs): | ||
cond = super().visit_Node(cond, **kwargs) | ||
return cond if cond.body else None | ||
|
||
routine.body = RemoveFieldAPITransformer().visit(routine.body) | ||
|
||
|
||
def add_field_api_view_updates(routine, dimension, field_group_types, dim_object=None): | ||
""" | ||
Adds FIELD API boilerplate calls for view updates. | ||
The provided :any:`Dimension` object describes the local loop variables to | ||
pass to the respective update calls. In particular, ``dimension.indices[1]`` | ||
is used to denote the block loop index that is passed to ``UPDATE_VIEW()`` | ||
calls on field group object. The list of type names ``field_group_types`` | ||
is used to identify for which objcets the view update calls get added. | ||
Parameters | ||
---------- | ||
routine : :any:`Subroutine` | ||
The routine from which to remove FIELD API update calls | ||
dimension : :any:`Dimension` | ||
The dimension object describing the block loop variables. | ||
field_group_types : tuple of str | ||
List of names of the derived types of "field group" objects to remove | ||
dim_object : str, optional | ||
Optional name of the "dimension" object; if provided it will remove the | ||
call to ``<dim>%UPDATE(...)`` accordingly. | ||
""" | ||
|
||
def _create_dim_update(scope, dim_object): | ||
index = scope.parse_expr(dimension.index) | ||
upper = scope.parse_expr(dimension.upper[1]) | ||
bindex = scope.parse_expr(dimension.indices[1]) | ||
idims = scope.get_symbol(dim_object) | ||
csym = sym.ProcedureSymbol(name='UPDATE', parent=idims, scope=idims.scope) | ||
return ir.CallStatement(name=csym, arguments=(bindex, upper, index), kwarguments=()) | ||
|
||
def _create_view_updates(section, scope): | ||
bindex = scope.parse_expr(dimension.indices[1]) | ||
|
||
fgroup_vars = sorted(tuple( | ||
v for v in FindVariables(unique=True).visit(section) | ||
if str(v.type.dtype) in field_group_types | ||
), key=str) | ||
calls = () | ||
for fgvar in fgroup_vars: | ||
fgsym = scope.get_symbol(fgvar.name) | ||
csym = sym.ProcedureSymbol(name='UPDATE_VIEW', parent=fgsym, scope=fgsym.scope) | ||
calls += (ir.CallStatement(name=csym, arguments=(bindex,), kwarguments=()),) | ||
|
||
return calls | ||
|
||
class InsertFieldAPIViewsTransformer(Transformer): | ||
""" Injects FIELD-API view updates into block loops """ | ||
|
||
def visit_Loop(self, loop, **kwargs): # pylint: disable=unused-argument | ||
if not loop.variable == 'JKGLO': | ||
return loop | ||
|
||
scope = kwargs.get('scope') | ||
|
||
# Find the loop-setup assignments | ||
_loop_symbols = dimension.indices | ||
_loop_symbols += as_tuple(dimension.lower) + as_tuple(dimension.upper) | ||
loop_setup = tuple( | ||
a for a in FindNodes(ir.Assignment).visit(loop.body) | ||
if a.lhs in _loop_symbols | ||
) | ||
idx = max(loop.body.index(a) for a in loop_setup) + 1 | ||
|
||
# Prepend FIELD API boilerplate | ||
preamble = ( | ||
ir.Comment(''), ir.Comment('! Set up thread-local view pointers') | ||
) | ||
if dim_object: | ||
preamble += (_create_dim_update(scope, dim_object=dim_object),) | ||
preamble += _create_view_updates(loop.body, scope) | ||
|
||
loop._update(body=loop.body[:idx] + preamble + loop.body[idx:]) | ||
return loop | ||
|
||
routine.body = InsertFieldAPIViewsTransformer().visit(routine.body, scope=routine) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# (C) Copyright 2018- ECMWF. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
import pytest | ||
|
||
from loki import Subroutine, Module, Dimension | ||
from loki.frontend import available_frontends, OMNI | ||
from loki.ir import nodes as ir, FindNodes | ||
|
||
from loki.transformations.parallel import ( | ||
remove_field_api_view_updates, add_field_api_view_updates | ||
) | ||
|
||
|
||
@pytest.mark.parametrize('frontend', available_frontends( | ||
skip=[(OMNI, 'OMNI needs full type definitions for derived types')] | ||
)) | ||
def test_field_api_remove_view_updates(frontend): | ||
""" | ||
A simple test for :any:`remove_field_api_view_updates` | ||
""" | ||
|
||
fcode = """ | ||
subroutine test_remove_block_loop(ngptot, nproma, nflux, dims, state, aux_fields, fluxes) | ||
implicit none | ||
integer(kind=4), intent(in) :: ngptot, nproma, nflux | ||
type(dimension_type), intent(inout) :: dims | ||
type(state_type), intent(inout) :: state | ||
type(aux_type), intent(inout) :: aux_fields | ||
type(flux_type), intent(inout) :: fluxes(nflux) | ||
integer :: JKGLO, IBL, ICEND, JK, JL, JF | ||
DO jkglo=1, ngptot, nproma | ||
icend = min(nproma, ngptot - JKGLO + 1) | ||
ibl = (jkglo - 1) / nproma + 1 | ||
CALL DIMS%UPDATE(IBL, ICEND, JKGLO) | ||
CALL STATE%UPDATE_VIEW(IBL) | ||
CALL AUX_FIELDS%UPDATE_VIEW(block_index=IBL) | ||
DO jf=1, nflux | ||
CALL FLUXES(JF)%UPDATE_VIEW(IBL) | ||
END DO | ||
CALL MY_KERNEL(STATE%U, STATE%V, AUX_FIELDS%STUFF, FLUXES(1)%FOO, FLUXES(2)%BAR) | ||
END DO | ||
end subroutine test_remove_block_loop | ||
""" | ||
routine = Subroutine.from_source(fcode, frontend=frontend) | ||
|
||
assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 5 | ||
assert len(FindNodes(ir.Loop).visit(routine.body)) == 2 | ||
|
||
field_group_types = ['state_type', 'aux_type', 'flux_type'] | ||
remove_field_api_view_updates( | ||
routine, field_group_types=field_group_types, dim_object='DIMS' | ||
) | ||
|
||
calls = FindNodes(ir.CallStatement).visit(routine.body) | ||
assert len(calls) == 1 | ||
assert calls[0].name == 'MY_KERNEL' | ||
|
||
loops = FindNodes(ir.Loop).visit(routine.body) | ||
assert len(loops) == 1 | ||
assert loops[0].variable == 'jkglo' | ||
|
||
|
||
@pytest.mark.parametrize('frontend', available_frontends( | ||
skip=[(OMNI, 'OMNI needs full type definitions for derived types')] | ||
)) | ||
def test_field_api_add_view_updates(frontend): | ||
""" | ||
A simple test for :any:`add_field_api_view_updates`. | ||
""" | ||
|
||
fcode = """ | ||
subroutine test_remove_block_loop(ngptot, nproma, nflux, dims, state, aux_fields, fluxes) | ||
implicit none | ||
integer(kind=4), intent(in) :: ngptot, nproma, nflux | ||
type(dimension_type), intent(inout) :: dims | ||
type(state_type), intent(inout) :: state | ||
type(aux_type), intent(inout) :: aux_fields | ||
type(flux_type), intent(inout) :: fluxes | ||
integer :: JKGLO, IBL, ICEND, JK, JL, JF | ||
DO jkglo=1, ngptot, nproma | ||
icend = min(nproma, ngptot - jkglo + 1) | ||
ibl = (jkglo - 1) / nproma + 1 | ||
CALL MY_KERNEL(STATE%U, STATE%V, AUX_FIELDS%STUFF, FLUXES%FOO, FLUXES%BAR) | ||
END DO | ||
end subroutine test_remove_block_loop | ||
""" | ||
routine = Subroutine.from_source(fcode, frontend=frontend) | ||
|
||
assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1 | ||
assert len(FindNodes(ir.Loop).visit(routine.body)) == 1 | ||
|
||
block = Dimension( | ||
index=('jkglo', 'ibl'), step='NPROMA', | ||
lower=('1', 'ICST'), upper=('NGPTOT', 'ICEND') | ||
) | ||
field_group_types = ['state_type', 'aux_type', 'flux_type'] | ||
add_field_api_view_updates( | ||
routine, dimension=block, field_group_types=field_group_types, | ||
dim_object='DIMS' | ||
) | ||
|
||
calls = FindNodes(ir.CallStatement).visit(routine.body) | ||
assert len(calls) == 5 | ||
assert calls[0].name == 'DIMS%UPDATE' and calls[0].arguments == ('IBL', 'ICEND', 'JKGLO') | ||
assert calls[1].name == 'AUX_FIELDS%UPDATE_VIEW' and calls[1].arguments == ('IBL',) | ||
assert calls[2].name == 'FLUXES%UPDATE_VIEW' and calls[2].arguments == ('IBL',) | ||
assert calls[3].name == 'STATE%UPDATE_VIEW' and calls[3].arguments == ('IBL',) | ||
|
||
assert len(FindNodes(ir.Loop).visit(routine.body)) == 1 |