Skip to content

Commit

Permalink
[F2C transpilation] (driver level) convert interface to import
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSt98 committed Nov 25, 2024
1 parent 8dbdafc commit f48f528
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
25 changes: 25 additions & 0 deletions loki/transformations/transpile/fortran_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def transform_subroutine(self, routine, **kwargs):
depth = depths[item]

if role == 'driver':
self.interface_to_import(routine, targets)
return

for arg in routine.arguments:
Expand Down Expand Up @@ -243,6 +244,30 @@ def convert_kwargs_to_args(self, routine, targets):
if inline_call_map:
routine.body = SubstituteExpressions(inline_call_map).visit(routine.body)

def interface_to_import(self, routine, targets):
"""
Convert interface to import.
"""
for call in FindNodes(CallStatement).visit(routine.body):
if str(call.name).lower() in as_tuple(targets):
call.convert_kwargs_to_args()
intfs = FindNodes(Interface).visit(routine.spec)
removal_map = {}
for i in intfs:
for b in i.body:
if isinstance(b, Subroutine):
if targets and b.name.lower() in targets:
# Create a new module import with explicitly qualified symbol
modname = f'{b.name}_FC_MOD'
new_symbol = Variable(name=f'{b.name}_FC', scope=routine)
new_import = Import(module=modname, c_import=False, symbols=(new_symbol,))
routine.spec.prepend(new_import)
# Mark current import for removal
removal_map[i] = None
# Apply any scheduled interface removals to spec
if removal_map:
routine.spec = Transformer(removal_map).visit(routine.spec)

def c_struct_typedef(self, derived):
"""
Create the :class:`TypeDef` for the C-wrapped struct definition.
Expand Down
44 changes: 44 additions & 0 deletions loki/transformations/transpile/tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,50 @@ def init_var(dtype, val=0):
fc_function(in_var, out_var)
assert int(out_var) == expected_results[i]

@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_interface_to_module(tmp_path, frontend):
driver_fcode = """
SUBROUTINE driver_interface_to_module(a, b, c)
IMPLICIT NONE
INTERFACE
SUBROUTINE KERNEL(a, b, c)
INTEGER, INTENT(INOUT) :: a, b, c
END SUBROUTINE KERNEL
END INTERFACE
INTERFACE
SUBROUTINE KERNEL2(a, b)
INTEGER, INTENT(INOUT) :: a, b
END SUBROUTINE KERNEL2
END INTERFACE
INTERFACE
SUBROUTINE KERNEL3(a)
INTEGER, INTENT(INOUT) :: a
END SUBROUTINE KERNEL3
END INTERFACE
INTEGER, INTENT(INOUT) :: a, b, c
CALL kernel(a, b ,c)
CALL kernel2(a, b)
END SUBROUTINE driver_interface_to_module
""".strip()

routine = Subroutine.from_source(driver_fcode, frontend=frontend)

interfaces = FindNodes(ir.Interface).visit(routine.spec)
imports = FindNodes(ir.Import).visit(routine.spec)
assert len(interfaces) == 3
assert not imports

f2c = FortranCTransformation()
f2c.apply(source=routine, path=tmp_path, targets=('kernel',), role='driver')

interfaces = FindNodes(ir.Interface).visit(routine.spec)
imports = FindNodes(ir.Import).visit(routine.spec)
assert len(interfaces) == 2
assert len(imports) == 1
assert imports[0].module.upper() == 'KERNEL_FC_MOD'


@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
Expand Down

0 comments on commit f48f528

Please sign in to comment.