diff --git a/loki/transformations/data_offload/field_offload.py b/loki/transformations/data_offload/field_offload.py index ed7244573..58dbd7a54 100644 --- a/loki/transformations/data_offload/field_offload.py +++ b/loki/transformations/data_offload/field_offload.py @@ -11,23 +11,19 @@ from loki.batch import Transformation from loki.expression import Array, symbols as sym from loki.ir import ( - nodes as ir, FindNodes, PragmaRegion, CallStatement, Transformer, - pragma_regions_attached, SubstituteExpressions, FindVariables, - is_loki_pragma + nodes as ir, FindNodes, FindVariables, Transformer, + SubstituteExpressions, pragma_regions_attached, is_loki_pragma ) from loki.logging import warning, error -from loki.tools import as_tuple from loki.types import BasicType from loki.transformations.field_api import FieldPointerMap from loki.transformations.parallel import remove_field_api_view_updates - __all__ = [ - 'FieldOffloadTransformation', 'find_target_calls', - 'find_offload_variables', 'add_field_offload_calls', - 'replace_kernel_args' + 'FieldOffloadTransformation', 'find_offload_variables', + 'add_field_offload_calls', 'replace_kernel_args' ] @@ -67,18 +63,17 @@ def __init__(self, devptr_prefix=None, field_group_types=None, offload_index=Non def transform_subroutine(self, routine, **kwargs): role = kwargs['role'] - targets = as_tuple(kwargs.get('targets'), (None)) if role == 'driver': - self.process_driver(routine, targets) + self.process_driver(routine) - def process_driver(self, driver, targets): + def process_driver(self, driver): # Remove the Field-API view-pointer boilerplate remove_field_api_view_updates(driver, self.field_group_types) with pragma_regions_attached(driver): with dataflow_analysis_attached(driver): - for region in FindNodes(PragmaRegion).visit(driver.body): + for region in FindNodes(ir.PragmaRegion).visit(driver.body): # Only work on active `!$loki data` regions if not region.pragma or not is_loki_pragma(region.pragma, starts_with='data'): continue @@ -95,22 +90,6 @@ def process_driver(self, driver, targets): replace_kernel_args(driver, offload_map, self.offload_index) -def find_target_calls(region, targets): - """ - Returns a list of all calls to targets inside the region. - - Parameters - ---------- - :region: :any:`PragmaRegion` - :targets: collection of :any:`Subroutine` - Iterable object of subroutines or functions called - :returns: list of :any:`CallStatement` - """ - calls = FindNodes(CallStatement).visit(region) - calls = [c for c in calls if str(c.name).lower() in targets] - return calls - - def find_offload_variables(driver, region, field_group_types): # Use dataflow analysis to find in, out and inout variables to that region diff --git a/loki/transformations/data_offload/tests/test_field_offload.py b/loki/transformations/data_offload/tests/test_field_offload.py index 0067e4a7f..f1bdb3608 100644 --- a/loki/transformations/data_offload/tests/test_field_offload.py +++ b/loki/transformations/data_offload/tests/test_field_offload.py @@ -315,93 +315,6 @@ def test_field_offload_multiple_calls(frontend, state_module, tmp_path): assert devptr.name in (arg.name for kernel_call in kernel_calls for arg in kernel_call.arguments) -@pytest.mark.parametrize('frontend', available_frontends()) -def test_field_offload_no_targets(frontend, state_module, tmp_path): - fother = """ - module another_module - implicit none - contains - subroutine another_kernel(nlon, nlev, a, b, c) - integer, intent(in) :: nlon, nlev - real, intent(in) :: a(nlon,nlev) - real, intent(inout) :: b(nlon,nlev) - real, intent(out) :: c(nlon,nlev) - integer :: i, j - end subroutine - end module - """ - - fcode = """ - module driver_mod - use parkind1, only: jprb - use state_mod, only: state_type - use another_module, only: another_kernel - - implicit none - - contains - - subroutine kernel_routine(nlon, nlev, a, b, c) - integer, intent(in) :: nlon, nlev - real(kind=jprb), intent(in) :: a(nlon,nlev) - real(kind=jprb), intent(inout) :: b(nlon,nlev) - real(kind=jprb), intent(out) :: c(nlon,nlev) - integer :: i, j - - do j=1, nlon - do i=1, nlev - b(i,j) = a(i,j) + 0.1 - c(i,j) = 0.1 - end do - end do - end subroutine kernel_routine - - subroutine driver_routine(nlon, nlev, state) - integer, intent(in) :: nlon, nlev - type(state_type), intent(inout) :: state - integer :: i - - !$loki data - do i=1,nlev - call state%update_view(i) - call another_kernel(nlon, state%a, state%b, state%c) - end do - !$loki end data - - end subroutine driver_routine - end module driver_mod - """ - - Sourcefile.from_source(fother, frontend=frontend, xmods=[tmp_path]) - driver_mod = Module.from_source( - fcode, frontend=frontend, definitions=state_module,xmods=[tmp_path] - ) - driver = driver_mod['driver_routine'] - deviceptr_prefix = 'loki_devptr_prefix_' - driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix, - offload_index='i', - field_group_types=['state_type']), - role='driver', - targets=['kernel_routine']) - - calls = FindNodes(CallStatement).visit(driver.body) - assert not any(c for c in calls if c.name=='kernel_routine') - - # verify that no field offloads are generated - in_calls = [c for c in calls if 'get_device_data_rdonly' in c.name.name.lower()] - assert len(in_calls) == 0 - inout_calls = [c for c in calls if 'get_device_data_rdwr' in c.name.name.lower()] - assert len(inout_calls) == 0 - # verify that no field sync host calls are generated - sync_calls = [c for c in calls if 'sync_host_rdwr' in c.name.name.lower()] - assert len(sync_calls) == 0 - - # verify that data offload pragmas remain - pragmas = FindNodes(Pragma).visit(driver.body) - assert len(pragmas) == 2 - assert all(p.keyword=='loki' and p.content==c for p, c in zip(pragmas, ['data', 'end data'])) - - @pytest.mark.parametrize('frontend', available_frontends()) def test_field_offload_unknown_kernel(caplog, frontend, state_module, tmp_path): fother = """