diff --git a/linq4j/build.gradle.kts b/linq4j/build.gradle.kts index 7a9db175f24e..05f7e1965dd3 100644 --- a/linq4j/build.gradle.kts +++ b/linq4j/build.gradle.kts @@ -19,4 +19,5 @@ dependencies { implementation("com.google.guava:guava") implementation("org.apache.calcite.avatica:avatica-core") + implementation("org.checkerframework:checker-qual") } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/BaseQueryable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/BaseQueryable.java index fe4a63506224..6099ce7dd1e7 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/BaseQueryable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/BaseQueryable.java @@ -18,6 +18,8 @@ import org.apache.calcite.linq4j.tree.Expression; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.util.Iterator; @@ -34,10 +36,10 @@ public abstract class BaseQueryable extends AbstractQueryable { protected final QueryProvider provider; protected final Type elementType; - protected final Expression expression; + protected final @Nullable Expression expression; public BaseQueryable(QueryProvider provider, Type elementType, - Expression expression) { + @Nullable Expression expression) { this.provider = provider; this.elementType = elementType; this.expression = expression; @@ -51,7 +53,7 @@ public Type getElementType() { return elementType; } - public Expression getExpression() { + public @Nullable Expression getExpression() { return expression; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java index 3cc4c9ed0c8a..42a46962dba6 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultEnumerable.java @@ -33,6 +33,9 @@ import org.apache.calcite.linq4j.function.Predicate1; import org.apache.calcite.linq4j.function.Predicate2; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; + import java.math.BigDecimal; import java.util.Collection; import java.util.Comparator; @@ -69,7 +72,7 @@ protected OrderedEnumerable getThisOrdered() { return this; } - public R foreach(Function1 func) { + public @Nullable R foreach(Function1 func) { R result = null; try (Enumerator enumerator = enumerator()) { while (enumerator.moveNext()) { @@ -89,12 +92,12 @@ protected OrderedQueryable asOrderedQueryable() { return EnumerableDefaults.asOrderedQueryable(this); } - public T aggregate(Function2 func) { + public @Nullable T aggregate(Function2<@Nullable T, T, T> func) { return EnumerableDefaults.aggregate(getThis(), func); } - public TAccumulate aggregate(TAccumulate seed, - Function2 func) { + public @PolyNull TAccumulate aggregate(@PolyNull TAccumulate seed, + Function2<@PolyNull TAccumulate, T, @PolyNull TAccumulate> func) { return EnumerableDefaults.aggregate(getThis(), seed, func); } @@ -191,11 +194,11 @@ public OrderedEnumerable createOrderedEnumerable( keySelector, comparator, descending); } - public Enumerable defaultIfEmpty() { + public Enumerable<@Nullable T> defaultIfEmpty() { return EnumerableDefaults.defaultIfEmpty(getThis()); } - public Enumerable defaultIfEmpty(T value) { + public Enumerable<@PolyNull T> defaultIfEmpty(@PolyNull T value) { return EnumerableDefaults.defaultIfEmpty(getThis(), value); } @@ -211,7 +214,7 @@ public T elementAt(int index) { return EnumerableDefaults.elementAt(getThis(), index); } - public T elementAtOrDefault(int index) { + public @Nullable T elementAtOrDefault(int index) { return EnumerableDefaults.elementAtOrDefault(getThis(), index); } @@ -241,11 +244,11 @@ public T first(Predicate1 predicate) { return EnumerableDefaults.first(getThis(), predicate); } - public T firstOrDefault() { + public @Nullable T firstOrDefault() { return EnumerableDefaults.firstOrDefault(getThis()); } - public T firstOrDefault(Predicate1 predicate) { + public @Nullable T firstOrDefault(Predicate1 predicate) { return EnumerableDefaults.firstOrDefault(getThis(), predicate); } @@ -428,11 +431,11 @@ public T last(Predicate1 predicate) { return EnumerableDefaults.last(getThis(), predicate); } - public T lastOrDefault() { + public @Nullable T lastOrDefault() { return EnumerableDefaults.lastOrDefault(getThis()); } - public T lastOrDefault(Predicate1 predicate) { + public @Nullable T lastOrDefault(Predicate1 predicate) { return EnumerableDefaults.lastOrDefault(getThis(), predicate); } @@ -621,11 +624,11 @@ public T single(Predicate1 predicate) { return EnumerableDefaults.single(getThis(), predicate); } - public T singleOrDefault() { + public @Nullable T singleOrDefault() { return EnumerableDefaults.singleOrDefault(getThis()); } - public T singleOrDefault(Predicate1 predicate) { + public @Nullable T singleOrDefault(Predicate1 predicate) { return EnumerableDefaults.singleOrDefault(getThis(), predicate); } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultQueryable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultQueryable.java index 4d09d433ad4d..0e474c8e5d4c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultQueryable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/DefaultQueryable.java @@ -33,6 +33,8 @@ import org.apache.calcite.linq4j.function.Predicate2; import org.apache.calcite.linq4j.tree.FunctionExpression; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.math.BigDecimal; import java.util.Comparator; @@ -148,7 +150,7 @@ public Enumerator enumerator() { return factory.ofType(getThis(), clazz); } - @Override public Queryable defaultIfEmpty() { + @Override public Queryable<@Nullable T> defaultIfEmpty() { return factory.defaultIfEmpty(getThis()); } @@ -162,7 +164,7 @@ public Enumerator enumerator() { // End disambiguate - public T aggregate(FunctionExpression> selector) { + public @Nullable T aggregate(FunctionExpression> selector) { return factory.aggregate(getThis(), selector); } @@ -243,7 +245,7 @@ public T first(FunctionExpression> predicate) { return factory.first(getThis(), predicate); } - public T firstOrDefault(FunctionExpression> predicate) { + public @Nullable T firstOrDefault(FunctionExpression> predicate) { return factory.firstOrDefault(getThis(), predicate); } @@ -421,7 +423,7 @@ public T single(FunctionExpression> predicate) { return factory.single(getThis(), predicate); } - public T singleOrDefault(FunctionExpression> predicate) { + public @Nullable T singleOrDefault(FunctionExpression> predicate) { return factory.singleOrDefault(getThis(), predicate); } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/Enumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/Enumerable.java index 181d0358ee6c..19d90242b0d5 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/Enumerable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/Enumerable.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j; +import org.checkerframework.framework.qual.Covariant; + /** * Exposes the enumerator, which supports a simple iteration over a collection. * @@ -26,6 +28,7 @@ * * @param Element type */ +@Covariant(0) public interface Enumerable extends RawEnumerable, Iterable, ExtendedEnumerable { /** diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java index 512130d1b75b..4c811aed5631 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableDefaults.java @@ -41,6 +41,11 @@ import com.google.common.collect.Sets; import org.apiguardian.api.API; +import org.checkerframework.checker.nullness.qual.KeyFor; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; +import org.checkerframework.dataflow.qual.Pure; +import org.checkerframework.framework.qual.HasQualifierParameter; import java.math.BigDecimal; import java.util.AbstractList; @@ -67,6 +72,8 @@ import static org.apache.calcite.linq4j.Linq4j.ListEnumerable; import static org.apache.calcite.linq4j.function.Functions.adapt; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Default implementations of methods in the {@link Enumerable} interface. */ @@ -75,10 +82,13 @@ public abstract class EnumerableDefaults { /** * Applies an accumulator function over a sequence. */ - public static TSource aggregate(Enumerable source, - Function2 func) { - TSource result = null; + public static @Nullable TSource aggregate(Enumerable source, + Function2<@Nullable TSource, TSource, TSource> func) { try (Enumerator os = source.enumerator()) { + if (!os.moveNext()) { + return null; + } + TSource result = os.current(); while (os.moveNext()) { TSource o = os.current(); result = func.apply(result, o); @@ -347,7 +357,7 @@ public static boolean contains(Enumerable enumerable, try (Enumerator os = enumerable.enumerator()) { while (os.moveNext()) { TSource o = os.current(); - if (o.equals(element)) { + if (Objects.equals(o, element)) { return true; } } @@ -391,7 +401,7 @@ public static int count(Enumerable enumerable, * the type parameter's default value in a singleton collection if * the sequence is empty. */ - public static Enumerable defaultIfEmpty( + public static Enumerable<@Nullable TSource> defaultIfEmpty( Enumerable enumerable) { return defaultIfEmpty(enumerable, null); } @@ -401,24 +411,25 @@ public static Enumerable defaultIfEmpty( * the specified value in a singleton collection if the sequence * is empty. */ - public static Enumerable defaultIfEmpty( + @SuppressWarnings("return.type.incompatible") + public static Enumerable<@PolyNull TSource> defaultIfEmpty( Enumerable enumerable, - TSource value) { + @PolyNull TSource value) { try (Enumerator os = enumerable.enumerator()) { if (os.moveNext()) { - return Linq4j.asEnumerable(() -> new Iterator() { + return Linq4j.asEnumerable(() -> new Iterator() { private boolean nonFirst; - private Iterator rest; + private @Nullable Iterator rest; public boolean hasNext() { - return !nonFirst || rest.hasNext(); + return !nonFirst || castNonNull(rest).hasNext(); } public TSource next() { if (nonFirst) { - return rest.next(); + return castNonNull(rest).next(); } else { final TSource first = os.current(); nonFirst = true; @@ -501,7 +512,7 @@ public static TSource elementAt(Enumerable enumerable, * sequence or a default value if the index is out of * range. */ - public static TSource elementAtOrDefault( + public static @Nullable TSource elementAtOrDefault( Enumerable enumerable, int index) { final ListEnumerable list = enumerable instanceof ListEnumerable ? ((ListEnumerable) enumerable) @@ -552,7 +563,8 @@ public static Enumerable except( try (Enumerator os = source1.enumerator()) { while (os.moveNext()) { TSource o = os.current(); - collection.remove(o); + @SuppressWarnings("argument.type.incompatible") + boolean unused = collection.remove(o); } return Linq4j.asEnumerable(collection); } @@ -625,7 +637,7 @@ public static TSource first(Enumerable enumerable, * Returns the first element of a sequence, or a * default value if the sequence contains no elements. */ - public static TSource firstOrDefault( + public static @Nullable TSource firstOrDefault( Enumerable enumerable) { try (Enumerator os = enumerable.enumerator()) { if (os.moveNext()) { @@ -640,7 +652,7 @@ public static TSource firstOrDefault( * satisfies a condition or a default value if no such element is * found. */ - public static TSource firstOrDefault(Enumerable enumerable, + public static @Nullable TSource firstOrDefault(Enumerable enumerable, Predicate1 predicate) { for (TSource o : enumerable) { if (predicate.apply(o)) { @@ -860,9 +872,9 @@ private static class SortedAggregateEnumerator comparator; private boolean isInitialized; private boolean isLastMoveNextFalse; - private TAccumulate curAccumulator; + private @Nullable TAccumulate curAccumulator; private Enumerator enumerator; - private TResult curResult; + private @Nullable TResult curResult; SortedAggregateEnumerator( Enumerable enumerable, @@ -888,7 +900,7 @@ private static class SortedAggregateEnumerator Enumerable groupBy while (os.moveNext()) { TSource o = os.current(); TKey key = keySelector.apply(o); + @SuppressWarnings("argument.type.incompatible") TAccumulate accumulator = map.get(key); if (accumulator == null) { accumulator = accumulatorInitializer.apply(); @@ -989,6 +1002,7 @@ private static Enumerable groupBy for (Function1 keySelector : keySelectors) { TSource o = os.current(); TKey key = keySelector.apply(o); + @SuppressWarnings("argument.type.incompatible") TAccumulate accumulator = map.get(key); if (accumulator == null) { accumulator = accumulatorInitializer.apply(); @@ -1041,6 +1055,7 @@ public Enumerator enumerator() { return new Enumerator() { public TResult current() { final Map.Entry entry = entries.current(); + @SuppressWarnings("argument.type.incompatible") final Enumerable inners = innerLookup.get(entry.getKey()); return resultSelector.apply(entry.getValue(), inners == null ? Linq4j.emptyEnumerable() : inners); @@ -1082,6 +1097,7 @@ public Enumerator enumerator() { return new Enumerator() { public TResult current() { final Map.Entry entry = entries.current(); + @SuppressWarnings("argument.type.incompatible") final Enumerable inners = innerLookup.get(entry.getKey()); return resultSelector.apply(entry.getValue(), inners == null ? Linq4j.emptyEnumerable() : inners); @@ -1126,7 +1142,9 @@ public static Enumerable intersect( try (Enumerator os = source0.enumerator()) { while (os.moveNext()) { TSource o = os.current(); - if (set1.remove(o)) { + @SuppressWarnings("argument.type.incompatible") + boolean removed = set1.remove(o); + if (removed) { resultCollection.add(o); } } @@ -1225,7 +1243,7 @@ public static Enumerable hashJoin( Function1 outerKeySelector, Function1 innerKeySelector, Function2 resultSelector, - EqualityComparer comparer, boolean generateNullsOnLeft, + @Nullable EqualityComparer comparer, boolean generateNullsOnLeft, boolean generateNullsOnRight) { return hashEquiJoin_( outer, @@ -1248,7 +1266,7 @@ public static Enumerable hashJoin( Function1 outerKeySelector, Function1 innerKeySelector, Function2 resultSelector, - EqualityComparer comparer, boolean generateNullsOnLeft, + @Nullable EqualityComparer comparer, boolean generateNullsOnLeft, boolean generateNullsOnRight, Predicate2 predicate) { if (predicate == null) { @@ -1281,7 +1299,8 @@ private static Enumerable hashEquiJoin final Function1 outerKeySelector, final Function1 innerKeySelector, final Function2 resultSelector, - final EqualityComparer comparer, final boolean generateNullsOnLeft, + final @Nullable EqualityComparer comparer, + final boolean generateNullsOnLeft, final boolean generateNullsOnRight) { return new AbstractEnumerable() { public Enumerator enumerator() { @@ -1293,7 +1312,7 @@ public Enumerator enumerator() { return new Enumerator() { Enumerator outers = outer.enumerator(); Enumerator inners = Linq4j.emptyEnumerator(); - Set unmatchedKeys = + @Nullable Set unmatchedKeys = generateNullsOnLeft ? new HashSet<>(innerLookup.keySet()) : null; @@ -1314,7 +1333,9 @@ public boolean moveNext() { // not the left. List list = new ArrayList<>(); for (TKey key : unmatchedKeys) { - for (TInner tInner : innerLookup.get(key)) { + @SuppressWarnings("argument.type.incompatible") + Enumerable innerValues = castNonNull(innerLookup.get(key)); + for (TInner tInner : innerValues) { list.add(tInner); } } @@ -1374,7 +1395,8 @@ private static Enumerable hashJoinWith final Function1 outerKeySelector, final Function1 innerKeySelector, final Function2 resultSelector, - final EqualityComparer comparer, final boolean generateNullsOnLeft, + final @Nullable EqualityComparer comparer, + final boolean generateNullsOnLeft, final boolean generateNullsOnRight, final Predicate2 predicate) { return new AbstractEnumerable() { @@ -1397,7 +1419,7 @@ public Enumerator enumerator() { return new Enumerator() { Enumerator outers = outer.enumerator(); Enumerator inners = Linq4j.emptyEnumerator(); - List innersUnmatched = + @Nullable List innersUnmatched = generateNullsOnLeft ? new ArrayList<>(innerToLookUp.toList()) : null; @@ -1483,7 +1505,7 @@ public void close() { public static Enumerable correlateJoin( final JoinType joinType, final Enumerable outer, final Function1> inner, - final Function2 resultSelector) { + final Function2 resultSelector) { if (joinType == JoinType.RIGHT || joinType == JoinType.FULL) { throw new IllegalArgumentException("JoinType " + joinType + " is not valid for correlation"); } @@ -1491,14 +1513,14 @@ public static Enumerable correlateJoin( return new AbstractEnumerable() { public Enumerator enumerator() { return new Enumerator() { - private Enumerator outerEnumerator = outer.enumerator(); - private Enumerator innerEnumerator; - TSource outerValue; - TInner innerValue; + private final Enumerator outerEnumerator = outer.enumerator(); + private @Nullable Enumerator innerEnumerator; + @Nullable TSource outerValue; + @Nullable TInner innerValue; int state = 0; // 0 -- moving outer, 1 moving inner; public TResult current() { - return resultSelector.apply(outerValue, innerValue); + return resultSelector.apply(castNonNull(outerValue), innerValue); } public boolean moveNext() { @@ -1544,6 +1566,7 @@ public boolean moveNext() { continue; case 1: // subsequent move inner + Enumerator innerEnumerator = castNonNull(this.innerEnumerator); if (innerEnumerator.moveNext()) { innerValue = innerEnumerator.current(); return true; @@ -1635,18 +1658,19 @@ public static Enumerable correlateBatchJoin( return new AbstractEnumerable() { @Override public Enumerator enumerator() { return new Enumerator() { - Enumerator outerEnumerator = outer.enumerator(); - List outerValues = new ArrayList<>(batchSize); - List innerValues = new ArrayList<>(); - TSource outerValue; - TInner innerValue; - Enumerable innerEnumerable; - Enumerator innerEnumerator; + final Enumerator outerEnumerator = outer.enumerator(); + final List outerValues = new ArrayList<>(batchSize); + final List innerValues = new ArrayList<>(); + @Nullable TSource outerValue; + @Nullable TInner innerValue; + @Nullable Enumerable innerEnumerable; + @Nullable Enumerator innerEnumerator; boolean innerEnumHasNext = false; boolean atLeastOneResult = false; int i = -1; // outer position int j = -1; // inner position + @SuppressWarnings("argument.type.incompatible") @Override public TResult current() { return resultSelector.apply(outerValue, innerValue); } @@ -1694,7 +1718,7 @@ public static Enumerable correlateBatchJoin( outerValue = outerValues.get(i); // get current outer value nextInnerValue(); // Compare current block row to current inner value - if (predicate.apply(outerValue, innerValue)) { + if (predicate.apply(castNonNull(outerValue), castNonNull(innerValue))) { atLeastOneResult = true; // Skip the rest of inner values in case of // ANTI and SEMI when a match is found @@ -1702,6 +1726,7 @@ public static Enumerable correlateBatchJoin( // Two ways of skipping inner values, // enumerator way and ArrayList way if (i == 0) { + Enumerator innerEnumerator = castNonNull(this.innerEnumerator); while (innerEnumHasNext) { innerValues.add(innerEnumerator.current()); innerEnumHasNext = innerEnumerator.moveNext(); @@ -1737,6 +1762,7 @@ public void nextOuterValue() { private void nextInnerValue() { if (i == 0) { + Enumerator innerEnumerator = castNonNull(this.innerEnumerator); innerValue = innerEnumerator.current(); innerValues.add(innerValue); innerEnumHasNext = innerEnumerator.moveNext(); // next enumerator inner value @@ -1883,10 +1909,11 @@ public Enumerator enumerator() { final Predicate1 predicate = v0 -> { TKey key = outerKeySelector.apply(v0); - if (!innerLookup.get().containsKey(key)) { + @SuppressWarnings("argument.type.incompatible") + Enumerable innersOfKey = innerLookup.get().get(key); + if (innersOfKey == null) { return anti; } - Enumerable innersOfKey = innerLookup.get().get(key); try (Enumerator os = innersOfKey.enumerator()) { while (os.moveNext()) { TInner v1 = os.current(); @@ -1911,7 +1938,7 @@ private static Enumerable semiEquiJoin_( final Enumerable outer, final Enumerable inner, final Function1 outerKeySelector, final Function1 innerKeySelector, - final EqualityComparer comparer, + final @Nullable EqualityComparer comparer, final boolean anti) { return new AbstractEnumerable() { public Enumerator enumerator() { @@ -1937,7 +1964,7 @@ public Enumerator enumerator() { public static Enumerable nestedLoopJoin( final Enumerable outer, final Enumerable inner, final Predicate2 predicate, - Function2 resultSelector, + Function2 resultSelector, final JoinType joinType) { if (!joinType.generatesNullsOnLeft()) { return nestedLoopJoinOptimized(outer, inner, predicate, resultSelector, joinType); @@ -1952,7 +1979,7 @@ public static Enumerable nestedLoopJoin( private static Enumerable nestedLoopJoinAsList( final Enumerable outer, final Enumerable inner, final Predicate2 predicate, - Function2 resultSelector, + Function2 resultSelector, final JoinType joinType) { final boolean generateNullsOnLeft = joinType.generatesNullsOnLeft(); final boolean generateNullsOnRight = joinType.generatesNullsOnRight(); @@ -1976,7 +2003,8 @@ private static Enumerable nestedLoopJoinAsLi break; } else { if (rightUnmatched != null) { - rightUnmatched.remove(right); + @SuppressWarnings("argument.type.incompatible") + boolean unused = rightUnmatched.remove(right); } result.add(resultSelector.apply(left, right)); if (joinType == JoinType.SEMI) { @@ -2007,7 +2035,7 @@ private static Enumerable nestedLoopJoinAsLi private static Enumerable nestedLoopJoinOptimized( final Enumerable outer, final Enumerable inner, final Predicate2 predicate, - Function2 resultSelector, + Function2 resultSelector, final JoinType joinType) { if (joinType == JoinType.RIGHT || joinType == JoinType.FULL) { throw new IllegalArgumentException("JoinType " + joinType + " is unsupported"); @@ -2017,14 +2045,14 @@ private static Enumerable nestedLoopJoinOpti public Enumerator enumerator() { return new Enumerator() { private Enumerator outerEnumerator = outer.enumerator(); - private Enumerator innerEnumerator = null; + private @Nullable Enumerator innerEnumerator = null; private boolean outerMatch = false; // whether the outerValue has matched an innerValue - private TSource outerValue; - private TInner innerValue; + private @Nullable TSource outerValue; + private @Nullable TInner innerValue; private int state = 0; // 0 moving outer, 1 moving inner @Override public TResult current() { - return resultSelector.apply(outerValue, innerValue); + return resultSelector.apply(castNonNull(outerValue), castNonNull(innerValue)); } @Override public boolean moveNext() { @@ -2043,9 +2071,11 @@ public Enumerator enumerator() { continue; case 1: // move inner + Enumerator innerEnumerator = castNonNull(this.innerEnumerator); if (innerEnumerator.moveNext()) { - innerValue = innerEnumerator.current(); - if (predicate.apply(outerValue, innerValue)) { + TInner innerValue = innerEnumerator.current(); + this.innerValue = innerValue; + if (predicate.apply(castNonNull(outerValue), innerValue)) { outerMatch = true; switch (joinType) { case ANTI: // try next outer row @@ -2106,7 +2136,7 @@ private void closeInner() { final Enumerable inner, final Function1 outerKeySelector, final Function1 innerKeySelector, - final Function2 resultSelector, + final Function2 resultSelector, boolean generateNullsOnLeft, boolean generateNullsOnRight) { if (generateNullsOnLeft) { @@ -2146,7 +2176,7 @@ public static boolean isMergeJoinSupported(JoinType joinType) { final Enumerable inner, final Function1 outerKeySelector, final Function1 innerKeySelector, - final Function2 resultSelector, + final Function2 resultSelector, final JoinType joinType, final Comparator comparator) { return mergeJoin(outer, inner, outerKeySelector, innerKeySelector, null, resultSelector, @@ -2175,10 +2205,10 @@ public static boolean isMergeJoinSupported(JoinType joinType) { final Enumerable inner, final Function1 outerKeySelector, final Function1 innerKeySelector, - final Predicate2 extraPredicate, - final Function2 resultSelector, + final @Nullable Predicate2 extraPredicate, + final Function2 resultSelector, final JoinType joinType, - final Comparator comparator) { + final @Nullable Comparator comparator) { if (!isMergeJoinSupported(joinType)) { throw new UnsupportedOperationException("MergeJoin unsupported for join type " + joinType); } @@ -2231,7 +2261,7 @@ public static TSource last(Enumerable enumerable, * Returns the last element of a sequence, or a * default value if the sequence contains no elements. */ - public static TSource lastOrDefault( + public static @Nullable TSource lastOrDefault( Enumerable enumerable) { final ListEnumerable list = enumerable instanceof ListEnumerable ? ((ListEnumerable) enumerable) @@ -2261,7 +2291,7 @@ public static TSource lastOrDefault( * satisfies a condition or a default value if no such element is * found. */ - public static TSource lastOrDefault(Enumerable enumerable, + public static @Nullable TSource lastOrDefault(Enumerable enumerable, Predicate1 predicate) { final ListEnumerable list = enumerable instanceof ListEnumerable ? ((ListEnumerable) enumerable) @@ -2330,8 +2360,7 @@ public static long longCount(Enumerable enumerable, */ public static > TSource max( Enumerable source) { - Function2 max = maxFunction(); - return aggregate(source, null, max); + return aggregate(source, maxFunction()); } /** @@ -2340,8 +2369,7 @@ public static > TSource max( */ public static BigDecimal max(Enumerable source, BigDecimalFunction1 selector) { - Function2 max = maxFunction(); - return aggregate(source.select(selector), null, max); + return aggregate(source.select(selector), maxFunction()); } /** @@ -2351,8 +2379,7 @@ public static BigDecimal max(Enumerable source, */ public static BigDecimal max(Enumerable source, NullableBigDecimalFunction1 selector) { - Function2 max = maxFunction(); - return aggregate(source.select(selector), null, max); + return aggregate(source.select(selector), maxFunction()); } /** @@ -2361,8 +2388,7 @@ public static BigDecimal max(Enumerable source, */ public static double max(Enumerable source, DoubleFunction1 selector) { - return aggregate(source.select(adapt(selector)), null, - Extensions.DOUBLE_MAX); + return castNonNull(aggregate(source.select(adapt(selector)), Extensions.DOUBLE_MAX)); } /** @@ -2372,7 +2398,7 @@ public static double max(Enumerable source, */ public static Double max(Enumerable source, NullableDoubleFunction1 selector) { - return aggregate(source.select(selector), null, Extensions.DOUBLE_MAX); + return aggregate(source.select(selector), Extensions.DOUBLE_MAX); } /** @@ -2381,8 +2407,7 @@ public static Double max(Enumerable source, */ public static int max(Enumerable source, IntegerFunction1 selector) { - return aggregate(source.select(adapt(selector)), null, - Extensions.INTEGER_MAX); + return castNonNull(aggregate(source.select(adapt(selector)), Extensions.INTEGER_MAX)); } /** @@ -2392,7 +2417,7 @@ public static int max(Enumerable source, */ public static Integer max(Enumerable source, NullableIntegerFunction1 selector) { - return aggregate(source.select(selector), null, Extensions.INTEGER_MAX); + return aggregate(source.select(selector), Extensions.INTEGER_MAX); } /** @@ -2401,7 +2426,7 @@ public static Integer max(Enumerable source, */ public static long max(Enumerable source, LongFunction1 selector) { - return aggregate(source.select(adapt(selector)), null, Extensions.LONG_MAX); + return castNonNull(aggregate(source.select(adapt(selector)), Extensions.LONG_MAX)); } /** @@ -2411,7 +2436,7 @@ public static long max(Enumerable source, */ public static Long max(Enumerable source, NullableLongFunction1 selector) { - return aggregate(source.select(selector), null, Extensions.LONG_MAX); + return aggregate(source.select(selector), Extensions.LONG_MAX); } /** @@ -2420,8 +2445,7 @@ public static Long max(Enumerable source, */ public static float max(Enumerable source, FloatFunction1 selector) { - return aggregate(source.select(adapt(selector)), null, - Extensions.FLOAT_MAX); + return castNonNull(aggregate(source.select(adapt(selector)), Extensions.FLOAT_MAX)); } /** @@ -2431,7 +2455,7 @@ public static float max(Enumerable source, */ public static Float max(Enumerable source, NullableFloatFunction1 selector) { - return aggregate(source.select(selector), null, Extensions.FLOAT_MAX); + return aggregate(source.select(selector), Extensions.FLOAT_MAX); } /** @@ -2441,8 +2465,7 @@ public static Float max(Enumerable source, */ public static > TResult max( Enumerable source, Function1 selector) { - Function2 max = maxFunction(); - return aggregate(source.select(selector), null, max); + return aggregate(source.select(selector), maxFunction()); } /** @@ -2451,8 +2474,7 @@ public static > TResult max( */ public static > TSource min( Enumerable source) { - Function2 min = minFunction(); - return aggregate(source, null, min); + return aggregate(source, minFunction()); } @SuppressWarnings("unchecked") @@ -2484,8 +2506,7 @@ public static BigDecimal min(Enumerable source, */ public static BigDecimal min(Enumerable source, NullableBigDecimalFunction1 selector) { - Function2 min = minFunction(); - return aggregate(source.select(selector), null, min); + return aggregate(source.select(selector), minFunction()); } /** @@ -2494,8 +2515,7 @@ public static BigDecimal min(Enumerable source, */ public static double min(Enumerable source, DoubleFunction1 selector) { - return aggregate(source.select(adapt(selector)), null, - Extensions.DOUBLE_MIN); + return castNonNull(aggregate(source.select(adapt(selector)), Extensions.DOUBLE_MIN)); } /** @@ -2505,7 +2525,7 @@ public static double min(Enumerable source, */ public static Double min(Enumerable source, NullableDoubleFunction1 selector) { - return aggregate(source.select(selector), null, Extensions.DOUBLE_MIN); + return aggregate(source.select(selector), Extensions.DOUBLE_MIN); } /** @@ -2514,8 +2534,7 @@ public static Double min(Enumerable source, */ public static int min(Enumerable source, IntegerFunction1 selector) { - return aggregate(source.select(adapt(selector)), null, - Extensions.INTEGER_MIN); + return castNonNull(aggregate(source.select(adapt(selector)), Extensions.INTEGER_MIN)); } /** @@ -2611,7 +2630,7 @@ public static Enumerable orderBy( */ public static Enumerable orderBy( Enumerable source, Function1 keySelector, - Comparator comparator) { + @Nullable Comparator comparator) { return new AbstractEnumerable() { @Override public Enumerator enumerator() { // NOTE: TreeMap allows null comparator. But the caller of this method @@ -2960,7 +2979,7 @@ public static boolean sequenceEqual(Enumerable first, * {@code EqualityComparer}. */ public static boolean sequenceEqual(Enumerable first, - Enumerable second, EqualityComparer comparer) { + Enumerable second, @Nullable EqualityComparer comparer) { Objects.requireNonNull(first); Objects.requireNonNull(second); if (comparer == null) { @@ -3052,7 +3071,7 @@ public static TSource single(Enumerable source, * exception if there is more than one element in the * sequence. */ - public static TSource singleOrDefault(Enumerable source) { + public static @Nullable TSource singleOrDefault(Enumerable source) { TSource toRet = null; try (Enumerator os = source.enumerator()) { if (os.moveNext()) { @@ -3073,7 +3092,7 @@ public static TSource singleOrDefault(Enumerable source) { * element exists; this method throws an exception if more than * one element satisfies the condition. */ - public static TSource singleOrDefault(Enumerable source, + public static @Nullable TSource singleOrDefault(Enumerable source, Predicate1 predicate) { TSource toRet = null; for (TSource s : source) { @@ -3469,6 +3488,7 @@ static LookupImpl toLookup_( while (os.moveNext()) { TSource o = os.current(); final TKey key = keySelector.apply(o); + @SuppressWarnings("nullness") List list = map.get(key); if (list == null) { // for first entry, use a singleton list to save space @@ -3667,7 +3687,7 @@ public static OrderedQueryable asOrderedQueryable( return source instanceof OrderedQueryable ? ((OrderedQueryable) source) : new EnumerableOrderedQueryable<>( - source, (Class) Object.class, null, null); + source, (Class) Object.class, castNonNull(null), null); } /** Default implementation of {@link ExtendedEnumerable#into(Collection)}. */ @@ -3825,12 +3845,15 @@ public void close() { /** Enumerator that casts each value. * - * @param element type */ - static class CastingEnumerator implements Enumerator { - private final Enumerator enumerator; + * @param source element type + * @param element type*/ + @HasQualifierParameter(Nullable.class) + static class CastingEnumerator + implements Enumerator { + private final Enumerator enumerator; private final Class clazz; - CastingEnumerator(Enumerator enumerator, Class clazz) { + CastingEnumerator(Enumerator enumerator, Class clazz) { this.enumerator = enumerator; this.clazz = clazz; } @@ -3872,7 +3895,7 @@ static Wrapped upAs(EqualityComparer comparer, T element) { return comparer.hashCode(element); } - @Override public boolean equals(Object obj) { + @Override public boolean equals(@Nullable Object obj) { //noinspection unchecked return obj == this || obj instanceof Wrapped && comparer.equal(element, ((Wrapped) obj).element); @@ -3896,8 +3919,9 @@ protected WrapMap(Function0, V>> mapProvider, EqualityComparer this.comparer = comparer; } - @Override public Set> entrySet() { - return new AbstractSet>() { + @Override public Set> entrySet() { + return new AbstractSet>() { + @SuppressWarnings("override.return.invalid") @Override public Iterator> iterator() { final Iterator, V>> iterator = map.entrySet().iterator(); @@ -3924,23 +3948,26 @@ public void remove() { }; } + @SuppressWarnings("contracts.conditional.postcondition.not.satisfied") @Override public boolean containsKey(Object key) { return map.containsKey(wrap((K) key)); } + @Pure private Wrapped wrap(K key) { return Wrapped.upAs(comparer, key); } - @Override public V get(Object key) { + @Override public @Nullable V get(Object key) { return map.get(wrap((K) key)); } - @Override public V put(K key, V value) { + @SuppressWarnings("contracts.postcondition.not.satisfied") + @Override public @Nullable V put(K key, V value) { return map.put(wrap(key), value); } - @Override public V remove(Object key) { + @Override public @Nullable V remove(Object key) { return map.remove(wrap((K) key)); } @@ -4003,30 +4030,31 @@ private static class MergeJoinEnumerator rights = new ArrayList<>(); private final Enumerable leftEnumerable; private final Enumerable rightEnumerable; - private Enumerator leftEnumerator = null; - private Enumerator rightEnumerator = null; + private @Nullable Enumerator leftEnumerator = null; + private @Nullable Enumerator rightEnumerator = null; private final Function1 outerKeySelector; private final Function1 innerKeySelector; // extra predicate in case of non equi-join, in case of equi-join it will be null - private final Predicate2 extraPredicate; - private final Function2 resultSelector; + private final @Nullable Predicate2 extraPredicate; + private final Function2 resultSelector; private final JoinType joinType; // key comparator, possibly null (Comparable#compareTo to be used in that case) - private final Comparator comparator; + private final @Nullable Comparator comparator; private boolean done; - private Enumerator results = null; + private @Nullable Enumerator results = null; // used for LEFT/ANTI join: if right input is over, all remaining elements from left are results private boolean remainingLeft; private TResult current = (TResult) DUMMY; + @SuppressWarnings("method.invocation.invalid") MergeJoinEnumerator(Enumerable leftEnumerable, Enumerable rightEnumerable, Function1 outerKeySelector, Function1 innerKeySelector, - Predicate2 extraPredicate, - Function2 resultSelector, + @Nullable Predicate2 extraPredicate, + Function2 resultSelector, JoinType joinType, - Comparator comparator) { + @Nullable Comparator comparator) { this.leftEnumerable = leftEnumerable; this.rightEnumerable = rightEnumerable; this.outerKeySelector = outerKeySelector; @@ -4119,9 +4147,9 @@ private int compareNullsLast(TKey v0, TKey v1) { * enumerator. */ private boolean advance() { for (;;) { - TSource left = leftEnumerator.current(); + TSource left = castNonNull(leftEnumerator).current(); TKey leftKey = outerKeySelector.apply(left); - TInner right = rightEnumerator.current(); + TInner right = castNonNull(rightEnumerator).current(); TKey rightKey = innerKeySelector.apply(right); // iterate until finding matching keys (or ANTI join results) for (;;) { @@ -4315,7 +4343,9 @@ public void reset() { results = null; current = (TResult) DUMMY; remainingLeft = false; - leftEnumerator.reset(); + if (leftEnumerator != null) { + leftEnumerator.reset(); + } if (rightEnumerator != null) { rightEnumerator.reset(); } @@ -4323,7 +4353,9 @@ public void reset() { } public void close() { - leftEnumerator.close(); + if (leftEnumerator != null) { + leftEnumerator.close(); + } if (rightEnumerator != null) { rightEnumerator.close(); } @@ -4337,10 +4369,10 @@ public void close() { * @param right input record type */ private static class CartesianProductJoinEnumerator extends CartesianProductEnumerator { - private final Function2 resultSelector; + private final Function2 resultSelector; @SuppressWarnings("unchecked") - CartesianProductJoinEnumerator(Function2 resultSelector, + CartesianProductJoinEnumerator(Function2 resultSelector, Enumerator outer, Enumerator inner) { super(ImmutableList.of((Enumerator) outer, (Enumerator) inner)); this.resultSelector = resultSelector; @@ -4384,7 +4416,7 @@ public static Enumerable repeatUnion( private boolean seedProcessed = false; private int currentIteration = 0; private final Enumerator seedEnumerator = seed.enumerator(); - private Enumerator iterativeEnumerator = null; + private @Nullable Enumerator iterativeEnumerator = null; // Set to control duplicates, only used if "all" is false private final Set> processed = new HashSet<>(); @@ -4434,8 +4466,9 @@ private boolean checkValue(TSource value) { return false; } + Enumerator iterativeEnumerator = this.iterativeEnumerator; if (iterativeEnumerator == null) { - iterativeEnumerator = iteration.enumerator(); + this.iterativeEnumerator = iterativeEnumerator = iteration.enumerator(); } while (iterativeEnumerator.moveNext()) { @@ -4454,7 +4487,7 @@ private boolean checkValue(TSource value) { // current iteration level (which returned some values) is finished, go to next one current = (TSource) DUMMY; iterativeEnumerator.close(); - iterativeEnumerator = null; + this.iterativeEnumerator = null; currentIteration++; } } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableOrderedQueryable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableOrderedQueryable.java index 0fe30bfeba86..675fd24f7e35 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableOrderedQueryable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableOrderedQueryable.java @@ -20,6 +20,8 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.FunctionExpression; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.Comparator; /** @@ -31,7 +33,7 @@ class EnumerableOrderedQueryable extends EnumerableQueryable implements OrderedQueryable { EnumerableOrderedQueryable(Enumerable enumerable, Class rowType, - QueryProvider provider, Expression expression) { + QueryProvider provider, @Nullable Expression expression) { super(provider, rowType, expression, enumerable); } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableQueryable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableQueryable.java index 77f1ba12540a..02c6ed075e5c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableQueryable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/EnumerableQueryable.java @@ -34,6 +34,8 @@ import org.apache.calcite.linq4j.tree.Expression; import org.apache.calcite.linq4j.tree.FunctionExpression; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.math.BigDecimal; import java.util.Comparator; @@ -49,10 +51,10 @@ class EnumerableQueryable extends DefaultEnumerable private final QueryProvider provider; private final Class elementType; private final Enumerable enumerable; - private final Expression expression; + private final @Nullable Expression expression; EnumerableQueryable(QueryProvider provider, Class elementType, - Expression expression, Enumerable enumerable) { + @Nullable Expression expression, Enumerable enumerable) { this.enumerable = enumerable; this.elementType = elementType; this.provider = provider; @@ -153,7 +155,7 @@ public Queryable reverse() { return EnumerableDefaults.ofType(getThis(), clazz).asQueryable(); } - @Override public Queryable defaultIfEmpty() { + @Override public Queryable<@Nullable T> defaultIfEmpty() { return EnumerableDefaults.defaultIfEmpty(getThis()).asQueryable(); } @@ -167,7 +169,7 @@ public Type getElementType() { return elementType; } - public Expression getExpression() { + public @Nullable Expression getExpression() { return expression; } @@ -177,7 +179,7 @@ public QueryProvider getProvider() { // ............. - public T aggregate(FunctionExpression> selector) { + public @Nullable T aggregate(FunctionExpression> selector) { return EnumerableDefaults.aggregate(getThis(), selector.getFunction()); } @@ -260,7 +262,7 @@ public T first(FunctionExpression> predicate) { return EnumerableDefaults.first(getThis(), predicate.getFunction()); } - public T firstOrDefault(FunctionExpression> predicate) { + public @Nullable T firstOrDefault(FunctionExpression> predicate) { return EnumerableDefaults.firstOrDefault(getThis(), predicate.getFunction()); } @@ -373,7 +375,7 @@ public T last(FunctionExpression> predicate) { return EnumerableDefaults.last(getThis(), predicate.getFunction()); } - public T lastOrDefault(FunctionExpression> predicate) { + public @Nullable T lastOrDefault(FunctionExpression> predicate) { return EnumerableDefaults.lastOrDefault(getThis(), predicate.getFunction()); } @@ -466,7 +468,7 @@ public T single(FunctionExpression> predicate) { return EnumerableDefaults.single(getThis(), predicate.getFunction()); } - public T singleOrDefault(FunctionExpression> predicate) { + public @Nullable T singleOrDefault(FunctionExpression> predicate) { return EnumerableDefaults.singleOrDefault(getThis(), predicate.getFunction()); } @@ -557,7 +559,7 @@ public Queryable zip(Enumerable source1, resultSelector.getFunction()).asQueryable(); } - public T aggregate(Function2 func) { + public @Nullable T aggregate(Function2<@Nullable T, T, T> func) { return EnumerableDefaults.aggregate(getThis(), func); } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/Enumerator.java b/linq4j/src/main/java/org/apache/calcite/linq4j/Enumerator.java index 90bc29ce039d..df879fb16989 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/Enumerator.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/Enumerator.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j; +import org.checkerframework.framework.qual.Covariant; + /** * Supports a simple iteration over a collection. * @@ -26,6 +28,7 @@ * * @param Element type */ +@Covariant(0) public interface Enumerator extends AutoCloseable { /** * Gets the current element in the collection. diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java index 5ed792424254..d29083ed9778 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedEnumerable.java @@ -33,6 +33,10 @@ import org.apache.calcite.linq4j.function.Predicate1; import org.apache.calcite.linq4j.function.Predicate2; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; +import org.checkerframework.framework.qual.Covariant; + import java.math.BigDecimal; import java.util.Collection; import java.util.Comparator; @@ -44,6 +48,7 @@ * * @param Element type */ +@Covariant(0) public interface ExtendedEnumerable { /** @@ -55,21 +60,21 @@ public interface ExtendedEnumerable { * @param func Operation * @param Return type */ - R foreach(Function1 func); + @Nullable R foreach(Function1 func); /** * Applies an accumulator function over a * sequence. */ - TSource aggregate(Function2 func); + @Nullable TSource aggregate(Function2<@Nullable TSource, TSource, TSource> func); /** * Applies an accumulator function over a * sequence. The specified seed value is used as the initial * accumulator value. */ - TAccumulate aggregate(TAccumulate seed, - Function2 func); + @PolyNull TAccumulate aggregate(@PolyNull TAccumulate seed, + Function2<@PolyNull TAccumulate, TSource, @PolyNull TAccumulate> func); /** * Applies an accumulator function over a @@ -263,14 +268,14 @@ TResult aggregate(TAccumulate seed, * the type parameter's default value in a singleton collection if * the sequence is empty. */ - Enumerable defaultIfEmpty(); + Enumerable<@Nullable TSource> defaultIfEmpty(); /** * Returns the elements of the specified sequence or * the specified value in a singleton collection if the sequence * is empty. */ - Enumerable defaultIfEmpty(TSource value); + Enumerable<@PolyNull TSource> defaultIfEmpty(@PolyNull TSource value); /** * Returns distinct elements from a sequence by using @@ -295,7 +300,7 @@ TResult aggregate(TAccumulate seed, * sequence or a default value if the index is out of * range. */ - TSource elementAtOrDefault(int index); + @Nullable TSource elementAtOrDefault(int index); /** * Produces the set difference of two sequences by @@ -344,14 +349,14 @@ Enumerable except(Enumerable enumerable1, * Returns the first element of a sequence, or a * default value if the sequence contains no elements. */ - TSource firstOrDefault(); + @Nullable TSource firstOrDefault(); /** * Returns the first element of the sequence that * satisfies a condition or a default value if no such element is * found. */ - TSource firstOrDefault(Predicate1 predicate); + @Nullable TSource firstOrDefault(Predicate1 predicate); /** * Groups the elements of a sequence according to a @@ -644,14 +649,14 @@ Enumerable correlateJoin( * Returns the last element of a sequence, or a * default value if the sequence contains no elements. */ - TSource lastOrDefault(); + @Nullable TSource lastOrDefault(); /** * Returns the last element of a sequence that * satisfies a condition or a default value if no such element is * found. */ - TSource lastOrDefault(Predicate1 predicate); + @Nullable TSource lastOrDefault(Predicate1 predicate); /** * Returns an long that represents the total number @@ -956,7 +961,7 @@ boolean sequenceEqual(Enumerable enumerable1, * exception if there is more than one element in the * sequence. */ - TSource singleOrDefault(); + @Nullable TSource singleOrDefault(); /** * Returns the only element of a sequence that @@ -964,7 +969,7 @@ boolean sequenceEqual(Enumerable enumerable1, * element exists; this method throws an exception if more than * one element satisfies the condition. */ - TSource singleOrDefault(Predicate1 predicate); + @Nullable TSource singleOrDefault(Predicate1 predicate); /** * Bypasses a specified number of elements in a diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedQueryable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedQueryable.java index 573dcb908492..8af18f3424ab 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedQueryable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/ExtendedQueryable.java @@ -33,6 +33,9 @@ import org.apache.calcite.linq4j.function.Predicate2; import org.apache.calcite.linq4j.tree.FunctionExpression; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.framework.qual.Covariant; + import java.math.BigDecimal; import java.util.Comparator; @@ -41,13 +44,14 @@ * * @param Element type */ +@Covariant(0) interface ExtendedQueryable extends ExtendedEnumerable { /** * Applies an accumulator function over a sequence. */ - TSource aggregate( - FunctionExpression> selector); + @Nullable TSource aggregate( + FunctionExpression> selector); /** * Applies an accumulator function over a @@ -172,7 +176,7 @@ Long averageNullableLong( * the type parameter's default value in a singleton collection if * the sequence is empty. */ - Queryable defaultIfEmpty(); + Queryable<@Nullable TSource> defaultIfEmpty(); /** * Returns distinct elements from a sequence by using @@ -227,7 +231,7 @@ Queryable except(Enumerable enumerable, * satisfies a specified condition or a default value if no such * element is found. */ - TSource firstOrDefault(FunctionExpression> predicate); + @Nullable TSource firstOrDefault(FunctionExpression> predicate); /** * Groups the elements of a sequence according to a @@ -400,7 +404,7 @@ Queryable join(Enumerable inner, * satisfies a condition or a default value if no such element is * found. */ - TSource lastOrDefault(FunctionExpression> predicate); + @Nullable TSource lastOrDefault(FunctionExpression> predicate); /** * Returns an long that represents the number of @@ -564,7 +568,7 @@ Queryable selectManyN( * exception if there is more than one element in the * sequence. */ - TSource singleOrDefault(); + @Nullable TSource singleOrDefault(); /** * Returns the only element of a sequence that @@ -572,7 +576,7 @@ Queryable selectManyN( * element exists; this method throws an exception if more than * one element satisfies the condition. */ - TSource singleOrDefault(FunctionExpression> predicate); + @Nullable TSource singleOrDefault(FunctionExpression> predicate); /** * Bypasses a specified number of elements in a diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/Grouping.java b/linq4j/src/main/java/org/apache/calcite/linq4j/Grouping.java index f3091afc27dc..aede1d0254f9 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/Grouping.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/Grouping.java @@ -16,12 +16,15 @@ */ package org.apache.calcite.linq4j; +import org.checkerframework.framework.qual.Covariant; + /** * Represents a collection of objects that have a common key. * * @param Key type * @param Element type */ +@Covariant(0) public interface Grouping extends Enumerable { /** * Gets the key of this Grouping. diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/GroupingImpl.java b/linq4j/src/main/java/org/apache/calcite/linq4j/GroupingImpl.java index 9444dc540cfd..49c1bfe4b427 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/GroupingImpl.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/GroupingImpl.java @@ -16,6 +16,9 @@ */ package org.apache.calcite.linq4j; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.Collection; import java.util.Map; import java.util.Objects; @@ -26,7 +29,8 @@ * @param Key type * @param Value type */ -class GroupingImpl extends AbstractEnumerable +@SuppressWarnings("type.argument.type.incompatible") +class GroupingImpl extends AbstractEnumerable implements Grouping, Map.Entry> { private final K key; private final Collection values; @@ -48,7 +52,7 @@ class GroupingImpl extends AbstractEnumerable return key.hashCode() ^ values.hashCode(); } - @Override public boolean equals(Object obj) { + @Override public boolean equals(@Nullable Object obj) { return obj instanceof GroupingImpl && key.equals(((GroupingImpl) obj).key) && values.equals(((GroupingImpl) obj).values); diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/Linq4j.java b/linq4j/src/main/java/org/apache/calcite/linq4j/Linq4j.java index b1251e898b02..4c4f94227e7c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/Linq4j.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/Linq4j.java @@ -18,6 +18,8 @@ import org.apache.calcite.linq4j.function.Function1; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; @@ -28,6 +30,8 @@ import java.util.Objects; import java.util.RandomAccess; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Utility and factory methods for Linq4j. */ @@ -36,7 +40,7 @@ private Linq4j() {} private static final Object DUMMY = new Object(); - public static Method getMethod(String className, String methodName, + public static @Nullable Method getMethod(String className, String methodName, Class... parameterTypes) { try { return Class.forName(className).getMethod(methodName, parameterTypes); @@ -423,7 +427,7 @@ public static T requireNonNull(T o) { } /** Closes an iterator, if it can be closed. */ - private static void closeIterator(Iterator iterator) { + private static void closeIterator(@Nullable Iterator iterator) { if (iterator instanceof AutoCloseable) { try { ((AutoCloseable) iterator).close(); @@ -441,7 +445,7 @@ private static void closeIterator(Iterator iterator) { @SuppressWarnings("unchecked") static class IterableEnumerator implements Enumerator { private final Iterable iterable; - Iterator iterator; + @Nullable Iterator iterator; T current; IterableEnumerator(Iterable iterable) { @@ -458,7 +462,7 @@ public T current() { } public boolean moveNext() { - if (iterator.hasNext()) { + if (castNonNull(iterator).hasNext()) { current = iterator.next(); return true; } @@ -564,6 +568,7 @@ protected Collection getCollection() { return getCollection().size(); } + @SuppressWarnings("argument.type.incompatible") @Override public boolean contains(T element) { return getCollection().contains(element); } @@ -644,7 +649,7 @@ public void close() { /** Enumerator that returns one null element. * * @param element type */ - private static class SingletonNullEnumerator implements Enumerator { + private static class SingletonNullEnumerator<@Nullable E> implements Enumerator { int i = 0; public E current() { diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/LookupImpl.java b/linq4j/src/main/java/org/apache/calcite/linq4j/LookupImpl.java index def63c4fad20..ab210c8eef88 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/LookupImpl.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/LookupImpl.java @@ -18,6 +18,9 @@ import org.apache.calcite.linq4j.function.Function2; +import org.checkerframework.checker.nullness.qual.KeyFor; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.AbstractCollection; import java.util.AbstractMap; import java.util.AbstractSet; @@ -81,6 +84,7 @@ public boolean isEmpty() { return map.isEmpty(); } + @SuppressWarnings("contracts.conditional.postcondition.not.satisfied") public boolean containsKey(Object key) { return map.containsKey(key); } @@ -91,17 +95,18 @@ public boolean containsValue(Object value) { return map.containsValue(list); } - public Enumerable get(Object key) { + public @Nullable Enumerable get(Object key) { final List list = map.get(key); return list == null ? null : Linq4j.asEnumerable(list); } - public Enumerable put(K key, Enumerable value) { + @SuppressWarnings("contracts.postcondition.not.satisfied") + public @Nullable Enumerable put(K key, Enumerable value) { final List list = map.put(key, value.toList()); return list == null ? null : Linq4j.asEnumerable(list); } - public Enumerable remove(Object key) { + public @Nullable Enumerable remove(Object key) { final List list = map.remove(key); return list == null ? null : Linq4j.asEnumerable(list); } @@ -116,7 +121,8 @@ public void clear() { map.clear(); } - public Set keySet() { + @SuppressWarnings("return.type.incompatible") + public Set<@KeyFor("this") K> keySet() { return map.keySet(); } @@ -147,8 +153,9 @@ public int size() { }; } - public Set>> entrySet() { - final Set>> entries = map.entrySet(); + @SuppressWarnings("return.type.incompatible") + public Set>> entrySet() { + final Set>> entries = map.entrySet(); return new AbstractSet>>() { public Iterator>> iterator() { final Iterator>> iterator = entries.iterator(); diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/MemoryEnumerator.java b/linq4j/src/main/java/org/apache/calcite/linq4j/MemoryEnumerator.java index 13167b122844..c09661cde38b 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/MemoryEnumerator.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/MemoryEnumerator.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.concurrent.atomic.AtomicInteger; /** @@ -23,7 +25,7 @@ * * @param Row value */ -public class MemoryEnumerator implements Enumerator> { +public class MemoryEnumerator<@Nullable E> implements Enumerator> { private final Enumerator enumerator; private final MemoryFactory memoryFactory; private final AtomicInteger prevCounter; diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/MemoryFactory.java b/linq4j/src/main/java/org/apache/calcite/linq4j/MemoryFactory.java index 45d641a6c051..f68fa91777e7 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/MemoryFactory.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/MemoryFactory.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.Arrays; /** @@ -30,7 +32,7 @@ public class MemoryFactory { // Index: 0 1 2 3 4 // Idea -2 -1 0 +1 +2 ModularInteger offset; - private Object[] values; + private final @Nullable Object[] values; public MemoryFactory(int history, int future) { this.history = history; @@ -62,10 +64,10 @@ public static class Memory { private final int history; private final int future; private final ModularInteger offset; - private final Object[] values; + private final @Nullable Object[] values; public Memory(int history, int future, - ModularInteger offset, Object[] values) { + ModularInteger offset, @Nullable Object[] values) { this.history = history; this.future = future; this.offset = offset; diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/ModularInteger.java b/linq4j/src/main/java/org/apache/calcite/linq4j/ModularInteger.java index e6b7695eefe7..12cc470c468d 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/ModularInteger.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/ModularInteger.java @@ -18,6 +18,8 @@ import com.google.common.base.Preconditions; +import org.checkerframework.checker.nullness.qual.Nullable; + /** * Represents an integer in modular arithmetic. * Its {@code value} is between 0 and {@code m - 1} for some modulus {@code m}. @@ -35,7 +37,7 @@ class ModularInteger { this.modulus = modulus; } - @Override public boolean equals(Object obj) { + @Override public boolean equals(@Nullable Object obj) { return obj == this || obj instanceof ModularInteger && value == ((ModularInteger) obj).value diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/NullnessUtil.java b/linq4j/src/main/java/org/apache/calcite/linq4j/NullnessUtil.java new file mode 100644 index 000000000000..afd8127834a1 --- /dev/null +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/NullnessUtil.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.linq4j; + +import org.checkerframework.checker.nullness.qual.EnsuresNonNull; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.dataflow.qual.Pure; + +/** + * The methods in this class allow to cast nullable reference to a non-nullable one. + * This is an internal class, and it is not meant to be used as a public API. + *

The class enables to remove checker-qual runtime dependency, and helps IDEs to see + * the resulting types of {@code castNonNull} better

+ */ +@SuppressWarnings({"cast.unsafe", "NullableProblems", "contracts.postcondition.not.satisfied"}) +public class NullnessUtil { + private NullnessUtil() { + } + + /** + * Casts nullable reference to a non-null type. + * + * @param the type of the reference + * @param ref a reference of @Nullable type, that is non-null at run time + * @return the argument, casted to have the type qualifier @NonNull + */ + @Pure + public static @EnsuresNonNull("#1") + @NonNull T castNonNull( + @Nullable T ref) { + assert ref != null : "Misuse of castNonNull: called with a null argument"; + return (@NonNull T) ref; + } + + /** + * Casts nullable reference to a non-null type. + * + * @param the type of the reference + * @param ref a reference of @Nullable type, that is non-null at run time + * @param message extra message in case the argument is null + * @return the argument, casted to have the type qualifier @NonNull + */ + @Pure + public static @EnsuresNonNull("#1") + @NonNull T castNonNull( + @Nullable T ref, String message) { + assert ref != null : "Misuse of castNonNull: called with a null argument " + message; + return (@NonNull T) ref; + } +} diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/Queryable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/Queryable.java index afdda6a16640..9f9f049744ee 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/Queryable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/Queryable.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j; +import org.checkerframework.framework.qual.Covariant; + /** * Provides functionality to evaluate queries against a specific data source * wherein the type of the data is known. @@ -24,5 +26,6 @@ * * @param Element type */ +@Covariant(0) public interface Queryable extends RawQueryable, ExtendedQueryable { } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableDefaults.java b/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableDefaults.java index 8200c17470d1..37f0813b7992 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableDefaults.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableDefaults.java @@ -36,11 +36,15 @@ import org.apache.calcite.linq4j.tree.Expressions; import org.apache.calcite.linq4j.tree.FunctionExpression; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.math.BigDecimal; import java.util.Comparator; import java.util.Iterator; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Default implementations for methods in the {@link Queryable} interface. */ @@ -741,13 +745,13 @@ public static Queryable reverse(Queryable source) { public static Queryable select(Queryable source, FunctionExpression> selector) { return source.getProvider().createQuery( - Expressions.call(source.getExpression(), "select", selector), + Expressions.call(castNonNull(source.getExpression()), "select", selector), functionResultType(selector)); } private static Type functionResultType( FunctionExpression> selector) { - return selector.body.getType(); + return castNonNull(selector.body).getType(); } /** @@ -1202,7 +1206,7 @@ public Type getElementType() { return original.getElementType(); } - public Expression getExpression() { + public @Nullable Expression getExpression() { return original.getExpression(); } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableFactory.java b/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableFactory.java index 9db00aefed06..2fb884735830 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableFactory.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableFactory.java @@ -33,6 +33,10 @@ import org.apache.calcite.linq4j.function.Predicate2; import org.apache.calcite.linq4j.tree.FunctionExpression; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; +import org.checkerframework.framework.qual.Covariant; + import java.math.BigDecimal; import java.util.Comparator; @@ -41,13 +45,14 @@ * * @param Element type */ +@Covariant(0) public interface QueryableFactory { /** * Applies an accumulator function over a sequence. */ - T aggregate(Queryable source, - FunctionExpression> selector); + @Nullable T aggregate(Queryable source, + FunctionExpression> selector); /** * Applies an accumulator function over a @@ -201,14 +206,14 @@ boolean contains(Queryable source, T element, * the type parameter's default value in a singleton collection if * the sequence is empty. */ - Queryable defaultIfEmpty(Queryable source); + Queryable<@Nullable T> defaultIfEmpty(Queryable source); /** * Returns the elements of the specified sequence or * the specified value in a singleton collection if the sequence * is empty. */ - Queryable defaultIfEmpty(Queryable source, T value); + Queryable<@PolyNull T> defaultIfEmpty(Queryable source, @PolyNull T value); /** * Returns distinct elements from a sequence by using @@ -282,14 +287,14 @@ Queryable except(Queryable source, Enumerable enumerable, * Returns the first element of a sequence, or a * default value if the sequence contains no elements. */ - T firstOrDefault(Queryable source); + @Nullable T firstOrDefault(Queryable source); /** * Returns the first element of a sequence that * satisfies a specified condition or a default value if no such * element is found. */ - T firstOrDefault(Queryable source, + @Nullable T firstOrDefault(Queryable source, FunctionExpression> predicate); /** diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableRecorder.java b/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableRecorder.java index 17d399d58f70..f480ee31b91c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableRecorder.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/QueryableRecorder.java @@ -33,12 +33,18 @@ import org.apache.calcite.linq4j.function.Predicate2; import org.apache.calcite.linq4j.tree.FunctionExpression; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; +import org.checkerframework.framework.qual.Covariant; + import java.lang.reflect.Type; import java.math.BigDecimal; import java.util.Comparator; import static org.apache.calcite.linq4j.QueryableDefaults.NonLeafReplayableQueryable; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Implementation of {@link QueryableFactory} that records each event * and returns an object that can replay the event when you call its @@ -47,6 +53,7 @@ * * @param Element type */ +@Covariant(0) public class QueryableRecorder implements QueryableFactory { private static final QueryableRecorder INSTANCE = new QueryableRecorder(); @@ -55,8 +62,8 @@ public static QueryableRecorder instance() { return INSTANCE; } - public T aggregate(final Queryable source, - final FunctionExpression> func) { + public @Nullable T aggregate(final Queryable source, + final FunctionExpression> func) { return new QueryableDefaults.NonLeafReplayableQueryable(source) { public void replay(QueryableFactory factory) { factory.aggregate(source, func); @@ -253,7 +260,7 @@ public void replay(QueryableFactory factory) { }.castSingle(); // CHECKSTYLE: IGNORE 0 } - public Queryable defaultIfEmpty(final Queryable source) { + public Queryable<@Nullable T> defaultIfEmpty(final Queryable source) { return new NonLeafReplayableQueryable(source) { public void replay(QueryableFactory factory) { factory.defaultIfEmpty(source); @@ -261,7 +268,8 @@ public void replay(QueryableFactory factory) { }; } - public Queryable defaultIfEmpty(final Queryable source, final T value) { + @SuppressWarnings("return.type.incompatible") + public Queryable<@PolyNull T> defaultIfEmpty(final Queryable source, final @PolyNull T value) { return new NonLeafReplayableQueryable(source) { public void replay(QueryableFactory factory) { factory.defaultIfEmpty(source, value); @@ -347,7 +355,7 @@ public void replay(QueryableFactory factory) { }.single(); // CHECKSTYLE: IGNORE 0 } - public T firstOrDefault(final Queryable source) { + public @Nullable T firstOrDefault(final Queryable source) { return new NonLeafReplayableQueryable(source) { public void replay(QueryableFactory factory) { factory.firstOrDefault(source); @@ -355,7 +363,7 @@ public void replay(QueryableFactory factory) { }.single(); // CHECKSTYLE: IGNORE 0 } - public T firstOrDefault(final Queryable source, + public @Nullable T firstOrDefault(final Queryable source, final FunctionExpression> predicate) { return new NonLeafReplayableQueryable(source) { public void replay(QueryableFactory factory) { @@ -688,7 +696,7 @@ public void replay(QueryableFactory factory) { } @Override public Type getElementType() { - return selector.body.type; + return castNonNull(selector.body).type; } }.castQueryable(); // CHECKSTYLE: IGNORE 0 } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/RawEnumerable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/RawEnumerable.java index 4dc6538f6bc8..35d3ac1db50e 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/RawEnumerable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/RawEnumerable.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j; +import org.checkerframework.framework.qual.Covariant; + /** * Exposes the enumerator, which supports a simple iteration over a collection, * without the extension methods. @@ -29,6 +31,7 @@ * @param Element type * @see Enumerable */ +@Covariant(0) public interface RawEnumerable { /** * Returns an enumerator that iterates through a collection. diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/RawQueryable.java b/linq4j/src/main/java/org/apache/calcite/linq4j/RawQueryable.java index 7a2bacfe3867..953f6a07411c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/RawQueryable.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/RawQueryable.java @@ -18,6 +18,9 @@ import org.apache.calcite.linq4j.tree.Expression; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.framework.qual.Covariant; + import java.lang.reflect.Type; /** @@ -29,6 +32,7 @@ * * @param Element type */ +@Covariant(0) public interface RawQueryable extends Enumerable { /** * Gets the type of the element(s) that are returned when the expression @@ -38,8 +42,9 @@ public interface RawQueryable extends Enumerable { /** * Gets the expression tree that is associated with this Queryable. + * @return null if the expression is not available */ - Expression getExpression(); + @Nullable Expression getExpression(); /** * Gets the query provider that is associated with this data source. diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/function/Functions.java b/linq4j/src/main/java/org/apache/calcite/linq4j/function/Functions.java index 539ea5016d31..15c57038793b 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/function/Functions.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/function/Functions.java @@ -16,6 +16,10 @@ */ package org.apache.calcite.linq4j.function; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.framework.qual.DefaultQualifier; +import org.checkerframework.framework.qual.TypeUseLocation; + import java.io.Serializable; import java.lang.reflect.Type; import java.math.BigDecimal; @@ -75,7 +79,8 @@ private Functions() {} private static final EqualityComparer ARRAY_COMPARER = new ArrayEqualityComparer(); - private static final Function1 CONSTANT_NULL_FUNCTION1 = s -> null; + private static final Function1 CONSTANT_NULL_FUNCTION1 = + (Function1) s -> null; private static final Function1 TO_STRING_FUNCTION1 = (Function1) Object::toString; @@ -532,7 +537,7 @@ public boolean equal(T v1, T v2) { } public int hashCode(T t) { - return t == null ? 0x789d : selector.apply(t).hashCode(); + return t == null ? 0x789d : Objects.hashCode(selector.apply(t)); } } @@ -613,7 +618,7 @@ public int compare(Comparable o1, Comparable o2) { * @param result type * @param first argument type * @param second argument type */ - private static final class Ignore + private static final class Ignore<@Nullable R, T0, T1> implements Function0, Function1, Function2 { public R apply() { return null; @@ -627,7 +632,13 @@ public R apply(T0 p0, T1 p1) { return null; } - static final Ignore INSTANCE = new Ignore(); + @DefaultQualifier( + value = Nullable.class, + locations = { + TypeUseLocation.LOWER_BOUND, + TypeUseLocation.UPPER_BOUND, + }) + static final Ignore INSTANCE = new Ignore<>(); } /** List that generates each element using a function. diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/package-info.java b/linq4j/src/main/java/org/apache/calcite/linq4j/package-info.java index cf57933faeaa..726df9815271 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/package-info.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/package-info.java @@ -18,4 +18,11 @@ /** * Language-integrated query for Java (linq4j) main package. */ +@DefaultQualifier(value = NonNull.class, locations = TypeUseLocation.FIELD) +@DefaultQualifier(value = NonNull.class, locations = TypeUseLocation.PARAMETER) +@DefaultQualifier(value = NonNull.class, locations = TypeUseLocation.RETURN) package org.apache.calcite.linq4j; + +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.framework.qual.DefaultQualifier; +import org.checkerframework.framework.qual.TypeUseLocation; diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/AbstractNode.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/AbstractNode.java index 014db9045ccb..0a2622564d9d 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/AbstractNode.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/AbstractNode.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.util.Objects; @@ -70,12 +72,12 @@ public Node accept(Shuttle shuttle) { "visit not supported: " + getClass() + ":" + nodeType); } - public Object evaluate(Evaluator evaluator) { + public @Nullable Object evaluate(Evaluator evaluator) { throw new RuntimeException( "evaluation not supported: " + getClass() + ":" + nodeType); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ArrayLengthRecordField.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ArrayLengthRecordField.java index d6a6e2061e83..f04fb444d066 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ArrayLengthRecordField.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ArrayLengthRecordField.java @@ -16,10 +16,14 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Array; import java.lang.reflect.Type; import java.util.Objects; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Length field of a RecordType. */ @@ -50,15 +54,15 @@ public int getModifiers() { return 0; } - public Object get(Object o) throws IllegalAccessException { - return Array.getLength(o); + @Override public Object get(@Nullable Object o) throws IllegalAccessException { + return Array.getLength(castNonNull(o)); } - public Type getDeclaringClass() { + @Override public Type getDeclaringClass() { return clazz; } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BinaryExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BinaryExpression.java index a18025c2c9f0..83278f5c6342 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BinaryExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BinaryExpression.java @@ -16,16 +16,20 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.util.Objects; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Represents an expression that has a binary operator. */ public class BinaryExpression extends Expression { public final Expression expression0; public final Expression expression1; - private final Primitive primitive; + private final @Nullable Primitive primitive; BinaryExpression(ExpressionType nodeType, Type type, Expression expression0, Expression expression1) { @@ -51,9 +55,12 @@ public R accept(Visitor visitor) { public Object evaluate(Evaluator evaluator) { switch (nodeType) { case AndAlso: - return (Boolean) expression0.evaluate(evaluator) - && (Boolean) expression1.evaluate(evaluator); + return evaluateBoolean(evaluator, expression0) + && evaluateBoolean(evaluator, expression1); case Add: + if (primitive == null) { + throw cannotEvaluate(); + } switch (primitive) { case INT: return evaluateInt(expression0, evaluator) + evaluateInt(expression1, evaluator); @@ -71,6 +78,9 @@ public Object evaluate(Evaluator evaluator) { throw cannotEvaluate(); } case Divide: + if (primitive == null) { + throw cannotEvaluate(); + } switch (primitive) { case INT: return evaluateInt(expression0, evaluator) / evaluateInt(expression1, evaluator); @@ -88,9 +98,11 @@ public Object evaluate(Evaluator evaluator) { throw cannotEvaluate(); } case Equal: - return expression0.evaluate(evaluator) - .equals(expression1.evaluate(evaluator)); + return Objects.equals(expression0.evaluate(evaluator), expression1.evaluate(evaluator)); case GreaterThan: + if (primitive == null) { + throw cannotEvaluate(); + } switch (primitive) { case INT: return evaluateInt(expression0, evaluator) > evaluateInt(expression1, evaluator); @@ -108,6 +120,9 @@ public Object evaluate(Evaluator evaluator) { throw cannotEvaluate(); } case GreaterThanOrEqual: + if (primitive == null) { + throw cannotEvaluate(); + } switch (primitive) { case INT: return evaluateInt(expression0, evaluator) >= evaluateInt(expression1, evaluator); @@ -125,6 +140,9 @@ public Object evaluate(Evaluator evaluator) { throw cannotEvaluate(); } case LessThan: + if (primitive == null) { + throw cannotEvaluate(); + } switch (primitive) { case INT: return evaluateInt(expression0, evaluator) < evaluateInt(expression1, evaluator); @@ -142,6 +160,9 @@ public Object evaluate(Evaluator evaluator) { throw cannotEvaluate(); } case LessThanOrEqual: + if (primitive == null) { + throw cannotEvaluate(); + } switch (primitive) { case INT: return evaluateInt(expression0, evaluator) <= evaluateInt(expression1, evaluator); @@ -159,6 +180,9 @@ public Object evaluate(Evaluator evaluator) { throw cannotEvaluate(); } case Multiply: + if (primitive == null) { + throw cannotEvaluate(); + } switch (primitive) { case INT: return evaluateInt(expression0, evaluator) * evaluateInt(expression1, evaluator); @@ -176,12 +200,14 @@ public Object evaluate(Evaluator evaluator) { throw cannotEvaluate(); } case NotEqual: - return !expression0.evaluate(evaluator) - .equals(expression1.evaluate(evaluator)); + return !Objects.equals(expression0.evaluate(evaluator), expression1.evaluate(evaluator)); case OrElse: - return (Boolean) expression0.evaluate(evaluator) - || (Boolean) expression1.evaluate(evaluator); + return evaluateBoolean(evaluator, expression0) + || evaluateBoolean(evaluator, expression1); case Subtract: + if (primitive == null) { + throw cannotEvaluate(); + } switch (primitive) { case INT: return evaluateInt(expression0, evaluator) - evaluateInt(expression1, evaluator); @@ -217,31 +243,39 @@ private RuntimeException cannotEvaluate() { + nodeType + ", primitive=" + primitive); } + private boolean evaluateBoolean(Evaluator evaluator, Expression expression) { + return (Boolean) castNonNull(expression.evaluate(evaluator)); + } + + private Number evaluateNumber(Expression expression, Evaluator evaluator) { + return (Number) castNonNull(expression.evaluate(evaluator)); + } + private int evaluateInt(Expression expression, Evaluator evaluator) { - return ((Number) expression.evaluate(evaluator)).intValue(); + return evaluateNumber(expression, evaluator).intValue(); } private short evaluateShort(Expression expression, Evaluator evaluator) { - return ((Number) expression.evaluate(evaluator)).shortValue(); + return evaluateNumber(expression, evaluator).shortValue(); } private long evaluateLong(Expression expression, Evaluator evaluator) { - return ((Number) expression.evaluate(evaluator)).longValue(); + return evaluateNumber(expression, evaluator).longValue(); } private byte evaluateByte(Expression expression, Evaluator evaluator) { - return ((Number) expression.evaluate(evaluator)).byteValue(); + return evaluateNumber(expression, evaluator).byteValue(); } private float evaluateFloat(Expression expression, Evaluator evaluator) { - return ((Number) expression.evaluate(evaluator)).floatValue(); + return evaluateNumber(expression, evaluator).floatValue(); } private double evaluateDouble(Expression expression, Evaluator evaluator) { - return ((Number) expression.evaluate(evaluator)).doubleValue(); + return evaluateNumber(expression, evaluator).doubleValue(); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BlockBuilder.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BlockBuilder.java index f09b03ff4ef2..923a6bc18e56 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BlockBuilder.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BlockBuilder.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Modifier; import java.lang.reflect.Type; import java.util.ArrayList; @@ -26,6 +28,8 @@ import java.util.Map; import java.util.Set; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Builder for {@link BlockStatement}. * @@ -41,7 +45,7 @@ public class BlockBuilder { new HashMap<>(); private final boolean optimizing; - private final BlockBuilder parent; + private final @Nullable BlockBuilder parent; private static final Shuttle OPTIMIZE_SHUTTLE = new OptimizeShuttle(); @@ -66,7 +70,7 @@ public BlockBuilder(boolean optimizing) { * * @param optimizing Whether to eliminate common sub-expressions */ - public BlockBuilder(boolean optimizing, BlockBuilder parent) { + public BlockBuilder(boolean optimizing, @Nullable BlockBuilder parent) { this.optimizing = optimizing; this.parent = parent; } @@ -85,7 +89,7 @@ public void clear() { * (possibly a variable) that represents the result of the newly added * block. */ - public Expression append(String name, BlockStatement block) { + public @Nullable Expression append(String name, BlockStatement block) { return append(name, block, true); } @@ -99,7 +103,7 @@ public Expression append(String name, BlockStatement block) { * a variable. Do not do this if the expression has * side-effects or a time-dependent value. */ - public Expression append(String name, BlockStatement block, + public @Nullable Expression append(String name, BlockStatement block, boolean optimize) { if (statements.size() > 0) { Statement lastStatement = statements.get(statements.size() - 1); @@ -153,7 +157,7 @@ public Expression append(String name, BlockStatement block, result = ((DeclarationStatement) statement).parameter; } else if (statement instanceof GotoStatement) { statements.remove(statements.size() - 1); - result = append_(name, ((GotoStatement) statement).expression, + result = append_(name, castNonNull(((GotoStatement) statement).expression), optimize); if (isSimpleExpression(result)) { // already simple; no need to declare a variable or @@ -184,7 +188,7 @@ public Expression append(String name, Expression expression) { /** * Appends an expression to a list of statements, if it is not null. */ - public Expression appendIfNotNull(String name, Expression expression) { + public @Nullable Expression appendIfNotNull(String name, @Nullable Expression expression) { if (expression == null) { return null; } @@ -233,7 +237,7 @@ private Expression append_(String name, Expression expression, * @param expr expression to test * @return true when given expression is safe to always inline */ - protected boolean isSimpleExpression(Expression expr) { + protected boolean isSimpleExpression(@Nullable Expression expr) { if (expr instanceof ParameterExpression || expr instanceof ConstantExpression) { return true; @@ -285,7 +289,7 @@ private Expression normalizeDeclaration(DeclarationStatement decl) { * @param expr expression to test * @return existing ParameterExpression or null */ - public DeclarationStatement getComputedExpression(Expression expr) { + public @Nullable DeclarationStatement getComputedExpression(Expression expr) { if (parent != null) { DeclarationStatement decl = parent.getComputedExpression(expr); if (decl != null) { diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BlockStatement.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BlockStatement.java index 2aaa40bbf9da..a5ee64eedb19 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BlockStatement.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/BlockStatement.java @@ -16,6 +16,9 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.initialization.qual.UnderInitialization; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.util.HashSet; import java.util.List; @@ -38,7 +41,10 @@ public class BlockStatement extends Statement { assert distinctVariables(true); } - private boolean distinctVariables(boolean fail) { + private boolean distinctVariables( + @UnderInitialization(BlockStatement.class) BlockStatement this, + boolean fail + ) { Set names = new HashSet<>(); for (Statement statement : statements) { if (statement instanceof DeclarationStatement) { @@ -75,7 +81,7 @@ public R accept(Visitor visitor) { writer.end("}\n"); } - @Override public Object evaluate(Evaluator evaluator) { + @Override public @Nullable Object evaluate(Evaluator evaluator) { Object o = null; for (Statement statement : statements) { o = statement.evaluate(evaluator); @@ -83,7 +89,7 @@ public R accept(Visitor visitor) { return o; } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Blocks.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Blocks.java index 05e3201e661e..b34a1bbad0a4 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Blocks.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Blocks.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** *

Helper methods concerning {@link BlockStatement}s.

* @@ -71,7 +73,7 @@ public static Expression simple(BlockStatement block) { if (block.statements.size() == 1) { Statement statement = block.statements.get(0); if (statement instanceof GotoStatement) { - return ((GotoStatement) statement).expression; + return castNonNull(((GotoStatement) statement).expression); } } throw new AssertionError("not a simple block: " + block); diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/CatchBlock.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/CatchBlock.java index 732d63b6b88b..21aff57d96f7 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/CatchBlock.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/CatchBlock.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.Objects; /** @@ -31,7 +33,7 @@ public CatchBlock(ParameterExpression parameter, this.body = body; } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ClassDeclaration.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ClassDeclaration.java index f4e0c9a266c7..8baaabad88da 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ClassDeclaration.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ClassDeclaration.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Modifier; import java.lang.reflect.Type; import java.util.List; @@ -70,7 +72,7 @@ public R accept(Visitor visitor) { return visitor.visit(this); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ClassDeclarationFinder.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ClassDeclarationFinder.java index 6373931df3da..fe1a4554d3d6 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ClassDeclarationFinder.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ClassDeclarationFinder.java @@ -18,6 +18,8 @@ import org.apache.calcite.linq4j.function.Function1; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; @@ -30,7 +32,7 @@ * created for optimizing a new expression tree. */ public class ClassDeclarationFinder extends Shuttle { - protected final ClassDeclarationFinder parent; + protected final @Nullable ClassDeclarationFinder parent; /** * The list of new final static fields to be added to the current class. @@ -152,7 +154,7 @@ protected ClassDeclarationFinder(ClassDeclarationFinder parent) { } @Override public Expression visit(NewExpression newExpression, - List arguments, List memberDeclarations) { + List arguments, @Nullable List memberDeclarations) { if (parent == null) { // Unable to optimize since no wrapper class exists to put fields to. arguments = newExpression.arguments; @@ -254,7 +256,7 @@ protected boolean isConstant(Iterable list) { * @param expression input expression * @return always returns null */ - protected ParameterExpression findDeclaredExpression(Expression expression) { + protected @Nullable ParameterExpression findDeclaredExpression(Expression expression) { return null; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConditionalExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConditionalExpression.java index 55c00cefba48..92875e737934 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConditionalExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConditionalExpression.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.util.List; import java.util.Objects; @@ -59,7 +61,7 @@ public R accept(Visitor visitor) { } } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConditionalStatement.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConditionalStatement.java index 9d15a458bed9..fe3f7203324a 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConditionalStatement.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConditionalStatement.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.List; import java.util.Objects; @@ -72,7 +74,7 @@ private static E last(List collection) { return collection.get(collection.size() - 1); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java index 38b75879bc5c..e05f7f887c43 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantExpression.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Type; @@ -28,22 +30,21 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Represents an expression that has a constant value. */ public class ConstantExpression extends Expression { - public final Object value; + public final @Nullable Object value; - public ConstantExpression(Type type, Object value) { + public ConstantExpression(Type type, @Nullable Object value) { super(ExpressionType.Constant, type); this.value = value; if (value != null) { if (type instanceof Class) { Class clazz = (Class) type; - Primitive primitive = Primitive.of(clazz); - if (primitive != null) { - clazz = primitive.boxClass; - } + clazz = Primitive.box(clazz); if (!clazz.isInstance(value) && !((clazz == Float.class || clazz == Double.class) && value instanceof BigDecimal)) { @@ -54,7 +55,7 @@ public ConstantExpression(Type type, Object value) { } } - public Object evaluate(Evaluator evaluator) { + public @Nullable Object evaluate(Evaluator evaluator) { return value; } @@ -77,15 +78,13 @@ public R accept(Visitor visitor) { } private static ExpressionWriter write(ExpressionWriter writer, - final Object value, Type type) { + final Object value, @Nullable Type type) { if (value == null) { return writer.append("null"); } if (type == null) { type = value.getClass(); - if (Primitive.isBox(type)) { - type = Primitive.ofBox(type).primitiveClass; - } + type = Primitive.unbox(type); } if (value instanceof String) { escapeString(writer.getBuf(), (String) value); @@ -171,7 +170,7 @@ private static ExpressionWriter write(ExpressionWriter writer, return writer.append(recordType.getName()).append(".class"); } if (value.getClass().isArray()) { - writer.append("new ").append(value.getClass().getComponentType()); + writer.append("new ").append(castNonNull(value.getClass().getComponentType())); list(writer, Primitive.asList(value), "[] {\n", ",\n", "}"); return writer; } @@ -195,7 +194,8 @@ private static ExpressionWriter write(ExpressionWriter writer, writer.append("new ").append(value.getClass()); list(writer, Arrays.stream(value.getClass().getFields()) - .map(field -> { + // <@Nullable Object> is needed for CheckerFramework + .<@Nullable Object>map(field -> { try { return field.get(value); } catch (IllegalAccessException e) { @@ -275,7 +275,7 @@ private static ExpressionWriter set(ExpressionWriter writer, Set set, return writer.append(end); } - private static Constructor matchingConstructor(Object value) { + private static @Nullable Constructor matchingConstructor(Object value) { final Field[] fields = value.getClass().getFields(); for (Constructor constructor : value.getClass().getConstructors()) { if (argsMatchFields(fields, constructor.getParameterTypes())) { @@ -328,7 +328,7 @@ private static void escapeString(StringBuilder buf, String s) { buf.append('"'); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { // REVIEW: Should constants with the same value and different type // (e.g. 3L and 3) be considered equal. if (this == o) { diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantUntypedNull.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantUntypedNull.java index 48ab4ee8452b..204f82d8e901 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantUntypedNull.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstantUntypedNull.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + /** * Represents a constant null of unknown type * Java allows type inference for such nulls, thus "null" cannot always be @@ -36,7 +38,7 @@ private ConstantUntypedNull() { writer.append("null"); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { return o == INSTANCE; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstructorDeclaration.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstructorDeclaration.java index ebc244531ec2..47677dd6fe61 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstructorDeclaration.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ConstructorDeclaration.java @@ -16,12 +16,14 @@ */ package org.apache.calcite.linq4j.tree; -import com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; import java.lang.reflect.Modifier; import java.lang.reflect.Type; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; /** * Declaration of a constructor. @@ -29,13 +31,13 @@ public class ConstructorDeclaration extends MemberDeclaration { public final int modifier; public final Type resultType; - public final List parameters; + public final List<@NonNull ParameterExpression> parameters; public final BlockStatement body; /** Cached hash code for the expression. */ private int hash; public ConstructorDeclaration(int modifier, Type declaredAgainst, - List parameters, BlockStatement body) { + List<@NonNull ParameterExpression> parameters, BlockStatement body) { assert parameters != null : "parameters should not be null"; assert body != null : "body should not be null"; assert declaredAgainst != null : "declaredAgainst should not be null"; @@ -65,18 +67,18 @@ public void accept(ExpressionWriter writer) { writer .append(resultType) .list("(", ", ", ")", - Lists.transform(parameters, parameter -> { + parameters.stream().map(parameter -> { final String modifiers1 = Modifier.toString(parameter.modifier); return modifiers1 + (modifiers1.isEmpty() ? "" : " ") + Types.className(parameter.getType()) + " " + parameter.name; - })) + }).collect(Collectors.toList())) .append(' ').append(body); writer.newlineAndIndent(); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/DeclarationStatement.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/DeclarationStatement.java index 31a529dcecce..d4fe2ccdc1ff 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/DeclarationStatement.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/DeclarationStatement.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Modifier; import java.util.Objects; @@ -25,10 +27,10 @@ public class DeclarationStatement extends Statement { public final int modifiers; public final ParameterExpression parameter; - public final Expression initializer; + public final @Nullable Expression initializer; public DeclarationStatement(int modifiers, ParameterExpression parameter, - Expression initializer) { + @Nullable Expression initializer) { super(ExpressionType.Declaration, Void.TYPE); assert parameter != null : "parameter should not be null"; this.modifiers = modifiers; @@ -78,7 +80,7 @@ public void accept2(ExpressionWriter writer, boolean withType) { } } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/DeterministicCodeOptimizer.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/DeterministicCodeOptimizer.java index e9168071c704..e20bd255dc71 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/DeterministicCodeOptimizer.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/DeterministicCodeOptimizer.java @@ -21,6 +21,8 @@ import com.google.common.collect.ImmutableSet; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.lang.reflect.Modifier; @@ -167,7 +169,7 @@ && isMethodDeterministic(methodCallExpression.method)) { } @Override public Expression visit(MethodCallExpression methodCallExpression, - Expression targetExpression, List expressions) { + @Nullable Expression targetExpression, List expressions) { Expression result = super.visit(methodCallExpression, targetExpression, expressions); @@ -176,7 +178,7 @@ && isMethodDeterministic(methodCallExpression.method)) { } @Override public Expression visit(MemberExpression memberExpression, - Expression expression) { + @Nullable Expression expression) { Expression result = super.visit(memberExpression, expression); if (isConstant(expression) @@ -187,7 +189,7 @@ && isMethodDeterministic(methodCallExpression.method)) { } @Override public MemberDeclaration visit(FieldDeclaration fieldDeclaration, - Expression initializer) { + @Nullable Expression initializer) { if (Modifier.isStatic(fieldDeclaration.modifier)) { // Avoid optimization of static fields, since we'll have to track order // of static declarations. @@ -224,7 +226,7 @@ && isMethodDeterministic(methodCallExpression.method)) { * @param expression input expression * @return parameter of the already existing declaration, or null */ - protected ParameterExpression findDeclaredExpression(Expression expression) { + protected @Nullable ParameterExpression findDeclaredExpression(Expression expression) { if (!dedup.isEmpty()) { ParameterExpression pe = dedup.get(expression); if (pe != null) { @@ -294,7 +296,7 @@ protected String inventFieldName(Expression expression) { * @param expression expression to test * @return true when the expression is known to be constant */ - @Override protected boolean isConstant(Expression expression) { + @Override protected boolean isConstant(@Nullable Expression expression) { return expression == null || expression instanceof ConstantExpression || !constants.isEmpty() && constants.containsKey(expression) @@ -330,7 +332,7 @@ protected boolean isConstructorDeterministic(NewExpression newExpression) { && constructor.isAnnotationPresent(Deterministic.class); } - private Constructor getConstructor(Class klass) { + private @Nullable Constructor getConstructor(Class klass) { try { return klass.getConstructor(); } catch (NoSuchMethodException e) { @@ -347,7 +349,7 @@ private Constructor getConstructor(Class klass) { */ protected boolean allMethodsDeterministic(Class klass) { return DETERMINISTIC_CLASSES.contains(klass) - || klass.getCanonicalName().equals("org.apache.calcite.avatica.util.DateTimeUtils") + || "org.apache.calcite.avatica.util.DateTimeUtils".equals(klass.getCanonicalName()) || klass.isAnnotationPresent(Deterministic.class); } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Evaluator.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Evaluator.java index a7fc34d5a042..f6cc41b7dd2c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Evaluator.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Evaluator.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.ArrayList; import java.util.List; @@ -24,12 +26,12 @@ */ class Evaluator { final List parameters = new ArrayList<>(); - final List values = new ArrayList<>(); + final List<@Nullable Object> values = new ArrayList<>(); Evaluator() { } - void push(ParameterExpression parameter, Object value) { + void push(ParameterExpression parameter, @Nullable Object value) { parameters.add(parameter); values.add(value); } @@ -42,7 +44,7 @@ void pop(int n) { } } - Object peek(ParameterExpression param) { + @Nullable Object peek(ParameterExpression param) { for (int i = parameters.size() - 1; i >= 0; i--) { if (parameters.get(i) == param) { return values.get(i); @@ -51,7 +53,7 @@ Object peek(ParameterExpression param) { throw new RuntimeException("parameter " + param + " not on stack"); } - Object evaluate(Node expression) { + @Nullable Object evaluate(Node expression) { return ((AbstractNode) expression).evaluate(this); } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ExpressionType.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ExpressionType.java index 636620af0318..5f20ae169d37 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ExpressionType.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ExpressionType.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + /** *

Analogous to LINQ's System.Linq.Expressions.ExpressionType.

*/ @@ -596,8 +598,8 @@ public enum ExpressionType { */ While; - final String op; - final String op2; + final @Nullable String op; + final @Nullable String op2; final boolean postfix; final int lprec; final int rprec; @@ -607,16 +609,16 @@ public enum ExpressionType { this(null, false, 0, false); } - ExpressionType(String op, boolean postfix, int prec, boolean right) { + ExpressionType(@Nullable String op, boolean postfix, int prec, boolean right) { this(op, null, postfix, prec, right); } - ExpressionType(String op, String op2, boolean postfix, int prec, + ExpressionType(@Nullable String op, @Nullable String op2, boolean postfix, int prec, boolean right) { this(op, op2, postfix, prec, right, false); } - ExpressionType(String op, String op2, boolean postfix, int prec, + ExpressionType(@Nullable String op, @Nullable String op2, boolean postfix, int prec, boolean right, boolean modifiesLvalue) { this.op = op; this.op2 = op2; diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ExpressionWriter.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ExpressionWriter.java index cb7c9113d90e..8bc640bac6e4 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ExpressionWriter.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ExpressionWriter.java @@ -18,6 +18,8 @@ import org.apache.calcite.avatica.util.Spacer; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.util.Iterator; @@ -130,13 +132,13 @@ public ExpressionWriter append(AbstractNode o) { return this; } - public ExpressionWriter append(Object o) { + public ExpressionWriter append(@Nullable Object o) { checkIndent(); buf.append(o); return this; } - public ExpressionWriter append(String s) { + public ExpressionWriter append(@Nullable String s) { checkIndent(); buf.append(s); return this; diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java index 8260055542d9..9f7bac24db55 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Expressions.java @@ -26,6 +26,10 @@ import com.google.common.collect.ImmutableList; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.checkerframework.checker.nullness.qual.PolyNull; + import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Member; @@ -41,6 +45,8 @@ import java.util.Objects; import java.util.UUID; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Utility methods for expressions, including a lot of factory methods. */ @@ -283,7 +289,7 @@ public static BlockStatement block(Statement... statements) { * Creates a BlockExpression that contains the given expressions, * has no variables and has specific result type. */ - public static BlockStatement block(Type type, + public static BlockStatement block(@Nullable Type type, Iterable expressions) { List list = toList(expressions); if (type == null) { @@ -343,7 +349,7 @@ public static GotoStatement break_(LabelTarget labelTarget, * static method that has arguments. */ public static MethodCallExpression call(Method method, - Iterable arguments) { + Iterable arguments) { return new MethodCallExpression(method, null, toList(arguments)); } @@ -360,8 +366,8 @@ public static MethodCallExpression call(Method method, * Creates a MethodCallExpression that represents a call to a * method that takes arguments. */ - public static MethodCallExpression call(Expression expression, Method method, - Iterable arguments) { + public static MethodCallExpression call(@Nullable Expression expression, Method method, + Iterable arguments) { return new MethodCallExpression(method, expression, toList(arguments)); } @@ -501,14 +507,6 @@ public static Expression condition(Expression test, Expression ifTrue, return makeTernary(ExpressionType.Conditional, test, ifTrue, ifFalse); } - private static Type box(Type type) { - Primitive primitive = Primitive.of(type); - if (primitive != null) { - return primitive.boxClass; - } - return type; - } - /** Returns whether an expression always evaluates to null. */ public static boolean isConstantNull(Expression e) { return e instanceof ConstantExpression @@ -541,19 +539,11 @@ public static ConditionalExpression condition(Expression test, * classes that have a constructor with a parameter for each field, and * arrays.

*/ - public static ConstantExpression constant(Object value) { - Class type; + public static ConstantExpression constant(@Nullable Object value) { if (value == null) { return ConstantUntypedNull.INSTANCE; - } else { - final Class clazz = value.getClass(); - final Primitive primitive = Primitive.ofBox(clazz); - if (primitive != null) { - type = primitive.primitiveClass; - } else { - type = clazz; - } } + Class type = Primitive.unbox(value.getClass()); return new ConstantExpression(type, value); } @@ -561,13 +551,13 @@ public static ConstantExpression constant(Object value) { * Creates a ConstantExpression that has the Value and Type * properties set to the specified values. */ - public static ConstantExpression constant(Object value, Type type) { + public static ConstantExpression constant(@Nullable Object value, Type type) { if (value != null && type instanceof Class) { // Fix up value so that it matches type. - Class clazz = (Class) type; + Class clazz = (Class) type; Primitive primitive = Primitive.ofBoxOr(clazz); if (primitive != null) { - clazz = primitive.boxClass; + clazz = castNonNull(primitive.boxClass); } if ((clazz == Float.class || clazz == Double.class) && value instanceof BigDecimal) { @@ -839,7 +829,7 @@ public static MemberExpression field(Expression expression, Field field) { /** * Creates a MemberExpression that represents accessing a field. */ - public static MemberExpression field(Expression expression, + public static MemberExpression field(@Nullable Expression expression, PseudoField field) { return makeMemberAccess(expression, field); } @@ -857,7 +847,7 @@ public static MemberExpression field(Expression expression, /** * Creates a MemberExpression that represents accessing a field. */ - public static MemberExpression field(Expression expression, Type type, + public static MemberExpression field(@Nullable Expression expression, Type type, String fieldName) { PseudoField field = Types.getField(fieldName, type); return makeMemberAccess(expression, field); @@ -1384,7 +1374,7 @@ public static ListInitExpression listInit(NewExpression newExpression, */ public static ForStatement for_( Iterable declarations, - Expression condition, Expression post, Statement body) { + @Nullable Expression condition, @Nullable Expression post, Statement body) { return new ForStatement(toList(declarations), condition, post, body); } @@ -1393,7 +1383,7 @@ public static ForStatement for_( */ public static ForStatement for_( DeclarationStatement declaration, - Expression condition, Expression post, Statement body) { + @Nullable Expression condition, @Nullable Expression post, Statement body) { return new ForStatement(Collections.singletonList(declaration), condition, post, body); } @@ -1434,7 +1424,7 @@ public static BinaryExpression makeBinary(ExpressionType binaryType, /** Returns an expression to box the value of a primitive expression. * E.g. {@code box(e, Primitive.INT)} returns {@code Integer.valueOf(e)}. */ public static Expression box(Expression expression, Primitive primitive) { - return call(primitive.boxClass, "valueOf", expression); + return call(castNonNull(primitive.boxClass), "valueOf", expression); } /** Converts e.g. "anInteger" to "Integer.valueOf(anInteger)". */ @@ -1450,7 +1440,7 @@ public static Expression box(Expression expression) { * E.g. {@code unbox(e, Primitive.INT)} returns {@code e.intValue()}. * It is assumed that e is of the right box type (or {@link Number})."Value */ public static Expression unbox(Expression expression, Primitive primitive) { - return call(expression, primitive.primitiveName + "Value"); + return call(expression, castNonNull(primitive.primitiveName) + "Value"); } /** Converts e.g. "anInteger" to "anInteger.intValue()". */ @@ -1526,12 +1516,12 @@ public static TernaryExpression makeTernary(ExpressionType ternaryType, switch (ternaryType) { case Conditional: if (e1 instanceof ConstantUntypedNull) { - type = box(e2.getType()); + type = Primitive.box(e2.getType()); if (e1.getType() != type) { e1 = constant(null, type); } } else if (e2 instanceof ConstantUntypedNull) { - type = box(e1.getType()); + type = Primitive.box(e1.getType()); if (e2.getType() != type) { e2 = constant(null, type); } @@ -1585,7 +1575,7 @@ public static GotoStatement makeGoto(GotoExpressionKind kind, /** * Creates a MemberExpression that represents accessing a field. */ - public static MemberExpression makeMemberAccess(Expression expression, + public static MemberExpression makeMemberAccess(@Nullable Expression expression, PseudoField member) { return new MemberExpression(expression, member); } @@ -1631,7 +1621,7 @@ public static UnaryExpression makeUnary(ExpressionType expressionType, * method, by calling the appropriate factory method. */ public static UnaryExpression makeUnary(ExpressionType expressionType, - Expression expression, Type type, Method method) { + Expression expression, Type type, @Nullable Method method) { assert type != null; return new UnaryExpression(expressionType, type, expression); } @@ -1716,7 +1706,7 @@ public static ConstructorDeclaration constructorDecl(int modifier, * Declares a field with an initializer. */ public static FieldDeclaration fieldDecl(int modifier, - ParameterExpression parameter, Expression initializer) { + ParameterExpression parameter, @Nullable Expression initializer) { return new FieldDeclaration(modifier, parameter, initializer); } @@ -1899,7 +1889,8 @@ public static UnaryExpression negate(Expression expression) { * negation operation. */ public static UnaryExpression negate(Expression expression, Method method) { - return makeUnary(ExpressionType.Negate, expression, null, method); + // TODO: use method + return negate(expression); } /** @@ -1917,7 +1908,8 @@ public static UnaryExpression negateChecked(Expression expression) { */ public static UnaryExpression negateChecked(Expression expression, Method method) { - return makeUnary(ExpressionType.NegateChecked, expression, null, method); + throw new UnsupportedOperationException("not implemented"); + //return makeUnary(ExpressionType.NegateChecked, expression, null, method); } /** @@ -1966,7 +1958,7 @@ public static NewExpression new_(Type type, Expression... arguments) { */ public static NewExpression new_(Type type, Iterable arguments, - Iterable memberDeclarations) { + @Nullable Iterable memberDeclarations) { return new NewExpression(type, toList(arguments), toList(memberDeclarations)); } @@ -2038,7 +2030,7 @@ public static NewExpression new_(Constructor constructor, * that has a specified rank. */ public static NewArrayExpression newArrayBounds(Type type, int dimension, - Expression bound) { + @Nullable Expression bound) { return new NewArrayExpression(type, dimension, bound, null); } @@ -2103,7 +2095,8 @@ public static UnaryExpression not(Expression expression) { * operation. The implementing method can be specified. */ public static UnaryExpression not(Expression expression, Method method) { - return makeUnary(ExpressionType.Not, expression, null, method); + // TODO: use method + return not(expression); } /** @@ -2509,13 +2502,13 @@ public static GotoStatement return_(LabelTarget labelTarget) { * Creates a GotoExpression representing a return statement. The * value passed to the label upon jumping can be specified. */ - public static GotoStatement return_(LabelTarget labelTarget, - Expression expression) { + public static GotoStatement return_(@Nullable LabelTarget labelTarget, + @Nullable Expression expression) { return makeGoto(GotoExpressionKind.Return, labelTarget, expression); } public static GotoStatement makeGoto(GotoExpressionKind kind, - LabelTarget labelTarget, Expression expression) { + @Nullable LabelTarget labelTarget, @Nullable Expression expression) { return new GotoStatement(kind, labelTarget, expression); } @@ -2698,6 +2691,7 @@ public static BinaryExpression subtractChecked(Expression left, * Creates a SwitchExpression that represents a switch statement * without a default case. */ + @SuppressWarnings("nullness") public static SwitchStatement switch_(Expression switchValue, SwitchCase... cases) { return switch_(switchValue, null, null, toList(cases)); @@ -2707,6 +2701,7 @@ public static SwitchStatement switch_(Expression switchValue, * Creates a SwitchExpression that represents a switch statement * that has a default case. */ + @SuppressWarnings("nullness") public static SwitchStatement switch_(Expression switchValue, Expression defaultBody, SwitchCase... cases) { return switch_(switchValue, defaultBody, null, toList(cases)); @@ -2937,7 +2932,7 @@ public static WhileStatement while_(Expression condition, Statement body) { * Creates a statement that declares a variable. */ public static DeclarationStatement declare(int modifiers, - ParameterExpression parameter, Expression initializer) { + ParameterExpression parameter, @Nullable Expression initializer) { return new DeclarationStatement(modifiers, parameter, initializer); } @@ -2959,7 +2954,7 @@ public static DeclarationStatement declare(int modifiers, String name, /** * Creates a statement that executes an expression. */ - public static Statement statement(Expression expression) { + public static Statement statement(@Nullable Expression expression) { return new GotoStatement(GotoExpressionKind.Sequence, null, expression); } @@ -3055,7 +3050,7 @@ public static FluentList list(Iterable ts) { /** * Evaluates an expression and returns the result. */ - public static Object evaluate(Node node) { + public static @Nullable Object evaluate(Node node) { Objects.requireNonNull(node); final Evaluator evaluator = new Evaluator(); return ((AbstractNode) node).evaluate(evaluator); @@ -3083,7 +3078,7 @@ private static Class deduceType(List parameterList, } } - private static List toList(Iterable iterable) { + private static @PolyNull List toList(@PolyNull Iterable iterable) { if (iterable == null) { return null; } @@ -3112,20 +3107,16 @@ private static Collection toCollection(Iterable iterable) { return toList(iterable); } - private static T[] toArray(Iterable iterable, T[] a) { - return toCollection(iterable).toArray(a); - } - - static Expression accept(T node, Shuttle shuttle) { + static @PolyNull Expression accept(T node, Shuttle shuttle) { if (node == null) { - return null; + return node; } return node.accept(shuttle); } - static Statement accept(T node, Shuttle shuttle) { + static @PolyNull Statement accept(T node, Shuttle shuttle) { if (node == null) { - return null; + return node; } return node.accept(shuttle); } @@ -3186,8 +3177,8 @@ static List acceptDeclarations( return declarations1; } - static List acceptMemberDeclarations( - List memberDeclarations, Shuttle shuttle) { + static @PolyNull List acceptMemberDeclarations( + @PolyNull List memberDeclarations, Shuttle shuttle) { if (memberDeclarations == null || memberDeclarations.isEmpty()) { return memberDeclarations; // short cut } @@ -3210,7 +3201,8 @@ static List acceptExpressions(List expressions, return expressions1; } - static R acceptNodes(List nodes, Visitor visitor) { + static @Nullable R acceptNodes(@Nullable List nodes, + Visitor visitor) { R r = null; if (nodes != null) { for (Node node : nodes) { diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/FieldDeclaration.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/FieldDeclaration.java index 7780f8c4dfb4..9d754cbe6cdc 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/FieldDeclaration.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/FieldDeclaration.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Modifier; import java.util.Objects; @@ -25,10 +27,10 @@ public class FieldDeclaration extends MemberDeclaration { public final int modifier; public final ParameterExpression parameter; - public final Expression initializer; + public final @Nullable Expression initializer; public FieldDeclaration(int modifier, ParameterExpression parameter, - Expression initializer) { + @Nullable Expression initializer) { assert parameter != null : "parameter should not be null"; this.modifier = modifier; this.parameter = parameter; @@ -61,7 +63,7 @@ public void accept(ExpressionWriter writer) { writer.newlineAndIndent(); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ForEachStatement.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ForEachStatement.java index d2cc3eff9dd0..eec7cd623ebf 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ForEachStatement.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ForEachStatement.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.Objects; /** @@ -59,7 +61,7 @@ public R accept(Visitor visitor) { .append(Blocks.toBlock(body)); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { return this == o || o instanceof ForEachStatement && parameter.equals(((ForEachStatement) o).parameter) diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ForStatement.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ForStatement.java index feb2b6164516..b9ccf47b244b 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ForStatement.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ForStatement.java @@ -18,6 +18,8 @@ import org.apache.calcite.linq4j.Ord; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.List; import java.util.Objects; @@ -26,14 +28,14 @@ */ public class ForStatement extends Statement { public final List declarations; - public final Expression condition; - public final Expression post; + public final @Nullable Expression condition; + public final @Nullable Expression post; public final Statement body; /** Cached hash code for the expression. */ private int hash; public ForStatement(List declarations, - Expression condition, Expression post, Statement body) { + @Nullable Expression condition, @Nullable Expression post, Statement body) { super(ExpressionType.For, Void.TYPE); assert declarations != null; assert body != null; @@ -74,7 +76,7 @@ public R accept(Visitor visitor) { writer.append(") ").append(Blocks.toBlock(body)); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/FunctionExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/FunctionExpression.java index 6ebf8b60a7c1..231369c79fa3 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/FunctionExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/FunctionExpression.java @@ -22,6 +22,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.lang.reflect.Type; @@ -30,6 +33,10 @@ import java.util.List; import java.util.Objects; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + +import static java.util.Objects.requireNonNull; + /** * Represents a strongly typed lambda expression as a data structure in the form * of an expression tree. This class cannot be inherited. @@ -38,14 +45,14 @@ */ public final class FunctionExpression> extends LambdaExpression { - public final F function; - public final BlockStatement body; + public final @Nullable F function; + public final @Nullable BlockStatement body; public final List parameterList; - private F dynamicFunction; + private @Nullable F dynamicFunction; /** Cached hash code for the expression. */ private int hash; - private FunctionExpression(Class type, F function, BlockStatement body, + private FunctionExpression(Class type, @Nullable F function, @Nullable BlockStatement body, List parameterList) { super(ExpressionType.Lambda, type); assert type != null : "type should not be null"; @@ -57,7 +64,7 @@ private FunctionExpression(Class type, F function, BlockStatement body, this.parameterList = parameterList; } - public FunctionExpression(F function) { + public FunctionExpression(@NonNull F function) { this((Class) function.getClass(), function, null, ImmutableList.of()); } @@ -68,7 +75,7 @@ public FunctionExpression(Class type, BlockStatement body, @Override public Expression accept(Shuttle shuttle) { shuttle = shuttle.preVisit(this); - BlockStatement body = this.body.accept(shuttle); + BlockStatement body = this.body == null ? null : this.body.accept(shuttle); return shuttle.visit(this, body); } @@ -82,7 +89,7 @@ public Invokable compile() { for (int i = 0; i < args.length; i++) { evaluator.push(parameterList.get(i), args[i]); } - return evaluator.evaluate(body); + return evaluator.evaluate(castNonNull(body)); }; } @@ -93,8 +100,9 @@ public F getFunction() { if (dynamicFunction == null) { final Invokable x = compile(); + ClassLoader classLoader = requireNonNull(castNonNull(getClass().getClassLoader())); //noinspection unchecked - dynamicFunction = (F) Proxy.newProxyInstance(getClass().getClassLoader(), + dynamicFunction = (F) Proxy.newProxyInstance(classLoader, new Class[]{Types.toClass(type)}, (proxy, method, args) -> x.dynamicInvoke(args)); } return dynamicFunction; @@ -142,9 +150,10 @@ public F getFunction() { boxBridgeParams.add(parameterExpression.declString(parameterBoxType)); boxBridgeArgs.add(parameterExpression.name + (Primitive.is(parameterType) - ? "." + Primitive.of(parameterType).primitiveName + "Value()" + ? "." + castNonNull(Primitive.of(parameterType)).primitiveName + "Value()" : "")); } + castNonNull(body); Type bridgeResultType = Functions.FUNCTION_RESULT_TYPES.get(this.type); if (bridgeResultType == null) { bridgeResultType = body.getType(); @@ -202,13 +211,11 @@ public F getFunction() { private boolean isAbstractMethodPrimitive() { Method method = getAbstractMethod(); - assert method != null; return Primitive.is(method.getReturnType()); } private String getAbstractMethodName() { final Method abstractMethod = getAbstractMethod(); - assert abstractMethod != null; return abstractMethod.getName(); } @@ -222,10 +229,10 @@ private Method getAbstractMethod() { return declaredMethods.get(0); } } - return null; + throw new IllegalStateException("Method not found, type = " + type); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } @@ -266,6 +273,6 @@ private Method getAbstractMethod() { /** Function that can be invoked with a variable number of arguments. */ public interface Invokable { - Object dynamicInvoke(Object... args); + @Nullable Object dynamicInvoke(@Nullable Object... args); } } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/GotoStatement.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/GotoStatement.java index 4c00e62f70cc..be366b60a05f 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/GotoStatement.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/GotoStatement.java @@ -16,19 +16,23 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.Objects; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Represents an unconditional jump. This includes return statements, break and * continue statements, and other jumps. */ public class GotoStatement extends Statement { public final GotoExpressionKind kind; - public final LabelTarget labelTarget; - public final Expression expression; + public final @Nullable LabelTarget labelTarget; + public final @Nullable Expression expression; - GotoStatement(GotoExpressionKind kind, LabelTarget labelTarget, - Expression expression) { + GotoStatement(GotoExpressionKind kind, @Nullable LabelTarget labelTarget, + @Nullable Expression expression) { super(ExpressionType.Goto, expression == null ? Void.TYPE : expression.getType()); assert kind != null : "kind should not be null"; @@ -88,19 +92,19 @@ public R accept(Visitor visitor) { writer.append(';').newlineAndIndent(); } - @Override public Object evaluate(Evaluator evaluator) { + @Override public @Nullable Object evaluate(Evaluator evaluator) { switch (kind) { case Return: case Sequence: // NOTE: We ignore control flow. This is only correct if "return" // is the last statement in the block. - return expression.evaluate(evaluator); + return castNonNull(expression).evaluate(evaluator); default: throw new AssertionError("evaluate not implemented"); } } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/IndexExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/IndexExpression.java index 0f1c909f7220..3fbedb940643 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/IndexExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/IndexExpression.java @@ -16,9 +16,13 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.List; import java.util.Objects; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Represents indexing a property or array. */ @@ -27,8 +31,7 @@ public class IndexExpression extends Expression { public final List indexExpressions; public IndexExpression(Expression array, List indexExpressions) { - super(ExpressionType.ArrayIndex, Types.getComponentType(array.getType())); - assert array != null : "array should not be null"; + super(ExpressionType.ArrayIndex, castNonNull(Types.getComponentType(array.getType()))); assert indexExpressions != null : "indexExpressions should not be null"; assert !indexExpressions.isEmpty() : "indexExpressions should not be empty"; this.array = array; @@ -52,7 +55,7 @@ public R accept(Visitor visitor) { writer.list("[", ", ", "]", indexExpressions); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/LabelStatement.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/LabelStatement.java index 38a5d279892c..1f98718f0dd8 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/LabelStatement.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/LabelStatement.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.Objects; /** @@ -41,7 +43,7 @@ public R accept(Visitor visitor) { return visitor.visit(this); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/LabelTarget.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/LabelTarget.java index 49cbb4d03b6b..16679ab02343 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/LabelTarget.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/LabelTarget.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.Objects; /** @@ -28,7 +30,7 @@ public LabelTarget(String name) { this.name = name; } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MemberExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MemberExpression.java index d746ea19e583..29e6a82f704c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MemberExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MemberExpression.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Field; import java.lang.reflect.Modifier; import java.util.Objects; @@ -24,14 +26,14 @@ * Represents accessing a field or property. */ public class MemberExpression extends Expression { - public final Expression expression; + public final @Nullable Expression expression; public final PseudoField field; public MemberExpression(Expression expression, Field field) { this(expression, Types.field(field)); } - public MemberExpression(Expression expression, PseudoField field) { + public MemberExpression(@Nullable Expression expression, PseudoField field) { super(ExpressionType.MemberAccess, field.getType()); assert field != null : "field should not be null"; assert expression != null || Modifier.isStatic(field.getModifiers()) @@ -52,7 +54,7 @@ public R accept(Visitor visitor) { return visitor.visit(this); } - public Object evaluate(Evaluator evaluator) { + @Override public @Nullable Object evaluate(Evaluator evaluator) { final Object o = expression == null ? null : expression.evaluate(evaluator); @@ -76,7 +78,7 @@ public Object evaluate(Evaluator evaluator) { writer.append('.').append(field.getName()); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MethodCallExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MethodCallExpression.java index f98cfe0be02a..53e0d5bbeae7 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MethodCallExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MethodCallExpression.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; @@ -28,13 +30,13 @@ */ public class MethodCallExpression extends Expression { public final Method method; - public final Expression targetExpression; // null for call to static method + public final @Nullable Expression targetExpression; // null for call to static method public final List expressions; /** Cached hash code for the expression. */ private int hash; MethodCallExpression(Type returnType, Method method, - Expression targetExpression, List expressions) { + @Nullable Expression targetExpression, List expressions) { super(ExpressionType.Call, returnType); assert expressions != null : "expressions should not be null"; assert method != null : "method should not be null"; @@ -46,7 +48,7 @@ public class MethodCallExpression extends Expression { this.expressions = expressions; } - MethodCallExpression(Method method, Expression targetExpression, + MethodCallExpression(Method method, @Nullable Expression targetExpression, List expressions) { this(method.getReturnType(), method, targetExpression, expressions); } @@ -64,14 +66,14 @@ public R accept(Visitor visitor) { return visitor.visit(this); } - @Override public Object evaluate(Evaluator evaluator) { + @Override public @Nullable Object evaluate(Evaluator evaluator) { final Object target; if (targetExpression == null) { target = null; } else { target = targetExpression.evaluate(evaluator); } - final Object[] args = new Object[expressions.size()]; + final @Nullable Object[] args = new Object[expressions.size()]; for (int i = 0; i < expressions.size(); i++) { Expression expression = expressions.get(i); args[i] = expression.evaluate(evaluator); @@ -105,7 +107,7 @@ public R accept(Visitor visitor) { writer.append(')'); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MethodDeclaration.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MethodDeclaration.java index e67fa6bb207f..1cd58a48f225 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MethodDeclaration.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/MethodDeclaration.java @@ -16,12 +16,14 @@ */ package org.apache.calcite.linq4j.tree; -import com.google.common.collect.Lists; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.Nullable; import java.lang.reflect.Modifier; import java.lang.reflect.Type; import java.util.List; import java.util.Objects; +import java.util.stream.Collectors; /** * Declaration of a method. @@ -30,11 +32,11 @@ public class MethodDeclaration extends MemberDeclaration { public final int modifier; public final String name; public final Type resultType; - public final List parameters; + public final List<@NonNull ParameterExpression> parameters; public final BlockStatement body; public MethodDeclaration(int modifier, String name, Type resultType, - List parameters, BlockStatement body) { + List<@NonNull ParameterExpression> parameters, BlockStatement body) { assert name != null : "name should not be null"; assert resultType != null : "resultType should not be null"; assert parameters != null : "parameters should not be null"; @@ -68,13 +70,13 @@ public void accept(ExpressionWriter writer) { .append(' ') .append(name) .list("(", ", ", ")", - Lists.transform(parameters, ParameterExpression::declString)) + parameters.stream().map(ParameterExpression::declString).collect(Collectors.toList())) .append(' ') .append(body); writer.newlineAndIndent(); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/NewArrayExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/NewArrayExpression.java index 54cba746fde3..9380a10ed779 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/NewArrayExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/NewArrayExpression.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.util.List; import java.util.Objects; @@ -26,13 +28,13 @@ */ public class NewArrayExpression extends Expression { public final int dimension; - public final Expression bound; - public final List expressions; + public final @Nullable Expression bound; + public final @Nullable List expressions; /** Cached hash code for the expression. */ private int hash; - public NewArrayExpression(Type type, int dimension, Expression bound, - List expressions) { + public NewArrayExpression(Type type, int dimension, @Nullable Expression bound, + @Nullable List expressions) { super(ExpressionType.NewArrayInit, Types.arrayType(type, dimension)); this.dimension = dimension; this.bound = bound; @@ -67,7 +69,7 @@ public R accept(Visitor visitor) { } } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/NewExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/NewExpression.java index d42d67f5cf2e..b65b2c884eee 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/NewExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/NewExpression.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.util.List; import java.util.Objects; @@ -29,12 +31,12 @@ public class NewExpression extends Expression { public final Type type; public final List arguments; - public final List memberDeclarations; + public final @Nullable List memberDeclarations; /** Cached hash code for the expression. */ private int hash; public NewExpression(Type type, List arguments, - List memberDeclarations) { + @Nullable List memberDeclarations) { super(ExpressionType.New, type); this.type = type; this.arguments = arguments; @@ -61,7 +63,7 @@ public R accept(Visitor visitor) { } } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/OptimizeShuttle.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/OptimizeShuttle.java index c4a508ef803c..0dfb8703644c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/OptimizeShuttle.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/OptimizeShuttle.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.ArrayList; @@ -215,7 +217,7 @@ && eq(cmp.expression1, expression2)) { return super.visit(binary, expression0, expression1); } - private Expression visit0( + private @Nullable Expression visit0( BinaryExpression binary, Expression expression0, Expression expression1) { @@ -361,7 +363,7 @@ && isKnownNotNull(expression0)) { } @Override public Expression visit(MethodCallExpression methodCallExpression, - Expression targetExpression, + @Nullable Expression targetExpression, List expressions) { if (BOOLEAN_VALUEOF_BOOL.equals(methodCallExpression.method)) { Boolean always = always(expressions.get(0)); @@ -381,7 +383,7 @@ private boolean isConstantNull(Expression expression) { * Returns whether an expression always evaluates to true or false. * Assumes that expression has already been optimized. */ - private static Boolean always(Expression x) { + private static @Nullable Boolean always(Expression x) { if (x.equals(FALSE_EXPR) || x.equals(BOXED_FALSE_EXPR)) { return Boolean.FALSE; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ParameterExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ParameterExpression.java index fb7eaf2d22d3..03178320d802 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ParameterExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ParameterExpression.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Modifier; import java.lang.reflect.Type; import java.util.concurrent.atomic.AtomicInteger; @@ -51,7 +53,7 @@ public R accept(Visitor visitor) { return visitor.visit(this); } - public Object evaluate(Evaluator evaluator) { + @Override public @Nullable Object evaluate(Evaluator evaluator) { return evaluator.peek(this); } @@ -69,7 +71,7 @@ String declString(Type type) { + " " + name; } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { return this == o; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java index f95105d93917..82dd9a7aff02 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Primitive.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Array; import java.lang.reflect.Field; import java.lang.reflect.Type; @@ -28,6 +30,8 @@ import java.util.List; import java.util.Map; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Enumeration of Java's primitive types. * @@ -54,30 +58,30 @@ public enum Primitive { VOID(Void.TYPE, Void.class, 3, null, null, null, null, null, -1), OTHER(null, null, 4, null, null, null, null, null, -1); - public final Class primitiveClass; - public final Class boxClass; - public final String primitiveName; // e.g. "int" - public final String boxName; + public final @Nullable Class primitiveClass; + public final @Nullable Class boxClass; + public final @Nullable String primitiveName; // e.g. "int" + public final @Nullable String boxName; private final int family; /** The default value of this primitive class. This is the value * taken by uninitialized fields, for instance; 0 for {@code int}, false for * {@code boolean}, etc. */ - public final Object defaultValue; + public final @Nullable Object defaultValue; /** The minimum value of this primitive class. */ - public final Object min; + public final @Nullable Object min; /** The largest value that is less than zero. Null if not applicable for this * type. */ - public final Object maxNegative; + public final @Nullable Object maxNegative; /** The smallest value that is greater than zero. Null if not applicable for * this type. */ - public final Object minPositive; + public final @Nullable Object minPositive; /** The maximum value of this primitive class. */ - public final Object max; + public final @Nullable Object max; /** The size of a value of this type, in bits. Null if not applicable for this * type. */ @@ -98,9 +102,9 @@ public enum Primitive { } } - Primitive(Class primitiveClass, Class boxClass, int family, - Object defaultValue, Object min, Object maxNegative, Object minPositive, - Object max, int size) { + Primitive(@Nullable Class primitiveClass, @Nullable Class boxClass, int family, + @Nullable Object defaultValue, @Nullable Object min, @Nullable Object maxNegative, + @Nullable Object minPositive, @Nullable Object max, int size) { this.primitiveClass = primitiveClass; this.family = family; this.primitiveName = @@ -123,7 +127,7 @@ public enum Primitive { * of(Long.class) and of(String.class) return * {@code null}. */ - public static Primitive of(Type type) { + public static @Nullable Primitive of(Type type) { //noinspection SuspiciousMethodCalls return PRIMITIVE_MAP.get(type); } @@ -134,7 +138,7 @@ public static Primitive of(Type type) { *

For example, ofBox(java.util.Long.class) * returns {@link #LONG}. */ - public static Primitive ofBox(Type type) { + public static @Nullable Primitive ofBox(Type type) { //noinspection SuspiciousMethodCalls return BOX_MAP.get(type); } @@ -145,7 +149,7 @@ public static Primitive ofBox(Type type) { *

For example, ofBoxOr(Long.class) and * ofBoxOr(long.class) both return {@link #LONG}. */ - public static Primitive ofBoxOr(Type type) { + public static @Nullable Primitive ofBoxOr(Type type) { Primitive primitive = of(type); if (primitive == null) { primitive = ofBox(type); @@ -216,7 +220,7 @@ public boolean isFixedNumeric() { */ public static Type box(Type type) { Primitive primitive = of(type); - return primitive == null ? type : primitive.boxClass; + return primitive == null ? type : castNonNull(primitive.boxClass); } /** @@ -225,7 +229,7 @@ public static Type box(Type type) { */ public static Class box(Class type) { Primitive primitive = of(type); - return primitive == null ? type : primitive.boxClass; + return primitive == null ? type : castNonNull(primitive.boxClass); } /** @@ -234,7 +238,7 @@ public static Class box(Class type) { */ public static Type unbox(Type type) { Primitive primitive = ofBox(type); - return primitive == null ? type : primitive.primitiveClass; + return primitive == null ? type : castNonNull(primitive.primitiveClass); } /** @@ -243,7 +247,7 @@ public static Type unbox(Type type) { */ public static Class unbox(Class type) { Primitive primitive = ofBox(type); - return primitive == null ? type : primitive.primitiveClass; + return primitive == null ? type : castNonNull(primitive.primitiveClass); } /** @@ -696,7 +700,7 @@ public void send(Field field, Object o, Sink sink) /** * Gets an item from an array. */ - public Object arrayItem(Object dataSet, int ordinal) { + public @Nullable Object arrayItem(Object dataSet, int ordinal) { // Plain old Array.get doesn't cut it when you have an array of // Integer values but you want to read Short values. Array.getShort // does the right thing. @@ -727,6 +731,7 @@ public Object arrayItem(Object dataSet, int ordinal) { /** * Reads value from a source into an array. */ + @SuppressWarnings("argument.type.incompatible") public void arrayItem(Source source, Object dataSet, int ordinal) { switch (this) { case DOUBLE: @@ -804,7 +809,7 @@ public void arrayItem(Object dataSet, int ordinal, Sink sink) { * @param resultSet Result set * @param i Ordinal of column (1-based, per JDBC) */ - public Object jdbcGet(ResultSet resultSet, int i) throws SQLException { + public @Nullable Object jdbcGet(ResultSet resultSet, int i) throws SQLException { switch (this) { case BOOLEAN: return resultSet.getBoolean(i); @@ -976,7 +981,7 @@ public interface Sink { void set(double v); - void set(Object v); + void set(@Nullable Object v); } /** @@ -999,7 +1004,7 @@ public interface Source { double getDouble(); - Object getObject(); + @Nullable Object getObject(); } /** Whether a type is primitive (e.g. {@code int}), diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/PseudoField.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/PseudoField.java index ec1192334106..93f781954a7d 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/PseudoField.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/PseudoField.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; /** @@ -29,7 +31,7 @@ public interface PseudoField { int getModifiers(); - Object get(Object o) throws IllegalAccessException; + @Nullable Object get(@Nullable Object o) throws IllegalAccessException; Type getDeclaringClass(); } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ReflectedPseudoField.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ReflectedPseudoField.java index 4dc11a7b53d3..0372d4359714 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ReflectedPseudoField.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ReflectedPseudoField.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Field; import java.lang.reflect.Type; @@ -43,7 +45,7 @@ public int getModifiers() { return field.getModifiers(); } - public Object get(Object o) throws IllegalAccessException { + @Override public @Nullable Object get(@Nullable Object o) throws IllegalAccessException { return field.get(o); } @@ -51,7 +53,7 @@ public Class getDeclaringClass() { return field.getDeclaringClass(); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Shuttle.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Shuttle.java index 52806c0a3bca..ef5aad5595d9 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Shuttle.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Shuttle.java @@ -16,9 +16,13 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.List; import java.util.Objects; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Extension to {@link Visitor} that returns a mutated tree. */ @@ -61,7 +65,7 @@ public Shuttle preVisit(GotoStatement gotoStatement) { return this; } - public Statement visit(GotoStatement gotoStatement, Expression expression) { + public Statement visit(GotoStatement gotoStatement, @Nullable Expression expression) { return expression == gotoStatement.expression ? gotoStatement : Expressions.makeGoto( @@ -78,8 +82,8 @@ public Shuttle preVisit(ForStatement forStatement) { } public ForStatement visit(ForStatement forStatement, - List declarations, Expression condition, - Expression post, Statement body) { + List declarations, @Nullable Expression condition, + @Nullable Expression post, Statement body) { return declarations.equals(forStatement.declarations) && condition == forStatement.condition && post == forStatement.post @@ -116,7 +120,7 @@ public Shuttle preVisit(DeclarationStatement declarationStatement) { } public DeclarationStatement visit(DeclarationStatement declarationStatement, - Expression initializer) { + @Nullable Expression initializer) { return declarationStatement.initializer == initializer ? declarationStatement : Expressions.declare( @@ -133,10 +137,10 @@ public Shuttle preVisit(FunctionExpression functionExpression) { } public Expression visit(FunctionExpression functionExpression, - BlockStatement body) { - return functionExpression.body.equals(body) + @Nullable BlockStatement body) { + return Objects.equals(body, functionExpression.body) ? functionExpression - : Expressions.lambda(body, functionExpression.parameterList); + : Expressions.lambda(castNonNull(body), functionExpression.parameterList); } public Shuttle preVisit(BinaryExpression binaryExpression) { @@ -195,7 +199,7 @@ public Shuttle preVisit(MethodCallExpression methodCallExpression) { } public Expression visit(MethodCallExpression methodCallExpression, - Expression targetExpression, List expressions) { + @Nullable Expression targetExpression, List expressions) { return methodCallExpression.targetExpression == targetExpression && methodCallExpression.expressions.equals(expressions) ? methodCallExpression @@ -216,7 +220,7 @@ public Shuttle preVisit(MemberExpression memberExpression) { } public Expression visit(MemberExpression memberExpression, - Expression expression) { + @Nullable Expression expression) { return memberExpression.expression == expression ? memberExpression : Expressions.field(expression, memberExpression.field); @@ -231,7 +235,7 @@ public Shuttle preVisit(NewArrayExpression newArrayExpression) { } public Expression visit(NewArrayExpression newArrayExpression, int dimension, - Expression bound, List expressions) { + @Nullable Expression bound, @Nullable List expressions) { return Objects.equals(expressions, newArrayExpression.expressions) && Objects.equals(bound, newArrayExpression.bound) ? newArrayExpression @@ -252,7 +256,7 @@ public Shuttle preVisit(NewExpression newExpression) { } public Expression visit(NewExpression newExpression, - List arguments, List memberDeclarations) { + List arguments, @Nullable List memberDeclarations) { return arguments.equals(newExpression.arguments) && Objects.equals(memberDeclarations, newExpression.memberDeclarations) ? newExpression @@ -268,7 +272,7 @@ public Shuttle preVisit(TryStatement tryStatement) { } public Statement visit(TryStatement tryStatement, - Statement body, List catchBlocks, Statement fynally) { + Statement body, List catchBlocks, @Nullable Statement fynally) { return body.equals(tryStatement.body) && Objects.equals(catchBlocks, tryStatement.catchBlocks) && Objects.equals(fynally, tryStatement.fynally) @@ -310,7 +314,7 @@ public Shuttle preVisit(FieldDeclaration fieldDeclaration) { } public MemberDeclaration visit(FieldDeclaration fieldDeclaration, - Expression initializer) { + @Nullable Expression initializer) { return Objects.equals(initializer, fieldDeclaration.initializer) ? fieldDeclaration : Expressions.fieldDecl(fieldDeclaration.modifier, diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TernaryExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TernaryExpression.java index 00ef60c370f7..427192cadccd 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TernaryExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TernaryExpression.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.util.Objects; @@ -61,7 +63,7 @@ void accept(ExpressionWriter writer, int lprec, int rprec) { expression2.accept(writer, nodeType.rprec, rprec); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ThrowStatement.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ThrowStatement.java index 0cd427d70d0e..66472d80ad72 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ThrowStatement.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/ThrowStatement.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.Objects; /** @@ -43,7 +45,7 @@ public R accept(Visitor visitor) { writer.append("throw ").append(expression).append(';').newlineAndIndent(); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TryStatement.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TryStatement.java index 0ee3d7a96b12..1fd4d6ef4ba1 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TryStatement.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TryStatement.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -26,10 +28,10 @@ public class TryStatement extends Statement { public final Statement body; public final List catchBlocks; - public final Statement fynally; + public final @Nullable Statement fynally; public TryStatement(Statement body, List catchBlocks, - Statement fynally) { + @Nullable Statement fynally) { super(ExpressionType.Try, body.getType()); this.body = Objects.requireNonNull(body); this.catchBlocks = Objects.requireNonNull(catchBlocks); @@ -67,7 +69,7 @@ public R accept(Visitor visitor) { } } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TypeBinaryExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TypeBinaryExpression.java index 967c228849c9..50d03e3dc323 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TypeBinaryExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/TypeBinaryExpression.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.util.Objects; @@ -53,7 +55,7 @@ void accept(ExpressionWriter writer, int lprec, int rprec) { writer.append(type); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java index a449a46a925c..ef484432a6ad 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/Types.java @@ -18,6 +18,8 @@ import org.apache.calcite.linq4j.Enumerator; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Array; import java.lang.reflect.Constructor; import java.lang.reflect.Field; @@ -33,6 +35,8 @@ import java.util.Iterator; import java.util.List; +import static org.apache.calcite.linq4j.NullnessUtil.castNonNull; + /** * Utilities for converting between {@link Expression}, {@link Type} and * {@link Class}. @@ -60,7 +64,7 @@ public static Type of(Type type, Type... typeArguments) { * *

Returns null if the type is not one of these.

*/ - public static Type getElementType(Type type) { + public static @Nullable Type getElementType(Type type) { if (type instanceof ArrayType) { return ((ArrayType) type).getComponentType(); } @@ -167,7 +171,7 @@ public static Class[] toClassArray(Iterable arguments) { /** * Returns the component type of an array. */ - public static Type getComponentType(Type type) { + public static @Nullable Type getComponentType(Type type) { if (type instanceof Class) { return ((Class) type).getComponentType(); } @@ -189,30 +193,20 @@ public static Type getComponentType(Type type) { static Type getComponentTypeN(Type type) { for (;;) { - final Type oldType = type; - type = getComponentType(type); - if (type == null) { - return oldType; + Type componentType = getComponentType(type); + if (componentType == null) { + return type; } + type = componentType; } } public static Type box(Type type) { - Primitive primitive = Primitive.of(type); - if (primitive != null) { - return primitive.boxClass; - } else { - return type; - } + return Primitive.box(type); } public static Type unbox(Type type) { - Primitive primitive = Primitive.ofBox(type); - if (primitive != null) { - return primitive.primitiveClass; - } else { - return type; - } + return Primitive.unbox(type); } static String className(Type type) { @@ -289,6 +283,7 @@ public static boolean allAssignable(boolean varArgs, Class[] parameterTypes, * * @return Whether parameter can be assigned from argument */ + @SuppressWarnings("nullness") private static boolean assignableFrom(Class parameter, Class argument) { return parameter.isAssignableFrom(argument) || parameter.isPrimitive() @@ -404,7 +399,7 @@ static Type gcd(Type... types) { return Object.class; } } - return bestPrimitive.primitiveClass; + return castNonNull(bestPrimitive.primitiveClass); } else { for (int i = 1; i < types.length; i++) { if (types[i] != types[0]) { @@ -438,7 +433,7 @@ public static Expression castIfNecessary(Type returnType, // Integer foo(BigDecimal o) { // return o.intValue(); // } - return Expressions.unbox(expression, Primitive.ofBox(returnType)); + return Expressions.unbox(expression, castNonNull(Primitive.ofBox(returnType))); } if (Primitive.is(returnType) && !Primitive.is(type)) { // E.g. @@ -447,7 +442,7 @@ public static Expression castIfNecessary(Type returnType, // } return Expressions.unbox( Expressions.convert_(expression, Types.box(returnType)), - Primitive.of(returnType)); + castNonNull(Primitive.of(returnType))); } if (!Primitive.is(returnType) && Primitive.is(type)) { // E.g. @@ -526,10 +521,10 @@ public static Type stripGenerics(Type type) { static class ParameterizedTypeImpl implements ParameterizedType { private final Type rawType; private final List typeArguments; - private final Type ownerType; + private final @Nullable Type ownerType; ParameterizedTypeImpl(Type rawType, List typeArguments, - Type ownerType) { + @Nullable Type ownerType) { super(); this.rawType = rawType; this.typeArguments = typeArguments; @@ -563,7 +558,7 @@ public Type getRawType() { return rawType; } - public Type getOwnerType() { + public @Nullable Type getOwnerType() { return ownerType; } } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/UnaryExpression.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/UnaryExpression.java index d07e545aa816..5ae22bc60226 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/UnaryExpression.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/UnaryExpression.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.lang.reflect.Type; import java.util.Objects; @@ -59,7 +61,7 @@ void accept(ExpressionWriter writer, int lprec, int rprec) { } } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/VisitorImpl.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/VisitorImpl.java index 8d8431f27757..4873f90ad86c 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/VisitorImpl.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/VisitorImpl.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.List; /** @@ -24,7 +26,7 @@ * * @param Return type */ -public class VisitorImpl implements Visitor { +public class VisitorImpl<@Nullable R> implements Visitor { public VisitorImpl() { super(); } @@ -84,8 +86,8 @@ public R visit(FieldDeclaration fieldDeclaration) { public R visit(ForStatement forStatement) { R r0 = Expressions.acceptNodes(forStatement.declarations, this); - R r1 = forStatement.condition.accept(this); - R r2 = forStatement.post.accept(this); + R r1 = forStatement.condition == null ? null : forStatement.condition.accept(this); + R r2 = forStatement.post == null ? null : forStatement.post.accept(this); return forStatement.body.accept(this); } @@ -99,7 +101,7 @@ public R visit(FunctionExpression functionExpression) { @SuppressWarnings("unchecked") final List parameterList = functionExpression.parameterList; R r0 = Expressions.acceptNodes(parameterList, this); - return functionExpression.body.accept(this); + return functionExpression.body == null ? null : functionExpression.body.accept(this); } public R visit(GotoStatement gotoStatement) { diff --git a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/WhileStatement.java b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/WhileStatement.java index 76ab8b0f2860..8e5e0c505c91 100644 --- a/linq4j/src/main/java/org/apache/calcite/linq4j/tree/WhileStatement.java +++ b/linq4j/src/main/java/org/apache/calcite/linq4j/tree/WhileStatement.java @@ -16,6 +16,8 @@ */ package org.apache.calcite.linq4j.tree; +import org.checkerframework.checker.nullness.qual.Nullable; + import java.util.Objects; /** @@ -49,7 +51,7 @@ public R accept(Visitor visitor) { Blocks.toBlock(body)); } - @Override public boolean equals(Object o) { + @Override public boolean equals(@Nullable Object o) { if (this == o) { return true; } diff --git a/src/main/config/checkerframework/Constructor.astub b/src/main/config/checkerframework/Constructor.astub new file mode 100644 index 000000000000..1e7bba8ecb22 --- /dev/null +++ b/src/main/config/checkerframework/Constructor.astub @@ -0,0 +1,7 @@ +package java.lang.reflect; + +import org.checkerframework.checker.nullness.qual.*; + +class Constructor { + public @NonNull T newInstance(@Nullable Object... initargs); +} diff --git a/src/main/config/checkerframework/Field.astub b/src/main/config/checkerframework/Field.astub new file mode 100644 index 000000000000..bf9bf25bfd5d --- /dev/null +++ b/src/main/config/checkerframework/Field.astub @@ -0,0 +1,8 @@ +package java.lang.reflect; + +import org.checkerframework.checker.nullness.qual.*; + +class Field { + @Nullable Object get(@Nullable Object obj); + int getInt(@Nullable Object obj); +} diff --git a/src/main/config/checkerframework/InvocationHandler.astub b/src/main/config/checkerframework/InvocationHandler.astub new file mode 100644 index 000000000000..27aff377745b --- /dev/null +++ b/src/main/config/checkerframework/InvocationHandler.astub @@ -0,0 +1,7 @@ +package java.lang.reflect; + +import org.checkerframework.checker.nullness.qual.*; + +interface InvocationHandler { + @Nullable Object invoke(Object proxy, Method method, @Nullable Object[] args); +} diff --git a/src/main/config/checkerframework/Method.astub b/src/main/config/checkerframework/Method.astub new file mode 100644 index 000000000000..3a20f211ba1f --- /dev/null +++ b/src/main/config/checkerframework/Method.astub @@ -0,0 +1,7 @@ +package java.lang.reflect; + +import org.checkerframework.checker.nullness.qual.*; + +class Method { + @Nullable Object invoke(@Nullable Object obj, @Nullable Object... args); +} diff --git a/src/main/config/checkerframework/Proxy.astub b/src/main/config/checkerframework/Proxy.astub new file mode 100644 index 000000000000..a77969d949a4 --- /dev/null +++ b/src/main/config/checkerframework/Proxy.astub @@ -0,0 +1,7 @@ +package java.lang.reflect; + +import org.checkerframework.checker.nullness.qual.*; + +class Proxy { + Object newProxyInstance(@Nullable ClassLoader loader, Class[] interfaces, InvocationHandler h); +}