diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index 1efc3c8ac..42c8a42ac 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -2100,8 +2100,17 @@ def visit_If_Construct(self, o, **kwargs): else_if_stmt_index, else_if_stmts = zip(*else_if_stmts) else: else_if_stmt_index = () - else_stmt = get_child(o, Fortran2003.Else_Stmt) - else_stmt_index = o.children.index(else_stmt) if else_stmt else end_if_stmt_index + + # Note: we need to use here the same method as for else-if because finding Else_Stmt + # directly and checking its position via o.children.index may give the wrong result. + # This is because Else_Stmt may erronously compare equal to other node types. + # See https://github.com/stfc/fparser/issues/400 + else_stmt = tuple((i, c) for i, c in enumerate(o.children) if isinstance(c, Fortran2003.Else_Stmt)) + if else_stmt: + assert len(else_stmt) == 1 + else_stmt_index, else_stmt = else_stmt[0] + else: + else_stmt_index = end_if_stmt_index conditions = as_tuple(self.visit(c, **kwargs) for c in (if_then_stmt,) + else_if_stmts) bodies = tuple( tuple(flatten(as_tuple(self.visit(c, **kwargs) for c in o.children[start+1:stop]))) diff --git a/tests/test_control_flow.py b/tests/test_control_flow.py index dbccb6b30..5b558f93b 100644 --- a/tests/test_control_flow.py +++ b/tests/test_control_flow.py @@ -10,7 +10,7 @@ import numpy as np from conftest import jit_compile, clean_test, available_frontends -from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node +from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node, Intrinsic @pytest.fixture(scope='module', name='here') @@ -455,3 +455,45 @@ def test_conditional_bodies(frontend): c.else_body and isinstance(c.else_body, tuple) and all(isinstance(n, Node) for n in c.else_body) for c in conditionals ) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_conditional_else_body_return(frontend): + fcode = """ +FUNCTION FUNC(PX,KN) +IMPLICIT NONE +INTEGER,INTENT(INOUT) :: KN +REAL,INTENT(IN) :: PX +REAL :: FUNC +INTEGER :: J +REAL :: Z0, Z1, Z2 +Z0= 1.0 +Z1= PX +IF (KN == 0) THEN + FUNC= Z0 + RETURN +ELSEIF (KN == 1) THEN + FUNC= Z1 + RETURN +ELSE + DO J=2,KN + Z2= Z0+Z1 + Z0= Z1 + Z1= Z2 + ENDDO + FUNC= Z2 + RETURN +ENDIF +END FUNCTION FUNC + """.strip() + + routine = Subroutine.from_source(fcode, frontend=frontend) + conditionals = FindNodes(Conditional).visit(routine.body) + assert len(conditionals) == 2 + assert isinstance(conditionals[0].body[-1], Intrinsic) + assert conditionals[0].body[-1].text.upper() == 'RETURN' + assert conditionals[0].else_body == (conditionals[1],) + assert isinstance(conditionals[1].body[-1], Intrinsic) + assert conditionals[1].body[-1].text.upper() == 'RETURN' + assert isinstance(conditionals[1].else_body[-1], Intrinsic) + assert conditionals[1].else_body[-1].text.upper() == 'RETURN'