Skip to content

Commit

Permalink
Fixed issue when specifying same file for loading and storing bounds.
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasHaas committed Oct 25, 2024
1 parent c875e5c commit 78969e5
Showing 1 changed file with 62 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import org.apache.logging.log4j.Logger;
import org.sosy_lab.common.configuration.*;

import java.io.*;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Reader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
Expand Down Expand Up @@ -69,6 +72,10 @@ public void setUnrollingBound(int bound) {

// =====================================================================

// We use this once for loading bounds from files (if any), and then to track
// all computed loop bounds for later storing them into a file (if desired).
private Map<Function, Map<CondJump, Integer>> globalLoopBoundsMap = new HashMap<>();

private LoopUnrolling() {
}

Expand All @@ -92,15 +99,18 @@ public void run(Program program) {
logger.warn("Skipped unrolling: Program is already unrolled.");
return;
}
if (pathIsSpecified(boundsSavePath)) {
ensureFileExistsAndIsEmpty(boundsSavePath);
}

globalLoopBoundsMap = loadLoopBoundsMapFromFile(program, boundsLoadPath);

final int defaultBound = this.bound;
program.getFunctions().forEach(this::run);
program.getThreads().forEach(this::run);
program.markAsUnrolled(defaultBound);
IdReassignment.newInstance().run(program); // Reassign ids because of newly created events

dumpLoopBoundsMapToFile(program, globalLoopBoundsMap, boundsSavePath);
globalLoopBoundsMap = null; // Save up some memory

logger.info("Program unrolled {} times", defaultBound);
}

Expand Down Expand Up @@ -145,14 +155,12 @@ private Map<CondJump, Integer> computeLoopBoundsMap(Function func, int defaultBo
}

// Merge with loaded bounds if those exist.
if(pathIsSpecified(boundsLoadPath)) {
final Map<CondJump, Integer> loopBoundsMapFromFile = loadLoopBoundsMapFromFile(func, boundsLoadPath);
if(globalLoopBoundsMap.containsKey(func)) {
final Map<CondJump, Integer> loopBoundsMapFromFile = globalLoopBoundsMap.get(func);
loopBoundsMapFromFile.forEach((key, value) -> loopBoundsMap.merge(key, value, Math::max));
}
// Store bounds we computed
if (pathIsSpecified(boundsSavePath)) {
dumpLoopBoundsMapToFile(func, loopBoundsMap, boundsSavePath);
}
// Remember bounds for function for dumping.
globalLoopBoundsMap.put(func, loopBoundsMap);

return loopBoundsMap;
}
Expand Down Expand Up @@ -238,63 +246,71 @@ private boolean pathIsSpecified(String path) {
return !path.isEmpty();
}

private void ensureFileExistsAndIsEmpty(String filePath) {
try {
final File file = new File(filePath);
if (!file.createNewFile()) {
// Clear file content
new FileWriter(file).close();
}
} catch (IOException e) {
e.printStackTrace();
}
}

public static int getPersistentLoopId(CondJump loopBackjump) {
return loopBackjump.getMetadata(UnrollingId.class).value();
final UnrollingId id = loopBackjump.getMetadata(UnrollingId.class);
return id != null ? id.value() : loopBackjump.getGlobalId();
}

public static int getUnrollingBoundAnnotation(CondJump boundEvent) {
Preconditions.checkArgument(boundEvent.hasTag(Tag.BOUND));
return boundEvent.getMetadata(UnrollingBound.class).value();
}

private Map<CondJump, Integer> loadLoopBoundsMapFromFile(Function func, String filePath) {
Preconditions.checkArgument(pathIsSpecified(filePath));
Preconditions.checkArgument(Files.exists(Path.of(filePath)));
private Map<Function, Map<CondJump, Integer>> loadLoopBoundsMapFromFile(Program program, String filePath) {
if (!pathIsSpecified(filePath)) {
return new HashMap<>();
}
if (!Files.exists(Path.of(filePath))) {
logger.warn("There is no bounds file at path {} . Using default bounds.", filePath);
return new HashMap<>();
}

final List<CondJump> jumps = func.getEvents(CondJump.class);
final Map<CondJump, Integer> loopBoundsMapFromFile = new HashMap<>();
// Compute mapping from ids to loop events
final Map<Integer, CondJump> idToJump = new HashMap<>();
program.getFunctions().forEach(f -> f.getEvents(CondJump.class).forEach(
jump -> idToJump.put(getPersistentLoopId(jump), jump))
);

// Read CSV file to find bounds for loop events
final Map<Function, Map<CondJump, Integer>> loopBoundsMapPerFunction = new HashMap<>();
try (Reader reader = new FileReader(filePath)) {
Iterable<CSVRecord> records = CSVFormat.DEFAULT.parse(reader);
for (CSVRecord record : records) {
final int loopId = Integer.parseInt(record.get(0));
final int bound = Integer.parseInt(record.get(1));
jumps.stream()
.filter(e -> getPersistentLoopId(e) == loopId)
.findFirst().ifPresent(loop -> loopBoundsMapFromFile.put(loop, bound));
final CondJump loopJump = idToJump.get(loopId);
if (loopJump == null) {
logger.warn("Loaded bounds file does not match with the program. Ignoring file.");
loopBoundsMapPerFunction.clear();
break;
}
loopBoundsMapPerFunction
.computeIfAbsent(loopJump.getFunction(), key -> new HashMap<>())
.put(loopJump, bound);
}
} catch (IOException e) {
e.printStackTrace();
}
return loopBoundsMapFromFile;

return loopBoundsMapPerFunction;
}

private void dumpLoopBoundsMapToFile(Function func, Map<CondJump, Integer> boundsMap, String filePath) {
Preconditions.checkArgument(pathIsSpecified(filePath));
Preconditions.checkArgument(Files.exists(Path.of(filePath)));

final SyntacticContextAnalysis synContext = SyntacticContextAnalysis.newInstance(func.getProgram());
try (Writer writer = new FileWriter(filePath, true);
CSVPrinter csvPrinter = new CSVPrinter(writer, CSVFormat.DEFAULT)) {
for (Map.Entry<CondJump, Integer> entry : boundsMap.entrySet()) {
final CondJump loopJump = entry.getKey();
final int loopId = getPersistentLoopId(loopJump);
final int loopBound = entry.getValue();
final String sourceLoc = synContext.getSourceLocationWithContext(loopJump, false);
csvPrinter.printRecord(loopId, loopBound, sourceLoc);
private void dumpLoopBoundsMapToFile(Program program, Map<Function, Map<CondJump, Integer>> loopBounds, String filePath) {
if (!pathIsSpecified(filePath)) {
return;
}

final SyntacticContextAnalysis synContext = SyntacticContextAnalysis.newInstance(program);
try (CSVPrinter csvPrinter = new CSVPrinter( new FileWriter(filePath, true), CSVFormat.DEFAULT)) {
for (Map<CondJump, Integer> loopBoundsMap : loopBounds.values()) {
for (Map.Entry<CondJump, Integer> entry : loopBoundsMap.entrySet()) {
final CondJump loopJump = entry.getKey();
final int loopId = getPersistentLoopId(loopJump);
final int loopBound = entry.getValue();
final String sourceLoc = synContext.getSourceLocationWithContext(loopJump, false);
csvPrinter.printRecord(loopId, loopBound, sourceLoc);
}
}
writer.flush();
csvPrinter.flush();
} catch (IOException e) {
e.printStackTrace();
Expand Down

0 comments on commit 78969e5

Please sign in to comment.