Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
natgavrilenko committed Jul 9, 2024
1 parent 0171d04 commit b306be5
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 243 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import com.dat3m.dartagnan.expression.integers.IntLiteral;
import com.dat3m.dartagnan.expression.processing.ExprTransformer;
import com.dat3m.dartagnan.expression.type.FunctionType;
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.BuiltIn;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.Decoration;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperDecorations;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.helpers.HelperTags;
import com.dat3m.dartagnan.program.memory.ScopedPointer;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.transformers.MemoryTransformer;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.transformers.RegisterTransformer;
import com.dat3m.dartagnan.program.memory.ScopedPointerVariable;
import com.dat3m.dartagnan.expression.type.ScopedPointerType;
import com.dat3m.dartagnan.program.Thread;
import com.dat3m.dartagnan.program.*;
import com.dat3m.dartagnan.program.event.*;
Expand Down Expand Up @@ -41,28 +43,26 @@ public class ProgramBuilderSpv {

private static final Logger logger = LogManager.getLogger(ProgramBuilderSpv.class);

protected final Map<String, Type> types = new HashMap<>();
protected final Map<String, Expression> expressions = new HashMap<>();
protected final Set<String> specConstants = new HashSet<>();
protected final Map<String, Function> forwardFunctions = new HashMap<>();
protected final Map<String, Label> labels = new HashMap<>();
protected final Deque<Label> blocks = new ArrayDeque<>();
protected final Map<Label, Event> blockEndEvents = new HashMap<>();
protected final Map<Label, Map<Register, String>> phiDefinitions = new HashMap<>();
protected final Map<Label, Label> cfDefinitions = new HashMap<>();
protected final List<Integer> threadGrid;
protected final Map<String, Expression> inputs;
protected final Program program;

protected Function currentFunction;
protected String entryPointId;
protected int nextFunctionId = 0;
protected Set<String> nextOps;

private final HelperTags helperTags = new HelperTags();
private final HelperDecorations helperDecorations;
private final Map<String, Type> types = new HashMap<>();
private final Map<String, String> pointedTypes = new HashMap<>();
private final Map<String, Type> variableTypes = new HashMap<>();
private final Map<String, String> pointerClasses = new HashMap<>();
private final Map<String, String> registerClasses = new HashMap<>();
private final Map<String, Expression> expressions = new HashMap<>();
private final Map<String, Function> forwardFunctions = new HashMap<>();
private final Map<String, Label> labels = new HashMap<>();
private final Deque<Label> blocks = new ArrayDeque<>();
private final Map<Label, Event> blockEndEvents = new HashMap<>();
private final Map<Label, Map<Register, String>> phiDefinitions = new HashMap<>();
private final Map<Label, Label> cfDefinitions = new HashMap<>();
private final Set<String> specConstants = new HashSet<>();
private final List<Integer> threadGrid;
private final Map<String, Expression> inputs;
private final Program program;
protected Function currentFunction;
private String entryPointId;
private int nextFunctionId = 0;
private Set<String> nextOps;

public ProgramBuilderSpv(List<Integer> threadGrid, Map<String, Expression> inputs) {
validateThreadGrid(threadGrid);
Expand Down Expand Up @@ -109,7 +109,16 @@ public Program build() {

Function entry = getEntryPointFunction();
BuiltIn builtIn = (BuiltIn) getDecoration(DecorationType.BUILT_IN);
MemoryTransformer transformer = new MemoryTransformer(program, builtIn, variableTypes, registerClasses);

Map<String, Type> vType = expressions.entrySet().stream()
.filter(e -> e.getValue() instanceof ScopedPointerVariable)
.collect(Collectors.toMap((Map.Entry::getKey), (e -> ((ScopedPointerVariable)e.getValue()).getInnerType())));

Map<String, String> stClsMap = expressions.entrySet().stream()
.filter(e -> e.getValue() instanceof ScopedPointerVariable)
.collect(Collectors.toMap((Map.Entry::getKey), (e -> ((ScopedPointerVariable)e.getValue()).getScopeId())));

MemoryTransformer transformer = new MemoryTransformer(program, builtIn, vType, stClsMap);
for (int z = 0; z < threadGrid.get(2); z++) {
for (int y = 0; y < threadGrid.get(1); y++) {
for (int x = 0; x < threadGrid.get(0); x++) {
Expand Down Expand Up @@ -237,12 +246,10 @@ public Event addEvent(Event event) {
return event;
}

public MemoryObject allocateMemory(int bytes) {
return program.getMemory().allocate(bytes);
}

public MemoryObject allocateMemoryVirtual(int bytes) {
return program.getMemory().allocateVirtual(bytes, true, null);
public ScopedPointerVariable allocateMemoryVirtual(String id, String typeId, Type type, int bytes) {
MemoryObject memoryObject = program.getMemory().allocateVirtual(bytes, true, null);
memoryObject.setName(id);
return new ScopedPointerVariable(id, ((ScopedPointerType) getType(typeId)).getScopeId(), type, memoryObject);
}

public Expression newUndefinedValue(Type type) {
Expand All @@ -261,73 +268,28 @@ public Type addType(String name, Type type) {
if (types.containsKey(name) || expressions.containsKey(name)) {
throw new ParsingException("Duplicated definition '%s'", name);
}
//if (TypeFactory.getInstance().isPointerType(type)) {
// throw new ParsingException("Unexpected pointer type '%s'", name);
//}
types.put(name, type);
return type;
}

public Type addPointerType(String name, String innerTypeId, String cls) {
if (types.containsKey(name) || expressions.containsKey(name)) {
throw new ParsingException("Duplicated definition '%s'", name);
}
if (pointedTypes.containsKey(name)) {
throw new ParsingException("Duplicated pointer type definition '%s'", name);
}
getType(innerTypeId);
pointedTypes.put(name, innerTypeId);
if (pointerClasses.containsKey(name)) {
throw new ParsingException("Duplicated variable storage class definition '%s'", name);
}
pointerClasses.put(name, getStorageClass(cls));
Type type = TypeFactory.getInstance().getPointerType();
types.put(name, type);
return type;
}

public Type getPointedType(String name) {
String typeId = pointedTypes.get(name);
if (typeId == null) {
if (!types.containsKey(name)) {
throw new ParsingException("Reference to undefined pointer type '%s'", name);
}
throw new ParsingException("Type '%s' is not a pointer type", name);
}
return getType(typeId);
}

public Type getVariableType(String name) {
Type type = variableTypes.get(name);
if (type == null) {
throw new ParsingException("Reference to undefined variable '%s'", name);
}
return type;
}

public Type addVariableType(String name, Type type) {
if (variableTypes.containsKey(name)) {
throw new ParsingException("Duplicated variable type definition '%s'", name);
}
variableTypes.put(name, type);
return type;
}

public List<MemoryObject> getVariablesWithStorageClass(String storageClass) {
return registerClasses.entrySet().stream()
.filter(e -> e.getValue().equals(storageClass))
.map(e -> getExpression(e.getKey()))
.filter(MemoryObject.class::isInstance)
.map(e -> (MemoryObject) e)
public List<ScopedPointerVariable> getVariablesWithStorageClass(String storageClass) {
return expressions.values().stream()
.filter(ScopedPointerVariable.class::isInstance)
.map(e -> (ScopedPointerVariable)e)
.filter(e -> e.getScopeId().equals(storageClass))
.toList();
}

public String getExpressionStorageClass(String name) {
String storageClass = registerClasses.get(name);
if (storageClass == null) {
throw new ParsingException("Reference to undefined pointer '%s'", name);
Expression expression = getExpression(name);
if (expression instanceof ScopedPointer pExpr) {
return pExpr.getScopeId();
}
if (expression instanceof Register) {
// TODO: Hacky, ideally new pointer type for registers
return Tag.Spirv.SC_FUNCTION;
}
return storageClass;
throw new ParsingException("Reference to undefined pointer '%s'", name);
}

public Label makeBranchBackJumpLabel(Label label) {
Expand Down Expand Up @@ -391,24 +353,10 @@ public Register getRegister(String id) {
}

public Register addRegister(String id, String typeId) {
addStorageClassForExpr(id, typeId);
return getCurrentFunctionOrThrowError().newRegister(id, getType(typeId));
}

public String addStorageClassForExpr(String id, String typeId) {
if (pointedTypes.containsKey(typeId)) {
String storageClass = pointerClasses.get(typeId);
if (storageClass == null) {
throw new ParsingException("Missing storage class for pointer '%s'", typeId);
}
if (registerClasses.containsKey(id)) {
throw new ParsingException("Duplicated storage class definition for expression'%s'", id);
}
registerClasses.put(id, storageClass);
return storageClass;
if (getType(typeId) instanceof ScopedPointerType) {
throw new ParsingException("Register cannot be a pointer");
}
// TODO:
return null;
return getCurrentFunctionOrThrowError().newRegister(id, getType(typeId));
}

public boolean hasBlock(String id) {
Expand Down Expand Up @@ -625,18 +573,6 @@ public String getCurrentFunctionName() {
return getCurrentFunctionOrThrowError().getName();
}

public Set<Function> getForwardFunctions() {
return Set.copyOf(forwardFunctions.values());
}

public Map<Label, Event> getBlockEndEvents() {
return Map.copyOf(blockEndEvents);
}

public Map<Label, Label> getCfDefinition() {
return Map.copyOf(cfDefinitions);
}

public Label makeBranchEndLabel(Label label) {
String id = label.getName() + "_end";
if (labels.containsKey(id)) {
Expand All @@ -650,23 +586,4 @@ public Label makeBranchEndLabel(Label label) {
public Map<Register, String> getPhiDefinitions(Label label) {
return phiDefinitions.computeIfAbsent(label, k -> new HashMap<>());
}

public Map<String, Type> getTypes() {
return Map.copyOf(types);
}

public Map<String, Expression> getExpressions() {
return Map.copyOf(expressions);
}

public List<Label> getBlocks() {
return blocks.stream().toList();
}

public MemoryObject getMemoryObject(String id) {
return program.getMemory().getObjects().stream()
.filter(o -> o.getName().equals(id))
.findFirst()
.orElseThrow(() -> new ParsingException("Undefined memory object '%s'", id));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.dat3m.dartagnan.expression.integers.IntBinaryOp;
import com.dat3m.dartagnan.expression.integers.IntCmpOp;
import com.dat3m.dartagnan.expression.type.IntegerType;
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.parsers.SpirvBaseVisitor;
import com.dat3m.dartagnan.parsers.SpirvParser;
import com.dat3m.dartagnan.program.Register;
Expand All @@ -20,7 +19,6 @@

public class VisitorOpsAtomic extends SpirvBaseVisitor<Event> {

private static final TypeFactory TYPE_FACTORY = TypeFactory.getInstance();
private final ProgramBuilderSpv builder;

public VisitorOpsAtomic(ProgramBuilderSpv builder) {
Expand Down Expand Up @@ -229,11 +227,13 @@ private IntegerType getIntegerType(String typeId) {
}

private Expression getPointer(String ptrId) {
return builder.getExpression(ptrId);
/*
Expression result = builder.getExpression(ptrId);
if (builder.getVariableType(ptrId) != null) {
if (result instanceof MemoryObject) {
return result;
}
throw new ParsingException("Unexpected type at '%s', expected pointer but received '%s'", ptrId, result.getType());
throw new ParsingException("Unexpected type at '%s', expected pointer but received '%s'", ptrId, result.getType());*/
}

public Set<String> getSupportedOps() {
Expand Down
Loading

0 comments on commit b306be5

Please sign in to comment.