Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expression: Expression cloning and mapper tests #419

Merged
merged 12 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 23 additions & 44 deletions loki/expression/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Mappers for traversing and transforming the
:ref:`internal_representation:Expression tree`.
"""
from copy import deepcopy

import re
from itertools import zip_longest
import pymbolic.primitives as pmbl
Expand Down Expand Up @@ -515,37 +515,23 @@ class LokiIdentityMapper(IdentityMapper):
This can serve as basis for any transformation mappers
that apply changes to the expression tree. Expression nodes that
are unchanged are returned as is.

Parameters
----------
invalidate_source : bool, optional
By default the :attr:`source` property of nodes is discarded
when rebuilding the node, setting this to `False` allows to
retain that information
"""

def __init__(self, invalidate_source=True):
super().__init__()
self.invalidate_source = invalidate_source
@staticmethod
def _rebuild(expr):
""" Utility to safely rebuild any symbol """
if hasattr(expr, 'clone'):
return expr.clone()

# Re-create symbol Pymbolic-style
cargs = dict(zip(expr.init_arg_names, expr.__getinitargs__()))
return type(expr)(**cargs)

def __call__(self, expr, *args, **kwargs):
if expr is None:
return None
kwargs.setdefault('recurse_to_declaration_attributes', False)
new_expr = super().__call__(expr, *args, **kwargs)
if getattr(expr, 'source', None):
if isinstance(new_expr, tuple):
for e in new_expr:
if self.invalidate_source:
e.source = None
else:
e.source = deepcopy(expr.source)
else:
if self.invalidate_source:
new_expr.source = None
else:
new_expr.source = deepcopy(expr.source)
return new_expr
return super().__call__(expr, *args, **kwargs)

rec = __call__

Expand All @@ -557,7 +543,7 @@ def __call__(self, expr, *args, **kwargs):
def map_int_literal(self, expr, *args, **kwargs):
kind = self.rec(expr.kind, *args, **kwargs)
if kind is expr.kind:
return expr
return self._rebuild(expr)
return expr.__class__(expr.value, kind=kind)

map_float_literal = map_int_literal
Expand Down Expand Up @@ -615,11 +601,11 @@ def map_variable_symbol(self, expr, *args, **kwargs):
parent = self.rec(expr.parent, *args, **kwargs)
if expr.scope is None:
if parent is expr.parent and not is_type_changed:
return expr
return self._rebuild(expr)
return expr.clone(parent=parent, type=new_type)

if parent is expr.parent:
return expr
return self._rebuild(expr)
return expr.clone(parent=parent)

map_deferred_type_symbol = map_variable_symbol
Expand All @@ -631,7 +617,7 @@ def map_meta_symbol(self, expr, *args, **kwargs):
# but with no rebuilt it may return VariableSymbol. Therefore we need to return the
# original expression if the underlying symbol is unchanged
if symbol is expr._symbol:
return expr
return self._rebuild(expr)
return symbol

map_scalar = map_meta_symbol
Expand Down Expand Up @@ -659,7 +645,7 @@ def map_array(self, expr, *args, **kwargs):
if (getattr(symbol, 'symbol', symbol) is expr.symbol and
all(d is orig_d for d, orig_d in zip_longest(dimensions or (), expr.dimensions or ())) and
all(d is orig_d for d, orig_d in zip_longest(shape or (), symbol.type.shape or ()))):
return expr
return self._rebuild(expr)
return symbol.clone(dimensions=dimensions, type=symbol.type.clone(shape=shape), parent=parent)

def map_array_subscript(self, expr, *args, **kwargs):
Expand All @@ -678,14 +664,14 @@ def map_cast(self, expr, *args, **kwargs):
kind = self.rec(expr.kind, *args, **kwargs)
if (function is expr.function and kind is expr.kind and
all(p is orig_p for p, orig_p in zip_longest(parameters, expr.parameters))):
return expr
return self._rebuild(expr)
return expr.__class__(function, parameters, kind=kind)

def map_sum(self, expr, *args, **kwargs):
# Need to re-implement to avoid application of flattened_sum/flattened_product
children = self.rec(expr.children, *args, **kwargs)
if all(c is orig_c for c, orig_c in zip_longest(children, expr.children)):
return expr
return self._rebuild(expr)
return expr.__class__(children)

