From 01eafe78871ee9a5b3235ad2f5952ef0b50011eb Mon Sep 17 00:00:00 2001 From: Michael Staneker <michael.staneker@ecmwf.int> Date: Wed, 30 Oct 2024 16:08:55 +0100 Subject: [PATCH 1/2] [F2C transpilation] (driver level) convert interface to import --- loki/transformations/transpile/fortran_c.py | 25 +++++++++++ .../transpile/tests/test_transpile.py | 44 +++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/loki/transformations/transpile/fortran_c.py b/loki/transformations/transpile/fortran_c.py index bbb516a03..11015656f 100644 --- a/loki/transformations/transpile/fortran_c.py +++ b/loki/transformations/transpile/fortran_c.py @@ -182,6 +182,7 @@ def transform_subroutine(self, routine, **kwargs): depth = depths[item] if role == 'driver': + self.interface_to_import(routine, targets) return for arg in routine.arguments: @@ -243,6 +244,30 @@ def convert_kwargs_to_args(self, routine, targets): if inline_call_map: routine.body = SubstituteExpressions(inline_call_map).visit(routine.body) + def interface_to_import(self, routine, targets): + """ + Convert interface to import. + """ + for call in FindNodes(CallStatement).visit(routine.body): + if str(call.name).lower() in as_tuple(targets): + call.convert_kwargs_to_args() + intfs = FindNodes(Interface).visit(routine.spec) + removal_map = {} + for i in intfs: + for b in i.body: + if isinstance(b, Subroutine): + if targets and b.name.lower() in targets: + # Create a new module import with explicitly qualified symbol + modname = f'{b.name}_FC_MOD' + new_symbol = Variable(name=f'{b.name}_FC', scope=routine) + new_import = Import(module=modname, c_import=False, symbols=(new_symbol,)) + routine.spec.prepend(new_import) + # Mark current import for removal + removal_map[i] = None + # Apply any scheduled interface removals to spec + if removal_map: + routine.spec = Transformer(removal_map).visit(routine.spec) + def c_struct_typedef(self, derived): """ Create the :class:`TypeDef` for the C-wrapped struct definition. diff --git a/loki/transformations/transpile/tests/test_transpile.py b/loki/transformations/transpile/tests/test_transpile.py index 506ea30c8..e9997f4c8 100644 --- a/loki/transformations/transpile/tests/test_transpile.py +++ b/loki/transformations/transpile/tests/test_transpile.py @@ -1372,6 +1372,50 @@ def init_var(dtype, val=0): fc_function(in_var, out_var) assert int(out_var) == expected_results[i] +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transpile_interface_to_module(tmp_path, frontend): + driver_fcode = """ +SUBROUTINE driver_interface_to_module(a, b, c) + IMPLICIT NONE + INTERFACE + SUBROUTINE KERNEL(a, b, c) + INTEGER, INTENT(INOUT) :: a, b, c + END SUBROUTINE KERNEL + END INTERFACE + INTERFACE + SUBROUTINE KERNEL2(a, b) + INTEGER, INTENT(INOUT) :: a, b + END SUBROUTINE KERNEL2 + END INTERFACE + INTERFACE + SUBROUTINE KERNEL3(a) + INTEGER, INTENT(INOUT) :: a + END SUBROUTINE KERNEL3 + END INTERFACE + + INTEGER, INTENT(INOUT) :: a, b, c + + CALL kernel(a, b ,c) + CALL kernel2(a, b) +END SUBROUTINE driver_interface_to_module + """.strip() + + routine = Subroutine.from_source(driver_fcode, frontend=frontend) + + interfaces = FindNodes(ir.Interface).visit(routine.spec) + imports = FindNodes(ir.Import).visit(routine.spec) + assert len(interfaces) == 3 + assert not imports + + f2c = FortranCTransformation() + f2c.apply(source=routine, path=tmp_path, targets=('kernel',), role='driver') + + interfaces = FindNodes(ir.Interface).visit(routine.spec) + imports = FindNodes(ir.Import).visit(routine.spec) + assert len(interfaces) == 2 + assert len(imports) == 1 + assert imports[0].module.upper() == 'KERNEL_FC_MOD' + @pytest.fixture(scope='module', name='horizontal') def fixture_horizontal(): From 384e2ec1b72dcbb57035ea63527d4921fa9eda9b Mon Sep 17 00:00:00 2001 From: Michael Staneker <michael.staneker@ecmwf.int> Date: Mon, 25 Nov 2024 13:25:04 +0100 Subject: [PATCH 2/2] [F2C transpilation] improve implementation for (driver level) convert interface to import --- loki/transformations/transpile/fortran_c.py | 19 +++++++++---------- .../transpile/tests/test_transpile.py | 6 +++--- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/loki/transformations/transpile/fortran_c.py b/loki/transformations/transpile/fortran_c.py index 11015656f..69f1396d0 100644 --- a/loki/transformations/transpile/fortran_c.py +++ b/loki/transformations/transpile/fortran_c.py @@ -254,16 +254,15 @@ def interface_to_import(self, routine, targets): intfs = FindNodes(Interface).visit(routine.spec) removal_map = {} for i in intfs: - for b in i.body: - if isinstance(b, Subroutine): - if targets and b.name.lower() in targets: - # Create a new module import with explicitly qualified symbol - modname = f'{b.name}_FC_MOD' - new_symbol = Variable(name=f'{b.name}_FC', scope=routine) - new_import = Import(module=modname, c_import=False, symbols=(new_symbol,)) - routine.spec.prepend(new_import) - # Mark current import for removal - removal_map[i] = None + for s in i.symbols: + if targets and s in targets: + # Create a new module import with explicitly qualified symbol + new_symbol = s.clone(name=f'{s.name}_FC', scope=routine) + modname = f'{new_symbol.name}_MOD' + new_import = Import(module=modname, c_import=False, symbols=(new_symbol,)) + routine.spec.prepend(new_import) + # Mark current import for removal + removal_map[i] = None # Apply any scheduled interface removals to spec if removal_map: routine.spec = Transformer(removal_map).visit(routine.spec) diff --git a/loki/transformations/transpile/tests/test_transpile.py b/loki/transformations/transpile/tests/test_transpile.py index e9997f4c8..03c083612 100644 --- a/loki/transformations/transpile/tests/test_transpile.py +++ b/loki/transformations/transpile/tests/test_transpile.py @@ -1410,11 +1410,11 @@ def test_transpile_interface_to_module(tmp_path, frontend): f2c = FortranCTransformation() f2c.apply(source=routine, path=tmp_path, targets=('kernel',), role='driver') - interfaces = FindNodes(ir.Interface).visit(routine.spec) - imports = FindNodes(ir.Import).visit(routine.spec) - assert len(interfaces) == 2 + assert len(routine.interfaces) == 2 + imports = routine.imports assert len(imports) == 1 assert imports[0].module.upper() == 'KERNEL_FC_MOD' + assert imports[0].symbols == ('KERNEL_FC',) @pytest.fixture(scope='module', name='horizontal')