diff --git a/be/src/exec/aggregator.cpp b/be/src/exec/aggregator.cpp index a14f33672f018a..420a802293b213 100644 --- a/be/src/exec/aggregator.cpp +++ b/be/src/exec/aggregator.cpp @@ -1359,6 +1359,7 @@ void Aggregator::_init_agg_hash_variant(HashVariantType& hash_variant) { } } } + VLOG_ROW << "hash type is " << static_cast::type>(type); hash_variant.init(_state, type, _agg_stat); diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java index 0a69d76ca9e210..06c77b4a43982a 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java @@ -364,6 +364,8 @@ public class SessionVariable implements Serializable, Writable, Cloneable { public static final String CBO_PRUNE_SHUFFLE_COLUMN_RATE = "cbo_prune_shuffle_column_rate"; public static final String CBO_PUSH_DOWN_AGGREGATE_MODE = "cbo_push_down_aggregate_mode"; public static final String CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN = "cbo_push_down_aggregate_on_broadcast_join"; + public static final String CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN_ROW_COUNT_LIMIT = + "cbo_push_down_aggregate_on_broadcast_join_row_count_limit"; public static final String CBO_PUSH_DOWN_DISTINCT_BELOW_WINDOW = "cbo_push_down_distinct_below_window"; public static final String CBO_PUSH_DOWN_AGGREGATE = "cbo_push_down_aggregate"; @@ -1532,6 +1534,9 @@ public static MaterializedViewRewriteMode parse(String str) { @VarAttr(name = CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN, flag = VariableMgr.INVISIBLE) private boolean cboPushDownAggregateOnBroadcastJoin = true; + @VarAttr(name = CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN_ROW_COUNT_LIMIT, flag = VariableMgr.INVISIBLE) + private long cboPushDownAggregateOnBroadcastJoinRowCountLimit = 250000; + // auto, global, local @VarAttr(name = CBO_PUSH_DOWN_AGGREGATE, flag = VariableMgr.INVISIBLE) private String cboPushDownAggregate = "global"; @@ -3595,6 +3600,14 @@ public void setCboPushDownAggregateOnBroadcastJoin(boolean cboPushDownAggregateO this.cboPushDownAggregateOnBroadcastJoin = cboPushDownAggregateOnBroadcastJoin; } + public long getCboPushDownAggregateOnBroadcastJoinRowCountLimit() { + return cboPushDownAggregateOnBroadcastJoinRowCountLimit; + } + + public void setCboPushDownAggregateOnBroadcastJoinRowCountLimit(long cboPushDownAggregateOnBroadcastJoinRowCountLimit) { + this.cboPushDownAggregateOnBroadcastJoinRowCountLimit = cboPushDownAggregateOnBroadcastJoinRowCountLimit; + } + public String getCboPushDownAggregate() { return cboPushDownAggregate; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pdagg/PushDownAggregateCollector.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pdagg/PushDownAggregateCollector.java index 2ec2a8f8cf7717..faf0cd469a2d80 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pdagg/PushDownAggregateCollector.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pdagg/PushDownAggregateCollector.java @@ -56,6 +56,8 @@ import java.util.Set; import java.util.stream.Collectors; +import static com.starrocks.sql.optimizer.statistics.StatisticsEstimateCoefficient.SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT; + /* * Collect all can be push down aggregate context, to get which aggregation can be * pushed down and the push down path. @@ -473,13 +475,23 @@ private boolean checkStatistics(AggregatePushDownContext context, ColumnRefSet g List[] cards = new List[] {lower, medium, high}; - groupBys.getStream().map(factory::getColumnRef) + Set columnStatistics = groupBys.getStream() + .map(factory::getColumnRef) .map(s -> ExpressionStatisticCalculator.calculate(s, statistics)) - .forEach(s -> cards[groupByCardinality(s, statistics.getOutputRowCount())].add(s)); + .collect(Collectors.toSet()); + columnStatistics.forEach(s -> cards[groupByCardinality(s, statistics.getOutputRowCount())].add(s)); double lowerCartesian = lower.stream().map(ColumnStatistic::getDistinctValuesCount).reduce((a, b) -> a * b) .orElse(Double.MAX_VALUE); + // target is the immediate child of a small broadcast join + // and the ndv of all columns is less than SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT + if (pushDownMode == PUSH_DOWN_AGG_AUTO && context.immediateChildOfSmallBroadcastJoin) { + if (columnStatistics.stream().anyMatch(x -> x.getDistinctValuesCount() > SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT)) { + return false; + } + } + // pow(row_count/20, a half of lower column size) double lowerUpper = Math.max(statistics.getOutputRowCount() / 20, 1); lowerUpper = Math.pow(lowerUpper, Math.max(lower.size() / 2, 1)); @@ -516,15 +528,9 @@ private boolean checkStatistics(AggregatePushDownContext context, ColumnRefSet g } } - // 2. forbidden rules - // 2.1 target is the immediate child of a small broadcast join and the cardinality of the aggregation is not lower. - if (pushDownMode == PUSH_DOWN_AGG_AUTO && context.immediateChildOfSmallBroadcastJoin) { - return false; - } - - // 2.2 high cardinality >= 2 - // 2.3 medium cardinality > 2 - // 2.4 high cardinality = 1 and medium cardinality > 0 + // 2.1 high cardinality >= 2 + // 2.2 medium cardinality > 2 + // 2.3 high cardinality = 1 and medium cardinality > 0 if (high.size() >= 2 || medium.size() > 2 || (high.size() == 1 && !medium.isEmpty())) { return false; } @@ -553,9 +559,9 @@ private boolean checkStatistics(AggregatePushDownContext context, ColumnRefSet g return false; } - // high(2): cardinality/count > MEDIUM_AGGREGATE - // medium(1): cardinality/count <= MEDIUM_AGGREGATE and > LOW_AGGREGATE - // lower(0): cardinality/count < LOW_AGGREGATE + // high(2): row_count / cardinality < MEDIUM_AGGREGATE_EFFECT_COEFFICIENT + // medium(1): row_count / cardinality >= MEDIUM_AGGREGATE_EFFECT_COEFFICIENT and < LOW_AGGREGATE_EFFECT_COEFFICIENT + // lower(0): row_count / cardinality >= LOW_AGGREGATE_EFFECT_COEFFICIENT private int groupByCardinality(ColumnStatistic statistic, double rowCount) { if (statistic.isUnknown()) { return 2; @@ -586,7 +592,7 @@ private boolean isSmallBroadcastJoin(OptExpression optExpression) { } double rightRows = rightStatistics.getOutputRowCount(); return rightRows <= sessionVariable.getBroadcastRowCountLimit() && - rightRows <= StatisticsEstimateCoefficient.SMALL_BROADCAST_JOIN_ROW_COUNT_UPPER_BOUND; + rightRows <= sessionVariable.getCboPushDownAggregateOnBroadcastJoinRowCountLimit(); } /** diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java index b9e8b151a291c0..a2716bccc48e98 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java @@ -47,7 +47,7 @@ public class StatisticsEstimateCoefficient { public static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000; public static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000; public static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100; - public static final int SMALL_BROADCAST_JOIN_ROW_COUNT_UPPER_BOUND = 4096; + public static final int SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT = 100000; public static final double EXTREME_HIGH_AGGREGATE_EFFECT_COEFFICIENT = 3; // default selectivity for anti join diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDSPushAggTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDSPushAggTest.java index bede2783541ec1..e16c673214dae1 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDSPushAggTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDSPushAggTest.java @@ -148,24 +148,24 @@ private static Stream testPushDownProvider() { Arguments[] cases = new Arguments[] { Arguments.of("Q01", 4, 4, false, 6, true, 4, false, 6, true), Arguments.of("Q02", 2, 6, true, 6, true, 6, true, 6, true), - Arguments.of("Q03", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q03", 2, 2, false, 4, true, 4, true, 4, true), // Although the number of aggregators is the same, the aggregator was pushed down. // This is caused by the CTE. orig: CTE inline, auto~high: CTE Arguments.of("Q04", 12, 12, true, 12, true, 12, true, 12, true), Arguments.of("Q05", 8, 16, true, 16, true, 16, true, 16, true), Arguments.of("Q08", 4, 6, true, 6, true, 6, true, 6, true), Arguments.of("Q11", 8, 8, true, 8, true, 8, true, 8, true), - Arguments.of("Q12", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q12", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q15", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q19", 2, 2, false, 4, true, 2, false, 2, false), - Arguments.of("Q20", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q20", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q23_1", 10, 13, true, 13, true, 13, true, 13, true), Arguments.of("Q24_1", 6, 6, false, 7, true, 6, false, 6, false), Arguments.of("Q24_2", 6, 6, false, 7, true, 6, false, 6, false), Arguments.of("Q30", 4, 4, false, 6, true, 4, false, 4, false), Arguments.of("Q31", 4, 8, true, 8, true, 8, true, 8, true), Arguments.of("Q33", 8, 8, false, 14, true, 14, true, 14, true), - Arguments.of("Q37", 2, 4, true, 8, true, 6, true, 7, true), + Arguments.of("Q37", 2, 2, false, 8, true, 6, true, 7, true), Arguments.of("Q38", 8, 14, true, 20, true, 14, true, 17, true), Arguments.of("Q41", 4, 4, false, 6, true, 4, false, 4, false), Arguments.of("Q42", 2, 4, true, 4, true, 4, true, 4, true), @@ -173,11 +173,11 @@ private static Stream testPushDownProvider() { Arguments.of("Q45", 6, 6, false, 8, true, 6, false, 8, true), Arguments.of("Q46", 2, 2, false, 4, true, 2, false, 2, false), Arguments.of("Q47", 2, 2, true, 4, true, 4, true, 4, true), - Arguments.of("Q51", 4, 8, true, 8, true, 8, true, 8, true), - Arguments.of("Q52", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q51", 4, 4, false, 8, true, 8, true, 8, true), + Arguments.of("Q52", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q53", 2, 2, false, 4, true, 4, true, 4, true), - Arguments.of("Q54", 9, 11, true, 18, true, 11, true, 17, true), - Arguments.of("Q55", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q54", 9, 9, false, 18, true, 11, true, 17, true), + Arguments.of("Q55", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q56", 8, 8, false, 14, true, 14, true, 14, true), Arguments.of("Q57", 2, 2, true, 4, true, 4, true, 4, true), Arguments.of("Q58", 6, 12, true, 12, true, 12, true, 12, true), @@ -194,13 +194,13 @@ private static Stream testPushDownProvider() { Arguments.of("Q78", 6, 6, false, 9, true, 6, false, 6, false), Arguments.of("Q79", 2, 2, false, 4, true, 2, false, 2, false), Arguments.of("Q81", 4, 4, false, 6, true, 4, false, 4, false), - Arguments.of("Q82", 2, 4, true, 8, true, 6, true, 7, true), + Arguments.of("Q82", 2, 2, false, 8, true, 6, true, 7, true), Arguments.of("Q83", 6, 12, true, 12, true, 12, true, 12, true), Arguments.of("Q87", 8, 14, true, 20, true, 14, true, 17, true), Arguments.of("Q89", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q91", 2, 4, true, 4, true, 4, true, 4, true), Arguments.of("Q97", 6, 6, false, 12, true, 10, true, 12, true), - Arguments.of("Q98", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q98", 2, 2, false, 4, true, 4, true, 4, true), }; return Arrays.stream(cases);