diff --git a/loki/transformations/parallel/__init__.py b/loki/transformations/parallel/__init__.py index 825921022..0071a9814 100644 --- a/loki/transformations/parallel/__init__.py +++ b/loki/transformations/parallel/__init__.py @@ -11,4 +11,5 @@ """ from loki.transformations.parallel.block_loop import * # noqa +from loki.transformations.parallel.field_api import * # noqa from loki.transformations.parallel.openmp_region import * # noqa diff --git a/loki/transformations/parallel/field_api.py b/loki/transformations/parallel/field_api.py new file mode 100644 index 000000000..2648276a2 --- /dev/null +++ b/loki/transformations/parallel/field_api.py @@ -0,0 +1,154 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Transformation utilities to manage and inject FIELD-API boilerplate code. +""" + +from loki.expression import symbols as sym +from loki.ir import ( + nodes as ir, FindNodes, FindVariables, Transformer +) +from loki.logging import warning +from loki.tools import as_tuple + + +__all__ = [ + 'remove_field_api_view_updates', 'add_field_api_view_updates' +] + + +def remove_field_api_view_updates(routine, field_group_types, dim_object=None): + """ + Remove FIELD API boilerplate calls for view updates of derived types. + + This utility is intended to remove the IFS-specific group type + objects that provide block-scope view pointers to deep kernel + trees. It will remove all calls to ``UPDATE_VIEW`` on derive-type + objects with the respective types. + + Parameters + ---------- + routine : :any:`Subroutine` + The routine from which to remove FIELD API update calls + field_group_types : tuple of str + List of names of the derived types of "field group" objects to remove + dim_object : str, optional + Optional name of the "dimension" object; if provided it will remove the + call to ``%UPDATE(...)`` accordingly. + """ + field_group_types = as_tuple(field_group_types) + + class RemoveFieldAPITransformer(Transformer): + + def visit_CallStatement(self, call, **kwargs): # pylint: disable=unused-argument + + if '%update_view' in str(call.name).lower(): + if not call.name.parent: + warning(f'[Loki::ControlFlow] Removing {call.name} call without parent!') + if not str(call.name.parent.type.dtype) in field_group_types: + warning(f'[Loki::ControlFlow] Removing {call.name} call, but not in field group types!') + + return None + + if dim_object and f'{dim_object}%update'.lower() in str(call.name).lower(): + return None + + return call + + def visit_Assignment(self, assign, **kwargs): # pylint: disable=unused-argument + if assign.lhs.type.dtype in field_group_types: + warning(f'[Loki::ControlFlow] Found LHS field group assign: {assign}') + return assign + + def visit_Loop(self, loop, **kwargs): + loop = self.visit_Node(loop, **kwargs) + return loop if loop.body else None + + def visit_Conditional(self, cond, **kwargs): + cond = super().visit_Node(cond, **kwargs) + return cond if cond.body else None + + routine.body = RemoveFieldAPITransformer().visit(routine.body) + + +def add_field_api_view_updates(routine, dimension, field_group_types, dim_object=None): + """ + Adds FIELD API boilerplate calls for view updates. + + The provided :any:`Dimension` object describes the local loop variables to + pass to the respective update calls. In particular, ``dimension.indices[1]`` + is used to denote the block loop index that is passed to ``UPDATE_VIEW()`` + calls on field group object. The list of type names ``field_group_types`` + is used to identify for which objcets the view update calls get added. + + Parameters + ---------- + routine : :any:`Subroutine` + The routine from which to remove FIELD API update calls + dimension : :any:`Dimension` + The dimension object describing the block loop variables. + field_group_types : tuple of str + List of names of the derived types of "field group" objects to remove + dim_object : str, optional + Optional name of the "dimension" object; if provided it will remove the + call to ``%UPDATE(...)`` accordingly. + """ + + def _create_dim_update(scope, dim_object): + index = scope.parse_expr(dimension.index) + upper = scope.parse_expr(dimension.upper[1]) + bindex = scope.parse_expr(dimension.indices[1]) + idims = scope.get_symbol(dim_object) + csym = sym.ProcedureSymbol(name='UPDATE', parent=idims, scope=idims.scope) + return ir.CallStatement(name=csym, arguments=(bindex, upper, index), kwarguments=()) + + def _create_view_updates(section, scope): + bindex = scope.parse_expr(dimension.indices[1]) + + fgroup_vars = sorted(tuple( + v for v in FindVariables(unique=True).visit(section) + if str(v.type.dtype) in field_group_types + ), key=str) + calls = () + for fgvar in fgroup_vars: + fgsym = scope.get_symbol(fgvar.name) + csym = sym.ProcedureSymbol(name='UPDATE_VIEW', parent=fgsym, scope=fgsym.scope) + calls += (ir.CallStatement(name=csym, arguments=(bindex,), kwarguments=()),) + + return calls + + class InsertFieldAPIViewsTransformer(Transformer): + """ Injects FIELD-API view updates into block loops """ + + def visit_Loop(self, loop, **kwargs): # pylint: disable=unused-argument + if not loop.variable == 'JKGLO': + return loop + + scope = kwargs.get('scope') + + # Find the loop-setup assignments + _loop_symbols = dimension.indices + _loop_symbols += as_tuple(dimension.lower) + as_tuple(dimension.upper) + loop_setup = tuple( + a for a in FindNodes(ir.Assignment).visit(loop.body) + if a.lhs in _loop_symbols + ) + idx = max(loop.body.index(a) for a in loop_setup) + 1 + + # Prepend FIELD API boilerplate + preamble = ( + ir.Comment(''), ir.Comment('! Set up thread-local view pointers') + ) + if dim_object: + preamble += (_create_dim_update(scope, dim_object=dim_object),) + preamble += _create_view_updates(loop.body, scope) + + loop._update(body=loop.body[:idx] + preamble + loop.body[idx:]) + return loop + + routine.body = InsertFieldAPIViewsTransformer().visit(routine.body, scope=routine) diff --git a/loki/transformations/parallel/tests/test_field_api.py b/loki/transformations/parallel/tests/test_field_api.py new file mode 100644 index 000000000..1dbc76d10 --- /dev/null +++ b/loki/transformations/parallel/tests/test_field_api.py @@ -0,0 +1,121 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest + +from loki import Subroutine, Module, Dimension +from loki.frontend import available_frontends, OMNI +from loki.ir import nodes as ir, FindNodes + +from loki.transformations.parallel import ( + remove_field_api_view_updates, add_field_api_view_updates +) + + +@pytest.mark.parametrize('frontend', available_frontends( + skip=[(OMNI, 'OMNI needs full type definitions for derived types')] +)) +def test_field_api_remove_view_updates(frontend): + """ + A simple test for :any:`remove_field_api_view_updates` + """ + + fcode = """ +subroutine test_remove_block_loop(ngptot, nproma, nflux, dims, state, aux_fields, fluxes) + implicit none + integer(kind=4), intent(in) :: ngptot, nproma, nflux + type(dimension_type), intent(inout) :: dims + type(state_type), intent(inout) :: state + type(aux_type), intent(inout) :: aux_fields + type(flux_type), intent(inout) :: fluxes(nflux) + + integer :: JKGLO, IBL, ICEND, JK, JL, JF + + DO jkglo=1, ngptot, nproma + icend = min(nproma, ngptot - JKGLO + 1) + ibl = (jkglo - 1) / nproma + 1 + + CALL DIMS%UPDATE(IBL, ICEND, JKGLO) + CALL STATE%UPDATE_VIEW(IBL) + CALL AUX_FIELDS%UPDATE_VIEW(block_index=IBL) + DO jf=1, nflux + CALL FLUXES(JF)%UPDATE_VIEW(IBL) + END DO + + CALL MY_KERNEL(STATE%U, STATE%V, AUX_FIELDS%STUFF, FLUXES(1)%FOO, FLUXES(2)%BAR) + END DO +end subroutine test_remove_block_loop +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 5 + assert len(FindNodes(ir.Loop).visit(routine.body)) == 2 + + field_group_types = ['state_type', 'aux_type', 'flux_type'] + remove_field_api_view_updates( + routine, field_group_types=field_group_types, dim_object='DIMS' + ) + + calls = FindNodes(ir.CallStatement).visit(routine.body) + assert len(calls) == 1 + assert calls[0].name == 'MY_KERNEL' + + loops = FindNodes(ir.Loop).visit(routine.body) + assert len(loops) == 1 + assert loops[0].variable == 'jkglo' + + +@pytest.mark.parametrize('frontend', available_frontends( + skip=[(OMNI, 'OMNI needs full type definitions for derived types')] +)) +def test_field_api_add_view_updates(frontend): + """ + A simple test for :any:`add_field_api_view_updates`. + """ + + fcode = """ +subroutine test_remove_block_loop(ngptot, nproma, nflux, dims, state, aux_fields, fluxes) + implicit none + integer(kind=4), intent(in) :: ngptot, nproma, nflux + type(dimension_type), intent(inout) :: dims + type(state_type), intent(inout) :: state + type(aux_type), intent(inout) :: aux_fields + type(flux_type), intent(inout) :: fluxes + + integer :: JKGLO, IBL, ICEND, JK, JL, JF + + DO jkglo=1, ngptot, nproma + icend = min(nproma, ngptot - jkglo + 1) + ibl = (jkglo - 1) / nproma + 1 + + CALL MY_KERNEL(STATE%U, STATE%V, AUX_FIELDS%STUFF, FLUXES%FOO, FLUXES%BAR) + END DO +end subroutine test_remove_block_loop +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1 + assert len(FindNodes(ir.Loop).visit(routine.body)) == 1 + + block = Dimension( + index=('jkglo', 'ibl'), step='NPROMA', + lower=('1', 'ICST'), upper=('NGPTOT', 'ICEND') + ) + field_group_types = ['state_type', 'aux_type', 'flux_type'] + add_field_api_view_updates( + routine, dimension=block, field_group_types=field_group_types, + dim_object='DIMS' + ) + + calls = FindNodes(ir.CallStatement).visit(routine.body) + assert len(calls) == 5 + assert calls[0].name == 'DIMS%UPDATE' and calls[0].arguments == ('IBL', 'ICEND', 'JKGLO') + assert calls[1].name == 'AUX_FIELDS%UPDATE_VIEW' and calls[1].arguments == ('IBL',) + assert calls[2].name == 'FLUXES%UPDATE_VIEW' and calls[2].arguments == ('IBL',) + assert calls[3].name == 'STATE%UPDATE_VIEW' and calls[3].arguments == ('IBL',) + + assert len(FindNodes(ir.Loop).visit(routine.body)) == 1