def map_quotient(self, expr, *args, **kwargs):
Expand All @@ -707,7 +693,7 @@ def map_literal_list(self, expr, *args, **kwargs):
values = tuple(v if isinstance(v, str) else self.rec(v, *args, **kwargs)
for v in expr.elements)
if all(v is orig_v for v, orig_v in zip_longest(values, expr.elements)):
return expr
return self._rebuild(expr)
return expr.__class__(values, dtype=expr.dtype)

def map_inline_do(self, expr, *args, **kwargs):
Expand Down Expand Up @@ -750,15 +736,11 @@ class SubstituteExpressionsMapper(LokiIdentityMapper):
----------
expr_map : dict
Expression mapping to apply to the expression tree.
invalidate_source : bool, optional
By default the :attr:`source` property of nodes is discarded
when rebuilding the node, setting this to `False` allows to
retain that information
"""
# pylint: disable=abstract-method

def __init__(self, expr_map, invalidate_source=True):
super().__init__(invalidate_source=invalidate_source)
def __init__(self, expr_map):
super().__init__()

self.expr_map = expr_map
for expr in self.expr_map.keys():
Expand All @@ -770,7 +752,7 @@ def map_from_expr_map(self, expr, *args, **kwargs):
otherwise continue tree traversal
"""
if expr in self.expr_map:
return self.expr_map[expr]
return self._rebuild(self.expr_map[expr])
map_fn = getattr(super(), expr.mapper_method)
return map_fn(expr, *args, **kwargs)

Expand All @@ -789,7 +771,7 @@ class AttachScopesMapper(LokiIdentityMapper):
"""

def __init__(self, fail=False):
super().__init__(invalidate_source=False)
super().__init__()
self.fail = fail

def _update_symbol_scope(self, expr, scope):
Expand Down Expand Up @@ -847,9 +829,6 @@ class DetachScopesMapper(LokiIdentityMapper):
analysis passes.
"""

def __init__(self):
super().__init__(invalidate_source=False)

def map_variable_symbol(self, expr, *args, **kwargs):
new_expr = super().map_variable_symbol(expr, *args, **kwargs)
new_expr = new_expr.clone(scope=None)
Expand Down
4 changes: 2 additions & 2 deletions loki/expression/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ class SimplifyMapper(LokiIdentityMapper):
"""
# pylint: disable=abstract-method

def __init__(self, enabled_simplifications=Simplification.ALL, invalidate_source=True):
super().__init__(invalidate_source=invalidate_source)
def __init__(self, enabled_simplifications=Simplification.ALL):
super().__init__()

self.enabled_simplifications = enabled_simplifications

Expand Down
11 changes: 11 additions & 0 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,8 @@ class LiteralList(pmbl.AlgebraicLeaf):
A list of constant literals, e.g., as used in Array Initialization Lists.
"""

init_arg_names = ('values', 'dtype')

def __init__(self, values, dtype=None, **kwargs):
self.elements = values
self.dtype = dtype
Expand Down Expand Up @@ -1424,17 +1426,26 @@ class Cast(pmbl.Call):
Internal representation of a data type cast.
"""

init_arg_names = ('name', 'expression', 'kind')

def __init__(self, name, expression, kind=None, **kwargs):
assert kind is None or isinstance(kind, pmbl.Expression)
self.kind = kind
super().__init__(pmbl.make_variable(name), as_tuple(expression), **kwargs)

def __getinitargs__(self):
return (self.name, self.expression, self.kind)

mapper_method = intern('map_cast')

@property
def name(self):
return self.function.name

@property
def expression(self):
return self.parameters


class Range(StrCompareMixin, pmbl.Slice):
"""
Expand Down
1 change: 0 additions & 1 deletion loki/expression/tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import numpy as np

import pymbolic.primitives as pmbl
import pymbolic.mapper as pmbl_mapper

