Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unsoundness of symmetry learning #555

Merged
merged 2 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.dat3m.dartagnan;

import com.dat3m.dartagnan.solver.caat4wmm.Refiner;
import com.dat3m.dartagnan.solver.caat4wmm.coreReasoning.CoreReasoner;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.sosy_lab.common.configuration.Configuration;
Expand Down Expand Up @@ -35,14 +35,14 @@ public static void configure(Configuration config) throws InvalidConfigurationEx
*/
public static final boolean REFINEMENT_GENERATE_GRAPHVIZ_DEBUG_FILES = false;

public static final Refiner.SymmetryLearning REFINEMENT_SYMMETRY_LEARNING = Refiner.SymmetryLearning.FULL;
public static final CoreReasoner.SymmetricLearning REFINEMENT_SYMMETRIC_LEARNING = CoreReasoner.SymmetricLearning.FULL;

// --------------------

public static void LogGlobalSettings() {
// Refinement settings
logger.info("REFINEMENT_GENERATE_GRAPHVIZ_DEBUG_FILES: " + REFINEMENT_GENERATE_GRAPHVIZ_DEBUG_FILES);
logger.info("REFINEMENT_SYMMETRY_LEARNING: " + REFINEMENT_SYMMETRY_LEARNING.name());
logger.info("REFINEMENT_SYMMETRIC_LEARNING: " + REFINEMENT_SYMMETRIC_LEARNING.name());
}

public static String getOrCreateOutputDirectory() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,135 +1,61 @@
package com.dat3m.dartagnan.solver.caat4wmm;

import com.dat3m.dartagnan.encoding.EncodingContext;
import com.dat3m.dartagnan.program.Thread;
import com.dat3m.dartagnan.program.analysis.ThreadSymmetry;
import com.dat3m.dartagnan.program.event.core.Event;
import com.dat3m.dartagnan.program.event.core.MemoryCoreEvent;
import com.dat3m.dartagnan.solver.caat4wmm.coreReasoning.AddressLiteral;
import com.dat3m.dartagnan.solver.caat4wmm.coreReasoning.CoreLiteral;
import com.dat3m.dartagnan.solver.caat4wmm.coreReasoning.ExecLiteral;
import com.dat3m.dartagnan.solver.caat4wmm.coreReasoning.RelLiteral;
import com.dat3m.dartagnan.utils.equivalence.EquivalenceClass;
import com.dat3m.dartagnan.utils.logic.Conjunction;
import com.dat3m.dartagnan.utils.logic.DNF;
import com.dat3m.dartagnan.verification.Context;
import com.dat3m.dartagnan.wmm.Relation;
import org.sosy_lab.java_smt.api.BooleanFormula;
import org.sosy_lab.java_smt.api.BooleanFormulaManager;

import java.util.*;
import java.util.function.Function;

import static com.dat3m.dartagnan.GlobalSettings.REFINEMENT_SYMMETRY_LEARNING;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;

/*
This class handles the computation of refinement clauses from violations found by the WMM-solver procedure.
Furthermore, it incorporates symmetry reasoning if possible.
*/
public class Refiner {

public enum SymmetryLearning { NONE, LINEAR, QUADRATIC, FULL }

private final ThreadSymmetry symm;
private final List<Function<Event, Event>> symmPermutations;
private final SymmetryLearning learningOption;

public Refiner(Context analysisContext) {
this.learningOption = REFINEMENT_SYMMETRY_LEARNING;
symm = analysisContext.requires(ThreadSymmetry.class);
symmPermutations = computeSymmetryPermutations();
}

public Refiner() { }

// This method computes a refinement clause from a set of violations.
// Furthermore, it computes symmetric violations if symmetry learning is enabled.
public BooleanFormula refine(DNF<CoreLiteral> coreReasons, EncodingContext context) {
//TODO: A specialized algorithm that computes the orbit under permutation may be better,
// since most violations involve only few threads and hence the orbit is far smaller than the full
// set of permutations.
HashSet<BooleanFormula> addedFormulas = new HashSet<>(); // To avoid adding duplicates
BooleanFormulaManager bmgr = context.getBooleanFormulaManager();
final BooleanFormulaManager bmgr = context.getBooleanFormulaManager();
List<BooleanFormula> refinement = new ArrayList<>();
HashSet<BooleanFormula> addedFormulas = new HashSet<>(); // To avoid adding duplicates
// For each symmetry permutation, we will create refinement clauses
for (Function<Event, Event> perm : symmPermutations) {
for (Conjunction<CoreLiteral> reason : coreReasons.getCubes()) {
BooleanFormula permutedClause = bmgr.makeFalse();
for (CoreLiteral lit : reason.getLiterals()) {
BooleanFormula litFormula = permuteAndConvert(lit, perm, context);
if (bmgr.isFalse(litFormula)) {
permutedClause = bmgr.makeTrue();
break;
} else {
permutedClause = bmgr.or(permutedClause, bmgr.not(litFormula));
}
}
if (addedFormulas.add(permutedClause)) {
refinement.add(permutedClause);
for (Conjunction<CoreLiteral> reason : coreReasons.getCubes()) {
BooleanFormula permutedClause = bmgr.makeFalse();
for (CoreLiteral lit : reason.getLiterals()) {
final BooleanFormula litFormula = encode(lit, context);
if (bmgr.isFalse(litFormula)) {
permutedClause = bmgr.makeTrue();
break;
} else {
permutedClause = bmgr.or(permutedClause, bmgr.not(litFormula));
}
}
}
return bmgr.and(refinement);
}

// Computes a list of permutations allowed by the program.
// Depending on the <learningOption>, the set of computed permutations differs.
// In particular, for the option NONE, only the identity permutation will be returned.
private List<Function<Event, Event>> computeSymmetryPermutations() {
Set<? extends EquivalenceClass<Thread>> symmClasses = symm.getNonTrivialClasses();
List<Function<Event, Event>> perms = new ArrayList<>();
perms.add(Function.identity());

for (EquivalenceClass<Thread> c : symmClasses) {
List<Thread> threads = new ArrayList<>(c);
threads.sort(Comparator.comparingInt(Thread::getId));

switch (learningOption) {
case NONE:
break;
case LINEAR:
for (int i = 0; i < threads.size(); i++) {
int j = (i + 1) % threads.size();
perms.add(symm.createEventTransposition(threads.get(i), threads.get(j)));
}
break;
case QUADRATIC:
for (int i = 0; i < threads.size(); i++) {
for (int j = i + 1; j < threads.size(); j++) {
perms.add(symm.createEventTransposition(threads.get(i), threads.get(j)));
}
}
break;
case FULL:
List<Function<Event, Event>> allPerms = symm.createAllEventPermutations(c);
allPerms.remove(Function.identity()); // We avoid adding multiple identities
perms.addAll(allPerms);
break;
default:
throw new UnsupportedOperationException("Symmetry learning option: "
+ learningOption + " is not recognized.");
if (addedFormulas.add(permutedClause)) {
refinement.add(permutedClause);
}
}

return perms;
return bmgr.and(refinement);
}


// Changes a reasoning <literal> based on a given permutation <perm> and translates the result
// into a BooleanFormula for Refinement.
private BooleanFormula permuteAndConvert(CoreLiteral literal, Function<Event, Event> perm, EncodingContext encoder) {
BooleanFormulaManager bmgr = encoder.getBooleanFormulaManager();
BooleanFormula enc;
private BooleanFormula encode(CoreLiteral literal, EncodingContext encoder) {
final BooleanFormulaManager bmgr = encoder.getBooleanFormulaManager();
final BooleanFormula enc;
if (literal instanceof ExecLiteral lit) {
enc = encoder.execution(perm.apply(lit.getData()));
enc = encoder.execution(lit.getData());
} else if (literal instanceof AddressLiteral loc) {
MemoryCoreEvent e1 = (MemoryCoreEvent) perm.apply(loc.getFirst());
MemoryCoreEvent e2 = (MemoryCoreEvent) perm.apply(loc.getSecond());
enc = encoder.sameAddress(e1, e2);
enc = encoder.sameAddress((MemoryCoreEvent) loc.getFirst(), (MemoryCoreEvent) loc.getSecond());
} else if (literal instanceof RelLiteral lit) {
Relation rel = encoder.getTask().getMemoryModel().getRelation(lit.getName());
enc = encoder.edge(rel,
perm.apply(lit.getData().first()),
perm.apply(lit.getData().second()));
final Relation rel = encoder.getTask().getMemoryModel().getRelation(lit.getName());
enc = encoder.edge(rel, lit.getData().first(), lit.getData().second());
} else {
throw new IllegalArgumentException("CoreLiteral " + literal.toString() + " is not supported");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public Result check(Model model) {
curTime = System.currentTimeMillis();
List<Conjunction<CoreLiteral>> coreReasons = new ArrayList<>(caatResult.getBaseReasons().getNumberOfCubes());
for (Conjunction<CAATLiteral> baseReason : caatResult.getBaseReasons().getCubes()) {
coreReasons.add(reasoner.toCoreReason(baseReason));
coreReasons.addAll(reasoner.toCoreReasons(baseReason));
}
stats.numComputedCoreReasons = coreReasons.size();
result.coreReasons = new DNF<>(coreReasons);
Expand Down
Loading