Skip to content

Commit

Permalink
Review comments - part 1
Browse files Browse the repository at this point in the history
Co-authored-by: Natalia Gavrilenko <[email protected]>
  • Loading branch information
natgavrilenko and Natalia Gavrilenko committed Aug 5, 2024
1 parent bd8b5c5 commit d59c8e1
Show file tree
Hide file tree
Showing 39 changed files with 93 additions and 80 deletions.
7 changes: 3 additions & 4 deletions benchmarks/opencl/CORR.cl
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
// clspv CORR.cl --cl-std=CL2.0 --inline-entry-points --spv-version=1.6
// spirv-dis a.spv

__kernel void test(global atomic_uint* x, global uint* r0, global uint* r1, global uint* r2, global uint* r3) {
// T0 @ wg 0
if (get_group_id(0) == 0) {
*r0 = atomic_load_explicit(x, memory_order_relaxed);
*r1 = atomic_load_explicit(x, memory_order_relaxed);
}
// T1 @ wg 1
if(get_group_id(0) == 1) {
*r2 = atomic_load_explicit(x, memory_order_relaxed);
*r3 = atomic_load_explicit(x, memory_order_relaxed);
}
// T2 @ wg 2
if(get_group_id(0) == 2) {
atomic_store_explicit(x, 2, memory_order_relaxed);
}
// T3 @ wg 3
if(get_group_id(0) == 3) {
atomic_store_explicit(x, 1, memory_order_relaxed);
}
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/opencl/IRIW.cl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// clspv IRIW.cl --cl-std=CL2.0 --inline-entry-points --spv-version=1.6
// spirv-dis a.spv

__kernel void test(global atomic_uint* x, global atomic_uint* y, global uint* r0, global uint* r1, global uint* r2, global uint* r3) {
if (get_local_id(0) == 0) {
atomic_store_explicit(x, 1, memory_order_release);
Expand Down
4 changes: 4 additions & 0 deletions benchmarks/opencl/MP.cl
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
// clspv MP.cl --cl-std=CL2.0 --inline-entry-points --spv-version=1.6
// spirv-dis a.spv

#ifdef ACQ2RX
#define mo_acq memory_order_relaxed
#else
#define mo_acq memory_order_acquire
#endif

#ifdef REL2RX
#define mo_rel memory_order_relaxed
#else
Expand Down
4 changes: 4 additions & 0 deletions benchmarks/opencl/SB.cl
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
// clspv SB.cl --cl-std=CL2.0 --inline-entry-points --spv-version=1.6
// spirv-dis a.spv

#ifdef ACQ2RX
#define mo_acq memory_order_relaxed
#else
#define mo_acq memory_order_acquire
#endif

#ifdef REL2RX
#define mo_rel memory_order_relaxed
#else
Expand Down
8 changes: 6 additions & 2 deletions benchmarks/opencl/caslock-sc.cl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
static void lock(global uint* l) {
// clspv caslock-cs.cl --cl-std=CL1.0 --inline-entry-points --spv-version=1.6
// spirv-dis a.spv

void lock(global uint* l) {
while (atom_cmpxchg(l, 0, 1) == 1) {}
}

static void unlock(global uint* l) {
void unlock(global uint* l) {
atom_xchg(l, 0);
}

Expand All @@ -12,4 +15,5 @@ __kernel void mutex_test(global uint* l, global int* x, global int* A) {
a = *x;
*x = a + 1;
unlock(l);
A[get_global_id(0)] = a;
}
6 changes: 5 additions & 1 deletion benchmarks/opencl/caslock.cl
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
// clspv caslock.cl --cl-std=CL2.0 --inline-entry-points --spv-version=1.6
// spirv-dis a.spv

#ifdef ACQ2RX
#define mo_lock memory_order_relaxed
#else
#define mo_lock memory_order_acquire
#endif

#ifdef REL2RX
#define mo_unlock memory_order_relaxed
#else
Expand Down Expand Up @@ -33,4 +37,4 @@ __kernel void mutex_test(global atomic_uint* l, global int* x, global int* A) {
*x = a + 1;
unlock(l);
A[get_global_id(0)] = a;
}
}
4 changes: 4 additions & 0 deletions benchmarks/opencl/ticketlock.cl
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
// clspv ticketlock.cl --cl-std=CL2.0 --inline-entry-points --spv-version=1.6
// spirv-dis a.spv

#ifdef ACQ2RX
#define mo_lock memory_order_relaxed
#else
#define mo_lock memory_order_acquire
#endif

#ifdef REL2RX
#define mo_unlock memory_order_relaxed
#else
Expand Down
4 changes: 4 additions & 0 deletions benchmarks/opencl/ttaslock.cl
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
// clspv ttaslock.cl --cl-std=CL2.0 --inline-entry-points --spv-version=1.6
// spirv-dis a.spv

#ifdef ACQ2RX
#define mo_lock memory_order_relaxed
#else
#define mo_lock memory_order_acquire
#endif

#ifdef REL2RX
#define mo_unlock memory_order_relaxed
#else
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/opencl/xf-barrier.cl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// clspv xf-barrier.cl --cl-std=CL2.0 --inline-entry-points --spv-version=1.6
// spirv-dis a.spv

#ifdef FAIL1
#define mo1 memory_order_relaxed
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public static ProgramEncoder withContext(EncodingContext context) throws Invalid

public BooleanFormula encodeFullProgram() {
return context.getBooleanFormulaManager().and(
encodeControlBarrier(),
encodeControlBarriers(),
encodeConstants(),
encodeMemory(),
encodeControlFlow(),
Expand All @@ -81,7 +81,7 @@ public BooleanFormula encodeFullProgram() {
encodeDependencies());
}

public BooleanFormula encodeControlBarrier() {
public BooleanFormula encodeControlBarriers() {
BooleanFormulaManager bmgr = context.getBooleanFormulaManager();
BooleanFormula enc = bmgr.makeTrue();
Map<Integer, List<ControlBarrier>> groups = context.getTask().getProgram().getThreads().stream()
Expand Down Expand Up @@ -333,6 +333,7 @@ private int getWorkgroupId(Thread thread) {
}
return id;
}
return -1;
throw new IllegalArgumentException("Attempt to compute workgroup ID " +
"for a non-hierarchical thread");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,7 @@ private TrackableFormula encodeProgramSpecification() {
private BooleanFormula encodeProgramTermination() {
final BooleanFormulaManager bmgr = context.getBooleanFormulaManager();
return bmgr.and(program.getThreads().stream()
.map(t -> {
BooleanFormula started = context.execution(t.getEntry());
BooleanFormula finished = context.execution(t.getEvents().get(t.getEvents().size() - 1));
return bmgr.equivalence(started, finished);
})
.map(t -> bmgr.equivalence(context.execution(t.getEntry()), context.execution(t.getExit())))
.toList());
}

Expand Down Expand Up @@ -448,13 +444,9 @@ public TrackableFormula encodeDeadlocks() {
final Map<Thread, List<SpinIteration>> spinloopsMap =
Maps.toMap(program.getThreads(), t -> this.findSpinLoopsInThread(t, loopAnalysis));
// Compute "stuckness" encoding for all threads
final Map<Thread, BooleanFormula> isStuckMap = Maps.toMap(program.getThreads(), t -> {
List<BooleanFormula> stuckAtBarrier = t.getEvents().stream()
.filter(ControlBarrier.class::isInstance)
.map(e -> bmgr.and(context.controlFlow(e), bmgr.not(context.execution(e))))
.toList();
return bmgr.or(bmgr.or(stuckAtBarrier), this.generateStucknessEncoding(spinloopsMap.get(t), context));
});
final Map<Thread, BooleanFormula> isStuckMap = Maps.toMap(program.getThreads(), t ->
bmgr.or(generateBarrierStucknessEncoding(t, context),
this.generateSpinloopStucknessEncoding(spinloopsMap.get(t), context)));

// Deadlock <=> allStuckOrDone /\ atLeastOneStuck
BooleanFormula allStuckOrDone = bmgr.makeTrue();
Expand All @@ -476,9 +468,17 @@ public TrackableFormula encodeDeadlocks() {
return new TrackableFormula(bmgr.not(LIVENESS.getSMTVariable(context)), hasDeadlock);
}

private BooleanFormula generateBarrierStucknessEncoding(Thread thread, EncodingContext context) {
final BooleanFormulaManager bmgr = context.getBooleanFormulaManager();
return bmgr.or(thread.getEvents().stream()
.filter(ControlBarrier.class::isInstance)
.map(e -> bmgr.and(context.controlFlow(e), bmgr.not(context.execution(e))))
.toList());
}

// Compute "stuckness": A thread is stuck if it reaches a spin loop bound event
// while only reading from co-maximal stores.
private BooleanFormula generateStucknessEncoding(List<SpinIteration> loops, EncodingContext context) {
private BooleanFormula generateSpinloopStucknessEncoding(List<SpinIteration> loops, EncodingContext context) {
final BooleanFormulaManager bmgr = context.getBooleanFormulaManager();
final RelationAnalysis ra = PropertyEncoder.this.ra;
final Relation rf = memoryModel.getRelation(RelationNameRepository.RF);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

public class ScopedPointerType extends IntegerType {

private static final int ARCH_SIZE = TypeFactory.getInstance().getArchType().getBitWidth();

private final String scopeId;
private final Type pointedType;

ScopedPointerType(String scopeId, Type pointedType) {
super(64);
super(ARCH_SIZE);
this.scopeId = scopeId;
this.pointedType = pointedType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class VisitorSpirvOutput extends SpirvBaseVisitor<Expression> {
private final Map<Location, Type> locationTypes = new HashMap<>();
private final ProgramBuilder builder;
private Program.SpecificationType type;
private Expression assertion;
private Expression condition;

public VisitorSpirvOutput(ProgramBuilder builder) {
this.builder = builder;
Expand All @@ -48,9 +48,9 @@ public Expression visitSpvHeaders(SpirvParser.SpvHeadersContext ctx) {
}
if (type == null) {
type = FORALL;
assertion = ExpressionFactory.getInstance().makeTrue();
condition = ExpressionFactory.getInstance().makeTrue();
}
builder.setSpecification(type, assertion);
builder.setSpecification(type, condition);
return null;
}

Expand Down Expand Up @@ -109,14 +109,14 @@ public Expression visitAssertionValue(SpirvParser.AssertionValueContext ctx) {
}

private void appendAssertion(Program.SpecificationType newType, Expression expression) {
if (assertion == null) {
if (condition == null) {
type = newType;
assertion = expression;
condition = expression;
} else if (newType.equals(type)) {
if (type.equals(FORALL)) {
assertion = ExpressionFactory.getInstance().makeAnd(assertion, expression);
condition = ExpressionFactory.getInstance().makeAnd(condition, expression);
} else if (type.equals(NOT_EXISTS)) {
assertion = ExpressionFactory.getInstance().makeOr(assertion, expression);
condition = ExpressionFactory.getInstance().makeOr(condition, expression);
} else {
throw new ParsingException("Multiline assertion is not supported for type " + newType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class ControlFlowBuilder {
protected final Map<String, Label> blockLabels = new HashMap<>();
protected final Map<String, Event> lastBlockEvents = new HashMap<>();
protected final Map<String, String> mergeLabelIds = new HashMap<>();
protected final Deque<String> blockStack = new ArrayDeque<>();
protected final Deque<String> blockStack = new ArrayDeque<>();
protected final Map<String, Map<Register, String>> phiDefinitions = new HashMap<>();
protected final Map<String, Expression> expressions;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ public void setEntryPointId(String id) {
entryPointId = id;
}

public void setSpecification(Program.SpecificationType type, Expression assertion) {
public void setSpecification(Program.SpecificationType type, Expression condition) {
if (program.getSpecification() != null) {
throw new ParsingException("Attempt to override program specification");
}
program.setSpecification(type, assertion);
program.setSpecification(type, condition);
}

public boolean hasInput(String id) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ private Expression makeScalar(String id, Type type, int x) {

private IntegerType getArrayElementType(String id, Type type) {
if (type instanceof ArrayType aType && aType.getNumElements() == 3) {
return getIntegerType(id, ((ArrayType) type).getElementType());
return getIntegerType(id, aType.getElementType());
}
throw new ParsingException("Illegal type of element '%s', " +
"expected array of three elements but received '%s'", id, type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public VisitorExtensionClspvReflection(ProgramBuilder builder) {

@Override
public Void visitKernel(SpirvParser.KernelContext ctx) {
// Do nothing, kernel name and the number of argument
// Do nothing, kernel name and the number of arguments
return null;
}

Expand All @@ -59,7 +59,7 @@ public Void visitArgumentWorkgroup(SpirvParser.ArgumentWorkgroupContext ctx) {

@Override
public Void visitSpecConstantWorkgroupSize(SpirvParser.SpecConstantWorkgroupSizeContext ctx) {
// Do nothing, will be overwritten but BuiltIn WorkgroupSize
// Do nothing, will be overwritten by BuiltIn WorkgroupSize
return null;
}

Expand Down Expand Up @@ -123,7 +123,7 @@ private Void setPushConstantValue(String decorationId, String sizeId) {
if (type instanceof ArrayType aType && aType.getNumElements() == 3 && typeSize == expectedSize) {
Type elType = aType.getElementType();
if (elType instanceof IntegerType iType) {
List<Integer> values = getPushConstantValue(decorationId);
List<Integer> values = computePushConstantValue(decorationId);
int localOffset = 0;
for (int value : values) {
Expression elExpr = expressions.makeValue(value, iType);
Expand All @@ -139,7 +139,7 @@ private Void setPushConstantValue(String decorationId, String sizeId) {
pushConstant.getId(), pushConstantIndex);
}

private List<Integer> getPushConstantValue(String command) {
private List<Integer> computePushConstantValue(String command) {
ThreadGrid grid = builder.getThreadGrid();
return switch (command) {
case "PushConstantGlobalSize" -> List.of(grid.dvSize(), 1, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,11 @@

public class HelperTags {
private static final List<String> scopes = mkScopesList();
private static final Set<String> moStrong = mkStrongMemoryOrderSet();
private static final Map<Integer, String> semantics = mkSemanticsMap();

private HelperTags() {
}

public static boolean isMemorySemanticsNone(String id, Expression expr) {
return getIntValue(id, expr) == 0;
}

public static Set<String> parseMemorySemanticsTags(String id, Expression expr) {
int value = getIntValue(id, expr);
Set<String> tags = new HashSet<>();
Expand All @@ -32,7 +27,7 @@ public static Set<String> parseMemorySemanticsTags(String id, Expression expr) {
tags.add(semantics.get(i));
}
}
int moSize = Sets.intersection(moStrong, tags).size();
int moSize = Sets.intersection(moTags, tags).size();
if (moSize > 1) {
throw new ParsingException("Selected multiple non-relaxed memory order bits");
}
Expand Down Expand Up @@ -91,15 +86,6 @@ private static List<String> mkScopesList() {
SHADER_CALL);
}

private static Set<String> mkStrongMemoryOrderSet() {
return Set.of(
ACQUIRE,
RELEASE,
ACQ_REL,
SEQ_CST
);
}

private static Map<Integer, String> mkSemanticsMap() {
Map<Integer, String> map = new HashMap<>();
map.put(0x2, ACQUIRE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ public ThreadGrid(int sg, int wg, int qf, int dv) {
if (elements.stream().anyMatch(i -> i <= 0)) {
throw new ParsingException("Thread grid dimensions must be positive");
}
if (elements.stream().reduce(1, (a, b) -> a * b) > 128) {
throw new ParsingException("Thread grid dimensions must be less than 128");
}
this.sg = sg;
this.wg = wg;
this.qf = qf;
Expand Down
Loading

0 comments on commit d59c8e1

Please sign in to comment.