from loki import (
Sourcefile, Subroutine, Module, Scope, BasicType,
Expand Down
127 changes: 127 additions & 0 deletions loki/expression/tests/test_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from loki import Subroutine, Scope
from loki.expression import symbols as sym, parse_expr
from loki.expression.mappers import (
ExpressionRetriever, LokiIdentityMapper, SubstituteExpressionsMapper
)
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes


@pytest.mark.parametrize('frontend', available_frontends())
def test_expression_retriever(frontend):
""" Test for :any:`ExpressionRetriever` (a :any:`LokiWalkMapper`) """

fcode = """
subroutine test_expr_retriever(n, a, b, c)
integer, intent(inout) :: n, a, b(n), c

a = 5 * a + 4 * b(c) + a
end subroutine test_expr_retriever
"""
routine = Subroutine.from_source(fcode, frontend=frontend)
expr = FindNodes(ir.Assignment).visit(routine.body)[0].rhs

def q_symbol(n):
return isinstance(n, sym.TypedSymbol)

def q_array(n):
return isinstance(n, sym.Array)

def q_scalar(n):
return isinstance(n, sym.Scalar)

def q_deferred(n):
return isinstance(n, sym.DeferredTypeSymbol)

def q_literal(n):
return isinstance(n, sym.IntLiteral)

assert ExpressionRetriever(q_symbol).retrieve(expr) == ['a', 'b', 'c', 'a']
assert ExpressionRetriever(q_array).retrieve(expr) == ['b(c)']
assert ExpressionRetriever(q_scalar).retrieve(expr) == ['a', 'c', 'a']
assert ExpressionRetriever(q_literal).retrieve(expr) == [5, 4]

scope = Scope()
expr = parse_expr('5 * a + 4 * b(c) + a', scope=scope)

assert ExpressionRetriever(q_symbol).retrieve(expr) == ['a', 'b', 'c', 'a']
assert ExpressionRetriever(q_array).retrieve(expr) == ['b(c)']
# Cannot determine Scalar without declarations, so check for deferred
assert ExpressionRetriever(q_deferred).retrieve(expr) == ['a', 'c', 'a']
assert ExpressionRetriever(q_literal).retrieve(expr) == [5, 4]


@pytest.mark.parametrize('frontend', available_frontends())
def test_identity_mapper(frontend):
"""
Test for :any:`LokiIdentityMapper`, in particular deep-copying
expression nodes.
"""

fcode = """
subroutine test_expr_retriever(n, a, b, c)
integer, intent(inout) :: n, a, b(n), c

a = 5 * a + 4 * b(c) + a
end subroutine test_expr_retriever
"""
routine = Subroutine.from_source(fcode, frontend=frontend)
expr = FindNodes(ir.Assignment).visit(routine.body)[0].rhs

# Run the identity mapper over the expression
new_expr = LokiIdentityMapper()(expr)

# Check that symbols and literals are equivalent, but distinct objects!
get_symbols = ExpressionRetriever(lambda e: isinstance(e, sym.TypedSymbol)).retrieve
get_literals = ExpressionRetriever(lambda e: isinstance(e, sym.IntLiteral)).retrieve

for old, new in zip(get_symbols(expr), get_symbols(new_expr)):
assert old == new
assert not old is new

for old, new in zip(get_literals(expr), get_literals(new_expr)):
assert old == new
assert not old is new


@pytest.mark.parametrize('frontend', available_frontends())
def test_substitute_expression_mapper(frontend):
"""
Test for :any:`SubstituteExpressionsMapper`.
"""

fcode = """
subroutine test_expr_retriever(n, a, b, c, d)
integer, intent(inout) :: n, a, b(n), c, d

a = 5 * a + 4 * b(c) + a
end subroutine test_expr_retriever
"""
routine = Subroutine.from_source(fcode, frontend=frontend)
expr = FindNodes(ir.Assignment).visit(routine.body)[0].rhs

retriever = ExpressionRetriever(lambda e: isinstance(e, sym.TypedSymbol))
symbols = retriever.retrieve(expr)
assert symbols == ['a', 'b', 'c', 'a']
assert symbols[0] == symbols[3]
assert not symbols[0] is symbols[3]
a = symbols[0]
d = routine.variable_map['d']

new_expr = SubstituteExpressionsMapper(expr_map={a: d})(expr)

assert new_expr == '5*d + 4*b(c) + d'
new_symbols = retriever.retrieve(new_expr)
assert new_symbols == ['d', 'b', 'c', 'd']
assert new_symbols[0] == new_symbols[3]
# Ensure multiple inserted symbols are still unique
assert not new_symbols[0] is new_symbols[3]
2 changes: 1 addition & 1 deletion loki/expression/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from loki import Subroutine, Module, Scope
from loki.expression import symbols as sym, parse_expr
from loki.frontend import (
available_frontends, OMNI, HAVE_FP, parse_fparser_expression
available_frontends, HAVE_FP, parse_fparser_expression
)
from loki.ir import FindVariables

Expand Down
Loading
Loading