diff --git a/scripts/ga-infer.sh b/scripts/ga-infer.sh new file mode 100755 index 0000000..9981f7a --- /dev/null +++ b/scripts/ga-infer.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +export MYDIR=`dirname $0` +. ./$MYDIR/setup.sh + +CHECKER=universe.UniverseInferenceChecker + +SOLVER=universe.solver.UniverseGASolverEngine +IS_HACK=true + +$CFI/scripts/inference-dev --checker "$CHECKER" --solver "$SOLVER" --solverArgs="collectStatistics=true,outputCNF=true" --hacks="$IS_HACK" -m ROUNDTRIP -afud ./annotated "$@" diff --git a/scripts/infer.sh b/scripts/infer.sh index e0a85fb..a7d8f0a 100755 --- a/scripts/infer.sh +++ b/scripts/infer.sh @@ -9,4 +9,4 @@ CHECKER=universe.UniverseInferenceChecker SOLVER=universe.solver.UniverseSolverEngine IS_HACK=true -$CFI/scripts/inference-dev --checker "$CHECKER" --solver "$SOLVER" --solverArgs="collectStatistics=true" --hacks="$IS_HACK" -m ROUNDTRIP -afud ./annotated "$@" +$CFI/scripts/inference-dev --checker "$CHECKER" --solver "$SOLVER" --solverArgs="collectStatistics=true,outputCNF=true" --hacks="$IS_HACK" -m ROUNDTRIP -afud ./annotated "$@" diff --git a/src/main/java/universe/UniverseInferenceVisitor.java b/src/main/java/universe/UniverseInferenceVisitor.java index 25bd4c2..22c48fb 100644 --- a/src/main/java/universe/UniverseInferenceVisitor.java +++ b/src/main/java/universe/UniverseInferenceVisitor.java @@ -3,6 +3,7 @@ import static universe.UniverseAnnotationMirrorHolder.ANY; import static universe.UniverseAnnotationMirrorHolder.BOTTOM; import static universe.UniverseAnnotationMirrorHolder.LOST; +import static universe.UniverseAnnotationMirrorHolder.PEER; import static universe.UniverseAnnotationMirrorHolder.REP; import static universe.UniverseAnnotationMirrorHolder.SELF; @@ -92,6 +93,10 @@ public Void visitVariable(VariableTree node, Void p) { InferenceMain.getInstance().getConstraintManager(); ConstantSlot rep = slotManager.createConstantSlot(REP); constraintManager.addPreferenceConstraint((VariableSlot) slot, rep, 80); + ConstantSlot peer = slotManager.createConstantSlot(PEER); + constraintManager.addPreferenceConstraint((VariableSlot) slot, peer, 50); + ConstantSlot any = slotManager.createConstantSlot(ANY); + constraintManager.addPreferenceConstraint((VariableSlot) slot, any, 10); } } return super.visitVariable(node, p); diff --git a/src/main/java/universe/solver/GeneticMaxSatSolverFactory.java b/src/main/java/universe/solver/GeneticMaxSatSolverFactory.java new file mode 100644 index 0000000..8573593 --- /dev/null +++ b/src/main/java/universe/solver/GeneticMaxSatSolverFactory.java @@ -0,0 +1,30 @@ +package universe.solver; + +import checkers.inference.model.Constraint; +import checkers.inference.model.Slot; +import checkers.inference.solver.backend.AbstractSolverFactory; +import checkers.inference.solver.backend.Solver; +import checkers.inference.solver.backend.maxsat.MaxSatFormatTranslator; +import checkers.inference.solver.frontend.Lattice; +import checkers.inference.solver.util.SolverEnvironment; + +import java.util.Collection; + +public class GeneticMaxSatSolverFactory extends AbstractSolverFactory { + + @Override + public MaxSatFormatTranslator createFormatTranslator(Lattice lattice) { + return new UniverseFormatTranslator(lattice); + } + + @Override + public Solver createSolver( + SolverEnvironment solverEnvironment, + Collection slots, + Collection constraints, + Lattice lattice) { + MaxSatFormatTranslator formatTranslator = createFormatTranslator(lattice); + return new UniverseGeneticMaxSatSolver( + solverEnvironment, slots, constraints, formatTranslator, lattice); + } +} diff --git a/src/main/java/universe/solver/UniverseGASolverEngine.java b/src/main/java/universe/solver/UniverseGASolverEngine.java new file mode 100644 index 0000000..5f441d4 --- /dev/null +++ b/src/main/java/universe/solver/UniverseGASolverEngine.java @@ -0,0 +1,11 @@ +package universe.solver; + +import checkers.inference.solver.SolverEngine; +import checkers.inference.solver.backend.SolverFactory; + +public class UniverseGASolverEngine extends SolverEngine { + @Override + protected SolverFactory createSolverFactory() { + return new GeneticMaxSatSolverFactory(); + } +} diff --git a/src/main/java/universe/solver/UniverseGeneticMaxSatSolver.java b/src/main/java/universe/solver/UniverseGeneticMaxSatSolver.java new file mode 100644 index 0000000..7c85d50 --- /dev/null +++ b/src/main/java/universe/solver/UniverseGeneticMaxSatSolver.java @@ -0,0 +1,122 @@ +package universe.solver; + +import static io.jenetics.engine.EvolutionResult.toBestEvolutionResult; +import static io.jenetics.engine.Limits.bySteadyFitness; + +import checkers.inference.model.Constraint; +import checkers.inference.model.Slot; +import checkers.inference.solver.backend.geneticmaxsat.GeneticMaxSatSolver; +import checkers.inference.solver.backend.maxsat.MaxSatFormatTranslator; +import checkers.inference.solver.frontend.Lattice; +import checkers.inference.solver.util.SolverEnvironment; + +import io.jenetics.IntegerGene; +import io.jenetics.MeanAlterer; +import io.jenetics.Mutator; +import io.jenetics.Optimize; +import io.jenetics.RouletteWheelSelector; +import io.jenetics.TournamentSelector; +import io.jenetics.engine.Codecs; +import io.jenetics.engine.Engine; +import io.jenetics.engine.EvolutionResult; +import io.jenetics.engine.EvolutionStatistics; +import io.jenetics.util.IntRange; + +import org.sat4j.maxsat.WeightedMaxSatDecorator; +import org.sat4j.maxsat.reader.WDimacsReader; +import org.sat4j.pb.IPBSolver; +import org.sat4j.reader.ParseFormatException; +import org.sat4j.specs.ContradictionException; +import org.sat4j.specs.TimeoutException; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import javax.lang.model.element.AnnotationMirror; + +public class UniverseGeneticMaxSatSolver extends GeneticMaxSatSolver { + public UniverseGeneticMaxSatSolver( + SolverEnvironment solverEnvironment, + Collection slots, + Collection constraints, + MaxSatFormatTranslator formatTranslator, + Lattice lattice) { + super(solverEnvironment, slots, constraints, formatTranslator, lattice); + } + + /** The fitness function in this case is the count of {@link universe.qual.Rep} */ + public int fitness(final int[] chromosome) { + IPBSolver solver = org.sat4j.maxsat.SolverFactory.newDefault(); + WDimacsReader reader = new WDimacsReader(new WeightedMaxSatDecorator(solver)); + Map solutions; + int fitness_count = 0; + + String WCNFModInput = changeSoftWeights(chromosome, this.wcnfFileContent, false); + + InputStream stream = + new ByteArrayInputStream(WCNFModInput.getBytes(StandardCharsets.UTF_8)); + + try { + solver = (IPBSolver) reader.parseInstance(stream); + } catch (ContradictionException | IOException | ParseFormatException e) { + System.out.println(e); + } + + try { + if (solver.isSatisfiable()) { + solutions = decode(solver.model()); + + List sol = new ArrayList<>(solutions.values()); + + for (AnnotationMirror sol_0 : sol) { + if (sol_0.toString().equals("@universe.qual.Rep")) { + fitness_count += 1; + } + } + } else { + System.out.println("UNSAT at " + chromosome[0]); + } + } catch (TimeoutException e) { + e.printStackTrace(); + } + + // System.out.println("Rep count: " + fitness_count); + + return fitness_count; + } + + @Override + public void fit() { + final Engine engine = + Engine.builder( + this::fitness, + Codecs.ofVector(IntRange.of(0, 700), this.allSoftWeightsCount)) + .populationSize(500) + .offspringFraction(0.7) + .survivorsSelector(new RouletteWheelSelector<>()) + .offspringSelector(new TournamentSelector<>()) + .optimize(Optimize.MAXIMUM) + .alterers(new Mutator<>(0.03), new MeanAlterer<>(0.6)) + .build(); + + final EvolutionStatistics statistics = EvolutionStatistics.ofNumber(); + + final EvolutionResult best_res = + engine.stream() + .limit(bySteadyFitness(7)) + .limit(100) + .peek(statistics) + .collect(toBestEvolutionResult()); + + System.out.println(statistics); + System.out.println("Genotype length: " + best_res.genotypes().length()); + System.out.println("Best Phenotype: " + best_res.bestPhenotype()); + System.out.println("Worst Phenotype: " + best_res.worstPhenotype()); + } +}