diff --git a/dag_in_context/src/greedy_dag_extractor.rs b/dag_in_context/src/greedy_dag_extractor.rs index 25b947410..d725106de 100644 --- a/dag_in_context/src/greedy_dag_extractor.rs +++ b/dag_in_context/src/greedy_dag_extractor.rs @@ -894,7 +894,7 @@ impl CostModel for DefaultCostModel { fn get_op_cost(&self, op: &str) -> Cost { match op { // Leaves - "Const" => 1., + "Const" => 0., "Arg" => 0., _ if op.parse::().is_ok() || op.parse::().is_ok() || op.starts_with('"') => { 0. @@ -909,7 +909,7 @@ impl CostModel for DefaultCostModel { // Algebra "Add" | "PtrAdd" | "Sub" | "And" | "Or" | "Not" | "Shl" | "Shr" => 10., "FAdd" | "FSub" | "Fmax" | "Fmin" => 50., - "Mul" => 30., + "Mul" => 100., "FMul" => 150., "Div" => 50., "FDiv" => 250., diff --git a/dag_in_context/src/lib.rs b/dag_in_context/src/lib.rs index ffed419ab..c664269ff 100644 --- a/dag_in_context/src/lib.rs +++ b/dag_in_context/src/lib.rs @@ -64,8 +64,10 @@ pub fn prologue() -> String { include_str!("optimizations/loop_unroll.egg"), include_str!("optimizations/passthrough.egg"), include_str!("optimizations/loop_strength_reduction.egg"), + include_str!("optimizations/loop_multiply_motion.egg"), include_str!("utility/debug-helper.egg"), &rulesets(), + include_str!("optimizations/ivt.egg"), ] .join("\n") } diff --git a/dag_in_context/src/optimizations/ivt.egg b/dag_in_context/src/optimizations/ivt.egg new file mode 100644 index 000000000..c26698c72 --- /dev/null +++ b/dag_in_context/src/optimizations/ivt.egg @@ -0,0 +1,128 @@ + +;; A Perm is a reverse list of integers +(datatype Perm (PermPush i64 Perm) (PNil)) +;; expr1 is a list of expressions of the form (Get expr2 i), +;; where all the i's form a permutation +(relation IVTPermutationAnalysisDemand (Expr)) +;; expr1 curr expr2 +(relation IVTPermutationAnalysisImpl (Expr Expr Expr Perm)) +;; expr1 expr2 +(relation IVTPermutationAnalysis (Expr Expr Perm)) + +(rule ( + (DoWhile inpW outW) +) ( + (IVTPermutationAnalysisDemand outW) +) :ruleset always-run) + +(rule ( + (IVTPermutationAnalysisDemand loop-body) + (= loop-body (Concat (Single (Get expr ith)) rest)) +) ( + (let perm (PermPush ith (PNil))) + (IVTPermutationAnalysisImpl loop-body rest expr perm) +) :ruleset always-run) + +(rule ( + (IVTPermutationAnalysisImpl loop-body curr expr perm) + (= curr (Concat (Single (Get expr ith)) rest)) +) ( + (let new-perm (PermPush ith perm)) + (IVTPermutationAnalysisImpl loop-body rest expr new-perm) +) :ruleset always-run) + +(rule ( + (IVTPermutationAnalysisImpl loop-body (Single last) expr perm) + (= last (Get expr ith)) +) ( + (let new-perm (PermPush ith perm)) + (IVTPermutationAnalysis loop-body expr new-perm) +) :ruleset always-run) + +(function ApplyPerm (Perm Expr) Expr) + +(rewrite (ApplyPerm (PermPush ith rest) expr) + (Concat (ApplyPerm rest expr) (Single (Get expr ith))) + :when ((= rest (PermPush _jth _rest))) + :ruleset always-run) +(rewrite (ApplyPerm (PermPush ith (PNil)) expr) (Single (Get expr ith)) + :ruleset always-run) + +(ruleset loop-inversion) + +;; This is for unified handling of thn/els branches +(relation ith-arg-is-bool (Expr Expr Expr i64 Expr Expr)) + +(rule ( + (= loop (DoWhile inpW outW)) + (IVTPermutationAnalysis outW conditional perm) + (= conditional (If condI inpI thn els)) + + (= (Get thn ith) (Const (Bool true) _unused1 _unused2)) + (= (Get els ith) (Const (Bool false) _unused3 _unused4)) +) ( + (ith-arg-is-bool conditional condI inpI ith thn els) +) :ruleset always-run) + +(rule ( + (= loop (DoWhile inpW outW)) + (IVTPermutationAnalysis outW conditional perm) + (= conditional (If condI inpI thn els)) + + (= (Get thn ith) (Const (Bool false) _unused1 _unused2)) + (= (Get els ith) (Const (Bool true) _unused3 _unused4)) +) ( + ;; TODO: this may introduce overhead, but is the only way to + ;; not have two rules + (ith-arg-is-bool conditional (Uop (Not) condI) inpI ith els thn) +) :ruleset always-run) + +(rule ( + (= loop (DoWhile inpW outW)) + (IVTPermutationAnalysis outW conditional perm) + ;; This generalizes the following conditions: + ;; (= conditional (If condI inpI thn els)) + ;; (= (Get thn ith) (Const (Bool true) _unused1 _unused2)) + ;; (= (Get els ith) (Const (Bool false) _unused3 _unused4)) + (ith-arg-is-bool conditional condI inpI ith thn els) + + (ContextOf inpW outer-ctx) + (ContextOf inpI if-ctx) + (HasType inpI argI) +) ( + ;; first peeled condition + (let new-if-cond (Subst outer-ctx inpW condI)) + ;; if contexts + (let new-if-inp (Subst outer-ctx inpW inpI)) + (let new-if-true-ctx (InIf true new-if-cond new-if-inp)) + (let new-if-false-ctx (InIf false new-if-cond new-if-inp)) + + (let new-loop-context (TmpCtx)) + + ;; body + (let new-loop-outputs_ + (TupleRemoveAt (ApplyPerm perm thn) 0)) + (let new-loop-outputs + (Subst new-loop-context new-loop-outputs_ + (Concat (Single condI) inpI))) + + (let new-loop (DoWhile (Arg argI new-if-true-ctx) new-loop-outputs)) + (let new-if + (If new-if-cond new-if-inp + new-loop + (Arg argI new-if-false-ctx))) + + ;; Apply the body of the false branch as an afterprocessing wrapper + (let new-expr_ + (Subst outer-ctx new-if els)) + (let new-expr + (TupleRemoveAt + (ApplyPerm perm new-expr_) + 0)) + + (union new-expr loop) + (union new-loop-context (InLoop (Arg argI new-if-true-ctx) new-loop-outputs)) + + (delete (DoWhile inpW outW)) + (delete (TmpCtx)) +) :ruleset loop-inversion) diff --git a/dag_in_context/src/optimizations/loop_multiply_motion.egg b/dag_in_context/src/optimizations/loop_multiply_motion.egg new file mode 100644 index 000000000..35bd30dfb --- /dev/null +++ b/dag_in_context/src/optimizations/loop_multiply_motion.egg @@ -0,0 +1,156 @@ +(ruleset loop-multiply-motion) + + +;;Example: +;; +;;original: +;; int x = 0; +;; while (x < 3) { +;; x += 1; +;; } +;; return x * 5; +;; optimized: +;; +;; int x = 0; +;; while (x < 15) { +;; x += 5; +;; } +;; return x; +(rule ((= loop (DoWhile in pred_out)) + (= argi (Get (Arg ty ctx) i)) + (= inputs-len (tuple-length (Arg ty ctx))) + ;; variable is incremented by constant each iteration + (= (Get pred_out (+ i 1)) + (Bop (Add) argi (Const (Int c) ty ctx))) + ;; check that it is less than n + (= (Get pred_out 0) + (Bop (LessThan) (Get pred_out (+ i 1)) + (Const (Int n) ty ctx))) + ;; overapproximate check that we won't overflow + (< (+ n (+ c k)) 10000) + + ;; after the loop, we multiply by some constant + (= res (Bop (Mul) (Get loop i) (Const (Int k) ty_outer ctx_outer))) + (= ty (TupleT ty_list))) + ( ;; new type + (let new-arg-ty + (TupleT (TLConcat ty_list (TCons (IntT) (TNil))))) + ;; new inputs with i duplicated + (let new-inputs (Concat in (Single (Get in i)))) + (let new-pred-out + (Subst (TmpCtx) + (SubTuple (Arg new-arg-ty (TmpCtx)) 0 inputs-len) pred_out)) + (let new-x (Get (Arg new-arg-ty (TmpCtx)) inputs-len)) + (let new-body + (Concat new-pred-out + (Single (Bop (Add) new-x + (Const (Int (* c k)) new-arg-ty (TmpCtx)))))) + ;; add another value to the loop like i but multiplied + (let new-loop + (DoWhile new-inputs + new-body)) + ;; union context + (union (TmpCtx) (InLoop new-inputs new-body)) + ;; old loop equal to new loop + (union (SubTuple new-loop 0 inputs-len) loop) + ;; multiplication is equal to the new value + (union res (Get new-loop inputs-len)) + (delete (TmpCtx))) + :ruleset loop-multiply-motion) + + +(relation IfGreaterThanThenOne (i64 i64 i64)) + +(rule ((DoWhile inputs pred_out) + (= argx (Get (Arg ty ctx) x)) + (= argy (Get (Arg ty ctx) y)) + ;; iter variable x + (= (Get pred_out (+ x 1)) + (Bop (Add) argx (Const (Int xconst) ty ctx))) + ;; iter variable y + (= (Get pred_out (+ y 1)) + (Bop (Add) argy (Const (Int yconst) ty ctx))) + (> x y)) + ((IfGreaterThanThenOne x y 1)) + :ruleset always-run) + + +(rule ((DoWhile inputs pred_out) + (= argx (Get (Arg ty ctx) x)) + (= argy (Get (Arg ty ctx) y)) + ;; iter variable x + (= (Get pred_out (+ x 1)) + (Bop (Add) argx (Const (Int xconst) ty ctx))) + ;; iter variable y + (= (Get pred_out (+ y 1)) + (Bop (Add) argy (Const (Int yconst) ty ctx))) + (< x y)) + ((IfGreaterThanThenOne x y 0)) + :ruleset always-run) + + + + +;; try to consolidate loop iter variables by finding +;; equivalence between predicates +;; by finding the one that is used by the predicate and changing it to use the other +(rule ((= loop (DoWhile inputs pred_out)) + (HasArgType inputs ty-outer) + (ContextOf inputs ctx-outer) + + (= argx (Get (Arg ty ctx) x)) + (= argy (Get (Arg ty ctx) y)) + (!= x y) + (= inputs-len (tuple-length inputs)) + ;; iter variable x + (= (Get pred_out (+ x 1)) + (Bop (Add) argx (Const (Int xconst) ty ctx))) + ;; iter variable y + (= (Get pred_out (+ y 1)) + (Bop (Add) argy (Const (Int yconst) ty ctx))) + ;; loop condition is over y + (= (Get pred_out 0) + (Bop (LessThan) (Get pred_out (+ y 1)) + (Const (Int n) ty ctx))) + ;; x starts at a constant zero + (= (Get inputs x) (Const (Int 0) ty-outer ctx-outer)) + ;; y starts at a constant + (= (Get inputs y) (Const (Int 0) ty-outer ctx-outer)) + ;; isgreater is 1 when x comes after y + (IfGreaterThanThenOne x y isgreater) + ;; x increment is divisible by y increment + (= factor (/ xconst yconst)) + (= (* factor yconst) xconst) + + ;; we won't run into overflow issues + (< (+ factor n) 10000) + (= ty (TupleT ty_list))) + (;; find another way to express predicate in old loop + (let old-x-predicate + (Bop (LessThan) (Get pred_out (+ x 1)) + (Const (Int (* factor n)) ty ctx))) + (union (Get pred_out 0) old-x-predicate) + + ;; new inputs with y removed + (let new-inputs (TupleRemoveAt inputs y)) + ;; new type + (let new-arg-ty (TupleT (TypeListRemoveAt ty_list y))) + ;; new body with y removed + (let new-body + (DropAt (TmpCtx) y (TupleRemoveAt pred_out (+ y 1)))) + ;; new loop + (let new-loop + (DoWhile new-inputs + new-body)) + ;; union context + (union (TmpCtx) (InLoop new-inputs new-body)) + (union loop + (Concat + (SubTuple new-loop 0 y) + (Concat + (Single (Bop (Div) (Get new-loop (- x isgreater)) + (Const (Int factor) ty-outer ctx-outer))) + (SubTuple new-loop y (- inputs-len (+ y 1)))))) + + (delete (TmpCtx))) + :ruleset loop-multiply-motion) diff --git a/dag_in_context/src/optimizations/passthrough.egg b/dag_in_context/src/optimizations/passthrough.egg index c420723d0..276c8a749 100644 --- a/dag_in_context/src/optimizations/passthrough.egg +++ b/dag_in_context/src/optimizations/passthrough.egg @@ -23,7 +23,7 @@ (= (Get branch1 i) (Get (Arg _ _ctx1) j)) (= passed-through (Get inputs j)) (HasType lhs lhs_ty) - (!= lhs_ty (Base (StateT)))) + (NonStateType lhs_ty)) ((union lhs passed-through)) :ruleset passthrough) @@ -45,7 +45,7 @@ (= then-branch (Get (Arg arg_ty _then_ctx) j)) (= else-branch (Get (Arg arg_ty _else_ctx) j)) (HasType then-branch lhs_ty) - (!= lhs_ty (Base (StateT)))) + (NonStateType lhs_ty)) ((union (Get if i) (Get inputs j))) :ruleset passthrough) diff --git a/dag_in_context/src/schedule.rs b/dag_in_context/src/schedule.rs index 0d6cc7738..c7d6990d9 100644 --- a/dag_in_context/src/schedule.rs +++ b/dag_in_context/src/schedule.rs @@ -57,6 +57,9 @@ fn optimizations() -> Vec { "switch_rewrite", "loop-inv-motion", "loop-strength-reduction", + "loop-peel", + "loop-inversion", + "loop-multiply-motion", ] .iter() .map(|opt| opt.to_string()) diff --git a/dag_in_context/src/type_analysis.egg b/dag_in_context/src/type_analysis.egg index 69da1e8cd..2b513aecd 100644 --- a/dag_in_context/src/type_analysis.egg +++ b/dag_in_context/src/type_analysis.egg @@ -43,7 +43,7 @@ (extract actual) (extract "with message") (extract msg) - (panic "type mismatch")) + (panic "type mismatch, use RUST_LOG=info to see type error")) :ruleset error-checking) (relation HasArgType (Expr Type)) @@ -508,6 +508,41 @@ ((panic "input types and output types don't match")) :ruleset error-checking) +; ================================= +; Predicate to check if a type is not a state type +; ================================= + +(relation NonStateBaseType (BaseType)) +(rule () + ((NonStateBaseType (IntT)) + (NonStateBaseType (BoolT)) + (NonStateBaseType (FloatT))) + :ruleset type-analysis) +(rule ((= ptr-type (PointerT ty)) + (NonStateBaseType ty)) + ((NonStateBaseType ptr-type)) + :ruleset type-analysis) + +(relation NonStateTypeList (TypeList)) +(rule () + ((NonStateTypeList (TNil))) + :ruleset type-analysis) +(rule ((NonStateTypeList tl) + (NonStateBaseType hd) + (= type-list (TCons hd tl))) + ((NonStateTypeList type-list)) + :ruleset type-analysis) + +(relation NonStateType (Type)) +(rule ((NonStateBaseType bt) + (= t (Base bt))) + ((NonStateType t)) + :ruleset type-analysis) +(rule ((NonStateTypeList tl) + (= t (TupleT tl))) + ((NonStateType t)) + :ruleset type-analysis) + ; ================================= ; Functions ; ================================= diff --git a/tests/passing/small/fuse_iteration_counter.bril b/tests/passing/small/fuse_iteration_counter.bril new file mode 100644 index 000000000..881b6954b --- /dev/null +++ b/tests/passing/small/fuse_iteration_counter.bril @@ -0,0 +1,18 @@ +# ARGS: +@main { +.b0_: + c1_: int = const 0; + i: int = const 0; + x: int = const 0; +.b6_: + c2_: int = const 3; + c7_: int = const 1; + i: int = add i c7_; + x: int = add x c2_; + v11_: bool = lt i c2_; + br v11_ .b6_ .b12_; +.b12_: + print x; + ret; +} + diff --git a/tests/snapshots/files__block-diamond-optimize-sequential.snap b/tests/snapshots/files__block-diamond-optimize-sequential.snap index 646549c44..774bc48b7 100644 --- a/tests/snapshots/files__block-diamond-optimize-sequential.snap +++ b/tests/snapshots/files__block-diamond-optimize-sequential.snap @@ -5,23 +5,23 @@ expression: visualization.result # ARGS: 1 @main(v0: int) { .b1_: - c2_: int = const 1; - c3_: int = const 2; - v4_: bool = lt v0 c3_; - c5_: int = const 0; + c2_: int = const 2; + v3_: bool = lt v0 c2_; + c4_: int = const 0; + c5_: int = const 1; c6_: int = const 5; - v7_: int = id c2_; - v8_: int = id c2_; - v9_: int = id c3_; - br v4_ .b10_ .b11_; + v7_: int = id c5_; + v8_: int = id c5_; + v9_: int = id c2_; + br v3_ .b10_ .b11_; .b10_: c12_: int = const 4; v7_: int = id c12_; - v8_: int = id c2_; - v9_: int = id c3_; + v8_: int = id c5_; + v9_: int = id c2_; v13_: int = id v7_; v14_: int = id v8_; - v15_: int = add c2_ v13_; + v15_: int = add c5_ v13_; print v15_; ret; jmp .b16_; @@ -31,7 +31,7 @@ expression: visualization.result v17_: int = add v7_ v9_; v13_: int = id v17_; v14_: int = id v8_; - v15_: int = add c2_ v13_; + v15_: int = add c5_ v13_; print v15_; ret; .b16_: diff --git a/tests/snapshots/files__block-diamond-optimize.snap b/tests/snapshots/files__block-diamond-optimize.snap index 76b2474dd..a10c015d2 100644 --- a/tests/snapshots/files__block-diamond-optimize.snap +++ b/tests/snapshots/files__block-diamond-optimize.snap @@ -5,24 +5,24 @@ expression: visualization.result # ARGS: 1 @main(v0: int) { .b1_: - c2_: int = const 1; - c3_: int = const 2; - v4_: bool = lt v0 c3_; - c5_: int = const 4; - v6_: int = select v4_ c5_ c2_; + c2_: int = const 2; + v3_: bool = lt v0 c2_; + c4_: int = const 4; + c5_: int = const 1; + v6_: int = select v3_ c4_ c5_; v7_: int = id v6_; - v8_: int = id c2_; - br v4_ .b9_ .b10_; + v8_: int = id c5_; + br v3_ .b9_ .b10_; .b9_: - v11_: int = add c2_ v7_; + v11_: int = add c5_ v7_; print v11_; ret; jmp .b12_; .b10_: - v13_: int = add c3_ v6_; + v13_: int = add c2_ v6_; v7_: int = id v13_; - v8_: int = id c2_; - v11_: int = add c2_ v7_; + v8_: int = id c5_; + v11_: int = add c5_ v7_; print v11_; ret; .b12_: diff --git a/tests/snapshots/files__fib_recursive-optimize-sequential.snap b/tests/snapshots/files__fib_recursive-optimize-sequential.snap index 6abe76b39..53e24b650 100644 --- a/tests/snapshots/files__fib_recursive-optimize-sequential.snap +++ b/tests/snapshots/files__fib_recursive-optimize-sequential.snap @@ -119,7 +119,11 @@ expression: visualization.result br v57_ .b58_ .b59_; .b58_: v60_: int = call @fac c2_; +<<<<<<< HEAD v61_: int = id c4_; +======= + v61_: int = id v0; +>>>>>>> 3e24659a (fix bug in loop strength reduction, add helper) v39_: int = id v61_; v42_: int = id c4_; jmp .b43_; @@ -281,7 +285,11 @@ expression: visualization.result br v137_ .b138_ .b139_; .b138_: v140_: int = call @fac c2_; +<<<<<<< HEAD v141_: int = id c4_; +======= + v141_: int = id v67_; +>>>>>>> 3e24659a (fix bug in loop strength reduction, add helper) v133_: int = id v141_; v136_: int = id c4_; v71_: int = id v136_; @@ -374,8 +382,19 @@ expression: visualization.result v14_: int = id c13_; br v12_ .b15_ .b16_; .b15_: +<<<<<<< HEAD v17_: int = id c4_; v5_: int = id v17_; +======= + v17_: bool = eq c1_ c2_; + br v17_ .b18_ .b19_; +.b18_: + v20_: int = call @fac c2_; + v21_: int = id c1_; + v13_: int = id v21_; + v16_: int = id c1_; + v5_: int = id v16_; +>>>>>>> 3e24659a (fix bug in loop strength reduction, add helper) print v5_; ret; jmp .b8_; diff --git a/tests/snapshots/files__fib_recursive-optimize.snap b/tests/snapshots/files__fib_recursive-optimize.snap index a2a803728..815e995f0 100644 --- a/tests/snapshots/files__fib_recursive-optimize.snap +++ b/tests/snapshots/files__fib_recursive-optimize.snap @@ -119,7 +119,11 @@ expression: visualization.result br v57_ .b58_ .b59_; .b58_: v60_: int = call @fac c2_; +<<<<<<< HEAD v61_: int = id c4_; +======= + v61_: int = id v0; +>>>>>>> 3e24659a (fix bug in loop strength reduction, add helper) v39_: int = id v61_; v42_: int = id c4_; jmp .b43_; @@ -281,7 +285,11 @@ expression: visualization.result br v137_ .b138_ .b139_; .b138_: v140_: int = call @fac c2_; +<<<<<<< HEAD v141_: int = id c4_; +======= + v141_: int = id v67_; +>>>>>>> 3e24659a (fix bug in loop strength reduction, add helper) v133_: int = id v141_; v136_: int = id c4_; v71_: int = id v136_; @@ -374,8 +382,19 @@ expression: visualization.result v14_: int = id c13_; br v12_ .b15_ .b16_; .b15_: +<<<<<<< HEAD v17_: int = id c4_; v5_: int = id v17_; +======= + v17_: bool = eq c1_ c2_; + br v17_ .b18_ .b19_; +.b18_: + v20_: int = call @fac c2_; + v21_: int = id c1_; + v13_: int = id v21_; + v16_: int = id c1_; + v5_: int = id v16_; +>>>>>>> 3e24659a (fix bug in loop strength reduction, add helper) print v5_; ret; jmp .b8_;