Skip to content

Commit

Permalink
DataOffload: Derive offload variables via dataflow analysis for regions
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Dec 5, 2024
1 parent 7bfbf38 commit cce8b93
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 28 deletions.
64 changes: 36 additions & 28 deletions loki/transformations/data_offload/field_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@

from itertools import chain

from loki.analyse import dataflow_analysis_attached
from loki.batch import Transformation
from loki.expression import Array, symbols as sym
from loki.ir import (
FindNodes, PragmaRegion, CallStatement,
Transformer, pragma_regions_attached,
SubstituteExpressions, FindVariables
nodes as ir, FindNodes, PragmaRegion, CallStatement, Transformer,
pragma_regions_attached, SubstituteExpressions, FindVariables,
is_loki_pragma
)
from loki.logging import warning, error
from loki.tools import as_tuple
from loki.types import BasicType

from loki.transformations.data_offload.offload import DataOffloadTransformation
from loki.transformations.field_api import FieldPointerMap
from loki.transformations.parallel import remove_field_api_view_updates

Expand Down Expand Up @@ -72,20 +72,27 @@ def transform_subroutine(self, routine, **kwargs):
self.process_driver(routine, targets)

def process_driver(self, driver, targets):

# Remove the Field-API view-pointer boilerplate
remove_field_api_view_updates(driver, self.field_group_types)

with pragma_regions_attached(driver):
for region in FindNodes(PragmaRegion).visit(driver.body):
# Only work on active `!$loki data` regions
if not DataOffloadTransformation._is_active_loki_data_region(region, targets):
continue
kernel_calls = find_target_calls(region, targets)
offload_variables = find_offload_variables(driver, kernel_calls, self.field_group_types)
offload_map = FieldPointerMap(
*offload_variables, scope=driver, ptr_prefix=self.deviceptr_prefix
)
declare_device_ptrs(driver, deviceptrs=offload_map.dataptrs)
add_field_offload_calls(driver, region, offload_map)
replace_kernel_args(driver, offload_map, self.offload_index)
with dataflow_analysis_attached(driver):
for region in FindNodes(PragmaRegion).visit(driver.body):
# Only work on active `!$loki data` regions
if not region.pragma or not is_loki_pragma(region.pragma, starts_with='data'):
continue

# Determine the array variables for generating Field API offload
offload_variables = find_offload_variables(driver, region, self.field_group_types)
offload_map = FieldPointerMap(
*offload_variables, scope=driver, ptr_prefix=self.deviceptr_prefix
)

# Inject declarations and offload API calls into driver region
declare_device_ptrs(driver, deviceptrs=offload_map.dataptrs)
add_field_offload_calls(driver, region, offload_map)
replace_kernel_args(driver, offload_map, self.offload_index)


def find_target_calls(region, targets):
Expand All @@ -104,12 +111,20 @@ def find_target_calls(region, targets):
return calls


def find_offload_variables(driver, calls, field_group_types):
inargs = ()
inoutargs = ()
outargs = ()
def find_offload_variables(driver, region, field_group_types):

# Use dataflow analysis to find in, out and inout variables to that region
inargs = region.uses_symbols - region.defines_symbols
inoutargs = region.uses_symbols & region.defines_symbols
outargs = region.defines_symbols - region.uses_symbols

for call in calls:
# Filter out relevant array symbols
inargs = tuple(a for a in inargs if isinstance(a, sym.Array) and a.parent)
inoutargs = tuple(a for a in inoutargs if isinstance(a, sym.Array) and a.parent)
outargs = tuple(a for a in outargs if isinstance(a, sym.Array) and a.parent)

# Do some sanity checking and warning for enclosed calls
for call in FindNodes(ir.CallStatement).visit(region):
if call.routine is BasicType.DEFERRED:
error(f'[Loki] Data offload: Routine {driver.name} has not been enriched ' +
f'in {str(call.name).lower()}')
Expand All @@ -128,13 +143,6 @@ def find_offload_variables(driver, calls, field_group_types):
+ f' {driver.name} that is not wrapped by a Field API object')
continue

if param.type.intent.lower() == 'in':
inargs += (arg, )
if param.type.intent.lower() == 'inout':
inoutargs += (arg, )
if param.type.intent.lower() == 'out':
outargs += (arg, )

return inargs, inoutargs, outargs


Expand Down
61 changes: 61 additions & 0 deletions loki/transformations/data_offload/tests/test_field_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,64 @@ def test_field_offload_aliasing(frontend, state_module, tmp_path):
decls = FindNodes(ir.VariableDeclaration).visit(driver.spec)
assert len(decls) == 5 if frontend == OMNI else 4
assert decls[-1].symbols == ('state_a(:,:,:)',)


@pytest.mark.parametrize('frontend', available_frontends())
def test_field_offload_driver_compute(frontend, state_module, tmp_path):
fcode = """
module driver_mod
use state_mod, only: state_type
use parkind1, only: jprb
implicit none
contains
subroutine driver_routine(nlon, nlev, state)
integer, intent(in) :: nlon, nlev
type(state_type), intent(inout) :: state
integer :: i, ibl
!$loki data
do ibl=1,nlev
call state%update_view(ibl)
do i=1, nlon
state%a(i, 1) = state%b(i, 1) + 0.1
state%a(i, 2) = state%a(i, 1)
end do
end do
!$loki end data
end subroutine driver_routine
end module driver_mod
"""
driver_mod = Module.from_source(
fcode, frontend=frontend, definitions=state_module, xmods=[tmp_path]
)
driver = driver_mod['driver_routine']

calls = FindNodes(ir.CallStatement).visit(driver.body)
assert len(calls) == 1
assert calls[0].name == 'state%update_view'

field_offload = FieldOffloadTransformation(
devptr_prefix='', offload_index='ibl', field_group_types=['state_type']
)
driver.apply(field_offload, role='driver', targets=['kernel_routine'])

calls = FindNodes(ir.CallStatement).visit(driver.body)
assert len(calls) == 3
assert calls[0].name == 'state%f_b%get_device_data_rdonly'
assert calls[0].arguments == ('state_b',)
assert calls[1].name == 'state%f_a%get_device_data_rdwr'
assert calls[1].arguments == ('state_a',)
assert calls[2].name == 'state%f_a%sync_host_rdwr'
assert calls[2].arguments == ()

assigns = FindNodes(ir.Assignment).visit(driver.body)
assert len(assigns) == 2
assert assigns[0].lhs == 'state_a(i,1,ibl)'
assert assigns[0].rhs == 'state_b(i,1,ibl) + 0.1'
assert assigns[1].lhs == 'state_a(i,2,ibl)'
assert assigns[1].rhs == 'state_a(i,1,ibl)'

0 comments on commit cce8b93

Please sign in to comment.