From 6ebcd89e959090a24935fc55844cc7ca392ec6d0 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Thu, 31 Oct 2024 14:35:00 +0000 Subject: [PATCH] Expression: Add test for expression clone/re-create and fix Cast --- loki/expression/symbols.py | 9 ++++ loki/expression/tests/test_symbols.py | 75 +++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 loki/expression/tests/test_symbols.py diff --git a/loki/expression/symbols.py b/loki/expression/symbols.py index 528898807..436488c4a 100644 --- a/loki/expression/symbols.py +++ b/loki/expression/symbols.py @@ -1424,17 +1424,26 @@ class Cast(pmbl.Call): Internal representation of a data type cast. """ + init_arg_names = ('name', 'expression', 'kind') + def __init__(self, name, expression, kind=None, **kwargs): assert kind is None or isinstance(kind, pmbl.Expression) self.kind = kind super().__init__(pmbl.make_variable(name), as_tuple(expression), **kwargs) + def __getinitargs__(self): + return (self.name, self.expression, self.kind) + mapper_method = intern('map_cast') @property def name(self): return self.function.name + @property + def expression(self): + self.parameters + class Range(StrCompareMixin, pmbl.Slice): """ diff --git a/loki/expression/tests/test_symbols.py b/loki/expression/tests/test_symbols.py new file mode 100644 index 000000000..6cee63399 --- /dev/null +++ b/loki/expression/tests/test_symbols.py @@ -0,0 +1,75 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest + +from loki.expression import symbols as sym +from loki.scope import Scope +from loki.types import BasicType, ProcedureType, SymbolAttributes + + +def test_symbol_recreation(): + """ """ + scope = Scope() + int_type = SymbolAttributes(BasicType.INTEGER, parameter=True) + real_type = SymbolAttributes(BasicType.REAL, kind='rick') + log_type = SymbolAttributes(BasicType.LOGICAL) + deferred_type = SymbolAttributes(BasicType.DEFERRED) + proc_type = SymbolAttributes( + ProcedureType(name='f', is_function=True, return_type=real_type) + ) + + i = sym.Scalar(name='i', type=int_type, scope=scope) + a = sym.Array(name='a', type=real_type, scope=scope) + b = sym.Variable( + name='b', dimensions=(i,), type=int_type, scope=scope + ) + t = sym.Scalar(name='t', type=log_type, scope=scope) + f = sym.ProcedureSymbol(name='f', type=proc_type, scope=scope) + + # Basic variables and symbols + exprs = [i, a, b, t, f] + + # Literals + exprs.append( sym.FloatLiteral(66.6) ) + exprs.append( sym.IntLiteral(42) ) + exprs.append( sym.LogicLiteral(True) ) + exprs.append( sym.StringLiteral('Dave') ) + + # Operations + exprs.append( sym.Sum((b, a)) ) # b(i) + a + exprs.append( sym.Product((b, a)) ) # b(i) +* a + exprs.append( sym.Sum((b, sym.Product((-1, a))))) # b(i) - a + exprs.append( sym.Quotient(numerator=b, denominator=a) ) # b(i) / a + + exprs.append( sym.Comparison(b, '==', a) ) # b(i) == a + exprs.append( sym.LogicalNot(t) ) + exprs.append( sym.LogicalAnd((t, sym.LogicalNot(t))) ) + exprs.append( sym.LogicalOr((t, sym.LogicalNot(t))) ) + + # Slightly special symbol types + exprs.append( sym.InlineCall(function=f, parameters=(a, b)) ) + exprs.append( sym.Range((sym.IntLiteral(1), i)) ) + exprs.append( sym.LoopRange((sym.IntLiteral(1), i)) ) + exprs.append( sym.RangeIndex((sym.IntLiteral(1), i)) ) + + exprs.append( sym.Cast(name='int', expression=b, kind=i) ) + exprs.append( sym.Reference(expression=b) ) + exprs.append( sym.Dereference(expression=b) ) + + for expr in exprs: + # Check that Pymbolic-style re-generation works for all + # TODO: Should we introduce a Mixin "Cloneable" to makes these sane? + cargs = dict(zip(expr.init_arg_names, expr.__getinitargs__())) + clone = type(expr)(**cargs) + assert clone == expr + + if isinstance(expr, sym.TypedSymbol): + # Check that TypedSymbols replicate scope via .clone() + scoped_clone = expr.clone() + assert scoped_clone == expr + assert scoped_clone.scope == expr.scope