From e72225010be37004d40561541ba8fb8cb0b5625c Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Mon, 9 Oct 2023 09:18:36 +0200 Subject: [PATCH 01/27] identified arrays passed as scalars --- loki/transform/__init__.py | 1 + loki/transform/transform_scalar_syntax.py | 71 +++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 loki/transform/transform_scalar_syntax.py diff --git a/loki/transform/__init__.py b/loki/transform/__init__.py index 56504cdb4..8f1dc4448 100644 --- a/loki/transform/__init__.py +++ b/loki/transform/__init__.py @@ -19,3 +19,4 @@ from loki.transform.build_system_transform import * # noqa from loki.transform.transform_hoist_variables import * # noqa from loki.transform.transform_parametrise import * # noqa +from loki.transform.transform_scalar_syntax import * # noqa diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py new file mode 100644 index 000000000..a2e30c99b --- /dev/null +++ b/loki/transform/transform_scalar_syntax.py @@ -0,0 +1,71 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from loki.expression import ( + Product, IntLiteral + ) +from loki.ir import CallStatement +from loki.visitors import FindNodes +from loki import Array, RangeIndex + + +__all__ = [ + 'fix_scalar_syntax' +] + + +def fix_scalar_syntax(routine): + """ + Housekeeping routine to replace scalar syntax when passing arrays as arguments + For example, a call like + + call myroutine(a(i,j)) + + where myroutine looks like + + subroutine myroutine(a) + real :: a(5) + end subroutine myroutine + + should be changed to + + call myroutine(a(i:i+5,j) + + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine where calls will be changed + """ + + #Define minus one for later + minus_one = Product((-1, IntLiteral(1))) + + calls = FindNodes(CallStatement).visit(routine.body) + + for call in calls: + + arg_map = {} + + + for dummy, arg in call.arg_map.items(): + if isinstance(arg, Array) and isinstance(dummy, Array): + if arg.dimensions: + n_dummy_ranges = sum(1 for d in arg.dimensions if isinstance(d, RangeIndex)) + if n_dummy_ranges == 0: + print(call) + print(arg, dummy) + for s in dummy.shape: + print(s, s.__class__) + print() + + + + + + + From 294b89ecf321c4f5433bc8c2d64df9684c4d7b67 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Tue, 10 Oct 2023 09:27:23 +0200 Subject: [PATCH 02/27] scalars sort of fixed --- loki/transform/transform_scalar_syntax.py | 80 ++++++++++++++++++----- 1 file changed, 64 insertions(+), 16 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index a2e30c99b..9392fdfbc 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -6,17 +6,45 @@ # nor does it submit to any jurisdiction. from loki.expression import ( - Product, IntLiteral + Sum, Product, IntLiteral, Scalar, Array, RangeIndex, DeferredTypeSymbol ) from loki.ir import CallStatement from loki.visitors import FindNodes -from loki import Array, RangeIndex +from loki.tools import as_tuple __all__ = [ 'fix_scalar_syntax' ] +def check_if_scalar_syntax(arg, dummy): + if isinstance(arg, Array) and isinstance(dummy, Array): + if arg.dimensions: + n_dummy_ranges = sum(1 for d in arg.dimensions if isinstance(d, RangeIndex)) + if n_dummy_ranges == 0: + return True + return False + +def construct_range_index(lower, length): + + #Define one and minus one for later + + one = IntLiteral(1) + minus_one = Product((-1, IntLiteral(1))) + + if lower == one: + new_high = length + elif isinstance(lower, IntLiteral) and isinstance(length, IntLiteral): + new_high = IntLiteral(value = length.value + lower.value - 1) + elif isinstance(lower, IntLiteral): + new_high = Sum((length,IntLiteral(value = lower.value - 1))) + elif isinstance(length, IntLiteral): + new_high = Sum((lower,IntLiteral(value = length.value - 1))) + else: + new_high = Sum((lower, length, minus_one)) + + return RangeIndex((lower, new_high)) + def fix_scalar_syntax(routine): """ @@ -42,26 +70,46 @@ def fix_scalar_syntax(routine): The subroutine where calls will be changed """ - #Define minus one for later - minus_one = Product((-1, IntLiteral(1))) - calls = FindNodes(CallStatement).visit(routine.body) for call in calls: - arg_map = {} - + new_arg_map = {} for dummy, arg in call.arg_map.items(): - if isinstance(arg, Array) and isinstance(dummy, Array): - if arg.dimensions: - n_dummy_ranges = sum(1 for d in arg.dimensions if isinstance(d, RangeIndex)) - if n_dummy_ranges == 0: - print(call) - print(arg, dummy) - for s in dummy.shape: - print(s, s.__class__) - print() + if check_if_scalar_syntax(arg, dummy): + print(routine) + print(call) + print(arg, dummy) + new_dims = [] + for s, lower in zip(dummy.shape, arg.dimensions): + + if isinstance(s, IntLiteral): + new_dims += [construct_range_index(lower, s)] + + elif isinstance(s, Scalar): + if s in call.routine.arguments: + new_dims += [construct_range_index(lower,call.arg_map[s])] + elif call.routine in routine.members and s in routine.variables: + new_dims += [construct_range_index(lower,s)] + else: + raise RuntimeError('[Loki::fix_scalar_syntax] Unable to resolve argument dimension. Module variable?') + + elif isinstance(s, DeferredTypeSymbol): + + if s.parents[0] in call.routine.arguments: + print(s, s.parents[0], s.parents[0].scope) + print(call.arg_map[s.parents[0]]) + print() + + + if len(arg.dimensions) > len(dummy.shape): + new_dims += [d for d in arg.dimensions[len(dummy.shape):]] + + new_dims = as_tuple(new_dims) + new_arg = arg.clone(dimensions=new_dims) + print('new_arg: ', new_arg) + print() From d576e991cbf57c10da2e3dd8a275342d62efc2fa Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Tue, 10 Oct 2023 09:29:18 +0200 Subject: [PATCH 03/27] commit call --- transformations/transformations/single_column_coalesced.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformations/transformations/single_column_coalesced.py b/transformations/transformations/single_column_coalesced.py index 6185a3c88..9065fc840 100644 --- a/transformations/transformations/single_column_coalesced.py +++ b/transformations/transformations/single_column_coalesced.py @@ -7,7 +7,7 @@ import re from loki.expression import symbols as sym -from loki.transform import resolve_associates, inline_member_procedures +from loki.transform import resolve_associates, inline_member_procedures, fix_scalar_syntax from loki import ( Transformation, FindNodes, FindScopes, Transformer, info, pragmas_attached, as_tuple, flatten, ir, FindExpressions, @@ -238,6 +238,8 @@ def process_kernel(self, routine): # Find the iteration index variable for the specified horizontal v_index = self.get_integer_variable(routine, name=self.horizontal.index) + fix_scalar_syntax(routine) + # Perform full source-inlining for member subroutines if so requested if self.inline_members: inline_member_procedures(routine) From 40b90d61e2c4f12ab3348e19d6bb3a6e6e85f444 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Tue, 10 Oct 2023 16:56:36 +0200 Subject: [PATCH 04/27] constructed new arguments --- loki/transform/transform_scalar_syntax.py | 92 ++++++++++++++--------- 1 file changed, 56 insertions(+), 36 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 9392fdfbc..12b86279a 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -6,7 +6,8 @@ # nor does it submit to any jurisdiction. from loki.expression import ( - Sum, Product, IntLiteral, Scalar, Array, RangeIndex, DeferredTypeSymbol + Sum, Product, IntLiteral, Scalar, Array, RangeIndex, + DeferredTypeSymbol, SubstituteExpressions ) from loki.ir import CallStatement from loki.visitors import FindNodes @@ -25,14 +26,10 @@ def check_if_scalar_syntax(arg, dummy): return True return False -def construct_range_index(lower, length): - - #Define one and minus one for later - one = IntLiteral(1) - minus_one = Product((-1, IntLiteral(1))) +def construct_range_index(lower, length): - if lower == one: + if lower == IntLiteral(1): new_high = length elif isinstance(lower, IntLiteral) and isinstance(length, IntLiteral): new_high = IntLiteral(value = length.value + lower.value - 1) @@ -41,11 +38,53 @@ def construct_range_index(lower, length): elif isinstance(length, IntLiteral): new_high = Sum((lower,IntLiteral(value = length.value - 1))) else: - new_high = Sum((lower, length, minus_one)) - + new_high = Sum((lower, length, Product((-1, IntLiteral(1))))) + return RangeIndex((lower, new_high)) +def merge_parents(parent, symbol): + + new_parent = parent.clone() + for p in symbol.parents[1:]: + new_parent = DeferredTypeSymbol(name=p.name_parts[-1], scope=parent.scope, parent=new_parent) + return symbol.clone(parent=new_parent, scope=parent.scope) + + +def process_symbol(symbol, caller, call): + + if isinstance(symbol, IntLiteral): + return symbol + + elif isinstance(symbol, Scalar): + if symbol in call.routine.arguments: + return call.arg_map[symbol] + + elif isinstance(symbol, DeferredTypeSymbol): + if symbol.parents[0] in call.routine.arguments: + return merge_parents(call.arg_map[symbol.parents[0]], symbol) + + if call.routine in caller.members and symbol in caller.variables: + return symbol + + raise RuntimeError('[Loki::fix_scalar_syntax] Unable to resolve argument dimension. Module variable?') + + +def construct_length(xrange, routine, call): + + new_start = process_symbol(xrange.start, routine, call) + new_stop = process_symbol(xrange.stop, routine, call) + + if isinstance(new_start, IntLiteral) and isinstance(new_stop, IntLiteral): + return IntLiteral(value = new_stop.value - new_start.value + 1) + elif isinstance(new_start, IntLiteral): + return Sum((new_stop, Product((-1,(IntLiteral(value = new_start.value - 1)))))) + elif isinstance(new_stop, IntLiteral): + return Sum((IntLiteral(value = new_stop.value + 1), Product((-1,new_start)))) + else: + return Sum((new_stop, Product((-1,new_start)), IntLiteral(1))) + + def fix_scalar_syntax(routine): """ Housekeeping routine to replace scalar syntax when passing arrays as arguments @@ -78,42 +117,23 @@ def fix_scalar_syntax(routine): for dummy, arg in call.arg_map.items(): if check_if_scalar_syntax(arg, dummy): - print(routine) - print(call) - print(arg, dummy) + new_dims = [] for s, lower in zip(dummy.shape, arg.dimensions): - - if isinstance(s, IntLiteral): - new_dims += [construct_range_index(lower, s)] - - elif isinstance(s, Scalar): - if s in call.routine.arguments: - new_dims += [construct_range_index(lower,call.arg_map[s])] - elif call.routine in routine.members and s in routine.variables: - new_dims += [construct_range_index(lower,s)] - else: - raise RuntimeError('[Loki::fix_scalar_syntax] Unable to resolve argument dimension. Module variable?') - - elif isinstance(s, DeferredTypeSymbol): - - if s.parents[0] in call.routine.arguments: - print(s, s.parents[0], s.parents[0].scope) - print(call.arg_map[s.parents[0]]) - print() + if isinstance(s, RangeIndex): + new_dims += [construct_range_index(lower, construct_length(s, routine, call))] + else: + new_dims += [construct_range_index(lower, process_symbol(s, routine, call))] if len(arg.dimensions) > len(dummy.shape): new_dims += [d for d in arg.dimensions[len(dummy.shape):]] new_dims = as_tuple(new_dims) new_arg = arg.clone(dimensions=new_dims) - print('new_arg: ', new_arg) - print() - - - - + print(arg, new_arg) + new_arg_map[arg] = new_arg + routine.body = SubstituteExpressions(new_arg_map).visit(routine.body) From f46453ee68329e82789551f80ef9baf85bf7f245 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Wed, 11 Oct 2023 12:38:16 +0200 Subject: [PATCH 05/27] Might work now --- loki/transform/transform_scalar_syntax.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 12b86279a..036a33ed2 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -10,7 +10,7 @@ DeferredTypeSymbol, SubstituteExpressions ) from loki.ir import CallStatement -from loki.visitors import FindNodes +from loki.visitors import FindNodes, Transformer from loki.tools import as_tuple @@ -110,10 +110,11 @@ def fix_scalar_syntax(routine): """ calls = FindNodes(CallStatement).visit(routine.body) + call_map = {} for call in calls: - new_arg_map = {} + new_args = [] for dummy, arg in call.arg_map.items(): if check_if_scalar_syntax(arg, dummy): @@ -129,11 +130,12 @@ def fix_scalar_syntax(routine): if len(arg.dimensions) > len(dummy.shape): new_dims += [d for d in arg.dimensions[len(dummy.shape):]] - new_dims = as_tuple(new_dims) - new_arg = arg.clone(dimensions=new_dims) + new_args += [arg.clone(dimensions=as_tuple(new_dims)),] - print(arg, new_arg) - new_arg_map[arg] = new_arg + else: - routine.body = SubstituteExpressions(new_arg_map).visit(routine.body) + new_args += [arg,] + call_map[call] = call.clone(arguments = as_tuple(new_args)) + + routine.body = Transformer(call_map).visit(routine.body) From 2afa4724c59194491caddc3cadfedb4c61f634f0 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Mon, 16 Oct 2023 15:49:33 +0200 Subject: [PATCH 06/27] Check that we have the called routine definition --- loki/transform/transform_scalar_syntax.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 036a33ed2..65e8f1eb6 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -12,6 +12,7 @@ from loki.ir import CallStatement from loki.visitors import FindNodes, Transformer from loki.tools import as_tuple +from loki.types import BasicType __all__ = [ @@ -109,7 +110,8 @@ def fix_scalar_syntax(routine): The subroutine where calls will be changed """ - calls = FindNodes(CallStatement).visit(routine.body) + #List calls in routine, but make sure we have the called routine definition + calls = (c for c in FindNodes(CallStatement).visit(routine.body) if not c.procedure_type is BasicType.DEFERRED) call_map = {} for call in calls: From e68135baeff03f4f20b34e2703e436d6cde57a59 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Mon, 16 Oct 2023 16:59:16 +0200 Subject: [PATCH 07/27] some documentation --- loki/transform/transform_scalar_syntax.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 65e8f1eb6..497ba79c1 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -6,8 +6,8 @@ # nor does it submit to any jurisdiction. from loki.expression import ( - Sum, Product, IntLiteral, Scalar, Array, RangeIndex, - DeferredTypeSymbol, SubstituteExpressions + Sum, Product, IntLiteral, Scalar, Array, RangeIndex, + TypedSymbol, SubstituteExpressions ) from loki.ir import CallStatement from loki.visitors import FindNodes, Transformer @@ -20,10 +20,19 @@ ] def check_if_scalar_syntax(arg, dummy): + """ + Check if an array argument, arg, + is passed to an array dummy argument, dummy, + using scalar syntax. i.e. arg(1,1) -> d(m,n) + + Parameters + ---------- + arg: variable + dummy: variable + """ if isinstance(arg, Array) and isinstance(dummy, Array): if arg.dimensions: - n_dummy_ranges = sum(1 for d in arg.dimensions if isinstance(d, RangeIndex)) - if n_dummy_ranges == 0: + if not any(isinstance(d, RangeIndex) for d in arg.dimensions): return True return False @@ -48,7 +57,7 @@ def merge_parents(parent, symbol): new_parent = parent.clone() for p in symbol.parents[1:]: - new_parent = DeferredTypeSymbol(name=p.name_parts[-1], scope=parent.scope, parent=new_parent) + new_parent = TypedSymbol(name=p.name_parts[-1], scope=parent.scope, parent=new_parent) return symbol.clone(parent=new_parent, scope=parent.scope) @@ -61,7 +70,7 @@ def process_symbol(symbol, caller, call): if symbol in call.routine.arguments: return call.arg_map[symbol] - elif isinstance(symbol, DeferredTypeSymbol): + elif isinstance(symbol, TypedSymbol): if symbol.parents[0] in call.routine.arguments: return merge_parents(call.arg_map[symbol.parents[0]], symbol) From 2653b84245133476e38126710685505643d9e488 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Tue, 17 Oct 2023 13:29:54 +0200 Subject: [PATCH 08/27] simplify TypedSymbol handling --- loki/transform/transform_scalar_syntax.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 497ba79c1..db2b2c62c 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -53,14 +53,6 @@ def construct_range_index(lower, length): return RangeIndex((lower, new_high)) -def merge_parents(parent, symbol): - - new_parent = parent.clone() - for p in symbol.parents[1:]: - new_parent = TypedSymbol(name=p.name_parts[-1], scope=parent.scope, parent=new_parent) - return symbol.clone(parent=new_parent, scope=parent.scope) - - def process_symbol(symbol, caller, call): if isinstance(symbol, IntLiteral): @@ -72,7 +64,7 @@ def process_symbol(symbol, caller, call): elif isinstance(symbol, TypedSymbol): if symbol.parents[0] in call.routine.arguments: - return merge_parents(call.arg_map[symbol.parents[0]], symbol) + return SubstituteExpressions(call.arg_map).visit(symbol) if call.routine in caller.members and symbol in caller.variables: return symbol From 7560abb7c69e96d6601cf38d7c170684cdfd5d2c Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Wed, 18 Oct 2023 17:02:22 +0200 Subject: [PATCH 09/27] some cleanup --- loki/transform/transform_scalar_syntax.py | 57 ++++++++++++++++++----- 1 file changed, 46 insertions(+), 11 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index db2b2c62c..60876df58 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -7,12 +7,13 @@ from loki.expression import ( Sum, Product, IntLiteral, Scalar, Array, RangeIndex, - TypedSymbol, SubstituteExpressions + SubstituteExpressions ) from loki.ir import CallStatement from loki.visitors import FindNodes, Transformer from loki.tools import as_tuple from loki.types import BasicType +import pymbolic.primitives as pmbl __all__ = [ @@ -37,6 +38,39 @@ def check_if_scalar_syntax(arg, dummy): return False +def single_sum(expr): + if isinstance(expr, pmbl.Sum): + return expr + else: + return Sum((expr,)) + + +def sum_ints(expr): + if isinstance(expr, pmbl.Sum): + n = 0 + new_children = [] + for c in expr.children: + if isinstance(c, IntLiteral): + n += c.value + elif (isinstance(c, pmbl.Product) and + all(isinstance(cc, IntLiteral) or isinstance(cc,int) for cc in c.children)): + m = 1 + for cc in c.children: + if isinstance(cc, IntLiteral): + m = m*cc.value + else: + m = m*cc + n += m + else: + new_children += [c] + + if n != 0: + new_children += [IntLiteral(n)] + + expr.children = as_tuple(new_children) + + + def construct_range_index(lower, length): if lower == IntLiteral(1): @@ -44,11 +78,13 @@ def construct_range_index(lower, length): elif isinstance(lower, IntLiteral) and isinstance(length, IntLiteral): new_high = IntLiteral(value = length.value + lower.value - 1) elif isinstance(lower, IntLiteral): - new_high = Sum((length,IntLiteral(value = lower.value - 1))) + new_high = single_sum(length) + IntLiteral(value = lower.value - 1) elif isinstance(length, IntLiteral): - new_high = Sum((lower,IntLiteral(value = length.value - 1))) + new_high = single_sum(lower) + IntLiteral(value = length.value - 1) else: - new_high = Sum((lower, length, Product((-1, IntLiteral(1))))) + new_high = single_sum(length) + lower - IntLiteral(1) + + sum_ints(new_high) return RangeIndex((lower, new_high)) @@ -58,13 +94,12 @@ def process_symbol(symbol, caller, call): if isinstance(symbol, IntLiteral): return symbol - elif isinstance(symbol, Scalar): + elif not symbol.parents: if symbol in call.routine.arguments: return call.arg_map[symbol] - elif isinstance(symbol, TypedSymbol): - if symbol.parents[0] in call.routine.arguments: - return SubstituteExpressions(call.arg_map).visit(symbol) + elif symbol.parents[0] in call.routine.arguments: + return SubstituteExpressions(call.arg_map).visit(symbol.clone(scope=caller)) if call.routine in caller.members and symbol in caller.variables: return symbol @@ -80,11 +115,11 @@ def construct_length(xrange, routine, call): if isinstance(new_start, IntLiteral) and isinstance(new_stop, IntLiteral): return IntLiteral(value = new_stop.value - new_start.value + 1) elif isinstance(new_start, IntLiteral): - return Sum((new_stop, Product((-1,(IntLiteral(value = new_start.value - 1)))))) + return single_sum(new_stop) - IntLiteral(value = new_start.value - 1) elif isinstance(new_stop, IntLiteral): - return Sum((IntLiteral(value = new_stop.value + 1), Product((-1,new_start)))) + return single_sum(IntLiteral(value = new_stop.value + 1)) - new_start else: - return Sum((new_stop, Product((-1,new_start)), IntLiteral(1))) + return single_sum(new_stop) - new_start + IntLiteral(1) def fix_scalar_syntax(routine): From afb17af34765caffa09a1b3912aa6ddd351b7d23 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Thu, 19 Oct 2023 14:12:43 +0200 Subject: [PATCH 10/27] Simplify (?) and add docstrings --- loki/transform/transform_scalar_syntax.py | 144 ++++++++++++++++------ 1 file changed, 108 insertions(+), 36 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 60876df58..9d1627157 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -39,57 +39,126 @@ def check_if_scalar_syntax(arg, dummy): def single_sum(expr): + """ + Return a Sum object of expr if expr is not an instance of pymbolic.primitives.Sum. + Otherwise return expr + + Parameters + ---------- + expr: any pymbolic expression + """ if isinstance(expr, pmbl.Sum): return expr else: return Sum((expr,)) -def sum_ints(expr): +def product_value(expr): + """ + If expr is an instance of pymbolic.primitives.Product, try to evaluate it + If it is possible, return the value as an int. + If it is not possible, try to simplify the the product and return as a Product + If it is not a pymbolic.primitives.Product , return expr + + Note: Negative numbers and subtractions in Sums are represented as Product of + the integer -1 and the symbol. This complicates matters. + + Parameters + ---------- + expr: any pymbolic expression + """ + if isinstance(expr, pmbl.Product): + m = 1 + new_children = [] + for c in expr.children: + if isinstance(c, IntLiteral): + m = m*c.value + elif isinstance(c, int): + m = m*c + else: + new_children += [c] + + if m == 0: + return 0 + elif not new_children: + return m + else: + if m > 1: + m = IntLiteral(m) + elif m < -1: + m = Product((-1, IntLiteral(abs(m)))) + return m*Product(as_tuple(new_children)) + else: + return expr + + +def simplify_sum(expr): + """ + If expr is an instance of pymbolic.primitives.Sum, + try to simplify it by evaluating any Products and adding up ints and IntLiterals. + If the sum can be reduced to a number, it returns an IntLiteral + If the Sum reduces to one expression, it returns that expression + + Parameters + ---------- + expr: any pymbolic expression + """ + if isinstance(expr, pmbl.Sum): n = 0 new_children = [] for c in expr.children: + c = product_value(c) if isinstance(c, IntLiteral): n += c.value - elif (isinstance(c, pmbl.Product) and - all(isinstance(cc, IntLiteral) or isinstance(cc,int) for cc in c.children)): - m = 1 - for cc in c.children: - if isinstance(cc, IntLiteral): - m = m*cc.value - else: - m = m*cc - n += m + elif isinstance(c, int): + n += c else: new_children += [c] - if n != 0: - new_children += [IntLiteral(n)] + if new_children: + if n > 0: + new_children += [IntLiteral(n)] + elif n < 0: + new_children += [Product((-1,IntLiteral(abs(n))))] + + if len(new_children) > 1: + return Sum(as_tuple(new_children)) + else: + return new_children[0] - expr.children = as_tuple(new_children) - + else: + return IntLiteral(n) + else: + return expr def construct_range_index(lower, length): + """ + Construct a range index from lower to lower + length - 1 - if lower == IntLiteral(1): - new_high = length - elif isinstance(lower, IntLiteral) and isinstance(length, IntLiteral): - new_high = IntLiteral(value = length.value + lower.value - 1) - elif isinstance(lower, IntLiteral): - new_high = single_sum(length) + IntLiteral(value = lower.value - 1) - elif isinstance(length, IntLiteral): - new_high = single_sum(lower) + IntLiteral(value = length.value - 1) - else: - new_high = single_sum(length) + lower - IntLiteral(1) + Parameters + ---------- + lower : any pymbolic expression + length: any pymbolic expression + """ - sum_ints(new_high) + new_high = simplify_sum(single_sum(length) + lower - IntLiteral(1)) return RangeIndex((lower, new_high)) def process_symbol(symbol, caller, call): + """ + Map symbol in call.routine to the appropriate symbol in caller, + taking any parents into account + + Parameters + ---------- + symbol: Loki variable in call.routine + caller: Subroutine object containing call + call : Call object + """ if isinstance(symbol, IntLiteral): return symbol @@ -107,19 +176,22 @@ def process_symbol(symbol, caller, call): raise RuntimeError('[Loki::fix_scalar_syntax] Unable to resolve argument dimension. Module variable?') -def construct_length(xrange, routine, call): +def construct_length(xrange, caller, call): + """ + Construct an expression for the length of xrange, + defined in call.routine, in caller. + + Parameters + ---------- + xrange: RangeIndex object defined in call.routine + caller: Subroutine object + call : call contained in caller + """ - new_start = process_symbol(xrange.start, routine, call) - new_stop = process_symbol(xrange.stop, routine, call) + new_start = process_symbol(xrange.start, caller, call) + new_stop = process_symbol(xrange.stop, caller, call) - if isinstance(new_start, IntLiteral) and isinstance(new_stop, IntLiteral): - return IntLiteral(value = new_stop.value - new_start.value + 1) - elif isinstance(new_start, IntLiteral): - return single_sum(new_stop) - IntLiteral(value = new_start.value - 1) - elif isinstance(new_stop, IntLiteral): - return single_sum(IntLiteral(value = new_stop.value + 1)) - new_start - else: - return single_sum(new_stop) - new_start + IntLiteral(1) + return simplify_sum(single_sum(new_stop) - new_start + IntLiteral(1)) def fix_scalar_syntax(routine): From b6b66cc65fc563b88d214c68a6eefe98c4994d6c Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Thu, 19 Oct 2023 14:24:38 +0200 Subject: [PATCH 11/27] fix style --- loki/transform/transform_scalar_syntax.py | 42 ++++++++++++----------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 9d1627157..a5078a4c0 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -5,15 +5,16 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import pymbolic.primitives as pmbl + from loki.expression import ( - Sum, Product, IntLiteral, Scalar, Array, RangeIndex, + Sum, Product, IntLiteral, Array, RangeIndex, SubstituteExpressions ) from loki.ir import CallStatement from loki.visitors import FindNodes, Transformer from loki.tools import as_tuple from loki.types import BasicType -import pymbolic.primitives as pmbl __all__ = [ @@ -49,8 +50,7 @@ def single_sum(expr): """ if isinstance(expr, pmbl.Sum): return expr - else: - return Sum((expr,)) + return Sum((expr,)) def product_value(expr): @@ -80,14 +80,16 @@ def product_value(expr): if m == 0: return 0 - elif not new_children: + if not new_children: return m - else: - if m > 1: - m = IntLiteral(m) - elif m < -1: - m = Product((-1, IntLiteral(abs(m)))) - return m*Product(as_tuple(new_children)) + + if m > 1: + m = IntLiteral(m) + elif m < -1: + m = Product((-1, IntLiteral(abs(m)))) + + return m*Product(as_tuple(new_children)) + else: return expr @@ -124,8 +126,7 @@ def simplify_sum(expr): if len(new_children) > 1: return Sum(as_tuple(new_children)) - else: - return new_children[0] + return new_children[0] else: return IntLiteral(n) @@ -163,7 +164,7 @@ def process_symbol(symbol, caller, call): if isinstance(symbol, IntLiteral): return symbol - elif not symbol.parents: + if not symbol.parents: if symbol in call.routine.arguments: return call.arg_map[symbol] @@ -226,8 +227,10 @@ def fix_scalar_syntax(routine): new_args = [] + found_scalar = False for dummy, arg in call.arg_map.items(): if check_if_scalar_syntax(arg, dummy): + found_scalar = True new_dims = [] for s, lower in zip(dummy.shape, arg.dimensions): @@ -238,14 +241,13 @@ def fix_scalar_syntax(routine): new_dims += [construct_range_index(lower, process_symbol(s, routine, call))] if len(arg.dimensions) > len(dummy.shape): - new_dims += [d for d in arg.dimensions[len(dummy.shape):]] - + new_dims += arg.dimensions[len(dummy.shape):] new_args += [arg.clone(dimensions=as_tuple(new_dims)),] - else: - new_args += [arg,] - call_map[call] = call.clone(arguments = as_tuple(new_args)) + if found_scalar: + call_map[call] = call.clone(arguments = as_tuple(new_args)) - routine.body = Transformer(call_map).visit(routine.body) + if call_map: + routine.body = Transformer(call_map).visit(routine.body) From 674a6f59018be6cdc3d321490c04834552aebbac Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Thu, 19 Oct 2023 14:29:17 +0200 Subject: [PATCH 12/27] more style --- loki/transform/transform_scalar_syntax.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index a5078a4c0..f76d09e07 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -90,8 +90,7 @@ def product_value(expr): return m*Product(as_tuple(new_children)) - else: - return expr + return expr def simplify_sum(expr): @@ -130,8 +129,8 @@ def simplify_sum(expr): else: return IntLiteral(n) - else: - return expr + + return expr def construct_range_index(lower, length): From c2ec915e8c086a4a19a8e7a6c8b4efa0e125cf34 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Thu, 19 Oct 2023 14:34:08 +0200 Subject: [PATCH 13/27] even more style --- loki/transform/transform_scalar_syntax.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index f76d09e07..64b1f7522 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -126,10 +126,7 @@ def simplify_sum(expr): if len(new_children) > 1: return Sum(as_tuple(new_children)) return new_children[0] - - else: - return IntLiteral(n) - + return IntLiteral(n) return expr From 9c2f44173a93545aa545d4d3559f30b001d33819 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Thu, 19 Oct 2023 16:01:59 +0200 Subject: [PATCH 14/27] Ensure Loki versions of Sum and Product --- loki/transform/transform_scalar_syntax.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 64b1f7522..9a1931f8e 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -62,6 +62,7 @@ def product_value(expr): Note: Negative numbers and subtractions in Sums are represented as Product of the integer -1 and the symbol. This complicates matters. + Note: Ensure that a Loki Product is returned, not a pymbolic Product Parameters ---------- @@ -84,11 +85,11 @@ def product_value(expr): return m if m > 1: - m = IntLiteral(m) + new_children = [IntLiteral(m)] + new_children elif m < -1: - m = Product((-1, IntLiteral(abs(m)))) + new_children = [-1, IntLiteral(abs(m))] + new_children - return m*Product(as_tuple(new_children)) + return Product(as_tuple(new_children)) return expr @@ -100,6 +101,8 @@ def simplify_sum(expr): If the sum can be reduced to a number, it returns an IntLiteral If the Sum reduces to one expression, it returns that expression + Note: Ensure that a Loki Sum is returned, not a pymbolic Sum + Parameters ---------- expr: any pymbolic expression @@ -188,7 +191,7 @@ def construct_length(xrange, caller, call): new_start = process_symbol(xrange.start, caller, call) new_stop = process_symbol(xrange.stop, caller, call) - return simplify_sum(single_sum(new_stop) - new_start + IntLiteral(1)) + return single_sum(new_stop) - new_start + IntLiteral(1) def fix_scalar_syntax(routine): @@ -208,6 +211,10 @@ def fix_scalar_syntax(routine): call myroutine(a(i:i+5,j) + Note: Using the __add__ and __mul__ functions of Sum and Product, respectively, + returns the pymbolic.primitives version of the objuect, not the loki.expressions version. + simplify_sum and product_value returns loki versions, so this is currently not an issue, + but this can cause unexpected behaviour Parameters ---------- From b3b78ca4029a66de99709de5a77e2d3900c701fb Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 20 Oct 2023 12:10:51 +0200 Subject: [PATCH 15/27] fix another negative bug --- loki/transform/transform_scalar_syntax.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 9a1931f8e..f62200790 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -48,6 +48,7 @@ def single_sum(expr): ---------- expr: any pymbolic expression """ + if isinstance(expr, pmbl.Sum): return expr return Sum((expr,)) @@ -78,7 +79,6 @@ def product_value(expr): m = m*c else: new_children += [c] - if m == 0: return 0 if not new_children: @@ -86,6 +86,8 @@ def product_value(expr): if m > 1: new_children = [IntLiteral(m)] + new_children + elif m == -1: + new_children = [-1] + new_children elif m < -1: new_children = [-1, IntLiteral(abs(m))] + new_children From 7fac0d274ad6da91ef12080df52ca3b1b1c2f954 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 20 Oct 2023 12:11:28 +0200 Subject: [PATCH 16/27] add tests --- tests/test_transform_scalar.py | 207 +++++++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) create mode 100644 tests/test_transform_scalar.py diff --git a/tests/test_transform_scalar.py b/tests/test_transform_scalar.py new file mode 100644 index 000000000..7c1bc8b15 --- /dev/null +++ b/tests/test_transform_scalar.py @@ -0,0 +1,207 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest + +from conftest import available_frontends +from loki.transform import fix_scalar_syntax +from loki.module import Module +from loki.ir import CallStatement +from loki.visitors import FindNodes +from loki.expression import Sum, IntLiteral, Scalar, Product, RangeIndex + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_utilities_recursive_expression_map_update(frontend): + fcode = """ +module mod_a + implicit none + + type type_b + integer :: c + integer :: d + end type type_b + + type type_a + type(type_b) :: b + end type type_a + +contains + + subroutine main() + + type(type_a) :: a + integer :: k, m, n + + real :: array(10,10) + + call sub_a(array(1, 1), k) + call sub_a(array(2, 2), k) + call sub_a(array(m, m), k) + call sub_a(array(m-1, m-1), k) + call sub_a(array(a%b%c, a%b%c), k) + + call sub_b(array(1, 1)) + call sub_b(array(2, 2)) + call sub_b(array(m, 2)) + call sub_b(array(m-1, m), k) + call sub_b(array(a%b%c, 2)) + + call sub_c(array(1, 1), k) + call sub_c(array(2, 2), k) + call sub_c(array(m, 1), k) + call sub_c(array(m-1, m), k) + call sub_c(array(a%b%c, 1), k) + + call sub_d(array(1, 1), 1, n) + call sub_d(array(2, 2), 1, n) + call sub_d(array(m, 1), k, n) + call sub_d(array(m-1, 1), k, n-1) + call sub_d(array(a%b%c, 1), 1, n) + + call sub_e(array(1, 1), a%b) + call sub_e(array(2, 2), a%b) + call sub_e(array(m, 1), a%b) + call sub_e(array(m-1, 1), a%b) + call sub_e(array(a%b%c, 1), a%b) + + call sub_x(array(1, 1), 1) + call sub_x(array(2, 2), 2) + call sub_x(array(m, 1), k) + call sub_x(array(m-1, 1), k-1) + call sub_x(array(a%b%c, 1), a%b%d) + + contains + + subroutine sub_x(array, k) + + integer, intent(in) :: k + real, intent(in) :: array(k:n) + + end subroutine sub_x + + end subroutine main + + subroutine sub_a(array, k) + + integer, intent(in) :: k + real, intent(in) :: array(k) + + end subroutine sub_a + + subroutine sub_b(array) + + real, intent(in) :: array(1:3) + + end subroutine sub_b + + subroutine sub_c(array, k) + + integer, intent(in) :: k + real, intent(in) :: array(2:k) + + end subroutine sub_c + + subroutine sub_d(array, k, n) + + integer, intent(in) :: k, n + real, intent(in) :: array(k:n) + + end subroutine sub_d + + subroutine sub_e(array, x) + + type(type_b), intent(in) :: x + real, intent(in) :: array(x%d) + + end subroutine sub_e + +end module mod_a + """.strip() + + module = Module.from_source(fcode, frontend=frontend) + routine = module['main'] + + fix_scalar_syntax(routine) + + calls = FindNodes(CallStatement).visit(routine.body) + + one = IntLiteral(1) + two = IntLiteral(2) + three = IntLiteral(3) + four = IntLiteral(4) + m_one = Product((-1,one)) + m_two = Product((-1,two)) + m_three = Product((-1,three)) + m = Scalar('m') + n = Scalar('n') + k = Scalar('k') + m_k = Product((-1,k)) + abc = Scalar(name='a%b%c', parent=Scalar(name='a%b', parent=Scalar('a'))) + abd = Scalar(name='a%b%d', parent=Scalar(name='a%b', parent=Scalar('a'))) + m_abd = Product((-1,abd)) + + #Check that second dimension is properly added + assert calls[0].arguments[0].dimensions[1] == one + assert calls[1].arguments[0].dimensions[1] == two + assert calls[2].arguments[0].dimensions[1] == m + assert calls[3].arguments[0].dimensions[1] == Sum((m,m_one)) + assert calls[4].arguments[0].dimensions[1] == abc + + #Check that start of ranges is correct + assert calls[0].arguments[0].dimensions[0].start == one + assert calls[1].arguments[0].dimensions[0].start == two + assert calls[2].arguments[0].dimensions[0].start == m + assert calls[3].arguments[0].dimensions[0].start == Sum((m,m_one)) + assert calls[4].arguments[0].dimensions[0].start == abc + + #Check that stop of ranges is correct + #sub_a + assert calls[0].arguments[0].dimensions[0].stop == k + assert calls[1].arguments[0].dimensions[0].stop == Sum((k,one)) + assert calls[2].arguments[0].dimensions[0].stop == Sum((k,m,m_one)) + assert calls[3].arguments[0].dimensions[0].stop == Sum((k,m,m_two)) + assert calls[4].arguments[0].dimensions[0].stop == Sum((k,abc,m_one)) + + #sub_b + assert calls[5].arguments[0].dimensions[0].stop == three + assert calls[6].arguments[0].dimensions[0].stop == four + assert calls[7].arguments[0].dimensions[0].stop == Sum((m,two)) + assert calls[8].arguments[0].dimensions[0].stop == Sum((m,one)) + assert calls[9].arguments[0].dimensions[0].stop == Sum((abc,two)) + + #sub_c + assert calls[10].arguments[0].dimensions[0].stop == Sum((k,m_one)) + assert calls[11].arguments[0].dimensions[0].stop == k + assert calls[12].arguments[0].dimensions[0].stop == Sum((k,m,m_two)) + assert calls[13].arguments[0].dimensions[0].stop == Sum((k,m,m_three)) + assert calls[14].arguments[0].dimensions[0].stop == Sum((k,abc,m_two)) + + #sub_d + assert calls[15].arguments[0].dimensions[0].stop == n + assert calls[16].arguments[0].dimensions[0].stop == Sum((n,one)) + assert calls[17].arguments[0].dimensions[0].stop == Sum((n,m_k,m)) + assert calls[18].arguments[0].dimensions[0].stop == Sum((n,m_k,m,m_two)) + assert calls[19].arguments[0].dimensions[0].stop == Sum((n,abc,m_one)) + + #sub_e + assert calls[20].arguments[0].dimensions[0].stop == abd + assert calls[21].arguments[0].dimensions[0].stop == Sum((abd,one)) + assert calls[22].arguments[0].dimensions[0].stop == Sum((abd,m,m_one)) + assert calls[23].arguments[0].dimensions[0].stop == Sum((abd,m,m_two)) + assert calls[24].arguments[0].dimensions[0].stop == Sum((abd,abc,m_one)) + + #sub_x + assert calls[25].arguments[0].dimensions[0].stop == n + assert calls[26].arguments[0].dimensions[0].stop == n + assert calls[27].arguments[0].dimensions[0].stop == Sum((n,m_k,m)) + assert calls[28].arguments[0].dimensions[0].stop == Sum((n,Product((-1,Sum((k, m_one)))),m,m_one)) + assert calls[29].arguments[0].dimensions[0].stop == Sum((n,m_abd,abc)) + + + + From b546733550047aa8313be4d9dda0cb69c920d237 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 20 Oct 2023 12:19:19 +0200 Subject: [PATCH 17/27] going out in style --- ...nsform_scalar.py => test_transform_scalar_notation.py} | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) rename tests/{test_transform_scalar.py => test_transform_scalar_notation.py} (97%) diff --git a/tests/test_transform_scalar.py b/tests/test_transform_scalar_notation.py similarity index 97% rename from tests/test_transform_scalar.py rename to tests/test_transform_scalar_notation.py index 7c1bc8b15..ed414eace 100644 --- a/tests/test_transform_scalar.py +++ b/tests/test_transform_scalar_notation.py @@ -12,11 +12,11 @@ from loki.module import Module from loki.ir import CallStatement from loki.visitors import FindNodes -from loki.expression import Sum, IntLiteral, Scalar, Product, RangeIndex +from loki.expression import Sum, IntLiteral, Scalar, Product @pytest.mark.parametrize('frontend', available_frontends()) -def test_transform_utilities_recursive_expression_map_update(frontend): +def test_transform_scalar_notation(frontend): fcode = """ module mod_a implicit none @@ -201,7 +201,3 @@ def test_transform_utilities_recursive_expression_map_update(frontend): assert calls[27].arguments[0].dimensions[0].stop == Sum((n,m_k,m)) assert calls[28].arguments[0].dimensions[0].stop == Sum((n,Product((-1,Sum((k, m_one)))),m,m_one)) assert calls[29].arguments[0].dimensions[0].stop == Sum((n,m_abd,abc)) - - - - From 94c89c60e82b505c97d7c8a0d8342a2abc13117f Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Mon, 30 Oct 2023 14:47:41 +0100 Subject: [PATCH 18/27] Add option for turning scalar fix on and off --- cmake/loki_transform.cmake | 13 +++++++++---- cmake/loki_transform_helpers.cmake | 4 ++++ scripts/loki_transform.py | 8 ++++++-- .../transformations/single_column_coalesced.py | 7 +++++-- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/cmake/loki_transform.cmake b/cmake/loki_transform.cmake index dc14af58e..6ca391992 100644 --- a/cmake/loki_transform.cmake +++ b/cmake/loki_transform.cmake @@ -25,6 +25,7 @@ include( loki_transform_helpers ) # [CPP] # [FRONTEND ] # [INLINE_MEMBERS] +# [FIX_SCALAR_SYNTAX] # [BUILDDIR ] # [SOURCES [ ...]] # [HEADERS [ ...]] @@ -46,7 +47,7 @@ function( loki_transform ) set( options CPP DATA_OFFLOAD REMOVE_OPENMP ASSUME_DEVICEPTR TRIM_VECTOR_SECTIONS GLOBAL_VAR_OFFLOAD - REMOVE_DERIVED_ARGS INLINE_MEMBERS DERIVE_ARGUMENT_ARRAY_SHAPE + REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SCALAR_SYNTAX DERIVE_ARGUMENT_ARRAY_SHAPE ) set( oneValueArgs COMMAND MODE DIRECTIVE FRONTEND CONFIG BUILDDIR @@ -193,7 +194,7 @@ endfunction() # [DIRECTIVE ] # [SOURCES [ ...]] # [HEADERS [ ...]] -# [NO_PLAN_SOURCEDIR COPY_UNMODIFIED INLINE_MEMBERS] +# [NO_PLAN_SOURCEDIR COPY_UNMODIFIED INLINE_MEMBERS FIX_SCALAR_SYNTAX] # ) # # Applies a Loki bulk transformation to the source files belonging to a particular @@ -222,7 +223,7 @@ endfunction() function( loki_transform_target ) - set( options NO_PLAN_SOURCEDIR COPY_UNMODIFIED CPP CPP_PLAN INLINE_MEMBERS ) + set( options NO_PLAN_SOURCEDIR COPY_UNMODIFIED CPP CPP_PLAN INLINE_MEMBERS FIX_SCALAR_SYNTAX ) set( single_value_args TARGET COMMAND MODE DIRECTIVE FRONTEND CONFIG PLAN ) set( multi_value_args SOURCES HEADERS ) @@ -291,6 +292,10 @@ function( loki_transform_target ) list( APPEND _TRANSFORM_OPTIONS INLINE_MEMBERS ) endif() + if( _PAR_FIX_SCALAR_SYNTAX ) + list( APPEND _TRANSFORM_OPTIONS FIX_SCALAR_SYNTAX ) + endif() + loki_transform( COMMAND ${_PAR_COMMAND} OUTPUT ${LOKI_SOURCES_TO_APPEND} @@ -384,7 +389,7 @@ or set( options CPP DATA_OFFLOAD REMOVE_OPENMP ASSUME_DEVICEPTR GLOBAL_VAR_OFFLOAD - TRIM_VECTOR_SECTIONS REMOVE_DERIVED_ARGS INLINE_MEMBERS + TRIM_VECTOR_SECTIONS REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SCALAR_SYNTAX ) set( oneValueArgs MODE DIRECTIVE FRONTEND CONFIG PATH OUTPATH diff --git a/cmake/loki_transform_helpers.cmake b/cmake/loki_transform_helpers.cmake index a3a5b344d..4dc871e36 100644 --- a/cmake/loki_transform_helpers.cmake +++ b/cmake/loki_transform_helpers.cmake @@ -112,6 +112,10 @@ macro( _loki_transform_parse_options ) list( APPEND _ARGS --inline-members ) endif() + if( _PAR_FIX_SCALAR_SYNTAX ) + list( APPEND _ARGS --fix-scalar-syntax ) + endif() + if( _PAR_DERIVE_ARGUMENT_ARRAY_SHAPE ) list( APPEND _ARGS --derive-argument-array-shape ) endif() diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index 1679a66a4..61ee9e54b 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -108,12 +108,15 @@ def cli(debug): help="Remove derived-type arguments and replace with canonical arguments") @click.option('--inline-members/--no-inline-members', default=False, help='Inline member functions for SCC-class transformations.') +@click.option('--fix-scalar-syntax/--no-fix-scalar-syntax', default=False, + help='Replace array arguments passed as scalars with arrays.') @click.option('--derive-argument-array-shape/--no-derive-argument-array-shape', default=False, help="Recursively derive explicit shape dimension for argument arrays") def convert( mode, config, build, source, header, cpp, directive, include, define, omni_include, xmod, data_offload, remove_openmp, assume_deviceptr, frontend, trim_vector_sections, - global_var_offload, remove_derived_args, inline_members, derive_argument_array_shape + global_var_offload, remove_derived_args, inline_members, fix_scalar_syntax, + derive_argument_array_shape ): """ Batch-processing mode for Fortran-to-Fortran transformations that @@ -207,7 +210,8 @@ def convert( if mode in ['scc', 'scc-hoist', 'scc-stack']: # Apply the basic SCC transformation set scheduler.process( SCCBaseTransformation( - horizontal=horizontal, directive=directive, inline_members=inline_members + horizontal=horizontal, directive=directive, + inline_members=inline_members, fix_scalar_syntax=fix_scalar_syntax )) scheduler.process( SCCDevectorTransformation( horizontal=horizontal, trim_vector_sections=trim_vector_sections diff --git a/transformations/transformations/single_column_coalesced.py b/transformations/transformations/single_column_coalesced.py index 77712e0fc..d1ac75bde 100644 --- a/transformations/transformations/single_column_coalesced.py +++ b/transformations/transformations/single_column_coalesced.py @@ -40,12 +40,13 @@ class methods can be called directly. Enable full source-inlining of member subroutines; default: False. """ - def __init__(self, horizontal, directive=None, inline_members=False): + def __init__(self, horizontal, directive=None, inline_members=False, fix_scalar_syntax=False): self.horizontal = horizontal assert directive in [None, 'openacc'] self.directive = directive + self.fix_scalar_syntax = fix_scalar_syntax self.inline_members = inline_members @classmethod @@ -291,7 +292,9 @@ def process_kernel(self, routine): # Find the iteration index variable for the specified horizontal v_index = self.get_integer_variable(routine, name=self.horizontal.index) - fix_scalar_syntax(routine) + # Transform arrays passed with scalar syntax to array syntax + if self.fix_scalar_syntax: + fix_scalar_syntax(routine) # Perform full source-inlining for member subroutines if so requested if self.inline_members: From 85e93bf834fc368843c1047fdb73d9f5c8704ff5 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Mon, 30 Oct 2023 14:56:59 +0100 Subject: [PATCH 19/27] moved option setting a bit --- transformations/transformations/single_column_coalesced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformations/transformations/single_column_coalesced.py b/transformations/transformations/single_column_coalesced.py index d1ac75bde..00246d8c1 100644 --- a/transformations/transformations/single_column_coalesced.py +++ b/transformations/transformations/single_column_coalesced.py @@ -46,8 +46,8 @@ def __init__(self, horizontal, directive=None, inline_members=False, fix_scalar_ assert directive in [None, 'openacc'] self.directive = directive - self.fix_scalar_syntax = fix_scalar_syntax self.inline_members = inline_members + self.fix_scalar_syntax = fix_scalar_syntax @classmethod def check_routine_pragmas(cls, routine, directive): From bf0590ea0e4eb3ca3ea9c29f7a406ab107863c7f Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Mon, 30 Oct 2023 15:03:25 +0100 Subject: [PATCH 20/27] fix linter complaints --- cmake/loki_transform.cmake | 14 +++++++------- cmake/loki_transform_helpers.cmake | 4 ++-- scripts/loki_transform.py | 6 +++--- .../transformations/single_column_coalesced.py | 6 +++--- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/cmake/loki_transform.cmake b/cmake/loki_transform.cmake index 6ca391992..cb979ae53 100644 --- a/cmake/loki_transform.cmake +++ b/cmake/loki_transform.cmake @@ -25,7 +25,7 @@ include( loki_transform_helpers ) # [CPP] # [FRONTEND ] # [INLINE_MEMBERS] -# [FIX_SCALAR_SYNTAX] +# [FIX_SCALARS] # [BUILDDIR ] # [SOURCES [ ...]] # [HEADERS [ ...]] @@ -47,7 +47,7 @@ function( loki_transform ) set( options CPP DATA_OFFLOAD REMOVE_OPENMP ASSUME_DEVICEPTR TRIM_VECTOR_SECTIONS GLOBAL_VAR_OFFLOAD - REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SCALAR_SYNTAX DERIVE_ARGUMENT_ARRAY_SHAPE + REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SCALARS DERIVE_ARGUMENT_ARRAY_SHAPE ) set( oneValueArgs COMMAND MODE DIRECTIVE FRONTEND CONFIG BUILDDIR @@ -194,7 +194,7 @@ endfunction() # [DIRECTIVE ] # [SOURCES [ ...]] # [HEADERS [ ...]] -# [NO_PLAN_SOURCEDIR COPY_UNMODIFIED INLINE_MEMBERS FIX_SCALAR_SYNTAX] +# [NO_PLAN_SOURCEDIR COPY_UNMODIFIED INLINE_MEMBERS FIX_SCALARS] # ) # # Applies a Loki bulk transformation to the source files belonging to a particular @@ -223,7 +223,7 @@ endfunction() function( loki_transform_target ) - set( options NO_PLAN_SOURCEDIR COPY_UNMODIFIED CPP CPP_PLAN INLINE_MEMBERS FIX_SCALAR_SYNTAX ) + set( options NO_PLAN_SOURCEDIR COPY_UNMODIFIED CPP CPP_PLAN INLINE_MEMBERS FIX_SCALARS ) set( single_value_args TARGET COMMAND MODE DIRECTIVE FRONTEND CONFIG PLAN ) set( multi_value_args SOURCES HEADERS ) @@ -292,8 +292,8 @@ function( loki_transform_target ) list( APPEND _TRANSFORM_OPTIONS INLINE_MEMBERS ) endif() - if( _PAR_FIX_SCALAR_SYNTAX ) - list( APPEND _TRANSFORM_OPTIONS FIX_SCALAR_SYNTAX ) + if( _PAR_FIX_SCALARS ) + list( APPEND _TRANSFORM_OPTIONS FIX_SCALARS ) endif() loki_transform( @@ -389,7 +389,7 @@ or set( options CPP DATA_OFFLOAD REMOVE_OPENMP ASSUME_DEVICEPTR GLOBAL_VAR_OFFLOAD - TRIM_VECTOR_SECTIONS REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SCALAR_SYNTAX + TRIM_VECTOR_SECTIONS REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SCALARS ) set( oneValueArgs MODE DIRECTIVE FRONTEND CONFIG PATH OUTPATH diff --git a/cmake/loki_transform_helpers.cmake b/cmake/loki_transform_helpers.cmake index 4dc871e36..9e9aa9601 100644 --- a/cmake/loki_transform_helpers.cmake +++ b/cmake/loki_transform_helpers.cmake @@ -112,8 +112,8 @@ macro( _loki_transform_parse_options ) list( APPEND _ARGS --inline-members ) endif() - if( _PAR_FIX_SCALAR_SYNTAX ) - list( APPEND _ARGS --fix-scalar-syntax ) + if( _PAR_FIX_SCALARS ) + list( APPEND _ARGS --fix-scalars ) endif() if( _PAR_DERIVE_ARGUMENT_ARRAY_SHAPE ) diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index 61ee9e54b..fabf34f0d 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -108,14 +108,14 @@ def cli(debug): help="Remove derived-type arguments and replace with canonical arguments") @click.option('--inline-members/--no-inline-members', default=False, help='Inline member functions for SCC-class transformations.') -@click.option('--fix-scalar-syntax/--no-fix-scalar-syntax', default=False, +@click.option('--fix-scalars/--no-fix-scalars', default=False, help='Replace array arguments passed as scalars with arrays.') @click.option('--derive-argument-array-shape/--no-derive-argument-array-shape', default=False, help="Recursively derive explicit shape dimension for argument arrays") def convert( mode, config, build, source, header, cpp, directive, include, define, omni_include, xmod, data_offload, remove_openmp, assume_deviceptr, frontend, trim_vector_sections, - global_var_offload, remove_derived_args, inline_members, fix_scalar_syntax, + global_var_offload, remove_derived_args, inline_members, fix_scalars, derive_argument_array_shape ): """ @@ -211,7 +211,7 @@ def convert( # Apply the basic SCC transformation set scheduler.process( SCCBaseTransformation( horizontal=horizontal, directive=directive, - inline_members=inline_members, fix_scalar_syntax=fix_scalar_syntax + inline_members=inline_members, fix_scalars=fix_scalars )) scheduler.process( SCCDevectorTransformation( horizontal=horizontal, trim_vector_sections=trim_vector_sections diff --git a/transformations/transformations/single_column_coalesced.py b/transformations/transformations/single_column_coalesced.py index 00246d8c1..2b4cf310e 100644 --- a/transformations/transformations/single_column_coalesced.py +++ b/transformations/transformations/single_column_coalesced.py @@ -40,14 +40,14 @@ class methods can be called directly. Enable full source-inlining of member subroutines; default: False. """ - def __init__(self, horizontal, directive=None, inline_members=False, fix_scalar_syntax=False): + def __init__(self, horizontal, directive=None, inline_members=False, fix_scalars=False): self.horizontal = horizontal assert directive in [None, 'openacc'] self.directive = directive self.inline_members = inline_members - self.fix_scalar_syntax = fix_scalar_syntax + self.fix_scalars = fix_scalars @classmethod def check_routine_pragmas(cls, routine, directive): @@ -293,7 +293,7 @@ def process_kernel(self, routine): v_index = self.get_integer_variable(routine, name=self.horizontal.index) # Transform arrays passed with scalar syntax to array syntax - if self.fix_scalar_syntax: + if self.fix_scalars: fix_scalar_syntax(routine) # Perform full source-inlining for member subroutines if so requested From 64a959d5318ab9a0d9244571ad5fbad38ce47f45 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Thu, 9 Nov 2023 16:57:40 +0100 Subject: [PATCH 21/27] scalar_syntax -> sequence_association --- cmake/loki_transform.cmake | 14 +++++++------- cmake/loki_transform_helpers.cmake | 4 ++-- loki/transform/__init__.py | 2 +- ...syntax.py => transform_sequence_association.py} | 6 +++--- scripts/loki_transform.py | 6 +++--- ...n.py => test_transform_sequence_association.py} | 4 ++-- .../transformations/single_column_coalesced.py | 10 +++++----- 7 files changed, 23 insertions(+), 23 deletions(-) rename loki/transform/{transform_scalar_syntax.py => transform_sequence_association.py} (97%) rename tests/{test_transform_scalar_notation.py => test_transform_sequence_association.py} (98%) diff --git a/cmake/loki_transform.cmake b/cmake/loki_transform.cmake index cb979ae53..595bd9fae 100644 --- a/cmake/loki_transform.cmake +++ b/cmake/loki_transform.cmake @@ -25,7 +25,7 @@ include( loki_transform_helpers ) # [CPP] # [FRONTEND ] # [INLINE_MEMBERS] -# [FIX_SCALARS] +# [FIX_SEQUENCE_ASSOCIATION] # [BUILDDIR ] # [SOURCES [ ...]] # [HEADERS [ ...]] @@ -47,7 +47,7 @@ function( loki_transform ) set( options CPP DATA_OFFLOAD REMOVE_OPENMP ASSUME_DEVICEPTR TRIM_VECTOR_SECTIONS GLOBAL_VAR_OFFLOAD - REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SCALARS DERIVE_ARGUMENT_ARRAY_SHAPE + REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SEQUENCE_ASSOCIATION DERIVE_ARGUMENT_ARRAY_SHAPE ) set( oneValueArgs COMMAND MODE DIRECTIVE FRONTEND CONFIG BUILDDIR @@ -194,7 +194,7 @@ endfunction() # [DIRECTIVE ] # [SOURCES [ ...]] # [HEADERS [ ...]] -# [NO_PLAN_SOURCEDIR COPY_UNMODIFIED INLINE_MEMBERS FIX_SCALARS] +# [NO_PLAN_SOURCEDIR COPY_UNMODIFIED INLINE_MEMBERS FIX_SEQUENCE_ASSOCIATION] # ) # # Applies a Loki bulk transformation to the source files belonging to a particular @@ -223,7 +223,7 @@ endfunction() function( loki_transform_target ) - set( options NO_PLAN_SOURCEDIR COPY_UNMODIFIED CPP CPP_PLAN INLINE_MEMBERS FIX_SCALARS ) + set( options NO_PLAN_SOURCEDIR COPY_UNMODIFIED CPP CPP_PLAN INLINE_MEMBERS FIX_SEQUENCE_ASSOCIATION ) set( single_value_args TARGET COMMAND MODE DIRECTIVE FRONTEND CONFIG PLAN ) set( multi_value_args SOURCES HEADERS ) @@ -292,8 +292,8 @@ function( loki_transform_target ) list( APPEND _TRANSFORM_OPTIONS INLINE_MEMBERS ) endif() - if( _PAR_FIX_SCALARS ) - list( APPEND _TRANSFORM_OPTIONS FIX_SCALARS ) + if( _PAR_FIX_SEQUENCE_ASSOCIATION ) + list( APPEND _TRANSFORM_OPTIONS FIX_SEQUENCE_ASSOCIATION ) endif() loki_transform( @@ -389,7 +389,7 @@ or set( options CPP DATA_OFFLOAD REMOVE_OPENMP ASSUME_DEVICEPTR GLOBAL_VAR_OFFLOAD - TRIM_VECTOR_SECTIONS REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SCALARS + TRIM_VECTOR_SECTIONS REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SEQUENCE_ASSOCIATION ) set( oneValueArgs MODE DIRECTIVE FRONTEND CONFIG PATH OUTPATH diff --git a/cmake/loki_transform_helpers.cmake b/cmake/loki_transform_helpers.cmake index 9e9aa9601..6b33267e5 100644 --- a/cmake/loki_transform_helpers.cmake +++ b/cmake/loki_transform_helpers.cmake @@ -112,8 +112,8 @@ macro( _loki_transform_parse_options ) list( APPEND _ARGS --inline-members ) endif() - if( _PAR_FIX_SCALARS ) - list( APPEND _ARGS --fix-scalars ) + if( _PAR_FIX_SEQUENCE_ASSOCIATION ) + list( APPEND _ARGS --fix-sequence-association ) endif() if( _PAR_DERIVE_ARGUMENT_ARRAY_SHAPE ) diff --git a/loki/transform/__init__.py b/loki/transform/__init__.py index 8f1dc4448..e3a51ed85 100644 --- a/loki/transform/__init__.py +++ b/loki/transform/__init__.py @@ -19,4 +19,4 @@ from loki.transform.build_system_transform import * # noqa from loki.transform.transform_hoist_variables import * # noqa from loki.transform.transform_parametrise import * # noqa -from loki.transform.transform_scalar_syntax import * # noqa +from loki.transform.transform_sequence_association import * # noqa diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_sequence_association.py similarity index 97% rename from loki/transform/transform_scalar_syntax.py rename to loki/transform/transform_sequence_association.py index f62200790..9b1454cfe 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_sequence_association.py @@ -18,7 +18,7 @@ __all__ = [ - 'fix_scalar_syntax' + 'fix_sequence_association' ] def check_if_scalar_syntax(arg, dummy): @@ -175,7 +175,7 @@ def process_symbol(symbol, caller, call): if call.routine in caller.members and symbol in caller.variables: return symbol - raise RuntimeError('[Loki::fix_scalar_syntax] Unable to resolve argument dimension. Module variable?') + raise RuntimeError('[Loki::fix_sequence_association] Unable to resolve argument dimension. Module variable?') def construct_length(xrange, caller, call): @@ -196,7 +196,7 @@ def construct_length(xrange, caller, call): return single_sum(new_stop) - new_start + IntLiteral(1) -def fix_scalar_syntax(routine): +def fix_sequence_association(routine): """ Housekeeping routine to replace scalar syntax when passing arrays as arguments For example, a call like diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index fabf34f0d..89eb302de 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -108,14 +108,14 @@ def cli(debug): help="Remove derived-type arguments and replace with canonical arguments") @click.option('--inline-members/--no-inline-members', default=False, help='Inline member functions for SCC-class transformations.') -@click.option('--fix-scalars/--no-fix-scalars', default=False, +@click.option('--fix-sequence-association/--no-fix-sequence-association', default=False, help='Replace array arguments passed as scalars with arrays.') @click.option('--derive-argument-array-shape/--no-derive-argument-array-shape', default=False, help="Recursively derive explicit shape dimension for argument arrays") def convert( mode, config, build, source, header, cpp, directive, include, define, omni_include, xmod, data_offload, remove_openmp, assume_deviceptr, frontend, trim_vector_sections, - global_var_offload, remove_derived_args, inline_members, fix_scalars, + global_var_offload, remove_derived_args, inline_members, fix_sequence_association, derive_argument_array_shape ): """ @@ -211,7 +211,7 @@ def convert( # Apply the basic SCC transformation set scheduler.process( SCCBaseTransformation( horizontal=horizontal, directive=directive, - inline_members=inline_members, fix_scalars=fix_scalars + inline_members=inline_members, fix_sequence_association=fix_sequence_association )) scheduler.process( SCCDevectorTransformation( horizontal=horizontal, trim_vector_sections=trim_vector_sections diff --git a/tests/test_transform_scalar_notation.py b/tests/test_transform_sequence_association.py similarity index 98% rename from tests/test_transform_scalar_notation.py rename to tests/test_transform_sequence_association.py index ed414eace..679a3b4df 100644 --- a/tests/test_transform_scalar_notation.py +++ b/tests/test_transform_sequence_association.py @@ -8,7 +8,7 @@ import pytest from conftest import available_frontends -from loki.transform import fix_scalar_syntax +from loki.transform import fix_sequence_association from loki.module import Module from loki.ir import CallStatement from loki.visitors import FindNodes @@ -126,7 +126,7 @@ def test_transform_scalar_notation(frontend): module = Module.from_source(fcode, frontend=frontend) routine = module['main'] - fix_scalar_syntax(routine) + fix_sequence_association(routine) calls = FindNodes(CallStatement).visit(routine.body) diff --git a/transformations/transformations/single_column_coalesced.py b/transformations/transformations/single_column_coalesced.py index 2b4cf310e..c5d1aa033 100644 --- a/transformations/transformations/single_column_coalesced.py +++ b/transformations/transformations/single_column_coalesced.py @@ -7,7 +7,7 @@ import re from loki.expression import symbols as sym -from loki.transform import resolve_associates, inline_member_procedures, fix_scalar_syntax +from loki.transform import resolve_associates, inline_member_procedures, fix_sequence_association from loki import ( Transformation, FindNodes, Transformer, info, pragmas_attached, as_tuple, flatten, ir, FindExpressions, @@ -40,14 +40,14 @@ class methods can be called directly. Enable full source-inlining of member subroutines; default: False. """ - def __init__(self, horizontal, directive=None, inline_members=False, fix_scalars=False): + def __init__(self, horizontal, directive=None, inline_members=False, fix_sequence_association=False): self.horizontal = horizontal assert directive in [None, 'openacc'] self.directive = directive self.inline_members = inline_members - self.fix_scalars = fix_scalars + self.fix_sequence_association = fix_sequence_association @classmethod def check_routine_pragmas(cls, routine, directive): @@ -293,8 +293,8 @@ def process_kernel(self, routine): v_index = self.get_integer_variable(routine, name=self.horizontal.index) # Transform arrays passed with scalar syntax to array syntax - if self.fix_scalars: - fix_scalar_syntax(routine) + if self.fix_sequence_association: + fix_sequence_association(routine) # Perform full source-inlining for member subroutines if so requested if self.inline_members: From 8f069d58a11af8c068327b5d91b9d4f6f548be14 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Thu, 9 Nov 2023 17:19:57 +0100 Subject: [PATCH 22/27] fix -> resolve --- cmake/loki_transform.cmake | 14 +++++++------- cmake/loki_transform_helpers.cmake | 4 ++-- loki/transform/transform_sequence_association.py | 6 +++--- scripts/loki_transform.py | 6 +++--- tests/test_transform_sequence_association.py | 4 ++-- .../transformations/single_column_coalesced.py | 10 +++++----- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/cmake/loki_transform.cmake b/cmake/loki_transform.cmake index 595bd9fae..2c3cab406 100644 --- a/cmake/loki_transform.cmake +++ b/cmake/loki_transform.cmake @@ -25,7 +25,7 @@ include( loki_transform_helpers ) # [CPP] # [FRONTEND ] # [INLINE_MEMBERS] -# [FIX_SEQUENCE_ASSOCIATION] +# [RESOLVE_SEQUENCE_ASSOCIATION] # [BUILDDIR ] # [SOURCES [ ...]] # [HEADERS [ ...]] @@ -47,7 +47,7 @@ function( loki_transform ) set( options CPP DATA_OFFLOAD REMOVE_OPENMP ASSUME_DEVICEPTR TRIM_VECTOR_SECTIONS GLOBAL_VAR_OFFLOAD - REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SEQUENCE_ASSOCIATION DERIVE_ARGUMENT_ARRAY_SHAPE + REMOVE_DERIVED_ARGS INLINE_MEMBERS RESOLVE_SEQUENCE_ASSOCIATION DERIVE_ARGUMENT_ARRAY_SHAPE ) set( oneValueArgs COMMAND MODE DIRECTIVE FRONTEND CONFIG BUILDDIR @@ -194,7 +194,7 @@ endfunction() # [DIRECTIVE ] # [SOURCES [ ...]] # [HEADERS [ ...]] -# [NO_PLAN_SOURCEDIR COPY_UNMODIFIED INLINE_MEMBERS FIX_SEQUENCE_ASSOCIATION] +# [NO_PLAN_SOURCEDIR COPY_UNMODIFIED INLINE_MEMBERS RESOLVE_SEQUENCE_ASSOCIATION] # ) # # Applies a Loki bulk transformation to the source files belonging to a particular @@ -223,7 +223,7 @@ endfunction() function( loki_transform_target ) - set( options NO_PLAN_SOURCEDIR COPY_UNMODIFIED CPP CPP_PLAN INLINE_MEMBERS FIX_SEQUENCE_ASSOCIATION ) + set( options NO_PLAN_SOURCEDIR COPY_UNMODIFIED CPP CPP_PLAN INLINE_MEMBERS RESOLVE_SEQUENCE_ASSOCIATION ) set( single_value_args TARGET COMMAND MODE DIRECTIVE FRONTEND CONFIG PLAN ) set( multi_value_args SOURCES HEADERS ) @@ -292,8 +292,8 @@ function( loki_transform_target ) list( APPEND _TRANSFORM_OPTIONS INLINE_MEMBERS ) endif() - if( _PAR_FIX_SEQUENCE_ASSOCIATION ) - list( APPEND _TRANSFORM_OPTIONS FIX_SEQUENCE_ASSOCIATION ) + if( _PAR_RESOLVE_SEQUENCE_ASSOCIATION ) + list( APPEND _TRANSFORM_OPTIONS RESOLVE_SEQUENCE_ASSOCIATION ) endif() loki_transform( @@ -389,7 +389,7 @@ or set( options CPP DATA_OFFLOAD REMOVE_OPENMP ASSUME_DEVICEPTR GLOBAL_VAR_OFFLOAD - TRIM_VECTOR_SECTIONS REMOVE_DERIVED_ARGS INLINE_MEMBERS FIX_SEQUENCE_ASSOCIATION + TRIM_VECTOR_SECTIONS REMOVE_DERIVED_ARGS INLINE_MEMBERS RESOLVE_SEQUENCE_ASSOCIATION ) set( oneValueArgs MODE DIRECTIVE FRONTEND CONFIG PATH OUTPATH diff --git a/cmake/loki_transform_helpers.cmake b/cmake/loki_transform_helpers.cmake index 6b33267e5..680ae0e72 100644 --- a/cmake/loki_transform_helpers.cmake +++ b/cmake/loki_transform_helpers.cmake @@ -112,8 +112,8 @@ macro( _loki_transform_parse_options ) list( APPEND _ARGS --inline-members ) endif() - if( _PAR_FIX_SEQUENCE_ASSOCIATION ) - list( APPEND _ARGS --fix-sequence-association ) + if( _PAR_RESOLVE_SEQUENCE_ASSOCIATION ) + list( APPEND _ARGS --resolve-sequence-association ) endif() if( _PAR_DERIVE_ARGUMENT_ARRAY_SHAPE ) diff --git a/loki/transform/transform_sequence_association.py b/loki/transform/transform_sequence_association.py index 9b1454cfe..817da9afe 100644 --- a/loki/transform/transform_sequence_association.py +++ b/loki/transform/transform_sequence_association.py @@ -18,7 +18,7 @@ __all__ = [ - 'fix_sequence_association' + 'transform_sequence_association' ] def check_if_scalar_syntax(arg, dummy): @@ -175,7 +175,7 @@ def process_symbol(symbol, caller, call): if call.routine in caller.members and symbol in caller.variables: return symbol - raise RuntimeError('[Loki::fix_sequence_association] Unable to resolve argument dimension. Module variable?') + raise RuntimeError('[Loki::transform_sequence_association] Unable to resolve argument dimension. Module variable?') def construct_length(xrange, caller, call): @@ -196,7 +196,7 @@ def construct_length(xrange, caller, call): return single_sum(new_stop) - new_start + IntLiteral(1) -def fix_sequence_association(routine): +def transform_sequence_association(routine): """ Housekeeping routine to replace scalar syntax when passing arrays as arguments For example, a call like diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index 89eb302de..3e549fa66 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -108,14 +108,14 @@ def cli(debug): help="Remove derived-type arguments and replace with canonical arguments") @click.option('--inline-members/--no-inline-members', default=False, help='Inline member functions for SCC-class transformations.') -@click.option('--fix-sequence-association/--no-fix-sequence-association', default=False, +@click.option('--resolve-sequence-association/--no-resolve-sequence-association', default=False, help='Replace array arguments passed as scalars with arrays.') @click.option('--derive-argument-array-shape/--no-derive-argument-array-shape', default=False, help="Recursively derive explicit shape dimension for argument arrays") def convert( mode, config, build, source, header, cpp, directive, include, define, omni_include, xmod, data_offload, remove_openmp, assume_deviceptr, frontend, trim_vector_sections, - global_var_offload, remove_derived_args, inline_members, fix_sequence_association, + global_var_offload, remove_derived_args, inline_members, resolve_sequence_association, derive_argument_array_shape ): """ @@ -211,7 +211,7 @@ def convert( # Apply the basic SCC transformation set scheduler.process( SCCBaseTransformation( horizontal=horizontal, directive=directive, - inline_members=inline_members, fix_sequence_association=fix_sequence_association + inline_members=inline_members, resolve_sequence_association=resolve_sequence_association )) scheduler.process( SCCDevectorTransformation( horizontal=horizontal, trim_vector_sections=trim_vector_sections diff --git a/tests/test_transform_sequence_association.py b/tests/test_transform_sequence_association.py index 679a3b4df..650d63845 100644 --- a/tests/test_transform_sequence_association.py +++ b/tests/test_transform_sequence_association.py @@ -8,7 +8,7 @@ import pytest from conftest import available_frontends -from loki.transform import fix_sequence_association +from loki.transform import transform_sequence_association from loki.module import Module from loki.ir import CallStatement from loki.visitors import FindNodes @@ -126,7 +126,7 @@ def test_transform_scalar_notation(frontend): module = Module.from_source(fcode, frontend=frontend) routine = module['main'] - fix_sequence_association(routine) + transform_sequence_association(routine) calls = FindNodes(CallStatement).visit(routine.body) diff --git a/transformations/transformations/single_column_coalesced.py b/transformations/transformations/single_column_coalesced.py index c5d1aa033..5bf2f9cb1 100644 --- a/transformations/transformations/single_column_coalesced.py +++ b/transformations/transformations/single_column_coalesced.py @@ -7,7 +7,7 @@ import re from loki.expression import symbols as sym -from loki.transform import resolve_associates, inline_member_procedures, fix_sequence_association +from loki.transform import resolve_associates, inline_member_procedures, transform_sequence_association from loki import ( Transformation, FindNodes, Transformer, info, pragmas_attached, as_tuple, flatten, ir, FindExpressions, @@ -40,14 +40,14 @@ class methods can be called directly. Enable full source-inlining of member subroutines; default: False. """ - def __init__(self, horizontal, directive=None, inline_members=False, fix_sequence_association=False): + def __init__(self, horizontal, directive=None, inline_members=False, resolve_sequence_association=False): self.horizontal = horizontal assert directive in [None, 'openacc'] self.directive = directive self.inline_members = inline_members - self.fix_sequence_association = fix_sequence_association + self.resolve_sequence_association = resolve_sequence_association @classmethod def check_routine_pragmas(cls, routine, directive): @@ -293,8 +293,8 @@ def process_kernel(self, routine): v_index = self.get_integer_variable(routine, name=self.horizontal.index) # Transform arrays passed with scalar syntax to array syntax - if self.fix_sequence_association: - fix_sequence_association(routine) + if self.resolve_sequence_association: + transform_sequence_association(routine) # Perform full source-inlining for member subroutines if so requested if self.inline_members: From 95e9cbf9a1bfe6c1eb3f4adcedf3ceb5ae2afcc9 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 10 Nov 2023 10:43:03 +0100 Subject: [PATCH 23/27] inline and sequence association tests in single column transformation --- .../tests/test_single_column_coalesced.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/transformations/tests/test_single_column_coalesced.py b/transformations/tests/test_single_column_coalesced.py index 8294adcb6..8a218a031 100644 --- a/transformations/tests/test_single_column_coalesced.py +++ b/transformations/tests/test_single_column_coalesced.py @@ -1721,3 +1721,85 @@ def test_single_column_coalesced_vector_section_trim_complex(frontend, horizonta else: assert assign in loop.body assert(len(FindNodes(Assignment).visit(loop.body)) == 4) + + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('inline_members', [False, True]) +@pytest.mark.parametrize('resolve_sequence_association', [False, True]) +def test_single_column_coalesced_inline_and_sequence_association(frontend, horizontal, + inline_members, resolve_sequence_association, + capsys): + """ + Test the combinations of routine inlining and sequence association + """ + + fcode_kernel = """ + subroutine some_kernel(nlon, start, end) + implicit none + + integer, intent(in) :: nlon, start, end + real, dimension(nlon) :: work + + call contained_kernel(work(1)) + + contains + + subroutine contained_kernel(work) + implicit none + + real, dimension(nlon) :: work + integer :: jl + + do jl = start, end + work(jl) = 1. + enddo + + end subroutine contained_kernel + end subroutine some_kernel + """ + + routine = Subroutine.from_source(fcode_kernel, frontend=frontend) + + scc_transform = SCCBaseTransformation(horizontal=horizontal, + inline_members=inline_members, + resolve_sequence_association=resolve_sequence_association) + + #Not really doing anything for contained routines + if (not inline_members and not resolve_sequence_association): + scc_transform.apply(routine, role='kernel') + + assert len(routine.members) == 1 + assert not FindNodes(Loop).visit(routine.body) + + #Should fail because it can't resolve sequence association + elif (inline_members and not resolve_sequence_association): + with pytest.raises(RuntimeError) as e_info: + scc_transform.apply(routine, role='kernel') + assert(e_info.exconly() == + 'RuntimeError: [Loki::TransformInline] Unable to resolve member subroutine call') + + #Check that the call is properly modified + elif (not inline_members and resolve_sequence_association): + scc_transform.apply(routine, role='kernel') + + assert len(routine.members) == 1 + call = FindNodes(CallStatement).visit(routine.body)[0] + assert fgen(call).lower() == 'call contained_kernel(work(1:nlon))' + + #Check that the contained subroutine has been inlined + else: + scc_transform.apply(routine, role='kernel') + + assert len(routine.members) == 0 + + loop = FindNodes(Loop).visit(routine.body)[0] + assert loop.variable == 'jl' + assert loop.bounds == 'start:end' + + assign = FindNodes(Assignment).visit(loop.body)[0] + assert fgen(assign).lower() == 'work(jl) = 1.' + + + + + From 286fb8df56cd51c277f3fcb6cb897aebbc39daf5 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 10 Nov 2023 10:58:27 +0100 Subject: [PATCH 24/27] cleanup --- .../tests/test_single_column_coalesced.py | 44 ++++++++----------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/transformations/tests/test_single_column_coalesced.py b/transformations/tests/test_single_column_coalesced.py index 8a218a031..b9f1be0ae 100644 --- a/transformations/tests/test_single_column_coalesced.py +++ b/transformations/tests/test_single_column_coalesced.py @@ -1727,8 +1727,7 @@ def test_single_column_coalesced_vector_section_trim_complex(frontend, horizonta @pytest.mark.parametrize('inline_members', [False, True]) @pytest.mark.parametrize('resolve_sequence_association', [False, True]) def test_single_column_coalesced_inline_and_sequence_association(frontend, horizontal, - inline_members, resolve_sequence_association, - capsys): + inline_members, resolve_sequence_association): """ Test the combinations of routine inlining and sequence association """ @@ -1766,40 +1765,35 @@ def test_single_column_coalesced_inline_and_sequence_association(frontend, horiz #Not really doing anything for contained routines if (not inline_members and not resolve_sequence_association): - scc_transform.apply(routine, role='kernel') + scc_transform.apply(routine, role='kernel') - assert len(routine.members) == 1 - assert not FindNodes(Loop).visit(routine.body) + assert len(routine.members) == 1 + assert not FindNodes(Loop).visit(routine.body) #Should fail because it can't resolve sequence association elif (inline_members and not resolve_sequence_association): - with pytest.raises(RuntimeError) as e_info: - scc_transform.apply(routine, role='kernel') - assert(e_info.exconly() == - 'RuntimeError: [Loki::TransformInline] Unable to resolve member subroutine call') + with pytest.raises(RuntimeError) as e_info: + scc_transform.apply(routine, role='kernel') + assert(e_info.exconly() == + 'RuntimeError: [Loki::TransformInline] Unable to resolve member subroutine call') #Check that the call is properly modified elif (not inline_members and resolve_sequence_association): - scc_transform.apply(routine, role='kernel') + scc_transform.apply(routine, role='kernel') - assert len(routine.members) == 1 - call = FindNodes(CallStatement).visit(routine.body)[0] - assert fgen(call).lower() == 'call contained_kernel(work(1:nlon))' + assert len(routine.members) == 1 + call = FindNodes(CallStatement).visit(routine.body)[0] + assert fgen(call).lower() == 'call contained_kernel(work(1:nlon))' #Check that the contained subroutine has been inlined else: - scc_transform.apply(routine, role='kernel') - - assert len(routine.members) == 0 - - loop = FindNodes(Loop).visit(routine.body)[0] - assert loop.variable == 'jl' - assert loop.bounds == 'start:end' - - assign = FindNodes(Assignment).visit(loop.body)[0] - assert fgen(assign).lower() == 'work(jl) = 1.' - - + scc_transform.apply(routine, role='kernel') + assert len(routine.members) == 0 + loop = FindNodes(Loop).visit(routine.body)[0] + assert loop.variable == 'jl' + assert loop.bounds == 'start:end' + assign = FindNodes(Assignment).visit(loop.body)[0] + assert fgen(assign).lower() == 'work(jl) = 1.' From 94a0f6a4cc14aca82163e8bd5cfd4312bc0570fd Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 10 Nov 2023 13:53:56 +0100 Subject: [PATCH 25/27] greatly simplify by using callers dimensions --- .../transform_sequence_association.py | 179 +----------------- tests/test_transform_sequence_association.py | 142 +------------- 2 files changed, 13 insertions(+), 308 deletions(-) diff --git a/loki/transform/transform_sequence_association.py b/loki/transform/transform_sequence_association.py index 817da9afe..5d9427f31 100644 --- a/loki/transform/transform_sequence_association.py +++ b/loki/transform/transform_sequence_association.py @@ -5,12 +5,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import pymbolic.primitives as pmbl - -from loki.expression import ( - Sum, Product, IntLiteral, Array, RangeIndex, - SubstituteExpressions - ) +from loki.expression import Array, RangeIndex from loki.ir import CallStatement from loki.visitors import FindNodes, Transformer from loki.tools import as_tuple @@ -39,168 +34,13 @@ def check_if_scalar_syntax(arg, dummy): return False -def single_sum(expr): - """ - Return a Sum object of expr if expr is not an instance of pymbolic.primitives.Sum. - Otherwise return expr - - Parameters - ---------- - expr: any pymbolic expression - """ - - if isinstance(expr, pmbl.Sum): - return expr - return Sum((expr,)) - - -def product_value(expr): - """ - If expr is an instance of pymbolic.primitives.Product, try to evaluate it - If it is possible, return the value as an int. - If it is not possible, try to simplify the the product and return as a Product - If it is not a pymbolic.primitives.Product , return expr - - Note: Negative numbers and subtractions in Sums are represented as Product of - the integer -1 and the symbol. This complicates matters. - Note: Ensure that a Loki Product is returned, not a pymbolic Product - - Parameters - ---------- - expr: any pymbolic expression - """ - if isinstance(expr, pmbl.Product): - m = 1 - new_children = [] - for c in expr.children: - if isinstance(c, IntLiteral): - m = m*c.value - elif isinstance(c, int): - m = m*c - else: - new_children += [c] - if m == 0: - return 0 - if not new_children: - return m - - if m > 1: - new_children = [IntLiteral(m)] + new_children - elif m == -1: - new_children = [-1] + new_children - elif m < -1: - new_children = [-1, IntLiteral(abs(m))] + new_children - - return Product(as_tuple(new_children)) - - return expr - - -def simplify_sum(expr): - """ - If expr is an instance of pymbolic.primitives.Sum, - try to simplify it by evaluating any Products and adding up ints and IntLiterals. - If the sum can be reduced to a number, it returns an IntLiteral - If the Sum reduces to one expression, it returns that expression - - Note: Ensure that a Loki Sum is returned, not a pymbolic Sum - - Parameters - ---------- - expr: any pymbolic expression - """ - - if isinstance(expr, pmbl.Sum): - n = 0 - new_children = [] - for c in expr.children: - c = product_value(c) - if isinstance(c, IntLiteral): - n += c.value - elif isinstance(c, int): - n += c - else: - new_children += [c] - - if new_children: - if n > 0: - new_children += [IntLiteral(n)] - elif n < 0: - new_children += [Product((-1,IntLiteral(abs(n))))] - - if len(new_children) > 1: - return Sum(as_tuple(new_children)) - return new_children[0] - return IntLiteral(n) - return expr - - -def construct_range_index(lower, length): - """ - Construct a range index from lower to lower + length - 1 - - Parameters - ---------- - lower : any pymbolic expression - length: any pymbolic expression - """ - - new_high = simplify_sum(single_sum(length) + lower - IntLiteral(1)) - - return RangeIndex((lower, new_high)) - - -def process_symbol(symbol, caller, call): - """ - Map symbol in call.routine to the appropriate symbol in caller, - taking any parents into account - - Parameters - ---------- - symbol: Loki variable in call.routine - caller: Subroutine object containing call - call : Call object - """ - - if isinstance(symbol, IntLiteral): - return symbol - - if not symbol.parents: - if symbol in call.routine.arguments: - return call.arg_map[symbol] - - elif symbol.parents[0] in call.routine.arguments: - return SubstituteExpressions(call.arg_map).visit(symbol.clone(scope=caller)) - - if call.routine in caller.members and symbol in caller.variables: - return symbol - - raise RuntimeError('[Loki::transform_sequence_association] Unable to resolve argument dimension. Module variable?') - - -def construct_length(xrange, caller, call): - """ - Construct an expression for the length of xrange, - defined in call.routine, in caller. - - Parameters - ---------- - xrange: RangeIndex object defined in call.routine - caller: Subroutine object - call : call contained in caller - """ - - new_start = process_symbol(xrange.start, caller, call) - new_stop = process_symbol(xrange.stop, caller, call) - - return single_sum(new_stop) - new_start + IntLiteral(1) - - def transform_sequence_association(routine): """ Housekeeping routine to replace scalar syntax when passing arrays as arguments For example, a call like + real :: a(m,n) + call myroutine(a(i,j)) where myroutine looks like @@ -211,12 +51,7 @@ def transform_sequence_association(routine): should be changed to - call myroutine(a(i:i+5,j) - - Note: Using the __add__ and __mul__ functions of Sum and Product, respectively, - returns the pymbolic.primitives version of the objuect, not the loki.expressions version. - simplify_sum and product_value returns loki versions, so this is currently not an issue, - but this can cause unexpected behaviour + call myroutine(a(i:m,j) Parameters ---------- @@ -238,12 +73,12 @@ def transform_sequence_association(routine): found_scalar = True new_dims = [] - for s, lower in zip(dummy.shape, arg.dimensions): + for s, lower, d in zip(arg.shape, arg.dimensions, dummy.shape): if isinstance(s, RangeIndex): - new_dims += [construct_range_index(lower, construct_length(s, routine, call))] + new_dims += [RangeIndex((lower, s.stop))] else: - new_dims += [construct_range_index(lower, process_symbol(s, routine, call))] + new_dims += [RangeIndex((lower, s))] if len(arg.dimensions) > len(dummy.shape): new_dims += arg.dimensions[len(dummy.shape):] diff --git a/tests/test_transform_sequence_association.py b/tests/test_transform_sequence_association.py index 650d63845..66b711eee 100644 --- a/tests/test_transform_sequence_association.py +++ b/tests/test_transform_sequence_association.py @@ -7,6 +7,7 @@ import pytest +from loki import fgen from conftest import available_frontends from loki.transform import transform_sequence_association from loki.module import Module @@ -14,7 +15,6 @@ from loki.visitors import FindNodes from loki.expression import Sum, IntLiteral, Scalar, Product - @pytest.mark.parametrize('frontend', available_frontends()) def test_transform_scalar_notation(frontend): fcode = """ @@ -39,36 +39,6 @@ def test_transform_scalar_notation(frontend): real :: array(10,10) - call sub_a(array(1, 1), k) - call sub_a(array(2, 2), k) - call sub_a(array(m, m), k) - call sub_a(array(m-1, m-1), k) - call sub_a(array(a%b%c, a%b%c), k) - - call sub_b(array(1, 1)) - call sub_b(array(2, 2)) - call sub_b(array(m, 2)) - call sub_b(array(m-1, m), k) - call sub_b(array(a%b%c, 2)) - - call sub_c(array(1, 1), k) - call sub_c(array(2, 2), k) - call sub_c(array(m, 1), k) - call sub_c(array(m-1, m), k) - call sub_c(array(a%b%c, 1), k) - - call sub_d(array(1, 1), 1, n) - call sub_d(array(2, 2), 1, n) - call sub_d(array(m, 1), k, n) - call sub_d(array(m-1, 1), k, n-1) - call sub_d(array(a%b%c, 1), 1, n) - - call sub_e(array(1, 1), a%b) - call sub_e(array(2, 2), a%b) - call sub_e(array(m, 1), a%b) - call sub_e(array(m-1, 1), a%b) - call sub_e(array(a%b%c, 1), a%b) - call sub_x(array(1, 1), 1) call sub_x(array(2, 2), 2) call sub_x(array(m, 1), k) @@ -86,40 +56,6 @@ def test_transform_scalar_notation(frontend): end subroutine main - subroutine sub_a(array, k) - - integer, intent(in) :: k - real, intent(in) :: array(k) - - end subroutine sub_a - - subroutine sub_b(array) - - real, intent(in) :: array(1:3) - - end subroutine sub_b - - subroutine sub_c(array, k) - - integer, intent(in) :: k - real, intent(in) :: array(2:k) - - end subroutine sub_c - - subroutine sub_d(array, k, n) - - integer, intent(in) :: k, n - real, intent(in) :: array(k:n) - - end subroutine sub_d - - subroutine sub_e(array, x) - - type(type_b), intent(in) :: x - real, intent(in) :: array(x%d) - - end subroutine sub_e - end module mod_a """.strip() @@ -130,74 +66,8 @@ def test_transform_scalar_notation(frontend): calls = FindNodes(CallStatement).visit(routine.body) - one = IntLiteral(1) - two = IntLiteral(2) - three = IntLiteral(3) - four = IntLiteral(4) - m_one = Product((-1,one)) - m_two = Product((-1,two)) - m_three = Product((-1,three)) - m = Scalar('m') - n = Scalar('n') - k = Scalar('k') - m_k = Product((-1,k)) - abc = Scalar(name='a%b%c', parent=Scalar(name='a%b', parent=Scalar('a'))) - abd = Scalar(name='a%b%d', parent=Scalar(name='a%b', parent=Scalar('a'))) - m_abd = Product((-1,abd)) - - #Check that second dimension is properly added - assert calls[0].arguments[0].dimensions[1] == one - assert calls[1].arguments[0].dimensions[1] == two - assert calls[2].arguments[0].dimensions[1] == m - assert calls[3].arguments[0].dimensions[1] == Sum((m,m_one)) - assert calls[4].arguments[0].dimensions[1] == abc - - #Check that start of ranges is correct - assert calls[0].arguments[0].dimensions[0].start == one - assert calls[1].arguments[0].dimensions[0].start == two - assert calls[2].arguments[0].dimensions[0].start == m - assert calls[3].arguments[0].dimensions[0].start == Sum((m,m_one)) - assert calls[4].arguments[0].dimensions[0].start == abc - - #Check that stop of ranges is correct - #sub_a - assert calls[0].arguments[0].dimensions[0].stop == k - assert calls[1].arguments[0].dimensions[0].stop == Sum((k,one)) - assert calls[2].arguments[0].dimensions[0].stop == Sum((k,m,m_one)) - assert calls[3].arguments[0].dimensions[0].stop == Sum((k,m,m_two)) - assert calls[4].arguments[0].dimensions[0].stop == Sum((k,abc,m_one)) - - #sub_b - assert calls[5].arguments[0].dimensions[0].stop == three - assert calls[6].arguments[0].dimensions[0].stop == four - assert calls[7].arguments[0].dimensions[0].stop == Sum((m,two)) - assert calls[8].arguments[0].dimensions[0].stop == Sum((m,one)) - assert calls[9].arguments[0].dimensions[0].stop == Sum((abc,two)) - - #sub_c - assert calls[10].arguments[0].dimensions[0].stop == Sum((k,m_one)) - assert calls[11].arguments[0].dimensions[0].stop == k - assert calls[12].arguments[0].dimensions[0].stop == Sum((k,m,m_two)) - assert calls[13].arguments[0].dimensions[0].stop == Sum((k,m,m_three)) - assert calls[14].arguments[0].dimensions[0].stop == Sum((k,abc,m_two)) - - #sub_d - assert calls[15].arguments[0].dimensions[0].stop == n - assert calls[16].arguments[0].dimensions[0].stop == Sum((n,one)) - assert calls[17].arguments[0].dimensions[0].stop == Sum((n,m_k,m)) - assert calls[18].arguments[0].dimensions[0].stop == Sum((n,m_k,m,m_two)) - assert calls[19].arguments[0].dimensions[0].stop == Sum((n,abc,m_one)) - - #sub_e - assert calls[20].arguments[0].dimensions[0].stop == abd - assert calls[21].arguments[0].dimensions[0].stop == Sum((abd,one)) - assert calls[22].arguments[0].dimensions[0].stop == Sum((abd,m,m_one)) - assert calls[23].arguments[0].dimensions[0].stop == Sum((abd,m,m_two)) - assert calls[24].arguments[0].dimensions[0].stop == Sum((abd,abc,m_one)) - - #sub_x - assert calls[25].arguments[0].dimensions[0].stop == n - assert calls[26].arguments[0].dimensions[0].stop == n - assert calls[27].arguments[0].dimensions[0].stop == Sum((n,m_k,m)) - assert calls[28].arguments[0].dimensions[0].stop == Sum((n,Product((-1,Sum((k, m_one)))),m,m_one)) - assert calls[29].arguments[0].dimensions[0].stop == Sum((n,m_abd,abc)) + assert fgen(calls[0]).lower() == 'call sub_x(array(1:10, 1), 1)' + assert fgen(calls[1]).lower() == 'call sub_x(array(2:10, 2), 2)' + assert fgen(calls[2]).lower() == 'call sub_x(array(m:10, 1), k)' + assert fgen(calls[3]).lower() == 'call sub_x(array(m - 1:10, 1), k - 1)' + assert fgen(calls[4]).lower() == 'call sub_x(array(a%b%c:10, 1), a%b%d)' From 80435d7228ec7662964658885d7cd6bfa40c2ab3 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 10 Nov 2023 14:12:34 +0100 Subject: [PATCH 26/27] cleanup --- loki/transform/transform_sequence_association.py | 4 +++- tests/test_transform_sequence_association.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/loki/transform/transform_sequence_association.py b/loki/transform/transform_sequence_association.py index 5d9427f31..cecc340c8 100644 --- a/loki/transform/transform_sequence_association.py +++ b/loki/transform/transform_sequence_association.py @@ -73,7 +73,9 @@ def transform_sequence_association(routine): found_scalar = True new_dims = [] - for s, lower, d in zip(arg.shape, arg.dimensions, dummy.shape): + for i in range(len(dummy.shape)): + s = arg.shape[i] + lower = arg.dimensions[i] if isinstance(s, RangeIndex): new_dims += [RangeIndex((lower, s.stop))] diff --git a/tests/test_transform_sequence_association.py b/tests/test_transform_sequence_association.py index 66b711eee..bdd2c4978 100644 --- a/tests/test_transform_sequence_association.py +++ b/tests/test_transform_sequence_association.py @@ -7,13 +7,13 @@ import pytest -from loki import fgen from conftest import available_frontends + +from loki.backend import fgen from loki.transform import transform_sequence_association from loki.module import Module from loki.ir import CallStatement from loki.visitors import FindNodes -from loki.expression import Sum, IntLiteral, Scalar, Product @pytest.mark.parametrize('frontend', available_frontends()) def test_transform_scalar_notation(frontend): From 60c3a5b3295c68be34b50e8b4e5e95a6bd570d5e Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Fri, 10 Nov 2023 14:19:55 +0100 Subject: [PATCH 27/27] more cleanup --- loki/transform/transform_sequence_association.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/loki/transform/transform_sequence_association.py b/loki/transform/transform_sequence_association.py index cecc340c8..08f7b9d22 100644 --- a/loki/transform/transform_sequence_association.py +++ b/loki/transform/transform_sequence_association.py @@ -72,17 +72,16 @@ def transform_sequence_association(routine): if check_if_scalar_syntax(arg, dummy): found_scalar = True + n_dims = len(dummy.shape) new_dims = [] - for i in range(len(dummy.shape)): - s = arg.shape[i] - lower = arg.dimensions[i] + for s, lower in zip(arg.shape[:n_dims], arg.dimensions[:n_dims]): if isinstance(s, RangeIndex): new_dims += [RangeIndex((lower, s.stop))] else: new_dims += [RangeIndex((lower, s))] - if len(arg.dimensions) > len(dummy.shape): + if len(arg.dimensions) > n_dims: new_dims += arg.dimensions[len(dummy.shape):] new_args += [arg.clone(dimensions=as_tuple(new_dims)),] else: