Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deep-cloning of subroutiens and modules (fix #174) #175

Merged
merged 2 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,13 @@ def clone(self, **kwargs):
kwargs.setdefault('incomplete', self._incomplete)

# Rebuild IRs
rebuild = Transformer({}, rebuild_scopes=True)
if 'docstring' in kwargs:
kwargs['docstring'] = Transformer({}).visit(kwargs['docstring'])
kwargs['docstring'] = rebuild.visit(kwargs['docstring'])
if 'spec' in kwargs:
kwargs['spec'] = Transformer({}).visit(kwargs['spec'])
kwargs['spec'] = rebuild.visit(kwargs['spec'])
if 'contains' in kwargs:
kwargs['contains'] = Transformer({}).visit(kwargs['contains'])
kwargs['contains'] = rebuild.visit(kwargs['contains'])

# Rescope symbols if not explicitly disabled
kwargs.setdefault('rescope_symbols', True)
Expand Down
2 changes: 1 addition & 1 deletion loki/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def clone(self, **kwargs):

# Rebuild body (other IR components are taken care of in super class)
if 'body' in kwargs:
kwargs['body'] = Transformer({}).visit(kwargs['body'])
kwargs['body'] = Transformer({}, rebuild_scopes=True).visit(kwargs['body'])

# Escalate to parent class
return super().clone(**kwargs)
Expand Down
51 changes: 50 additions & 1 deletion tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
OFP, OMNI, Module, Subroutine, VariableDeclaration, TypeDef, fexprgen,
BasicType, Assignment, FindNodes, FindInlineCalls, FindTypedSymbols,
Transformer, fgen, SymbolAttributes, Variable, Import, Section, Intrinsic,
Scalar, DeferredTypeSymbol
Scalar, DeferredTypeSymbol, FindVariables, SubstituteExpressions, Literal
)


Expand Down Expand Up @@ -529,6 +529,55 @@ def test_module_rescope_clone(frontend):
with pytest.raises(AttributeError):
fgen(other_module_copy)

@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'Parsing fails without dummy module provided')]
))
def test_module_deep_clone(frontend):
"""
Test the rescoping of variables in clone with nested scopes.
"""
fcode = """
module test_module_rescope_clone
use parkind1, only : jpim, jprb
implicit none

integer :: n

real :: array(n)

type my_type
real :: vector(n)
real :: matrix(n, n)
end type

end module test_module_rescope_clone
"""
module = Module.from_source(fcode, frontend=frontend)

# Deep-copy/clone the module
new_module = module.clone()

n = [v for v in FindVariables().visit(new_module.spec) if v.name == 'n'][0]
n_decl = FindNodes(VariableDeclaration).visit(new_module.spec)[0]

# Remove the declaration of `n` and replace it with `3`
new_module.spec = Transformer({n_decl: None}).visit(new_module.spec)
new_module.spec = SubstituteExpressions({n: Literal(3)}).visit(new_module.spec)

# Check the new module has been changed
assert len(FindNodes(VariableDeclaration).visit(new_module.spec)) == 1
new_type_decls = FindNodes(VariableDeclaration).visit(new_module['my_type'].body)
assert len(new_type_decls) == 2
assert new_type_decls[0].symbols[0] == 'vector(3)'
assert new_type_decls[1].symbols[0] == 'matrix(3, 3)'

# Check the old one has not changed
assert len(FindNodes(VariableDeclaration).visit(module.spec)) == 2
type_decls = FindNodes(VariableDeclaration).visit(module['my_type'].body)
assert len(type_decls) == 2
assert type_decls[0].symbols[0] == 'vector(n)'
assert type_decls[1].symbols[0] == 'matrix(n, n)'


@pytest.mark.parametrize('frontend', available_frontends())
def test_module_access_spec_none(frontend):
Expand Down
45 changes: 45 additions & 0 deletions tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,3 +2054,48 @@ def test_enrich_calls_explicit_interface(frontend):
# confirm that rescoping symbols has no effect
driver.rescope_symbols()
assert calls[0].routine is kernel


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'OMNI cannot handle external type defs without source')]
))
def test_subroutine_deep_clone(frontend):
"""
Test that deep-cloning a subroutine actually ensures clean scope separation.
"""

fcode = """
subroutine myroutine(something)
use parkind1, only : jpim, jprb
implicit none

type(that_thing), intent(inout) :: something
real(kind=jprb) :: foo(something%n)

foo(:)=0.0_jprb

associate(thing=>something%else)
if (something%entirely%different) then
foo(:)=42.0_jprb
else
foo(:)=66.6_jprb
end if
end associate
end subroutine myroutine
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

# Create a deep-copy of the routine
new_routine = routine.clone()

# Replace all assignments with dummy calls
map_nodes={}
for assign in FindNodes(Assignment).visit(new_routine.body):
map_nodes[assign] = CallStatement(
name=DeferredTypeSymbol(name='testcall'), arguments=(assign.lhs,), scope=new_routine
)
new_routine.body = Transformer(map_nodes).visit(new_routine.body)

# Ensure that the original copy of the routine remains unaffected
assert len(FindNodes(Assignment).visit(routine.body)) == 3
assert len(FindNodes(Assignment).visit(new_routine.body)) == 0