diff --git a/loki/transformations/transpile/fortran_c.py b/loki/transformations/transpile/fortran_c.py index bbb516a03..69f1396d0 100644 --- a/loki/transformations/transpile/fortran_c.py +++ b/loki/transformations/transpile/fortran_c.py @@ -182,6 +182,7 @@ def transform_subroutine(self, routine, **kwargs): depth = depths[item] if role == 'driver': + self.interface_to_import(routine, targets) return for arg in routine.arguments: @@ -243,6 +244,29 @@ def convert_kwargs_to_args(self, routine, targets): if inline_call_map: routine.body = SubstituteExpressions(inline_call_map).visit(routine.body) + def interface_to_import(self, routine, targets): + """ + Convert interface to import. + """ + for call in FindNodes(CallStatement).visit(routine.body): + if str(call.name).lower() in as_tuple(targets): + call.convert_kwargs_to_args() + intfs = FindNodes(Interface).visit(routine.spec) + removal_map = {} + for i in intfs: + for s in i.symbols: + if targets and s in targets: + # Create a new module import with explicitly qualified symbol + new_symbol = s.clone(name=f'{s.name}_FC', scope=routine) + modname = f'{new_symbol.name}_MOD' + new_import = Import(module=modname, c_import=False, symbols=(new_symbol,)) + routine.spec.prepend(new_import) + # Mark current import for removal + removal_map[i] = None + # Apply any scheduled interface removals to spec + if removal_map: + routine.spec = Transformer(removal_map).visit(routine.spec) + def c_struct_typedef(self, derived): """ Create the :class:`TypeDef` for the C-wrapped struct definition. diff --git a/loki/transformations/transpile/tests/test_transpile.py b/loki/transformations/transpile/tests/test_transpile.py index 506ea30c8..03c083612 100644 --- a/loki/transformations/transpile/tests/test_transpile.py +++ b/loki/transformations/transpile/tests/test_transpile.py @@ -1372,6 +1372,50 @@ def init_var(dtype, val=0): fc_function(in_var, out_var) assert int(out_var) == expected_results[i] +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transpile_interface_to_module(tmp_path, frontend): + driver_fcode = """ +SUBROUTINE driver_interface_to_module(a, b, c) + IMPLICIT NONE + INTERFACE + SUBROUTINE KERNEL(a, b, c) + INTEGER, INTENT(INOUT) :: a, b, c + END SUBROUTINE KERNEL + END INTERFACE + INTERFACE + SUBROUTINE KERNEL2(a, b) + INTEGER, INTENT(INOUT) :: a, b + END SUBROUTINE KERNEL2 + END INTERFACE + INTERFACE + SUBROUTINE KERNEL3(a) + INTEGER, INTENT(INOUT) :: a + END SUBROUTINE KERNEL3 + END INTERFACE + + INTEGER, INTENT(INOUT) :: a, b, c + + CALL kernel(a, b ,c) + CALL kernel2(a, b) +END SUBROUTINE driver_interface_to_module + """.strip() + + routine = Subroutine.from_source(driver_fcode, frontend=frontend) + + interfaces = FindNodes(ir.Interface).visit(routine.spec) + imports = FindNodes(ir.Import).visit(routine.spec) + assert len(interfaces) == 3 + assert not imports + + f2c = FortranCTransformation() + f2c.apply(source=routine, path=tmp_path, targets=('kernel',), role='driver') + + assert len(routine.interfaces) == 2 + imports = routine.imports + assert len(imports) == 1 + assert imports[0].module.upper() == 'KERNEL_FC_MOD' + assert imports[0].symbols == ('KERNEL_FC',) + @pytest.fixture(scope='module', name='horizontal') def fixture_horizontal():