From ca3147d8c81e2ad5d7f370f8a2f4ec0b5babf5bf Mon Sep 17 00:00:00 2001
From: Natalia Gavrilenko <natgavrilenko@gmail.com>
Date: Fri, 12 Jul 2024 11:36:32 +0200
Subject: [PATCH] Refactored control flow

---
 .../visitors/spirv/ProgramBuilderSpv.java     | 123 +-----
 .../visitors/spirv/VisitorOpsControlFlow.java | 206 +++++-----
 .../spirv/helpers/HelperControlFlow.java      | 107 +++++
 .../visitors/spirv/ProgramBuilderTest.java    |   7 +-
 .../visitors/spirv/VisitorOpsBarrierTest.java |   2 +-
 .../spirv/VisitorOpsControlFlowTest.java      | 375 +++++++++---------
 .../spirv/mocks/MockHelperControlFlow.java    |  32 ++
 .../spirv/mocks/MockProgramBuilderSpv.java    |  28 +-
 8 files changed, 440 insertions(+), 440 deletions(-)
 create mode 100644 dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperControlFlow.java
 create mode 100644 dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockHelperControlFlow.java

diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderSpv.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderSpv.java
index e7aba4a36a..a073ff566a 100644
--- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderSpv.java
+++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderSpv.java
@@ -8,6 +8,7 @@
 import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.BuiltIn;
 import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.Decoration;
 import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType;
+import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperControlFlow;
 import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperDecorations;
 import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperTags;
 import com.dat3m.dartagnan.parsers.program.visitors.spirv.transformers.ThreadCreator;
@@ -16,8 +17,6 @@
 import com.dat3m.dartagnan.expression.type.ScopedPointerType;
 import com.dat3m.dartagnan.program.*;
 import com.dat3m.dartagnan.program.event.*;
-import com.dat3m.dartagnan.program.event.core.Label;
-import com.dat3m.dartagnan.program.event.core.Local;
 import com.dat3m.dartagnan.program.memory.Memory;
 import com.dat3m.dartagnan.program.memory.MemoryObject;
 import org.apache.logging.log4j.LogManager;
@@ -37,14 +36,9 @@ public class ProgramBuilderSpv {
     protected final Map<String, Type> types = new HashMap<>();
     protected final Map<String, Expression> expressions = new HashMap<>();
     protected final Map<String, Function> forwardFunctions = new HashMap<>();
-    protected final Map<String, Label> labels = new HashMap<>();
-    protected final Deque<Label> blocks = new ArrayDeque<>();
-    protected final Map<Label, Event> blockEndEvents = new HashMap<>();
-    protected final Map<Label, Map<Register, String>> phiDefinitions = new HashMap<>();
-    protected final Map<Label, Label> cfDefinitions = new HashMap<>();
     protected final List<Integer> threadGrid;
     protected final Map<String, Expression> inputs;
-    protected final Program program;
+    public final Program program;
 
     protected Function currentFunction;
     protected String entryPointId;
@@ -53,6 +47,7 @@ public class ProgramBuilderSpv {
 
     private final HelperTags helperTags = new HelperTags();
     private final HelperDecorations helperDecorations;
+    protected HelperControlFlow helperControlFlow = new HelperControlFlow(expressions);
 
     public ProgramBuilderSpv(List<Integer> threadGrid, Map<String, Expression> inputs) {
         validateThreadGrid(threadGrid);
@@ -64,7 +59,7 @@ public ProgramBuilderSpv(List<Integer> threadGrid, Map<String, Expression> input
 
     public Program build() {
         validateBeforeBuild();
-        preprocessBlocks();
+        helperControlFlow.build();
         BuiltIn builtIn = (BuiltIn) getDecoration(DecorationType.BUILT_IN);
         Set<ScopedPointerVariable> variables = expressions.values().stream()
                 .filter(ScopedPointerVariable.class::isInstance)
@@ -163,26 +158,11 @@ private Function checkFunctionType(String id, Function function, Type type) {
                 "function type doesn't match the function definition", id);
     }
 
-    public void startBlock(Label label) {
-        if (blockEndEvents.containsKey(label)) {
-            throw new ParsingException("Attempt to redefine label '%s'", label.getName());
-        }
-        blocks.push(label);
-    }
-
-    public Event endBlock(Event event) {
-        if (blocks.isEmpty()) {
-            throw new ParsingException("Attempt to exit block while not in a block definition");
-        }
-        blockEndEvents.put(blocks.pop(), event);
-        return event;
-    }
-
     public Event addEvent(Event event) {
         if (currentFunction == null) {
             throw new ParsingException("Attempt to add an event outside a function definition");
         }
-        if (blocks.isEmpty()) {
+        if (!helperControlFlow.isInsideBlock()) {
             throw new ParsingException("Attempt to add an event outside a control flow block");
         }
         if (event instanceof RegWriter regWriter) {
@@ -223,14 +203,6 @@ public String getExpressionStorageClass(String name) {
         throw new ParsingException("Reference to undefined pointer '%s'", name);
     }
 
-    public Label makeBranchBackJumpLabel(Label label) {
-        String id = label.getName() + "_back";
-        if (labels.containsKey(id)) {
-            throw new ParsingException("Overlapping blocks with back jump in label '%s'", label.getName());
-        }
-        return getOrCreateLabel(id);
-    }
-
     public boolean hasInput(String id) {
         return inputs.containsKey(id);
     }
@@ -253,22 +225,6 @@ public Register addRegister(String id, String typeId) {
         return getCurrentFunctionOrThrowError().newRegister(id, getType(typeId));
     }
 
-    public boolean hasBlock(String id) {
-        if (!labels.containsKey(id)) {
-            return false;
-        }
-        Label label = labels.get(id);
-        return blockEndEvents.containsKey(label) || blocks.contains(label);
-    }
-
-    public Label getOrCreateLabel(String id) {
-        Label label = labels.computeIfAbsent(id, EventFactory::newLabel);
-        if (label.getFunction() == null) {
-            label.setFunction(getCurrentFunctionOrThrowError());
-        }
-        return label;
-    }
-
     public String getScope(String id) {
         return helperTags.visitScope(id, getExpression(id));
     }
@@ -289,7 +245,11 @@ public Decoration getDecoration(DecorationType type) {
         return helperDecorations.getDecoration(type);
     }
 
-    private Function getCurrentFunctionOrThrowError() {
+    public HelperControlFlow getHelperControlFlow() {
+        return helperControlFlow;
+    }
+
+    public Function getCurrentFunctionOrThrowError() {
         if (currentFunction != null) {
             return currentFunction;
         }
@@ -306,28 +266,11 @@ private void validateBeforeBuild() {
             throw new ParsingException("Unclosed definition for function '%s'",
                     currentFunction.getName());
         }
-        if (!blocks.isEmpty()) {
-            throw new ParsingException("Unclosed blocks for labels %s",
-                    String.join(",", blocks.stream().map(Label::getName).toList()));
-        }
         if (nextOps != null) {
             throw new ParsingException("Missing expected op: %s",
                     String.join(",", nextOps));
-        }
-        // TODO: Validate no event refers to an undefined label
-    }
 
-    private void preprocessBlocks() {
-        Map<Event, Label> blockEndToLabelMap = new HashMap<>();
-        for (Map.Entry<Label, Event> entry : blockEndEvents.entrySet()) {
-            if (blockEndToLabelMap.containsKey(entry.getValue())) {
-                throw new ParsingException("Malformed control flow, " +
-                        "multiple block refer to the same end event " + entry.getValue());
-            }
-            blockEndToLabelMap.put(entry.getValue(), entry.getKey());
         }
-        insertPhiDefinitions(blockEndToLabelMap);
-        insertBlockEndLabels(blockEndToLabelMap);
     }
 
     private void checkSpecification() {
@@ -338,40 +281,6 @@ private void checkSpecification() {
         }
     }
 
-    public void addPhiDefinition(Label label, Register register, String id) {
-        phiDefinitions.computeIfAbsent(label, k -> new HashMap<>()).put(register, id);
-    }
-
-    private void insertPhiDefinitions(Map<Event, Label> blockEndToLabelMap) {
-        for (Function function : program.getFunctions()) {
-            for (Event event : function.getEvents()) {
-                Label label = blockEndToLabelMap.get(event.getSuccessor());
-                if (label != null) {
-                    Map<Register, String> phi = phiDefinitions.get(label);
-                    if (phi != null) {
-                        for (Map.Entry<Register, String> entry : phi.entrySet()) {
-                            event.insertAfter(new Local(entry.getKey(), getExpression(entry.getValue())));
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    private void insertBlockEndLabels(Map<Event, Label> blockEndToLabelMap) {
-        for (Function function : program.getFunctions()) {
-            for (Event event : function.getEvents()) {
-                Label label = blockEndToLabelMap.get(event.getSuccessor());
-                if (label != null) {
-                    Label endLabel = cfDefinitions.get(label);
-                    if (endLabel != null) {
-                        event.insertAfter(endLabel);
-                    }
-                }
-            }
-        }
-    }
-
     private Function getEntryPointFunction() {
         if (entryPointId == null) {
             throw new ParsingException("Cannot build the program, entryPointId is missing");
@@ -389,8 +298,6 @@ private Function getEntryPointFunction() {
         throw new ParsingException("Entry point expression '%s' must be a function", entryPointId);
     }
 
-    // TODO: Move to the corresponding visitors, tests, mocks etc ======================================================
-
     public FunctionType getCurrentFunctionType() {
         return getCurrentFunctionOrThrowError().getFunctionType();
     }
@@ -399,16 +306,6 @@ public String getCurrentFunctionName() {
         return getCurrentFunctionOrThrowError().getName();
     }
 
-    public Label makeBranchEndLabel(Label label) {
-        String id = label.getName() + "_end";
-        if (labels.containsKey(id)) {
-            throw new ParsingException("Overlapping blocks with endpoint in label '%s'", label.getName());
-        }
-        Label endLabel = getOrCreateLabel(id);
-        cfDefinitions.put(label, endLabel);
-        return endLabel;
-    }
-
     private void validateThreadGrid(List<Integer> threadGrid) {
         if (threadGrid.size() != 4) {
             throw new ParsingException("Thread grid must have 4 dimensions");
diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlow.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlow.java
index d5657b245c..47343c13b4 100644
--- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlow.java
+++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlow.java
@@ -2,17 +2,19 @@
 
 import com.dat3m.dartagnan.exception.ParsingException;
 import com.dat3m.dartagnan.expression.Expression;
+import com.dat3m.dartagnan.expression.ExpressionFactory;
 import com.dat3m.dartagnan.expression.Type;
 import com.dat3m.dartagnan.expression.type.TypeFactory;
 import com.dat3m.dartagnan.parsers.SpirvBaseVisitor;
 import com.dat3m.dartagnan.parsers.SpirvParser;
+import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperControlFlow;
 import com.dat3m.dartagnan.program.Register;
 import com.dat3m.dartagnan.program.event.Event;
 import com.dat3m.dartagnan.program.event.EventFactory;
 import com.dat3m.dartagnan.program.event.core.Label;
 import com.dat3m.dartagnan.program.event.functions.Return;
 
-import java.util.Set;
+import java.util.*;
 
 import static com.dat3m.dartagnan.program.event.EventFactory.newFunctionReturn;
 
@@ -20,12 +22,14 @@ public class VisitorOpsControlFlow extends SpirvBaseVisitor<Event> {
 
     private static final TypeFactory types = TypeFactory.getInstance();
     private final ProgramBuilderSpv builder;
+    private final HelperControlFlow helper;
     private String continueLabelId;
     private String mergeLabelId;
     private String nextLabelId;
 
     public VisitorOpsControlFlow(ProgramBuilderSpv builder) {
         this.builder = builder;
+        this.helper = builder.getHelperControlFlow();
     }
 
     @Override
@@ -35,8 +39,9 @@ public Event visitOpPhi(SpirvParser.OpPhiContext ctx) {
         Register register = builder.addRegister(id, typeId);
         for (SpirvParser.VariableContext vCtx : ctx.variable()) {
             SpirvParser.PairIdRefIdRefContext pCtx = vCtx.pairIdRefIdRef();
-            Label event = builder.getOrCreateLabel(pCtx.idRef(1).getText());
-            builder.addPhiDefinition(event, register, pCtx.idRef(0).getText());
+            String labelId = pCtx.idRef(1).getText();
+            String expressionId = pCtx.idRef(0).getText();
+            helper.addPhiDefinition(labelId, register, expressionId);
         }
         builder.addExpression(id, register);
         return null;
@@ -51,61 +56,12 @@ public Event visitOpLabel(SpirvParser.OpLabelContext ctx) {
             }
             nextLabelId = null;
         }
-        Label event = builder.getOrCreateLabel(ctx.idResult().getText());
-        builder.startBlock(event);
+        String labelId = ctx.idResult().getText();
+        Label event = helper.getOrCreateLabel(labelId);
+        helper.startBlock(labelId);
         return builder.addEvent(event);
     }
 
-    @Override
-    public Event visitOpBranch(SpirvParser.OpBranchContext ctx) {
-        String labelId = ctx.targetLabel().getText();
-        if (continueLabelId == null && mergeLabelId == null) {
-            Label label = builder.getOrCreateLabel(labelId);
-            Event event = EventFactory.newGoto(label);
-            builder.addEvent(event);
-            return builder.endBlock(event);
-        }
-
-        // TODO: Test me!
-        continueLabelId = null;
-        mergeLabelId = null;
-        Label label = builder.getOrCreateLabel(labelId);
-        Event event = EventFactory.newGoto(label);
-        builder.addEvent(event);
-        return builder.endBlock(event);
-        //throw new ParsingException("Unsupported control flow around OpBranch '%s'", labelId);
-    }
-
-    @Override
-    public Event visitOpBranchConditional(SpirvParser.OpBranchConditionalContext ctx) {
-        if (ctx.trueLabel().getText().equals(ctx.falseLabel().getText())) {
-            throw new ParsingException("Labels of conditional branch cannot be the same");
-        }
-        if (mergeLabelId == null) {
-            return visitOpBranchConditionalUnstructured(ctx);
-        }
-        if (continueLabelId != null) {
-            if (mergeLabelId.equals(ctx.trueLabel().getText())
-                    && continueLabelId.equals(ctx.falseLabel().getText())) {
-                mergeLabelId = null;
-                continueLabelId = null;
-                nextLabelId = ctx.trueLabel().getText();
-                builder.setNextOps(Set.of("OpLabel"));
-                return visitOpBranchConditionalStructuredLoop(ctx);
-            }
-            throw new ParsingException("Illegal labels, expected mergeLabel='%s' and continueLabel='%s' " +
-                    "but received mergeLabel='%s' and continueLabel='%s'",
-                    mergeLabelId, continueLabelId, ctx.trueLabel().getText(), ctx.falseLabel().getText());
-        } else if (mergeLabelId.equals(ctx.falseLabel().getText())) {
-            mergeLabelId = null;
-            nextLabelId = ctx.trueLabel().getText();
-            builder.setNextOps(Set.of("OpLabel"));
-            return visitOpBranchConditionalStructured(ctx);
-        }
-        throw new ParsingException("Illegal last label in conditional branch, " +
-                "expected '%s' but received '%s'", mergeLabelId, ctx.falseLabel().getText());
-    }
-
     @Override
     public Event visitOpSelectionMerge(SpirvParser.OpSelectionMergeContext ctx) {
         if (mergeLabelId == null) {
@@ -121,14 +77,40 @@ public Event visitOpLoopMerge(SpirvParser.OpLoopMergeContext ctx) {
         if (continueLabelId == null && mergeLabelId == null) {
             continueLabelId = ctx.continueTarget().getText();
             mergeLabelId = ctx.mergeBlock().getText();
-            builder.setNextOps(Set.of("OpBranch", "OpBranchConditional"));
+            builder.setNextOps(Set.of("OpBranchConditional", "OpBranch"));
             return null;
         }
         throw new ParsingException("End and continue labels must be null");
+    }
 
-        // TODO: For a structured while loop, a control dependency
-        //  should be generated in the same way as for a structured if.
-        //  Add a new event for this.
+    @Override
+    public Event visitOpBranch(SpirvParser.OpBranchContext ctx) {
+        String labelId = ctx.targetLabel().getText();
+        if (continueLabelId == null && mergeLabelId == null) {
+            return visitGoto(labelId);
+        }
+        if (continueLabelId != null && mergeLabelId != null) {
+            return visitLoopBranch(labelId);
+        }
+        throw new ParsingException("OpBranch '%s' must be either " +
+                "a part of a loop definition or an arbitrary goto", labelId);
+    }
+
+    @Override
+    public Event visitOpBranchConditional(SpirvParser.OpBranchConditionalContext ctx) {
+        Expression guard = builder.getExpression(ctx.condition().getText());
+        String trueLabelId = ctx.trueLabel().getText();
+        String falseLabelId = ctx.falseLabel().getText();
+        if (trueLabelId.equals(falseLabelId)) {
+            throw new ParsingException("Labels of conditional branch cannot be the same");
+        }
+        if (mergeLabelId != null) {
+            if (continueLabelId != null) {
+                return visitLoopBranchConditional(guard, trueLabelId, falseLabelId);
+            }
+            return visitIfBranch(guard, trueLabelId, falseLabelId);
+        }
+        return visitConditionalJump(guard, trueLabelId, falseLabelId);
     }
 
     @Override
@@ -137,7 +119,7 @@ public Event visitOpReturn(SpirvParser.OpReturnContext ctx) {
         if (types.getVoidType().equals(returnType)) {
             Return event = newFunctionReturn(null);
             builder.addEvent(event);
-            return builder.endBlock(event);
+            return helper.endBlock(event);
         }
         throw new ParsingException("Illegal non-value return for a non-void function '%s'",
                 builder.getCurrentFunctionName());
@@ -151,76 +133,72 @@ public Event visitOpReturnValue(SpirvParser.OpReturnValueContext ctx) {
             Expression expression = builder.getExpression(valueId);
             Event event = newFunctionReturn(expression);
             builder.addEvent(event);
-            return builder.endBlock(event);
+            return helper.endBlock(event);
         }
         throw new ParsingException("Illegal value return for a void function '%s'",
                 builder.getCurrentFunctionName());
     }
 
-    private Event visitOpBranchConditionalStructuredLoop(SpirvParser.OpBranchConditionalContext ctx) {
-        String trueLabelId = validateForwardLabel(ctx.trueLabel().getText());
-        String falseLabelId = validateBackwardLabel(ctx.falseLabel().getText());
-        Expression guard = builder.getExpression(ctx.condition().getText());
-        Label falseLabel = builder.getOrCreateLabel(falseLabelId);
-        Label trueLabel = builder.getOrCreateLabel(trueLabelId);
-        Event event = builder.addEvent(EventFactory.newIfJump(guard, trueLabel, trueLabel));
-        builder.addEvent(EventFactory.newGoto(falseLabel));
-        return builder.endBlock(event);
+    private Event visitGoto(String labelId) {
+        Label label = helper.getOrCreateLabel(labelId);
+        Event event = EventFactory.newGoto(label);
+        builder.addEvent(event);
+        return helper.endBlock(event);
     }
 
-    private Event visitOpBranchConditionalStructured(SpirvParser.OpBranchConditionalContext ctx) {
-        validateForwardLabel(ctx.trueLabel().getText());
-        String falseLabelId = validateForwardLabel(ctx.falseLabel().getText());
-        Expression guard = builder.getExpression(ctx.condition().getText());
-        Label falseLabel = builder.getOrCreateLabel(falseLabelId);
-        Label mergeLabel = builder.makeBranchEndLabel(falseLabel);
+    private Event visitLoopBranch(String labelId) {
+        continueLabelId = null;
+        mergeLabelId = null;
+        return visitGoto(labelId);
+    }
+
+    private Event visitIfBranch(Expression guard, String trueLabelId, String falseLabelId) {
+        for (String labelId : List.of(trueLabelId, falseLabelId)) {
+            if (helper.isBlockStarted(labelId)) {
+                throw new ParsingException("Illegal backward jump to '%s' from a structured branch", labelId);
+            }
+        }
+        mergeLabelId = null;
+        nextLabelId = trueLabelId;
+        builder.setNextOps(Set.of("OpLabel"));
+        Label falseLabel = helper.getOrCreateLabel(falseLabelId);
+        Label mergeLabel = helper.createMergeLabel(falseLabelId);
         Event event = EventFactory.newIfJumpUnless(guard, falseLabel, mergeLabel);
         builder.addEvent(event);
-        return builder.endBlock(event);
+        return helper.endBlock(event);
     }
 
-    private Event visitOpBranchConditionalUnstructured(SpirvParser.OpBranchConditionalContext ctx) {
-        // Representing a Spir-V two-labels conditional jump as a pair of jumps
-        // TODO: A clean implementation with a new event type,
-        //  support for unstructured back jumps
-        if (!builder.hasBlock(ctx.trueLabel().getText()) && !builder.hasBlock(ctx.falseLabel().getText())) {
-            Label trueLabel = builder.getOrCreateLabel(ctx.trueLabel().getText());
-            Label falseLabel = builder.getOrCreateLabel(ctx.falseLabel().getText());
-            Expression guard = builder.getExpression(ctx.condition().getText());
-            Event trueJump = builder.addEvent(EventFactory.newJump(guard, trueLabel));
-            builder.addEvent(EventFactory.newJumpUnless(guard, falseLabel));
-            return builder.endBlock(trueJump);
+    private Event visitConditionalJump(Expression guard, String trueLabelId, String falseLabelId) {
+        if (helper.isBlockStarted(trueLabelId)) {
+            if (helper.isBlockStarted(falseLabelId)) {
+                throw new ParsingException("Unsupported conditional branch " +
+                        "with two backward jumps to '%s' and '%s'", trueLabelId, falseLabelId);
+            }
+            String labelId = trueLabelId;
+            trueLabelId = falseLabelId;
+            falseLabelId = labelId;
+            guard = ExpressionFactory.getInstance().makeNot(guard);
         }
-
-        Expression guard = builder.getExpression(ctx.condition().getText());
-        Label trueLabel = builder.getOrCreateLabel(ctx.trueLabel().getText());
-        Label falseLabel = builder.getOrCreateLabel(ctx.falseLabel().getText());
-        Label trueEnd = builder.makeBranchBackJumpLabel(trueLabel);
-        Label falseEnd = builder.makeBranchBackJumpLabel(falseLabel);
-
-        Event trueJump = builder.addEvent(EventFactory.newJump(guard, trueEnd));
-        builder.addEvent(EventFactory.newJumpUnless(guard, falseEnd));
-        builder.addEvent(trueEnd);
-        builder.addEvent(EventFactory.newGoto(trueLabel));
-        builder.addEvent(falseEnd);
+        Label trueLabel = helper.getOrCreateLabel(trueLabelId);
+        Label falseLabel = helper.getOrCreateLabel(falseLabelId);
+        Event trueJump = builder.addEvent(EventFactory.newJump(guard, trueLabel));
         builder.addEvent(EventFactory.newGoto(falseLabel));
-        return builder.endBlock(trueJump);
+        return helper.endBlock(trueJump);
     }
 
-    private String validateBackwardLabel(String id) {
-        if (!builder.hasBlock(id)) {
-            throw new ParsingException("Illegal forward jump to label '%s' " +
-                    "from a structured loop", id);
-        }
-        return id;
-    }
+    private Event visitLoopBranchConditional(Expression guard, String trueLabelId, String falseLabelId) {
+        mergeLabelId = null;
+        continueLabelId = null;
+        nextLabelId = trueLabelId;
+        builder.setNextOps(Set.of("OpLabel"));
 
-    private String validateForwardLabel(String id) {
-        if (builder.hasBlock(id)) {
-            throw new ParsingException("Illegal backward jump to label '%s' " +
-                    "from a structured branch", id);
-        }
-        return id;
+        // TODO: For a structured while loop, a control dependency
+        //  should be generated in the same way as for a structured if branch.
+        //  We need to add a new event type for this.
+        //  For now, we can treat while loop as unstructured jumps,
+        //  because Vulkan memory model has no control dependency.
+
+        return visitConditionalJump(guard, trueLabelId, falseLabelId);
     }
 
     public Set<String> getSupportedOps() {
diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperControlFlow.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperControlFlow.java
new file mode 100644
index 0000000000..a23b09b0dc
--- /dev/null
+++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperControlFlow.java
@@ -0,0 +1,107 @@
+package com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers;
+
+import com.dat3m.dartagnan.exception.ParsingException;
+import com.dat3m.dartagnan.expression.Expression;
+import com.dat3m.dartagnan.program.Register;
+import com.dat3m.dartagnan.program.event.Event;
+import com.dat3m.dartagnan.program.event.EventFactory;
+import com.dat3m.dartagnan.program.event.core.Label;
+import com.google.common.collect.Sets;
+
+import java.util.*;
+
+public class HelperControlFlow {
+
+    protected final Map<String, Label> blockLabels = new HashMap<>();
+    protected final Map<String, Event> lastBlockEvents = new HashMap<>();
+    protected final Map<String, String> mergeLabelIds = new HashMap<>();
+    protected final  Deque<String> blockStack = new ArrayDeque<>();
+    protected final Map<String, Map<Register, String>> phiDefinitions = new HashMap<>();
+    protected final Map<String, Expression> expressions;
+
+    public HelperControlFlow(Map<String, Expression> expressions) {
+        this.expressions = expressions;
+    }
+
+    public boolean isInsideBlock() {
+        return !blockStack.isEmpty();
+    }
+
+    public boolean isBlockStarted(String id) {
+        return lastBlockEvents.containsKey(id) || blockStack.contains(id);
+    }
+
+    public void startBlock(String id) {
+        if (lastBlockEvents.containsKey(id)) {
+            throw new ParsingException("Attempt to redefine label '%s'", id);
+        }
+        blockStack.push(id);
+    }
+
+    public Event endBlock(Event event) {
+        if (blockStack.isEmpty()) {
+            throw new ParsingException("Attempt to exit block while not in a block definition");
+        }
+        lastBlockEvents.put(blockStack.pop(), event);
+        return event;
+    }
+
+    public Label getOrCreateLabel(String id) {
+        return blockLabels.computeIfAbsent(id, EventFactory::newLabel);
+    }
+
+    public Label createMergeLabel(String id) {
+        String mergeId = id + "_end";
+        mergeLabelIds.put(id, mergeId);
+        return createLabel(mergeId);
+    }
+
+    public Label createBackJumpLabel(String id) {
+        return createLabel(id + "_back");
+    }
+
+    public void addPhiDefinition(String blockId, Register register, String expressionId) {
+        phiDefinitions.computeIfAbsent(blockId, k -> new HashMap<>()).put(register, expressionId);
+    }
+
+    public void build() {
+        validateBeforeBuild();
+        phiDefinitions.forEach((blockId, def) ->
+                def.forEach((k, v) -> {
+                    Event event = EventFactory.newLocal(k, expressions.get(v));
+                    lastBlockEvents.get(blockId).getPredecessor().insertAfter(event);
+                }));
+        mergeLabelIds.forEach((jumpLabelId, endLabelId) ->
+                lastBlockEvents.get(jumpLabelId).getPredecessor().insertAfter(blockLabels.get(endLabelId)));
+    }
+
+    private void validateBeforeBuild() {
+        if (!blockStack.isEmpty()) {
+            throw new ParsingException("Unclosed blocks %s", String.join(",", blockStack));
+        }
+        Set<String> missingPhiBlocks = Sets.difference(phiDefinitions.keySet(), blockLabels.keySet());
+        if (!missingPhiBlocks.isEmpty()) {
+            throw new ParsingException("Phi operation(s) refer to undefined block(s) %s",
+                    String.join(", ", missingPhiBlocks));
+        }
+        Set<String> missingMergeBlocks = Sets.difference(mergeLabelIds.keySet(), blockLabels.keySet());
+        if (!missingMergeBlocks.isEmpty()) {
+            throw new ParsingException("Branch merge label(s) refer to undefined block(s) %s",
+                    String.join(", ", missingMergeBlocks));
+        }
+        Map<Event, String> reverse = new HashMap<>();
+        lastBlockEvents.forEach((k, v) -> {
+            if (reverse.containsKey(v)) {
+                throw new ParsingException("Multiple blocks end in the same event '%s'", v);
+            }
+            reverse.put(v, k);
+        });
+    }
+
+    private Label createLabel(String id) {
+        if (blockLabels.containsKey(id)) {
+            throw new ParsingException("Attempt to redefine label '%s'", id);
+        }
+        return getOrCreateLabel(id);
+    }
+}
diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderTest.java
index e69041c4fd..62b513099f 100644
--- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderTest.java
+++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderTest.java
@@ -3,6 +3,7 @@
 import com.dat3m.dartagnan.exception.ParsingException;
 import com.dat3m.dartagnan.expression.type.FunctionType;
 import com.dat3m.dartagnan.expression.type.TypeFactory;
+import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperControlFlow;
 import com.dat3m.dartagnan.program.event.Event;
 import com.dat3m.dartagnan.program.event.core.Skip;
 import org.junit.Test;
@@ -18,6 +19,7 @@ public class ProgramBuilderTest {
     private static final TypeFactory types = TypeFactory.getInstance();
 
     private final ProgramBuilderSpv builder = new ProgramBuilderSpv(List.of(1, 1, 1, 1), Map.of());
+    private final HelperControlFlow helper = builder.getHelperControlFlow();
 
     @Test
     public void testAddEventOutsideFunction() {
@@ -35,8 +37,9 @@ public void testAddEventBeforeBlock() {
     public void testAddEventAfterBlock() {
         FunctionType type = types.getFunctionType(types.getVoidType(), List.of());
         builder.startFunctionDefinition("test_func", type, List.of());
-        builder.startBlock(builder.getOrCreateLabel("test_label"));
-        builder.endBlock(new Skip());
+        helper.getOrCreateLabel("test_label");
+        helper.startBlock("test_label");
+        helper.endBlock(new Skip());
         testAddChildError("Attempt to add an event outside a control flow block");
     }
 
diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsBarrierTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsBarrierTest.java
index accf81802d..af3ea3e056 100644
--- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsBarrierTest.java
+++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsBarrierTest.java
@@ -86,7 +86,7 @@ public void testMemoryBarrierNoneSemantics() {
 
     private Event visit(String text) {
         builder.mockFunctionStart();
-        builder.startBlock(builder.getOrCreateLabel("test"));
+        builder.mockLabel("test");
         return new MockSpirvParser(text).spv().spvInstructions().accept(new VisitorOpsBarrier(builder));
     }
 }
diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlowTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlowTest.java
index e4dc3fe6de..f6c9fa7815 100644
--- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlowTest.java
+++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlowTest.java
@@ -1,16 +1,18 @@
 package com.dat3m.dartagnan.parsers.program.visitors.spirv;
 
 import com.dat3m.dartagnan.exception.ParsingException;
-import com.dat3m.dartagnan.expression.Expression;
 import com.dat3m.dartagnan.expression.booleans.BoolLiteral;
-import com.dat3m.dartagnan.expression.booleans.BoolUnaryExpr;
 import com.dat3m.dartagnan.expression.type.FunctionType;
+import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockHelperControlFlow;
 import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockProgramBuilderSpv;
 import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockSpirvParser;
 import com.dat3m.dartagnan.program.Function;
 import com.dat3m.dartagnan.program.Register;
 import com.dat3m.dartagnan.program.event.Event;
-import com.dat3m.dartagnan.program.event.core.*;
+import com.dat3m.dartagnan.program.event.core.CondJump;
+import com.dat3m.dartagnan.program.event.core.IfAsJump;
+import com.dat3m.dartagnan.program.event.core.Label;
+import com.dat3m.dartagnan.program.event.core.Skip;
 import com.dat3m.dartagnan.program.event.functions.Return;
 import org.junit.Test;
 
@@ -22,6 +24,7 @@
 public class VisitorOpsControlFlowTest {
 
     private final MockProgramBuilderSpv builder = new MockProgramBuilderSpv();
+    private final MockHelperControlFlow helper = (MockHelperControlFlow) builder.getHelperControlFlow();
 
     @Test
     public void testOpPhi() {
@@ -37,13 +40,11 @@ public void testOpPhi() {
         Register register = builder.getRegister("%phi");
         assertEquals(builder.getType("%int"), register.getType());
 
-        Label label1 = builder.getOrCreateLabel("%label1");
-        Map<Register, String> phi1 = builder.getPhiDefinitions(label1);
+        Map<Register, String> phi1 = helper.getPhiDefinitions("%label1");
         assertEquals(1, phi1.size());
         assertEquals("%value1", phi1.get(register));
 
-        Label label2 = builder.getOrCreateLabel("%label2");
-        Map<Register, String> phi2 = builder.getPhiDefinitions(label2);
+        Map<Register, String> phi2 = helper.getPhiDefinitions("%label2");
         assertEquals(1, phi2.size());
         assertEquals("%value2", phi2.get(register));
     }
@@ -61,21 +62,19 @@ public void testOpLabel() {
         visit(input);
 
         // then
-        Label label1 = builder.getOrCreateLabel("%label1");
-        Label label2 = builder.getOrCreateLabel("%label2");
-        assertEquals(List.of(label2, label1), builder.getBlocks());
+        assertEquals(List.of("%label2", "%label1"), helper.getBlockStack());
 
         // when
-        builder.endBlock(new Skip());
+        helper.endBlock(new Skip());
 
         // then
-        assertEquals(List.of(label1), builder.getBlocks());
+        assertEquals(List.of("%label1"), helper.getBlockStack());
 
         // when
-        builder.endBlock(new Skip());
+        helper.endBlock(new Skip());
 
         // then
-        assertEquals(List.of(), builder.getBlocks());
+        assertEquals(List.of(), helper.getBlockStack());
     }
 
     @Test
@@ -92,10 +91,10 @@ public void testOpBranch() {
 
         // then
         Function function = builder.getCurrentFunction();
-        CondJump event = (CondJump) function.getEvents().get(1);
-        assertTrue(((BoolLiteral) (event.getGuard())).getValue());
-        assertEquals(builder.getOrCreateLabel("%label"), event.getLabel());
-        assertTrue(builder.getBlocks().isEmpty());
+        CondJump condJump = (CondJump) function.getEvents().get(1);
+        assertTrue(((BoolLiteral) (condJump.getGuard())).getValue());
+        assertEquals("%label", condJump.getLabel().getName());
+        assertTrue(helper.getBlockStack().isEmpty());
     }
 
     @Test
@@ -116,15 +115,15 @@ public void testOpBranchNested() {
         // then
         Function function = builder.getCurrentFunction();
 
-        CondJump event1 = (CondJump) function.getEvents().get(2);
-        assertTrue(((BoolLiteral) (event1.getGuard())).getValue());
-        assertEquals(builder.getOrCreateLabel("%label3"), event1.getLabel());
+        CondJump condJump1 = (CondJump) function.getEvents().get(2);
+        assertTrue(((BoolLiteral) (condJump1.getGuard())).getValue());
+        assertEquals("%label3", condJump1.getLabel().getName());
 
-        CondJump event2 = (CondJump) function.getEvents().get(4);
-        assertTrue(((BoolLiteral) (event2.getGuard())).getValue());
-        assertEquals(builder.getOrCreateLabel("%label2"), event2.getLabel());
+        CondJump condJump2 = (CondJump) function.getEvents().get(4);
+        assertTrue(((BoolLiteral) (condJump2.getGuard())).getValue());
+        assertEquals("%label2", condJump2.getLabel().getName());
 
-        assertEquals(List.of(builder.getOrCreateLabel("%label1")), builder.getBlocks());
+        assertEquals(List.of("%label1"), helper.getBlockStack());
     }
 
     @Test
@@ -157,16 +156,17 @@ public void testStructuredBranch() {
         Label label2 = (Label) events.get(4);
         Return ret = (Return) events.get(5);
 
-        Label label2End = builder.getCfDefinition().get(label2);
+        assertEquals("%label0", label0.getName());
+        assertEquals("%label2", ifJump.getLabel().getName());
+        assertEquals("%label2_end", ifJump.getEndIf().getName());
+        assertEquals("%label1", label1.getName());
+        assertEquals("%label2", jump.getLabel().getName());
+        assertEquals("%label2", label2.getName());
 
-        assertEquals(label2, ifJump.getLabel());
-        assertEquals(label2End, ifJump.getEndIf());
         assertTrue(jump.isGoto());
-        assertEquals(label2, jump.getLabel());
-
-        assertEquals(ifJump, builder.getBlockEndEvents().get(label0));
-        assertEquals(jump, builder.getBlockEndEvents().get(label1));
-        assertEquals(ret, builder.getBlockEndEvents().get(label2));
+        assertEquals(Map.of("%label2", "%label2_end"), helper.getMergeLabelIds());
+        assertEquals(Map.of("%label0", ifJump, "%label1", jump, "%label2", ret),
+                helper.getLastBlockEvents());
     }
 
     @Test
@@ -208,24 +208,29 @@ public void testStructuredBranchNestedTrue() {
         Label label2 = (Label) events.get(8);
         Return ret = (Return) events.get(9);
 
-        Label label2End = builder.getCfDefinition().get(label2);
-        Label label2EndInner = builder.getCfDefinition().get(label2Inner);
-
-        assertEquals(label2, ifJump.getLabel());
-        assertEquals(label2End, ifJump.getEndIf());
-        assertEquals(label2Inner, ifJumpInner.getLabel());
-        assertEquals(label2EndInner, ifJumpInner.getEndIf());
+        assertEquals("%label0", label0.getName());
+        assertEquals("%label2", ifJump.getLabel().getName());
+        assertEquals("%label2_end", ifJump.getEndIf().getName());
+        assertEquals("%label1", label1.getName());
+        assertEquals("%label2_inner", ifJumpInner.getLabel().getName());
+        assertEquals("%label2_inner_end", ifJumpInner.getEndIf().getName());
+        assertEquals("%label1_inner", label1Inner.getName());
+        assertEquals("%label2_inner", jumpInner.getLabel().getName());
+        assertEquals("%label2_inner", label2Inner.getName());
+        assertEquals("%label2", jump.getLabel().getName());
+        assertEquals("%label2", jump.getLabel().getName());
+        assertEquals("%label2", label2.getName());
 
         assertTrue(jump.isGoto());
-        assertEquals(label2, jump.getLabel());
         assertTrue(jumpInner.isGoto());
-        assertEquals(label2Inner, jumpInner.getLabel());
 
-        assertEquals(ifJump, builder.getBlockEndEvents().get(label0));
-        assertEquals(ifJumpInner, builder.getBlockEndEvents().get(label1));
-        assertEquals(jumpInner, builder.getBlockEndEvents().get(label1Inner));
-        assertEquals(jump, builder.getBlockEndEvents().get(label2Inner));
-        assertEquals(ret, builder.getBlockEndEvents().get(label2));
+        assertEquals(Map.of(
+                "%label0", ifJump,
+                "%label1", ifJumpInner,
+                "%label1_inner", jumpInner,
+                "%label2_inner", jump,
+                "%label2", ret
+        ), helper.getLastBlockEvents());
     }
 
     @Test
@@ -234,7 +239,7 @@ public void testStructuredBranchNestedFalse() {
         String input = """
                 %label0 = OpLabel
                 OpSelectionMerge %label2 None
-                OpBranchConditional %value %label1 %label2
+                OpBranchConditional %value %label1 %label3
                 %label1 = OpLabel
                 OpBranch %label2
                 %label2 = OpLabel
@@ -243,6 +248,8 @@ public void testStructuredBranchNestedFalse() {
                 %label1_inner = OpLabel
                 OpBranch %label2_inner
                 %label2_inner = OpLabel
+                OpBranch %label3
+                %label3 = OpLabel
                 OpReturn
                 """;
 
@@ -259,32 +266,42 @@ public void testStructuredBranchNestedFalse() {
         Label label0 = (Label) events.get(0);
         IfAsJump ifJump = (IfAsJump) events.get(1);
         Label label1 = (Label) events.get(2);
-        CondJump jump = (CondJump) events.get(3);
+        CondJump jump1 = (CondJump) events.get(3);
         Label label2 = (Label) events.get(4);
         IfAsJump ifJumpInner = (IfAsJump) events.get(5);
         Label label1Inner = (Label) events.get(6);
         CondJump jumpInner = (CondJump) events.get(7);
         Label label2Inner = (Label) events.get(8);
-        Return ret = (Return) events.get(9);
-
-        Label label2End = builder.getCfDefinition().get(label2);
-        Label label2EndInner = builder.getCfDefinition().get(label2Inner);
-
-        assertEquals(label2, ifJump.getLabel());
-        assertEquals(label2End, ifJump.getEndIf());
-        assertEquals(label2Inner, ifJumpInner.getLabel());
-        assertEquals(label2EndInner, ifJumpInner.getEndIf());
-
-        assertTrue(jump.isGoto());
-        assertEquals(label2, jump.getLabel());
+        CondJump jump2 = (CondJump) events.get(9);
+        Label label3 = (Label) events.get(10);
+        Return ret = (Return) events.get(11);
+
+        assertEquals("%label0", label0.getName());
+        assertEquals("%label3", ifJump.getLabel().getName());
+        assertEquals("%label3_end", ifJump.getEndIf().getName());
+        assertEquals("%label2", jump1.getLabel().getName());
+        assertEquals("%label1", label1.getName());
+        assertEquals("%label2", label2.getName());
+        assertEquals("%label2_inner", ifJumpInner.getLabel().getName());
+        assertEquals("%label2_inner_end", ifJumpInner.getEndIf().getName());
+        assertEquals("%label1_inner", label1Inner.getName());
+        assertEquals("%label2_inner", jumpInner.getLabel().getName());
+        assertEquals("%label2_inner", label2Inner.getName());
+        assertEquals("%label3", jump2.getLabel().getName());
+        assertEquals("%label3", label3.getName());
+
+        assertTrue(jump1.isGoto());
+        assertTrue(jump2.isGoto());
         assertTrue(jumpInner.isGoto());
-        assertEquals(label2Inner, jumpInner.getLabel());
 
-        assertEquals(ifJump, builder.getBlockEndEvents().get(label0));
-        assertEquals(jump, builder.getBlockEndEvents().get(label1));
-        assertEquals(ifJumpInner, builder.getBlockEndEvents().get(label2));
-        assertEquals(jumpInner, builder.getBlockEndEvents().get(label1Inner));
-        assertEquals(ret, builder.getBlockEndEvents().get(label2Inner));
+        assertEquals(Map.of(
+                "%label0", ifJump,
+                "%label1", jump1,
+                "%label2", ifJumpInner,
+                "%label1_inner", jumpInner,
+                "%label2_inner", jump2,
+                "%label3", ret
+        ), helper.getLastBlockEvents());
     }
 
     @Test
@@ -313,13 +330,60 @@ public void testStructuredBranchNestedSameLabel() {
             fail("Should throw exception");
         } catch (ParsingException e) {
             // then
-            assertEquals("Overlapping blocks with endpoint in label '%label2'",
-                    e.getMessage());
+            assertEquals("Attempt to redefine label '%label2_end'", e.getMessage());
         }
     }
 
     @Test
-    public void testStructuredLoop() {
+    public void testLoopWithForwardLabels() {
+        // given
+        String input = """
+                        %label0 = OpLabel
+                        OpLoopMerge %label1 %label2 None
+                        OpBranchConditional %value %label1 %label2
+                        %label1 = OpLabel
+                        OpBranch %label2
+                        %label2 = OpLabel
+                        OpReturn
+                        """;
+
+        builder.mockFunctionStart();
+        builder.mockBoolType("%bool");
+        builder.mockRegister("%value", "%bool");
+
+        // when
+        visit(input);
+
+        // then
+        List<Event> events = builder.getCurrentFunction().getEvents();
+
+        Label label0 = (Label) events.get(0);
+        CondJump jump1 = (CondJump) events.get(1);
+        CondJump jump2 = (CondJump) events.get(2);
+        Label label1 = (Label) events.get(3);
+        CondJump jump3 = (CondJump) events.get(4);
+        Label label2 = (Label) events.get(5);
+        Return ret = (Return) events.get(6);
+
+        assertEquals("%label0", label0.getName());
+        assertEquals("%label1", jump1.getLabel().getName());
+        assertEquals("%label2", jump2.getLabel().getName());
+        assertEquals("%label1", label1.getName());
+        assertEquals("%label2", jump2.getLabel().getName());
+        assertEquals("%label2", label2.getName());
+
+        assertFalse(jump1.isGoto());
+        assertTrue(jump2.isGoto());
+        assertTrue(jump3.isGoto());
+
+        assertTrue(helper.getMergeLabelIds().isEmpty());
+
+        assertEquals(Map.of("%label0", jump1, "%label1", jump3, "%label2", ret),
+                helper.getLastBlockEvents());
+    }
+
+    @Test
+    public void testLoopWithBackwardLabel() {
         // given
         String input = """
                 %label0 = OpLabel
@@ -340,20 +404,23 @@ public void testStructuredLoop() {
         List<Event> events = builder.getCurrentFunction().getEvents();
 
         Label label0 = (Label) events.get(0);
-        IfAsJump ifJump = (IfAsJump) events.get(1);
-        CondJump jump = (CondJump) events.get(2);
+        CondJump jump1 = (CondJump) events.get(1);
+        CondJump jump2 = (CondJump) events.get(2);
         Label label1 = (Label) events.get(3);
         Return ret = (Return) events.get(4);
 
-        assertTrue(builder.getCfDefinition().isEmpty());
+        assertEquals("%label0", label0.getName());
+        assertEquals("%label1", jump1.getLabel().getName());
+        assertEquals("%label0", jump2.getLabel().getName());
+        assertEquals("%label1", label1.getName());
 
-        assertEquals(label1, ifJump.getLabel());
-        assertEquals(label1, ifJump.getEndIf());
-        assertTrue(jump.isGoto());
-        assertEquals(label0, jump.getLabel());
+        assertFalse(jump1.isGoto());
+        assertTrue(jump2.isGoto());
 
-        assertEquals(ifJump, builder.getBlockEndEvents().get(label0));
-        assertEquals(ret, builder.getBlockEndEvents().get(label1));
+        assertTrue(helper.getMergeLabelIds().isEmpty());
+
+        assertEquals(Map.of("%label0", jump1, "%label1", ret),
+                helper.getLastBlockEvents());
     }
 
     @Test
@@ -365,7 +432,7 @@ public void testStructuredBranchBackwardTrue() {
                         %label1 = OpLabel
                         OpReturn
                         """,
-                "Illegal backward jump to label '%label0' " +
+                "Illegal backward jump to '%label0' " +
                         "from a structured branch");
     }
 
@@ -378,57 +445,10 @@ public void testStructuredBranchBackwardFalse() {
                         %label1 = OpLabel
                         OpReturn
                         """,
-                "Illegal backward jump to label '%label0' " +
-                        "from a structured branch");
-    }
-
-    @Test
-    public void testLoopMergeBackward() {
-        doTestIllegalStructuredBranch("""
-                        %label2 = OpLabel
-                        OpBranch %label0
-                        %label0 = OpLabel
-                        OpLoopMerge %label2 %label0 None
-                        OpBranchConditional %value1 %label2 %label0
-                        %label1 = OpLabel
-                        OpReturn
-                        """,
-                "Illegal backward jump to label '%label2' " +
+                "Illegal backward jump to '%label0' " +
                         "from a structured branch");
     }
 
-    @Test
-    public void testLoopContinueForward() {
-        doTestIllegalStructuredBranch("""
-                        %label0 = OpLabel
-                        OpLoopMerge %label1 %label2 None
-                        OpBranchConditional %value1 %label1 %label2
-                        %label1 = OpLabel
-                        OpBranch %label2
-                        %label2 = OpLabel
-                        OpReturn
-                        """,
-                "Illegal forward jump to label '%label2' " +
-                        "from a structured loop");
-    }
-
-    @Test
-    public void testStructuredBranchIllegalMergeLabel() {
-        doTestIllegalStructuredBranch("""
-                        %label0 = OpLabel
-                        OpSelectionMerge %label3 None
-                        OpBranchConditional %value1 %label1 %label2
-                        %label1 = OpLabel
-                        OpBranch %label2
-                        %label2 = OpLabel
-                        OpBranch %label3
-                        %label3 = OpLabel
-                        OpReturn
-                        """,
-                "Illegal last label in conditional branch, " +
-                        "expected '%label3' but received '%label2'");
-    }
-
     @Test
     public void testStructuredBranchLabelsIllegalOrder() {
         doTestIllegalStructuredBranch("""
@@ -444,39 +464,22 @@ public void testStructuredBranchLabelsIllegalOrder() {
     }
 
     @Test
-    public void testOpLoopLabelsIllegalContinueLabel() {
+    public void testLoopMergeWithTwoBackwardLabels() {
         doTestIllegalStructuredBranch("""
-                        %label0 = OpLabel
-                        OpLoopMerge %label1 %label2 None
-                        OpBranchConditional %value1 %label1 %label0
-                        %label1 = OpLabel
-                        OpBranch %label2
                         %label2 = OpLabel
-                        OpReturn
-                        """,
-                "Illegal labels, expected mergeLabel='%label1' " +
-                        "and continueLabel='%label2' but received " +
-                        "mergeLabel='%label1' and continueLabel='%label0'");
-    }
-
-    @Test
-    public void testOpLoopLabelsIllegalMergeLabel() {
-        doTestIllegalStructuredBranch("""
+                        OpBranch %label0
                         %label0 = OpLabel
                         OpLoopMerge %label2 %label0 None
-                        OpBranchConditional %value1 %label1 %label0
+                        OpBranchConditional %value1 %label2 %label0
                         %label1 = OpLabel
-                        OpBranch %label2
-                        %label2 = OpLabel
                         OpReturn
                         """,
-                "Illegal labels, expected mergeLabel='%label2' " +
-                        "and continueLabel='%label0' but received " +
-                        "mergeLabel='%label1' and continueLabel='%label0'");
+                "Unsupported conditional branch " +
+                        "with two backward jumps to '%label2' and '%label0'");
     }
 
     @Test
-    public void testOpLoopLabelsIllegalOrder() {
+    public void testOpLoopIllegalTrueLabel() {
         doTestIllegalStructuredBranch("""
                         %label0 = OpLabel
                         OpLoopMerge %label1 %label0 None
@@ -484,9 +487,7 @@ public void testOpLoopLabelsIllegalOrder() {
                         %label1 = OpLabel
                         OpReturn
                         """,
-                "Illegal labels, expected mergeLabel='%label1' " +
-                        "and continueLabel='%label0' but received " +
-                        "mergeLabel='%label0' and continueLabel='%label1'");
+                "Illegal label, expected '%label0' but received '%label1'");
     }
 
     private void doTestIllegalStructuredBranch(String input, String error) {
@@ -510,10 +511,10 @@ public void testOpBranchConditionalUnstructured() {
         // given
         String input = """
                 %label0 = OpLabel
-                OpBranchConditional %value1 %label2 %label3
+                OpBranchConditional %value1 %label1 %label2
+                %label1 = OpLabel
+                OpBranchConditional %value2 %label2 %label1
                 %label2 = OpLabel
-                OpBranchConditional %value2 %label3 %label1
-                %label3 = OpLabel
                 OpReturn
                 """;
         builder.mockFunctionStart();
@@ -527,23 +528,31 @@ public void testOpBranchConditionalUnstructured() {
         // then
         List<Event> events = builder.getCurrentFunction().getEvents();
 
-        CondJump event1 = (CondJump) events.get(1);
-        assertEquals("%value1", getGuardRegister(event1).getName());
-        assertEquals(builder.getOrCreateLabel("%label2"), event1.getLabel());
-
-        CondJump event2 = (CondJump) events.get(2);
-        assertEquals("%value1", getGuardRegister(event2).getName());
-        assertEquals(builder.getOrCreateLabel("%label3"), event2.getLabel());
-
-        CondJump event4 = (CondJump) events.get(4);
-        assertEquals("%value2", getGuardRegister(event4).getName());
-        assertEquals(builder.getOrCreateLabel("%label3"), event4.getLabel());
-
-        CondJump event5 = (CondJump) events.get(5);
-        assertEquals("%value2", getGuardRegister(event5).getName());
-        assertEquals(builder.getOrCreateLabel("%label1"), event5.getLabel());
-
-        assertTrue(builder.getBlocks().isEmpty());
+        Label label0 = (Label) events.get(0);
+        CondJump jump1 = (CondJump) events.get(1);
+        CondJump jump2 = (CondJump) events.get(2);
+        Label label1 = (Label) events.get(3);
+        CondJump jump3 = (CondJump) events.get(4);
+        CondJump jump4 = (CondJump) events.get(5);
+        Label label2 = (Label) events.get(6);
+        Return ret = (Return) events.get(7);
+
+        assertEquals("%label0", label0.getName());
+        assertEquals("%label1", jump1.getLabel().getName());
+        assertEquals("%label2", jump2.getLabel().getName());
+        assertEquals("%label1", label1.getName());
+        assertEquals("%label2", jump3.getLabel().getName());
+        assertEquals("%label1", jump4.getLabel().getName());
+        assertEquals("%label2", label2.getName());
+
+        assertFalse(jump1.isGoto());
+        assertTrue(jump2.isGoto());
+        assertFalse(jump3.isGoto());
+        assertTrue(jump4.isGoto());
+
+        assertTrue(helper.getBlockStack().isEmpty());
+        assertEquals(Map.of("%label0", jump1, "%label1", jump3, "%label2", ret),
+                helper.getLastBlockEvents());
     }
 
     @Test
@@ -563,8 +572,7 @@ public void testOpBranchConditionalSameLabels() {
             fail("Should throw exception");
         } catch (ParsingException e) {
             // then
-            assertEquals("Labels of conditional branch cannot be the same",
-                    e.getMessage());
+            assertEquals("Labels of conditional branch cannot be the same", e.getMessage());
         }
     }
 
@@ -604,7 +612,7 @@ public void testReturn() {
         Return event = (Return) function.getEvents().get(1);
         assertNotNull(event);
         assertTrue(event.getValue().isEmpty());
-        assertTrue(builder.getBlocks().isEmpty());
+        assertTrue(helper.getBlockStack().isEmpty());
     }
 
     @Test
@@ -623,7 +631,7 @@ public void testReturnValue() {
         Function function = builder.getCurrentFunction();
         Return event = (Return) function.getEvents().get(1);
         assertEquals(builder.getExpression("%value"), event.getValue().orElseThrow());
-        assertTrue(builder.getBlocks().isEmpty());
+        assertTrue(helper.getBlockStack().isEmpty());
     }
 
     @Test
@@ -667,17 +675,6 @@ public void testReturnValueForVoidFunction() {
         }
     }
 
-    private Register getGuardRegister(CondJump event) {
-        Expression guard = event.getGuard();
-        if (guard instanceof Register register) {
-            return register;
-        }
-        if (guard instanceof BoolUnaryExpr expr) {
-            return (Register) expr.getOperand();
-        }
-        throw new RuntimeException("Unexpected expression type");
-    }
-
     private void visit(String text) {
         new MockSpirvParser(text).spv().accept(new VisitorOpsControlFlow(builder));
     }
diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockHelperControlFlow.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockHelperControlFlow.java
new file mode 100644
index 0000000000..f758d665c0
--- /dev/null
+++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockHelperControlFlow.java
@@ -0,0 +1,32 @@
+package com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks;
+
+import com.dat3m.dartagnan.expression.Expression;
+import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperControlFlow;
+import com.dat3m.dartagnan.program.Register;
+import com.dat3m.dartagnan.program.event.Event;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+public class MockHelperControlFlow extends HelperControlFlow {
+    public MockHelperControlFlow(Map<String, Expression> expressions) {
+        super(expressions);
+    }
+
+    public List<String> getBlockStack() {
+        return blockStack.stream().toList();
+    }
+
+    public Map<String, String> getMergeLabelIds() {
+        return Map.copyOf(mergeLabelIds);
+    }
+
+    public Map<String, Event> getLastBlockEvents() {
+        return Map.copyOf(lastBlockEvents);
+    }
+
+    public Map<Register, String> getPhiDefinitions(String blockId) {
+        return Map.copyOf(phiDefinitions.computeIfAbsent(blockId, k -> new HashMap<>()));
+    }
+}
diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilderSpv.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilderSpv.java
index a348eea92f..c17ef1d81f 100644
--- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilderSpv.java
+++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilderSpv.java
@@ -25,15 +25,16 @@ public class MockProgramBuilderSpv extends ProgramBuilderSpv {
     private static final ExpressionFactory exprFactory = ExpressionFactory.getInstance();
 
     public MockProgramBuilderSpv() {
-        super(List.of(1, 1, 1, 1), Map.of());
+        this(List.of(1, 1, 1, 1), Map.of());
     }
 
     public MockProgramBuilderSpv(Map<String, Expression> input) {
-        super(List.of(1, 1, 1, 1), input);
+        this(List.of(1, 1, 1, 1), input);
     }
 
     public MockProgramBuilderSpv(List<Integer> grid, Map<String, Expression> input) {
         super(grid, input);
+        helperControlFlow = new MockHelperControlFlow(expressions);
     }
 
     @Override
@@ -141,12 +142,13 @@ public Register mockRegister(String id, String typeId) {
     }
 
     public void mockLabel() {
-        startBlock(new Label("%mock_label"));
+        helperControlFlow.getOrCreateLabel("%mock_label");
+        helperControlFlow.startBlock("%mock_label");
     }
 
     public void mockLabel(String id) {
-        Label label = getOrCreateLabel(id);
-        startBlock(label);
+        Label label = helperControlFlow.getOrCreateLabel(id);
+        helperControlFlow.startBlock(id);
         addEvent(label);
     }
 
@@ -175,23 +177,7 @@ public Map<String, Expression> getExpressions() {
         return Map.copyOf(expressions);
     }
 
-    public List<Label> getBlocks() {
-        return blocks.stream().toList();
-    }
-
-    public Map<Label, Label> getCfDefinition() {
-        return Map.copyOf(cfDefinitions);
-    }
-
-    public Map<Label, Event> getBlockEndEvents() {
-        return Map.copyOf(blockEndEvents);
-    }
-
     public Set<Function> getForwardFunctions() {
         return Set.copyOf(forwardFunctions.values());
     }
-
-    public Map<Register, String> getPhiDefinitions(Label label) {
-        return Map.copyOf(phiDefinitions.computeIfAbsent(label, k -> new HashMap<>()));
-    }
 }