Skip to content

Commit

Permalink
first implementation of optimization in while's guard
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielegenovese committed Jul 12, 2024
1 parent bdf1e40 commit 639f32a
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 30 deletions.
6 changes: 4 additions & 2 deletions progs/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
x = int(input())
y = int(input())
m = 1
while n < (2 * x - 3 * y + 5):
m = m + n
_tmp0 = x + 2 * y
while n < 2 * x - 3 * y + 5:
g = _tmp0
m = m + n + g
n = n + 1
print(m)
7 changes: 3 additions & 4 deletions src/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,10 @@ public static void main(String[] args) {

Node ast = visitor.visit(tree);
CommonTokenStream updatedTokens = visitor.getTokens();
System.out.println("AAA");
System.out.println("Tokens:");
for (Token token : updatedTokens.getTokens()) {
System.out.print(token.getText() + " ");
}
System.out.println("AAA");
ArrayList<SemanticError> errorsWithDup = ast.checkSemantics(ST, 0, null);
ArrayList<SemanticError> errors = Share.removeDuplicates(errorsWithDup);
if (!errors.isEmpty()) {
Expand All @@ -73,9 +72,9 @@ public static void main(String[] args) {
System.out.println("\t" + e);
}
} else {
System.out.println("Visualizing AST...");
System.out.println(ast.toPrint(""));
/*
* System.out.println("Visualizing AST...");
* System.out.println(ast.toPrint(""));
* System.out.println("Visualizing CFG...");
* System.out.println(cfg.printCode());
* System.out.println("Creating VM code...");
Expand Down
75 changes: 52 additions & 23 deletions src/ast/Python3VisitorImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import java.util.Map;

import ast.nodes.*;
import ast.types.*;
import codegen.Label;
import java.lang.reflect.Array;
import parser.Python3Lexer;
import parser.Python3ParserBaseVisitor;
import parser.Python3Parser.*;
Expand All @@ -22,9 +25,11 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor<Node> {

Map<String, Integer> R;
private CommonTokenStream tokens;
private TokenStreamRewriter rewriter;

public Python3VisitorImpl(CommonTokenStream tokens) {
this.tokens = tokens;
this.rewriter = new TokenStreamRewriter(tokens);
}

public CommonTokenStream getTokens() {
Expand Down Expand Up @@ -52,10 +57,7 @@ public Node visitRoot(RootContext ctx) {
}
}

System.out.println(R);

// cfg.addEdge(cfg.getExitNode());

return new RootNode(childs);
}

Expand Down Expand Up @@ -323,37 +325,64 @@ public Node visitWhile_stmt(While_stmtContext ctx) {
// Block 1 is for the while-else statement
Node block = visit(ctx.block(0));

Token newToken = new CommonToken(Python3Lexer.NAME, "x");
Token equalsToken = new CommonToken(Python3Lexer.ASSIGN, "=");
Token fiveToken = new CommonToken(Python3Lexer.NUMBER, "5");

List<Token> newTokens = new ArrayList<>();
newTokens.add(newToken);
newTokens.add(equalsToken);
newTokens.add(fiveToken);

List<Token> updatedTokens = new ArrayList<>(tokens.getTokens());
int lineStart = ctx.getStart().getLine();
int lineStop = ctx.getStop().getLine();
int index = ctx.getStart().getTokenIndex();

// Add the new tokens before the "while" statement
updatedTokens.addAll(index, newTokens);
tokens = new CommonTokenStream(new ListTokenSource(updatedTokens));

System.out.println(tokens);

// updatedTokens.addAll(index, newTokens);
// this.tokens = new CommonTokenStream(new ListTokenSource(updatedTokens));
System.out.println(R);
System.out
.println(ctx.getParent().getChild(0) + " " + ctx.getStart().getLine() + " " + ctx.getStop().getLine());
for (var e : expr.getExprs()) {
AtomNode a = ((ExprNode) e).getAtom();
System.out.println(a.toPrint("->"));
var exprs = expr.getExprs();
System.out.println("text1 " + this.rewriter.getText());
int counter = 0;
// check nella guardia
for (var e : exprs) {
ArrayList<String> al = findAtomPresent(e, new ArrayList<>());
if (!al.isEmpty()) {
boolean constant = true;
for (String a : al) {
int n = R.get(a);
if (n > lineStart && n <= lineStop) {
constant = false;
break;
}
}
if (constant) {
String newVar = Label.newVar();
rewriter.insertBefore(index, newVar + "=" + e.toPrint("") + "\n");
int lastToken = ctx.expr().expr(counter).getStop().getTokenIndex();
int firstToken = ctx.expr().expr(counter).getStart().getTokenIndex();
this.rewriter.replace(firstToken, lastToken, newVar);
}
}
counter++;
}
System.out.println("text2 " + this.rewriter.getText());

WhileStmtNode whileStmt = new WhileStmtNode(expr, block);

return whileStmt;
}

private ArrayList<String> findAtomPresent(Node e, ArrayList<String> Acc) {
if (e instanceof ExprNode) {
ExprNode expNode = (ExprNode) e;
ArrayList<Node> exprs = expNode.getExprs();
if (!exprs.isEmpty()) {
for (Node i : exprs) {
findAtomPresent(i, Acc);
}
} else {
AtomNode a = (AtomNode) expNode.getAtom();
if (a.typeCheck() instanceof AtomType) {
Acc.add(a.getId());
}
}
}
return Acc;
}

/**
* Returns a `ForSmtNode`. We do not provide 'else' branch.
*
Expand Down
3 changes: 2 additions & 1 deletion src/ast/nodes/CompoundNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public CompoundNode(Node ifNode, Node funcDef, Node forStmt, Node whileStmt) {
this.whileStmt = whileStmt;
}


@Override
public ArrayList<SemanticError> checkSemantics(SymbolTable ST, int _nesting, FunctionType ft) {
ArrayList<SemanticError> errors = new ArrayList<>();
Expand Down Expand Up @@ -117,7 +118,7 @@ public String toPrint(String prefix) {
}

return str;

}

public Node getForStmt() {
Expand Down
5 changes: 5 additions & 0 deletions src/codegen/Label.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ public class Label {
private static int labelCounter = 0;
private static int globalVarNum = 0;
private static int functionLabelCounter = 0;
private static int varDefCount = 0;

public static void addFunDef(String s) {
funDef += s;
Expand Down Expand Up @@ -38,4 +39,8 @@ public static String newFun(String base) {
return base + (functionLabelCounter++);
}


public static String newVar() {
return "_tmp" + (varDefCount++);
}
}

0 comments on commit 639f32a

Please sign in to comment.