diff --git a/loki/transformations/data_offload/field_offload.py b/loki/transformations/data_offload/field_offload.py index 74bb6c6ff..ed7244573 100644 --- a/loki/transformations/data_offload/field_offload.py +++ b/loki/transformations/data_offload/field_offload.py @@ -7,18 +7,18 @@ from itertools import chain +from loki.analyse import dataflow_analysis_attached from loki.batch import Transformation from loki.expression import Array, symbols as sym from loki.ir import ( - FindNodes, PragmaRegion, CallStatement, - Transformer, pragma_regions_attached, - SubstituteExpressions, FindVariables + nodes as ir, FindNodes, PragmaRegion, CallStatement, Transformer, + pragma_regions_attached, SubstituteExpressions, FindVariables, + is_loki_pragma ) from loki.logging import warning, error from loki.tools import as_tuple from loki.types import BasicType -from loki.transformations.data_offload.offload import DataOffloadTransformation from loki.transformations.field_api import FieldPointerMap from loki.transformations.parallel import remove_field_api_view_updates @@ -72,20 +72,27 @@ def transform_subroutine(self, routine, **kwargs): self.process_driver(routine, targets) def process_driver(self, driver, targets): + + # Remove the Field-API view-pointer boilerplate remove_field_api_view_updates(driver, self.field_group_types) + with pragma_regions_attached(driver): - for region in FindNodes(PragmaRegion).visit(driver.body): - # Only work on active `!$loki data` regions - if not DataOffloadTransformation._is_active_loki_data_region(region, targets): - continue - kernel_calls = find_target_calls(region, targets) - offload_variables = find_offload_variables(driver, kernel_calls, self.field_group_types) - offload_map = FieldPointerMap( - *offload_variables, scope=driver, ptr_prefix=self.deviceptr_prefix - ) - declare_device_ptrs(driver, deviceptrs=offload_map.dataptrs) - add_field_offload_calls(driver, region, offload_map) - replace_kernel_args(driver, offload_map, self.offload_index) + with dataflow_analysis_attached(driver): + for region in FindNodes(PragmaRegion).visit(driver.body): + # Only work on active `!$loki data` regions + if not region.pragma or not is_loki_pragma(region.pragma, starts_with='data'): + continue + + # Determine the array variables for generating Field API offload + offload_variables = find_offload_variables(driver, region, self.field_group_types) + offload_map = FieldPointerMap( + *offload_variables, scope=driver, ptr_prefix=self.deviceptr_prefix + ) + + # Inject declarations and offload API calls into driver region + declare_device_ptrs(driver, deviceptrs=offload_map.dataptrs) + add_field_offload_calls(driver, region, offload_map) + replace_kernel_args(driver, offload_map, self.offload_index) def find_target_calls(region, targets): @@ -104,12 +111,20 @@ def find_target_calls(region, targets): return calls -def find_offload_variables(driver, calls, field_group_types): - inargs = () - inoutargs = () - outargs = () +def find_offload_variables(driver, region, field_group_types): + + # Use dataflow analysis to find in, out and inout variables to that region + inargs = region.uses_symbols - region.defines_symbols + inoutargs = region.uses_symbols & region.defines_symbols + outargs = region.defines_symbols - region.uses_symbols - for call in calls: + # Filter out relevant array symbols + inargs = tuple(a for a in inargs if isinstance(a, sym.Array) and a.parent) + inoutargs = tuple(a for a in inoutargs if isinstance(a, sym.Array) and a.parent) + outargs = tuple(a for a in outargs if isinstance(a, sym.Array) and a.parent) + + # Do some sanity checking and warning for enclosed calls + for call in FindNodes(ir.CallStatement).visit(region): if call.routine is BasicType.DEFERRED: error(f'[Loki] Data offload: Routine {driver.name} has not been enriched ' + f'in {str(call.name).lower()}') @@ -128,13 +143,6 @@ def find_offload_variables(driver, calls, field_group_types): + f' {driver.name} that is not wrapped by a Field API object') continue - if param.type.intent.lower() == 'in': - inargs += (arg, ) - if param.type.intent.lower() == 'inout': - inoutargs += (arg, ) - if param.type.intent.lower() == 'out': - outargs += (arg, ) - return inargs, inoutargs, outargs diff --git a/loki/transformations/data_offload/tests/test_field_offload.py b/loki/transformations/data_offload/tests/test_field_offload.py index f430fbc96..0067e4a7f 100644 --- a/loki/transformations/data_offload/tests/test_field_offload.py +++ b/loki/transformations/data_offload/tests/test_field_offload.py @@ -634,3 +634,64 @@ def test_field_offload_aliasing(frontend, state_module, tmp_path): decls = FindNodes(ir.VariableDeclaration).visit(driver.spec) assert len(decls) == 5 if frontend == OMNI else 4 assert decls[-1].symbols == ('state_a(:,:,:)',) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_field_offload_driver_compute(frontend, state_module, tmp_path): + fcode = """ + module driver_mod + use state_mod, only: state_type + use parkind1, only: jprb + implicit none + + contains + + subroutine driver_routine(nlon, nlev, state) + integer, intent(in) :: nlon, nlev + type(state_type), intent(inout) :: state + integer :: i, ibl + + !$loki data + do ibl=1,nlev + call state%update_view(ibl) + + do i=1, nlon + state%a(i, 1) = state%b(i, 1) + 0.1 + state%a(i, 2) = state%a(i, 1) + end do + + end do + !$loki end data + + end subroutine driver_routine + end module driver_mod + """ + driver_mod = Module.from_source( + fcode, frontend=frontend, definitions=state_module, xmods=[tmp_path] + ) + driver = driver_mod['driver_routine'] + + calls = FindNodes(ir.CallStatement).visit(driver.body) + assert len(calls) == 1 + assert calls[0].name == 'state%update_view' + + field_offload = FieldOffloadTransformation( + devptr_prefix='', offload_index='ibl', field_group_types=['state_type'] + ) + driver.apply(field_offload, role='driver', targets=['kernel_routine']) + + calls = FindNodes(ir.CallStatement).visit(driver.body) + assert len(calls) == 3 + assert calls[0].name == 'state%f_b%get_device_data_rdonly' + assert calls[0].arguments == ('state_b',) + assert calls[1].name == 'state%f_a%get_device_data_rdwr' + assert calls[1].arguments == ('state_a',) + assert calls[2].name == 'state%f_a%sync_host_rdwr' + assert calls[2].arguments == () + + assigns = FindNodes(ir.Assignment).visit(driver.body) + assert len(assigns) == 2 + assert assigns[0].lhs == 'state_a(i,1,ibl)' + assert assigns[0].rhs == 'state_b(i,1,ibl) + 0.1' + assert assigns[1].lhs == 'state_a(i,2,ibl)' + assert assigns[1].rhs == 'state_a(i,1,ibl)'