Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional fixes for IR #55

Merged
merged 10 commits into from
May 4, 2023
Prev Previous commit
Next Next commit
Do not invalidate source in statement function injection
reuterbal committed May 4, 2023
commit 16cd4424eeb871fba1442513040fc1d4c17439c4
8 changes: 4 additions & 4 deletions loki/frontend/util.py
Original file line number Diff line number Diff line change
@@ -271,15 +271,15 @@ def create_type(stmt_func):

# Apply transformer with the built maps
if spec_map:
routine.spec = Transformer(spec_map).visit(routine.spec)
routine.spec = Transformer(spec_map, invalidate_source=False).visit(routine.spec)
if body_map:
routine.body = Transformer(body_map).visit(routine.body)
routine.body = Transformer(body_map, invalidate_source=False).visit(routine.body)
if spec_appendix:
routine.spec.append(spec_appendix)
if expr_map_spec:
routine.spec = SubstituteExpressions(expr_map_spec).visit(routine.spec)
routine.spec = SubstituteExpressions(expr_map_spec, invalidate_source=False).visit(routine.spec)
if expr_map_body:
routine.body = SubstituteExpressions(expr_map_body).visit(routine.body)
routine.body = SubstituteExpressions(expr_map_body, invalidate_source=False).visit(routine.body)

# And make sure all symbols have the right type
routine.rescope_symbols()
7 changes: 6 additions & 1 deletion tests/test_subroutine.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,8 @@
Section, CallStatement, BasicType, Array, Scalar, Variable,
SymbolAttributes, StringLiteral, fgen, fexprgen, VariableDeclaration,
Transformer, FindTypedSymbols, ProcedureSymbol, ProcedureType,
StatementFunction, normalize_range_indexing, DeferredTypeSymbol
StatementFunction, normalize_range_indexing, DeferredTypeSymbol,
Assignment
)


@@ -1504,6 +1505,10 @@ def test_subroutine_stmt_func(here, frontend):
routine = Subroutine.from_source(fcode, frontend=frontend)
routine.name += f'_{frontend!s}'

# Make sure the statement function injection doesn't invalidate source
for assignment in FindNodes(Assignment).visit(routine.body):
assert assignment.source is not None

# OMNI inlines statement functions, so we can only check correct representation
# for fparser
if frontend != OMNI: