Skip to content

Commit

Permalink
Cleanup vulkan compilation visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
natgavrilenko committed Nov 19, 2024
1 parent 9160df5 commit fa2d527
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -357,14 +357,14 @@ public static List<String> getScopeTags() {
public static String loadMO(String mo) {
return switch (mo) {
case ACQ_REL, ACQUIRE -> ACQUIRE;
default -> "";
default -> ATOM;
};
}

public static String storeMO(String mo) {
return switch (mo) {
case ACQ_REL, RELEASE -> RELEASE;
default -> "";
default -> ATOM;
};
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public List<Event> visitSpirvXchg(SpirvXchg e) {
VulkanRMW rmw = EventFactory.Vulkan.newRMW(e.getAddress(), e.getResultRegister(),
e.getValue(), mo, scope);
rmw.addTags(toVulkanTags(e.getTags()));
replaceAcqRelTag(rmw, Tag.Vulkan.ACQUIRE, Tag.Vulkan.RELEASE);
rmw.setFunction(e.getFunction());
return visitVulkanRMW(rmw);
}
Expand All @@ -82,6 +83,7 @@ public List<Event> visitSpirvRMW(SpirvRmw e) {
e.getOperand(), e.getOperator(), mo, scope);
rmwOp.setFunction(e.getFunction());
rmwOp.addTags(toVulkanTags(e.getTags()));
replaceAcqRelTag(rmwOp, Tag.Vulkan.ACQUIRE, Tag.Vulkan.RELEASE);
return visitVulkanRMWOp(rmwOp);
}

Expand All @@ -105,6 +107,7 @@ public List<Event> visitSpirvCmpXchg(SpirvCmpXchg e) {
e.getExpectedValue(), e.getStoreValue(), moToVulkanTag(spvMoEq), scope);
cmpXchg.setFunction(e.getFunction());
cmpXchg.addTags(toVulkanTags(eqTags));
replaceAcqRelTag(cmpXchg, Tag.Vulkan.ACQUIRE, Tag.Vulkan.RELEASE);

return visitVulkanCmpXchg(cmpXchg);
}
Expand All @@ -118,6 +121,7 @@ public List<Event> visitSpirvRmwExtremum(SpirvRmwExtremum e) {
e.getOperator(), e.getValue(), mo, scope);
rmw.setFunction(e.getFunction());
rmw.addTags(toVulkanTags(e.getTags()));
replaceAcqRelTag(rmw, Tag.Vulkan.ACQUIRE, Tag.Vulkan.RELEASE);
return visitVulkanRMWExtremum(rmw);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,29 @@
import com.dat3m.dartagnan.program.event.core.*;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import static com.dat3m.dartagnan.program.event.EventFactory.*;

public class VisitorVulkan extends VisitorBase {

private static final Set<String> commonTags = Set.of(
Tag.Vulkan.SUB_GROUP, Tag.Vulkan.WORK_GROUP,
Tag.Vulkan.QUEUE_FAMILY, Tag.Vulkan.DEVICE,
Tag.Vulkan.NON_PRIVATE, Tag.Vulkan.ATOM,
Tag.Vulkan.SC0, Tag.Vulkan.SC1);

private static final Set<String> semScTags = Set.of(Tag.Vulkan.SEMSC0, Tag.Vulkan.SEMSC1);

private static final Set<String> loadTags = Set.of(
Tag.Vulkan.ACQUIRE, Tag.Vulkan.VISIBLE, Tag.Vulkan.SEM_VISIBLE);

private static final Set<String> storeTags = Set.of(
Tag.Vulkan.RELEASE, Tag.Vulkan.AVAILABLE, Tag.Vulkan.SEM_AVAILABLE);

private int labelIdx = 0;

@Override
public List<Event> visitVulkanRMW(VulkanRMW e) {
Register resultRegister = e.getResultRegister();
Expand All @@ -24,8 +42,8 @@ public List<Event> visitVulkanRMW(VulkanRMW e) {
Register dummy = e.getFunction().newRegister(resultRegister.getType());
Load load = newRMWLoadWithMo(dummy, address, Tag.Vulkan.loadMO(mo));
RMWStore store = newRMWStoreWithMo(load, address, e.getValue(), Tag.Vulkan.storeMO(mo));
this.propagateTags(e, load);
this.propagateTags(e, store);
propagateLoadTags(e, load);
propagateStoreTags(e, store);
return eventSequence(
load,
store,
Expand All @@ -42,8 +60,8 @@ public List<Event> visitVulkanRMWOp(VulkanRMWOp e) {
Load load = newRMWLoadWithMo(dummy, address, Tag.Vulkan.loadMO(mo));
RMWStore store = newRMWStoreWithMo(load, address,
expressions.makeIntBinary(dummy, e.getOperator(), e.getOperand()), Tag.Vulkan.storeMO(mo));
this.propagateTags(e, load);
this.propagateTags(e, store);
propagateLoadTags(e, load);
propagateStoreTags(e, store);
return eventSequence(
load,
store,
Expand All @@ -61,8 +79,8 @@ public List<Event> visitVulkanRMWExtremum(VulkanRMWExtremum e) {
Expression cmpExpr = expressions.makeIntCmp(dummy, e.getOperator(), e.getValue());
Expression ite = expressions.makeITE(cmpExpr, dummy, e.getValue());
RMWStore store = newRMWStoreWithMo(load, address, ite, Tag.Vulkan.storeMO(mo));
this.propagateTags(e, load);
this.propagateTags(e, store);
propagateLoadTags(e, load);
propagateStoreTags(e, store);
return eventSequence(
load,
store,
Expand All @@ -78,13 +96,13 @@ public List<Event> visitVulkanCmpXchg(VulkanCmpXchg e) {
Expression expected = e.getExpectedValue();
Expression value = e.getStoreValue();
Register cmpResultRegister = e.getFunction().newRegister(types.getBooleanType());
Label casEnd = newLabel("CAS_end");
Label casEnd = newLabel("CAS_end_" + labelIdx++);
Load load = newRMWLoadWithMo(resultRegister, address, Tag.Vulkan.loadMO(mo));
RMWStore store = newRMWStoreWithMo(load, address, value, Tag.Vulkan.storeMO(mo));
Local local = newLocal(cmpResultRegister, expressions.makeEQ(resultRegister, expected));
CondJump condJump = newJumpUnless(cmpResultRegister, casEnd);
this.propagateTags(e, load);
this.propagateTags(e, store);
propagateLoadTags(e, load);
propagateStoreTags(e, store);
return eventSequence(
load,
local,
Expand All @@ -94,45 +112,19 @@ public List<Event> visitVulkanCmpXchg(VulkanCmpXchg e) {
);
}

private void propagateTags(Event source, Event target) {
for (String tag : List.of(Tag.Vulkan.SUB_GROUP, Tag.Vulkan.WORK_GROUP, Tag.Vulkan.QUEUE_FAMILY, Tag.Vulkan.DEVICE,
Tag.Vulkan.NON_PRIVATE, Tag.Vulkan.ATOM, Tag.Vulkan.SC0, Tag.Vulkan.SC1, Tag.Vulkan.SEMSC0, Tag.Vulkan.SEMSC1)) {
if (source.hasTag(tag)) {
target.addTags(tag);
}
}
if (target instanceof Load) {
// Atomic loads are always visible
if (source.hasTag(Tag.Vulkan.ATOM)) {
target.addTags(Tag.Vulkan.VISIBLE);
}
if (source.hasTag(Tag.Vulkan.SEM_VISIBLE)) {
target.addTags(Tag.Vulkan.SEM_VISIBLE);
}
// Remove tag if it refers to the release write
if (!source.hasTag(Tag.Vulkan.ACQUIRE) && source.hasTag(Tag.Vulkan.RELEASE)) {
target.removeTags(Tag.Vulkan.SEMSC0);
target.removeTags(Tag.Vulkan.SEMSC1);
}
if (source.hasTag(Tag.Vulkan.VISDEVICE)) {
target.addTags(Tag.Vulkan.VISDEVICE);
}
} else if (target instanceof Store) {
// Atomic stores are always available
if (source.hasTag(Tag.Vulkan.ATOM)) {
target.addTags(Tag.Vulkan.AVAILABLE);
}
if (source.hasTag(Tag.Vulkan.SEM_AVAILABLE)) {
target.addTags(Tag.Vulkan.SEM_AVAILABLE);
}
// Remove tag if it refers to the acquire read
if (!source.hasTag(Tag.Vulkan.RELEASE) && source.hasTag(Tag.Vulkan.ACQUIRE)) {
target.removeTags(Tag.Vulkan.SEMSC0);
target.removeTags(Tag.Vulkan.SEMSC1);
}
if (source.hasTag(Tag.Vulkan.AVDEVICE)) {
target.addTags(Tag.Vulkan.AVDEVICE);
}
}
private void propagateLoadTags(Event source, Event target) {
boolean isAcq = source.hasTag(Tag.Vulkan.ACQUIRE);
Set<String> tags = source.getTags().stream()
.filter(t -> commonTags.contains(t) || loadTags.contains(t) || isAcq && semScTags.contains(t))
.collect(Collectors.toSet());
target.addTags(tags);
}

private void propagateStoreTags(Event source, Event target) {
boolean isRel = source.hasTag(Tag.Vulkan.RELEASE);
Set<String> tags = source.getTags().stream()
.filter(t -> commonTags.contains(t) || storeTags.contains(t) || isRel && semScTags.contains(t))
.collect(Collectors.toSet());
target.addTags(tags);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package com.dat3m.dartagnan.spirv.basic;

import com.dat3m.dartagnan.configuration.Arch;
import com.dat3m.dartagnan.encoding.ProverWithTracker;
import com.dat3m.dartagnan.parsers.cat.ParserCat;
import com.dat3m.dartagnan.parsers.program.ProgramParser;
import com.dat3m.dartagnan.program.Program;
import com.dat3m.dartagnan.utils.Result;
import com.dat3m.dartagnan.verification.VerificationTask;
import com.dat3m.dartagnan.verification.solving.AssumeSolver;
import com.dat3m.dartagnan.verification.solving.RefinementSolver;
import com.dat3m.dartagnan.wmm.Wmm;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.sosy_lab.common.ShutdownManager;
import org.sosy_lab.common.configuration.Configuration;
import org.sosy_lab.common.configuration.InvalidConfigurationException;
import org.sosy_lab.common.log.BasicLogManager;
import org.sosy_lab.java_smt.SolverContextFactory;
import org.sosy_lab.java_smt.api.SolverContext;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.EnumSet;

import static com.dat3m.dartagnan.configuration.Property.CAT_SPEC;
import static com.dat3m.dartagnan.configuration.Property.PROGRAM_SPEC;
import static com.dat3m.dartagnan.utils.ResourceHelper.getRootPath;
import static com.dat3m.dartagnan.utils.ResourceHelper.getTestResourcePath;
import static com.dat3m.dartagnan.utils.Result.*;
import static org.junit.Assert.assertEquals;

@RunWith(Parameterized.class)
public class SpirvChecksTest {

private final String modelPath = getRootPath("cat/spirv-check.cat");
private final String programPath;
private final int bound;
private final Result expected;

public SpirvChecksTest(String file, int bound, Result expected) {
this.programPath = getTestResourcePath("spirv/basic/" + file);
this.bound = bound;
this.expected = expected;
}

@Parameterized.Parameters(name = "{index}: {0}, {1}, {2}")
public static Iterable<Object[]> data() throws IOException {
return Arrays.asList(new Object[][]{
{"empty-exists-false.spv.dis", 1, PASS},
{"empty-exists-true.spv.dis", 1, PASS},
{"empty-forall-false.spv.dis", 1, PASS},
{"empty-forall-true.spv.dis", 1, PASS},
{"empty-not-exists-false.spv.dis", 1, PASS},
{"empty-not-exists-true.spv.dis", 1, PASS},
{"init-forall.spv.dis", 1, PASS},
{"init-forall-split.spv.dis", 1, PASS},
{"init-forall-not-exists.spv.dis", 1, PASS},
{"init-forall-not-exists-fail.spv.dis", 1, PASS},
{"uninitialized-exists.spv.dis", 1, PASS},
{"uninitialized-forall.spv.dis", 1, PASS},
{"uninitialized-private-exists.spv.dis", 1, PASS},
{"uninitialized-private-forall.spv.dis", 1, PASS},
{"undef-exists.spv.dis", 1, PASS},
{"undef-forall.spv.dis", 1, PASS},
{"read-write.spv.dis", 1, PASS},
{"vector-init.spv.dis", 1, PASS},
{"vector.spv.dis", 1, PASS},
{"array.spv.dis", 1, PASS},
{"array-of-vector.spv.dis", 1, PASS},
{"array-of-vector1.spv.dis", 1, PASS},
{"vector-read-write.spv.dis", 1, PASS},
{"spec-id-integer.spv.dis", 1, PASS},
{"spec-id-boolean.spv.dis", 1, PASS},
{"mixed-size.spv.dis", 1, PASS},
{"ids.spv.dis", 1, PASS},
{"builtin-constant.spv.dis", 1, PASS},
{"builtin-variable.spv.dis", 1, PASS},
{"builtin-default-config.spv.dis", 1, PASS},
{"builtin-all-123.spv.dis", 1, PASS},
{"builtin-all-321.spv.dis", 1, PASS},
{"branch-cond-ff.spv.dis", 1, PASS},
{"branch-cond-ff-inverted.spv.dis", 1, PASS},
{"branch-cond-bf.spv.dis", 1, UNKNOWN},
{"branch-cond-bf.spv.dis", 2, PASS},
{"branch-cond-fb.spv.dis", 1, UNKNOWN},
{"branch-cond-fb.spv.dis", 2, PASS},
{"branch-cond-struct.spv.dis", 1, PASS},
{"branch-cond-struct-read-write.spv.dis", 1, PASS},
{"branch-race.spv.dis", 1, PASS},
{"branch-loop.spv.dis", 2, UNKNOWN},
{"branch-loop.spv.dis", 3, PASS},
{"loop-struct-cond.spv.dis", 1, UNKNOWN},
{"loop-struct-cond.spv.dis", 2, PASS},
{"loop-struct-cond-suffix.spv.dis", 1, UNKNOWN},
{"loop-struct-cond-suffix.spv.dis", 2, PASS},
{"loop-struct-cond-sequence.spv.dis", 2, UNKNOWN},
{"loop-struct-cond-sequence.spv.dis", 3, PASS},
{"loop-struct-cond-nested.spv.dis", 2, UNKNOWN},
{"loop-struct-cond-nested.spv.dis", 3, PASS},
{"phi.spv.dis", 1, PASS},
{"phi-unstruct-true.spv.dis", 1, PASS},
{"phi-unstruct-false.spv.dis", 1, PASS},
{"cmpxchg-const-const.spv.dis", 1, PASS},
{"cmpxchg-const-reg.spv.dis", 1, PASS},
{"cmpxchg-reg-const.spv.dis", 1, PASS},
{"cmpxchg-reg-reg.spv.dis", 1, PASS},
{"memory-scopes.spv.dis", 1, PASS},
{"rmw-extremum-true.spv.dis", 1, PASS},
{"rmw-extremum-false.spv.dis", 1, PASS},
{"push-constants.spv.dis", 1, PASS},
{"push-constants-pod.spv.dis", 1, PASS},
{"push-constant-mixed.spv.dis", 1, PASS}
});
}

@Test
public void testAllSolvers() throws Exception {
try (SolverContext ctx = mkCtx(); ProverWithTracker prover = mkProver(ctx)) {
assertEquals(expected, RefinementSolver.run(ctx, prover, mkTask()).getResult());
}
try (SolverContext ctx = mkCtx(); ProverWithTracker prover = mkProver(ctx)) {
assertEquals(expected, AssumeSolver.run(ctx, prover, mkTask()).getResult());
}
}

private SolverContext mkCtx() throws InvalidConfigurationException {
Configuration cfg = Configuration.builder().build();
return SolverContextFactory.createSolverContext(
cfg,
BasicLogManager.create(cfg),
ShutdownManager.create().getNotifier(),
SolverContextFactory.Solvers.Z3);
}

private ProverWithTracker mkProver(SolverContext ctx) {
return new ProverWithTracker(ctx, "", SolverContext.ProverOptions.GENERATE_MODELS);
}

private VerificationTask mkTask() throws Exception {
VerificationTask.VerificationTaskBuilder builder = VerificationTask.builder()
.withConfig(Configuration.builder().build())
.withBound(bound)
.withTarget(Arch.VULKAN);
Program program = new ProgramParser().parse(new File(programPath));
Wmm mcm = new ParserCat().parse(new File(modelPath));
return builder.build(program, mcm, EnumSet.of(CAT_SPEC));
}
}
Loading

0 comments on commit fa2d527

Please sign in to comment.