diff --git a/loki/expression/symbols.py b/loki/expression/symbols.py index 528898807..1025ef914 100644 --- a/loki/expression/symbols.py +++ b/loki/expression/symbols.py @@ -1236,6 +1236,8 @@ class LiteralList(pmbl.AlgebraicLeaf): A list of constant literals, e.g., as used in Array Initialization Lists. """ + init_arg_names = ('values', 'dtype') + def __init__(self, values, dtype=None, **kwargs): self.elements = values self.dtype = dtype @@ -1424,17 +1426,26 @@ class Cast(pmbl.Call): Internal representation of a data type cast. """ + init_arg_names = ('name', 'expression', 'kind') + def __init__(self, name, expression, kind=None, **kwargs): assert kind is None or isinstance(kind, pmbl.Expression) self.kind = kind super().__init__(pmbl.make_variable(name), as_tuple(expression), **kwargs) + def __getinitargs__(self): + return (self.name, self.expression, self.kind) + mapper_method = intern('map_cast') @property def name(self): return self.function.name + @property + def expression(self): + self.parameters + class Range(StrCompareMixin, pmbl.Slice): """ diff --git a/loki/expression/tests/test_symbols.py b/loki/expression/tests/test_symbols.py new file mode 100644 index 000000000..421d8f18a --- /dev/null +++ b/loki/expression/tests/test_symbols.py @@ -0,0 +1,78 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest + +from loki.expression import symbols as sym +from loki.scope import Scope +from loki.types import BasicType, ProcedureType, SymbolAttributes + + +def test_symbol_recreation(): + """ """ + scope = Scope() + int_type = SymbolAttributes(BasicType.INTEGER, parameter=True) + real_type = SymbolAttributes(BasicType.REAL, kind='rick') + log_type = SymbolAttributes(BasicType.LOGICAL) + deferred_type = SymbolAttributes(BasicType.DEFERRED) + proc_type = SymbolAttributes( + ProcedureType(name='f', is_function=True, return_type=real_type) + ) + + i = sym.Scalar(name='i', type=int_type, scope=scope) + a = sym.Array(name='a', type=real_type, scope=scope) + b = sym.Variable( + name='b', dimensions=(i,), type=int_type, scope=scope + ) + t = sym.Scalar(name='t', type=log_type, scope=scope) + f = sym.ProcedureSymbol(name='f', type=proc_type, scope=scope) + + # Basic variables and symbols + exprs = [i, a, b, t, f] + + # Literals + exprs.append( sym.FloatLiteral(66.6) ) + exprs.append( sym.IntLiteral(42) ) + exprs.append( sym.LogicLiteral(True) ) + exprs.append( sym.StringLiteral('Dave') ) + exprs.append( sym.LiteralList( + values=(sym.Literal(1), sym.IntLiteral(2)), dtype=int_type + ) ) + + # Operations + exprs.append( sym.Sum((b, a)) ) # b(i) + a + exprs.append( sym.Product((b, a)) ) # b(i) +* a + exprs.append( sym.Sum((b, sym.Product((-1, a))))) # b(i) - a + exprs.append( sym.Quotient(numerator=b, denominator=a) ) # b(i) / a + + exprs.append( sym.Comparison(b, '==', a) ) # b(i) == a + exprs.append( sym.LogicalNot(t) ) + exprs.append( sym.LogicalAnd((t, sym.LogicalNot(t))) ) + exprs.append( sym.LogicalOr((t, sym.LogicalNot(t))) ) + + # Slightly special symbol types + exprs.append( sym.InlineCall(function=f, parameters=(a, b)) ) + exprs.append( sym.Range((sym.IntLiteral(1), i)) ) + exprs.append( sym.LoopRange((sym.IntLiteral(1), i)) ) + exprs.append( sym.RangeIndex((sym.IntLiteral(1), i)) ) + + exprs.append( sym.Cast(name='int', expression=b, kind=i) ) + exprs.append( sym.Reference(expression=b) ) + exprs.append( sym.Dereference(expression=b) ) + + for expr in exprs: + # Check that Pymbolic-style re-generation works for all + # TODO: Should we introduce a Mixin "Cloneable" to makes these sane? + cargs = dict(zip(expr.init_arg_names, expr.__getinitargs__())) + clone = type(expr)(**cargs) + assert clone == expr + + if isinstance(expr, sym.TypedSymbol): + # Check that TypedSymbols replicate scope via .clone() + scoped_clone = expr.clone() + assert scoped_clone == expr + assert scoped_clone.scope == expr.scope