From 529d9ec5dab4002adbf506f2cc3b1c59e3c50c5f Mon Sep 17 00:00:00 2001
From: Hernan Ponce de Leon <zeta96@gmail.com>
Date: Sun, 27 Oct 2024 12:50:30 +0100
Subject: [PATCH] Allow bounds to be saved to and loaded from files  (#759)

Co-authored-by: Thomas Haas <tomy.haas@t-online.de>
---
 dartagnan/pom.xml                             |   7 +-
 .../java/com/dat3m/dartagnan/Dartagnan.java   |  70 ++++++++--
 .../dartagnan/configuration/OptionNames.java  |   2 +
 .../event/metadata/UnrollingBound.java        |   3 +
 .../program/processing/LoopUnrolling.java     | 132 +++++++++++++++++-
 pom.xml                                       |   2 +
 6 files changed, 201 insertions(+), 15 deletions(-)
 create mode 100644 dartagnan/src/main/java/com/dat3m/dartagnan/program/event/metadata/UnrollingBound.java

diff --git a/dartagnan/pom.xml b/dartagnan/pom.xml
index 50aa8ad175..63b113485b 100644
--- a/dartagnan/pom.xml
+++ b/dartagnan/pom.xml
@@ -45,7 +45,12 @@
         <dependency>
             <groupId>org.apache.maven</groupId>
             <artifactId>maven-model</artifactId>
-            <version>3.3.9</version>
+            <version>${maven-model.version}</version>
+        </dependency>
+        <dependency>
+            <groupId>org.apache.commons</groupId>
+            <artifactId>commons-csv</artifactId>
+            <version>${commons-csv.version}</version>
         </dependency>
 
         <!-- Z3 dependency (OS independent) -->
diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/Dartagnan.java b/dartagnan/src/main/java/com/dat3m/dartagnan/Dartagnan.java
index 780e2cfe73..2bbb4721ef 100644
--- a/dartagnan/src/main/java/com/dat3m/dartagnan/Dartagnan.java
+++ b/dartagnan/src/main/java/com/dat3m/dartagnan/Dartagnan.java
@@ -16,6 +16,7 @@
 import com.dat3m.dartagnan.program.event.core.Assert;
 import com.dat3m.dartagnan.program.event.core.CondJump;
 import com.dat3m.dartagnan.program.event.core.Load;
+import com.dat3m.dartagnan.program.processing.LoopUnrolling;
 import com.dat3m.dartagnan.utils.Result;
 import com.dat3m.dartagnan.utils.Utils;
 import com.dat3m.dartagnan.utils.options.BaseOptions;
@@ -34,6 +35,10 @@
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.io.CharSource;
+import org.apache.commons.csv.CSVFormat;
+import org.apache.commons.csv.CSVParser;
+import org.apache.commons.csv.CSVPrinter;
+import org.apache.commons.csv.CSVRecord;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.maven.model.io.xpp3.MavenXpp3Reader;
@@ -51,6 +56,7 @@
 
 import java.io.File;
 import java.io.FileReader;
+import java.io.FileWriter;
 import java.io.IOException;
 import java.math.BigInteger;
 import java.nio.file.Path;
@@ -58,8 +64,7 @@
 
 import static com.dat3m.dartagnan.GlobalSettings.getOrCreateOutputDirectory;
 import static com.dat3m.dartagnan.configuration.OptionInfo.collectOptions;
-import static com.dat3m.dartagnan.configuration.OptionNames.PHANTOM_REFERENCES;
-import static com.dat3m.dartagnan.configuration.OptionNames.TARGET;
+import static com.dat3m.dartagnan.configuration.OptionNames.*;
 import static com.dat3m.dartagnan.configuration.Property.*;
 import static com.dat3m.dartagnan.program.analysis.SyntacticContextAnalysis.*;
 import static com.dat3m.dartagnan.utils.GitInfo.*;
@@ -333,17 +338,26 @@ public static String generateResultSummary(VerificationTask task, ProverEnvironm
                 }
             } else if (result == UNKNOWN && modelChecker.hasModel()) {
                 // We reached unrolling bounds.
-                summary.append("=========== Not fully unrolled loops ============\n");
+                final List<Event> reachedBounds = new ArrayList<>();
                 for (Event ev : p.getThreadEventsWithAllTags(Tag.BOUND)) {
-                    final boolean isReached = TRUE.equals(model.evaluate(encCtx.execution(ev)));
-                    if (isReached) {
-                        summary
-                                .append("\t")
-                                .append(synContext.getSourceLocationWithContext(ev, true))
-                                .append("\n");
+                    if (TRUE.equals(model.evaluate(encCtx.execution(ev)))) {
+                        reachedBounds.add(ev);
                     }
                 }
+                summary.append("=========== Not fully unrolled loops ============\n");
+                for (Event bound : reachedBounds) {
+                    summary
+                            .append("\t")
+                            .append(synContext.getSourceLocationWithContext(bound, true))
+                            .append("\n");
+                }
                 summary.append("=================================================\n");
+
+                try {
+                    increaseBoundAndDump(reachedBounds, task.getConfig());
+                } catch (IOException e) {
+                    logger.warn("Failed to save bounds file: {}", e.getLocalizedMessage());
+                }
             }
             summary.append(result).append("\n");
         } else {
@@ -398,6 +412,44 @@ public static String generateResultSummary(VerificationTask task, ProverEnvironm
         return summary.toString();
     }
 
+    private static void increaseBoundAndDump(List<Event> boundEvents, Configuration config) throws IOException {
+        if(!config.hasProperty(BOUNDS_SAVE_PATH)) {
+            return;
+        }
+        final File boundsFile = new File(config.getProperty(BOUNDS_SAVE_PATH));
+
+        // Parse old entries
+        final List<CSVRecord> entries;
+        try (CSVParser parser = CSVParser.parse(new FileReader(boundsFile), CSVFormat.DEFAULT)) {
+            entries = parser.getRecords();
+        }
+
+        // Compute update for entries
+        final Map<Integer, Integer> loopId2UpdatedBound = new HashMap<>();
+        for (Event e : boundEvents) {
+            assert e instanceof CondJump;
+            final CondJump loopJump = (CondJump) e;
+            final int loopId = LoopUnrolling.getPersistentLoopId(loopJump);
+            final int bound = LoopUnrolling.getUnrollingBoundAnnotation(loopJump);
+            loopId2UpdatedBound.put(loopId, bound + 1);
+        }
+
+        // Write new entries
+        try (CSVPrinter csvPrinter = new CSVPrinter(new FileWriter(boundsFile, false), CSVFormat.DEFAULT)) {
+            for (CSVRecord entry : entries) {
+                final int entryId = Integer.parseInt(entry.get(0));
+                if (!loopId2UpdatedBound.containsKey(entryId)) {
+                    csvPrinter.printRecord(entry);
+                } else {
+                    final String[] content = entry.values();
+                    content[1] = String.valueOf(loopId2UpdatedBound.get(entryId));
+                    csvPrinter.printRecord(Arrays.asList(content));
+                }
+            }
+            csvPrinter.flush();
+        }
+    }
+
     private static void printWarningIfThreadStartFailed(Program p, EncodingContext encoder, ProverEnvironment prover)
             throws SolverException {
         for (Event e : p.getThreadEvents()) {
diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/configuration/OptionNames.java b/dartagnan/src/main/java/com/dat3m/dartagnan/configuration/OptionNames.java
index db9a0a0a52..20e39a44a5 100644
--- a/dartagnan/src/main/java/com/dat3m/dartagnan/configuration/OptionNames.java
+++ b/dartagnan/src/main/java/com/dat3m/dartagnan/configuration/OptionNames.java
@@ -5,6 +5,8 @@ public class OptionNames {
     // Base Options
     public static final String PROPERTY = "property";
     public static final String BOUND = "bound";
+    public static final String BOUNDS_LOAD_PATH = "bound.load";
+    public static final String BOUNDS_SAVE_PATH = "bound.save";
     public static final String TARGET = "target";
     public static final String METHOD = "method";
     public static final String SOLVER = "solver";
diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/event/metadata/UnrollingBound.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/event/metadata/UnrollingBound.java
new file mode 100644
index 0000000000..8c0ba81e82
--- /dev/null
+++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/event/metadata/UnrollingBound.java
@@ -0,0 +1,3 @@
+package com.dat3m.dartagnan.program.event.metadata;
+
+public record UnrollingBound(int value) implements Metadata { }
diff --git a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/LoopUnrolling.java b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/LoopUnrolling.java
index 293f15f489..5ba72b9c4f 100644
--- a/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/LoopUnrolling.java
+++ b/dartagnan/src/main/java/com/dat3m/dartagnan/program/processing/LoopUnrolling.java
@@ -1,9 +1,11 @@
 package com.dat3m.dartagnan.program.processing;
 
+import com.dat3m.dartagnan.GlobalSettings;
 import com.dat3m.dartagnan.expression.ExpressionFactory;
 import com.dat3m.dartagnan.program.Function;
 import com.dat3m.dartagnan.program.Program;
 import com.dat3m.dartagnan.program.Thread;
+import com.dat3m.dartagnan.program.analysis.SyntacticContextAnalysis;
 import com.dat3m.dartagnan.program.event.Event;
 import com.dat3m.dartagnan.program.event.EventFactory;
 import com.dat3m.dartagnan.program.event.EventUser;
@@ -11,15 +13,25 @@
 import com.dat3m.dartagnan.program.event.core.CondJump;
 import com.dat3m.dartagnan.program.event.core.Label;
 import com.dat3m.dartagnan.program.event.lang.svcomp.LoopBound;
+import com.dat3m.dartagnan.program.event.metadata.UnrollingBound;
 import com.dat3m.dartagnan.program.event.metadata.UnrollingId;
 import com.google.common.base.Preconditions;
+import org.apache.commons.csv.CSVFormat;
+import org.apache.commons.csv.CSVPrinter;
+import org.apache.commons.csv.CSVRecord;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.sosy_lab.common.configuration.*;
 
+import java.io.FileReader;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.Reader;
+import java.nio.file.Files;
+import java.nio.file.Path;
 import java.util.*;
 
-import static com.dat3m.dartagnan.configuration.OptionNames.BOUND;
+import static com.dat3m.dartagnan.configuration.OptionNames.*;
 
 @Options
 public class LoopUnrolling implements ProgramProcessor {
@@ -39,15 +51,33 @@ public class LoopUnrolling implements ProgramProcessor {
     @IntegerOption(min = 1)
     private int bound = 1;
 
-    public int getUnrollingBound() { return bound; }
+    public int getUnrollingBound() {
+        return bound;
+    }
+
     public void setUnrollingBound(int bound) {
         Preconditions.checkArgument(bound >= 1, "The unrolling bound must be positive.");
         this.bound = bound;
     }
 
+    @Option(name = BOUNDS_LOAD_PATH,
+            description = "Path to the CSV file containing loop bounds.",
+            secure = true)
+    private String boundsLoadPath = "";
+
+    @Option(name = BOUNDS_SAVE_PATH,
+            description = "Path to the CSV file to save loop bounds.",
+            secure = true)
+    private String boundsSavePath = "";
+
     // =====================================================================
 
-    private LoopUnrolling() { }
+    // We use this once for loading bounds from files (if any), and then to track
+    // all computed loop bounds for later storing them into a file (if desired).
+    private Map<Function, Map<CondJump, Integer>> globalLoopBoundsMap = new HashMap<>();
+
+    private LoopUnrolling() {
+    }
 
     private LoopUnrolling(Configuration config) throws InvalidConfigurationException {
         this();
@@ -69,16 +99,21 @@ public void run(Program program) {
             logger.warn("Skipped unrolling: Program is already unrolled.");
             return;
         }
+
+        globalLoopBoundsMap = loadLoopBoundsMapFromFile(program, boundsLoadPath);
+
         final int defaultBound = this.bound;
         program.getFunctions().forEach(this::run);
         program.getThreads().forEach(this::run);
         program.markAsUnrolled(defaultBound);
         IdReassignment.newInstance().run(program); // Reassign ids because of newly created events
 
+        dumpLoopBoundsMapToFile(program, globalLoopBoundsMap, boundsSavePath);
+        globalLoopBoundsMap = null; // Save up some memory
+
         logger.info("Program unrolled {} times", defaultBound);
     }
 
-
     private void run(Function function) {
         function.getEvents().forEach(e -> e.setMetadata(new UnrollingId(e.getGlobalId()))); // Track ids before unrolling
         unrollLoopsInFunction(function, bound);
@@ -95,7 +130,6 @@ private void unrollLoopsInFunction(Function func, int defaultBound) {
     }
 
     private Map<CondJump, Integer> computeLoopBoundsMap(Function func, int defaultBound) {
-
         LoopBound curBoundAnnotation = null;
         final Map<CondJump, Integer> loopBoundsMap = new HashMap<>();
         for (Event event : func.getEvents()) {
@@ -119,6 +153,15 @@ private Map<CondJump, Integer> computeLoopBoundsMap(Function func, int defaultBo
                 }
             }
         }
+
+        // Merge with loaded bounds if those exist.
+        if(globalLoopBoundsMap.containsKey(func)) {
+            final Map<CondJump, Integer> loopBoundsMapFromFile = globalLoopBoundsMap.get(func);
+            loopBoundsMapFromFile.forEach((key, value) -> loopBoundsMap.merge(key, value, Math::max));
+        }
+        // Remember bounds for function for dumping.
+        globalLoopBoundsMap.put(func, loopBoundsMap);
+
         return loopBoundsMap;
     }
 
@@ -146,6 +189,7 @@ private void unrollLoop(CondJump loopBackJump, int bound) {
                 boundEvent.getPredecessor().insertAfter(endOfLoopMarker);
 
                 boundEvent.copyAllMetadataFrom(loopBackJump);
+                boundEvent.setMetadata(new UnrollingBound(bound));
                 endOfLoopMarker.copyAllMetadataFrom(loopBackJump);
 
             } else {
@@ -195,4 +239,82 @@ private Event newBoundEvent(Function func) {
         return boundEvent;
     }
 
+    // ------------------------------------------------------------------------
+    // Functions related to loading and storing bound maps
+
+    private boolean pathIsSpecified(String path) {
+        return !path.isEmpty();
+    }
+
+    public static int getPersistentLoopId(CondJump loopBackjump) {
+        final UnrollingId id = loopBackjump.getMetadata(UnrollingId.class);
+        return id != null ? id.value() : loopBackjump.getGlobalId();
+    }
+
+    public static int getUnrollingBoundAnnotation(CondJump boundEvent) {
+        Preconditions.checkArgument(boundEvent.hasTag(Tag.BOUND));
+        return boundEvent.getMetadata(UnrollingBound.class).value();
+    }
+
+    private Map<Function, Map<CondJump, Integer>> loadLoopBoundsMapFromFile(Program program, String filePath) {
+        if (!pathIsSpecified(filePath)) {
+            return new HashMap<>();
+        }
+        if (!Files.exists(Path.of(filePath))) {
+            logger.warn("There is no bounds file at path {} . Using default bounds.", filePath);
+            return new HashMap<>();
+        }
+
+        // Compute mapping from ids to loop events
+        final Map<Integer, CondJump> idToJump = new HashMap<>();
+        program.getFunctions().forEach(f -> f.getEvents(CondJump.class).forEach(
+                jump -> idToJump.put(getPersistentLoopId(jump), jump))
+        );
+
+        // Read CSV file to find bounds for loop events
+        final Map<Function, Map<CondJump, Integer>> loopBoundsMapPerFunction = new HashMap<>();
+        try (Reader reader = new FileReader(filePath)) {
+            Iterable<CSVRecord> records = CSVFormat.DEFAULT.parse(reader);
+            for (CSVRecord record : records) {
+                final int loopId = Integer.parseInt(record.get(0));
+                final int bound = Integer.parseInt(record.get(1));
+                final CondJump loopJump = idToJump.get(loopId);
+                if (loopJump == null) {
+                    logger.warn("Loaded bounds file does not match with the program. Ignoring file.");
+                    loopBoundsMapPerFunction.clear();
+                    break;
+                }
+                loopBoundsMapPerFunction
+                        .computeIfAbsent(loopJump.getFunction(), key -> new HashMap<>())
+                        .put(loopJump, bound);
+            }
+        } catch (IOException e) {
+            logger.warn("Failed to read bounds file: {}", e.getLocalizedMessage());
+        }
+
+        return loopBoundsMapPerFunction;
+    }
+
+    private void dumpLoopBoundsMapToFile(Program program, Map<Function, Map<CondJump, Integer>> loopBounds, String filePath) {
+        if (!pathIsSpecified(filePath)) {
+            return;
+        }
+
+        final SyntacticContextAnalysis synContext = SyntacticContextAnalysis.newInstance(program);
+        try (CSVPrinter csvPrinter = new CSVPrinter( new FileWriter(filePath, false), CSVFormat.DEFAULT)) {
+            for (Map<CondJump, Integer> loopBoundsMap : loopBounds.values()) {
+                for (Map.Entry<CondJump, Integer> entry : loopBoundsMap.entrySet()) {
+                    final CondJump loopJump = entry.getKey();
+                    final int loopId = getPersistentLoopId(loopJump);
+                    final int loopBound = entry.getValue();
+                    final String sourceLoc = synContext.getSourceLocationWithContext(loopJump, false);
+                    csvPrinter.printRecord(loopId, loopBound, sourceLoc);
+                }
+            }
+            csvPrinter.flush();
+        } catch (IOException e) {
+            logger.warn("Failed to save bounds file: {}", e.getLocalizedMessage());
+        }
+    }
+
 }
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index 4776af4439..4b6665d857 100644
--- a/pom.xml
+++ b/pom.xml
@@ -47,6 +47,8 @@
         <guava.version>32.1.2-jre</guava.version>
         <junit.version>4.13.2</junit.version>
         <log4j.version>2.23.0</log4j.version>
+        <maven-model.version>3.3.9</maven-model.version>
+        <commons-csv.version>1.12.0</commons-csv.version>
         <mockito.version>5.11.0</mockito.version>
         <rsyntaxtextarea.version>3.3.4</rsyntaxtextarea.version>