Skip to content

Commit

Permalink
Removed automatic inferral of offload indexes from view updates
Browse files Browse the repository at this point in the history
  • Loading branch information
wertysas committed Nov 11, 2024
1 parent b29c6d7 commit 561424e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 23 deletions.
35 changes: 12 additions & 23 deletions loki/transformations/data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from loki.logging import warning
from loki.tools import as_tuple, flatten, CaseInsensitiveDict, CaseInsensitiveDefaultDict
from loki.types import BasicType, DerivedType
from loki.transformations.parallel import FieldAPITransferType, field_get_device_data, field_sync_host
from loki.transformations.parallel import (
FieldAPITransferType, field_get_device_data, field_sync_host, remove_field_api_view_updates
)

__all__ = [
'DataOffloadTransformation', 'GlobalVariableAnalysis',
Expand Down Expand Up @@ -1014,6 +1016,7 @@ def __init__(self, **kwargs):
self.deviceptr_prefix = kwargs.get('devptr_prefix', 'loki_devptr_')
field_group_types = kwargs.get('field_group_types', ['CLOUDSC_STATE_TYPE', 'CLOUDSC_AUX_TYPE', 'CLOUDSC_FLUX_TYPE'])
self.field_group_types = tuple(typename.lower() for typename in field_group_types)
self.offload_index = kwargs.get('offload_index', 'IBL')

def transform_subroutine(self, routine, **kwargs):
role = kwargs['role']
Expand All @@ -1032,12 +1035,14 @@ def process_driver(self, driver, targets):
# Only work on active `!$loki data` regions
if not DataOffloadTransformation._is_active_loki_data_region(region, targets):
continue
# remove_field_api_view_updates(driver, self.field_group_types) # FIXME: if called here, driver.body will not be updated in subsequent routines
kernel_calls = find_target_calls(region, targets)
offload_variables = self.find_offload_variables(driver, kernel_calls)
device_ptrs = self._declare_device_ptrs(driver, offload_variables)
offload_map = self.FieldPointerMap(device_ptrs, *offload_variables)
old_offload_calls = self._replace_data_offload_calls(driver, region, offload_map)
self._replace_kernel_args(kernel_calls, old_offload_calls, offload_map)
self._add_field_offload_calls(driver, region, offload_map)
self._replace_kernel_args(driver, kernel_calls, offload_map)
remove_field_api_view_updates(driver, self.field_group_types) # if called here it works

def find_offload_variables(self, driver, calls):
inargs = ()
Expand Down Expand Up @@ -1102,12 +1107,7 @@ def _devptr_from_array(self, driver, a: sym.Array):
devptr = sym.Variable(name=devptr_name, type=devptr_type, dimensions=shape)
return devptr

def _replace_data_offload_calls(self, driver, region, offload_map):
# remove calls to [field_group_type]%update_view
calls = FindNodes(CallStatement).visit(region)
field_group_updates = tuple(c for c in calls if self._is_field_group_update(driver, c))
# c.arguments contains Scalar(IBL)
Transformer(dict.fromkeys(field_group_updates, None), inplace=True).visit(region.body)
def _add_field_offload_calls(self, driver, region, offload_map):
host_to_device = tuple(field_get_device_data(self._get_field_ptr_from_view(inarg), devptr,
FieldAPITransferType.READ_ONLY, driver) for inarg, devptr in offload_map.in_pairs)
host_to_device += tuple(field_get_device_data(self._get_field_ptr_from_view(inarg), devptr,
Expand All @@ -1116,10 +1116,8 @@ def _replace_data_offload_calls(self, driver, region, offload_map):
FieldAPITransferType.READ_WRITE, driver) for inarg, devptr in offload_map.out_pairs)
device_to_host = tuple(field_sync_host(self._get_field_ptr_from_view(inarg), driver)
for inarg, _ in chain(offload_map.inout_pairs, offload_map.out_pairs))
# field_deletes = tuple(field_delete(field_ptr_map[var], routine) for var in blocking_arrays)
update_map = {region: host_to_device + (region,) + device_to_host}
Transformer(update_map, inplace=True).visit(driver.body)
return field_group_updates

def _is_field_group_update(self, driver, call):
try:
Expand All @@ -1136,20 +1134,11 @@ def _get_field_ptr_from_view(self, field_view):
field_type_name = 'F_' + type_chain[-1]
return field_view.parent.get_derived_type_member(field_type_name)

def _replace_kernel_args(self, kernel_calls, old_offload_calls, offload_map):
"""TODO: Docstring for _replace_kernel_calls.
:kernel_calls: TODO
:old_offload_calls: TODO
:device_ptrs: TODO
:returns: TODO
"""
def _replace_kernel_args(self, driver, kernel_calls, offload_map):
change_map = {}
offload_idx_expr = driver.variable_map[self.offload_index]
for arg, devptr in chain(offload_map.in_pairs, offload_map.inout_pairs, offload_map.out_pairs):
group_update = next((c for c in old_offload_calls if c.name.parent == arg.parent), None)
assert group_update is not None, "Group update should not be none"
block_idx = group_update.arguments[0]
dims = (sym.RangeIndex((None, None)),) * (len(devptr.shape)-1) + (block_idx,)
dims = (sym.RangeIndex((None, None)),) * (len(devptr.shape)-1) + (offload_idx_expr,)
change_map[arg] = devptr.clone(dimensions=dims)
arg_transformer = SubstituteExpressions(change_map, inplace=True)
for call in kernel_calls:
Expand Down
3 changes: 3 additions & 0 deletions loki/transformations/tests/test_data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,7 @@ def test_field_offload(frontend):
driver = driver_mod['driver_routine']
deviceptr_prefix = 'loki_devptr_prefix_'
driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
offload_index='i',
field_group_types=['state_type']),
role='driver',
targets=['kernel_routine'])
Expand Down Expand Up @@ -1000,6 +1001,7 @@ def test_field_offload_multiple_calls(frontend):
driver = driver_mod['driver_routine']
deviceptr_prefix = 'loki_devptr_prefix_'
driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
offload_index='i',
field_group_types=['state_type']),
role='driver',
targets=['kernel_routine'])
Expand Down Expand Up @@ -1084,6 +1086,7 @@ def test_field_offload_no_targets(frontend):
driver = driver_mod['driver_routine']
deviceptr_prefix = 'loki_devptr_prefix_'
driver.apply(FieldOffloadTransformation(devptr_prefix=deviceptr_prefix,
offload_index='i',
field_group_types=['state_type']),
role='driver',
targets=['kernel_routine'])
Expand Down

0 comments on commit 561424e

Please sign in to comment.