diff --git a/loki/transformations/transpile/fortran_c.py b/loki/transformations/transpile/fortran_c.py index 11015656f..69f1396d0 100644 --- a/loki/transformations/transpile/fortran_c.py +++ b/loki/transformations/transpile/fortran_c.py @@ -254,16 +254,15 @@ def interface_to_import(self, routine, targets): intfs = FindNodes(Interface).visit(routine.spec) removal_map = {} for i in intfs: - for b in i.body: - if isinstance(b, Subroutine): - if targets and b.name.lower() in targets: - # Create a new module import with explicitly qualified symbol - modname = f'{b.name}_FC_MOD' - new_symbol = Variable(name=f'{b.name}_FC', scope=routine) - new_import = Import(module=modname, c_import=False, symbols=(new_symbol,)) - routine.spec.prepend(new_import) - # Mark current import for removal - removal_map[i] = None + for s in i.symbols: + if targets and s in targets: + # Create a new module import with explicitly qualified symbol + new_symbol = s.clone(name=f'{s.name}_FC', scope=routine) + modname = f'{new_symbol.name}_MOD' + new_import = Import(module=modname, c_import=False, symbols=(new_symbol,)) + routine.spec.prepend(new_import) + # Mark current import for removal + removal_map[i] = None # Apply any scheduled interface removals to spec if removal_map: routine.spec = Transformer(removal_map).visit(routine.spec) diff --git a/loki/transformations/transpile/tests/test_transpile.py b/loki/transformations/transpile/tests/test_transpile.py index e9997f4c8..03c083612 100644 --- a/loki/transformations/transpile/tests/test_transpile.py +++ b/loki/transformations/transpile/tests/test_transpile.py @@ -1410,11 +1410,11 @@ def test_transpile_interface_to_module(tmp_path, frontend): f2c = FortranCTransformation() f2c.apply(source=routine, path=tmp_path, targets=('kernel',), role='driver') - interfaces = FindNodes(ir.Interface).visit(routine.spec) - imports = FindNodes(ir.Import).visit(routine.spec) - assert len(interfaces) == 2 + assert len(routine.interfaces) == 2 + imports = routine.imports assert len(imports) == 1 assert imports[0].module.upper() == 'KERNEL_FC_MOD' + assert imports[0].symbols == ('KERNEL_FC',) @pytest.fixture(scope='module', name='horizontal')