Skip to content

Commit

Permalink
Parallel: Add utility to add/remove FIELD-API view updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Nov 7, 2024
1 parent 997683b commit da8b9f3
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 0 deletions.
1 change: 1 addition & 0 deletions loki/transformations/parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
"""

from loki.transformations.parallel.block_loop import * # noqa
from loki.transformations.parallel.field_api import * # noqa
from loki.transformations.parallel.openmp_region import * # noqa
154 changes: 154 additions & 0 deletions loki/transformations/parallel/field_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Transformation utilities to manage and inject FIELD-API boilerplate code.
"""

from loki.expression import symbols as sym
from loki.ir import (
nodes as ir, FindNodes, FindVariables, Transformer
)
from loki.logging import warning
from loki.tools import as_tuple


__all__ = [
'remove_field_api_view_updates', 'add_field_api_view_updates'
]


def remove_field_api_view_updates(routine, field_group_types, dim_object=None):
"""
Remove FIELD API boilerplate calls for view updates of derived types.
This utility is intended to remove the IFS-specific group type
objects that provide block-scope view pointers to deep kernel
trees. It will remove all calls to ``UPDATE_VIEW`` on derive-type
objects with the respective types.
Parameters
----------
routine : :any:`Subroutine`
The routine from which to remove FIELD API update calls
field_group_types : tuple of str
List of names of the derived types of "field group" objects to remove
dim_object : str, optional
Optional name of the "dimension" object; if provided it will remove the
call to ``<dim>%UPDATE(...)`` accordingly.
"""
field_group_types = as_tuple(field_group_types)

class RemoveFieldAPITransformer(Transformer):

def visit_CallStatement(self, call, **kwargs): # pylint: disable=unused-argument

if '%update_view' in str(call.name).lower():
if not call.name.parent:
warning(f'[Loki::ControlFlow] Removing {call.name} call without parent!')
if not str(call.name.parent.type.dtype) in field_group_types:
warning(f'[Loki::ControlFlow] Removing {call.name} call, but not in field group types!')

return None

if dim_object and f'{dim_object}%update'.lower() in str(call.name).lower():
return None

return call

def visit_Assignment(self, assign, **kwargs): # pylint: disable=unused-argument
if assign.lhs.type.dtype in field_group_types:
warning(f'[Loki::ControlFlow] Found LHS field group assign: {assign}')
return assign

def visit_Loop(self, loop, **kwargs):
loop = self.visit_Node(loop, **kwargs)
return loop if loop.body else None

def visit_Conditional(self, cond, **kwargs):
cond = super().visit_Node(cond, **kwargs)
return cond if cond.body else None

routine.body = RemoveFieldAPITransformer().visit(routine.body)


def add_field_api_view_updates(routine, dimension, field_group_types, dim_object=None):
"""
Adds FIELD API boilerplate calls for view updates.
The provided :any:`Dimension` object describes the local loop variables to
pass to the respective update calls. In particular, ``dimension.indices[1]``
is used to denote the block loop index that is passed to ``UPDATE_VIEW()``
calls on field group object. The list of type names ``field_group_types``
is used to identify for which objcets the view update calls get added.
Parameters
----------
routine : :any:`Subroutine`
The routine from which to remove FIELD API update calls
dimension : :any:`Dimension`
The dimension object describing the block loop variables.
field_group_types : tuple of str
List of names of the derived types of "field group" objects to remove
dim_object : str, optional
Optional name of the "dimension" object; if provided it will remove the
call to ``<dim>%UPDATE(...)`` accordingly.
"""

def _create_dim_update(scope, dim_object):
index = scope.parse_expr(dimension.index)
upper = scope.parse_expr(dimension.upper[1])
bindex = scope.parse_expr(dimension.indices[1])
idims = scope.get_symbol(dim_object)
csym = sym.ProcedureSymbol(name='UPDATE', parent=idims, scope=idims.scope)
return ir.CallStatement(name=csym, arguments=(bindex, upper, index), kwarguments=())

def _create_view_updates(section, scope):
bindex = scope.parse_expr(dimension.indices[1])

fgroup_vars = sorted(tuple(
v for v in FindVariables(unique=True).visit(section)
if str(v.type.dtype) in field_group_types
), key=str)
calls = ()
for fgvar in fgroup_vars:
fgsym = scope.get_symbol(fgvar.name)
csym = sym.ProcedureSymbol(name='UPDATE_VIEW', parent=fgsym, scope=fgsym.scope)
calls += (ir.CallStatement(name=csym, arguments=(bindex,), kwarguments=()),)

return calls

class InsertFieldAPIViewsTransformer(Transformer):
""" Injects FIELD-API view updates into block loops """

def visit_Loop(self, loop, **kwargs): # pylint: disable=unused-argument
if not loop.variable == 'JKGLO':
return loop

scope = kwargs.get('scope')

