diff --git a/loki/program_unit.py b/loki/program_unit.py index b7b51a072..daf501b99 100644 --- a/loki/program_unit.py +++ b/loki/program_unit.py @@ -327,7 +327,8 @@ def enrich(self, definitions, recurse=False): """ definitions_map = CaseInsensitiveDict((r.name, r) for r in as_tuple(definitions)) - for imprt in self.imports: + # Enrich type info from all known imports (including parent scopes) + for imprt in self.all_imports: if not (module := definitions_map.get(imprt.module)): # Skip modules that are not available in the definitions list continue diff --git a/loki/tests/test_modules.py b/loki/tests/test_modules.py index a8397f75a..668e95322 100644 --- a/loki/tests/test_modules.py +++ b/loki/tests/test_modules.py @@ -7,15 +7,16 @@ import pytest -from loki import ( - Module, Subroutine, VariableDeclaration, TypeDef, fexprgen, - BasicType, Assignment, FindNodes, FindInlineCalls, FindTypedSymbols, - Transformer, fgen, SymbolAttributes, Variable, Import, Section, Intrinsic, - Scalar, DeferredTypeSymbol, FindVariables, SubstituteExpressions, Literal -) +from loki import Module, Subroutine, fexprgen, fgen from loki.build import jit_compile, clean_test +from loki.expression import symbols as sym from loki.frontend import available_frontends, OMNI +from loki.ir import ( + nodes as ir, FindNodes, FindInlineCalls, FindTypedSymbols, + FindVariables, SubstituteExpressions, Transformer +) from loki.sourcefile import Sourcefile +from loki.types import BasicType, DerivedType, SymbolAttributes @pytest.mark.parametrize('frontend', available_frontends()) @@ -40,8 +41,8 @@ def test_module_from_source(frontend, tmp_path): end module a_module """.strip() module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) - assert len([o for o in module.spec.body if isinstance(o, VariableDeclaration)]) == 2 - assert len([o for o in module.spec.body if isinstance(o, TypeDef)]) == 1 + assert len([o for o in module.spec.body if isinstance(o, ir.VariableDeclaration)]) == 2 + assert len([o for o in module.spec.body if isinstance(o, ir.TypeDef)]) == 1 assert 'derived_type' in module.typedef_map assert len(module.routines) == 1 assert module.routines[0].name == 'my_routine' @@ -100,7 +101,7 @@ def test_module_external_typedefs_subroutine(frontend, tmp_path): assert fexprgen(a.shape) == exptected_array_shape # Check the LHS of the assignment has correct meta-data - stmt = FindNodes(Assignment).visit(routine.body)[0] + stmt = FindNodes(ir.Assignment).visit(routine.body)[0] pt_ext_arr = stmt.lhs assert pt_ext_arr.type.dtype == BasicType.REAL assert fexprgen(pt_ext_arr.shape) == exptected_array_shape @@ -177,14 +178,14 @@ def test_module_external_typedefs_type(frontend, tmp_path): # Verify correct attachment of type information assert 'ext_type' in module.symbol_attrs - assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, TypeDef) - assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, TypeDef) - assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef) - assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, TypeDef) + assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, ir.TypeDef) + assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, ir.TypeDef) + assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef) + assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, ir.TypeDef) assert 'other_type' in module.symbol_attrs assert 'other_type' not in module['other_routine'].symbol_attrs - assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, TypeDef) - assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef) + assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, ir.TypeDef) + assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef) # OMNI resolves explicit shape parameters in the frontend parser exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)' @@ -206,7 +207,7 @@ def test_module_external_typedefs_type(frontend, tmp_path): assert fexprgen(pt_ext_a.shape) == exptected_array_shape # Check the LHS of the assignment has correct meta-data - stmt = FindNodes(Assignment).visit(routine.body)[0] + stmt = FindNodes(ir.Assignment).visit(routine.body)[0] pt_ext_arr = stmt.lhs assert pt_ext_arr.type.dtype == BasicType.REAL assert fexprgen(pt_ext_arr.shape) == exptected_array_shape @@ -412,9 +413,9 @@ def test_module_variables_add_remove(frontend, tmp_path): x = module.variable_map['x'] # That's the symbol for variable 'x' real_type = SymbolAttributes('real', kind=module.variable_map['jprb']) int_type = SymbolAttributes('integer') - a = Variable(name='a', type=real_type, scope=module) - b = Variable(name='b', dimensions=(x, ), type=real_type, scope=module) - c = Variable(name='c', type=int_type, scope=module) + a = sym.Variable(name='a', type=real_type, scope=module) + b = sym.Variable(name='b', dimensions=(x, ), type=real_type, scope=module) + c = sym.Variable(name='c', type=int_type, scope=module) # Add new variables and check that they are all in the module spec module.variables += (a, b, c) @@ -554,22 +555,22 @@ def test_module_deep_clone(frontend, tmp_path): new_module = module.clone() n = [v for v in FindVariables().visit(new_module.spec) if v.name == 'n'][0] - n_decl = FindNodes(VariableDeclaration).visit(new_module.spec)[0] + n_decl = FindNodes(ir.VariableDeclaration).visit(new_module.spec)[0] # Remove the declaration of `n` and replace it with `3` new_module.spec = Transformer({n_decl: None}).visit(new_module.spec) - new_module.spec = SubstituteExpressions({n: Literal(3)}).visit(new_module.spec) + new_module.spec = SubstituteExpressions({n: sym.Literal(3)}).visit(new_module.spec) # Check the new module has been changed - assert len(FindNodes(VariableDeclaration).visit(new_module.spec)) == 1 - new_type_decls = FindNodes(VariableDeclaration).visit(new_module['my_type'].body) + assert len(FindNodes(ir.VariableDeclaration).visit(new_module.spec)) == 1 + new_type_decls = FindNodes(ir.VariableDeclaration).visit(new_module['my_type'].body) assert len(new_type_decls) == 2 assert new_type_decls[0].symbols[0] == 'vector(3)' assert new_type_decls[1].symbols[0] == 'matrix(3, 3)' # Check the old one has not changed - assert len(FindNodes(VariableDeclaration).visit(module.spec)) == 2 - type_decls = FindNodes(VariableDeclaration).visit(module['my_type'].body) + assert len(FindNodes(ir.VariableDeclaration).visit(module.spec)) == 2 + type_decls = FindNodes(ir.VariableDeclaration).visit(module['my_type'].body) assert len(type_decls) == 2 assert type_decls[0].symbols[0] == 'vector(n)' assert type_decls[1].symbols[0] == 'matrix(n, n)' @@ -831,7 +832,7 @@ def test_module_rename_imports_with_definitions(frontend, tmp_path): assert mod3.symbol_attrs[s].compare(mod2.symbol_attrs[use_name or s], ignore=('imported', 'module', 'use_name')) # Verify Import IR node - for imprt in FindNodes(Import).visit(mod3.spec): + for imprt in FindNodes(ir.Import).visit(mod3.spec): if imprt.module == 'test_rename_mod': assert imprt.rename_list assert not imprt.symbols @@ -915,7 +916,7 @@ def test_module_rename_imports_no_definitions(frontend, tmp_path): assert mod3.symbol_attrs[s].use_name == use_name # Verify Import IR node - for imprt in FindNodes(Import).visit(mod3.spec): + for imprt in FindNodes(ir.Import).visit(mod3.spec): if imprt.module == 'test_rename_mod': assert imprt.rename_list assert not imprt.symbols @@ -969,7 +970,7 @@ def test_module_use_module_nature(frontend, tmp_path): # Check properties on the Import IR node in the external module assert ext_mod.imported_symbols == ('int16',) - imprt = FindNodes(Import).visit(ext_mod.spec)[0] + imprt = FindNodes(ir.Import).visit(ext_mod.spec)[0] assert imprt.nature.lower() == 'intrinsic' assert imprt.module.lower() == 'iso_c_binding' assert ext_mod.imported_symbol_map['int16'].type.imported is True @@ -988,8 +989,8 @@ def test_module_use_module_nature(frontend, tmp_path): assert set(my_kinds.imported_symbols) == {'int8', 'int16'} assert set(kinds.imported_symbols) == {'int8', 'int16'} - my_import_map = {s.name: imprt for imprt in FindNodes(Import).visit(my_kinds.spec) for s in imprt.symbols} - import_map = {s.name: imprt for imprt in FindNodes(Import).visit(kinds.spec) for s in imprt.symbols} + my_import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(my_kinds.spec) for s in imprt.symbols} + import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(kinds.spec) for s in imprt.symbols} assert my_import_map['int8'] is my_import_map['int16'] assert import_map['int8'] is import_map['int16'] @@ -1194,13 +1195,13 @@ def test_module_contains_auto_insert(frontend, tmp_path): assert routine1.contains is None routine1 = routine1.clone(contains=routine2) - assert isinstance(routine1.contains, Section) - assert isinstance(routine1.contains.body[0], Intrinsic) + assert isinstance(routine1.contains, ir.Section) + assert isinstance(routine1.contains.body[0], ir.Intrinsic) assert routine1.contains.body[0].text == 'CONTAINS' module = module.clone(contains=routine1) - assert isinstance(module.contains, Section) - assert isinstance(module.contains.body[0], Intrinsic) + assert isinstance(module.contains, ir.Section) + assert isinstance(module.contains.body[0], ir.Intrinsic) assert module.contains.body[0].text == 'CONTAINS' @@ -1243,14 +1244,14 @@ def test_module_missing_imported_symbol(frontend, only_list, complete_tree, tmp_ b = driver.symbol_map['b'] if complete_tree: - assert isinstance(a, Scalar) + assert isinstance(a, sym.Scalar) assert a.type.dtype is BasicType.INTEGER - assert isinstance(b, Scalar) + assert isinstance(b, sym.Scalar) assert b.type.dtype is BasicType.INTEGER else: - assert isinstance(a, DeferredTypeSymbol) + assert isinstance(a, sym.DeferredTypeSymbol) assert a.type.dtype is BasicType.DEFERRED - assert isinstance(b, DeferredTypeSymbol) + assert isinstance(b, sym.DeferredTypeSymbol) assert b.type.dtype is BasicType.DEFERRED assert a.type.imported @@ -1371,3 +1372,55 @@ def test_module_enrichment_within_file(frontend, tmp_path): assert calls[0].arguments[0].type.parameter assert calls[0].arguments[0].type.initial == 16 assert calls[0].arguments[0].type.module is source['foo'] + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_module_enrichment_typdefs(frontend, tmp_path): + """ Test that module-level enrihcment is propagated correctly """ + + fcode_state_mod = """ +module state_type_mod + implicit none + + type state_type + real, pointer, dimension(:,:) :: a + end type state_type + +end module state_type_mod +""" + + fcode_driver_mod = """ +module driver_mod + use state_type_mod, only: state_type + implicit none + +contains + subroutine driver_routine(state) + type(state_type), intent(inout) :: state + + state%a = 1 + + end subroutine driver_routine +end module driver_mod +""" + state_mod = Sourcefile.from_source(fcode_state_mod, frontend=frontend, xmods=[tmp_path])['state_type_mod'] + driver_mod = Sourcefile.from_source(fcode_driver_mod, frontend=frontend, xmods=[tmp_path])['driver_mod'] + driver = driver_mod['driver_routine'] + + state = driver.variable_map['state'] + assert isinstance(state.type.dtype, DerivedType) + assert state.type.dtype.typedef == BasicType.DEFERRED + + # Enrich typedef on the outer module Import + driver_mod.enrich([state_mod], recurse=True) + + state = driver.variable_map['state'] + + # Ensure type info has been propagated to inner subroutine + assert isinstance(state.type.dtype, DerivedType) + assert isinstance(state.type.dtype.typedef, ir.TypeDef) + + assigns = FindNodes(ir.Assignment).visit(driver.body) + assert len(assigns) == 1 + assert assigns[0].lhs.type.dtype == BasicType.REAL + assert assigns[0].lhs.type.shape == (':', ':')