Skip to content

Commit

Permalink
DataOffload: Apply symbol substitution over driver in field offload
Browse files Browse the repository at this point in the history
Instead of subbing just on calls, we apply the remapping over the
whole routine body.
  • Loading branch information
mlange05 committed Dec 5, 2024
1 parent 33429be commit 7bfbf38
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
17 changes: 10 additions & 7 deletions loki/transformations/data_offload/field_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from loki.ir import (
FindNodes, PragmaRegion, CallStatement,
Transformer, pragma_regions_attached,
SubstituteExpressions
SubstituteExpressions, FindVariables
)
from loki.logging import warning, error
from loki.tools import as_tuple
Expand Down Expand Up @@ -85,7 +85,7 @@ def process_driver(self, driver, targets):
)
declare_device_ptrs(driver, deviceptrs=offload_map.dataptrs)
add_field_offload_calls(driver, region, offload_map)
replace_kernel_args(driver, kernel_calls, offload_map, self.offload_index)
replace_kernel_args(driver, offload_map, self.offload_index)


def find_target_calls(region, targets):
Expand Down Expand Up @@ -158,17 +158,20 @@ def add_field_offload_calls(driver, region, offload_map):
Transformer(update_map, inplace=True).visit(driver.body)


def replace_kernel_args(driver, kernel_calls, offload_map, offload_index):
def replace_kernel_args(driver, offload_map, offload_index):
change_map = {}
offload_idx_expr = driver.variable_map[offload_index]
for arg in chain(offload_map.inargs, offload_map.inoutargs, offload_map.outargs):

args = tuple(chain(offload_map.inargs, offload_map.inoutargs, offload_map.outargs))
for arg in FindVariables().visit(driver.body):
if not arg.name in args:
continue

devptr = offload_map.dataptr_from_array(arg)
if len(arg.dimensions) != 0:
dims = arg.dimensions + (offload_idx_expr,)
else:
dims = (sym.RangeIndex((None, None)),) * (len(devptr.shape)-1) + (offload_idx_expr,)
change_map[arg] = devptr.clone(dimensions=dims)

arg_transformer = SubstituteExpressions(change_map, inplace=True)
for call in kernel_calls:
arg_transformer.visit(call)
driver.body = SubstituteExpressions(change_map, inplace=True).visit(driver.body)
4 changes: 2 additions & 2 deletions loki/transformations/data_offload/tests/test_field_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from loki import Sourcefile, Module
import loki.expression.symbols as sym
from loki.frontend import available_frontends
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes, Pragma, CallStatement
from loki.logging import log_levels

Expand Down Expand Up @@ -632,5 +632,5 @@ def test_field_offload_aliasing(frontend, state_module, tmp_path):
assert calls[2].arguments == ()

decls = FindNodes(ir.VariableDeclaration).visit(driver.spec)
assert len(decls) == 4
assert len(decls) == 5 if frontend == OMNI else 4
assert decls[-1].symbols == ('state_a(:,:,:)',)
21 changes: 9 additions & 12 deletions loki/transformations/field_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def __init__(self, inargs, inoutargs, outargs, scope, ptr_prefix='loki_devptr_')
outargs = tuple(v for v in outargs if v not in inoutargs)

# Filter out duplicates and return as tuple
self.inargs = tuple(dict.fromkeys(inargs))
self.inoutargs = tuple(dict.fromkeys(inoutargs))
self.outargs = tuple(dict.fromkeys(outargs))
self.inargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in inargs))
self.inoutargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in inoutargs))
self.outargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in outargs))

# Filter out duplicates across argument tuples
self.inargs = tuple(a for a in self.inargs if a not in self.inoutargs)

self.scope = scope

Expand Down Expand Up @@ -88,21 +91,15 @@ def host_to_device_calls(self):
"""
READ_ONLY, READ_WRITE = FieldAPITransferType.READ_ONLY, FieldAPITransferType.READ_WRITE

# Filter down to base symbols and avoid duplicates across sets
inargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in self.inargs))
inoutargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in self.inoutargs))
outargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in self.outargs))
inargs = tuple(a for a in inargs if a not in inoutargs)

host_to_device = tuple(field_get_device_data(
self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_ONLY, scope=self.scope
) for arg in inargs)
) for arg in self.inargs)
host_to_device += tuple(field_get_device_data(
self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_WRITE, scope=self.scope
) for arg in inoutargs)
) for arg in self.inoutargs)
host_to_device += tuple(field_get_device_data(
self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_WRITE, scope=self.scope
) for arg in outargs)
) for arg in self.outargs)

return tuple(dict.fromkeys(host_to_device))

Expand Down

0 comments on commit 7bfbf38

Please sign in to comment.