Skip to content

Commit

Permalink
Loki: Sanitise import in module test
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Nov 24, 2024
1 parent c05b3eb commit 83222f4
Showing 1 changed file with 41 additions and 41 deletions.
82 changes: 41 additions & 41 deletions loki/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@

import pytest

from loki import (
Module, Subroutine, VariableDeclaration, TypeDef, fexprgen,
Assignment, FindNodes, FindInlineCalls, FindTypedSymbols,
Transformer, fgen, SymbolAttributes, Variable, Import, Section, Intrinsic,
Scalar, DeferredTypeSymbol, FindVariables, SubstituteExpressions, Literal
)
from loki import Module, Subroutine, fexprgen, fgen
from loki.build import jit_compile, clean_test
from loki.expression import symbols as sym
from loki.frontend import available_frontends, OMNI
from loki.ir import (
nodes as ir, FindNodes, FindInlineCalls, FindTypedSymbols,
FindVariables, SubstituteExpressions, Transformer
)
from loki.sourcefile import Sourcefile
from loki.types import BasicType, DerivedType
from loki.types import BasicType, DerivedType, SymbolAttributes


@pytest.mark.parametrize('frontend', available_frontends())
Expand All @@ -41,8 +41,8 @@ def test_module_from_source(frontend, tmp_path):
end module a_module
""".strip()
module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
assert len([o for o in module.spec.body if isinstance(o, VariableDeclaration)]) == 2
assert len([o for o in module.spec.body if isinstance(o, TypeDef)]) == 1
assert len([o for o in module.spec.body if isinstance(o, ir.VariableDeclaration)]) == 2
assert len([o for o in module.spec.body if isinstance(o, ir.TypeDef)]) == 1
assert 'derived_type' in module.typedef_map
assert len(module.routines) == 1
assert module.routines[0].name == 'my_routine'
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_module_external_typedefs_subroutine(frontend, tmp_path):
assert fexprgen(a.shape) == exptected_array_shape

# Check the LHS of the assignment has correct meta-data
stmt = FindNodes(Assignment).visit(routine.body)[0]
stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
pt_ext_arr = stmt.lhs
assert pt_ext_arr.type.dtype == BasicType.REAL
assert fexprgen(pt_ext_arr.shape) == exptected_array_shape
Expand Down Expand Up @@ -178,14 +178,14 @@ def test_module_external_typedefs_type(frontend, tmp_path):

# Verify correct attachment of type information
assert 'ext_type' in module.symbol_attrs
assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, TypeDef)
assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, TypeDef)
assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, ir.TypeDef)
assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, ir.TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)
assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, ir.TypeDef)
assert 'other_type' in module.symbol_attrs
assert 'other_type' not in module['other_routine'].symbol_attrs
assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, TypeDef)
assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef)
assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, ir.TypeDef)
assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)

# OMNI resolves explicit shape parameters in the frontend parser
exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'
Expand All @@ -207,7 +207,7 @@ def test_module_external_typedefs_type(frontend, tmp_path):
assert fexprgen(pt_ext_a.shape) == exptected_array_shape

# Check the LHS of the assignment has correct meta-data
stmt = FindNodes(Assignment).visit(routine.body)[0]
stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
pt_ext_arr = stmt.lhs
assert pt_ext_arr.type.dtype == BasicType.REAL
assert fexprgen(pt_ext_arr.shape) == exptected_array_shape
Expand Down Expand Up @@ -413,9 +413,9 @@ def test_module_variables_add_remove(frontend, tmp_path):
x = module.variable_map['x'] # That's the symbol for variable 'x'
real_type = SymbolAttributes('real', kind=module.variable_map['jprb'])
int_type = SymbolAttributes('integer')
a = Variable(name='a', type=real_type, scope=module)
b = Variable(name='b', dimensions=(x, ), type=real_type, scope=module)
c = Variable(name='c', type=int_type, scope=module)
a = sym.Variable(name='a', type=real_type, scope=module)
b = sym.Variable(name='b', dimensions=(x, ), type=real_type, scope=module)
c = sym.Variable(name='c', type=int_type, scope=module)

# Add new variables and check that they are all in the module spec
module.variables += (a, b, c)
Expand Down Expand Up @@ -555,22 +555,22 @@ def test_module_deep_clone(frontend, tmp_path):
new_module = module.clone()

n = [v for v in FindVariables().visit(new_module.spec) if v.name == 'n'][0]
n_decl = FindNodes(VariableDeclaration).visit(new_module.spec)[0]
n_decl = FindNodes(ir.VariableDeclaration).visit(new_module.spec)[0]

# Remove the declaration of `n` and replace it with `3`
new_module.spec = Transformer({n_decl: None}).visit(new_module.spec)
new_module.spec = SubstituteExpressions({n: Literal(3)}).visit(new_module.spec)
new_module.spec = SubstituteExpressions({n: sym.Literal(3)}).visit(new_module.spec)

# Check the new module has been changed
assert len(FindNodes(VariableDeclaration).visit(new_module.spec)) == 1
new_type_decls = FindNodes(VariableDeclaration).visit(new_module['my_type'].body)
assert len(FindNodes(ir.VariableDeclaration).visit(new_module.spec)) == 1
new_type_decls = FindNodes(ir.VariableDeclaration).visit(new_module['my_type'].body)
assert len(new_type_decls) == 2
assert new_type_decls[0].symbols[0] == 'vector(3)'
assert new_type_decls[1].symbols[0] == 'matrix(3, 3)'

# Check the old one has not changed
assert len(FindNodes(VariableDeclaration).visit(module.spec)) == 2
type_decls = FindNodes(VariableDeclaration).visit(module['my_type'].body)
assert len(FindNodes(ir.VariableDeclaration).visit(module.spec)) == 2
type_decls = FindNodes(ir.VariableDeclaration).visit(module['my_type'].body)
assert len(type_decls) == 2
assert type_decls[0].symbols[0] == 'vector(n)'
assert type_decls[1].symbols[0] == 'matrix(n, n)'
Expand Down Expand Up @@ -832,7 +832,7 @@ def test_module_rename_imports_with_definitions(frontend, tmp_path):
assert mod3.symbol_attrs[s].compare(mod2.symbol_attrs[use_name or s], ignore=('imported', 'module', 'use_name'))

# Verify Import IR node
for imprt in FindNodes(Import).visit(mod3.spec):
for imprt in FindNodes(ir.Import).visit(mod3.spec):
if imprt.module == 'test_rename_mod':
assert imprt.rename_list
assert not imprt.symbols
Expand Down Expand Up @@ -916,7 +916,7 @@ def test_module_rename_imports_no_definitions(frontend, tmp_path):
assert mod3.symbol_attrs[s].use_name == use_name

# Verify Import IR node
for imprt in FindNodes(Import).visit(mod3.spec):
for imprt in FindNodes(ir.Import).visit(mod3.spec):
if imprt.module == 'test_rename_mod':
assert imprt.rename_list
assert not imprt.symbols
Expand Down Expand Up @@ -970,7 +970,7 @@ def test_module_use_module_nature(frontend, tmp_path):

# Check properties on the Import IR node in the external module
assert ext_mod.imported_symbols == ('int16',)
imprt = FindNodes(Import).visit(ext_mod.spec)[0]
imprt = FindNodes(ir.Import).visit(ext_mod.spec)[0]
assert imprt.nature.lower() == 'intrinsic'
assert imprt.module.lower() == 'iso_c_binding'
assert ext_mod.imported_symbol_map['int16'].type.imported is True
Expand All @@ -989,8 +989,8 @@ def test_module_use_module_nature(frontend, tmp_path):
assert set(my_kinds.imported_symbols) == {'int8', 'int16'}
assert set(kinds.imported_symbols) == {'int8', 'int16'}

my_import_map = {s.name: imprt for imprt in FindNodes(Import).visit(my_kinds.spec) for s in imprt.symbols}
import_map = {s.name: imprt for imprt in FindNodes(Import).visit(kinds.spec) for s in imprt.symbols}
my_import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(my_kinds.spec) for s in imprt.symbols}
import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(kinds.spec) for s in imprt.symbols}

assert my_import_map['int8'] is my_import_map['int16']
assert import_map['int8'] is import_map['int16']
Expand Down Expand Up @@ -1195,13 +1195,13 @@ def test_module_contains_auto_insert(frontend, tmp_path):
assert routine1.contains is None

routine1 = routine1.clone(contains=routine2)
assert isinstance(routine1.contains, Section)
assert isinstance(routine1.contains.body[0], Intrinsic)
assert isinstance(routine1.contains, ir.Section)
assert isinstance(routine1.contains.body[0], ir.Intrinsic)
assert routine1.contains.body[0].text == 'CONTAINS'

module = module.clone(contains=routine1)
assert isinstance(module.contains, Section)
assert isinstance(module.contains.body[0], Intrinsic)
assert isinstance(module.contains, ir.Section)
assert isinstance(module.contains.body[0], ir.Intrinsic)
assert module.contains.body[0].text == 'CONTAINS'


Expand Down Expand Up @@ -1244,14 +1244,14 @@ def test_module_missing_imported_symbol(frontend, only_list, complete_tree, tmp_
b = driver.symbol_map['b']

if complete_tree:
assert isinstance(a, Scalar)
assert isinstance(a, sym.Scalar)
assert a.type.dtype is BasicType.INTEGER
assert isinstance(b, Scalar)
assert isinstance(b, sym.Scalar)
assert b.type.dtype is BasicType.INTEGER
else:
assert isinstance(a, DeferredTypeSymbol)
assert isinstance(a, sym.DeferredTypeSymbol)
assert a.type.dtype is BasicType.DEFERRED
assert isinstance(b, DeferredTypeSymbol)
assert isinstance(b, sym.DeferredTypeSymbol)
assert b.type.dtype is BasicType.DEFERRED

assert a.type.imported
Expand Down Expand Up @@ -1418,9 +1418,9 @@ def test_module_enrichment_typdefs(frontend, tmp_path):

# Ensure type info has been propagated to inner subroutine
assert isinstance(state.type.dtype, DerivedType)
assert isinstance(state.type.dtype.typedef, TypeDef)
assert isinstance(state.type.dtype.typedef, ir.TypeDef)

assigns = FindNodes(Assignment).visit(driver.body)
assigns = FindNodes(ir.Assignment).visit(driver.body)
assert len(assigns) == 1
assert assigns[0].lhs.type.dtype == BasicType.REAL
assert assigns[0].lhs.type.shape == (':', ':')

0 comments on commit 83222f4

Please sign in to comment.