From 33429be53dfb5c140710d68dcdf0e409a8a8289b Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 4 Dec 2024 13:13:52 +0000 Subject: [PATCH] FieldAPI: Filter duplicates in FieldPointerMap when generating calls --- .../data_offload/field_offload.py | 9 ------ loki/transformations/field_api.py | 32 +++++++++++++------ 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/loki/transformations/data_offload/field_offload.py b/loki/transformations/data_offload/field_offload.py index a5d1e24f1..d09d77b98 100644 --- a/loki/transformations/data_offload/field_offload.py +++ b/loki/transformations/data_offload/field_offload.py @@ -135,15 +135,6 @@ def find_offload_variables(driver, calls, field_group_types): if param.type.intent.lower() == 'out': outargs += (arg, ) - inoutargs += tuple(v for v in inargs if v in outargs) - inargs = tuple(v for v in inargs if v not in inoutargs) - outargs = tuple(v for v in outargs if v not in inoutargs) - - # Filter out duplicates and return as tuple - inargs = tuple(dict.fromkeys(inargs)) - inoutargs = tuple(dict.fromkeys(inoutargs)) - outargs = tuple(dict.fromkeys(outargs)) - return inargs, inoutargs, outargs diff --git a/loki/transformations/field_api.py b/loki/transformations/field_api.py index 826c6fb71..43f9b8935 100644 --- a/loki/transformations/field_api.py +++ b/loki/transformations/field_api.py @@ -41,9 +41,15 @@ class FieldPointerMap: properties, based on the intent of the kernel argument. """ def __init__(self, inargs, inoutargs, outargs, scope, ptr_prefix='loki_devptr_'): - self.inargs = inargs - self.inoutargs = inoutargs - self.outargs = outargs + # Ensure no duplication between in/inout/out args + inoutargs += tuple(v for v in inargs if v in outargs) + inargs = tuple(v for v in inargs if v not in inoutargs) + outargs = tuple(v for v in outargs if v not in inoutargs) + + # Filter out duplicates and return as tuple + self.inargs = tuple(dict.fromkeys(inargs)) + self.inoutargs = tuple(dict.fromkeys(inoutargs)) + self.outargs = tuple(dict.fromkeys(outargs)) self.scope = scope @@ -70,10 +76,10 @@ def field_ptr_from_view(field_view): @property def dataptrs(self): """ Create a list of contiguous data pointer symbols """ - return tuple( + return tuple(dict.fromkeys( self.dataptr_from_array(a) for a in chain(*(self.inargs, self.inoutargs, self.outargs)) - ) + )) @property def host_to_device_calls(self): @@ -82,17 +88,23 @@ def host_to_device_calls(self): """ READ_ONLY, READ_WRITE = FieldAPITransferType.READ_ONLY, FieldAPITransferType.READ_WRITE + # Filter down to base symbols and avoid duplicates across sets + inargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in self.inargs)) + inoutargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in self.inoutargs)) + outargs = tuple(dict.fromkeys(a.clone(dimensions=None) for a in self.outargs)) + inargs = tuple(a for a in inargs if a not in inoutargs) + host_to_device = tuple(field_get_device_data( self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_ONLY, scope=self.scope - ) for arg in self.inargs) + ) for arg in inargs) host_to_device += tuple(field_get_device_data( self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_WRITE, scope=self.scope - ) for arg in self.inoutargs) + ) for arg in inoutargs) host_to_device += tuple(field_get_device_data( self.field_ptr_from_view(arg), self.dataptr_from_array(arg), READ_WRITE, scope=self.scope - ) for arg in self.outargs) + ) for arg in outargs) - return host_to_device + return tuple(dict.fromkeys(host_to_device)) @property def sync_host_calls(self): @@ -105,7 +117,7 @@ def sync_host_calls(self): sync_host += tuple( field_sync_host(self.field_ptr_from_view(arg), scope=self.scope) for arg in self.outargs ) - return sync_host + return tuple(dict.fromkeys(sync_host)) def get_field_type(a: sym.Array) -> sym.DerivedType: