From ca3147d8c81e2ad5d7f370f8a2f4ec0b5babf5bf Mon Sep 17 00:00:00 2001 From: Natalia Gavrilenko <natgavrilenko@gmail.com> Date: Fri, 12 Jul 2024 11:36:32 +0200 Subject: [PATCH] Refactored control flow --- .../visitors/spirv/ProgramBuilderSpv.java | 123 +----- .../visitors/spirv/VisitorOpsControlFlow.java | 206 +++++----- .../spirv/helpers/HelperControlFlow.java | 107 +++++ .../visitors/spirv/ProgramBuilderTest.java | 7 +- .../visitors/spirv/VisitorOpsBarrierTest.java | 2 +- .../spirv/VisitorOpsControlFlowTest.java | 375 +++++++++--------- .../spirv/mocks/MockHelperControlFlow.java | 32 ++ .../spirv/mocks/MockProgramBuilderSpv.java | 28 +- 8 files changed, 440 insertions(+), 440 deletions(-) create mode 100644 dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperControlFlow.java create mode 100644 dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockHelperControlFlow.java diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderSpv.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderSpv.java index e7aba4a36a..a073ff566a 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderSpv.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderSpv.java @@ -8,6 +8,7 @@ import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.BuiltIn; import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.Decoration; import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperControlFlow; import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperDecorations; import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperTags; import com.dat3m.dartagnan.parsers.program.visitors.spirv.transformers.ThreadCreator; @@ -16,8 +17,6 @@ import com.dat3m.dartagnan.expression.type.ScopedPointerType; import com.dat3m.dartagnan.program.*; import com.dat3m.dartagnan.program.event.*; -import com.dat3m.dartagnan.program.event.core.Label; -import com.dat3m.dartagnan.program.event.core.Local; import com.dat3m.dartagnan.program.memory.Memory; import com.dat3m.dartagnan.program.memory.MemoryObject; import org.apache.logging.log4j.LogManager; @@ -37,14 +36,9 @@ public class ProgramBuilderSpv { protected final Map<String, Type> types = new HashMap<>(); protected final Map<String, Expression> expressions = new HashMap<>(); protected final Map<String, Function> forwardFunctions = new HashMap<>(); - protected final Map<String, Label> labels = new HashMap<>(); - protected final Deque<Label> blocks = new ArrayDeque<>(); - protected final Map<Label, Event> blockEndEvents = new HashMap<>(); - protected final Map<Label, Map<Register, String>> phiDefinitions = new HashMap<>(); - protected final Map<Label, Label> cfDefinitions = new HashMap<>(); protected final List<Integer> threadGrid; protected final Map<String, Expression> inputs; - protected final Program program; + public final Program program; protected Function currentFunction; protected String entryPointId; @@ -53,6 +47,7 @@ public class ProgramBuilderSpv { private final HelperTags helperTags = new HelperTags(); private final HelperDecorations helperDecorations; + protected HelperControlFlow helperControlFlow = new HelperControlFlow(expressions); public ProgramBuilderSpv(List<Integer> threadGrid, Map<String, Expression> inputs) { validateThreadGrid(threadGrid); @@ -64,7 +59,7 @@ public ProgramBuilderSpv(List<Integer> threadGrid, Map<String, Expression> input public Program build() { validateBeforeBuild(); - preprocessBlocks(); + helperControlFlow.build(); BuiltIn builtIn = (BuiltIn) getDecoration(DecorationType.BUILT_IN); Set<ScopedPointerVariable> variables = expressions.values().stream() .filter(ScopedPointerVariable.class::isInstance) @@ -163,26 +158,11 @@ private Function checkFunctionType(String id, Function function, Type type) { "function type doesn't match the function definition", id); } - public void startBlock(Label label) { - if (blockEndEvents.containsKey(label)) { - throw new ParsingException("Attempt to redefine label '%s'", label.getName()); - } - blocks.push(label); - } - - public Event endBlock(Event event) { - if (blocks.isEmpty()) { - throw new ParsingException("Attempt to exit block while not in a block definition"); - } - blockEndEvents.put(blocks.pop(), event); - return event; - } - public Event addEvent(Event event) { if (currentFunction == null) { throw new ParsingException("Attempt to add an event outside a function definition"); } - if (blocks.isEmpty()) { + if (!helperControlFlow.isInsideBlock()) { throw new ParsingException("Attempt to add an event outside a control flow block"); } if (event instanceof RegWriter regWriter) { @@ -223,14 +203,6 @@ public String getExpressionStorageClass(String name) { throw new ParsingException("Reference to undefined pointer '%s'", name); } - public Label makeBranchBackJumpLabel(Label label) { - String id = label.getName() + "_back"; - if (labels.containsKey(id)) { - throw new ParsingException("Overlapping blocks with back jump in label '%s'", label.getName()); - } - return getOrCreateLabel(id); - } - public boolean hasInput(String id) { return inputs.containsKey(id); } @@ -253,22 +225,6 @@ public Register addRegister(String id, String typeId) { return getCurrentFunctionOrThrowError().newRegister(id, getType(typeId)); } - public boolean hasBlock(String id) { - if (!labels.containsKey(id)) { - return false; - } - Label label = labels.get(id); - return blockEndEvents.containsKey(label) || blocks.contains(label); - } - - public Label getOrCreateLabel(String id) { - Label label = labels.computeIfAbsent(id, EventFactory::newLabel); - if (label.getFunction() == null) { - label.setFunction(getCurrentFunctionOrThrowError()); - } - return label; - } - public String getScope(String id) { return helperTags.visitScope(id, getExpression(id)); } @@ -289,7 +245,11 @@ public Decoration getDecoration(DecorationType type) { return helperDecorations.getDecoration(type); } - private Function getCurrentFunctionOrThrowError() { + public HelperControlFlow getHelperControlFlow() { + return helperControlFlow; + } + + public Function getCurrentFunctionOrThrowError() { if (currentFunction != null) { return currentFunction; } @@ -306,28 +266,11 @@ private void validateBeforeBuild() { throw new ParsingException("Unclosed definition for function '%s'", currentFunction.getName()); } - if (!blocks.isEmpty()) { - throw new ParsingException("Unclosed blocks for labels %s", - String.join(",", blocks.stream().map(Label::getName).toList())); - } if (nextOps != null) { throw new ParsingException("Missing expected op: %s", String.join(",", nextOps)); - } - // TODO: Validate no event refers to an undefined label - } - private void preprocessBlocks() { - Map<Event, Label> blockEndToLabelMap = new HashMap<>(); - for (Map.Entry<Label, Event> entry : blockEndEvents.entrySet()) { - if (blockEndToLabelMap.containsKey(entry.getValue())) { - throw new ParsingException("Malformed control flow, " + - "multiple block refer to the same end event " + entry.getValue()); - } - blockEndToLabelMap.put(entry.getValue(), entry.getKey()); } - insertPhiDefinitions(blockEndToLabelMap); - insertBlockEndLabels(blockEndToLabelMap); } private void checkSpecification() { @@ -338,40 +281,6 @@ private void checkSpecification() { } } - public void addPhiDefinition(Label label, Register register, String id) { - phiDefinitions.computeIfAbsent(label, k -> new HashMap<>()).put(register, id); - } - - private void insertPhiDefinitions(Map<Event, Label> blockEndToLabelMap) { - for (Function function : program.getFunctions()) { - for (Event event : function.getEvents()) { - Label label = blockEndToLabelMap.get(event.getSuccessor()); - if (label != null) { - Map<Register, String> phi = phiDefinitions.get(label); - if (phi != null) { - for (Map.Entry<Register, String> entry : phi.entrySet()) { - event.insertAfter(new Local(entry.getKey(), getExpression(entry.getValue()))); - } - } - } - } - } - } - - private void insertBlockEndLabels(Map<Event, Label> blockEndToLabelMap) { - for (Function function : program.getFunctions()) { - for (Event event : function.getEvents()) { - Label label = blockEndToLabelMap.get(event.getSuccessor()); - if (label != null) { - Label endLabel = cfDefinitions.get(label); - if (endLabel != null) { - event.insertAfter(endLabel); - } - } - } - } - } - private Function getEntryPointFunction() { if (entryPointId == null) { throw new ParsingException("Cannot build the program, entryPointId is missing"); @@ -389,8 +298,6 @@ private Function getEntryPointFunction() { throw new ParsingException("Entry point expression '%s' must be a function", entryPointId); } - // TODO: Move to the corresponding visitors, tests, mocks etc ====================================================== - public FunctionType getCurrentFunctionType() { return getCurrentFunctionOrThrowError().getFunctionType(); } @@ -399,16 +306,6 @@ public String getCurrentFunctionName() { return getCurrentFunctionOrThrowError().getName(); } - public Label makeBranchEndLabel(Label label) { - String id = label.getName() + "_end"; - if (labels.containsKey(id)) { - throw new ParsingException("Overlapping blocks with endpoint in label '%s'", label.getName()); - } - Label endLabel = getOrCreateLabel(id); - cfDefinitions.put(label, endLabel); - return endLabel; - } - private void validateThreadGrid(List<Integer> threadGrid) { if (threadGrid.size() != 4) { throw new ParsingException("Thread grid must have 4 dimensions"); diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlow.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlow.java index d5657b245c..47343c13b4 100644 --- a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlow.java +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlow.java @@ -2,17 +2,19 @@ import com.dat3m.dartagnan.exception.ParsingException; import com.dat3m.dartagnan.expression.Expression; +import com.dat3m.dartagnan.expression.ExpressionFactory; import com.dat3m.dartagnan.expression.Type; import com.dat3m.dartagnan.expression.type.TypeFactory; import com.dat3m.dartagnan.parsers.SpirvBaseVisitor; import com.dat3m.dartagnan.parsers.SpirvParser; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperControlFlow; import com.dat3m.dartagnan.program.Register; import com.dat3m.dartagnan.program.event.Event; import com.dat3m.dartagnan.program.event.EventFactory; import com.dat3m.dartagnan.program.event.core.Label; import com.dat3m.dartagnan.program.event.functions.Return; -import java.util.Set; +import java.util.*; import static com.dat3m.dartagnan.program.event.EventFactory.newFunctionReturn; @@ -20,12 +22,14 @@ public class VisitorOpsControlFlow extends SpirvBaseVisitor<Event> { private static final TypeFactory types = TypeFactory.getInstance(); private final ProgramBuilderSpv builder; + private final HelperControlFlow helper; private String continueLabelId; private String mergeLabelId; private String nextLabelId; public VisitorOpsControlFlow(ProgramBuilderSpv builder) { this.builder = builder; + this.helper = builder.getHelperControlFlow(); } @Override @@ -35,8 +39,9 @@ public Event visitOpPhi(SpirvParser.OpPhiContext ctx) { Register register = builder.addRegister(id, typeId); for (SpirvParser.VariableContext vCtx : ctx.variable()) { SpirvParser.PairIdRefIdRefContext pCtx = vCtx.pairIdRefIdRef(); - Label event = builder.getOrCreateLabel(pCtx.idRef(1).getText()); - builder.addPhiDefinition(event, register, pCtx.idRef(0).getText()); + String labelId = pCtx.idRef(1).getText(); + String expressionId = pCtx.idRef(0).getText(); + helper.addPhiDefinition(labelId, register, expressionId); } builder.addExpression(id, register); return null; @@ -51,61 +56,12 @@ public Event visitOpLabel(SpirvParser.OpLabelContext ctx) { } nextLabelId = null; } - Label event = builder.getOrCreateLabel(ctx.idResult().getText()); - builder.startBlock(event); + String labelId = ctx.idResult().getText(); + Label event = helper.getOrCreateLabel(labelId); + helper.startBlock(labelId); return builder.addEvent(event); } - @Override - public Event visitOpBranch(SpirvParser.OpBranchContext ctx) { - String labelId = ctx.targetLabel().getText(); - if (continueLabelId == null && mergeLabelId == null) { - Label label = builder.getOrCreateLabel(labelId); - Event event = EventFactory.newGoto(label); - builder.addEvent(event); - return builder.endBlock(event); - } - - // TODO: Test me! - continueLabelId = null; - mergeLabelId = null; - Label label = builder.getOrCreateLabel(labelId); - Event event = EventFactory.newGoto(label); - builder.addEvent(event); - return builder.endBlock(event); - //throw new ParsingException("Unsupported control flow around OpBranch '%s'", labelId); - } - - @Override - public Event visitOpBranchConditional(SpirvParser.OpBranchConditionalContext ctx) { - if (ctx.trueLabel().getText().equals(ctx.falseLabel().getText())) { - throw new ParsingException("Labels of conditional branch cannot be the same"); - } - if (mergeLabelId == null) { - return visitOpBranchConditionalUnstructured(ctx); - } - if (continueLabelId != null) { - if (mergeLabelId.equals(ctx.trueLabel().getText()) - && continueLabelId.equals(ctx.falseLabel().getText())) { - mergeLabelId = null; - continueLabelId = null; - nextLabelId = ctx.trueLabel().getText(); - builder.setNextOps(Set.of("OpLabel")); - return visitOpBranchConditionalStructuredLoop(ctx); - } - throw new ParsingException("Illegal labels, expected mergeLabel='%s' and continueLabel='%s' " + - "but received mergeLabel='%s' and continueLabel='%s'", - mergeLabelId, continueLabelId, ctx.trueLabel().getText(), ctx.falseLabel().getText()); - } else if (mergeLabelId.equals(ctx.falseLabel().getText())) { - mergeLabelId = null; - nextLabelId = ctx.trueLabel().getText(); - builder.setNextOps(Set.of("OpLabel")); - return visitOpBranchConditionalStructured(ctx); - } - throw new ParsingException("Illegal last label in conditional branch, " + - "expected '%s' but received '%s'", mergeLabelId, ctx.falseLabel().getText()); - } - @Override public Event visitOpSelectionMerge(SpirvParser.OpSelectionMergeContext ctx) { if (mergeLabelId == null) { @@ -121,14 +77,40 @@ public Event visitOpLoopMerge(SpirvParser.OpLoopMergeContext ctx) { if (continueLabelId == null && mergeLabelId == null) { continueLabelId = ctx.continueTarget().getText(); mergeLabelId = ctx.mergeBlock().getText(); - builder.setNextOps(Set.of("OpBranch", "OpBranchConditional")); + builder.setNextOps(Set.of("OpBranchConditional", "OpBranch")); return null; } throw new ParsingException("End and continue labels must be null"); + } - // TODO: For a structured while loop, a control dependency - // should be generated in the same way as for a structured if. - // Add a new event for this. + @Override + public Event visitOpBranch(SpirvParser.OpBranchContext ctx) { + String labelId = ctx.targetLabel().getText(); + if (continueLabelId == null && mergeLabelId == null) { + return visitGoto(labelId); + } + if (continueLabelId != null && mergeLabelId != null) { + return visitLoopBranch(labelId); + } + throw new ParsingException("OpBranch '%s' must be either " + + "a part of a loop definition or an arbitrary goto", labelId); + } + + @Override + public Event visitOpBranchConditional(SpirvParser.OpBranchConditionalContext ctx) { + Expression guard = builder.getExpression(ctx.condition().getText()); + String trueLabelId = ctx.trueLabel().getText(); + String falseLabelId = ctx.falseLabel().getText(); + if (trueLabelId.equals(falseLabelId)) { + throw new ParsingException("Labels of conditional branch cannot be the same"); + } + if (mergeLabelId != null) { + if (continueLabelId != null) { + return visitLoopBranchConditional(guard, trueLabelId, falseLabelId); + } + return visitIfBranch(guard, trueLabelId, falseLabelId); + } + return visitConditionalJump(guard, trueLabelId, falseLabelId); } @Override @@ -137,7 +119,7 @@ public Event visitOpReturn(SpirvParser.OpReturnContext ctx) { if (types.getVoidType().equals(returnType)) { Return event = newFunctionReturn(null); builder.addEvent(event); - return builder.endBlock(event); + return helper.endBlock(event); } throw new ParsingException("Illegal non-value return for a non-void function '%s'", builder.getCurrentFunctionName()); @@ -151,76 +133,72 @@ public Event visitOpReturnValue(SpirvParser.OpReturnValueContext ctx) { Expression expression = builder.getExpression(valueId); Event event = newFunctionReturn(expression); builder.addEvent(event); - return builder.endBlock(event); + return helper.endBlock(event); } throw new ParsingException("Illegal value return for a void function '%s'", builder.getCurrentFunctionName()); } - private Event visitOpBranchConditionalStructuredLoop(SpirvParser.OpBranchConditionalContext ctx) { - String trueLabelId = validateForwardLabel(ctx.trueLabel().getText()); - String falseLabelId = validateBackwardLabel(ctx.falseLabel().getText()); - Expression guard = builder.getExpression(ctx.condition().getText()); - Label falseLabel = builder.getOrCreateLabel(falseLabelId); - Label trueLabel = builder.getOrCreateLabel(trueLabelId); - Event event = builder.addEvent(EventFactory.newIfJump(guard, trueLabel, trueLabel)); - builder.addEvent(EventFactory.newGoto(falseLabel)); - return builder.endBlock(event); + private Event visitGoto(String labelId) { + Label label = helper.getOrCreateLabel(labelId); + Event event = EventFactory.newGoto(label); + builder.addEvent(event); + return helper.endBlock(event); } - private Event visitOpBranchConditionalStructured(SpirvParser.OpBranchConditionalContext ctx) { - validateForwardLabel(ctx.trueLabel().getText()); - String falseLabelId = validateForwardLabel(ctx.falseLabel().getText()); - Expression guard = builder.getExpression(ctx.condition().getText()); - Label falseLabel = builder.getOrCreateLabel(falseLabelId); - Label mergeLabel = builder.makeBranchEndLabel(falseLabel); + private Event visitLoopBranch(String labelId) { + continueLabelId = null; + mergeLabelId = null; + return visitGoto(labelId); + } + + private Event visitIfBranch(Expression guard, String trueLabelId, String falseLabelId) { + for (String labelId : List.of(trueLabelId, falseLabelId)) { + if (helper.isBlockStarted(labelId)) { + throw new ParsingException("Illegal backward jump to '%s' from a structured branch", labelId); + } + } + mergeLabelId = null; + nextLabelId = trueLabelId; + builder.setNextOps(Set.of("OpLabel")); + Label falseLabel = helper.getOrCreateLabel(falseLabelId); + Label mergeLabel = helper.createMergeLabel(falseLabelId); Event event = EventFactory.newIfJumpUnless(guard, falseLabel, mergeLabel); builder.addEvent(event); - return builder.endBlock(event); + return helper.endBlock(event); } - private Event visitOpBranchConditionalUnstructured(SpirvParser.OpBranchConditionalContext ctx) { - // Representing a Spir-V two-labels conditional jump as a pair of jumps - // TODO: A clean implementation with a new event type, - // support for unstructured back jumps - if (!builder.hasBlock(ctx.trueLabel().getText()) && !builder.hasBlock(ctx.falseLabel().getText())) { - Label trueLabel = builder.getOrCreateLabel(ctx.trueLabel().getText()); - Label falseLabel = builder.getOrCreateLabel(ctx.falseLabel().getText()); - Expression guard = builder.getExpression(ctx.condition().getText()); - Event trueJump = builder.addEvent(EventFactory.newJump(guard, trueLabel)); - builder.addEvent(EventFactory.newJumpUnless(guard, falseLabel)); - return builder.endBlock(trueJump); + private Event visitConditionalJump(Expression guard, String trueLabelId, String falseLabelId) { + if (helper.isBlockStarted(trueLabelId)) { + if (helper.isBlockStarted(falseLabelId)) { + throw new ParsingException("Unsupported conditional branch " + + "with two backward jumps to '%s' and '%s'", trueLabelId, falseLabelId); + } + String labelId = trueLabelId; + trueLabelId = falseLabelId; + falseLabelId = labelId; + guard = ExpressionFactory.getInstance().makeNot(guard); } - - Expression guard = builder.getExpression(ctx.condition().getText()); - Label trueLabel = builder.getOrCreateLabel(ctx.trueLabel().getText()); - Label falseLabel = builder.getOrCreateLabel(ctx.falseLabel().getText()); - Label trueEnd = builder.makeBranchBackJumpLabel(trueLabel); - Label falseEnd = builder.makeBranchBackJumpLabel(falseLabel); - - Event trueJump = builder.addEvent(EventFactory.newJump(guard, trueEnd)); - builder.addEvent(EventFactory.newJumpUnless(guard, falseEnd)); - builder.addEvent(trueEnd); - builder.addEvent(EventFactory.newGoto(trueLabel)); - builder.addEvent(falseEnd); + Label trueLabel = helper.getOrCreateLabel(trueLabelId); + Label falseLabel = helper.getOrCreateLabel(falseLabelId); + Event trueJump = builder.addEvent(EventFactory.newJump(guard, trueLabel)); builder.addEvent(EventFactory.newGoto(falseLabel)); - return builder.endBlock(trueJump); + return helper.endBlock(trueJump); } - private String validateBackwardLabel(String id) { - if (!builder.hasBlock(id)) { - throw new ParsingException("Illegal forward jump to label '%s' " + - "from a structured loop", id); - } - return id; - } + private Event visitLoopBranchConditional(Expression guard, String trueLabelId, String falseLabelId) { + mergeLabelId = null; + continueLabelId = null; + nextLabelId = trueLabelId; + builder.setNextOps(Set.of("OpLabel")); - private String validateForwardLabel(String id) { - if (builder.hasBlock(id)) { - throw new ParsingException("Illegal backward jump to label '%s' " + - "from a structured branch", id); - } - return id; + // TODO: For a structured while loop, a control dependency + // should be generated in the same way as for a structured if branch. + // We need to add a new event type for this. + // For now, we can treat while loop as unstructured jumps, + // because Vulkan memory model has no control dependency. + + return visitConditionalJump(guard, trueLabelId, falseLabelId); } public Set<String> getSupportedOps() { diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperControlFlow.java b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperControlFlow.java new file mode 100644 index 0000000000..a23b09b0dc --- /dev/null +++ b/dartagnan/src/main/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/helpers/HelperControlFlow.java @@ -0,0 +1,107 @@ +package com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers; + +import com.dat3m.dartagnan.exception.ParsingException; +import com.dat3m.dartagnan.expression.Expression; +import com.dat3m.dartagnan.program.Register; +import com.dat3m.dartagnan.program.event.Event; +import com.dat3m.dartagnan.program.event.EventFactory; +import com.dat3m.dartagnan.program.event.core.Label; +import com.google.common.collect.Sets; + +import java.util.*; + +public class HelperControlFlow { + + protected final Map<String, Label> blockLabels = new HashMap<>(); + protected final Map<String, Event> lastBlockEvents = new HashMap<>(); + protected final Map<String, String> mergeLabelIds = new HashMap<>(); + protected final Deque<String> blockStack = new ArrayDeque<>(); + protected final Map<String, Map<Register, String>> phiDefinitions = new HashMap<>(); + protected final Map<String, Expression> expressions; + + public HelperControlFlow(Map<String, Expression> expressions) { + this.expressions = expressions; + } + + public boolean isInsideBlock() { + return !blockStack.isEmpty(); + } + + public boolean isBlockStarted(String id) { + return lastBlockEvents.containsKey(id) || blockStack.contains(id); + } + + public void startBlock(String id) { + if (lastBlockEvents.containsKey(id)) { + throw new ParsingException("Attempt to redefine label '%s'", id); + } + blockStack.push(id); + } + + public Event endBlock(Event event) { + if (blockStack.isEmpty()) { + throw new ParsingException("Attempt to exit block while not in a block definition"); + } + lastBlockEvents.put(blockStack.pop(), event); + return event; + } + + public Label getOrCreateLabel(String id) { + return blockLabels.computeIfAbsent(id, EventFactory::newLabel); + } + + public Label createMergeLabel(String id) { + String mergeId = id + "_end"; + mergeLabelIds.put(id, mergeId); + return createLabel(mergeId); + } + + public Label createBackJumpLabel(String id) { + return createLabel(id + "_back"); + } + + public void addPhiDefinition(String blockId, Register register, String expressionId) { + phiDefinitions.computeIfAbsent(blockId, k -> new HashMap<>()).put(register, expressionId); + } + + public void build() { + validateBeforeBuild(); + phiDefinitions.forEach((blockId, def) -> + def.forEach((k, v) -> { + Event event = EventFactory.newLocal(k, expressions.get(v)); + lastBlockEvents.get(blockId).getPredecessor().insertAfter(event); + })); + mergeLabelIds.forEach((jumpLabelId, endLabelId) -> + lastBlockEvents.get(jumpLabelId).getPredecessor().insertAfter(blockLabels.get(endLabelId))); + } + + private void validateBeforeBuild() { + if (!blockStack.isEmpty()) { + throw new ParsingException("Unclosed blocks %s", String.join(",", blockStack)); + } + Set<String> missingPhiBlocks = Sets.difference(phiDefinitions.keySet(), blockLabels.keySet()); + if (!missingPhiBlocks.isEmpty()) { + throw new ParsingException("Phi operation(s) refer to undefined block(s) %s", + String.join(", ", missingPhiBlocks)); + } + Set<String> missingMergeBlocks = Sets.difference(mergeLabelIds.keySet(), blockLabels.keySet()); + if (!missingMergeBlocks.isEmpty()) { + throw new ParsingException("Branch merge label(s) refer to undefined block(s) %s", + String.join(", ", missingMergeBlocks)); + } + Map<Event, String> reverse = new HashMap<>(); + lastBlockEvents.forEach((k, v) -> { + if (reverse.containsKey(v)) { + throw new ParsingException("Multiple blocks end in the same event '%s'", v); + } + reverse.put(v, k); + }); + } + + private Label createLabel(String id) { + if (blockLabels.containsKey(id)) { + throw new ParsingException("Attempt to redefine label '%s'", id); + } + return getOrCreateLabel(id); + } +} diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderTest.java index e69041c4fd..62b513099f 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/ProgramBuilderTest.java @@ -3,6 +3,7 @@ import com.dat3m.dartagnan.exception.ParsingException; import com.dat3m.dartagnan.expression.type.FunctionType; import com.dat3m.dartagnan.expression.type.TypeFactory; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperControlFlow; import com.dat3m.dartagnan.program.event.Event; import com.dat3m.dartagnan.program.event.core.Skip; import org.junit.Test; @@ -18,6 +19,7 @@ public class ProgramBuilderTest { private static final TypeFactory types = TypeFactory.getInstance(); private final ProgramBuilderSpv builder = new ProgramBuilderSpv(List.of(1, 1, 1, 1), Map.of()); + private final HelperControlFlow helper = builder.getHelperControlFlow(); @Test public void testAddEventOutsideFunction() { @@ -35,8 +37,9 @@ public void testAddEventBeforeBlock() { public void testAddEventAfterBlock() { FunctionType type = types.getFunctionType(types.getVoidType(), List.of()); builder.startFunctionDefinition("test_func", type, List.of()); - builder.startBlock(builder.getOrCreateLabel("test_label")); - builder.endBlock(new Skip()); + helper.getOrCreateLabel("test_label"); + helper.startBlock("test_label"); + helper.endBlock(new Skip()); testAddChildError("Attempt to add an event outside a control flow block"); } diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsBarrierTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsBarrierTest.java index accf81802d..af3ea3e056 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsBarrierTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsBarrierTest.java @@ -86,7 +86,7 @@ public void testMemoryBarrierNoneSemantics() { private Event visit(String text) { builder.mockFunctionStart(); - builder.startBlock(builder.getOrCreateLabel("test")); + builder.mockLabel("test"); return new MockSpirvParser(text).spv().spvInstructions().accept(new VisitorOpsBarrier(builder)); } } diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlowTest.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlowTest.java index e4dc3fe6de..f6c9fa7815 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlowTest.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/VisitorOpsControlFlowTest.java @@ -1,16 +1,18 @@ package com.dat3m.dartagnan.parsers.program.visitors.spirv; import com.dat3m.dartagnan.exception.ParsingException; -import com.dat3m.dartagnan.expression.Expression; import com.dat3m.dartagnan.expression.booleans.BoolLiteral; -import com.dat3m.dartagnan.expression.booleans.BoolUnaryExpr; import com.dat3m.dartagnan.expression.type.FunctionType; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockHelperControlFlow; import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockProgramBuilderSpv; import com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks.MockSpirvParser; import com.dat3m.dartagnan.program.Function; import com.dat3m.dartagnan.program.Register; import com.dat3m.dartagnan.program.event.Event; -import com.dat3m.dartagnan.program.event.core.*; +import com.dat3m.dartagnan.program.event.core.CondJump; +import com.dat3m.dartagnan.program.event.core.IfAsJump; +import com.dat3m.dartagnan.program.event.core.Label; +import com.dat3m.dartagnan.program.event.core.Skip; import com.dat3m.dartagnan.program.event.functions.Return; import org.junit.Test; @@ -22,6 +24,7 @@ public class VisitorOpsControlFlowTest { private final MockProgramBuilderSpv builder = new MockProgramBuilderSpv(); + private final MockHelperControlFlow helper = (MockHelperControlFlow) builder.getHelperControlFlow(); @Test public void testOpPhi() { @@ -37,13 +40,11 @@ public void testOpPhi() { Register register = builder.getRegister("%phi"); assertEquals(builder.getType("%int"), register.getType()); - Label label1 = builder.getOrCreateLabel("%label1"); - Map<Register, String> phi1 = builder.getPhiDefinitions(label1); + Map<Register, String> phi1 = helper.getPhiDefinitions("%label1"); assertEquals(1, phi1.size()); assertEquals("%value1", phi1.get(register)); - Label label2 = builder.getOrCreateLabel("%label2"); - Map<Register, String> phi2 = builder.getPhiDefinitions(label2); + Map<Register, String> phi2 = helper.getPhiDefinitions("%label2"); assertEquals(1, phi2.size()); assertEquals("%value2", phi2.get(register)); } @@ -61,21 +62,19 @@ public void testOpLabel() { visit(input); // then - Label label1 = builder.getOrCreateLabel("%label1"); - Label label2 = builder.getOrCreateLabel("%label2"); - assertEquals(List.of(label2, label1), builder.getBlocks()); + assertEquals(List.of("%label2", "%label1"), helper.getBlockStack()); // when - builder.endBlock(new Skip()); + helper.endBlock(new Skip()); // then - assertEquals(List.of(label1), builder.getBlocks()); + assertEquals(List.of("%label1"), helper.getBlockStack()); // when - builder.endBlock(new Skip()); + helper.endBlock(new Skip()); // then - assertEquals(List.of(), builder.getBlocks()); + assertEquals(List.of(), helper.getBlockStack()); } @Test @@ -92,10 +91,10 @@ public void testOpBranch() { // then Function function = builder.getCurrentFunction(); - CondJump event = (CondJump) function.getEvents().get(1); - assertTrue(((BoolLiteral) (event.getGuard())).getValue()); - assertEquals(builder.getOrCreateLabel("%label"), event.getLabel()); - assertTrue(builder.getBlocks().isEmpty()); + CondJump condJump = (CondJump) function.getEvents().get(1); + assertTrue(((BoolLiteral) (condJump.getGuard())).getValue()); + assertEquals("%label", condJump.getLabel().getName()); + assertTrue(helper.getBlockStack().isEmpty()); } @Test @@ -116,15 +115,15 @@ public void testOpBranchNested() { // then Function function = builder.getCurrentFunction(); - CondJump event1 = (CondJump) function.getEvents().get(2); - assertTrue(((BoolLiteral) (event1.getGuard())).getValue()); - assertEquals(builder.getOrCreateLabel("%label3"), event1.getLabel()); + CondJump condJump1 = (CondJump) function.getEvents().get(2); + assertTrue(((BoolLiteral) (condJump1.getGuard())).getValue()); + assertEquals("%label3", condJump1.getLabel().getName()); - CondJump event2 = (CondJump) function.getEvents().get(4); - assertTrue(((BoolLiteral) (event2.getGuard())).getValue()); - assertEquals(builder.getOrCreateLabel("%label2"), event2.getLabel()); + CondJump condJump2 = (CondJump) function.getEvents().get(4); + assertTrue(((BoolLiteral) (condJump2.getGuard())).getValue()); + assertEquals("%label2", condJump2.getLabel().getName()); - assertEquals(List.of(builder.getOrCreateLabel("%label1")), builder.getBlocks()); + assertEquals(List.of("%label1"), helper.getBlockStack()); } @Test @@ -157,16 +156,17 @@ public void testStructuredBranch() { Label label2 = (Label) events.get(4); Return ret = (Return) events.get(5); - Label label2End = builder.getCfDefinition().get(label2); + assertEquals("%label0", label0.getName()); + assertEquals("%label2", ifJump.getLabel().getName()); + assertEquals("%label2_end", ifJump.getEndIf().getName()); + assertEquals("%label1", label1.getName()); + assertEquals("%label2", jump.getLabel().getName()); + assertEquals("%label2", label2.getName()); - assertEquals(label2, ifJump.getLabel()); - assertEquals(label2End, ifJump.getEndIf()); assertTrue(jump.isGoto()); - assertEquals(label2, jump.getLabel()); - - assertEquals(ifJump, builder.getBlockEndEvents().get(label0)); - assertEquals(jump, builder.getBlockEndEvents().get(label1)); - assertEquals(ret, builder.getBlockEndEvents().get(label2)); + assertEquals(Map.of("%label2", "%label2_end"), helper.getMergeLabelIds()); + assertEquals(Map.of("%label0", ifJump, "%label1", jump, "%label2", ret), + helper.getLastBlockEvents()); } @Test @@ -208,24 +208,29 @@ public void testStructuredBranchNestedTrue() { Label label2 = (Label) events.get(8); Return ret = (Return) events.get(9); - Label label2End = builder.getCfDefinition().get(label2); - Label label2EndInner = builder.getCfDefinition().get(label2Inner); - - assertEquals(label2, ifJump.getLabel()); - assertEquals(label2End, ifJump.getEndIf()); - assertEquals(label2Inner, ifJumpInner.getLabel()); - assertEquals(label2EndInner, ifJumpInner.getEndIf()); + assertEquals("%label0", label0.getName()); + assertEquals("%label2", ifJump.getLabel().getName()); + assertEquals("%label2_end", ifJump.getEndIf().getName()); + assertEquals("%label1", label1.getName()); + assertEquals("%label2_inner", ifJumpInner.getLabel().getName()); + assertEquals("%label2_inner_end", ifJumpInner.getEndIf().getName()); + assertEquals("%label1_inner", label1Inner.getName()); + assertEquals("%label2_inner", jumpInner.getLabel().getName()); + assertEquals("%label2_inner", label2Inner.getName()); + assertEquals("%label2", jump.getLabel().getName()); + assertEquals("%label2", jump.getLabel().getName()); + assertEquals("%label2", label2.getName()); assertTrue(jump.isGoto()); - assertEquals(label2, jump.getLabel()); assertTrue(jumpInner.isGoto()); - assertEquals(label2Inner, jumpInner.getLabel()); - assertEquals(ifJump, builder.getBlockEndEvents().get(label0)); - assertEquals(ifJumpInner, builder.getBlockEndEvents().get(label1)); - assertEquals(jumpInner, builder.getBlockEndEvents().get(label1Inner)); - assertEquals(jump, builder.getBlockEndEvents().get(label2Inner)); - assertEquals(ret, builder.getBlockEndEvents().get(label2)); + assertEquals(Map.of( + "%label0", ifJump, + "%label1", ifJumpInner, + "%label1_inner", jumpInner, + "%label2_inner", jump, + "%label2", ret + ), helper.getLastBlockEvents()); } @Test @@ -234,7 +239,7 @@ public void testStructuredBranchNestedFalse() { String input = """ %label0 = OpLabel OpSelectionMerge %label2 None - OpBranchConditional %value %label1 %label2 + OpBranchConditional %value %label1 %label3 %label1 = OpLabel OpBranch %label2 %label2 = OpLabel @@ -243,6 +248,8 @@ public void testStructuredBranchNestedFalse() { %label1_inner = OpLabel OpBranch %label2_inner %label2_inner = OpLabel + OpBranch %label3 + %label3 = OpLabel OpReturn """; @@ -259,32 +266,42 @@ public void testStructuredBranchNestedFalse() { Label label0 = (Label) events.get(0); IfAsJump ifJump = (IfAsJump) events.get(1); Label label1 = (Label) events.get(2); - CondJump jump = (CondJump) events.get(3); + CondJump jump1 = (CondJump) events.get(3); Label label2 = (Label) events.get(4); IfAsJump ifJumpInner = (IfAsJump) events.get(5); Label label1Inner = (Label) events.get(6); CondJump jumpInner = (CondJump) events.get(7); Label label2Inner = (Label) events.get(8); - Return ret = (Return) events.get(9); - - Label label2End = builder.getCfDefinition().get(label2); - Label label2EndInner = builder.getCfDefinition().get(label2Inner); - - assertEquals(label2, ifJump.getLabel()); - assertEquals(label2End, ifJump.getEndIf()); - assertEquals(label2Inner, ifJumpInner.getLabel()); - assertEquals(label2EndInner, ifJumpInner.getEndIf()); - - assertTrue(jump.isGoto()); - assertEquals(label2, jump.getLabel()); + CondJump jump2 = (CondJump) events.get(9); + Label label3 = (Label) events.get(10); + Return ret = (Return) events.get(11); + + assertEquals("%label0", label0.getName()); + assertEquals("%label3", ifJump.getLabel().getName()); + assertEquals("%label3_end", ifJump.getEndIf().getName()); + assertEquals("%label2", jump1.getLabel().getName()); + assertEquals("%label1", label1.getName()); + assertEquals("%label2", label2.getName()); + assertEquals("%label2_inner", ifJumpInner.getLabel().getName()); + assertEquals("%label2_inner_end", ifJumpInner.getEndIf().getName()); + assertEquals("%label1_inner", label1Inner.getName()); + assertEquals("%label2_inner", jumpInner.getLabel().getName()); + assertEquals("%label2_inner", label2Inner.getName()); + assertEquals("%label3", jump2.getLabel().getName()); + assertEquals("%label3", label3.getName()); + + assertTrue(jump1.isGoto()); + assertTrue(jump2.isGoto()); assertTrue(jumpInner.isGoto()); - assertEquals(label2Inner, jumpInner.getLabel()); - assertEquals(ifJump, builder.getBlockEndEvents().get(label0)); - assertEquals(jump, builder.getBlockEndEvents().get(label1)); - assertEquals(ifJumpInner, builder.getBlockEndEvents().get(label2)); - assertEquals(jumpInner, builder.getBlockEndEvents().get(label1Inner)); - assertEquals(ret, builder.getBlockEndEvents().get(label2Inner)); + assertEquals(Map.of( + "%label0", ifJump, + "%label1", jump1, + "%label2", ifJumpInner, + "%label1_inner", jumpInner, + "%label2_inner", jump2, + "%label3", ret + ), helper.getLastBlockEvents()); } @Test @@ -313,13 +330,60 @@ public void testStructuredBranchNestedSameLabel() { fail("Should throw exception"); } catch (ParsingException e) { // then - assertEquals("Overlapping blocks with endpoint in label '%label2'", - e.getMessage()); + assertEquals("Attempt to redefine label '%label2_end'", e.getMessage()); } } @Test - public void testStructuredLoop() { + public void testLoopWithForwardLabels() { + // given + String input = """ + %label0 = OpLabel + OpLoopMerge %label1 %label2 None + OpBranchConditional %value %label1 %label2 + %label1 = OpLabel + OpBranch %label2 + %label2 = OpLabel + OpReturn + """; + + builder.mockFunctionStart(); + builder.mockBoolType("%bool"); + builder.mockRegister("%value", "%bool"); + + // when + visit(input); + + // then + List<Event> events = builder.getCurrentFunction().getEvents(); + + Label label0 = (Label) events.get(0); + CondJump jump1 = (CondJump) events.get(1); + CondJump jump2 = (CondJump) events.get(2); + Label label1 = (Label) events.get(3); + CondJump jump3 = (CondJump) events.get(4); + Label label2 = (Label) events.get(5); + Return ret = (Return) events.get(6); + + assertEquals("%label0", label0.getName()); + assertEquals("%label1", jump1.getLabel().getName()); + assertEquals("%label2", jump2.getLabel().getName()); + assertEquals("%label1", label1.getName()); + assertEquals("%label2", jump2.getLabel().getName()); + assertEquals("%label2", label2.getName()); + + assertFalse(jump1.isGoto()); + assertTrue(jump2.isGoto()); + assertTrue(jump3.isGoto()); + + assertTrue(helper.getMergeLabelIds().isEmpty()); + + assertEquals(Map.of("%label0", jump1, "%label1", jump3, "%label2", ret), + helper.getLastBlockEvents()); + } + + @Test + public void testLoopWithBackwardLabel() { // given String input = """ %label0 = OpLabel @@ -340,20 +404,23 @@ public void testStructuredLoop() { List<Event> events = builder.getCurrentFunction().getEvents(); Label label0 = (Label) events.get(0); - IfAsJump ifJump = (IfAsJump) events.get(1); - CondJump jump = (CondJump) events.get(2); + CondJump jump1 = (CondJump) events.get(1); + CondJump jump2 = (CondJump) events.get(2); Label label1 = (Label) events.get(3); Return ret = (Return) events.get(4); - assertTrue(builder.getCfDefinition().isEmpty()); + assertEquals("%label0", label0.getName()); + assertEquals("%label1", jump1.getLabel().getName()); + assertEquals("%label0", jump2.getLabel().getName()); + assertEquals("%label1", label1.getName()); - assertEquals(label1, ifJump.getLabel()); - assertEquals(label1, ifJump.getEndIf()); - assertTrue(jump.isGoto()); - assertEquals(label0, jump.getLabel()); + assertFalse(jump1.isGoto()); + assertTrue(jump2.isGoto()); - assertEquals(ifJump, builder.getBlockEndEvents().get(label0)); - assertEquals(ret, builder.getBlockEndEvents().get(label1)); + assertTrue(helper.getMergeLabelIds().isEmpty()); + + assertEquals(Map.of("%label0", jump1, "%label1", ret), + helper.getLastBlockEvents()); } @Test @@ -365,7 +432,7 @@ public void testStructuredBranchBackwardTrue() { %label1 = OpLabel OpReturn """, - "Illegal backward jump to label '%label0' " + + "Illegal backward jump to '%label0' " + "from a structured branch"); } @@ -378,57 +445,10 @@ public void testStructuredBranchBackwardFalse() { %label1 = OpLabel OpReturn """, - "Illegal backward jump to label '%label0' " + - "from a structured branch"); - } - - @Test - public void testLoopMergeBackward() { - doTestIllegalStructuredBranch(""" - %label2 = OpLabel - OpBranch %label0 - %label0 = OpLabel - OpLoopMerge %label2 %label0 None - OpBranchConditional %value1 %label2 %label0 - %label1 = OpLabel - OpReturn - """, - "Illegal backward jump to label '%label2' " + + "Illegal backward jump to '%label0' " + "from a structured branch"); } - @Test - public void testLoopContinueForward() { - doTestIllegalStructuredBranch(""" - %label0 = OpLabel - OpLoopMerge %label1 %label2 None - OpBranchConditional %value1 %label1 %label2 - %label1 = OpLabel - OpBranch %label2 - %label2 = OpLabel - OpReturn - """, - "Illegal forward jump to label '%label2' " + - "from a structured loop"); - } - - @Test - public void testStructuredBranchIllegalMergeLabel() { - doTestIllegalStructuredBranch(""" - %label0 = OpLabel - OpSelectionMerge %label3 None - OpBranchConditional %value1 %label1 %label2 - %label1 = OpLabel - OpBranch %label2 - %label2 = OpLabel - OpBranch %label3 - %label3 = OpLabel - OpReturn - """, - "Illegal last label in conditional branch, " + - "expected '%label3' but received '%label2'"); - } - @Test public void testStructuredBranchLabelsIllegalOrder() { doTestIllegalStructuredBranch(""" @@ -444,39 +464,22 @@ public void testStructuredBranchLabelsIllegalOrder() { } @Test - public void testOpLoopLabelsIllegalContinueLabel() { + public void testLoopMergeWithTwoBackwardLabels() { doTestIllegalStructuredBranch(""" - %label0 = OpLabel - OpLoopMerge %label1 %label2 None - OpBranchConditional %value1 %label1 %label0 - %label1 = OpLabel - OpBranch %label2 %label2 = OpLabel - OpReturn - """, - "Illegal labels, expected mergeLabel='%label1' " + - "and continueLabel='%label2' but received " + - "mergeLabel='%label1' and continueLabel='%label0'"); - } - - @Test - public void testOpLoopLabelsIllegalMergeLabel() { - doTestIllegalStructuredBranch(""" + OpBranch %label0 %label0 = OpLabel OpLoopMerge %label2 %label0 None - OpBranchConditional %value1 %label1 %label0 + OpBranchConditional %value1 %label2 %label0 %label1 = OpLabel - OpBranch %label2 - %label2 = OpLabel OpReturn """, - "Illegal labels, expected mergeLabel='%label2' " + - "and continueLabel='%label0' but received " + - "mergeLabel='%label1' and continueLabel='%label0'"); + "Unsupported conditional branch " + + "with two backward jumps to '%label2' and '%label0'"); } @Test - public void testOpLoopLabelsIllegalOrder() { + public void testOpLoopIllegalTrueLabel() { doTestIllegalStructuredBranch(""" %label0 = OpLabel OpLoopMerge %label1 %label0 None @@ -484,9 +487,7 @@ public void testOpLoopLabelsIllegalOrder() { %label1 = OpLabel OpReturn """, - "Illegal labels, expected mergeLabel='%label1' " + - "and continueLabel='%label0' but received " + - "mergeLabel='%label0' and continueLabel='%label1'"); + "Illegal label, expected '%label0' but received '%label1'"); } private void doTestIllegalStructuredBranch(String input, String error) { @@ -510,10 +511,10 @@ public void testOpBranchConditionalUnstructured() { // given String input = """ %label0 = OpLabel - OpBranchConditional %value1 %label2 %label3 + OpBranchConditional %value1 %label1 %label2 + %label1 = OpLabel + OpBranchConditional %value2 %label2 %label1 %label2 = OpLabel - OpBranchConditional %value2 %label3 %label1 - %label3 = OpLabel OpReturn """; builder.mockFunctionStart(); @@ -527,23 +528,31 @@ public void testOpBranchConditionalUnstructured() { // then List<Event> events = builder.getCurrentFunction().getEvents(); - CondJump event1 = (CondJump) events.get(1); - assertEquals("%value1", getGuardRegister(event1).getName()); - assertEquals(builder.getOrCreateLabel("%label2"), event1.getLabel()); - - CondJump event2 = (CondJump) events.get(2); - assertEquals("%value1", getGuardRegister(event2).getName()); - assertEquals(builder.getOrCreateLabel("%label3"), event2.getLabel()); - - CondJump event4 = (CondJump) events.get(4); - assertEquals("%value2", getGuardRegister(event4).getName()); - assertEquals(builder.getOrCreateLabel("%label3"), event4.getLabel()); - - CondJump event5 = (CondJump) events.get(5); - assertEquals("%value2", getGuardRegister(event5).getName()); - assertEquals(builder.getOrCreateLabel("%label1"), event5.getLabel()); - - assertTrue(builder.getBlocks().isEmpty()); + Label label0 = (Label) events.get(0); + CondJump jump1 = (CondJump) events.get(1); + CondJump jump2 = (CondJump) events.get(2); + Label label1 = (Label) events.get(3); + CondJump jump3 = (CondJump) events.get(4); + CondJump jump4 = (CondJump) events.get(5); + Label label2 = (Label) events.get(6); + Return ret = (Return) events.get(7); + + assertEquals("%label0", label0.getName()); + assertEquals("%label1", jump1.getLabel().getName()); + assertEquals("%label2", jump2.getLabel().getName()); + assertEquals("%label1", label1.getName()); + assertEquals("%label2", jump3.getLabel().getName()); + assertEquals("%label1", jump4.getLabel().getName()); + assertEquals("%label2", label2.getName()); + + assertFalse(jump1.isGoto()); + assertTrue(jump2.isGoto()); + assertFalse(jump3.isGoto()); + assertTrue(jump4.isGoto()); + + assertTrue(helper.getBlockStack().isEmpty()); + assertEquals(Map.of("%label0", jump1, "%label1", jump3, "%label2", ret), + helper.getLastBlockEvents()); } @Test @@ -563,8 +572,7 @@ public void testOpBranchConditionalSameLabels() { fail("Should throw exception"); } catch (ParsingException e) { // then - assertEquals("Labels of conditional branch cannot be the same", - e.getMessage()); + assertEquals("Labels of conditional branch cannot be the same", e.getMessage()); } } @@ -604,7 +612,7 @@ public void testReturn() { Return event = (Return) function.getEvents().get(1); assertNotNull(event); assertTrue(event.getValue().isEmpty()); - assertTrue(builder.getBlocks().isEmpty()); + assertTrue(helper.getBlockStack().isEmpty()); } @Test @@ -623,7 +631,7 @@ public void testReturnValue() { Function function = builder.getCurrentFunction(); Return event = (Return) function.getEvents().get(1); assertEquals(builder.getExpression("%value"), event.getValue().orElseThrow()); - assertTrue(builder.getBlocks().isEmpty()); + assertTrue(helper.getBlockStack().isEmpty()); } @Test @@ -667,17 +675,6 @@ public void testReturnValueForVoidFunction() { } } - private Register getGuardRegister(CondJump event) { - Expression guard = event.getGuard(); - if (guard instanceof Register register) { - return register; - } - if (guard instanceof BoolUnaryExpr expr) { - return (Register) expr.getOperand(); - } - throw new RuntimeException("Unexpected expression type"); - } - private void visit(String text) { new MockSpirvParser(text).spv().accept(new VisitorOpsControlFlow(builder)); } diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockHelperControlFlow.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockHelperControlFlow.java new file mode 100644 index 0000000000..f758d665c0 --- /dev/null +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockHelperControlFlow.java @@ -0,0 +1,32 @@ +package com.dat3m.dartagnan.parsers.program.visitors.spirv.mocks; + +import com.dat3m.dartagnan.expression.Expression; +import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperControlFlow; +import com.dat3m.dartagnan.program.Register; +import com.dat3m.dartagnan.program.event.Event; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class MockHelperControlFlow extends HelperControlFlow { + public MockHelperControlFlow(Map<String, Expression> expressions) { + super(expressions); + } + + public List<String> getBlockStack() { + return blockStack.stream().toList(); + } + + public Map<String, String> getMergeLabelIds() { + return Map.copyOf(mergeLabelIds); + } + + public Map<String, Event> getLastBlockEvents() { + return Map.copyOf(lastBlockEvents); + } + + public Map<Register, String> getPhiDefinitions(String blockId) { + return Map.copyOf(phiDefinitions.computeIfAbsent(blockId, k -> new HashMap<>())); + } +} diff --git a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilderSpv.java b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilderSpv.java index a348eea92f..c17ef1d81f 100644 --- a/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilderSpv.java +++ b/dartagnan/src/test/java/com/dat3m/dartagnan/parsers/program/visitors/spirv/mocks/MockProgramBuilderSpv.java @@ -25,15 +25,16 @@ public class MockProgramBuilderSpv extends ProgramBuilderSpv { private static final ExpressionFactory exprFactory = ExpressionFactory.getInstance(); public MockProgramBuilderSpv() { - super(List.of(1, 1, 1, 1), Map.of()); + this(List.of(1, 1, 1, 1), Map.of()); } public MockProgramBuilderSpv(Map<String, Expression> input) { - super(List.of(1, 1, 1, 1), input); + this(List.of(1, 1, 1, 1), input); } public MockProgramBuilderSpv(List<Integer> grid, Map<String, Expression> input) { super(grid, input); + helperControlFlow = new MockHelperControlFlow(expressions); } @Override @@ -141,12 +142,13 @@ public Register mockRegister(String id, String typeId) { } public void mockLabel() { - startBlock(new Label("%mock_label")); + helperControlFlow.getOrCreateLabel("%mock_label"); + helperControlFlow.startBlock("%mock_label"); } public void mockLabel(String id) { - Label label = getOrCreateLabel(id); - startBlock(label); + Label label = helperControlFlow.getOrCreateLabel(id); + helperControlFlow.startBlock(id); addEvent(label); } @@ -175,23 +177,7 @@ public Map<String, Expression> getExpressions() { return Map.copyOf(expressions); } - public List<Label> getBlocks() { - return blocks.stream().toList(); - } - - public Map<Label, Label> getCfDefinition() { - return Map.copyOf(cfDefinitions); - } - - public Map<Label, Event> getBlockEndEvents() { - return Map.copyOf(blockEndEvents); - } - public Set<Function> getForwardFunctions() { return Set.copyOf(forwardFunctions.values()); } - - public Map<Register, String> getPhiDefinitions(Label label) { - return Map.copyOf(phiDefinitions.computeIfAbsent(label, k -> new HashMap<>())); - } }