diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskExecutor.java b/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskExecutor.java index 7d7d31a50b4a..c57423caa4c7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskExecutor.java +++ b/core/trino-main/src/main/java/io/trino/execution/executor/timesharing/TimeSharingTaskExecutor.java @@ -325,9 +325,9 @@ private boolean doRemoveTask(TimeSharingTaskHandle taskHandle) splits = taskHandle.destroy(); // stop tracking splits (especially blocked splits which may never unblock) - allSplits.removeAll(splits); - intermediateSplits.removeAll(splits); - blockedSplits.keySet().removeAll(splits); + splits.forEach(allSplits::remove); + splits.forEach(intermediateSplits::remove); + splits.forEach(blockedSplits.keySet()::remove); waitingSplits.removeAll(splits); recordLeafSplitsSize(); } diff --git a/core/trino-main/src/main/java/io/trino/failuredetector/HeartbeatFailureDetector.java b/core/trino-main/src/main/java/io/trino/failuredetector/HeartbeatFailureDetector.java index b5f1e3b9876c..6308cc9a2d98 100644 --- a/core/trino-main/src/main/java/io/trino/failuredetector/HeartbeatFailureDetector.java +++ b/core/trino-main/src/main/java/io/trino/failuredetector/HeartbeatFailureDetector.java @@ -234,7 +234,7 @@ void updateMonitoredServices() .map(ServiceDescriptor::getId) .collect(toImmutableList()); - tasks.keySet().removeAll(expiredIds); + expiredIds.forEach(tasks.keySet()::remove); // 2. disable offline services tasks.values().stream() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/SymbolUtils.java b/core/trino-main/src/main/java/io/trino/sql/planner/SymbolUtils.java new file mode 100644 index 000000000000..1394eb463bda --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/SymbolUtils.java @@ -0,0 +1,34 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.List; + +public class SymbolUtils +{ + private SymbolUtils() {} + + public static boolean containsAll(List haystack, Collection needles) + { + return ImmutableSet.copyOf(haystack).containsAll(needles); + } + + public static boolean containsNone(Collection values, Collection testValues) + { + return values.stream().noneMatch(ImmutableSet.copyOf(testValues)::contains); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java index c85c794f0108..7ac1aac17265 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExtractSpatialJoins.java @@ -68,7 +68,6 @@ import java.io.IOException; import java.io.UncheckedIOException; -import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; @@ -88,6 +87,8 @@ import static io.trino.sql.ir.Comparison.Operator.LESS_THAN; import static io.trino.sql.ir.Comparison.Operator.LESS_THAN_OR_EQUAL; import static io.trino.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.trino.sql.planner.SymbolUtils.containsAll; +import static io.trino.sql.planner.SymbolUtils.containsNone; import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.planner.plan.JoinType.INNER; import static io.trino.sql.planner.plan.JoinType.LEFT; @@ -304,7 +305,7 @@ private static Result tryCreateSpatialJoin( // ST_Distance(a, b) <= r radius = spatialComparison.right(); Set radiusSymbols = extractUnique(radius); - if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) { + if (radiusSymbols.isEmpty() || containsAll(rightSymbols, radiusSymbols) && containsNone(leftSymbols, radiusSymbols)) { newRadiusSymbol = newRadiusSymbol(context, radius); newComparison = new Comparison(spatialComparison.operator(), spatialComparison.left(), toExpression(newRadiusSymbol, radius)); } @@ -316,7 +317,7 @@ private static Result tryCreateSpatialJoin( // r >= ST_Distance(a, b) radius = spatialComparison.left(); Set radiusSymbols = extractUnique(radius); - if (radiusSymbols.isEmpty() || (rightSymbols.containsAll(radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) { + if (radiusSymbols.isEmpty() || (containsAll(rightSymbols, radiusSymbols) && containsNone(leftSymbols, radiusSymbols))) { newRadiusSymbol = newRadiusSymbol(context, radius); newComparison = new Comparison(spatialComparison.operator().flip(), spatialComparison.right(), toExpression(newRadiusSymbol, radius)); } @@ -529,16 +530,16 @@ private static int checkAlignment(JoinNode joinNode, Set maybeLeftSymbol List leftSymbols = joinNode.getLeft().getOutputSymbols(); List rightSymbols = joinNode.getRight().getOutputSymbols(); - if (leftSymbols.containsAll(maybeLeftSymbols) + if (containsAll(leftSymbols, maybeLeftSymbols) && containsNone(leftSymbols, maybeRightSymbols) - && rightSymbols.containsAll(maybeRightSymbols) + && containsAll(rightSymbols, maybeRightSymbols) && containsNone(rightSymbols, maybeLeftSymbols)) { return 1; } - if (leftSymbols.containsAll(maybeRightSymbols) + if (containsAll(leftSymbols, maybeRightSymbols) && containsNone(leftSymbols, maybeLeftSymbols) - && rightSymbols.containsAll(maybeLeftSymbols) + && containsAll(rightSymbols, maybeLeftSymbols) && containsNone(rightSymbols, maybeRightSymbols)) { return -1; } @@ -606,9 +607,4 @@ private static PlanNode addPartitioningNodes(PlannerContext plannerContext, Cont Optional.empty(), INNER); } - - private static boolean containsNone(Collection values, Collection testValues) - { - return values.stream().noneMatch(ImmutableSet.copyOf(testValues)::contains); - } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index 7df67f5bcfea..d9594372b206 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -47,6 +47,7 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.SystemSessionProperties.preferPartialAggregation; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.plan.AggregationNode.Step.FINAL; import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; @@ -123,7 +124,7 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context .map(Partitioning.ArgumentBinding::getColumn) .collect(Collectors.toList()); - if (!aggregationNode.getGroupingKeys().containsAll(partitioningColumns)) { + if (!containsAll(aggregationNode.getGroupingKeys(), partitioningColumns)) { return Result.empty(); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java index 7eb830c3da36..0fca63dbfdd1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughJoin.java @@ -48,6 +48,7 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Sets.intersection; import static io.trino.SystemSessionProperties.isPushPartialAggregationThroughJoin; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.iterative.rule.PushProjectionThroughJoin.pushProjectionThroughJoin; import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs; import static io.trino.sql.planner.plan.AggregationNode.Step.INTERMEDIATE; @@ -175,7 +176,7 @@ private static boolean allAggregationsOn(Map aggregations, .map(SymbolsExtractor::extractAll) .flatMap(List::stream) .collect(toImmutableSet()); - return symbols.containsAll(inputs); + return containsAll(symbols, inputs); } private Optional pushPartialToLeftChild(AggregationNode node, JoinNode child, Context context) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java index 0df6f82f6b6b..8178cdd14c91 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughJoin.java @@ -34,6 +34,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.planner.plan.JoinType.INNER; @@ -69,11 +70,11 @@ public static Optional pushProjectionThroughJoin( for (Map.Entry assignment : projectNode.getAssignments().entrySet()) { Expression expression = assignment.getValue(); Set symbols = extractUnique(expression); - if (leftChild.getOutputSymbols().containsAll(symbols)) { + if (containsAll(leftChild.getOutputSymbols(), symbols)) { // expression is satisfied with left child symbols leftAssignmentsBuilder.put(assignment.getKey(), expression); } - else if (rightChild.getOutputSymbols().containsAll(symbols)) { + else if (containsAll(rightChild.getOutputSymbols(), symbols)) { // expression is satisfied with right child symbols rightAssignmentsBuilder.put(assignment.getKey(), expression); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java index 5af536318cb1..8cfe9270d5ba 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ReorderJoins.java @@ -79,6 +79,7 @@ import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic; import static io.trino.sql.planner.EqualityInference.isInferenceCandidate; import static io.trino.sql.planner.OptimizerConfig.JoinReorderingStrategy.AUTOMATIC; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.SymbolsExtractor.extractAll; import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.planner.iterative.rule.DetermineJoinDistributionType.canReplicate; @@ -518,7 +519,7 @@ static class MultiJoinNode this.pushedProjectionThroughJoin = pushedProjectionThroughJoin; List inputSymbols = sources.stream().flatMap(source -> source.getOutputSymbols().stream()).collect(toImmutableList()); - checkArgument(inputSymbols.containsAll(outputSymbols), "inputs do not contain all output symbols"); + checkArgument(containsAll(inputSymbols, outputSymbols), "inputs do not contain all output symbols"); } public Expression getFilter() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java index 3f5697118240..e954ad02e46b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java @@ -82,6 +82,7 @@ import static io.trino.SystemSessionProperties.isDistributedSortEnabled; import static io.trino.SystemSessionProperties.isSpillEnabled; import static io.trino.SystemSessionProperties.isTaskScaleWritersEnabled; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; @@ -609,7 +610,7 @@ private List pruneMarkDistinctSymbols(MarkDistinctNode node, List) property).getColumn()); } - else if (!node.getDistinctSymbols().containsAll(property.getColumns())) { + else if (!containsAll(node.getDistinctSymbols(), property.getColumns())) { // Ran into a non-distinct symbol. There will be no more symbols that are functionally dependent on distinct symbols exclusively. break; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java index 5c8a4c84c074..db43a3d5bb75 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java @@ -79,6 +79,7 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.metadata.OperatorNameUtil.mangleOperatorName; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; import static io.trino.sql.planner.plan.ChildReplacer.replaceChildren; @@ -760,7 +761,7 @@ private PlanWithProperties plan(PlanNode node, HashComputationSet parentPreferen { PlanWithProperties result = node.accept(this, parentPreference); checkState( - result.getNode().getOutputSymbols().containsAll(result.getHashSymbols().values()), + containsAll(result.getNode().getOutputSymbols(), result.getHashSymbols().values()), "Node %s declares hash symbols not in the output", result.getNode().getClass().getSimpleName()); return result; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java index 70cd23eda863..c85df471568a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java @@ -61,6 +61,7 @@ import static io.trino.spi.function.FunctionKind.AGGREGATE; import static io.trino.sql.ir.Booleans.TRUE; import static io.trino.sql.ir.IrUtils.combineConjuncts; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.plan.WindowFrameType.RANGE; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; @@ -288,7 +289,7 @@ private PlanNode planTableScan(TableScanNode node, Expression predicate, Context .transformKeys(node.getAssignments()::get) .intersect(node.getEnforcedConstraint()); - checkState(node.getOutputSymbols().containsAll(context.getLookupSymbols())); + checkState(containsAll(node.getOutputSymbols(), context.getLookupSymbols())); Set lookupColumns = context.getLookupSymbols().stream() .map(node.getAssignments()::get) @@ -375,8 +376,7 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) } // Lookup symbols can only be passed through if they are part of the partitioning - - if (!node.getPartitionBy().containsAll(context.get().getLookupSymbols())) { + if (!containsAll(node.getPartitionBy(), context.get().getLookupSymbols())) { return node; } @@ -393,7 +393,7 @@ public PlanNode visitIndexSource(IndexSourceNode node, RewriteContext c public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext context) { // Lookup symbols can only be passed through the probe side of an index join - if (!node.getProbeSource().getOutputSymbols().containsAll(context.get().getLookupSymbols())) { + if (!containsAll(node.getProbeSource().getOutputSymbols(), context.get().getLookupSymbols())) { return node; } @@ -411,7 +411,7 @@ public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext conte public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { // Lookup symbols can only be passed through if they are part of the group by columns - if (!node.getGroupingKeys().containsAll(context.get().getLookupSymbols())) { + if (!containsAll(node.getGroupingKeys(), context.get().getLookupSymbols())) { return node; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java index a12b8b699c38..13ea823de815 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java @@ -32,6 +32,7 @@ import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static java.util.Objects.requireNonNull; public class LimitPushDown @@ -132,7 +133,7 @@ public PlanNode visitAggregation(AggregationNode node, RewriteContext context) // function is injective, but that's a rare case. The majority of window nodes are expected to be partitioned by // pre-projected symbols. Predicate isSupported = conjunct -> - isDeterministic(conjunct) && - partitionSymbols.containsAll(extractUnique(conjunct)); + isDeterministic(conjunct) && containsAll(partitionSymbols, extractUnique(conjunct)); Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported)); @@ -260,8 +260,7 @@ public PlanNode visitTopNRanking(TopNRankingNode node, RewriteContext isSupported = conjunct -> - isDeterministic(conjunct) && - partitionSymbols.containsAll(extractUnique(conjunct)); + isDeterministic(conjunct) && containsAll(partitionSymbols, extractUnique(conjunct)); Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported)); @@ -500,7 +499,7 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) if (joinEqualityExpression(conjunct, node.getLeft().getOutputSymbols(), node.getRight().getOutputSymbols())) { Comparison equality = (Comparison) conjunct; - boolean alignedComparison = node.getLeft().getOutputSymbols().containsAll(extractUnique(equality.left())); + boolean alignedComparison = containsAll(node.getLeft().getOutputSymbols(), extractUnique(equality.left())); Expression leftExpression = alignedComparison ? equality.left() : equality.right(); Expression rightExpression = alignedComparison ? equality.right() : equality.left(); @@ -625,7 +624,7 @@ private DynamicFiltersResult createDynamicFilters( Comparison comparison = expression.getComparison(); Expression leftExpression = comparison.left(); Expression rightExpression = comparison.right(); - boolean alignedComparison = node.getLeft().getOutputSymbols().containsAll(extractUnique(leftExpression)); + boolean alignedComparison = containsAll(node.getLeft().getOutputSymbols(), extractUnique(leftExpression)); return new DynamicFilterExpression( new Comparison( alignedComparison ? comparison.operator() : comparison.operator().flip(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java index 201704550f52..d29987e89b43 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PropertyDerivations.java @@ -100,6 +100,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.spi.predicate.TupleDomain.extractFixedValues; import static io.trino.sql.ir.optimizer.IrExpressionOptimizer.newOptimizer; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.SystemPartitioningHandle.ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.optimizations.ActualProperties.Global.arbitraryPartition; import static io.trino.sql.planner.optimizations.ActualProperties.Global.coordinatorSinglePartition; @@ -138,15 +139,15 @@ public static ActualProperties deriveProperties( ActualProperties output = node.accept(new Visitor(plannerContext, session), inputProperties); output.getNodePartitioning().ifPresent(partitioning -> - verify(node.getOutputSymbols().containsAll(partitioning.getColumns()), "Node-level partitioning properties contain columns not present in node's output")); + verify(containsAll(node.getOutputSymbols(), partitioning.getColumns()), "Node-level partitioning properties contain columns not present in node's output")); - verify(node.getOutputSymbols().containsAll(output.getConstants().keySet()), "Node-level constant properties contain columns not present in node's output"); + verify(containsAll(node.getOutputSymbols(), output.getConstants().keySet()), "Node-level constant properties contain columns not present in node's output"); Set localPropertyColumns = output.getLocalProperties().stream() .flatMap(property -> property.getColumns().stream()) .collect(Collectors.toSet()); - verify(node.getOutputSymbols().containsAll(localPropertyColumns), "Node-level local properties contain columns not present in node's output"); + verify(containsAll(node.getOutputSymbols(), localPropertyColumns), "Node-level local properties contain columns not present in node's output"); return output; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java index f26d3bf01a4a..3c922c6769b1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StreamPropertyDerivations.java @@ -95,6 +95,7 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.SystemSessionProperties.isSpillEnabled; import static io.trino.spi.predicate.TupleDomain.extractFixedValues; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.FIXED; import static io.trino.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution.MULTIPLE; @@ -158,7 +159,7 @@ public static StreamProperties deriveProperties( .flatMap(property -> property.getColumns().stream()) .collect(Collectors.toSet()); - verify(node.getOutputSymbols().containsAll(localPropertyColumns), "Stream-level local properties contain columns not present in node's output"); + verify(containsAll(node.getOutputSymbols(), localPropertyColumns), "Stream-level local properties contain columns not present in node's output"); return result; } @@ -202,7 +203,7 @@ public static StreamProperties deriveStreamPropertiesWithoutActualProperties(Pla StreamProperties result = node.accept(new Visitor(metadata, session), inputProperties); result.getPartitioningColumns().ifPresent(columns -> - verify(node.getOutputSymbols().containsAll(columns), "Stream-level partitioning properties contain columns not present in node's output")); + verify(containsAll(node.getOutputSymbols(), columns), "Stream-level partitioning properties contain columns not present in node's output")); return result; } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java index 1c797eb0faec..816b5186219c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/AggregationNode.java @@ -38,6 +38,7 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; import static java.util.Objects.requireNonNull; @@ -114,7 +115,7 @@ public AggregationNode( this.hashSymbol = hashSymbol; requireNonNull(preGroupedSymbols, "preGroupedSymbols is null"); - checkArgument(preGroupedSymbols.isEmpty() || groupingSets.getGroupingKeys().containsAll(preGroupedSymbols), "Pre-grouped symbols must be a subset of the grouping keys"); + checkArgument(preGroupedSymbols.isEmpty() || containsAll(groupingSets.getGroupingKeys(), preGroupedSymbols), "Pre-grouped symbols must be a subset of the grouping keys"); this.preGroupedSymbols = ImmutableList.copyOf(preGroupedSymbols); ImmutableList.Builder outputs = ImmutableList.builder(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/ApplyNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ApplyNode.java index 48f544a1bdcc..fbda7eeaca56 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/ApplyNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ApplyNode.java @@ -24,6 +24,7 @@ import java.util.Map; import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static java.util.Objects.requireNonNull; @Immutable @@ -81,7 +82,7 @@ public ApplyNode( requireNonNull(correlation, "correlation is null"); requireNonNull(originSubquery, "originSubquery is null"); - checkArgument(input.getOutputSymbols().containsAll(correlation), "Input does not contain symbols from correlation"); + checkArgument(containsAll(input.getOutputSymbols(), correlation), "Input does not contain symbols from correlation"); this.input = input; this.subquery = subquery; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java index 08be129f4b3b..811019d92408 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/CorrelatedJoinNode.java @@ -24,6 +24,7 @@ import java.util.List; import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static java.util.Objects.requireNonNull; /** @@ -70,7 +71,7 @@ public CorrelatedJoinNode( requireNonNull(filter, "filter is null"); requireNonNull(originSubquery, "originSubquery is null"); - checkArgument(input.getOutputSymbols().containsAll(correlation), "Input does not contain symbols from correlation"); + checkArgument(containsAll(input.getOutputSymbols(), correlation), "Input does not contain symbols from correlation"); this.input = input; this.subquery = subquery; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java index 47ac2aa7ed43..1a4018c437a2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java @@ -30,6 +30,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; @@ -104,7 +105,7 @@ public ExchangeNode( PartitioningHandle partitioningHandle = partitioningScheme.getPartitioning().getHandle(); checkArgument(scope != REMOTE || partitioningHandle.equals(SINGLE_DISTRIBUTION), "remote merging exchange requires single distribution"); checkArgument(scope != LOCAL || partitioningHandle.equals(FIXED_PASSTHROUGH_DISTRIBUTION), "local merging exchange requires passthrough distribution"); - checkArgument(partitioningScheme.getOutputLayout().containsAll(ordering.orderBy()), "Partitioning scheme does not supply all required ordering symbols"); + checkArgument(containsAll(partitioningScheme.getOutputLayout(), ordering.orderBy()), "Partitioning scheme does not supply all required ordering symbols"); checkArgument(type == Type.GATHER, "Merging exchange must be of GATHER type"); checkArgument(inputs.size() == 1, "Merging exchange must have single input"); }); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/IndexSourceNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/IndexSourceNode.java index e93d18bdcfc6..dba8f5f12993 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/IndexSourceNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/IndexSourceNode.java @@ -28,6 +28,7 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static java.util.Objects.requireNonNull; public class IndexSourceNode @@ -57,7 +58,7 @@ public IndexSourceNode( checkArgument(!lookupSymbols.isEmpty(), "lookupSymbols is empty"); checkArgument(!outputSymbols.isEmpty(), "outputSymbols is empty"); checkArgument(assignments.keySet().containsAll(lookupSymbols), "Assignments do not include all lookup symbols"); - checkArgument(outputSymbols.containsAll(lookupSymbols), "Lookup symbols need to be part of the output symbols"); + checkArgument(containsAll(outputSymbols, lookupSymbols), "Lookup symbols need to be part of the output symbols"); } @JsonProperty diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PatternRecognitionNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PatternRecognitionNode.java index 15087657d7f4..2862709afa15 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PatternRecognitionNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PatternRecognitionNode.java @@ -37,6 +37,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Iterables.concat; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static io.trino.sql.planner.plan.FrameBoundType.CURRENT_ROW; import static io.trino.sql.planner.plan.RowsPerMatch.ONE; import static io.trino.sql.planner.plan.RowsPerMatch.WINDOW; @@ -96,7 +97,7 @@ public PatternRecognitionNode( requireNonNull(source, "source is null"); requireNonNull(specification, "specification is null"); requireNonNull(hashSymbol, "hashSymbol is null"); - checkArgument(specification.partitionBy().containsAll(prePartitionedInputs), "prePartitionedInputs must be contained in partitionBy"); + checkArgument(containsAll(specification.partitionBy(), prePartitionedInputs), "prePartitionedInputs must be contained in partitionBy"); Optional orderingScheme = specification.orderingScheme(); checkArgument(preSortedOrderPrefix == 0 || (orderingScheme.isPresent() && preSortedOrderPrefix <= orderingScheme.get().orderBy().size()), "Cannot have sorted more symbols than those requested"); checkArgument(preSortedOrderPrefix == 0 || ImmutableSet.copyOf(prePartitionedInputs).equals(ImmutableSet.copyOf(specification.partitionBy())), "preSortedOrderPrefix can only be greater than zero if all partition symbols are pre-partitioned"); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/UnnestNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/UnnestNode.java index cdf67cae1927..cd97ce28a12f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/UnnestNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/UnnestNode.java @@ -26,6 +26,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.sql.planner.SymbolUtils.containsAll; import static java.util.Objects.requireNonNull; @Immutable @@ -50,7 +51,7 @@ public UnnestNode( super(id); this.source = requireNonNull(source, "source is null"); this.replicateSymbols = ImmutableList.copyOf(requireNonNull(replicateSymbols, "replicateSymbols is null")); - checkArgument(source.getOutputSymbols().containsAll(replicateSymbols), "Source does not contain all replicateSymbols"); + checkArgument(containsAll(source.getOutputSymbols(), replicateSymbols), "Source does not contain all replicateSymbols"); requireNonNull(mappings, "mappings is null"); checkArgument(!mappings.isEmpty(), "mappings is empty"); this.mappings = ImmutableList.copyOf(mappings);