Skip to content

Commit

Permalink
DataOffload: Tidy up imports and remove obsolete utility and test
Browse files Browse the repository at this point in the history
Without tying the transformation to calls, explicit no-target skipping
becomes virtually impossible; hence removing the test for it.
  • Loading branch information
mlange05 committed Dec 5, 2024
1 parent cce8b93 commit 9d572ee
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 115 deletions.
35 changes: 7 additions & 28 deletions loki/transformations/data_offload/field_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,19 @@
from loki.batch import Transformation
from loki.expression import Array, symbols as sym
from loki.ir import (
nodes as ir, FindNodes, PragmaRegion, CallStatement, Transformer,
pragma_regions_attached, SubstituteExpressions, FindVariables,
is_loki_pragma
nodes as ir, FindNodes, FindVariables, Transformer,
SubstituteExpressions, pragma_regions_attached, is_loki_pragma
)
from loki.logging import warning, error
from loki.tools import as_tuple
from loki.types import BasicType

from loki.transformations.field_api import FieldPointerMap
from loki.transformations.parallel import remove_field_api_view_updates



__all__ = [
'FieldOffloadTransformation', 'find_target_calls',
'find_offload_variables', 'add_field_offload_calls',
'replace_kernel_args'
'FieldOffloadTransformation', 'find_offload_variables',
'add_field_offload_calls', 'replace_kernel_args'
]


Expand Down Expand Up @@ -67,18 +63,17 @@ def __init__(self, devptr_prefix=None, field_group_types=None, offload_index=Non

def transform_subroutine(self, routine, **kwargs):
role = kwargs['role']
targets = as_tuple(kwargs.get('targets'), (None))
if role == 'driver':
self.process_driver(routine, targets)
self.process_driver(routine)

def process_driver(self, driver, targets):
def process_driver(self, driver):

# Remove the Field-API view-pointer boilerplate
remove_field_api_view_updates(driver, self.field_group_types)

with pragma_regions_attached(driver):
with dataflow_analysis_attached(driver):
for region in FindNodes(PragmaRegion).visit(driver.body):
for region in FindNodes(ir.PragmaRegion).visit(driver.body):
# Only work on active `!$loki data` regions
if not region.pragma or not is_loki_pragma(region.pragma, starts_with='data'):
continue
Expand All @@ -95,22 +90,6 @@ def process_driver(self, driver, targets):
replace_kernel_args(driver, offload_map, self.offload_index)


def find_target_calls(region, targets):
"""
Returns a list of all calls to targets inside the region.
Parameters
----------
:region: :any:`PragmaRegion`
:targets: collection of :any:`Subroutine`
Iterable object of subroutines or functions called
:returns: list of :any:`CallStatement`
"""
calls = FindNodes(CallStatement).visit(region)
calls = [c for c in calls if str(c.name).lower() in targets]
return calls


def find_offload_variables(driver, region, field_group_types):

# Use dataflow analysis to find in, out and inout variables to that region
Expand Down
87 changes: 0 additions & 87 deletions loki/transformations/data_offload/tests/test_field_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,93 +315,6 @@ def test_field_offload_multiple_calls(frontend, state_module, tmp_path):
assert devptr.name in (arg.name for kernel_call in kernel_calls for arg in kernel_call.arguments)


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_no_targets(frontend, state_module, tmp_path):
fother = """
module another_module
implicit none
contains
subroutine another_kernel(nlon, nlev, a, b, c)
integer, intent(in) :: nlon, nlev
real, intent(in) :: a(nlon,nlev)
real, intent(inout) :: b(nlon,nlev)
real, intent(out) :: c(nlon,nlev)
integer :: i, j
end subroutine
end module
"""

fcode = """
module driver_mod
use parkind1, only: jprb
use state_mod, only: state_type
use another_module, only: another_kernel
implicit none
contains
subroutine kernel_routine(nlon, nlev, a, b, c)
integer, intent(in) :: nlon, nlev
real(kind=jprb), intent(in) :: a(nlon,nlev)
real(kind=jprb), intent(inout) :: b(nlon,nlev)
real(kind=jprb), intent(out) :: c(nlon,nlev)
integer :: i, j
do j=1, nlon
do i=1, nlev
b(i,j) = a(i,j) + 0.1
c(i,j) = 0.1
end do
end do
end subroutine kernel_routine
subroutine driver_routine(nlon, nlev, state)
integer, intent(in) :: nlon, nlev
type(state_type), intent(inout) :: state
integer :: i
!$loki data
do i=1,nlev
call state%update_view(i)
call another_kernel(nlon, state%a, state%b, state%c)
end do
!$loki end data
end subroutine driver_routine
end module driver_mod
"""

Sourcefile.from_source(fother, frontend=frontend, xmods=[tmp_path])
driver_mod = Module.from_source(
fcode, frontend=frontend, definitions=state_module,xmods=[tmp_path]
)
driver = driver_mod['driver_routine']
deviceptr_prefix = 'loki_devptr_prefix_'
driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
offload_index='i',
field_group_types=['state_type']),
role='driver',
targets=['kernel_routine'])

calls = FindNodes(CallStatement).visit(driver.body)
assert not any(c for c in calls if c.name=='kernel_routine')

# verify that no field offloads are generated
in_calls = [c for c in calls if 'get_device_data_rdonly' in c.name.name.lower()]
assert len(in_calls) == 0
inout_calls = [c for c in calls if 'get_device_data_rdwr' in c.name.name.lower()]
assert len(inout_calls) == 0
# verify that no field sync host calls are generated
sync_calls = [c for c in calls if 'sync_host_rdwr' in c.name.name.lower()]
assert len(sync_calls) == 0

# verify that data offload pragmas remain
pragmas = FindNodes(Pragma).visit(driver.body)
assert len(pragmas) == 2
assert all(p.keyword=='loki' and p.content==c for p, c in zip(pragmas, ['data', 'end data']))


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_unknown_kernel(caplog, frontend, state_module, tmp_path):
fother = """
Expand Down

0 comments on commit 9d572ee

Please sign in to comment.