diff --git a/loki/expression/mappers.py b/loki/expression/mappers.py index 4909b7d10..44cdfdcb2 100644 --- a/loki/expression/mappers.py +++ b/loki/expression/mappers.py @@ -93,6 +93,7 @@ def map_variable_symbol(self, expr, enclosing_prec, *args, **kwargs): map_deferred_type_symbol = map_variable_symbol map_procedure_symbol = map_variable_symbol + map_derived_type_symbol = map_variable_symbol def map_meta_symbol(self, expr, enclosing_prec, *args, **kwargs): return self.rec(expr._symbol, enclosing_prec, *args, **kwargs) @@ -234,6 +235,7 @@ def map_variable_symbol(self, expr, *args, **kwargs): map_deferred_type_symbol = map_variable_symbol map_procedure_symbol = map_variable_symbol + map_derived_type_symbol = map_variable_symbol def map_meta_symbol(self, expr, *args, **kwargs): if not self.visit(expr): @@ -611,6 +613,7 @@ def map_variable_symbol(self, expr, *args, **kwargs): map_deferred_type_symbol = map_variable_symbol map_procedure_symbol = map_variable_symbol + map_derived_type_symbol = map_variable_symbol def map_meta_symbol(self, expr, *args, **kwargs): symbol = self.rec(expr._symbol, *args, **kwargs) @@ -823,6 +826,7 @@ def map_procedure_symbol(self, expr, *args, **kwargs): return expr.clone(scope=kwargs['scope']) return self.map_variable_symbol(expr, *args, **kwargs) + class DetachScopesMapper(LokiIdentityMapper): """ A Pymbolic expression mapper (i.e., a visitor for the expression tree) diff --git a/loki/expression/symbols.py b/loki/expression/symbols.py index 1fa35592c..7daa870a8 100644 --- a/loki/expression/symbols.py +++ b/loki/expression/symbols.py @@ -30,7 +30,7 @@ # Mix-ins 'StrCompareMixin', # Typed leaf nodes - 'TypedSymbol', 'DeferredTypeSymbol', 'VariableSymbol', 'ProcedureSymbol', + 'TypedSymbol', 'DeferredTypeSymbol', 'VariableSymbol', 'ProcedureSymbol', 'DerivedTypeSymbol', 'MetaSymbol', 'Scalar', 'Array', 'Variable', # Non-typed leaf nodes 'FloatLiteral', 'IntLiteral', 'LogicLiteral', 'StringLiteral', @@ -481,6 +481,34 @@ def __init__(self, name, scope=None, type=None, **kwargs): mapper_method = intern('map_procedure_symbol') +class DerivedTypeSymbol(StrCompareMixin, TypedSymbol, _FunctionSymbol): + """ + Internal representation of a symbol that represents a named + derived type. + + This is used to represent the derived type symbolically in + :any:`Import` statements and when defining derived types. + + Parameters + ---------- + name : str + The name of the symbol. + scope : :any:`Scope` + The scope in which the symbol is declared. + type : optional + The type of that symbol. Defaults to :any:`BasicType.DEFERRED`. + """ + + def __init__(self, name, scope=None, type=None, **kwargs): + # pylint: disable=redefined-builtin + assert type is None or isinstance(type.dtype, DerivedType) + if type is not None: + assert name.lower() == type.dtype.name.lower() + super().__init__(name=name, scope=scope, type=type, **kwargs) + + mapper_method = intern('map_derived_type_symbol') + + class MetaSymbol(StrCompareMixin, pmbl.AlgebraicLeaf): """ Base class for meta symbols to encapsulate a symbol node with optional @@ -868,9 +896,8 @@ def __new__(cls, **kwargs): return ProcedureSymbol(**kwargs) if _type and isinstance(_type.dtype, DerivedType) and name.lower() == _type.dtype.name.lower(): - # This is a constructor call (or a type imported in an ``IMPORT`` statement, in which - # case this is classified wrong...) - return ProcedureSymbol(**kwargs) + # This the name of a derived type, as found in USE import statements + return DerivedTypeSymbol(**kwargs) if 'dimensions' in kwargs and kwargs['dimensions'] is None: # Convenience: This way we can construct Scalar variables with `dimensions=None` @@ -1315,7 +1342,9 @@ def __init__(self, function, parameters=None, kw_parameters=None, **kwargs): # Unfortunately, have to accept MetaSymbol here for the time being as # rescoping before injecting statement functions may create InlineCalls # with Scalar/Variable function names. - assert isinstance(function, (ProcedureSymbol, DeferredTypeSymbol, MetaSymbol)) + assert isinstance(function, ( + ProcedureSymbol, DerivedTypeSymbol, DeferredTypeSymbol, MetaSymbol + )) parameters = parameters or () kw_parameters = kw_parameters or {} diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index 693a89aa4..9564c2d0e 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -1780,6 +1780,26 @@ def visit_Subroutine_Subprogram(self, o, **kwargs): # symbols in the spec part to make them coherent with the symbol table spec = AttachScopes().visit(spec, scope=routine, recurse_to_declaration_attributes=True) + # To simplify things, we always declare the result-type of a function with + # a declaration in the spec as this can capture every possible situation. + # Therefore, if it has been declared as a prefix in the subroutine statement, + # we now have to inject a declaration instead. To ensure we do this in the + # right place in the spec to not violate the intrinsic order Fortran mandates, + # we search for the first occurence of any VariableDeclaration or + # ProcedureDeclaration and inject it before that one + if return_type is not None: + routine.symbol_attrs[routine.name] = return_type + return_var = sym.Variable(name=routine.name, scope=routine) + decl_source = self.get_source(subroutine_stmt, source=None) + return_var_decl = ir.VariableDeclaration(symbols=(return_var,), source=decl_source) + + decls = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(spec) + if not decls: + # No other declarations: add it to the end + spec.append(return_var_decl) + else: + spec.insert(spec.body.index(decls[0]), return_var_decl) + # Now all declarations are well-defined and we can parse the member routines if contains_ast is not None: contains = self.visit(contains_ast, **kwargs) @@ -1821,26 +1841,6 @@ def visit_Subroutine_Subprogram(self, o, **kwargs): comment_map[node] = None spec = Transformer(comment_map, invalidate_source=False).visit(spec) - # To simplify things, we always declare the result-type of a function with - # a declaration in the spec as this can capture every possible situation. - # Therefore, if it has been declared as a prefix in the subroutine statement, - # we now have to inject a declaration instead. To ensure we do this in the - # right place in the spec to not violate the intrinsic order Fortran mandates, - # we search for the first occurence of any VariableDeclaration or - # ProcedureDeclaration and inject it before that one - if return_type is not None: - routine.symbol_attrs[routine.name] = return_type - return_var = sym.Variable(name=routine.name, scope=routine) - decl_source = self.get_source(subroutine_stmt, source=None) - return_var_decl = ir.VariableDeclaration(symbols=(return_var,), source=decl_source) - - decls = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(spec) - if not decls: - # No other declarations: add it to the end - spec.append(return_var_decl) - else: - spec.insert(spec.body.index(decls[0]), return_var_decl) - # Finally, call the subroutine constructor on the object again to register all # bits and pieces in place and rescope all symbols # pylint: disable=unnecessary-dunder-call diff --git a/loki/frontend/tests/test_frontends.py b/loki/frontend/tests/test_frontends.py index 86ea0762b..61f5791cd 100644 --- a/loki/frontend/tests/test_frontends.py +++ b/loki/frontend/tests/test_frontends.py @@ -5,40 +5,21 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# pylint: disable=too-many-lines - """ Verify correct frontend behaviour and correct parsing of certain Fortran language features. """ -# pylint: disable=too-many-lines - -import platform -from pathlib import Path -from time import perf_counter import numpy as np import pytest from loki import ( - Module, Subroutine, FindVariables, BasicType, config, Sourcefile, - RawSource, RegexParserClass, ProcedureType, DerivedType, - PreprocessorDirective, config_override + Module, Subroutine, Sourcefile, BasicType, config, config_override ) from loki.build import jit_compile, clean_test from loki.expression import symbols as sym -from loki.frontend import available_frontends, OMNI, FP, REGEX -from loki.ir import nodes as ir, FindNodes - - -@pytest.fixture(scope='module', name='here') -def fixture_here(): - return Path(__file__).parent - - -@pytest.fixture(scope='module', name='testdir') -def fixture_testdir(here): - return here.parent.parent/'tests' +from loki.frontend import available_frontends, OMNI, FP +from loki.ir import nodes as ir, FindNodes, FindVariables @pytest.fixture(name='reset_frontend_mode') @@ -48,13 +29,6 @@ def fixture_reset_frontend_mode(): config['frontend-strict-mode'] = original_frontend_mode -@pytest.fixture(name='reset_regex_frontend_timeout') -def fixture_reset_regex_frontend_timeout(): - original_timeout = config['regex-frontend-timeout'] - yield - config['regex-frontend-timeout'] = original_timeout - - @pytest.mark.parametrize('frontend', available_frontends()) def test_check_alloc_opts(tmp_path, frontend): """ @@ -391,1016 +365,6 @@ def test_frontend_strict_mode(frontend, tmp_path): assert 'matrix' in module.typedef_map -def test_regex_subroutine_from_source(): - """ - Verify that the regex frontend is able to parse subroutines - """ - fcode = """ -subroutine routine_b( - ! arg 1 - i, - ! arg2 - j -) - use parkind1, only : jpim - implicit none - integer, intent(in) :: i, j - integer b - b = 4 - - call contained_c(i) - - call routine_a() -contains -!abc ^$^** - integer(kind=jpim) function contained_e(i) - integer, intent(in) :: i - contained_e = i - end function - - subroutine contained_c(i) - integer, intent(in) :: i - integer c - c = 5 - end subroutine contained_c - ! cc£$^£$^ - - subroutine contained_d(i) - integer, intent(in) :: i - integer c - c = 8 - end subroutine !add"£^£$ -end subroutine routine_b - """.strip() - - routine = Subroutine.from_source(fcode, frontend=REGEX) - assert routine.name == 'routine_b' - assert not routine.is_function - assert routine.arguments == () - assert routine.argnames == [] - assert [r.name for r in routine.subroutines] == ['contained_e', 'contained_c', 'contained_d'] - - contained_c = routine['contained_c'] - assert contained_c.name == 'contained_c' - assert not contained_c.is_function - assert contained_c.arguments == () - assert contained_c.argnames == [] - - contained_e = routine['contained_e'] - assert contained_e.name == 'contained_e' - assert contained_e.is_function - assert contained_e.arguments == () - assert contained_e.argnames == [] - - contained_d = routine['contained_d'] - assert contained_d.name == 'contained_d' - assert not contained_d.is_function - assert contained_d.arguments == () - assert contained_d.argnames == [] - - code = routine.to_fortran() - assert code.count('SUBROUTINE') == 6 - assert code.count('FUNCTION') == 2 - assert code.count('CONTAINS') == 1 - - -def test_regex_module_from_source(): - """ - Verify that the regex frontend is able to parse modules - """ - fcode = """ -module some_module - use foobar - implicit none - integer, parameter :: k = selected_int_kind(5) -contains - subroutine module_routine - integer m - m = 2 - - call routine_b(m, 6) - end subroutine module_routine - - integer(kind=k) function module_function(n) - integer n - module_function = n + 2 - end function module_function -end module some_module - """.strip() - - module = Module.from_source(fcode, frontend=REGEX) - assert module.name == 'some_module' - assert [r.name for r in module.subroutines] == ['module_routine', 'module_function'] - - code = module.to_fortran() - assert code.count('MODULE') == 2 - assert code.count('SUBROUTINE') == 2 - assert code.count('FUNCTION') == 2 - assert code.count('CONTAINS') == 1 - - -def test_regex_sourcefile_from_source(): - """ - Verify that the regex frontend is able to parse source files containing - multiple modules and subroutines - """ - fcode = """ -subroutine routine_a - integer a, i - a = 1 - i = a + 1 - - call routine_b(a, i) -end subroutine routine_a - -module some_module -contains - subroutine module_routine - integer m - m = 2 - - call routine_b(m, 6) - end subroutine module_routine - - function module_function(n) - integer n - integer module_function - module_function = n + 3 - end function module_function -end module some_module - -module other_module - integer :: n -end module - -subroutine routine_b( - ! arg 1 - i, - ! arg2 - j, - k!arg3 -) - integer, intent(in) :: i, j, k - integer b - b = 4 - - call contained_c(i) - - call routine_a() -contains -!abc ^$^** - subroutine contained_c(i) - integer, intent(in) :: i - integer c - c = 5 - end subroutine contained_c - ! cc£$^£$^ - integer function contained_e(i) - integer, intent(in) :: i - contained_e = i - end function - - subroutine contained_d(i) - integer, intent(in) :: i - integer c - c = 8 - end subroutine !add"£^£$ -endsubroutine routine_b - -function function_d(d) - integer d - d = 6 -end function function_d - -module last_module - implicit none -contains - subroutine last_routine1 - call contained() - contains - subroutine contained - integer n - n = 1 - end subroutine contained - end subroutine last_routine1 - subroutine last_routine2 - call contained2() - contains - subroutine contained2 - integer m - m = 1 - end subroutine contained2 - end subroutine last_routine2 -end module last_module - """.strip() - - sourcefile = Sourcefile.from_source(fcode, frontend=REGEX) - assert [m.name for m in sourcefile.modules] == ['some_module', 'other_module', 'last_module'] - assert [r.name for r in sourcefile.routines] == [ - 'routine_a', 'routine_b', 'function_d' - ] - assert [r.name for r in sourcefile.all_subroutines] == [ - 'routine_a', 'routine_b', 'function_d', 'module_routine', 'module_function', - 'last_routine1', 'last_routine2' - ] - - assert len(r := sourcefile['last_module']['last_routine1'].routines) == 1 and r[0].name == 'contained' - assert len(r := sourcefile['last_module']['last_routine2'].routines) == 1 and r[0].name == 'contained2' - - code = sourcefile.to_fortran() - assert code.count('SUBROUTINE') == 18 - assert code.count('FUNCTION') == 6 - assert code.count('CONTAINS') == 5 - assert code.count('MODULE') == 6 - - -def test_regex_sourcefile_from_file(testdir): - """ - Verify that the regex frontend is able to parse source files containing - multiple modules and subroutines - """ - - sourcefile = Sourcefile.from_file(testdir/'sources/sourcefile.f90', frontend=REGEX) - assert [m.name for m in sourcefile.modules] == ['some_module'] - assert [r.name for r in sourcefile.routines] == [ - 'routine_a', 'routine_b', 'function_d' - ] - assert [r.name for r in sourcefile.all_subroutines] == [ - 'routine_a', 'routine_b', 'function_d', 'module_routine', 'module_function' - ] - - routine_b = sourcefile['ROUTINE_B'] - assert routine_b.name == 'routine_b' - assert not routine_b.is_function - assert routine_b.arguments == () - assert routine_b.argnames == [] - assert [r.name for r in routine_b.subroutines] == ['contained_c'] - - function_d = sourcefile['function_d'] - assert function_d.name == 'function_d' - assert function_d.is_function - assert function_d.arguments == () - assert function_d.argnames == [] - assert not function_d.contains - - code = sourcefile.to_fortran() - assert code.count('SUBROUTINE') == 8 - assert code.count('FUNCTION') == 4 - assert code.count('CONTAINS') == 2 - assert code.count('MODULE') == 2 - - -def test_regex_sourcefile_from_file_parser_classes(testdir): - - filepath = testdir/'sources/Fortran-extract-interface-source.f90' - module_names = {'bar', 'foo'} - routine_names = { - 'func_simple', 'func_simple_1', 'func_simple_2', 'func_simple_pure', 'func_simple_recursive_pure', - 'func_simple_elemental', 'func_with_use_and_args', 'func_with_parameters', 'func_with_parameters_1', - 'func_with_contains', 'func_mix_local_and_result', 'sub_simple', 'sub_simple_1', 'sub_simple_2', - 'sub_simple_3', 'sub_with_contains', 'sub_with_renamed_import', 'sub_with_external', 'sub_with_end' - } - module_routine_names = {'foo_sub', 'foo_func'} - - # Empty parse (since we don't match typedef without having the enclosing module first) - sourcefile = Sourcefile.from_file(filepath, frontend=REGEX, parser_classes=RegexParserClass.TypeDefClass) - assert not sourcefile.subroutines - assert not sourcefile.modules - assert FindNodes(RawSource).visit(sourcefile.ir) - assert sourcefile._incomplete - assert sourcefile._parser_classes == RegexParserClass.TypeDefClass - - # Incremental addition of program unit objects - sourcefile.make_complete(frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass) - assert sourcefile._incomplete - assert sourcefile._parser_classes == RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass - # Note that the program unit objects don't include the TypeDefClass because it's lower in the hierarchy - # and was not matched previously - assert all( - module._parser_classes == RegexParserClass.ProgramUnitClass - for module in sourcefile.modules - ) - assert all( - routine._parser_classes == RegexParserClass.ProgramUnitClass - for routine in sourcefile.routines - ) - - assert {module.name.lower() for module in sourcefile.modules} == module_names - assert {routine.name.lower() for routine in sourcefile.routines} == routine_names - assert {routine.name.lower() for routine in sourcefile.all_subroutines} == routine_names | module_routine_names - - assert {routine.name.lower() for routine in sourcefile['func_with_contains'].routines} == {'func_with_contains_1'} - assert {routine.name.lower() for routine in sourcefile['sub_with_contains'].routines} == { - 'sub_with_contains_first', 'sub_with_contains_second', 'sub_with_contains_third' - } - - for module in sourcefile.modules: - assert not module.imports - for routine in sourcefile.all_subroutines: - assert not routine.imports - assert not sourcefile['bar'].typedefs - - # Validate that a re-parse with same parser classes does not change anything - sourcefile.make_complete(frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass) - assert sourcefile._incomplete - assert sourcefile._parser_classes == RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass - for module in sourcefile.modules: - assert not module.imports - for routine in sourcefile.all_subroutines: - assert not routine.imports - assert not sourcefile['bar'].typedefs - - # Incremental addition of imports - sourcefile.make_complete( - frontend=REGEX, - parser_classes=RegexParserClass.ProgramUnitClass | RegexParserClass.ImportClass - ) - assert sourcefile._parser_classes == ( - RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass | RegexParserClass.ImportClass - ) - # Note that the program unit objects don't include the TypeDefClass because it's lower in the hierarchy - # and was not matched previously - assert all( - module._parser_classes == ( - RegexParserClass.ProgramUnitClass | RegexParserClass.ImportClass - ) for module in sourcefile.modules - ) - assert all( - routine._parser_classes == ( - RegexParserClass.ProgramUnitClass | RegexParserClass.ImportClass - ) for routine in sourcefile.routines - ) - - assert {module.name.lower() for module in sourcefile.modules} == module_names - assert {routine.name.lower() for routine in sourcefile.routines} == routine_names - assert {routine.name.lower() for routine in sourcefile.all_subroutines} == routine_names | module_routine_names - - assert {routine.name.lower() for routine in sourcefile['func_with_contains'].routines} == {'func_with_contains_1'} - assert {routine.name.lower() for routine in sourcefile['sub_with_contains'].routines} == { - 'sub_with_contains_first', 'sub_with_contains_second', 'sub_with_contains_third' - } - - program_units_with_imports = { - 'foo': ['bar'], 'func_with_use_and_args': ['foo', 'bar'], 'sub_with_contains': ['bar'], - 'sub_with_renamed_import': ['bar'] - } - - for unit in module_names | routine_names | module_routine_names: - if unit in program_units_with_imports: - assert [import_.module.lower() for import_ in sourcefile[unit].imports] == program_units_with_imports[unit] - else: - assert not sourcefile[unit].imports - assert not sourcefile['bar'].typedefs - - # Parse the rest - sourcefile.make_complete(frontend=REGEX, parser_classes=RegexParserClass.AllClasses) - assert sourcefile._parser_classes == RegexParserClass.AllClasses - assert all( - module._parser_classes == RegexParserClass.AllClasses - for module in sourcefile.modules - ) - assert all( - routine._parser_classes == RegexParserClass.AllClasses - for routine in sourcefile.routines - ) - - assert {module.name.lower() for module in sourcefile.modules} == module_names - assert {routine.name.lower() for routine in sourcefile.routines} == routine_names - assert {routine.name.lower() for routine in sourcefile.all_subroutines} == routine_names | module_routine_names - - assert {routine.name.lower() for routine in sourcefile['func_with_contains'].routines} == {'func_with_contains_1'} - assert {routine.name.lower() for routine in sourcefile['sub_with_contains'].routines} == { - 'sub_with_contains_first', 'sub_with_contains_second', 'sub_with_contains_third' - } - - program_units_with_imports = { - 'foo': ['bar'], 'func_with_use_and_args': ['foo', 'bar'], 'sub_with_contains': ['bar'], - 'sub_with_renamed_import': ['bar'] - } - - for unit in module_names | routine_names | module_routine_names: - if unit in program_units_with_imports: - assert [import_.module.lower() for import_ in sourcefile[unit].imports] == program_units_with_imports[unit] - else: - assert not sourcefile[unit].imports - - # Check access via properties - assert 'bar' in sourcefile - assert 'food' in sourcefile['bar'] - assert sorted(sourcefile['bar'].typedef_map) == ['food', 'organic'] - assert sourcefile['bar'].definitions == sourcefile['bar'].typedefs + ('i_am_dim',) - assert 'cooking_method' in sourcefile['bar']['food'] - assert 'foobar' not in sourcefile['bar']['food'] - assert sourcefile['bar']['food'].interface_symbols == () - - # Check that triggering a full parse works from nested scopes - assert sourcefile['bar']._incomplete - sourcefile['bar']['food'].make_complete() - assert not sourcefile['bar']._incomplete - - -def test_regex_raw_source(): - """ - Verify that unparsed source appears in-between matched objects - """ - fcode = """ -! Some comment before the module -! -module some_mod - ! Some docstring - ! docstring - ! docstring - use some_mod - ! Some comment - ! comment - ! comment -end module some_mod - -! Other comment at the end - """.strip() - source = Sourcefile.from_source(fcode, frontend=REGEX) - - assert len(source.ir.body) == 3 - - assert isinstance(source.ir.body[0], RawSource) - assert source.ir.body[0].source.lines == (1, 2) - assert source.ir.body[0].text == '! Some comment before the module\n!' - assert source.ir.body[0].source.string == source.ir.body[0].text - - assert isinstance(source.ir.body[1], Module) - assert source.ir.body[1].source.lines == (3, 11) - assert source.ir.body[1].source.string.startswith('module') - - assert isinstance(source.ir.body[2], RawSource) - assert source.ir.body[2].source.lines == (12, 13) - assert source.ir.body[2].text == '\n! Other comment at the end' - assert source.ir.body[2].source.string == source.ir.body[2].text - - module = source['some_mod'] - assert len(module.spec.body) == 3 - assert isinstance(module.spec.body[0], RawSource) - assert isinstance(module.spec.body[1], ir.Import) - assert isinstance(module.spec.body[2], RawSource) - - assert module.spec.body[0].text.count('docstring') == 3 - assert module.spec.body[2].text.count('comment') == 3 - - -def test_regex_raw_source_with_cpp(): - """ - Verify that unparsed source appears in-between matched objects - and preprocessor statements are preserved - """ - fcode = """ -! Some comment before the subroutine -#ifdef RS6K -@PROCESS HOT(NOVECTOR) NOSTRICT -#endif -SUBROUTINE SOME_ROUTINE (KLON, KLEV) -IMPLICIT NONE -INTEGER, INTENT(IN) :: KLON, KLEV -! Comment inside routine -END SUBROUTINE SOME_ROUTINE - """.strip() - source = Sourcefile.from_source(fcode, frontend=REGEX) - - assert len(source.ir.body) == 2 - - assert isinstance(source.ir.body[0], RawSource) - assert source.ir.body[0].source.lines == (1, 4) - assert source.ir.body[0].text.startswith('! Some comment before the subroutine\n#') - assert source.ir.body[0].text.endswith('#endif') - assert source.ir.body[0].source.string == source.ir.body[0].text - - assert isinstance(source.ir.body[1], Subroutine) - assert source.ir.body[1].source.lines == (5, 9) - assert source.ir.body[1].source.string.startswith('SUBROUTINE') - - -def test_regex_raw_source_with_cpp_incomplete(): - """ - Verify that unparsed source appears inside matched objects if - parser classes are used to restrict the matching - """ - fcode = """ -SUBROUTINE driver(a, b, c) - INTEGER, INTENT(INOUT) :: a, b, c - -#include "kernel.intfb.h" - - CALL kernel(a, b ,c) -END SUBROUTINE driver - """.strip() - parser_classes = RegexParserClass.ProgramUnitClass - source = Sourcefile.from_source(fcode, frontend=REGEX, parser_classes=parser_classes) - - assert len(source.ir.body) == 1 - driver = source['driver'] - assert isinstance(driver, Subroutine) - assert not driver.docstring - assert not driver.body - assert not driver.contains - assert driver.spec and len(driver.spec.body) == 1 - assert isinstance(driver.spec.body[0], RawSource) - assert 'INTEGER, INTENT' in driver.spec.body[0].text - assert '#include' in driver.spec.body[0].text - - -@pytest.mark.parametrize('frontend', available_frontends( - xfail=[(OMNI, 'Non-standard notation needs full preprocessing')] -)) -def test_make_complete_sanitize(frontend): - """ - Test that attempts to first REGEX-parse and then complete source code - with unsupported features that require "frontend sanitization". - """ - fcode = """ -! Some comment before the subroutine -#ifdef RS6K -@PROCESS HOT(NOVECTOR) NOSTRICT -#endif -SUBROUTINE SOME_ROUTINE (KLON, KLEV) - IMPLICIT NONE - INTEGER, INTENT(IN) :: KLON, KLEV - ! Comment inside routine -END SUBROUTINE SOME_ROUTINE - """.strip() - source = Sourcefile.from_source(fcode, frontend=REGEX) - - # Ensure completion handles the non-supported features (@PROCESS) - source.make_complete(frontend=frontend) - - comments = FindNodes(ir.Comment).visit(source.ir) - assert len(comments) == 2 if frontend == FP else 1 - assert comments[0].text == '! Some comment before the subroutine' - if frontend == FP: - assert comments[1].text == '@PROCESS HOT(NOVECTOR) NOSTRICT' - - directives = FindNodes(PreprocessorDirective).visit(source.ir) - assert len(directives) == 2 - assert directives[0].text == '#ifdef RS6K' - assert directives[1].text == '#endif' - - -@pytest.mark.skipif(platform.system() == 'Darwin', - reason='Timeout utility test sporadically fails on MacOS CI runners.' -) -@pytest.mark.usefixtures('reset_regex_frontend_timeout') -def test_regex_timeout(): - """ - This source fails to parse because of missing SUBROUTINE in END - statement, and the test verifies that a timeout is encountered - """ - fcode = """ -subroutine some_routine(a) - real, intent(in) :: a -end - """.strip() - - # Test timeout - config['regex-frontend-timeout'] = 1 - start = perf_counter() - with pytest.raises(RuntimeError) as exc: - _ = Sourcefile.from_source(fcode, frontend=REGEX) - stop = perf_counter() - assert .9 < stop - start < 1.1 - assert 'REGEX frontend timeout of 1 s exceeded' in str(exc.value) - - # Test it works fine with proper Fortran: - fcode += ' subroutine' - source = Sourcefile.from_source(fcode, frontend=REGEX) - assert len(source.subroutines) == 1 - assert source.subroutines[0].name == 'some_routine' - - -def test_regex_module_imports(): - """ - Verify that the regex frontend is able to find and correctly parse - Fortran imports - """ - fcode = """ -module some_mod - use no_symbols_mod - use only_mod, only: my_var - use test_rename_mod, first_var1 => var1, first_var3 => var3 - use test_other_rename_mod, only: second_var1 => var1 - use test_other_rename_mod, only: other_var2 => var2, other_var3 => var3 - implicit none -end module some_mod - """.strip() - - module = Module.from_source(fcode, frontend=REGEX) - imports = FindNodes(ir.Import).visit(module.spec) - assert len(imports) == 5 - assert [import_.module for import_ in imports] == [ - 'no_symbols_mod', 'only_mod', 'test_rename_mod', 'test_other_rename_mod', - 'test_other_rename_mod' - ] - assert set(module.imported_symbols) == { - 'my_var', 'first_var1', 'first_var3', 'second_var1', 'other_var2', 'other_var3' - } - assert module.imported_symbol_map['first_var1'].type.use_name == 'var1' - assert module.imported_symbol_map['first_var3'].type.use_name == 'var3' - assert module.imported_symbol_map['second_var1'].type.use_name == 'var1' - assert module.imported_symbol_map['other_var2'].type.use_name == 'var2' - assert module.imported_symbol_map['other_var3'].type.use_name == 'var3' - - -def test_regex_subroutine_imports(): - """ - Verify that the regex frontend is able to find and correctly parse - Fortran imports - """ - fcode = """ -subroutine some_routine - use no_symbols_mod - use only_mod, only: my_var - use test_rename_mod, first_var1 => var1, first_var3 => var3 - use test_other_rename_mod, only: second_var1 => var1 - use test_other_rename_mod, only: other_var2 => var2, other_var3 => var3 - implicit none -end subroutine some_routine - """.strip() - - routine = Subroutine.from_source(fcode, frontend=REGEX) - imports = FindNodes(ir.Import).visit(routine.spec) - assert len(imports) == 5 - assert [import_.module for import_ in imports] == [ - 'no_symbols_mod', 'only_mod', 'test_rename_mod', 'test_other_rename_mod', - 'test_other_rename_mod' - ] - assert set(routine.imported_symbols) == { - 'my_var', 'first_var1', 'first_var3', 'second_var1', 'other_var2', 'other_var3' - } - assert routine.imported_symbol_map['first_var1'].type.use_name == 'var1' - assert routine.imported_symbol_map['first_var3'].type.use_name == 'var3' - assert routine.imported_symbol_map['second_var1'].type.use_name == 'var1' - assert routine.imported_symbol_map['other_var2'].type.use_name == 'var2' - assert routine.imported_symbol_map['other_var3'].type.use_name == 'var3' - - -def test_regex_import_linebreaks(): - """ - Verify correct handling of line breaks in import statements - """ - fcode = """ -module file_io_mod - USE PARKIND1 , ONLY : JPIM, JPRB, JPRD - -#ifdef HAVE_SERIALBOX - USE m_serialize, ONLY: & - fs_create_savepoint, & - fs_add_serializer_metainfo, & - fs_get_serializer_metainfo, & - fs_read_field, & - fs_write_field - USE utils_ppser, ONLY: & - ppser_initialize, & - ppser_finalize, & - ppser_serializer, & - ppser_serializer_ref, & - ppser_set_mode, & - ppser_savepoint -#endif - -#ifdef HAVE_HDF5 - USE hdf5_file_mod, only: hdf5_file -#endif - - implicit none -end module file_io_mod - """.strip() - module = Module.from_source(fcode, frontend=REGEX) - imports = FindNodes(ir.Import).visit(module.spec) - assert len(imports) == 4 - assert [import_.module for import_ in imports] == ['PARKIND1', 'm_serialize', 'utils_ppser', 'hdf5_file_mod'] - assert all( - s in module.imported_symbols for s in [ - 'JPIM', 'JPRB', 'JPRD', 'fs_create_savepoint', 'fs_add_serializer_metainfo', 'fs_get_serializer_metainfo', - 'fs_read_field', 'fs_write_field', 'ppser_initialize', 'ppser_finalize', 'ppser_serializer', - 'ppser_serializer_ref', 'ppser_set_mode', 'ppser_savepoint', 'hdf5_file' - ] - ) - - -def test_regex_typedef(): - """ - Verify that the regex frontend is able to parse type definitions and - correctly parse procedure bindings. - """ - fcode = """ -module typebound_item - implicit none - type some_type - contains - procedure, nopass :: routine => module_routine - procedure :: some_routine - procedure, pass :: other_routine - procedure :: routine1, & - & routine2 => routine - ! procedure :: routine1 - ! procedure :: routine2 => routine - end type some_type -contains - subroutine module_routine - integer m - m = 2 - end subroutine module_routine - - subroutine some_routine(self) - class(some_type) :: self - - call self%routine - end subroutine some_routine - - subroutine other_routine(self, m) - class(some_type), intent(inout) :: self - integer, intent(in) :: m - integer :: j - - j = m - call self%routine1 - call self%routine2 - end subroutine other_routine - - subroutine routine(self) - class(some_type) :: self - call self%some_routine - end subroutine routine - - subroutine routine1(self) - class(some_type) :: self - call module_routine - end subroutine routine1 -end module typebound_item - """.strip() - - module = Module.from_source(fcode, frontend=REGEX) - - assert 'some_type' in module.typedef_map - some_type = module.typedef_map['some_type'] - - proc_bindings = { - 'routine': ('module_routine',), - 'some_routine': None, - 'other_routine': None, - 'routine1': None, - 'routine2': ('routine',) - } - assert len(proc_bindings) == len(some_type.variables) - assert all(proc in some_type.variables for proc in proc_bindings) - assert all( - some_type.variable_map[proc].type.bind_names == bind - for proc, bind in proc_bindings.items() - ) - - -def test_regex_typedef_generic(): - fcode = """ -module typebound_header - implicit none - - type header_type - contains - procedure :: member_routine => header_member_routine - procedure :: routine_real => header_routine_real - procedure :: routine_integer - generic :: routine => routine_real, routine_integer - end type header_type - -contains - - subroutine header_member_routine(self, val) - class(header_type) :: self - integer, intent(in) :: val - integer :: j - j = val - end subroutine header_member_routine - - subroutine header_routine_real(self, val) - class(header_type) :: self - real, intent(out) :: val - val = 1.0 - end subroutine header_routine_real - - subroutine routine_integer(self, val) - class(header_type) :: self - integer, intent(out) :: val - val = 1 - end subroutine routine_integer -end module typebound_header - """.strip() - - module = Module.from_source(fcode, frontend=REGEX) - - assert 'header_type' in module.typedef_map - header_type = module.typedef_map['header_type'] - - proc_bindings = { - 'member_routine': ('header_member_routine',), - 'routine_real': ('header_routine_real',), - 'routine_integer': None, - 'routine': ('routine_real', 'routine_integer') - } - assert len(proc_bindings) == len(header_type.variables) - assert all(proc in header_type.variables for proc in proc_bindings) - assert all( - ( - header_type.variable_map[proc].type.bind_names == bind - and header_type.variable_map[proc].type.initial is None - ) - for proc, bind in proc_bindings.items() - ) - - -def test_regex_loki_69(): - """ - Test compliance of REGEX frontend with edge cases reported in LOKI-69. - This should become a full-blown Scheduler test when REGEX frontend undeprins the scheduler. - """ - fcode = """ -subroutine random_call_0(v_out,v_in,v_inout) -implicit none - - real(kind=jprb),intent(in) :: v_in - real(kind=jprb),intent(out) :: v_out - real(kind=jprb),intent(inout) :: v_inout - - -end subroutine random_call_0 - -!subroutine random_call_1(v_out,v_in,v_inout) -!implicit none -! -! real(kind=jprb),intent(in) :: v_in -! real(kind=jprb),intent(out) :: v_out -! real(kind=jprb),intent(inout) :: v_inout -! -! -!end subroutine random_call_1 - -subroutine random_call_2(v_out,v_in,v_inout) -implicit none - - real(kind=jprb),intent(in) :: v_in - real(kind=jprb),intent(out) :: v_out - real(kind=jprb),intent(inout) :: v_inout - - -end subroutine random_call_2 - -subroutine test(v_out,v_in,v_inout,some_logical) -implicit none - - real(kind=jprb),intent(in ) :: v_in - real(kind=jprb),intent(out ) :: v_out - real(kind=jprb),intent(inout) :: v_inout - - logical,intent(in) :: some_logical - - v_inout = 0._jprb - if(some_logical)then - call random_call_0(v_out,v_in,v_inout) - endif - - if(some_logical) call random_call_2 - -end subroutine test - """.strip() - - source = Sourcefile.from_source(fcode, frontend=REGEX) - assert [r.name for r in source.all_subroutines] == ['random_call_0', 'random_call_2', 'test'] - - calls = FindNodes(ir.CallStatement).visit(source['test'].ir) - assert [call.name for call in calls] == ['RANDOM_CALL_0', 'random_call_2'] - - variable_map_test = source['test'].variable_map - v_in_type = variable_map_test['v_in'].type - assert v_in_type.dtype is BasicType.REAL - assert v_in_type.kind == 'jprb' - - -def test_regex_variable_declaration(testdir): - """ - Test correct parsing of derived type variable declarations - """ - filepath = testdir/'sources/projTypeBound/typebound_item.F90' - source = Sourcefile.from_file(filepath, frontend=REGEX) - - driver = source['driver'] - assert driver.variables == ('constant', 'obj', 'obj2', 'header', 'other_obj', 'derived', 'x', 'i') - assert source['module_routine'].variables == ('m',) - assert source['other_routine'].variables == ('self', 'm', 'j') - assert source['routine'].variables == ('self',) - assert source['routine1'].variables == ('self',) - - # Check this for REGEX and complete parse to make sure their behaviour is aligned - for _ in range(2): - var_map = driver.symbol_map - assert isinstance(var_map['obj'].type.dtype, DerivedType) - assert var_map['obj'].type.dtype.name == 'some_type' - assert isinstance(var_map['obj2'].type.dtype, DerivedType) - assert var_map['obj2'].type.dtype.name == 'some_type' - assert isinstance(var_map['header'].type.dtype, DerivedType) - assert var_map['header'].type.dtype.name == 'header_type' - assert isinstance(var_map['other_obj'].type.dtype, DerivedType) - assert var_map['other_obj'].type.dtype.name == 'other' - assert isinstance(var_map['derived'].type.dtype, DerivedType) - assert var_map['derived'].type.dtype.name == 'other' - assert isinstance(var_map['x'].type.dtype, BasicType) - assert var_map['x'].type.dtype is BasicType.REAL - assert isinstance(var_map['i'].type.dtype, BasicType) - assert var_map['i'].type.dtype is BasicType.INTEGER - - # While we're here: let's check the call statements, too - calls = FindNodes(ir.CallStatement).visit(driver.ir) - assert len(calls) == 7 - assert all(isinstance(call.name.type.dtype, ProcedureType) for call in calls) - - # Note: we're explicitly accessing the string name here (instead of relying - # on the StrCompareMixin) as some have dimensions that only show up in the full - # parse - assert calls[0].name.name == 'obj%other_routine' - assert calls[0].name.parent.name == 'obj' - assert calls[1].name.name == 'obj2%some_routine' - assert calls[1].name.parent.name == 'obj2' - assert calls[2].name.name == 'header%member_routine' - assert calls[2].name.parent.name == 'header' - assert calls[3].name.name == 'header%routine' - assert calls[3].name.parent.name == 'header' - assert calls[4].name.name == 'header%routine' - assert calls[4].name.parent.name == 'header' - assert calls[5].name.name == 'other_obj%member' - assert calls[5].name.parent.name == 'other_obj' - assert calls[6].name.name == 'derived%var%member_routine' - assert calls[6].name.parent.name == 'derived%var' - assert calls[6].name.parent.parent.name == 'derived' - - # Hack: Split the procedure binding into one-per-line until Fparser - # supports this... - module = source['typebound_item'] - module.source.string = module.source.string.replace( - 'procedure :: routine1,', 'procedure :: routine1\nprocedure ::' - ) - - source.make_complete() - - -def test_regex_variable_declaration_parentheses(): - fcode = """ -subroutine definitely_not_allfpos(ydfpdata) -implicit none -integer, parameter :: NMaxCloudTypes = 12 -type(tfpdata), intent(in) :: ydfpdata -type(tfpofn) :: ylofn(size(ydfpdata%yfpos%yfpgeometry%yfpusergeo)) -real, dimension(nproma, max(nang, 1), max(nfre, 1)) :: not_an_annoying_ecwam_var -character(len=511) :: cloud_type_name(NMaxCloudTypes) = ["","","","","","","","","","","",""], other_name = "", names(3) = (/ "", "", "" /) -character(len=511) :: more_names(2) = (/ "What", " is" /), naaaames(2) = [ " going ", "on?" ] -end subroutine definitely_not_allfpos - """.strip() - - source = Sourcefile.from_source(fcode, frontend=REGEX) - routine = source['definitely_not_allfpos'] - assert routine.variables == ( - 'nmaxcloudtypes', 'ydfpdata', 'ylofn', 'not_an_annoying_ecwam_var', - 'cloud_type_name', 'other_name', 'names', 'more_names', 'naaaames' - ) - assert routine.symbol_map['not_an_annoying_ecwam_var'].type.dtype is BasicType.REAL - assert routine.symbol_map['cloud_type_name'].type.dtype is BasicType.CHARACTER - - -def test_regex_preproc_in_contains(): - fcode = """ -module preproc_in_contains - implicit none - public :: routine1, routine2, func -contains -#include "some_include.h" - subroutine routine1 - end subroutine routine1 - - module subroutine mod_routine - call other_routine - contains -#define something - subroutine other_routine - end subroutine other_routine - end subroutine mod_routine - - elemental function func - real func - end function func -end module preproc_in_contains - """.strip() - source = Sourcefile.from_source(fcode, frontend=REGEX) - - expected_names = {'preproc_in_contains', 'routine1', 'mod_routine', 'func'} - actual_names = {r.name for r in source.all_subroutines} | {m.name for m in source.modules} - assert expected_names == actual_names - - assert isinstance(source['mod_routine']['other_routine'], Subroutine) - - @pytest.mark.parametrize('frontend', available_frontends()) def test_frontend_pragma_vs_comment(frontend, tmp_path): """ @@ -1472,231 +436,6 @@ def test_frontend_source_lineno(frontend): assert calls[0].source.lines[0] < calls[1].source.lines[0] < calls[2].source.lines[0] -def test_regex_interface_subroutine(): - fcode = """ -subroutine test(callback) - -implicit none -interface - subroutine some_kernel(a, b, c) - integer, intent(in) :: a, b - integer, intent(out) :: c - end subroutine some_kernel - - SUBROUTINE other_kernel(a) - integer, intent(inout) :: a - end subroutine -end interface - -INTERFACE - function other_func(a) - integer, intent(in) :: a - integer, other_func - end function other_func -end interface - -abstract interface - function callback_func(a) result(b) - integer, intent(in) :: a - integer :: b - end FUNCTION callback_func -end INTERFACE - -procedure(callback_func), pointer, intent(in) :: callback -integer :: a, b, c - -a = callback(1) -b = other_func(a) - -call some_kernel(a, b, c) -call other_kernel(c) - -end subroutine test - """.strip() - - # Make sure only the host subroutine is captured - source = Sourcefile.from_source(fcode, frontend=REGEX) - assert len(source.subroutines) == 1 - assert source.subroutines[0].name == 'test' - assert source.subroutines[0].source.lines == (1, 38) - - # Make sure this also works for module procedures - fcode = f""" -module my_mod - implicit none -contains -{fcode} -end module my_mod - """.strip() - source = Sourcefile.from_source(fcode, frontend=REGEX) - assert not source.subroutines - assert len(source.all_subroutines) == 1 - assert source.all_subroutines[0].name == 'test' - assert source.all_subroutines[0].source.lines == (4, 41) - - -def test_regex_interface_module(): - fcode = """ -module my_mod - implicit none - interface - subroutine ext1 (x, y, z) - real, dimension(100, 100), intent(inout) :: x, y, z - end subroutine ext1 - subroutine ext2 (x, z) - real, intent(in) :: x - complex(kind = 4), intent(inout) :: z(2000) - end subroutine ext2 - function ext3 (p, q) - logical ext3 - integer, intent(in) :: p(1000) - logical, intent(in) :: q(1000) - end function ext3 - end interface - interface sub - subroutine sub_int (a) - integer, intent(in) :: a(:) - end subroutine sub_int - subroutine sub_real (a) - real, intent(in) :: a(:) - end subroutine sub_real - end interface sub - interface func - module procedure func_int - module procedure func_real - end interface func -contains - subroutine sub_int (a) - integer, intent(in) :: a(:) - end subroutine sub_int - subroutine sub_real (a) - real, intent(in) :: a(:) - end subroutine sub_real - integer module function func_int (a) - integer, intent(in) :: a(:) - end function func_int - real module function func_real (a) - real, intent(in) :: a(:) - end function func_real -end module my_mod - """.strip() - source = Sourcefile.from_source(fcode, frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass) - - assert len(source.modules) == 1 - assert source['my_mod'] is not None - assert not source['my_mod'].interfaces - - source.make_complete( - frontend=REGEX, - parser_class=RegexParserClass.ProgramUnitClass | RegexParserClass.InterfaceClass - ) - assert len(source['my_mod'].interfaces) == 3 - assert source['my_mod'].symbols == ( - 'ext1', 'ext2', 'ext3', - 'sub', 'sub_int', 'sub_real', - 'func', 'func_int', 'func_real', 'func_int', 'func_real', - 'sub_int', 'sub_real', - 'func_int', 'func_real' - ) - - -def test_regex_function_inline_return_type(): - fcode = """ -REAL(KIND=JPRB) FUNCTION DOT_PRODUCT_ECV() - -END FUNCTION DOT_PRODUCT_ECV - -SUBROUTINE DOT_PROD_SP_2D() - -END SUBROUTINE DOT_PROD_SP_2D - """.strip() - source = Sourcefile.from_source(fcode, frontend=REGEX) - assert { - routine.name.lower() for routine in source.subroutines - } == {'dot_product_ecv', 'dot_prod_sp_2d'} - - source.make_complete() - routine = source['dot_product_ecv'] - assert 'dot_product_ecv' in routine.variables - - -@pytest.mark.parametrize('frontend', available_frontends()) -def test_regex_prefix(frontend, tmp_path): - fcode = """ -module some_mod - implicit none -contains - pure elemental real function f_elem(a) - real, intent(in) :: a - f_elem = a - end function f_elem - - pure recursive integer function fib(i) result(fib_i) - integer, intent(in) :: i - if (i <= 0) then - fib_i = 0 - else if (i == 1) then - fib_i = 1 - else - fib_i = fib(i-1) + fib(i-2) - end if - end function fib -end module some_mod - """.strip() - source = Sourcefile.from_source(fcode, frontend=REGEX) - assert source['f_elem'].prefix == ('pure elemental real',) - assert source['fib'].prefix == ('pure recursive integer',) - source.make_complete(frontend=frontend, xmods=[tmp_path]) - assert tuple(p.lower() for p in source['f_elem'].prefix) == ('pure', 'elemental') - assert tuple(p.lower() for p in source['fib'].prefix) == ('pure', 'recursive') - - -def test_regex_fypp(): - """ - Test that unexpanded fypp-annotations are handled gracefully in the REGEX frontend. - """ - fcode = """ -module fypp_mod -! A pre-set array of pre-prcessor variables -#:mute -#:set foo = [2,3,4,5] -#:endmute - -contains - -! A non-templated routine -subroutine first_routine(i, x) - integer, intent(in) :: i - real, intent(inout) :: x(3) -end subroutine first_routine - -! A fypp-loop with in-place directives for subroutine names -#:for bar in foo -#:set rname = 'routine_%s' % (bar,) -subroutine ${rname}$ (i, x) - integer, intent(in) :: i - real, intent(inout) :: x(3) -end subroutine ${rname}$ -#:endfor - -! Another non-templated routine -subroutine last_routine(i, x) - integer, intent(in) :: i - real, intent(inout) :: x(3) -end subroutine last_routine - -end module fypp_mod -""" - source = Sourcefile.from_source(fcode, frontend=REGEX) - module = source['fypp_mod'] - assert isinstance(module, Module) - - # Check that only non-templated routines are included - assert len(module.routines) == 2 - assert module.routines[0].name == 'first_routine' - assert module.routines[1].name == 'last_routine' - - @pytest.mark.parametrize( 'frontend', available_frontends(include_regex=True, xfail=[(OMNI, 'OMNI may segfault on empty files')]) @@ -2204,3 +943,75 @@ def test_intrinsic_shadowing(tmp_path, frontend): assert isinstance(assigns[2].rhs.function, sym.ProcedureSymbol) assert not assigns[2].rhs.function.type.is_intrinsic assert assigns[2].rhs.function.type.dtype.procedure == algebra['min'] + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_function_symbol_scoping(frontend): + """ Check that the return symbol of a function has the right scope """ + fcode = """ +real function double_real(i) + implicit none + integer, intent(in) :: i + + double_real = dble(i*2) +end function double_real +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + + rsym = routine.variable_map['double_real'] + assert isinstance(rsym, sym.Scalar) + assert rsym.type.dtype == BasicType.REAL + assert rsym.scope == routine + + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert len(assigns) == 1 + assert assigns[0].lhs == 'double_real' + assert isinstance(assigns[0].lhs, sym.Scalar) + assert assigns[0].lhs.type.dtype == BasicType.REAL + assert assigns[0].lhs.scope == routine + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_frontend_derived_type_imports(tmp_path, frontend): + """ Checks that provided module and type info is attached during parse """ + fcode_module = """ +module my_type_mod + type my_type + real(kind=8) :: a, b(:) + end type my_type +end module my_type_mod +""" + + fcode = """ +subroutine test_derived_type_parse + use my_type_mod, only: my_type + implicit none + type(my_type) :: obj + + obj%a = 42.0 + obj%b = 66.6 +end subroutine test_derived_type_parse +""" + module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path]) + routine = Subroutine.from_source( + fcode, definitions=module, frontend=frontend, xmods=[tmp_path] + ) + + assert len(module.typedefs) == 1 + assert module.typedefs[0].name == 'my_type' + + # Ensure that the imported type is recognised as such + assert len(routine.imports) == 1 + assert routine.imports[0].module == 'my_type_mod' + assert len(routine.imports[0].symbols) == 1 + assert routine.imports[0].symbols[0] == 'my_type' + assert isinstance(routine.imports[0].symbols[0], sym.DerivedTypeSymbol) + + # Ensure that the declared variable and its components are recognised + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert len(assigns) == 2 + assert isinstance(assigns[0].lhs, sym.Scalar) + assert assigns[0].lhs.type.dtype == BasicType.REAL + assert isinstance(assigns[1].lhs, sym.Array) + assert assigns[1].lhs.type.dtype == BasicType.REAL + assert assigns[1].lhs.type.shape == (':',) diff --git a/loki/frontend/tests/test_regex_frontend.py b/loki/frontend/tests/test_regex_frontend.py index d1e0b642a..73e4ae9de 100644 --- a/loki/frontend/tests/test_regex_frontend.py +++ b/loki/frontend/tests/test_regex_frontend.py @@ -9,9 +9,1271 @@ Verify correct parsing behaviour of the REGEX frontend """ -from loki.frontend import REGEX -from loki.types import BasicType, DerivedType -from loki.subroutine import Subroutine +from pathlib import Path + +import platform +from time import perf_counter +import pytest + +from loki import Module, Subroutine, Sourcefile, RawSource, config +from loki.frontend import ( + available_frontends, OMNI, FP, REGEX, RegexParserClass +) +from loki.ir import nodes as ir, FindNodes, PreprocessorDirective +from loki.types import BasicType, ProcedureType, DerivedType + + +@pytest.fixture(scope='module', name='here') +def fixture_here(): + return Path(__file__).parent + + +@pytest.fixture(scope='module', name='testdir') +def fixture_testdir(here): + return here.parent.parent/'tests' + + +@pytest.fixture(name='reset_regex_frontend_timeout') +def fixture_reset_regex_frontend_timeout(): + original_timeout = config['regex-frontend-timeout'] + yield + config['regex-frontend-timeout'] = original_timeout + + +def test_regex_subroutine_from_source(): + """ + Verify that the regex frontend is able to parse subroutines + """ + fcode = """ +subroutine routine_b( + ! arg 1 + i, + ! arg2 + j +) + use parkind1, only : jpim + implicit none + integer, intent(in) :: i, j + integer b + b = 4 + + call contained_c(i) + + call routine_a() +contains +!abc ^$^** + integer(kind=jpim) function contained_e(i) + integer, intent(in) :: i + contained_e = i + end function + + subroutine contained_c(i) + integer, intent(in) :: i + integer c + c = 5 + end subroutine contained_c + ! cc£$^£$^ + + subroutine contained_d(i) + integer, intent(in) :: i + integer c + c = 8 + end subroutine !add"£^£$ +end subroutine routine_b + """.strip() + + routine = Subroutine.from_source(fcode, frontend=REGEX) + assert routine.name == 'routine_b' + assert not routine.is_function + assert routine.arguments == () + assert routine.argnames == [] + assert [r.name for r in routine.subroutines] == ['contained_e', 'contained_c', 'contained_d'] + + contained_c = routine['contained_c'] + assert contained_c.name == 'contained_c' + assert not contained_c.is_function + assert contained_c.arguments == () + assert contained_c.argnames == [] + + contained_e = routine['contained_e'] + assert contained_e.name == 'contained_e' + assert contained_e.is_function + assert contained_e.arguments == () + assert contained_e.argnames == [] + + contained_d = routine['contained_d'] + assert contained_d.name == 'contained_d' + assert not contained_d.is_function + assert contained_d.arguments == () + assert contained_d.argnames == [] + + code = routine.to_fortran() + assert code.count('SUBROUTINE') == 6 + assert code.count('FUNCTION') == 2 + assert code.count('CONTAINS') == 1 + + +def test_regex_module_from_source(): + """ + Verify that the regex frontend is able to parse modules + """ + fcode = """ +module some_module + use foobar + implicit none + integer, parameter :: k = selected_int_kind(5) +contains + subroutine module_routine + integer m + m = 2 + + call routine_b(m, 6) + end subroutine module_routine + + integer(kind=k) function module_function(n) + integer n + module_function = n + 2 + end function module_function +end module some_module + """.strip() + + module = Module.from_source(fcode, frontend=REGEX) + assert module.name == 'some_module' + assert [r.name for r in module.subroutines] == ['module_routine', 'module_function'] + + code = module.to_fortran() + assert code.count('MODULE') == 2 + assert code.count('SUBROUTINE') == 2 + assert code.count('FUNCTION') == 2 + assert code.count('CONTAINS') == 1 + + +def test_regex_sourcefile_from_source(): + """ + Verify that the regex frontend is able to parse source files containing + multiple modules and subroutines + """ + fcode = """ +subroutine routine_a + integer a, i + a = 1 + i = a + 1 + + call routine_b(a, i) +end subroutine routine_a + +module some_module +contains + subroutine module_routine + integer m + m = 2 + + call routine_b(m, 6) + end subroutine module_routine + + function module_function(n) + integer n + integer module_function + module_function = n + 3 + end function module_function +end module some_module + +module other_module + integer :: n +end module + +subroutine routine_b( + ! arg 1 + i, + ! arg2 + j, + k!arg3 +) + integer, intent(in) :: i, j, k + integer b + b = 4 + + call contained_c(i) + + call routine_a() +contains +!abc ^$^** + subroutine contained_c(i) + integer, intent(in) :: i + integer c + c = 5 + end subroutine contained_c + ! cc£$^£$^ + integer function contained_e(i) + integer, intent(in) :: i + contained_e = i + end function + + subroutine contained_d(i) + integer, intent(in) :: i + integer c + c = 8 + end subroutine !add"£^£$ +endsubroutine routine_b + +function function_d(d) + integer d + d = 6 +end function function_d + +module last_module + implicit none +contains + subroutine last_routine1 + call contained() + contains + subroutine contained + integer n + n = 1 + end subroutine contained + end subroutine last_routine1 + subroutine last_routine2 + call contained2() + contains + subroutine contained2 + integer m + m = 1 + end subroutine contained2 + end subroutine last_routine2 +end module last_module + """.strip() + + sourcefile = Sourcefile.from_source(fcode, frontend=REGEX) + assert [m.name for m in sourcefile.modules] == ['some_module', 'other_module', 'last_module'] + assert [r.name for r in sourcefile.routines] == [ + 'routine_a', 'routine_b', 'function_d' + ] + assert [r.name for r in sourcefile.all_subroutines] == [ + 'routine_a', 'routine_b', 'function_d', 'module_routine', 'module_function', + 'last_routine1', 'last_routine2' + ] + + assert len(r := sourcefile['last_module']['last_routine1'].routines) == 1 and r[0].name == 'contained' + assert len(r := sourcefile['last_module']['last_routine2'].routines) == 1 and r[0].name == 'contained2' + + code = sourcefile.to_fortran() + assert code.count('SUBROUTINE') == 18 + assert code.count('FUNCTION') == 6 + assert code.count('CONTAINS') == 5 + assert code.count('MODULE') == 6 + + +def test_regex_sourcefile_from_file(testdir): + """ + Verify that the regex frontend is able to parse source files containing + multiple modules and subroutines + """ + + sourcefile = Sourcefile.from_file(testdir/'sources/sourcefile.f90', frontend=REGEX) + assert [m.name for m in sourcefile.modules] == ['some_module'] + assert [r.name for r in sourcefile.routines] == [ + 'routine_a', 'routine_b', 'function_d' + ] + assert [r.name for r in sourcefile.all_subroutines] == [ + 'routine_a', 'routine_b', 'function_d', 'module_routine', 'module_function' + ] + + routine_b = sourcefile['ROUTINE_B'] + assert routine_b.name == 'routine_b' + assert not routine_b.is_function + assert routine_b.arguments == () + assert routine_b.argnames == [] + assert [r.name for r in routine_b.subroutines] == ['contained_c'] + + function_d = sourcefile['function_d'] + assert function_d.name == 'function_d' + assert function_d.is_function + assert function_d.arguments == () + assert function_d.argnames == [] + assert not function_d.contains + + code = sourcefile.to_fortran() + assert code.count('SUBROUTINE') == 8 + assert code.count('FUNCTION') == 4 + assert code.count('CONTAINS') == 2 + assert code.count('MODULE') == 2 + + +def test_regex_sourcefile_from_file_parser_classes(testdir): + + filepath = testdir/'sources/Fortran-extract-interface-source.f90' + module_names = {'bar', 'foo'} + routine_names = { + 'func_simple', 'func_simple_1', 'func_simple_2', 'func_simple_pure', 'func_simple_recursive_pure', + 'func_simple_elemental', 'func_with_use_and_args', 'func_with_parameters', 'func_with_parameters_1', + 'func_with_contains', 'func_mix_local_and_result', 'sub_simple', 'sub_simple_1', 'sub_simple_2', + 'sub_simple_3', 'sub_with_contains', 'sub_with_renamed_import', 'sub_with_external', 'sub_with_end' + } + module_routine_names = {'foo_sub', 'foo_func'} + + # Empty parse (since we don't match typedef without having the enclosing module first) + sourcefile = Sourcefile.from_file(filepath, frontend=REGEX, parser_classes=RegexParserClass.TypeDefClass) + assert not sourcefile.subroutines + assert not sourcefile.modules + assert FindNodes(RawSource).visit(sourcefile.ir) + assert sourcefile._incomplete + assert sourcefile._parser_classes == RegexParserClass.TypeDefClass + + # Incremental addition of program unit objects + sourcefile.make_complete(frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass) + assert sourcefile._incomplete + assert sourcefile._parser_classes == RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass + # Note that the program unit objects don't include the TypeDefClass because it's lower in the hierarchy + # and was not matched previously + assert all( + module._parser_classes == RegexParserClass.ProgramUnitClass + for module in sourcefile.modules + ) + assert all( + routine._parser_classes == RegexParserClass.ProgramUnitClass + for routine in sourcefile.routines + ) + + assert {module.name.lower() for module in sourcefile.modules} == module_names + assert {routine.name.lower() for routine in sourcefile.routines} == routine_names + assert {routine.name.lower() for routine in sourcefile.all_subroutines} == routine_names | module_routine_names + + assert {routine.name.lower() for routine in sourcefile['func_with_contains'].routines} == {'func_with_contains_1'} + assert {routine.name.lower() for routine in sourcefile['sub_with_contains'].routines} == { + 'sub_with_contains_first', 'sub_with_contains_second', 'sub_with_contains_third' + } + + for module in sourcefile.modules: + assert not module.imports + for routine in sourcefile.all_subroutines: + assert not routine.imports + assert not sourcefile['bar'].typedefs + + # Validate that a re-parse with same parser classes does not change anything + sourcefile.make_complete(frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass) + assert sourcefile._incomplete + assert sourcefile._parser_classes == RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass + for module in sourcefile.modules: + assert not module.imports + for routine in sourcefile.all_subroutines: + assert not routine.imports + assert not sourcefile['bar'].typedefs + + # Incremental addition of imports + sourcefile.make_complete( + frontend=REGEX, + parser_classes=RegexParserClass.ProgramUnitClass | RegexParserClass.ImportClass + ) + assert sourcefile._parser_classes == ( + RegexParserClass.ProgramUnitClass | RegexParserClass.TypeDefClass | RegexParserClass.ImportClass + ) + # Note that the program unit objects don't include the TypeDefClass because it's lower in the hierarchy + # and was not matched previously + assert all( + module._parser_classes == ( + RegexParserClass.ProgramUnitClass | RegexParserClass.ImportClass + ) for module in sourcefile.modules + ) + assert all( + routine._parser_classes == ( + RegexParserClass.ProgramUnitClass | RegexParserClass.ImportClass + ) for routine in sourcefile.routines + ) + + assert {module.name.lower() for module in sourcefile.modules} == module_names + assert {routine.name.lower() for routine in sourcefile.routines} == routine_names + assert {routine.name.lower() for routine in sourcefile.all_subroutines} == routine_names | module_routine_names + + assert {routine.name.lower() for routine in sourcefile['func_with_contains'].routines} == {'func_with_contains_1'} + assert {routine.name.lower() for routine in sourcefile['sub_with_contains'].routines} == { + 'sub_with_contains_first', 'sub_with_contains_second', 'sub_with_contains_third' + } + + program_units_with_imports = { + 'foo': ['bar'], 'func_with_use_and_args': ['foo', 'bar'], 'sub_with_contains': ['bar'], + 'sub_with_renamed_import': ['bar'] + } + + for unit in module_names | routine_names | module_routine_names: + if unit in program_units_with_imports: + assert [import_.module.lower() for import_ in sourcefile[unit].imports] == program_units_with_imports[unit] + else: + assert not sourcefile[unit].imports + assert not sourcefile['bar'].typedefs + + # Parse the rest + sourcefile.make_complete(frontend=REGEX, parser_classes=RegexParserClass.AllClasses) + assert sourcefile._parser_classes == RegexParserClass.AllClasses + assert all( + module._parser_classes == RegexParserClass.AllClasses + for module in sourcefile.modules + ) + assert all( + routine._parser_classes == RegexParserClass.AllClasses + for routine in sourcefile.routines + ) + + assert {module.name.lower() for module in sourcefile.modules} == module_names + assert {routine.name.lower() for routine in sourcefile.routines} == routine_names + assert {routine.name.lower() for routine in sourcefile.all_subroutines} == routine_names | module_routine_names + + assert {routine.name.lower() for routine in sourcefile['func_with_contains'].routines} == {'func_with_contains_1'} + assert {routine.name.lower() for routine in sourcefile['sub_with_contains'].routines} == { + 'sub_with_contains_first', 'sub_with_contains_second', 'sub_with_contains_third' + } + + program_units_with_imports = { + 'foo': ['bar'], 'func_with_use_and_args': ['foo', 'bar'], 'sub_with_contains': ['bar'], + 'sub_with_renamed_import': ['bar'] + } + + for unit in module_names | routine_names | module_routine_names: + if unit in program_units_with_imports: + assert [import_.module.lower() for import_ in sourcefile[unit].imports] == program_units_with_imports[unit] + else: + assert not sourcefile[unit].imports + + # Check access via properties + assert 'bar' in sourcefile + assert 'food' in sourcefile['bar'] + assert sorted(sourcefile['bar'].typedef_map) == ['food', 'organic'] + assert sourcefile['bar'].definitions == sourcefile['bar'].typedefs + ('i_am_dim',) + assert 'cooking_method' in sourcefile['bar']['food'] + assert 'foobar' not in sourcefile['bar']['food'] + assert sourcefile['bar']['food'].interface_symbols == () + + # Check that triggering a full parse works from nested scopes + assert sourcefile['bar']._incomplete + sourcefile['bar']['food'].make_complete() + assert not sourcefile['bar']._incomplete + + +def test_regex_raw_source(): + """ + Verify that unparsed source appears in-between matched objects + """ + fcode = """ +! Some comment before the module +! +module some_mod + ! Some docstring + ! docstring + ! docstring + use some_mod + ! Some comment + ! comment + ! comment +end module some_mod + +! Other comment at the end + """.strip() + source = Sourcefile.from_source(fcode, frontend=REGEX) + + assert len(source.ir.body) == 3 + + assert isinstance(source.ir.body[0], RawSource) + assert source.ir.body[0].source.lines == (1, 2) + assert source.ir.body[0].text == '! Some comment before the module\n!' + assert source.ir.body[0].source.string == source.ir.body[0].text + + assert isinstance(source.ir.body[1], Module) + assert source.ir.body[1].source.lines == (3, 11) + assert source.ir.body[1].source.string.startswith('module') + + assert isinstance(source.ir.body[2], RawSource) + assert source.ir.body[2].source.lines == (12, 13) + assert source.ir.body[2].text == '\n! Other comment at the end' + assert source.ir.body[2].source.string == source.ir.body[2].text + + module = source['some_mod'] + assert len(module.spec.body) == 3 + assert isinstance(module.spec.body[0], RawSource) + assert isinstance(module.spec.body[1], ir.Import) + assert isinstance(module.spec.body[2], RawSource) + + assert module.spec.body[0].text.count('docstring') == 3 + assert module.spec.body[2].text.count('comment') == 3 + + +def test_regex_raw_source_with_cpp(): + """ + Verify that unparsed source appears in-between matched objects + and preprocessor statements are preserved + """ + fcode = """ +! Some comment before the subroutine +#ifdef RS6K +@PROCESS HOT(NOVECTOR) NOSTRICT +#endif +SUBROUTINE SOME_ROUTINE (KLON, KLEV) +IMPLICIT NONE +INTEGER, INTENT(IN) :: KLON, KLEV +! Comment inside routine +END SUBROUTINE SOME_ROUTINE + """.strip() + source = Sourcefile.from_source(fcode, frontend=REGEX) + + assert len(source.ir.body) == 2 + + assert isinstance(source.ir.body[0], RawSource) + assert source.ir.body[0].source.lines == (1, 4) + assert source.ir.body[0].text.startswith('! Some comment before the subroutine\n#') + assert source.ir.body[0].text.endswith('#endif') + assert source.ir.body[0].source.string == source.ir.body[0].text + + assert isinstance(source.ir.body[1], Subroutine) + assert source.ir.body[1].source.lines == (5, 9) + assert source.ir.body[1].source.string.startswith('SUBROUTINE') + + +def test_regex_raw_source_with_cpp_incomplete(): + """ + Verify that unparsed source appears inside matched objects if + parser classes are used to restrict the matching + """ + fcode = """ +SUBROUTINE driver(a, b, c) + INTEGER, INTENT(INOUT) :: a, b, c + +#include "kernel.intfb.h" + + CALL kernel(a, b ,c) +END SUBROUTINE driver + """.strip() + parser_classes = RegexParserClass.ProgramUnitClass + source = Sourcefile.from_source(fcode, frontend=REGEX, parser_classes=parser_classes) + + assert len(source.ir.body) == 1 + driver = source['driver'] + assert isinstance(driver, Subroutine) + assert not driver.docstring + assert not driver.body + assert not driver.contains + assert driver.spec and len(driver.spec.body) == 1 + assert isinstance(driver.spec.body[0], RawSource) + assert 'INTEGER, INTENT' in driver.spec.body[0].text + assert '#include' in driver.spec.body[0].text + + +@pytest.mark.parametrize('frontend', available_frontends( + xfail=[(OMNI, 'Non-standard notation needs full preprocessing')] +)) +def test_make_complete_sanitize(frontend): + """ + Test that attempts to first REGEX-parse and then complete source code + with unsupported features that require "frontend sanitization". + """ + fcode = """ +! Some comment before the subroutine +#ifdef RS6K +@PROCESS HOT(NOVECTOR) NOSTRICT +#endif +SUBROUTINE SOME_ROUTINE (KLON, KLEV) + IMPLICIT NONE + INTEGER, INTENT(IN) :: KLON, KLEV + ! Comment inside routine +END SUBROUTINE SOME_ROUTINE + """.strip() + source = Sourcefile.from_source(fcode, frontend=REGEX) + + # Ensure completion handles the non-supported features (@PROCESS) + source.make_complete(frontend=frontend) + + comments = FindNodes(ir.Comment).visit(source.ir) + assert len(comments) == 2 if frontend == FP else 1 + assert comments[0].text == '! Some comment before the subroutine' + if frontend == FP: + assert comments[1].text == '@PROCESS HOT(NOVECTOR) NOSTRICT' + + directives = FindNodes(PreprocessorDirective).visit(source.ir) + assert len(directives) == 2 + assert directives[0].text == '#ifdef RS6K' + assert directives[1].text == '#endif' + + +@pytest.mark.skipif(platform.system() == 'Darwin', + reason='Timeout utility test sporadically fails on MacOS CI runners.' +) +@pytest.mark.usefixtures('reset_regex_frontend_timeout') +def test_regex_timeout(): + """ + This source fails to parse because of missing SUBROUTINE in END + statement, and the test verifies that a timeout is encountered + """ + fcode = """ +subroutine some_routine(a) + real, intent(in) :: a +end + """.strip() + + # Test timeout + config['regex-frontend-timeout'] = 1 + start = perf_counter() + with pytest.raises(RuntimeError) as exc: + _ = Sourcefile.from_source(fcode, frontend=REGEX) + stop = perf_counter() + assert .9 < stop - start < 1.1 + assert 'REGEX frontend timeout of 1 s exceeded' in str(exc.value) + + # Test it works fine with proper Fortran: + fcode += ' subroutine' + source = Sourcefile.from_source(fcode, frontend=REGEX) + assert len(source.subroutines) == 1 + assert source.subroutines[0].name == 'some_routine' + + +def test_regex_module_imports(): + """ + Verify that the regex frontend is able to find and correctly parse + Fortran imports + """ + fcode = """ +module some_mod + use no_symbols_mod + use only_mod, only: my_var + use test_rename_mod, first_var1 => var1, first_var3 => var3 + use test_other_rename_mod, only: second_var1 => var1 + use test_other_rename_mod, only: other_var2 => var2, other_var3 => var3 + implicit none +end module some_mod + """.strip() + + module = Module.from_source(fcode, frontend=REGEX) + imports = FindNodes(ir.Import).visit(module.spec) + assert len(imports) == 5 + assert [import_.module for import_ in imports] == [ + 'no_symbols_mod', 'only_mod', 'test_rename_mod', 'test_other_rename_mod', + 'test_other_rename_mod' + ] + assert set(module.imported_symbols) == { + 'my_var', 'first_var1', 'first_var3', 'second_var1', 'other_var2', 'other_var3' + } + assert module.imported_symbol_map['first_var1'].type.use_name == 'var1' + assert module.imported_symbol_map['first_var3'].type.use_name == 'var3' + assert module.imported_symbol_map['second_var1'].type.use_name == 'var1' + assert module.imported_symbol_map['other_var2'].type.use_name == 'var2' + assert module.imported_symbol_map['other_var3'].type.use_name == 'var3' + + +def test_regex_subroutine_imports(): + """ + Verify that the regex frontend is able to find and correctly parse + Fortran imports + """ + fcode = """ +subroutine some_routine + use no_symbols_mod + use only_mod, only: my_var + use test_rename_mod, first_var1 => var1, first_var3 => var3 + use test_other_rename_mod, only: second_var1 => var1 + use test_other_rename_mod, only: other_var2 => var2, other_var3 => var3 + implicit none +end subroutine some_routine + """.strip() + + routine = Subroutine.from_source(fcode, frontend=REGEX) + imports = FindNodes(ir.Import).visit(routine.spec) + assert len(imports) == 5 + assert [import_.module for import_ in imports] == [ + 'no_symbols_mod', 'only_mod', 'test_rename_mod', 'test_other_rename_mod', + 'test_other_rename_mod' + ] + assert set(routine.imported_symbols) == { + 'my_var', 'first_var1', 'first_var3', 'second_var1', 'other_var2', 'other_var3' + } + assert routine.imported_symbol_map['first_var1'].type.use_name == 'var1' + assert routine.imported_symbol_map['first_var3'].type.use_name == 'var3' + assert routine.imported_symbol_map['second_var1'].type.use_name == 'var1' + assert routine.imported_symbol_map['other_var2'].type.use_name == 'var2' + assert routine.imported_symbol_map['other_var3'].type.use_name == 'var3' + + +def test_regex_import_linebreaks(): + """ + Verify correct handling of line breaks in import statements + """ + fcode = """ +module file_io_mod + USE PARKIND1 , ONLY : JPIM, JPRB, JPRD + +#ifdef HAVE_SERIALBOX + USE m_serialize, ONLY: & + fs_create_savepoint, & + fs_add_serializer_metainfo, & + fs_get_serializer_metainfo, & + fs_read_field, & + fs_write_field + USE utils_ppser, ONLY: & + ppser_initialize, & + ppser_finalize, & + ppser_serializer, & + ppser_serializer_ref, & + ppser_set_mode, & + ppser_savepoint +#endif + +#ifdef HAVE_HDF5 + USE hdf5_file_mod, only: hdf5_file +#endif + + implicit none +end module file_io_mod + """.strip() + module = Module.from_source(fcode, frontend=REGEX) + imports = FindNodes(ir.Import).visit(module.spec) + assert len(imports) == 4 + assert [import_.module for import_ in imports] == ['PARKIND1', 'm_serialize', 'utils_ppser', 'hdf5_file_mod'] + assert all( + s in module.imported_symbols for s in [ + 'JPIM', 'JPRB', 'JPRD', 'fs_create_savepoint', 'fs_add_serializer_metainfo', 'fs_get_serializer_metainfo', + 'fs_read_field', 'fs_write_field', 'ppser_initialize', 'ppser_finalize', 'ppser_serializer', + 'ppser_serializer_ref', 'ppser_set_mode', 'ppser_savepoint', 'hdf5_file' + ] + ) + + +def test_regex_typedef(): + """ + Verify that the regex frontend is able to parse type definitions and + correctly parse procedure bindings. + """ + fcode = """ +module typebound_item + implicit none + type some_type + contains + procedure, nopass :: routine => module_routine + procedure :: some_routine + procedure, pass :: other_routine + procedure :: routine1, & + & routine2 => routine + ! procedure :: routine1 + ! procedure :: routine2 => routine + end type some_type +contains + subroutine module_routine + integer m + m = 2 + end subroutine module_routine + + subroutine some_routine(self) + class(some_type) :: self + + call self%routine + end subroutine some_routine + + subroutine other_routine(self, m) + class(some_type), intent(inout) :: self + integer, intent(in) :: m + integer :: j + + j = m + call self%routine1 + call self%routine2 + end subroutine other_routine + + subroutine routine(self) + class(some_type) :: self + call self%some_routine + end subroutine routine + + subroutine routine1(self) + class(some_type) :: self + call module_routine + end subroutine routine1 +end module typebound_item + """.strip() + + module = Module.from_source(fcode, frontend=REGEX) + + assert 'some_type' in module.typedef_map + some_type = module.typedef_map['some_type'] + + proc_bindings = { + 'routine': ('module_routine',), + 'some_routine': None, + 'other_routine': None, + 'routine1': None, + 'routine2': ('routine',) + } + assert len(proc_bindings) == len(some_type.variables) + assert all(proc in some_type.variables for proc in proc_bindings) + assert all( + some_type.variable_map[proc].type.bind_names == bind + for proc, bind in proc_bindings.items() + ) + + +def test_regex_typedef_generic(): + fcode = """ +module typebound_header + implicit none + + type header_type + contains + procedure :: member_routine => header_member_routine + procedure :: routine_real => header_routine_real + procedure :: routine_integer + generic :: routine => routine_real, routine_integer + end type header_type + +contains + + subroutine header_member_routine(self, val) + class(header_type) :: self + integer, intent(in) :: val + integer :: j + j = val + end subroutine header_member_routine + + subroutine header_routine_real(self, val) + class(header_type) :: self + real, intent(out) :: val + val = 1.0 + end subroutine header_routine_real + + subroutine routine_integer(self, val) + class(header_type) :: self + integer, intent(out) :: val + val = 1 + end subroutine routine_integer +end module typebound_header + """.strip() + + module = Module.from_source(fcode, frontend=REGEX) + + assert 'header_type' in module.typedef_map + header_type = module.typedef_map['header_type'] + + proc_bindings = { + 'member_routine': ('header_member_routine',), + 'routine_real': ('header_routine_real',), + 'routine_integer': None, + 'routine': ('routine_real', 'routine_integer') + } + assert len(proc_bindings) == len(header_type.variables) + assert all(proc in header_type.variables for proc in proc_bindings) + assert all( + ( + header_type.variable_map[proc].type.bind_names == bind + and header_type.variable_map[proc].type.initial is None + ) + for proc, bind in proc_bindings.items() + ) + + +def test_regex_loki_69(): + """ + Test compliance of REGEX frontend with edge cases reported in LOKI-69. + This should become a full-blown Scheduler test when REGEX frontend undeprins the scheduler. + """ + fcode = """ +subroutine random_call_0(v_out,v_in,v_inout) +implicit none + + real(kind=jprb),intent(in) :: v_in + real(kind=jprb),intent(out) :: v_out + real(kind=jprb),intent(inout) :: v_inout + + +end subroutine random_call_0 + +!subroutine random_call_1(v_out,v_in,v_inout) +!implicit none +! +! real(kind=jprb),intent(in) :: v_in +! real(kind=jprb),intent(out) :: v_out +! real(kind=jprb),intent(inout) :: v_inout +! +! +!end subroutine random_call_1 + +subroutine random_call_2(v_out,v_in,v_inout) +implicit none + + real(kind=jprb),intent(in) :: v_in + real(kind=jprb),intent(out) :: v_out + real(kind=jprb),intent(inout) :: v_inout + + +end subroutine random_call_2 + +subroutine test(v_out,v_in,v_inout,some_logical) +implicit none + + real(kind=jprb),intent(in ) :: v_in + real(kind=jprb),intent(out ) :: v_out + real(kind=jprb),intent(inout) :: v_inout + + logical,intent(in) :: some_logical + + v_inout = 0._jprb + if(some_logical)then + call random_call_0(v_out,v_in,v_inout) + endif + + if(some_logical) call random_call_2 + +end subroutine test + """.strip() + + source = Sourcefile.from_source(fcode, frontend=REGEX) + assert [r.name for r in source.all_subroutines] == ['random_call_0', 'random_call_2', 'test'] + + calls = FindNodes(ir.CallStatement).visit(source['test'].ir) + assert [call.name for call in calls] == ['RANDOM_CALL_0', 'random_call_2'] + + variable_map_test = source['test'].variable_map + v_in_type = variable_map_test['v_in'].type + assert v_in_type.dtype is BasicType.REAL + assert v_in_type.kind == 'jprb' + + +def test_regex_variable_declaration(testdir): + """ + Test correct parsing of derived type variable declarations + """ + filepath = testdir/'sources/projTypeBound/typebound_item.F90' + source = Sourcefile.from_file(filepath, frontend=REGEX) + + driver = source['driver'] + assert driver.variables == ('constant', 'obj', 'obj2', 'header', 'other_obj', 'derived', 'x', 'i') + assert source['module_routine'].variables == ('m',) + assert source['other_routine'].variables == ('self', 'm', 'j') + assert source['routine'].variables == ('self',) + assert source['routine1'].variables == ('self',) + + # Check this for REGEX and complete parse to make sure their behaviour is aligned + for _ in range(2): + var_map = driver.symbol_map + assert isinstance(var_map['obj'].type.dtype, DerivedType) + assert var_map['obj'].type.dtype.name == 'some_type' + assert isinstance(var_map['obj2'].type.dtype, DerivedType) + assert var_map['obj2'].type.dtype.name == 'some_type' + assert isinstance(var_map['header'].type.dtype, DerivedType) + assert var_map['header'].type.dtype.name == 'header_type' + assert isinstance(var_map['other_obj'].type.dtype, DerivedType) + assert var_map['other_obj'].type.dtype.name == 'other' + assert isinstance(var_map['derived'].type.dtype, DerivedType) + assert var_map['derived'].type.dtype.name == 'other' + assert isinstance(var_map['x'].type.dtype, BasicType) + assert var_map['x'].type.dtype is BasicType.REAL + assert isinstance(var_map['i'].type.dtype, BasicType) + assert var_map['i'].type.dtype is BasicType.INTEGER + + # While we're here: let's check the call statements, too + calls = FindNodes(ir.CallStatement).visit(driver.ir) + assert len(calls) == 7 + assert all(isinstance(call.name.type.dtype, ProcedureType) for call in calls) + + # Note: we're explicitly accessing the string name here (instead of relying + # on the StrCompareMixin) as some have dimensions that only show up in the full + # parse + assert calls[0].name.name == 'obj%other_routine' + assert calls[0].name.parent.name == 'obj' + assert calls[1].name.name == 'obj2%some_routine' + assert calls[1].name.parent.name == 'obj2' + assert calls[2].name.name == 'header%member_routine' + assert calls[2].name.parent.name == 'header' + assert calls[3].name.name == 'header%routine' + assert calls[3].name.parent.name == 'header' + assert calls[4].name.name == 'header%routine' + assert calls[4].name.parent.name == 'header' + assert calls[5].name.name == 'other_obj%member' + assert calls[5].name.parent.name == 'other_obj' + assert calls[6].name.name == 'derived%var%member_routine' + assert calls[6].name.parent.name == 'derived%var' + assert calls[6].name.parent.parent.name == 'derived' + + # Hack: Split the procedure binding into one-per-line until Fparser + # supports this... + module = source['typebound_item'] + module.source.string = module.source.string.replace( + 'procedure :: routine1,', 'procedure :: routine1\nprocedure ::' + ) + + source.make_complete() + + +def test_regex_variable_declaration_parentheses(): + fcode = """ +subroutine definitely_not_allfpos(ydfpdata) +implicit none +integer, parameter :: NMaxCloudTypes = 12 +type(tfpdata), intent(in) :: ydfpdata +type(tfpofn) :: ylofn(size(ydfpdata%yfpos%yfpgeometry%yfpusergeo)) +real, dimension(nproma, max(nang, 1), max(nfre, 1)) :: not_an_annoying_ecwam_var +character(len=511) :: cloud_type_name(NMaxCloudTypes) = ["","","","","","","","","","","",""], other_name = "", names(3) = (/ "", "", "" /) +character(len=511) :: more_names(2) = (/ "What", " is" /), naaaames(2) = [ " going ", "on?" ] +end subroutine definitely_not_allfpos + """.strip() + + source = Sourcefile.from_source(fcode, frontend=REGEX) + routine = source['definitely_not_allfpos'] + assert routine.variables == ( + 'nmaxcloudtypes', 'ydfpdata', 'ylofn', 'not_an_annoying_ecwam_var', + 'cloud_type_name', 'other_name', 'names', 'more_names', 'naaaames' + ) + assert routine.symbol_map['not_an_annoying_ecwam_var'].type.dtype is BasicType.REAL + assert routine.symbol_map['cloud_type_name'].type.dtype is BasicType.CHARACTER + + +def test_regex_preproc_in_contains(): + fcode = """ +module preproc_in_contains + implicit none + public :: routine1, routine2, func +contains +#include "some_include.h" + subroutine routine1 + end subroutine routine1 + + module subroutine mod_routine + call other_routine + contains +#define something + subroutine other_routine + end subroutine other_routine + end subroutine mod_routine + + elemental function func + real func + end function func +end module preproc_in_contains + """.strip() + source = Sourcefile.from_source(fcode, frontend=REGEX) + + expected_names = {'preproc_in_contains', 'routine1', 'mod_routine', 'func'} + actual_names = {r.name for r in source.all_subroutines} | {m.name for m in source.modules} + assert expected_names == actual_names + + assert isinstance(source['mod_routine']['other_routine'], Subroutine) + + +def test_regex_interface_subroutine(): + fcode = """ +subroutine test(callback) + +implicit none +interface + subroutine some_kernel(a, b, c) + integer, intent(in) :: a, b + integer, intent(out) :: c + end subroutine some_kernel + + SUBROUTINE other_kernel(a) + integer, intent(inout) :: a + end subroutine +end interface + +INTERFACE + function other_func(a) + integer, intent(in) :: a + integer, other_func + end function other_func +end interface + +abstract interface + function callback_func(a) result(b) + integer, intent(in) :: a + integer :: b + end FUNCTION callback_func +end INTERFACE + +procedure(callback_func), pointer, intent(in) :: callback +integer :: a, b, c + +a = callback(1) +b = other_func(a) + +call some_kernel(a, b, c) +call other_kernel(c) + +end subroutine test + """.strip() + + # Make sure only the host subroutine is captured + source = Sourcefile.from_source(fcode, frontend=REGEX) + assert len(source.subroutines) == 1 + assert source.subroutines[0].name == 'test' + assert source.subroutines[0].source.lines == (1, 38) + + # Make sure this also works for module procedures + fcode = f""" +module my_mod + implicit none +contains +{fcode} +end module my_mod + """.strip() + source = Sourcefile.from_source(fcode, frontend=REGEX) + assert not source.subroutines + assert len(source.all_subroutines) == 1 + assert source.all_subroutines[0].name == 'test' + assert source.all_subroutines[0].source.lines == (4, 41) + + +def test_regex_interface_module(): + fcode = """ +module my_mod + implicit none + interface + subroutine ext1 (x, y, z) + real, dimension(100, 100), intent(inout) :: x, y, z + end subroutine ext1 + subroutine ext2 (x, z) + real, intent(in) :: x + complex(kind = 4), intent(inout) :: z(2000) + end subroutine ext2 + function ext3 (p, q) + logical ext3 + integer, intent(in) :: p(1000) + logical, intent(in) :: q(1000) + end function ext3 + end interface + interface sub + subroutine sub_int (a) + integer, intent(in) :: a(:) + end subroutine sub_int + subroutine sub_real (a) + real, intent(in) :: a(:) + end subroutine sub_real + end interface sub + interface func + module procedure func_int + module procedure func_real + end interface func +contains + subroutine sub_int (a) + integer, intent(in) :: a(:) + end subroutine sub_int + subroutine sub_real (a) + real, intent(in) :: a(:) + end subroutine sub_real + integer module function func_int (a) + integer, intent(in) :: a(:) + end function func_int + real module function func_real (a) + real, intent(in) :: a(:) + end function func_real +end module my_mod + """.strip() + source = Sourcefile.from_source(fcode, frontend=REGEX, parser_classes=RegexParserClass.ProgramUnitClass) + + assert len(source.modules) == 1 + assert source['my_mod'] is not None + assert not source['my_mod'].interfaces + + source.make_complete( + frontend=REGEX, + parser_class=RegexParserClass.ProgramUnitClass | RegexParserClass.InterfaceClass + ) + assert len(source['my_mod'].interfaces) == 3 + assert source['my_mod'].symbols == ( + 'ext1', 'ext2', 'ext3', + 'sub', 'sub_int', 'sub_real', + 'func', 'func_int', 'func_real', 'func_int', 'func_real', + 'sub_int', 'sub_real', + 'func_int', 'func_real' + ) + + +def test_regex_function_inline_return_type(): + fcode = """ +REAL(KIND=JPRB) FUNCTION DOT_PRODUCT_ECV() + +END FUNCTION DOT_PRODUCT_ECV + +SUBROUTINE DOT_PROD_SP_2D() + +END SUBROUTINE DOT_PROD_SP_2D + """.strip() + source = Sourcefile.from_source(fcode, frontend=REGEX) + assert { + routine.name.lower() for routine in source.subroutines + } == {'dot_product_ecv', 'dot_prod_sp_2d'} + + source.make_complete() + routine = source['dot_product_ecv'] + assert 'dot_product_ecv' in routine.variables + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_regex_prefix(frontend, tmp_path): + fcode = """ +module some_mod + implicit none +contains + pure elemental real function f_elem(a) + real, intent(in) :: a + f_elem = a + end function f_elem + + pure recursive integer function fib(i) result(fib_i) + integer, intent(in) :: i + if (i <= 0) then + fib_i = 0 + else if (i == 1) then + fib_i = 1 + else + fib_i = fib(i-1) + fib(i-2) + end if + end function fib +end module some_mod + """.strip() + source = Sourcefile.from_source(fcode, frontend=REGEX) + assert source['f_elem'].prefix == ('pure elemental real',) + assert source['fib'].prefix == ('pure recursive integer',) + source.make_complete(frontend=frontend, xmods=[tmp_path]) + assert tuple(p.lower() for p in source['f_elem'].prefix) == ('pure', 'elemental') + assert tuple(p.lower() for p in source['fib'].prefix) == ('pure', 'recursive') + + +def test_regex_fypp(): + """ + Test that unexpanded fypp-annotations are handled gracefully in the REGEX frontend. + """ + fcode = """ +module fypp_mod +! A pre-set array of pre-prcessor variables +#:mute +#:set foo = [2,3,4,5] +#:endmute + +contains + +! A non-templated routine +subroutine first_routine(i, x) + integer, intent(in) :: i + real, intent(inout) :: x(3) +end subroutine first_routine + +! A fypp-loop with in-place directives for subroutine names +#:for bar in foo +#:set rname = 'routine_%s' % (bar,) +subroutine ${rname}$ (i, x) + integer, intent(in) :: i + real, intent(inout) :: x(3) +end subroutine ${rname}$ +#:endfor + +! Another non-templated routine +subroutine last_routine(i, x) + integer, intent(in) :: i + real, intent(inout) :: x(3) +end subroutine last_routine + +end module fypp_mod +""" + source = Sourcefile.from_source(fcode, frontend=REGEX) + module = source['fypp_mod'] + assert isinstance(module, Module) + + # Check that only non-templated routines are included + assert len(module.routines) == 2 + assert module.routines[0].name == 'first_routine' + assert module.routines[1].name == 'last_routine' + def test_declaration_whitespace_attributes(): """ diff --git a/loki/ir/expr_visitors.py b/loki/ir/expr_visitors.py index f67c30837..2676b5274 100644 --- a/loki/ir/expr_visitors.py +++ b/loki/ir/expr_visitors.py @@ -16,7 +16,8 @@ from loki.ir.transformer import Transformer from loki.tools import flatten, as_tuple from loki.expression.mappers import ( - SubstituteExpressionsMapper, ExpressionRetriever, AttachScopesMapper + SubstituteExpressionsMapper, ExpressionRetriever, + AttachScopesMapper, LokiIdentityMapper ) from loki.expression.symbols import ( Array, Scalar, InlineCall, TypedSymbol, FloatLiteral, IntLiteral, @@ -24,10 +25,11 @@ ) __all__ = [ - 'FindExpressions', 'FindVariables', 'FindTypedSymbols', - 'FindInlineCalls', 'FindLiterals', 'FindRealLiterals', + 'ExpressionFinder', 'FindExpressions', 'FindVariables', + 'FindTypedSymbols', 'FindInlineCalls', 'FindLiterals', + 'FindRealLiterals', 'ExpressionTransformer', 'SubstituteExpressions', 'SubstituteStringExpressions', - 'ExpressionFinder', 'AttachScopes' + 'AttachScopes' ] @@ -212,7 +214,39 @@ class FindRealLiterals(ExpressionFinder): retriever = ExpressionRetriever(lambda e: isinstance(e, FloatLiteral)) -class SubstituteExpressions(Transformer): +class ExpressionTransformer(Transformer): + """ + The :any:`Transformer` base class for manipulating expressions. + + This transformer uses the class attribute :data:`expr_mapper` to + map an existing expression sub-tree to an new one. By default, it + uses the :any:`LokiIdentityMapper` to replicate the existing tree. + + Attributes + ---------- + expr_mapper : :class:`pymbolic.mapper.Mapper` + An implementation of an expression mapper, e.g., + :any:`SubstituteExpressionsMapper`, that is used to map an + expression tree to a new one. + + Parameters + ---------- + inplace : bool, optional + If set to `True`, all updates are performed on existing :any:`Node` + objects, instead of rebuilding them, keeping the original tree intact. + """ + expr_mapper = LokiIdentityMapper() + + def visit_Expression(self, o, **kwargs): + """ + Call the associated mapper for the given expression node + """ + if kwargs.get('recurse_to_declaration_attributes'): + return self.expr_mapper(o, recurse_to_declaration_attributes=True) + return self.expr_mapper(o) + + +class SubstituteExpressions(ExpressionTransformer): """ A dedicated visitor to perform expression substitution in all IR nodes @@ -246,16 +280,9 @@ class SubstituteExpressions(Transformer): def __init__(self, expr_map, invalidate_source=True, **kwargs): super().__init__(invalidate_source=invalidate_source, **kwargs) + # Override the static default with a substitution mapper from ``expr_map`` self.expr_mapper = SubstituteExpressionsMapper(expr_map) - def visit_Expression(self, o, **kwargs): - """ - call :any:`SubstituteExpressionsMapper` for the given expression node - """ - if kwargs.get('recurse_to_declaration_attributes'): - return self.expr_mapper(o, recurse_to_declaration_attributes=True) - return self.expr_mapper(o) - def visit_Import(self, o, **kwargs): """ For :any:`Import` (as well as :any:`VariableDeclaration` and :any:`ProcedureDeclaration`) diff --git a/loki/program_unit.py b/loki/program_unit.py index 6b04a4013..73a8e245d 100644 --- a/loki/program_unit.py +++ b/loki/program_unit.py @@ -12,7 +12,9 @@ Frontend, parse_omni_source, parse_ofp_source, parse_fparser_source, RegexParserClass, preprocess_cpp, sanitize_input ) -from loki.ir import nodes as ir, FindNodes, Transformer +from loki.ir import ( + nodes as ir, FindNodes, Transformer, ExpressionTransformer +) from loki.logging import debug from loki.scope import Scope from loki.tools import CaseInsensitiveDict, as_tuple, flatten @@ -401,6 +403,9 @@ def enrich(self, definitions, recurse=False): updated_symbol_attrs[name] = attrs.clone(dtype=self.parent.symbol_attrs[name].dtype) self.symbol_attrs.update(updated_symbol_attrs) + # Rebuild local symbols to ensure correct symbol types + self.spec = ExpressionTransformer(inplace=True).visit(self.spec) + if recurse: for routine in self.subroutines: routine.enrich(definitions, recurse=True) diff --git a/loki/subroutine.py b/loki/subroutine.py index c931f1e5b..f2f82f78f 100644 --- a/loki/subroutine.py +++ b/loki/subroutine.py @@ -11,7 +11,8 @@ parse_regex_source ) from loki.ir import ( - nodes as ir, FindNodes, Transformer, pragmas_attached + nodes as ir, FindNodes, Transformer, ExpressionTransformer, + pragmas_attached ) from loki.logging import debug from loki.program_unit import ProgramUnit @@ -488,6 +489,9 @@ def enrich(self, definitions, recurse=False): symbol = symbol.clone(scope=self, type=symbol.type.clone(dtype=routine.procedure_type)) call._update(name=symbol) + # Rebuild local symbols to ensure correct symbol types + self.body = ExpressionTransformer(inplace=True).visit(self.body) + def __repr__(self): """ String representation. diff --git a/loki/tests/test_modules.py b/loki/tests/test_modules.py index 668e95322..620faf8a7 100644 --- a/loki/tests/test_modules.py +++ b/loki/tests/test_modules.py @@ -1375,8 +1375,8 @@ def test_module_enrichment_within_file(frontend, tmp_path): @pytest.mark.parametrize('frontend', available_frontends()) -def test_module_enrichment_typdefs(frontend, tmp_path): - """ Test that module-level enrihcment is propagated correctly """ +def test_module_enrichment_typedefs(frontend, tmp_path): + """ Test that module-level enrichment is propagated correctly """ fcode_state_mod = """ module state_type_mod @@ -1420,7 +1420,10 @@ def test_module_enrichment_typdefs(frontend, tmp_path): assert isinstance(state.type.dtype, DerivedType) assert isinstance(state.type.dtype.typedef, ir.TypeDef) + # Verify that we have the right symbol and type info assigns = FindNodes(ir.Assignment).visit(driver.body) assert len(assigns) == 1 assert assigns[0].lhs.type.dtype == BasicType.REAL assert assigns[0].lhs.type.shape == (':', ':') + assert isinstance(assigns[0].lhs, sym.Array) + assert assigns[0].lhs.parent.type.dtype.typedef == state_mod['state_type'] diff --git a/loki/tests/test_subroutine.py b/loki/tests/test_subroutine.py index 32dc9d79d..1093581d8 100644 --- a/loki/tests/test_subroutine.py +++ b/loki/tests/test_subroutine.py @@ -10,17 +10,17 @@ import pytest import numpy as np -from loki import ( - Sourcefile, Module, Subroutine, FindVariables, FindNodes, Section, - Array, Scalar, Variable, - SymbolAttributes, StringLiteral, fgen, fexprgen, - VariableDeclaration, Transformer, FindTypedSymbols, - ProcedureSymbol, StatementFunction, DeferredTypeSymbol -) +from loki import Sourcefile, Module, Subroutine, fgen, fexprgen from loki.build import jit_compile, jit_compile_lib, clean_test +from loki.expression import symbols as sym from loki.frontend import available_frontends, OMNI, REGEX -from loki.types import BasicType, DerivedType, ProcedureType -from loki.ir import nodes as ir +from loki.ir import ( + nodes as ir, FindNodes, FindVariables, FindTypedSymbols, + Transformer +) +from loki.types import ( + BasicType, DerivedType, ProcedureType, SymbolAttributes +) @pytest.fixture(scope='module', name='here') @@ -56,8 +56,8 @@ def test_routine_simple(tmp_path, frontend): # Test the internals of the subroutine routine = Subroutine.from_source(fcode, frontend=frontend) - assert isinstance(routine.body, Section) - assert isinstance(routine.spec, Section) + assert isinstance(routine.body, ir.Section) + assert isinstance(routine.spec, ir.Section) assert len(routine.docstring) == 1 assert routine.docstring[0].text == '! This is the docstring' assert routine.definitions == () @@ -216,9 +216,9 @@ def test_routine_arguments_add_remove(frontend): # Create a new set of variables and add to local routine variables x = routine.variables[1] # That's the symbol for variable 'x' real_type = routine.symbol_attrs['scalar'] # Type of variable 'maximum' - a = Scalar(name='a', type=real_type, scope=routine) - b = Array(name='b', dimensions=(x, ), type=real_type, scope=routine) - c = Variable(name='c', type=x.type, scope=routine) + a = sym.Scalar(name='a', type=real_type, scope=routine) + b = sym.Array(name='b', dimensions=(x, ), type=real_type, scope=routine) + c = sym.Variable(name='c', type=x.type, scope=routine) # Add new arguments and check that they are all in the routine spec routine.arguments += (a, b, c) @@ -375,9 +375,9 @@ def test_routine_variables_add_remove(frontend): x = routine.variable_map['x'] # That's the symbol for variable 'x' real_type = SymbolAttributes('real', kind=routine.variable_map['jprb']) int_type = SymbolAttributes('integer') - a = Scalar(name='a', type=real_type, scope=routine) - b = Array(name='b', dimensions=(x, ), type=real_type, scope=routine) - c = Variable(name='c', type=int_type, scope=routine) + a = sym.Scalar(name='a', type=real_type, scope=routine) + b = sym.Array(name='b', dimensions=(x, ), type=real_type, scope=routine) + c = sym.Variable(name='c', type=int_type, scope=routine) # Add new variables and check that they are all in the routine spec routine.variables += (a, b, c) @@ -493,17 +493,17 @@ def test_routine_variables_dim_shapes(frontend): assert routine.arguments == ('v1', 'v2', 'v3(:)', 'v4(v1, v2)', 'v5(0:v1, v2 - 1)') # Make sure variable/argument shapes on the routine work - shapes = [fexprgen(v.shape) for v in routine.arguments if isinstance(v, Array)] + shapes = [fexprgen(v.shape) for v in routine.arguments if isinstance(v, sym.Array)] assert shapes == ['(v1,)', '(v1, v2)', '(0:v1, v2 - 1)'] # Ensure that all spec variables (including dimension symbols) are scoped correctly spec_vars = [v for v in FindVariables(unique=False).visit(routine.spec) if v.name.lower() != 'selected_real_kind'] assert all(v.scope == routine for v in spec_vars) - assert all(isinstance(v, (Scalar, Array)) for v in spec_vars) + assert all(isinstance(v, (sym.Scalar, sym.Array)) for v in spec_vars) # Ensure shapes of body variables are ok b_shapes = [fexprgen(v.shape) for v in FindVariables(unique=False).visit(routine.body) - if isinstance(v, Array)] + if isinstance(v, sym.Array)] assert b_shapes == ['(v1,)', '(v1,)', '(v1, v2)', '(0:v1, v2 - 1)'] @@ -540,7 +540,7 @@ def test_routine_variables_shape_propagation(tmp_path, header_path, frontend): # Verify that all variable instances have type and shape information variables = FindVariables().visit(routine.body) - assert all(v.shape is not None for v in variables if isinstance(v, Array)) + assert all(v.shape is not None for v in variables if isinstance(v, sym.Array)) vmap = {v.name: v for v in variables} assert fexprgen(vmap['vector'].shape) == '(x,)' @@ -576,7 +576,7 @@ def test_routine_variables_shape_propagation(tmp_path, header_path, frontend): # Verify that all derived type variables have shape info variables = FindVariables().visit(routine.body) - assert all(v.shape is not None for v in variables if isinstance(v, Array)) + assert all(v.shape is not None for v in variables if isinstance(v, sym.Array)) # Verify shape info from imported derived type is propagated vmap = {v.name: v for v in variables} @@ -656,7 +656,7 @@ def test_routine_type_propagation(header_path, frontend, tmp_path): # Verify that all variable instances have type information variables = FindVariables().visit(routine.body) - assert all(v.type is not None for v in variables if isinstance(v, (Scalar, Array))) + assert all(v.type is not None for v in variables if isinstance(v, (sym.Scalar, sym.Array))) vmap = {v.name: v for v in variables} assert vmap['x'].type.dtype == BasicType.INTEGER @@ -744,11 +744,11 @@ def test_routine_call_arrays(header_path, frontend, tmp_path): assert str(call.arguments[3]) == 'matrix' assert str(call.arguments[4]) == 'item%matrix' - assert isinstance(call.arguments[0], Scalar) - assert isinstance(call.arguments[1], Scalar) - assert isinstance(call.arguments[2], Array) - assert isinstance(call.arguments[3], Array) - assert isinstance(call.arguments[4], Array) + assert isinstance(call.arguments[0], sym.Scalar) + assert isinstance(call.arguments[1], sym.Scalar) + assert isinstance(call.arguments[2], sym.Array) + assert isinstance(call.arguments[3], sym.Array) + assert isinstance(call.arguments[4], sym.Array) assert fexprgen(call.arguments[2].shape) == '(x,)' assert fexprgen(call.arguments[3].shape) == '(x, y)' @@ -791,10 +791,10 @@ def test_call_kwargs(frontend): assert all(isinstance(arg, tuple) and len(arg) == 2 for arg in calls[0].kwarguments) assert calls[0].kwarguments[0][0] == 'kprocs' - assert (isinstance(calls[0].kwarguments[0][1], Scalar) and + assert (isinstance(calls[0].kwarguments[0][1], sym.Scalar) and calls[0].kwarguments[0][1].name == 'kprocs') - assert calls[0].kwarguments[1] == ('cdstring', StringLiteral('routine_call_kwargs')) + assert calls[0].kwarguments[1] == ('cdstring', sym.StringLiteral('routine_call_kwargs')) @pytest.mark.parametrize('frontend', available_frontends()) @@ -812,7 +812,7 @@ def test_call_args_kwargs(frontend): assert calls[0].name == 'mpl_send' assert len(calls[0].arguments) == 3 assert all(a.name == b.name for a, b in zip(calls[0].arguments, routine.arguments)) - assert calls[0].kwarguments == (('cdstring', StringLiteral('routine_call_args_kwargs')),) + assert calls[0].kwarguments == (('cdstring', sym.StringLiteral('routine_call_args_kwargs')),) @pytest.mark.parametrize('frontend', available_frontends()) @@ -1178,15 +1178,12 @@ def test_external_stmt(tmp_path, frontend): routine = source['routine_external_stmt'] assert len(routine.arguments) == 8 - for decl in FindNodes(VariableDeclaration).visit(routine.spec): - # Skip local variables - if decl.symbols[0].name in ('invar', 'outvar', 'tmp'): - continue + for decl in FindNodes(ir.ProcedureDeclaration).visit(routine.spec): # Is the EXTERNAL attribute set? assert decl.external for v in decl.symbols: # Are procedure names represented as Scalar objects? - assert isinstance(v, ProcedureSymbol) + assert isinstance(v, sym.ProcedureSymbol) assert isinstance(v.type.dtype, ProcedureType) assert v.type.external is True assert v.type.dtype.procedure == BasicType.DEFERRED @@ -1499,12 +1496,12 @@ def test_subroutine_stmt_func(tmp_path, frontend): # OMNI inlines statement functions, so we can only check correct representation # for fparser if frontend != OMNI: - stmt_func_decls = {d.variable: d for d in FindNodes(StatementFunction).visit(routine.spec)} + stmt_func_decls = {d.variable: d for d in FindNodes(ir.StatementFunction).visit(routine.spec)} assert len(stmt_func_decls) == 3 for name in ('plus', 'minus', 'mult'): var = routine.variable_map[name] - assert isinstance(var, ProcedureSymbol) + assert isinstance(var, sym.ProcedureSymbol) assert isinstance(var.type.dtype, ProcedureType) assert var.type.dtype.procedure is stmt_func_decls[var] assert stmt_func_decls[var].source is not None @@ -1530,8 +1527,8 @@ def test_mixed_declaration_interface(frontend): with pytest.raises(AssertionError) as error: routine = Subroutine.from_source(fcode, frontend=frontend) - assert isinstance(routine.body, Section) - assert isinstance(routine.spec, Section) + assert isinstance(routine.body, ir.Section) + assert isinstance(routine.spec, ir.Section) _ = routine.interface assert "Declarations must have intents" in str(error.value) @@ -1556,7 +1553,7 @@ def test_subroutine_prefix(frontend): assert routine.return_type.dtype is BasicType.REAL assert routine.name in routine.symbol_map - decl = [d for d in FindNodes(VariableDeclaration).visit(routine.spec) if routine.name in d.symbols] + decl = [d for d in FindNodes(ir.VariableDeclaration).visit(routine.spec) if routine.name in d.symbols] assert len(decl) == 1 decl = decl[0] @@ -1752,15 +1749,15 @@ def test_subroutine_lazy_arguments_incomplete1(frontend): assert routine.arguments == () assert routine.argnames == [] assert routine._dummies == () - assert all(isinstance(arg, DeferredTypeSymbol) for arg in routine.arguments) + assert all(isinstance(arg, sym.DeferredTypeSymbol) for arg in routine.arguments) routine.make_complete(frontend=frontend) assert not routine._incomplete assert routine.arguments == ('n', 'a(n)', 'b(n)', 'd(n)') assert routine.argnames == ['n', 'a', 'b', 'd'] assert routine._dummies == ('n', 'a', 'b', 'd') - assert isinstance(routine.arguments[0], Scalar) - assert all(isinstance(arg, Array) for arg in routine.arguments[1:]) + assert isinstance(routine.arguments[0], sym.Scalar) + assert all(isinstance(arg, sym.Array) for arg in routine.arguments[1:]) @pytest.mark.parametrize('frontend', available_frontends()) @@ -1841,15 +1838,15 @@ def test_subroutine_lazy_arguments_incomplete2(frontend): assert routine.arguments == () assert routine.argnames == [] assert routine._dummies == () - assert all(isinstance(arg, DeferredTypeSymbol) for arg in routine.arguments) + assert all(isinstance(arg, sym.DeferredTypeSymbol) for arg in routine.arguments) routine.make_complete(frontend=frontend) assert not routine._incomplete assert routine.arguments == argnames_with_dim assert [arg.upper() for arg in routine.argnames] == [arg.upper() for arg in argnames] assert routine._dummies == argnames - assert all(isinstance(arg, Scalar) for arg in routine.arguments[:4]) - assert all(isinstance(arg, Array) for arg in routine.arguments[4:]) + assert all(isinstance(arg, sym.Scalar) for arg in routine.arguments[:4]) + assert all(isinstance(arg, sym.Array) for arg in routine.arguments[4:]) @pytest.mark.parametrize('frontend', available_frontends()) @@ -2079,6 +2076,9 @@ def test_enrich_derived_types(tmp_path, frontend): # Enrich the routine with module definitions routine.enrich(module) + field_3rb_symbol = routine.symbol_map['field_3rb_array'] + yda_array_p = routine.resolve_typebound_var('yda_array%p') + # Ensure the imported type symbol is correctly enriched assert field_3rb_symbol.type.imported assert field_3rb_symbol.type.module is module @@ -2089,6 +2089,21 @@ def test_enrich_derived_types(tmp_path, frontend): assert yda_array.type.dtype.typedef is field_3rb_tdef assert yda_array_p.type.dtype is BasicType.REAL assert yda_array_p.type.shape == (':', ':', ':') + assert isinstance(yda_array_p, sym.Array) + + # Double-check body and spec expressions + decls = FindNodes(ir.VariableDeclaration).visit(routine.spec) + assert len(decls) == 1 + assert len(decls[0].symbols) == 1 + assert isinstance(decls[0].symbols[0], sym.Scalar) + assert decls[0].symbols[0].type.dtype.typedef == field_3rb_tdef + + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert len(assigns) == 1 + assert isinstance(assigns[0].lhs, sym.Array) + assert assigns[0].lhs.type.dtype == BasicType.REAL + assert assigns[0].lhs.type.shape == (':', ':', ':') + assert assigns[0].lhs.parent.type.dtype.typedef == field_3rb_tdef @pytest.mark.parametrize('frontend', available_frontends( @@ -2127,7 +2142,7 @@ def test_subroutine_deep_clone(frontend): map_nodes={} for assign in FindNodes(ir.Assignment).visit(new_routine.body): map_nodes[assign] = ir.CallStatement( - name=DeferredTypeSymbol(name='testcall'), arguments=(assign.lhs,), scope=new_routine + name=sym.DeferredTypeSymbol(name='testcall'), arguments=(assign.lhs,), scope=new_routine ) new_routine.body = Transformer(map_nodes).visit(new_routine.body) @@ -2273,7 +2288,7 @@ def test_resolve_typebound_var(frontend, tmp_path): # Instead, we can creatae a deferred type variable in the scope and # resolve members relative to it - not_tt = Variable(name='not_tt', scope=routine) + not_tt = sym.Variable(name='not_tt', scope=routine) assert not_tt.type.dtype == BasicType.DEFERRED # pylint: disable=no-member not_tt_invalid = not_tt.get_derived_type_member('invalid') # pylint: disable=no-member assert not_tt_invalid == 'not_tt%invalid'