From 1379b4b41d512b351ffe5b264a7ceb4876bef2c7 Mon Sep 17 00:00:00 2001 From: Heshan Padamsiri Date: Thu, 5 Dec 2024 15:44:22 +0530 Subject: [PATCH] Add optimization to runtime --- .../runtime/api/types/semtype/Bdd.java | 124 ++++++++++++------ 1 file changed, 86 insertions(+), 38 deletions(-) diff --git a/bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java b/bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java index 2573230f05a8..7341d35a24da 100644 --- a/bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java +++ b/bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java @@ -20,6 +20,9 @@ import io.ballerina.runtime.internal.types.semtype.SubTypeData; +import java.util.HashMap; +import java.util.Map; + import static io.ballerina.runtime.api.types.semtype.Conjunction.and; /** @@ -38,10 +41,20 @@ public abstract sealed class Bdd extends SubType implements SubTypeData permits @Override public SubType union(SubType other) { - return bddUnion((Bdd) other); + return bddUnion(BddOpMemo.create(), (Bdd) other); + } + + private Bdd bddUnion(BddOpMemo memoTable, Bdd other) { + BddOpMemoKey key = new BddOpMemoKey(this, other); + Bdd memoized = memoTable.unionMemo().get(key); + if (memoized == null) { + memoized = bddUnionInner(memoTable, other); + memoTable.unionMemo().put(key, memoized); + } + return memoized; } - private Bdd bddUnion(Bdd other) { + private Bdd bddUnionInner(BddOpMemo memo, Bdd other) { if (other == this) { return this; } else if (this == BddAllOrNothing.ALL || other == BddAllOrNothing.ALL) { @@ -57,18 +70,18 @@ private Bdd bddUnion(Bdd other) { if (cmp < 0) { return bddCreate(b1Bdd.atom(), b1Bdd.left(), - b1Bdd.middle().bddUnion(other), + b1Bdd.middle().bddUnion(memo, other), b1Bdd.right()); } else if (cmp > 0) { return bddCreate(b2Bdd.atom(), b2Bdd.left(), - this.bddUnion(b2Bdd.middle()), + this.bddUnion(memo, b2Bdd.middle()), b2Bdd.right()); } else { return bddCreate(b1Bdd.atom(), - b1Bdd.left().bddUnion(b2Bdd.left()), - b1Bdd.middle().bddUnion(b2Bdd.middle()), - b1Bdd.right().bddUnion(b2Bdd.right())); + b1Bdd.left().bddUnion(memo, b2Bdd.left()), + b1Bdd.middle().bddUnion(memo, b2Bdd.middle()), + b1Bdd.right().bddUnion(memo, b2Bdd.right())); } } @@ -88,10 +101,20 @@ private int atomCmp(Atom a1, Atom a2) { @Override public SubType intersect(SubType other) { - return bddIntersect((Bdd) other); + return bddIntersect(BddOpMemo.create(), (Bdd) other); } - private Bdd bddIntersect(Bdd other) { + private Bdd bddIntersect(BddOpMemo memoTable, Bdd other) { + BddOpMemoKey key = new BddOpMemoKey(this, other); + Bdd memoized = memoTable.intersectionMemo().get(key); + if (memoized == null) { + memoized = bddIntersectInner(memoTable, other); + memoTable.intersectionMemo().put(key, memoized); + } + return memoized; + } + + private Bdd bddIntersectInner(BddOpMemo memo, Bdd other) { if (other == this) { return this; } else if (this == BddAllOrNothing.NOTHING || other == BddAllOrNothing.NOTHING) { @@ -106,66 +129,80 @@ private Bdd bddIntersect(Bdd other) { int cmp = atomCmp(b1Bdd.atom(), b2Bdd.atom()); if (cmp < 0) { return bddCreate(b1Bdd.atom(), - b1Bdd.left().bddIntersect(other), - b1Bdd.middle().bddIntersect(other), - b1Bdd.right().bddIntersect(other)); + b1Bdd.left().bddIntersect(memo, other), + b1Bdd.middle().bddIntersect(memo, other), + b1Bdd.right().bddIntersect(memo, other)); } else if (cmp > 0) { return bddCreate(b2Bdd.atom(), - this.bddIntersect(b2Bdd.left()), - this.bddIntersect(b2Bdd.middle()), - this.bddIntersect(b2Bdd.right())); + this.bddIntersect(memo, b2Bdd.left()), + this.bddIntersect(memo, b2Bdd.middle()), + this.bddIntersect(memo, b2Bdd.right())); } else { return bddCreate(b1Bdd.atom(), - b1Bdd.left().bddUnion(b1Bdd.middle()).bddIntersect(b2Bdd.left().bddUnion(b2Bdd.middle())), + b1Bdd.left().bddUnion(memo, b1Bdd.middle()) + .bddIntersect(memo, b2Bdd.left().bddUnion(memo, b2Bdd.middle())), BddAllOrNothing.NOTHING, - b1Bdd.right().bddUnion(b1Bdd.middle()).bddIntersect(b2Bdd.right().bddUnion(b2Bdd.middle()))); + b1Bdd.right().bddUnion(memo, b1Bdd.middle()) + .bddIntersect(memo, b2Bdd.right().bddUnion(memo, b2Bdd.middle()))); } } @Override public SubType diff(SubType other) { - return bddDiff((Bdd) other); + return bddDiff(BddOpMemo.create(), (Bdd) other); + } + + private Bdd bddDiff(BddOpMemo memoTable, Bdd other) { + var key = new BddOpMemoKey(this, other); + var memoized = memoTable.diffMemo.get(key); + if (memoized == null) { + memoized = bddDiffInner(memoTable, other); + memoTable.diffMemo.put(key, memoized); + } + return memoized; } - private Bdd bddDiff(Bdd other) { + private Bdd bddDiffInner(BddOpMemo memo, Bdd other) { if (this == other || other == BddAllOrNothing.ALL || this == BddAllOrNothing.NOTHING) { return BddAllOrNothing.NOTHING; } else if (other == BddAllOrNothing.NOTHING) { return this; } else if (this == BddAllOrNothing.ALL) { - return other.bddComplement(); + return other.bddComplement(memo); } BddNode b1Bdd = (BddNode) this; BddNode b2Bdd = (BddNode) other; int cmp = atomCmp(b1Bdd.atom(), b2Bdd.atom()); if (cmp < 0L) { return bddCreate(b1Bdd.atom(), - b1Bdd.left().bddUnion(b1Bdd.middle()).bddDiff(other), + b1Bdd.left().bddUnion(memo, b1Bdd.middle()).bddDiff(memo, other), BddAllOrNothing.NOTHING, - b1Bdd.right().bddUnion(b1Bdd.middle()).bddDiff(other)); + b1Bdd.right().bddUnion(memo, b1Bdd.middle()).bddDiff(memo, other)); } else if (cmp > 0L) { return bddCreate(b2Bdd.atom(), - this.bddDiff(b2Bdd.left().bddUnion(b2Bdd.middle())), + this.bddDiff(memo, b2Bdd.left().bddUnion(memo, b2Bdd.middle())), BddAllOrNothing.NOTHING, - this.bddDiff(b2Bdd.right().bddUnion(b2Bdd.middle()))); + this.bddDiff(memo, b2Bdd.right().bddUnion(memo, b2Bdd.middle()))); } else { // There is an error in the Castagna paper for this formula. // The union needs to be materialized here. // The original formula does not work in a case like (a0|a1) - a0. // Castagna confirms that the following formula is the correct one. return bddCreate(b1Bdd.atom(), - b1Bdd.left().bddUnion(b1Bdd.middle()).bddDiff(b2Bdd.left().bddUnion(b2Bdd.middle())), + b1Bdd.left().bddUnion(memo, b1Bdd.middle()) + .bddDiff(memo, b2Bdd.left().bddUnion(memo, b2Bdd.middle())), BddAllOrNothing.NOTHING, - b1Bdd.right().bddUnion(b1Bdd.middle()).bddDiff(b2Bdd.right().bddUnion(b2Bdd.middle()))); + b1Bdd.right().bddUnion(memo, b1Bdd.middle()) + .bddDiff(memo, b2Bdd.right().bddUnion(memo, b2Bdd.middle()))); } } @Override public SubType complement() { - return bddComplement(); + return bddComplement(BddOpMemo.create()); } - private Bdd bddComplement() { + private Bdd bddComplement(BddOpMemo memo) { if (this == BddAllOrNothing.ALL) { return BddAllOrNothing.NOTHING; } else if (this == BddAllOrNothing.NOTHING) { @@ -176,26 +213,26 @@ private Bdd bddComplement() { if (b.right() == nothing) { return bddCreate(b.atom(), nothing, - b.left().bddUnion(b.middle()).bddComplement(), - b.middle().bddComplement()); + b.left().bddUnion(memo, b.middle()).bddComplement(memo), + b.middle().bddComplement(memo)); } else if (b.left() == nothing) { return bddCreate(b.atom(), - b.middle().bddComplement(), - b.right().bddUnion(b.middle()).bddComplement(), + b.middle().bddComplement(memo), + b.right().bddUnion(memo, b.middle()).bddComplement(memo), nothing); } else if (b.middle() == nothing) { return bddCreate(b.atom(), - b.left().bddComplement(), - b.left().bddUnion(b.right()).bddComplement(), - b.right().bddComplement()); + b.left().bddComplement(memo), + b.left().bddUnion(memo, b.right()).bddComplement(memo), + b.right().bddComplement(memo)); } else { // There is a typo in the Frisch PhD thesis for this formula. // (It has left and right swapped.) // Castagna (the PhD supervisor) confirms that this is the correct formula. return bddCreate(b.atom(), - b.left().bddUnion(b.middle()).bddComplement(), + b.left().bddUnion(memo, b.middle()).bddComplement(memo), nothing, - b.right().bddUnion(b.middle()).bddComplement()); + b.right().bddUnion(memo, b.middle()).bddComplement(memo)); } } @@ -204,7 +241,7 @@ private Bdd bddCreate(Atom atom, Bdd left, Bdd middle, Bdd right) { return middle; } if (left.equals(right)) { - return left.bddUnion(right); + return left.bddUnion(BddOpMemo.create(), right); } return new BddNodeImpl(atom, left, middle, right); @@ -282,4 +319,15 @@ public static String bddToString(Bdd b, boolean inner) { } } + private record BddOpMemoKey(Bdd b1, Bdd b2) { + + } + + private record BddOpMemo(Map unionMemo, Map intersectionMemo, + Map diffMemo) { + + static BddOpMemo create() { + return new BddOpMemo(new HashMap<>(), new HashMap<>(), new HashMap<>()); + } + } }