From 972fef90deffd66a818bc45046356ca4051ccfa1 Mon Sep 17 00:00:00 2001 From: Heshan Padamsiri Date: Thu, 5 Dec 2024 15:19:06 +0530 Subject: [PATCH] Add optimization to compiler --- .../ballerina/types/typeops/BddCommonOps.java | 112 ++++++++++++++---- 1 file changed, 87 insertions(+), 25 deletions(-) diff --git a/semtypes/src/main/java/io/ballerina/types/typeops/BddCommonOps.java b/semtypes/src/main/java/io/ballerina/types/typeops/BddCommonOps.java index 6de40258eee1..cac04bd1e329 100644 --- a/semtypes/src/main/java/io/ballerina/types/typeops/BddCommonOps.java +++ b/semtypes/src/main/java/io/ballerina/types/typeops/BddCommonOps.java @@ -23,6 +23,9 @@ import io.ballerina.types.subtypedata.BddAllOrNothing; import io.ballerina.types.subtypedata.BddNode; +import java.util.HashMap; +import java.util.Map; + /** * Contain common BDD operations found in bdd.bal file. * @@ -38,6 +41,21 @@ public static BddNode bddAtom(Atom atom) { } public static Bdd bddUnion(Bdd b1, Bdd b2) { + return bddUnionWithMemo(BddOpMemo.create(), b1, b2); + } + + private static Bdd bddUnionWithMemo(BddOpMemo memoTable, Bdd b1, Bdd b2) { + BddOpMemoKey key = new BddOpMemoKey(b1, b2); + Bdd memoized = memoTable.unionMemo.get(key); + if (memoized != null) { + return memoized; + } + memoized = bddUnionInner(memoTable, b1, b2); + memoTable.unionMemo.put(key, memoized); + return memoized; + } + + private static Bdd bddUnionInner(BddOpMemo memo, Bdd b1, Bdd b2) { if (b1 == b2) { return b1; } else if (b1 instanceof BddAllOrNothing) { @@ -51,23 +69,38 @@ public static Bdd bddUnion(Bdd b1, Bdd b2) { if (cmp < 0L) { return bddCreate(b1Bdd.atom(), b1Bdd.left(), - bddUnion(b1Bdd.middle(), b2), + bddUnionWithMemo(memo, b1Bdd.middle(), b2), b1Bdd.right()); } else if (cmp > 0L) { return bddCreate(b2Bdd.atom(), b2Bdd.left(), - bddUnion(b1, b2Bdd.middle()), + bddUnionWithMemo(memo, b1, b2Bdd.middle()), b2Bdd.right()); } else { return bddCreate(b1Bdd.atom(), - bddUnion(b1Bdd.left(), b2Bdd.left()), - bddUnion(b1Bdd.middle(), b2Bdd.middle()), - bddUnion(b1Bdd.right(), b2Bdd.right())); + bddUnionWithMemo(memo, b1Bdd.left(), b2Bdd.left()), + bddUnionWithMemo(memo, b1Bdd.middle(), b2Bdd.middle()), + bddUnionWithMemo(memo, b1Bdd.right(), b2Bdd.right())); } } } public static Bdd bddIntersect(Bdd b1, Bdd b2) { + return bddIntersectWithMemo(BddOpMemo.create(), b1, b2); + } + + private static Bdd bddIntersectWithMemo(BddOpMemo memo, Bdd b1, Bdd b2) { + BddOpMemoKey key = new BddOpMemoKey(b1, b2); + Bdd memoized = memo.intersectionMemo.get(key); + if (memoized != null) { + return memoized; + } + memoized = bddIntersectInner(memo, b1, b2); + memo.intersectionMemo.put(key, memoized); + return memoized; + } + + private static Bdd bddIntersectInner(BddOpMemo memo, Bdd b1, Bdd b2) { if (b1 == b2) { return b1; } else if (b1 instanceof BddAllOrNothing) { @@ -80,28 +113,43 @@ public static Bdd bddIntersect(Bdd b1, Bdd b2) { long cmp = atomCmp(b1Bdd.atom(), b2Bdd.atom()); if (cmp < 0L) { return bddCreate(b1Bdd.atom(), - bddIntersect(b1Bdd.left(), b2), - bddIntersect(b1Bdd.middle(), b2), - bddIntersect(b1Bdd.right(), b2)); + bddIntersectWithMemo(memo, b1Bdd.left(), b2), + bddIntersectWithMemo(memo, b1Bdd.middle(), b2), + bddIntersectWithMemo(memo, b1Bdd.right(), b2)); } else if (cmp > 0L) { return bddCreate(b2Bdd.atom(), - bddIntersect(b1, b2Bdd.left()), - bddIntersect(b1, b2Bdd.middle()), - bddIntersect(b1, b2Bdd.right())); + bddIntersectWithMemo(memo, b1, b2Bdd.left()), + bddIntersectWithMemo(memo, b1, b2Bdd.middle()), + bddIntersectWithMemo(memo, b1, b2Bdd.right())); } else { return bddCreate(b1Bdd.atom(), - bddIntersect( - bddUnion(b1Bdd.left(), b1Bdd.middle()), - bddUnion(b2Bdd.left(), b2Bdd.middle())), + bddIntersectWithMemo(memo, + bddUnionWithMemo(memo, b1Bdd.left(), b1Bdd.middle()), + bddUnionWithMemo(memo, b2Bdd.left(), b2Bdd.middle())), BddAllOrNothing.bddNothing(), - bddIntersect( - bddUnion(b1Bdd.right(), b1Bdd.middle()), - bddUnion(b2Bdd.right(), b2Bdd.middle()))); + bddIntersectWithMemo(memo, + bddUnionWithMemo(memo, b1Bdd.right(), b1Bdd.middle()), + bddUnionWithMemo(memo, b2Bdd.right(), b2Bdd.middle()))); } } } public static Bdd bddDiff(Bdd b1, Bdd b2) { + return bddDiffWithMemo(BddOpMemo.create(), b1, b2); + } + + private static Bdd bddDiffWithMemo(BddOpMemo memo, Bdd b1, Bdd b2) { + BddOpMemoKey key = new BddOpMemoKey(b1, b2); + Bdd memoized = memo.diffMemo.get(key); + if (memoized != null) { + return memoized; + } + memoized = bddDiffInner(memo, b1, b2); + memo.diffMemo.put(key, memoized); + return memoized; + } + + private static Bdd bddDiffInner(BddOpMemo memo, Bdd b1, Bdd b2) { if (b1 == b2) { return BddAllOrNothing.bddNothing(); } else if (b2 instanceof BddAllOrNothing allOrNothing) { @@ -114,25 +162,27 @@ public static Bdd bddDiff(Bdd b1, Bdd b2) { long cmp = atomCmp(b1Bdd.atom(), b2Bdd.atom()); if (cmp < 0L) { return bddCreate(b1Bdd.atom(), - bddDiff(bddUnion(b1Bdd.left(), b1Bdd.middle()), b2), + bddDiffWithMemo(memo, bddUnionWithMemo(memo, b1Bdd.left(), b1Bdd.middle()), b2), BddAllOrNothing.bddNothing(), - bddDiff(bddUnion(b1Bdd.right(), b1Bdd.middle()), b2)); + bddDiffWithMemo(memo, bddUnionWithMemo(memo, b1Bdd.right(), b1Bdd.middle()), b2)); } else if (cmp > 0L) { return bddCreate(b2Bdd.atom(), - bddDiff(b1, bddUnion(b2Bdd.left(), b2Bdd.middle())), + bddDiffWithMemo(memo, b1, bddUnionWithMemo(memo, b2Bdd.left(), b2Bdd.middle())), BddAllOrNothing.bddNothing(), - bddDiff(b1, bddUnion(b2Bdd.right(), b2Bdd.middle()))); + bddDiffWithMemo(memo, b1, bddUnionWithMemo(memo, b2Bdd.right(), b2Bdd.middle()))); } else { // There is an error in the Castagna paper for this formula. // The union needs to be materialized here. // The original formula does not work in a case like (a0|a1) - a0. // Castagna confirms that the following formula is the correct one. return bddCreate(b1Bdd.atom(), - bddDiff(bddUnion(b1Bdd.left(), b1Bdd.middle()), - bddUnion(b2Bdd.left(), b2Bdd.middle())), + bddDiffWithMemo(memo, + bddUnionWithMemo(memo, b1Bdd.left(), b1Bdd.middle()), + bddUnionWithMemo(memo, b2Bdd.left(), b2Bdd.middle())), BddAllOrNothing.bddNothing(), - bddDiff(bddUnion(b1Bdd.right(), b1Bdd.middle()), - bddUnion(b2Bdd.right(), b2Bdd.middle()))); + bddDiffWithMemo(memo, + bddUnionWithMemo(memo, b1Bdd.right(), b1Bdd.middle()), + bddUnionWithMemo(memo, b2Bdd.right(), b2Bdd.middle()))); } } } @@ -222,4 +272,16 @@ public static String bddToString(Bdd b, boolean inner) { return str; } } + + private record BddOpMemoKey(Bdd b1, Bdd b2) { + + } + + private record BddOpMemo(Map unionMemo, Map intersectionMemo, + Map diffMemo) { + + static BddOpMemo create() { + return new BddOpMemo(new HashMap<>(), new HashMap<>(), new HashMap<>()); + } + } }