Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize BDD operations #43689

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pull_request_full_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
strategy:
fail-fast: false
matrix:
level: [ 1, 2, 3, 4, 5, 6, 7, 8, 9 ]
level: [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ]

steps:
- name: Checkout Repository
Expand Down
112 changes: 87 additions & 25 deletions semtypes/src/main/java/io/ballerina/types/typeops/BddCommonOps.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import io.ballerina.types.subtypedata.BddAllOrNothing;
import io.ballerina.types.subtypedata.BddNode;

import java.util.HashMap;
import java.util.Map;

/**
* Contain common BDD operations found in bdd.bal file.
*
Expand All @@ -38,6 +41,21 @@ public static BddNode bddAtom(Atom atom) {
}

public static Bdd bddUnion(Bdd b1, Bdd b2) {
return bddUnionWithMemo(BddOpMemo.create(), b1, b2);
}

private static Bdd bddUnionWithMemo(BddOpMemo memoTable, Bdd b1, Bdd b2) {
BddOpMemoKey key = new BddOpMemoKey(b1, b2);
Bdd memoized = memoTable.unionMemo.get(key);
if (memoized != null) {
return memoized;
}
memoized = bddUnionInner(memoTable, b1, b2);
memoTable.unionMemo.put(key, memoized);
return memoized;
}

private static Bdd bddUnionInner(BddOpMemo memo, Bdd b1, Bdd b2) {
if (b1 == b2) {
return b1;
} else if (b1 instanceof BddAllOrNothing) {
Expand All @@ -51,23 +69,38 @@ public static Bdd bddUnion(Bdd b1, Bdd b2) {
if (cmp < 0L) {
return bddCreate(b1Bdd.atom(),
b1Bdd.left(),
bddUnion(b1Bdd.middle(), b2),
bddUnionWithMemo(memo, b1Bdd.middle(), b2),
b1Bdd.right());
} else if (cmp > 0L) {
return bddCreate(b2Bdd.atom(),
b2Bdd.left(),
bddUnion(b1, b2Bdd.middle()),
bddUnionWithMemo(memo, b1, b2Bdd.middle()),
b2Bdd.right());
} else {
return bddCreate(b1Bdd.atom(),
bddUnion(b1Bdd.left(), b2Bdd.left()),
bddUnion(b1Bdd.middle(), b2Bdd.middle()),
bddUnion(b1Bdd.right(), b2Bdd.right()));
bddUnionWithMemo(memo, b1Bdd.left(), b2Bdd.left()),
bddUnionWithMemo(memo, b1Bdd.middle(), b2Bdd.middle()),
bddUnionWithMemo(memo, b1Bdd.right(), b2Bdd.right()));
}
}
}

public static Bdd bddIntersect(Bdd b1, Bdd b2) {
return bddIntersectWithMemo(BddOpMemo.create(), b1, b2);
}

private static Bdd bddIntersectWithMemo(BddOpMemo memo, Bdd b1, Bdd b2) {
BddOpMemoKey key = new BddOpMemoKey(b1, b2);
Bdd memoized = memo.intersectionMemo.get(key);
if (memoized != null) {
return memoized;
}
memoized = bddIntersectInner(memo, b1, b2);
memo.intersectionMemo.put(key, memoized);
return memoized;
}

private static Bdd bddIntersectInner(BddOpMemo memo, Bdd b1, Bdd b2) {
if (b1 == b2) {
return b1;
} else if (b1 instanceof BddAllOrNothing) {
Expand All @@ -80,28 +113,43 @@ public static Bdd bddIntersect(Bdd b1, Bdd b2) {
long cmp = atomCmp(b1Bdd.atom(), b2Bdd.atom());
if (cmp < 0L) {
return bddCreate(b1Bdd.atom(),
bddIntersect(b1Bdd.left(), b2),
bddIntersect(b1Bdd.middle(), b2),
bddIntersect(b1Bdd.right(), b2));
bddIntersectWithMemo(memo, b1Bdd.left(), b2),
bddIntersectWithMemo(memo, b1Bdd.middle(), b2),
bddIntersectWithMemo(memo, b1Bdd.right(), b2));
} else if (cmp > 0L) {
return bddCreate(b2Bdd.atom(),
bddIntersect(b1, b2Bdd.left()),
bddIntersect(b1, b2Bdd.middle()),
bddIntersect(b1, b2Bdd.right()));
bddIntersectWithMemo(memo, b1, b2Bdd.left()),
bddIntersectWithMemo(memo, b1, b2Bdd.middle()),
bddIntersectWithMemo(memo, b1, b2Bdd.right()));
} else {
return bddCreate(b1Bdd.atom(),
bddIntersect(
bddUnion(b1Bdd.left(), b1Bdd.middle()),
bddUnion(b2Bdd.left(), b2Bdd.middle())),
bddIntersectWithMemo(memo,
bddUnionWithMemo(memo, b1Bdd.left(), b1Bdd.middle()),
bddUnionWithMemo(memo, b2Bdd.left(), b2Bdd.middle())),
BddAllOrNothing.bddNothing(),
bddIntersect(
bddUnion(b1Bdd.right(), b1Bdd.middle()),
bddUnion(b2Bdd.right(), b2Bdd.middle())));
bddIntersectWithMemo(memo,
bddUnionWithMemo(memo, b1Bdd.right(), b1Bdd.middle()),
bddUnionWithMemo(memo, b2Bdd.right(), b2Bdd.middle())));
}
}
}

public static Bdd bddDiff(Bdd b1, Bdd b2) {
return bddDiffWithMemo(BddOpMemo.create(), b1, b2);
}

private static Bdd bddDiffWithMemo(BddOpMemo memo, Bdd b1, Bdd b2) {
BddOpMemoKey key = new BddOpMemoKey(b1, b2);
Bdd memoized = memo.diffMemo.get(key);
if (memoized != null) {
return memoized;
}
memoized = bddDiffInner(memo, b1, b2);
memo.diffMemo.put(key, memoized);
return memoized;
}

private static Bdd bddDiffInner(BddOpMemo memo, Bdd b1, Bdd b2) {
if (b1 == b2) {
return BddAllOrNothing.bddNothing();
} else if (b2 instanceof BddAllOrNothing allOrNothing) {
Expand All @@ -114,25 +162,27 @@ public static Bdd bddDiff(Bdd b1, Bdd b2) {
long cmp = atomCmp(b1Bdd.atom(), b2Bdd.atom());
if (cmp < 0L) {
return bddCreate(b1Bdd.atom(),
bddDiff(bddUnion(b1Bdd.left(), b1Bdd.middle()), b2),
bddDiffWithMemo(memo, bddUnionWithMemo(memo, b1Bdd.left(), b1Bdd.middle()), b2),
BddAllOrNothing.bddNothing(),
bddDiff(bddUnion(b1Bdd.right(), b1Bdd.middle()), b2));
bddDiffWithMemo(memo, bddUnionWithMemo(memo, b1Bdd.right(), b1Bdd.middle()), b2));
} else if (cmp > 0L) {
return bddCreate(b2Bdd.atom(),
bddDiff(b1, bddUnion(b2Bdd.left(), b2Bdd.middle())),
bddDiffWithMemo(memo, b1, bddUnionWithMemo(memo, b2Bdd.left(), b2Bdd.middle())),
BddAllOrNothing.bddNothing(),
bddDiff(b1, bddUnion(b2Bdd.right(), b2Bdd.middle())));
bddDiffWithMemo(memo, b1, bddUnionWithMemo(memo, b2Bdd.right(), b2Bdd.middle())));
} else {
// There is an error in the Castagna paper for this formula.
// The union needs to be materialized here.
// The original formula does not work in a case like (a0|a1) - a0.
// Castagna confirms that the following formula is the correct one.
return bddCreate(b1Bdd.atom(),
bddDiff(bddUnion(b1Bdd.left(), b1Bdd.middle()),
bddUnion(b2Bdd.left(), b2Bdd.middle())),
bddDiffWithMemo(memo,
bddUnionWithMemo(memo, b1Bdd.left(), b1Bdd.middle()),
bddUnionWithMemo(memo, b2Bdd.left(), b2Bdd.middle())),
BddAllOrNothing.bddNothing(),
bddDiff(bddUnion(b1Bdd.right(), b1Bdd.middle()),
bddUnion(b2Bdd.right(), b2Bdd.middle())));
bddDiffWithMemo(memo,
bddUnionWithMemo(memo, b1Bdd.right(), b1Bdd.middle()),
bddUnionWithMemo(memo, b2Bdd.right(), b2Bdd.middle())));
}
}
}
Expand Down Expand Up @@ -222,4 +272,16 @@ public static String bddToString(Bdd b, boolean inner) {
return str;
}
}

private record BddOpMemoKey(Bdd b1, Bdd b2) {

}

private record BddOpMemo(Map<BddOpMemoKey, Bdd> unionMemo, Map<BddOpMemoKey, Bdd> intersectionMemo,
Map<BddOpMemoKey, Bdd> diffMemo) {

static BddOpMemo create() {
return new BddOpMemo(new HashMap<>(), new HashMap<>(), new HashMap<>());
}
}
}
Loading