From 639f32ad3b41e0f4eddb8b6fe386033731a7e964 Mon Sep 17 00:00:00 2001 From: geno Date: Fri, 12 Jul 2024 20:11:00 +0200 Subject: [PATCH] first implementation of optimization in while's guard --- progs/test.py | 6 ++- src/Main.java | 7 ++- src/ast/Python3VisitorImpl.java | 75 +++++++++++++++++++++++---------- src/ast/nodes/CompoundNode.java | 3 +- src/codegen/Label.java | 5 +++ 5 files changed, 66 insertions(+), 30 deletions(-) diff --git a/progs/test.py b/progs/test.py index a1732ac..cbfa201 100644 --- a/progs/test.py +++ b/progs/test.py @@ -2,7 +2,9 @@ x = int(input()) y = int(input()) m = 1 -while n < (2 * x - 3 * y + 5): - m = m + n +_tmp0 = x + 2 * y +while n < 2 * x - 3 * y + 5: + g = _tmp0 + m = m + n + g n = n + 1 print(m) diff --git a/src/Main.java b/src/Main.java index 966a26a..fa38654 100644 --- a/src/Main.java +++ b/src/Main.java @@ -60,11 +60,10 @@ public static void main(String[] args) { Node ast = visitor.visit(tree); CommonTokenStream updatedTokens = visitor.getTokens(); - System.out.println("AAA"); + System.out.println("Tokens:"); for (Token token : updatedTokens.getTokens()) { System.out.print(token.getText() + " "); } - System.out.println("AAA"); ArrayList errorsWithDup = ast.checkSemantics(ST, 0, null); ArrayList errors = Share.removeDuplicates(errorsWithDup); if (!errors.isEmpty()) { @@ -73,9 +72,9 @@ public static void main(String[] args) { System.out.println("\t" + e); } } else { + System.out.println("Visualizing AST..."); + System.out.println(ast.toPrint("")); /* - * System.out.println("Visualizing AST..."); - * System.out.println(ast.toPrint("")); * System.out.println("Visualizing CFG..."); * System.out.println(cfg.printCode()); * System.out.println("Creating VM code..."); diff --git a/src/ast/Python3VisitorImpl.java b/src/ast/Python3VisitorImpl.java index 063c6d2..0e92f13 100644 --- a/src/ast/Python3VisitorImpl.java +++ b/src/ast/Python3VisitorImpl.java @@ -6,6 +6,9 @@ import java.util.Map; import ast.nodes.*; +import ast.types.*; +import codegen.Label; +import java.lang.reflect.Array; import parser.Python3Lexer; import parser.Python3ParserBaseVisitor; import parser.Python3Parser.*; @@ -22,9 +25,11 @@ public class Python3VisitorImpl extends Python3ParserBaseVisitor { Map R; private CommonTokenStream tokens; + private TokenStreamRewriter rewriter; public Python3VisitorImpl(CommonTokenStream tokens) { this.tokens = tokens; + this.rewriter = new TokenStreamRewriter(tokens); } public CommonTokenStream getTokens() { @@ -52,10 +57,7 @@ public Node visitRoot(RootContext ctx) { } } - System.out.println(R); - // cfg.addEdge(cfg.getExitNode()); - return new RootNode(childs); } @@ -323,37 +325,64 @@ public Node visitWhile_stmt(While_stmtContext ctx) { // Block 1 is for the while-else statement Node block = visit(ctx.block(0)); - Token newToken = new CommonToken(Python3Lexer.NAME, "x"); - Token equalsToken = new CommonToken(Python3Lexer.ASSIGN, "="); - Token fiveToken = new CommonToken(Python3Lexer.NUMBER, "5"); - - List newTokens = new ArrayList<>(); - newTokens.add(newToken); - newTokens.add(equalsToken); - newTokens.add(fiveToken); - - List updatedTokens = new ArrayList<>(tokens.getTokens()); + int lineStart = ctx.getStart().getLine(); + int lineStop = ctx.getStop().getLine(); int index = ctx.getStart().getTokenIndex(); // Add the new tokens before the "while" statement - updatedTokens.addAll(index, newTokens); - tokens = new CommonTokenStream(new ListTokenSource(updatedTokens)); - - System.out.println(tokens); - + // updatedTokens.addAll(index, newTokens); + // this.tokens = new CommonTokenStream(new ListTokenSource(updatedTokens)); System.out.println(R); - System.out - .println(ctx.getParent().getChild(0) + " " + ctx.getStart().getLine() + " " + ctx.getStop().getLine()); - for (var e : expr.getExprs()) { - AtomNode a = ((ExprNode) e).getAtom(); - System.out.println(a.toPrint("->")); + var exprs = expr.getExprs(); + System.out.println("text1 " + this.rewriter.getText()); + int counter = 0; + // check nella guardia + for (var e : exprs) { + ArrayList al = findAtomPresent(e, new ArrayList<>()); + if (!al.isEmpty()) { + boolean constant = true; + for (String a : al) { + int n = R.get(a); + if (n > lineStart && n <= lineStop) { + constant = false; + break; + } + } + if (constant) { + String newVar = Label.newVar(); + rewriter.insertBefore(index, newVar + "=" + e.toPrint("") + "\n"); + int lastToken = ctx.expr().expr(counter).getStop().getTokenIndex(); + int firstToken = ctx.expr().expr(counter).getStart().getTokenIndex(); + this.rewriter.replace(firstToken, lastToken, newVar); + } + } + counter++; } + System.out.println("text2 " + this.rewriter.getText()); WhileStmtNode whileStmt = new WhileStmtNode(expr, block); return whileStmt; } + private ArrayList findAtomPresent(Node e, ArrayList Acc) { + if (e instanceof ExprNode) { + ExprNode expNode = (ExprNode) e; + ArrayList exprs = expNode.getExprs(); + if (!exprs.isEmpty()) { + for (Node i : exprs) { + findAtomPresent(i, Acc); + } + } else { + AtomNode a = (AtomNode) expNode.getAtom(); + if (a.typeCheck() instanceof AtomType) { + Acc.add(a.getId()); + } + } + } + return Acc; + } + /** * Returns a `ForSmtNode`. We do not provide 'else' branch. * diff --git a/src/ast/nodes/CompoundNode.java b/src/ast/nodes/CompoundNode.java index a165165..ad3420d 100644 --- a/src/ast/nodes/CompoundNode.java +++ b/src/ast/nodes/CompoundNode.java @@ -22,6 +22,7 @@ public CompoundNode(Node ifNode, Node funcDef, Node forStmt, Node whileStmt) { this.whileStmt = whileStmt; } + @Override public ArrayList checkSemantics(SymbolTable ST, int _nesting, FunctionType ft) { ArrayList errors = new ArrayList<>(); @@ -117,7 +118,7 @@ public String toPrint(String prefix) { } return str; - + } public Node getForStmt() { diff --git a/src/codegen/Label.java b/src/codegen/Label.java index f0a4c6e..a301ede 100644 --- a/src/codegen/Label.java +++ b/src/codegen/Label.java @@ -6,6 +6,7 @@ public class Label { private static int labelCounter = 0; private static int globalVarNum = 0; private static int functionLabelCounter = 0; + private static int varDefCount = 0; public static void addFunDef(String s) { funDef += s; @@ -38,4 +39,8 @@ public static String newFun(String base) { return base + (functionLabelCounter++); } + + public static String newVar() { + return "_tmp" + (varDefCount++); + } }