Skip to content

Commit

Permalink
Separate class for thread grid
Browse files Browse the repository at this point in the history
  • Loading branch information
natgavrilenko committed Jul 15, 2024
1 parent a975028 commit ce8bc10
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.dat3m.dartagnan.parsers.SpirvParser;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.*;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ProgramBuilder;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid;
import com.dat3m.dartagnan.program.Program;
import com.dat3m.dartagnan.program.Register;
import org.antlr.v4.runtime.tree.ParseTree;
Expand All @@ -13,7 +14,6 @@
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

Expand Down Expand Up @@ -105,7 +105,7 @@ public Program visitSpv(SpirvParser.SpvContext ctx) {
}

private ProgramBuilder createBuilder(SpirvParser.SpvContext ctx) {
List<Integer> threadGrid = List.of(1, 1, 1, 1);
ThreadGrid grid = new ThreadGrid(1, 1, 1, 1);
VisitorSpirvInput visitor = new VisitorSpirvInput();
boolean hasConfig = false;
for (SpirvParser.SpvHeaderContext header : ctx.spvHeaders().spvHeader()) {
Expand All @@ -120,10 +120,10 @@ private ProgramBuilder createBuilder(SpirvParser.SpvContext ctx) {
int threadAmount = Integer.parseInt(header.configHeader().literanHeaderUnsignedInteger().get(0).getText());
int subGroupAmount = Integer.parseInt(header.configHeader().literanHeaderUnsignedInteger().get(1).getText());
int workGroupAmount = Integer.parseInt(header.configHeader().literanHeaderUnsignedInteger().get(2).getText());
threadGrid = List.of(threadAmount, subGroupAmount, workGroupAmount, 1);
grid = new ThreadGrid(threadAmount, subGroupAmount, workGroupAmount, 1);
}
}
return new ProgramBuilder(threadGrid, visitor.getInputs());
return new ProgramBuilder(grid, visitor.getInputs());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,29 @@
import com.dat3m.dartagnan.expression.type.ArrayType;
import com.dat3m.dartagnan.expression.type.IntegerType;
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid;
import com.dat3m.dartagnan.program.memory.MemoryObject;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

public class BuiltIn implements Decoration {

private static final TypeFactory types = TypeFactory.getInstance();
private static final ExpressionFactory expressions = ExpressionFactory.getInstance();
private static final int GRID_SIZE = 4;
// grid(0) - number of threads in a subgroup
// grid(1) - number of subgroups in a local workgroup
// assuming sgSize <= wgSize and a flat workgroup
// grid(2) - number of local workgroups in a queue family (global wg)
// assuming a flat workgroup
// grid(3) - number of queue families (global wg) in a device
private final List<Integer> grid;
private final List<Integer> threadId;
private final ThreadGrid grid;
private final Map<String, String> mapping;
private int tid;

public BuiltIn(List<Integer> grid) {
if (grid.size() != 4 || grid.stream().anyMatch(x -> x <= 0)) {
throw new ParsingException("Illegal BuiltIn size (%s)", grid);
}
public BuiltIn(ThreadGrid grid) {
this.grid = grid;
this.threadId = new ArrayList<>(Stream.generate(() -> 0)
.limit(GRID_SIZE).toList());
this.mapping = new HashMap<>();
}

public void setHierarchy(List<Integer> threadId) {
if (threadId.stream().anyMatch(e -> e < 0) || threadId.size() != GRID_SIZE) {
throw new ParsingException("Illegal BuiltIn hierarchy %s",
String.join(", ", threadId.stream().map(Object::toString).toList()));
}
this.threadId.clear();
this.threadId.addAll(threadId);
}

public int getGlobalIdAtIndex(int idx) {
if (idx >= GRID_SIZE) {
throw new ParsingException("Illegal thread hierarchy index %d", idx);
}
int id = 0;
for (int i = idx; i < GRID_SIZE; i++) {
int v = threadId.get(i);
for (int j = idx; j < i; j++) {
v *= grid.get(j);
}
id += v;
}
return id;
public void setThreadId(int tid) {
this.tid = tid;
}

@Override
Expand Down Expand Up @@ -106,15 +74,15 @@ public Expression getDecoration(String id, Type type) {

private Expression getDecorationExpressions(String id, Type type) {
return switch (mapping.get(id)) {
case "SubgroupLocalInvocationId" -> makeScalar(id, type, threadId.get(0));
case "LocalInvocationId" -> makeArray(id, type, threadId.get(0) + threadId.get(1) * grid.get(0), 0, 0);
case "LocalInvocationIndex" -> makeScalar(id, type, threadId.get(0) + threadId.get(1) * grid.get(0)); // scalar of LocalInvocationId
case "GlobalInvocationId" -> makeArray(id, type, threadId.get(0) + threadId.get(1) * grid.get(0) + threadId.get(2) * grid.get(0) * grid.get(1) + threadId.get(3) * grid.get(0) * grid.get(1) * grid.get(2), 0, 0);
case "SubgroupLocalInvocationId" -> makeScalar(id, type, tid % grid.sgSize());
case "LocalInvocationId" -> makeArray(id, type, tid % grid.wgSize(), 0, 0);
case "LocalInvocationIndex" -> makeScalar(id, type, tid % grid.wgSize()); // scalar of LocalInvocationId
case "GlobalInvocationId" -> makeArray(id, type, tid % grid.dvSize(), 0, 0);
case "DeviceIndex" -> makeScalar(id, type, 0);
case "SubgroupId" -> makeScalar(id, type, threadId.get(1));
case "WorkgroupId" -> makeArray(id, type, threadId.get(2), 0, 0);
case "SubgroupSize" -> makeScalar(id, type, grid.get(0));
case "WorkgroupSize" -> makeArray(id, type, grid.get(0) * grid.get(1), 1, 1);
case "SubgroupId" -> makeScalar(id, type, grid.sgId(tid));
case "WorkgroupId" -> makeArray(id, type, grid.wgId(tid), 0, 0);
case "SubgroupSize" -> makeScalar(id, type, grid.sgSize());
case "WorkgroupSize" -> makeArray(id, type, grid.wgSize(), 1, 1);
default -> throw new ParsingException("Unsupported decoration '%s'", mapping.get(id));
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.parsers.SpirvParser;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ProgramBuilder;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid;
import com.dat3m.dartagnan.program.memory.ScopedPointerVariable;
import com.dat3m.dartagnan.program.event.Tag;

Expand Down Expand Up @@ -137,11 +138,11 @@ private Void setPushConstantValue(String decorationId, String sizeId) {
}

private List<Integer> getPushConstantValue(String command) {
List<Integer> grid = builder.getThreadGrid();
ThreadGrid grid = builder.getThreadGrid();
return switch (command) {
case "PushConstantGlobalSize" -> List.of(grid.get(0) * grid.get(1) * grid.get(2), 1, 1);
case "PushConstantEnqueuedLocalSize" -> List.of(grid.get(0) * grid.get(1), 1, 1);
case "PushConstantNumWorkgroups" -> List.of(grid.get(2), 1, 1);
case "PushConstantGlobalSize" -> List.of(grid.dvSize(), 1, 1);
case "PushConstantEnqueuedLocalSize" -> List.of(grid.wgSize(), 1, 1);
case "PushConstantNumWorkgroups" -> List.of(grid.qfSize() / grid.wgSize(), 1, 1);
case "PushConstantGlobalOffset",
"PushConstantRegionOffset",
"PushConstantRegionGroupOffset"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.Decoration;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.SpecId;
import com.dat3m.dartagnan.parsers.program.visitors.spirv.utils.ThreadGrid;

import java.util.EnumMap;
import java.util.List;

import static com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType.BUILT_IN;
import static com.dat3m.dartagnan.parsers.program.visitors.spirv.decorations.DecorationType.SPEC_ID;
Expand All @@ -16,8 +16,8 @@ public class HelperDecorations {

private final EnumMap<DecorationType, Decoration> mapping = new EnumMap<>(DecorationType.class);

public HelperDecorations(List<Integer> threadGrid) {
mapping.put(BUILT_IN, new BuiltIn(threadGrid));
public HelperDecorations(ThreadGrid grid) {
mapping.put(BUILT_IN, new BuiltIn(grid));
mapping.put(SPEC_ID, new SpecId());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import com.dat3m.dartagnan.program.misc.NonDetValue;

import java.util.*;
import java.util.function.IntUnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -23,38 +24,41 @@ public class MemoryTransformer extends ExprTransformer {

private static final List<String> namePrefixes = List.of("T", "S", "W", "Q");

private final List<Integer> grid;
private final Function function;
private final BuiltIn builtIn;
private final List<Integer> threadId;
private final List<? extends Map<MemoryObject, MemoryObject>> scopeMapping;
private final Map<MemoryObject, ScopedPointerVariable> pointerMapping;
private final List<IntUnaryOperator> scopeIdProvider;
private final List<IntUnaryOperator> namePrefixIdxProvider;
private Map<Register, Register> registerMapping;
private final List<? extends Map<MemoryObject, MemoryObject>> memoryScopeMapping;
private final Map<MemoryObject, ScopedPointerVariable> memoryPointersMapping;
private int tid;

public MemoryTransformer(List<Integer> grid, Function function, BuiltIn builtIn, Set<ScopedPointerVariable> variables) {
this.grid = grid;
public MemoryTransformer(ThreadGrid grid, Function function, BuiltIn builtIn, Set<ScopedPointerVariable> variables) {
this.function = function;
this.builtIn = builtIn;
this.threadId = new ArrayList<>(Stream.generate(() -> 0).limit(grid.size()).toList());
this.memoryScopeMapping = Stream.generate(() -> new HashMap<MemoryObject, MemoryObject>()).limit(grid.size()).toList();
this.memoryPointersMapping = variables.stream().collect(Collectors.toMap((ScopedPointerVariable::getAddress), (v -> v)));
this.scopeMapping = Stream.generate(() -> new HashMap<MemoryObject, MemoryObject>()).limit(1L + ThreadGrid.DEPTH).toList();
this.pointerMapping = variables.stream().collect(Collectors.toMap((ScopedPointerVariable::getAddress), (v -> v)));
this.scopeIdProvider = List.of(grid::thId, grid::sgId, grid::wgId, grid::qfId, grid::dvId);
this.namePrefixIdxProvider = List.of(
i -> i,
i -> i / grid.sgSize(),
i -> i / grid.wgSize(),
i -> i / grid.qfSize(),
i -> i / grid.dvSize());
}

public Register getRegisterMapping(Register register) {
return registerMapping.get(register);
}

public void setThread(Thread thread) {
List<Integer> newThreadId = getThreadHierarchicalId(thread);
for (int i = 0; i < newThreadId.size(); i++) {
if (!threadId.get(i).equals(newThreadId.get(i))) {
for (int j = 0; j <= i; j++) {
memoryScopeMapping.get(j).clear();
}
}
threadId.set(i, newThreadId.get(i));
int newTid = thread.getId();
int depth = getScopeIdx(newTid, scopeIdProvider);
for (int i = 0; i <= depth; i++) {
scopeMapping.get(i).clear();
}
builtIn.setHierarchy(newThreadId);
tid = newTid;
builtIn.setThreadId(tid);
registerMapping = function.getRegisters().stream().collect(
toMap(r -> r, r -> thread.getOrNewRegister(r.getName(), r.getType())));
}
Expand All @@ -66,7 +70,7 @@ public Expression visitRegister(Register register) {

@Override
public Expression visitMemoryObject(MemoryObject memObj) {
String storageClass = memoryPointersMapping.get(memObj).getScopeId();
String storageClass = pointerMapping.get(memObj).getScopeId();
return switch (storageClass) {
case Tag.Spirv.SC_UNIFORM_CONSTANT,
Tag.Spirv.SC_UNIFORM,
Expand All @@ -83,37 +87,38 @@ public Expression visitMemoryObject(MemoryObject memObj) {
};
}

private Expression applyMapping(MemoryObject memObj, int scopeLevel) {
private Expression applyMapping(MemoryObject memObj, int scopeDepth) {
Program program = function.getProgram();
Map<MemoryObject, MemoryObject> mapping = memoryScopeMapping.get(scopeLevel);
Map<MemoryObject, MemoryObject> mapping = scopeMapping.get(scopeDepth);
if (!mapping.containsKey(memObj)) {
MemoryObject copy = memObj instanceof VirtualMemoryObject
? program.getMemory().allocateVirtual(memObj.size(), true, null)
: program.getMemory().allocate(memObj.size());

copy.setName(makeVariableName(scopeLevel, memObj.getName()));
copy.setName(makeVariableName(scopeDepth, memObj.getName()));
for (int offset : memObj.getInitializedFields()) {
Expression value = memObj.getInitialValue(offset);
if (value instanceof NonDetValue) {
value = program.newConstant(value.getType());
}
copy.setInitialValue(offset, value);
}
builtIn.decorate(memObj.getName(), copy, memoryPointersMapping.get(memObj).getInnerType());
builtIn.decorate(memObj.getName(), copy, pointerMapping.get(memObj).getInnerType());
mapping.put(memObj, copy);
}
return mapping.getOrDefault(memObj, memObj);
}

private List<Integer> getThreadHierarchicalId(Thread thread) {
int sgId = thread.getScopeHierarchy().getScopeId(Tag.Vulkan.SUB_GROUP);
int wgId = thread.getScopeHierarchy().getScopeId(Tag.Vulkan.WORK_GROUP);
int qfId = thread.getScopeHierarchy().getScopeId(Tag.Vulkan.QUEUE_FAMILY);
int thId = thread.getId() % grid.get(0);
return List.of(thId, sgId, wgId, qfId);
private int getScopeIdx(int newTid, List<IntUnaryOperator> f) {
for (int i = f.size() - 1; i >= 0; i--) {
if (f.get(i).applyAsInt(newTid) != f.get(i).applyAsInt(tid)) {
return i;
}
}
return -1;
}

private String makeVariableName(int idx, String base) {
return String.format("%s@%s%s", base, namePrefixes.get(idx), builtIn.getGlobalIdAtIndex(idx));
return String.format("%s@%s%s", base, namePrefixes.get(idx),
namePrefixIdxProvider.get(idx).applyAsInt(tid));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public class ProgramBuilder {
protected final Map<String, Type> types = new HashMap<>();
protected final Map<String, Expression> expressions = new HashMap<>();
protected final Map<String, Function> forwardFunctions = new HashMap<>();
protected final List<Integer> threadGrid;
protected final ThreadGrid grid;
protected final Map<String, Expression> inputs;
public final Program program;

Expand All @@ -47,12 +47,11 @@ public class ProgramBuilder {
private final HelperDecorations helperDecorations;
protected ControlFlowBuilder cfBuilder = new ControlFlowBuilder(expressions);

public ProgramBuilder(List<Integer> threadGrid, Map<String, Expression> inputs) {
validateThreadGrid(threadGrid);
this.threadGrid = threadGrid;
public ProgramBuilder(ThreadGrid grid, Map<String, Expression> inputs) {
this.grid = grid;
this.inputs = inputs;
this.program = new Program(new Memory(), Program.SourceLanguage.SPV);
this.helperDecorations = new HelperDecorations(threadGrid);
this.helperDecorations = new HelperDecorations(grid);
}

public Program build() {
Expand All @@ -63,7 +62,7 @@ public Program build() {
.filter(ScopedPointerVariable.class::isInstance)
.map(v -> (ScopedPointerVariable) v)
.collect(Collectors.toSet());
new ThreadCreator(threadGrid, getEntryPointFunction(), variables, builtIn).create();
new ThreadCreator(grid, getEntryPointFunction(), variables, builtIn).create();
checkSpecification();
return program;
}
Expand Down Expand Up @@ -304,20 +303,8 @@ public String getCurrentFunctionName() {
return getCurrentFunctionOrThrowError().getName();
}

private void validateThreadGrid(List<Integer> threadGrid) {
if (threadGrid.size() != 4) {
throw new ParsingException("Thread grid must have 4 dimensions");
}
if (threadGrid.stream().anyMatch(i -> i <= 0)) {
throw new ParsingException("Thread grid dimensions must be positive");
}
if (threadGrid.stream().reduce(1, (a, b) -> a * b) > 128) {
throw new ParsingException("Thread grid dimensions must be less than 128");
}
}

public List<Integer> getThreadGrid() {
return threadGrid;
public ThreadGrid getThreadGrid() {
return grid;
}

public Set<String> getNextOps() {
Expand Down
Loading

0 comments on commit ce8bc10

Please sign in to comment.