Skip to content

Commit

Permalink
Add optimization to runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
heshanpadmasiri committed Dec 5, 2024
1 parent 972fef9 commit 1379b4b
Showing 1 changed file with 86 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

import io.ballerina.runtime.internal.types.semtype.SubTypeData;

import java.util.HashMap;
import java.util.Map;

import static io.ballerina.runtime.api.types.semtype.Conjunction.and;

/**
Expand All @@ -38,10 +41,20 @@ public abstract sealed class Bdd extends SubType implements SubTypeData permits

@Override
public SubType union(SubType other) {
return bddUnion((Bdd) other);
return bddUnion(BddOpMemo.create(), (Bdd) other);
}

private Bdd bddUnion(BddOpMemo memoTable, Bdd other) {
BddOpMemoKey key = new BddOpMemoKey(this, other);
Bdd memoized = memoTable.unionMemo().get(key);
if (memoized == null) {
memoized = bddUnionInner(memoTable, other);
memoTable.unionMemo().put(key, memoized);
}
return memoized;
}

private Bdd bddUnion(Bdd other) {
private Bdd bddUnionInner(BddOpMemo memo, Bdd other) {
if (other == this) {
return this;
} else if (this == BddAllOrNothing.ALL || other == BddAllOrNothing.ALL) {
Expand All @@ -57,18 +70,18 @@ private Bdd bddUnion(Bdd other) {
if (cmp < 0) {
return bddCreate(b1Bdd.atom(),
b1Bdd.left(),
b1Bdd.middle().bddUnion(other),
b1Bdd.middle().bddUnion(memo, other),
b1Bdd.right());
} else if (cmp > 0) {
return bddCreate(b2Bdd.atom(),
b2Bdd.left(),
this.bddUnion(b2Bdd.middle()),
this.bddUnion(memo, b2Bdd.middle()),
b2Bdd.right());
} else {
return bddCreate(b1Bdd.atom(),
b1Bdd.left().bddUnion(b2Bdd.left()),
b1Bdd.middle().bddUnion(b2Bdd.middle()),
b1Bdd.right().bddUnion(b2Bdd.right()));
b1Bdd.left().bddUnion(memo, b2Bdd.left()),
b1Bdd.middle().bddUnion(memo, b2Bdd.middle()),
b1Bdd.right().bddUnion(memo, b2Bdd.right()));
}
}

Expand All @@ -88,10 +101,20 @@ private int atomCmp(Atom a1, Atom a2) {

@Override
public SubType intersect(SubType other) {
return bddIntersect((Bdd) other);
return bddIntersect(BddOpMemo.create(), (Bdd) other);
}

private Bdd bddIntersect(Bdd other) {
private Bdd bddIntersect(BddOpMemo memoTable, Bdd other) {
BddOpMemoKey key = new BddOpMemoKey(this, other);
Bdd memoized = memoTable.intersectionMemo().get(key);
if (memoized == null) {
memoized = bddIntersectInner(memoTable, other);
memoTable.intersectionMemo().put(key, memoized);
}
return memoized;
}

private Bdd bddIntersectInner(BddOpMemo memo, Bdd other) {
if (other == this) {
return this;
} else if (this == BddAllOrNothing.NOTHING || other == BddAllOrNothing.NOTHING) {
Expand All @@ -106,66 +129,80 @@ private Bdd bddIntersect(Bdd other) {
int cmp = atomCmp(b1Bdd.atom(), b2Bdd.atom());
if (cmp < 0) {
return bddCreate(b1Bdd.atom(),
b1Bdd.left().bddIntersect(other),
b1Bdd.middle().bddIntersect(other),
b1Bdd.right().bddIntersect(other));
b1Bdd.left().bddIntersect(memo, other),
b1Bdd.middle().bddIntersect(memo, other),
b1Bdd.right().bddIntersect(memo, other));
} else if (cmp > 0) {
return bddCreate(b2Bdd.atom(),
this.bddIntersect(b2Bdd.left()),
this.bddIntersect(b2Bdd.middle()),
this.bddIntersect(b2Bdd.right()));
this.bddIntersect(memo, b2Bdd.left()),
this.bddIntersect(memo, b2Bdd.middle()),
this.bddIntersect(memo, b2Bdd.right()));
} else {
return bddCreate(b1Bdd.atom(),
b1Bdd.left().bddUnion(b1Bdd.middle()).bddIntersect(b2Bdd.left().bddUnion(b2Bdd.middle())),
b1Bdd.left().bddUnion(memo, b1Bdd.middle())
.bddIntersect(memo, b2Bdd.left().bddUnion(memo, b2Bdd.middle())),
BddAllOrNothing.NOTHING,
b1Bdd.right().bddUnion(b1Bdd.middle()).bddIntersect(b2Bdd.right().bddUnion(b2Bdd.middle())));
b1Bdd.right().bddUnion(memo, b1Bdd.middle())
.bddIntersect(memo, b2Bdd.right().bddUnion(memo, b2Bdd.middle())));
}
}

@Override
public SubType diff(SubType other) {
return bddDiff((Bdd) other);
return bddDiff(BddOpMemo.create(), (Bdd) other);
}

private Bdd bddDiff(BddOpMemo memoTable, Bdd other) {
var key = new BddOpMemoKey(this, other);
var memoized = memoTable.diffMemo.get(key);
if (memoized == null) {
memoized = bddDiffInner(memoTable, other);
memoTable.diffMemo.put(key, memoized);
}
return memoized;
}

private Bdd bddDiff(Bdd other) {
private Bdd bddDiffInner(BddOpMemo memo, Bdd other) {
if (this == other || other == BddAllOrNothing.ALL || this == BddAllOrNothing.NOTHING) {
return BddAllOrNothing.NOTHING;
} else if (other == BddAllOrNothing.NOTHING) {
return this;
} else if (this == BddAllOrNothing.ALL) {
return other.bddComplement();
return other.bddComplement(memo);
}
BddNode b1Bdd = (BddNode) this;
BddNode b2Bdd = (BddNode) other;
int cmp = atomCmp(b1Bdd.atom(), b2Bdd.atom());
if (cmp < 0L) {
return bddCreate(b1Bdd.atom(),
b1Bdd.left().bddUnion(b1Bdd.middle()).bddDiff(other),
b1Bdd.left().bddUnion(memo, b1Bdd.middle()).bddDiff(memo, other),
BddAllOrNothing.NOTHING,
b1Bdd.right().bddUnion(b1Bdd.middle()).bddDiff(other));
b1Bdd.right().bddUnion(memo, b1Bdd.middle()).bddDiff(memo, other));
} else if (cmp > 0L) {
return bddCreate(b2Bdd.atom(),
this.bddDiff(b2Bdd.left().bddUnion(b2Bdd.middle())),
this.bddDiff(memo, b2Bdd.left().bddUnion(memo, b2Bdd.middle())),
BddAllOrNothing.NOTHING,
this.bddDiff(b2Bdd.right().bddUnion(b2Bdd.middle())));
this.bddDiff(memo, b2Bdd.right().bddUnion(memo, b2Bdd.middle())));
} else {
// There is an error in the Castagna paper for this formula.
// The union needs to be materialized here.
// The original formula does not work in a case like (a0|a1) - a0.
// Castagna confirms that the following formula is the correct one.
return bddCreate(b1Bdd.atom(),
b1Bdd.left().bddUnion(b1Bdd.middle()).bddDiff(b2Bdd.left().bddUnion(b2Bdd.middle())),
b1Bdd.left().bddUnion(memo, b1Bdd.middle())
.bddDiff(memo, b2Bdd.left().bddUnion(memo, b2Bdd.middle())),
BddAllOrNothing.NOTHING,
b1Bdd.right().bddUnion(b1Bdd.middle()).bddDiff(b2Bdd.right().bddUnion(b2Bdd.middle())));
b1Bdd.right().bddUnion(memo, b1Bdd.middle())
.bddDiff(memo, b2Bdd.right().bddUnion(memo, b2Bdd.middle())));
}
}

@Override
public SubType complement() {
return bddComplement();
return bddComplement(BddOpMemo.create());
}

private Bdd bddComplement() {
private Bdd bddComplement(BddOpMemo memo) {
if (this == BddAllOrNothing.ALL) {
return BddAllOrNothing.NOTHING;
} else if (this == BddAllOrNothing.NOTHING) {
Expand All @@ -176,26 +213,26 @@ private Bdd bddComplement() {
if (b.right() == nothing) {
return bddCreate(b.atom(),
nothing,
b.left().bddUnion(b.middle()).bddComplement(),
b.middle().bddComplement());
b.left().bddUnion(memo, b.middle()).bddComplement(memo),
b.middle().bddComplement(memo));
} else if (b.left() == nothing) {
return bddCreate(b.atom(),
b.middle().bddComplement(),
b.right().bddUnion(b.middle()).bddComplement(),
b.middle().bddComplement(memo),
b.right().bddUnion(memo, b.middle()).bddComplement(memo),

Check warning on line 221 in bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java

View check run for this annotation

Codecov / codecov/patch

bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java#L220-L221

Added lines #L220 - L221 were not covered by tests
nothing);
} else if (b.middle() == nothing) {
return bddCreate(b.atom(),
b.left().bddComplement(),
b.left().bddUnion(b.right()).bddComplement(),
b.right().bddComplement());
b.left().bddComplement(memo),
b.left().bddUnion(memo, b.right()).bddComplement(memo),
b.right().bddComplement(memo));

Check warning on line 227 in bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java

View check run for this annotation

Codecov / codecov/patch

bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java#L225-L227

Added lines #L225 - L227 were not covered by tests
} else {
// There is a typo in the Frisch PhD thesis for this formula.
// (It has left and right swapped.)
// Castagna (the PhD supervisor) confirms that this is the correct formula.
return bddCreate(b.atom(),
b.left().bddUnion(b.middle()).bddComplement(),
b.left().bddUnion(memo, b.middle()).bddComplement(memo),

Check warning on line 233 in bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java

View check run for this annotation

Codecov / codecov/patch

bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java#L233

Added line #L233 was not covered by tests
nothing,
b.right().bddUnion(b.middle()).bddComplement());
b.right().bddUnion(memo, b.middle()).bddComplement(memo));

Check warning on line 235 in bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java

View check run for this annotation

Codecov / codecov/patch

bvm/ballerina-runtime/src/main/java/io/ballerina/runtime/api/types/semtype/Bdd.java#L235

Added line #L235 was not covered by tests
}
}

Expand All @@ -204,7 +241,7 @@ private Bdd bddCreate(Atom atom, Bdd left, Bdd middle, Bdd right) {
return middle;
}
if (left.equals(right)) {
return left.bddUnion(right);
return left.bddUnion(BddOpMemo.create(), right);
}

return new BddNodeImpl(atom, left, middle, right);
Expand Down Expand Up @@ -282,4 +319,15 @@ public static String bddToString(Bdd b, boolean inner) {
}
}

private record BddOpMemoKey(Bdd b1, Bdd b2) {

}

private record BddOpMemo(Map<BddOpMemoKey, Bdd> unionMemo, Map<BddOpMemoKey, Bdd> intersectionMemo,
Map<BddOpMemoKey, Bdd> diffMemo) {

static BddOpMemo create() {
return new BddOpMemo(new HashMap<>(), new HashMap<>(), new HashMap<>());
}
}
}

0 comments on commit 1379b4b

Please sign in to comment.