Skip to content

Commit

Permalink
FieldAPI: Filter duplicates in FieldPointerMap when generating calls
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Dec 5, 2024
1 parent a4a446e commit 33429be
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
9 changes: 0 additions & 9 deletions loki/transformations/data_offload/field_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,6 @@ def find_offload_variables(driver, calls, field_group_types):
if param.type.intent.lower() == 'out':
outargs += (arg, )

inoutargs += tuple(v for v in inargs if v in outargs)
inargs = tuple(v for v in inargs if v not in inoutargs)
outargs = tuple(v for v in outargs if v not in inoutargs)

# Filter out duplicates and return as tuple
inargs = tuple(dict.fromkeys(inargs))
inoutargs = tuple(dict.fromkeys(inoutargs))
outargs = tuple(dict.fromkeys(outargs))

return inargs, inoutargs, outargs


Expand Down
32 changes: 22 additions & 10 deletions loki/transformations/field_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@ class FieldPointerMap:
properties, based on the intent of the kernel argument.
"""
def __init__(self, inargs, inoutargs, outargs, scope, ptr_prefix='loki_devptr_'):
self.inargs = inargs
self.inoutargs = inoutargs
self.outargs = outargs
# Ensure no duplication between in/inout/out args
inoutargs += tuple(v for v in inargs if v in outargs)
inargs = tuple(v for v in inargs if v not in inoutargs)
outargs = tuple(v for v in outargs if v not in inoutargs)

# Filter out duplicates and return as tuple
self.inargs = tuple(dict.fromkeys(inargs))
self.inoutargs = tuple(dict.fromkeys(inoutargs))
self.outargs = tuple(dict.fromkeys(outargs))

self.scope = scope

Expand All @@ -70,10 +76,10 @@ def field_ptr_from_view(field_view):
@property
def dataptrs(self):
""" Create a list of contiguous data pointer symbols """
return tuple(
return tuple(dict.fromkeys(
self.dataptr_from_array(a)
for a in chain(*(self.inargs, self.inoutargs, self.outargs))
)
))

@property
def host_to_device_calls(self):
Expand All @@ -82,17 +88,23 @@ def host_to_device_calls(self):
"""
READ_ONLY, READ_WRITE = FieldAPITransferType.READ_ONLY, FieldAPITransferType.READ_WRITE

# Filter down to base symbols and avoid duplicates across sets
inargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in self.inargs))
inoutargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in self.inoutargs))
outargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in self.outargs))
inargs = tuple(a for a in inargs if a not in inoutargs)

host_to_device = tuple(field_get_device_data(
self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_ONLY, scope=self.scope
) for arg in self.inargs)
) for arg in inargs)
host_to_device += tuple(field_get_device_data(
self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_WRITE, scope=self.scope
) for arg in self.inoutargs)
) for arg in inoutargs)
host_to_device += tuple(field_get_device_data(
self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_WRITE, scope=self.scope
) for arg in self.outargs)
) for arg in outargs)

return host_to_device
return tuple(dict.fromkeys(host_to_device))

@property
def sync_host_calls(self):
Expand All @@ -105,7 +117,7 @@ def sync_host_calls(self):
sync_host += tuple(
field_sync_host(self.field_ptr_from_view(arg), scope=self.scope) for arg in self.outargs
)
return sync_host
return tuple(dict.fromkeys(sync_host))


def get_field_type(a: sym.Array) -> sym.DerivedType:
Expand Down

0 comments on commit 33429be

Please sign in to comment.