Skip to content

Commit

Permalink
Refactored control flow
Browse files Browse the repository at this point in the history
  • Loading branch information
natgavrilenko committed Jul 14, 2024
1 parent afc0f7a commit ca3147d
Show file tree
Hide file tree
Showing 8 changed files with 440 additions and 440 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.BuiltIn;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.Decoration;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperControlFlow;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperDecorations;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperTags;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.transformers.ThreadCreator;
Expand All @@ -16,8 +17,6 @@
import com.dat3m.dartagnan.expression.type.ScopedPointerType;
import com.dat3m.dartagnan.program.*;
import com.dat3m.dartagnan.program.event.*;
import com.dat3m.dartagnan.program.event.core.Label;
import com.dat3m.dartagnan.program.event.core.Local;
import com.dat3m.dartagnan.program.memory.Memory;
import com.dat3m.dartagnan.program.memory.MemoryObject;
import org.apache.logging.log4j.LogManager;
Expand All @@ -37,14 +36,9 @@ public class ProgramBuilderSpv {
protected final Map<String, Type> types = new HashMap<>();
protected final Map<String, Expression> expressions = new HashMap<>();
protected final Map<String, Function> forwardFunctions = new HashMap<>();
protected final Map<String, Label> labels = new HashMap<>();
protected final Deque<Label> blocks = new ArrayDeque<>();
protected final Map<Label, Event> blockEndEvents = new HashMap<>();
protected final Map<Label, Map<Register, String>> phiDefinitions = new HashMap<>();
protected final Map<Label, Label> cfDefinitions = new HashMap<>();
protected final List<Integer> threadGrid;
protected final Map<String, Expression> inputs;
protected final Program program;
public final Program program;

protected Function currentFunction;
protected String entryPointId;
Expand All @@ -53,6 +47,7 @@ public class ProgramBuilderSpv {

private final HelperTags helperTags = new HelperTags();
private final HelperDecorations helperDecorations;
protected HelperControlFlow helperControlFlow = new HelperControlFlow(expressions);

public ProgramBuilderSpv(List<Integer> threadGrid, Map<String, Expression> inputs) {
validateThreadGrid(threadGrid);
Expand All @@ -64,7 +59,7 @@ public ProgramBuilderSpv(List<Integer> threadGrid, Map<String, Expression> input

public Program build() {
validateBeforeBuild();
preprocessBlocks();
helperControlFlow.build();
BuiltIn builtIn = (BuiltIn) getDecoration(DecorationType.BUILT_IN);
Set<ScopedPointerVariable> variables = expressions.values().stream()
.filter(ScopedPointerVariable.class::isInstance)
Expand Down Expand Up @@ -163,26 +158,11 @@ private Function checkFunctionType(String id, Function function, Type type) {
"function type doesn't match the function definition", id);
}

public void startBlock(Label label) {
if (blockEndEvents.containsKey(label)) {
throw new ParsingException("Attempt to redefine label '%s'", label.getName());
}
blocks.push(label);
}

public Event endBlock(Event event) {
if (blocks.isEmpty()) {
throw new ParsingException("Attempt to exit block while not in a block definition");
}
blockEndEvents.put(blocks.pop(), event);
return event;
}

public Event addEvent(Event event) {
if (currentFunction == null) {
throw new ParsingException("Attempt to add an event outside a function definition");
}
if (blocks.isEmpty()) {
if (!helperControlFlow.isInsideBlock()) {
throw new ParsingException("Attempt to add an event outside a control flow block");
}
if (event instanceof RegWriter regWriter) {
Expand Down Expand Up @@ -223,14 +203,6 @@ public String getExpressionStorageClass(String name) {
throw new ParsingException("Reference to undefined pointer '%s'", name);
}

public Label makeBranchBackJumpLabel(Label label) {
String id = label.getName() + "_back";
if (labels.containsKey(id)) {
throw new ParsingException("Overlapping blocks with back jump in label '%s'", label.getName());
}
return getOrCreateLabel(id);
}

public boolean hasInput(String id) {
return inputs.containsKey(id);
}
Expand All @@ -253,22 +225,6 @@ public Register addRegister(String id, String typeId) {
return getCurrentFunctionOrThrowError().newRegister(id, getType(typeId));
}

public boolean hasBlock(String id) {
if (!labels.containsKey(id)) {
return false;
}
Label label = labels.get(id);
return blockEndEvents.containsKey(label) || blocks.contains(label);
}

public Label getOrCreateLabel(String id) {
Label label = labels.computeIfAbsent(id, EventFactory::newLabel);
if (label.getFunction() == null) {
label.setFunction(getCurrentFunctionOrThrowError());
}
return label;
}

public String getScope(String id) {
return helperTags.visitScope(id, getExpression(id));
}
Expand All @@ -289,7 +245,11 @@ public Decoration getDecoration(DecorationType type) {
return helperDecorations.getDecoration(type);
}

private Function getCurrentFunctionOrThrowError() {
public HelperControlFlow getHelperControlFlow() {
return helperControlFlow;
}

public Function getCurrentFunctionOrThrowError() {
if (currentFunction != null) {
return currentFunction;
}
Expand All @@ -306,28 +266,11 @@ private void validateBeforeBuild() {
throw new ParsingException("Unclosed definition for function '%s'",
currentFunction.getName());
}
if (!blocks.isEmpty()) {
throw new ParsingException("Unclosed blocks for labels %s",
String.join(",", blocks.stream().map(Label::getName).toList()));
}
if (nextOps != null) {
throw new ParsingException("Missing expected op: %s",
String.join(",", nextOps));
}
// TODO: Validate no event refers to an undefined label
}

private void preprocessBlocks() {
Map<Event, Label> blockEndToLabelMap = new HashMap<>();
for (Map.Entry<Label, Event> entry : blockEndEvents.entrySet()) {
if (blockEndToLabelMap.containsKey(entry.getValue())) {
throw new ParsingException("Malformed control flow, " +
"multiple block refer to the same end event " + entry.getValue());
}
blockEndToLabelMap.put(entry.getValue(), entry.getKey());
}
insertPhiDefinitions(blockEndToLabelMap);
insertBlockEndLabels(blockEndToLabelMap);
}

private void checkSpecification() {
Expand All @@ -338,40 +281,6 @@ private void checkSpecification() {
}
}

public void addPhiDefinition(Label label, Register register, String id) {
phiDefinitions.computeIfAbsent(label, k -> new HashMap<>()).put(register, id);
}

private void insertPhiDefinitions(Map<Event, Label> blockEndToLabelMap) {
for (Function function : program.getFunctions()) {
for (Event event : function.getEvents()) {
Label label = blockEndToLabelMap.get(event.getSuccessor());
if (label != null) {
Map<Register, String> phi = phiDefinitions.get(label);
if (phi != null) {
for (Map.Entry<Register, String> entry : phi.entrySet()) {
event.insertAfter(new Local(entry.getKey(), getExpression(entry.getValue())));
}
}
}
}
}
}

private void insertBlockEndLabels(Map<Event, Label> blockEndToLabelMap) {
for (Function function : program.getFunctions()) {
for (Event event : function.getEvents()) {
Label label = blockEndToLabelMap.get(event.getSuccessor());
if (label != null) {
Label endLabel = cfDefinitions.get(label);
if (endLabel != null) {
event.insertAfter(endLabel);
}
}
}
}
}

private Function getEntryPointFunction() {
if (entryPointId == null) {
throw new ParsingException("Cannot build the program, entryPointId is missing");
Expand All @@ -389,8 +298,6 @@ private Function getEntryPointFunction() {
throw new ParsingException("Entry point expression '%s' must be a function", entryPointId);
}

// TODO: Move to the corresponding visitors, tests, mocks etc ======================================================

public FunctionType getCurrentFunctionType() {
return getCurrentFunctionOrThrowError().getFunctionType();
}
Expand All @@ -399,16 +306,6 @@ public String getCurrentFunctionName() {
return getCurrentFunctionOrThrowError().getName();
}

public Label makeBranchEndLabel(Label label) {
String id = label.getName() + "_end";
if (labels.containsKey(id)) {
throw new ParsingException("Overlapping blocks with endpoint in label '%s'", label.getName());
}
Label endLabel = getOrCreateLabel(id);
cfDefinitions.put(label, endLabel);
return endLabel;
}

private void validateThreadGrid(List<Integer> threadGrid) {
if (threadGrid.size() != 4) {
throw new ParsingException("Thread grid must have 4 dimensions");
Expand Down
Loading

0 comments on commit ca3147d

Please sign in to comment.