# Find the loop-setup assignments
_loop_symbols = dimension.indices
_loop_symbols += as_tuple(dimension.lower) + as_tuple(dimension.upper)
loop_setup = tuple(
a for a in FindNodes(ir.Assignment).visit(loop.body)
if a.lhs in _loop_symbols
)
idx = max(loop.body.index(a) for a in loop_setup) + 1

# Prepend FIELD API boilerplate
preamble = (
ir.Comment(''), ir.Comment('! Set up thread-local view pointers')
)
if dim_object:
preamble += (_create_dim_update(scope, dim_object=dim_object),)
preamble += _create_view_updates(loop.body, scope)

loop._update(body=loop.body[:idx] + preamble + loop.body[idx:])
return loop

routine.body = InsertFieldAPIViewsTransformer().visit(routine.body, scope=routine)
121 changes: 121 additions & 0 deletions loki/transformations/parallel/tests/test_field_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Module, Dimension
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes

from loki.transformations.parallel import (
remove_field_api_view_updates, add_field_api_view_updates
)


@pytest.mark.parametrize('frontend', available_frontends(
skip=[(OMNI, 'OMNI needs full type definitions for derived types')]
))
def test_field_api_remove_view_updates(frontend):
"""
A simple test for :any:`remove_field_api_view_updates`
"""

fcode = """
subroutine test_remove_block_loop(ngptot, nproma, nflux, dims, state, aux_fields, fluxes)
implicit none
integer(kind=4), intent(in) :: ngptot, nproma, nflux
type(dimension_type), intent(inout) :: dims
type(state_type), intent(inout) :: state
type(aux_type), intent(inout) :: aux_fields
type(flux_type), intent(inout) :: fluxes(nflux)
integer :: JKGLO, IBL, ICEND, JK, JL, JF
DO jkglo=1, ngptot, nproma
icend = min(nproma, ngptot - JKGLO + 1)
ibl = (jkglo - 1) / nproma + 1
CALL DIMS%UPDATE(IBL, ICEND, JKGLO)
CALL STATE%UPDATE_VIEW(IBL)
CALL AUX_FIELDS%UPDATE_VIEW(block_index=IBL)
DO jf=1, nflux
CALL FLUXES(JF)%UPDATE_VIEW(IBL)
END DO
CALL MY_KERNEL(STATE%U, STATE%V, AUX_FIELDS%STUFF, FLUXES(1)%FOO, FLUXES(2)%BAR)
END DO
end subroutine test_remove_block_loop
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 5
assert len(FindNodes(ir.Loop).visit(routine.body)) == 2

field_group_types = ['state_type', 'aux_type', 'flux_type']
remove_field_api_view_updates(
routine, field_group_types=field_group_types, dim_object='DIMS'
)

calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].name == 'MY_KERNEL'

loops = FindNodes(ir.Loop).visit(routine.body)
assert len(loops) == 1
assert loops[0].variable == 'jkglo'


@pytest.mark.parametrize('frontend', available_frontends(
skip=[(OMNI, 'OMNI needs full type definitions for derived types')]
))
def test_field_api_add_view_updates(frontend):
"""
A simple test for :any:`add_field_api_view_updates`.
"""

fcode = """
subroutine test_remove_block_loop(ngptot, nproma, nflux, dims, state, aux_fields, fluxes)
implicit none
integer(kind=4), intent(in) :: ngptot, nproma, nflux
type(dimension_type), intent(inout) :: dims
type(state_type), intent(inout) :: state
type(aux_type), intent(inout) :: aux_fields
type(flux_type), intent(inout) :: fluxes
integer :: JKGLO, IBL, ICEND, JK, JL, JF
DO jkglo=1, ngptot, nproma
icend = min(nproma, ngptot - jkglo + 1)
ibl = (jkglo - 1) / nproma + 1
CALL MY_KERNEL(STATE%U, STATE%V, AUX_FIELDS%STUFF, FLUXES%FOO, FLUXES%BAR)
END DO
end subroutine test_remove_block_loop
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1
assert len(FindNodes(ir.Loop).visit(routine.body)) == 1

block = Dimension(
index=('jkglo', 'ibl'), step='NPROMA',
lower=('1', 'ICST'), upper=('NGPTOT', 'ICEND')
)
field_group_types = ['state_type', 'aux_type', 'flux_type']
add_field_api_view_updates(
routine, dimension=block, field_group_types=field_group_types,
dim_object='DIMS'
)

calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 5
assert calls[0].name == 'DIMS%UPDATE' and calls[0].arguments == ('IBL', 'ICEND', 'JKGLO')
assert calls[1].name == 'AUX_FIELDS%UPDATE_VIEW' and calls[1].arguments == ('IBL',)
assert calls[2].name == 'FLUXES%UPDATE_VIEW' and calls[2].arguments == ('IBL',)
assert calls[3].name == 'STATE%UPDATE_VIEW' and calls[3].arguments == ('IBL',)

assert len(FindNodes(ir.Loop).visit(routine.body)) == 1

0 comments on commit da8b9f3

Please sign in to comment.