Skip to content

Commit

Permalink
Retain source object for injected statement functions
Browse files Browse the repository at this point in the history
  • Loading branch information
reuterbal committed Mar 26, 2023
1 parent e6ead63 commit 4983405
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
10 changes: 7 additions & 3 deletions loki/frontend/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def inject_statement_functions(routine):
def create_stmt_func(assignment):
arguments = assignment.lhs.dimensions
variable = assignment.lhs.clone(dimensions=None)
return StatementFunction(variable, arguments, assignment.rhs, variable.type)
return StatementFunction(variable, arguments, assignment.rhs, variable.type, source=assignment.source)

def create_type(stmt_func):
name = str(stmt_func.variable)
Expand Down Expand Up @@ -253,15 +253,19 @@ def create_type(stmt_func):
if variable.name.lower() in stmt_funcs:
if isinstance(variable, Array):
parameters = variable.dimensions
expr_map_spec[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters)
expr_map_spec[variable] = InlineCall(
variable.clone(dimensions=None), parameters=parameters, source=variable.source
)
elif not isinstance(variable, ProcedureSymbol):
expr_map_spec[variable] = variable.clone()
expr_map_body = {}
for variable in FindVariables().visit(routine.body):
if variable.name.lower() in stmt_funcs:
if isinstance(variable, Array):
parameters = variable.dimensions
expr_map_body[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters)
expr_map_body[variable] = InlineCall(
variable.clone(dimensions=None), parameters=parameters, source=variable.source
)
elif not isinstance(variable, ProcedureSymbol):
expr_map_body[variable] = variable.clone()

Expand Down
1 change: 1 addition & 0 deletions tests/test_sourcefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def test_sourcefile_cpp_stmt_func(here, frontend):
assert isinstance(var, ProcedureSymbol)
assert isinstance(var.type.dtype, ProcedureType)
assert var.type.dtype.procedure is decl
assert decl.source is not None

# Generate code and compile
filepath = here/f'{module.name}.f90'
Expand Down
1 change: 1 addition & 0 deletions tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,7 @@ def test_subroutine_stmt_func(here, frontend):
assert isinstance(var, ProcedureSymbol)
assert isinstance(var.type.dtype, ProcedureType)
assert var.type.dtype.procedure is stmt_func_decls[var]
assert stmt_func_decls[var].source is not None

# Make sure this produces the correct result
filepath = here/f'{routine.name}.f90'
Expand Down

0 comments on commit 4983405

Please sign in to comment.