From 32204661ee5f496512b224331e20443ab10bb00c Mon Sep 17 00:00:00 2001 From: stephen Date: Tue, 31 Dec 2024 19:38:54 +0800 Subject: [PATCH] [WIP] adjust agg pushdown strategy Signed-off-by: stephen --- be/src/exec/aggregator.cpp | 1 + .../com/starrocks/qe/SessionVariable.java | 13 +++++++ .../pdagg/PushDownAggregateCollector.java | 36 +++++++++++-------- .../StatisticsEstimateCoefficient.java | 2 +- .../starrocks/sql/plan/TPCDSPushAggTest.java | 20 +++++------ 5 files changed, 46 insertions(+), 26 deletions(-) diff --git a/be/src/exec/aggregator.cpp b/be/src/exec/aggregator.cpp index a14f33672f018..420a802293b21 100644 --- a/be/src/exec/aggregator.cpp +++ b/be/src/exec/aggregator.cpp @@ -1359,6 +1359,7 @@ void Aggregator::_init_agg_hash_variant(HashVariantType& hash_variant) { } } } + VLOG_ROW << "hash type is " << static_cast::type>(type); hash_variant.init(_state, type, _agg_stat); diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java index 0a69d76ca9e21..06c77b4a43982 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java @@ -364,6 +364,8 @@ public class SessionVariable implements Serializable, Writable, Cloneable { public static final String CBO_PRUNE_SHUFFLE_COLUMN_RATE = "cbo_prune_shuffle_column_rate"; public static final String CBO_PUSH_DOWN_AGGREGATE_MODE = "cbo_push_down_aggregate_mode"; public static final String CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN = "cbo_push_down_aggregate_on_broadcast_join"; + public static final String CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN_ROW_COUNT_LIMIT = + "cbo_push_down_aggregate_on_broadcast_join_row_count_limit"; public static final String CBO_PUSH_DOWN_DISTINCT_BELOW_WINDOW = "cbo_push_down_distinct_below_window"; public static final String CBO_PUSH_DOWN_AGGREGATE = "cbo_push_down_aggregate"; @@ -1532,6 +1534,9 @@ public static MaterializedViewRewriteMode parse(String str) { @VarAttr(name = CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN, flag = VariableMgr.INVISIBLE) private boolean cboPushDownAggregateOnBroadcastJoin = true; + @VarAttr(name = CBO_PUSH_DOWN_AGGREGATE_ON_BROADCAST_JOIN_ROW_COUNT_LIMIT, flag = VariableMgr.INVISIBLE) + private long cboPushDownAggregateOnBroadcastJoinRowCountLimit = 250000; + // auto, global, local @VarAttr(name = CBO_PUSH_DOWN_AGGREGATE, flag = VariableMgr.INVISIBLE) private String cboPushDownAggregate = "global"; @@ -3595,6 +3600,14 @@ public void setCboPushDownAggregateOnBroadcastJoin(boolean cboPushDownAggregateO this.cboPushDownAggregateOnBroadcastJoin = cboPushDownAggregateOnBroadcastJoin; } + public long getCboPushDownAggregateOnBroadcastJoinRowCountLimit() { + return cboPushDownAggregateOnBroadcastJoinRowCountLimit; + } + + public void setCboPushDownAggregateOnBroadcastJoinRowCountLimit(long cboPushDownAggregateOnBroadcastJoinRowCountLimit) { + this.cboPushDownAggregateOnBroadcastJoinRowCountLimit = cboPushDownAggregateOnBroadcastJoinRowCountLimit; + } + public String getCboPushDownAggregate() { return cboPushDownAggregate; } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pdagg/PushDownAggregateCollector.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pdagg/PushDownAggregateCollector.java index 2ec2a8f8cf771..faf0cd469a2d8 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pdagg/PushDownAggregateCollector.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/tree/pdagg/PushDownAggregateCollector.java @@ -56,6 +56,8 @@ import java.util.Set; import java.util.stream.Collectors; +import static com.starrocks.sql.optimizer.statistics.StatisticsEstimateCoefficient.SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT; + /* * Collect all can be push down aggregate context, to get which aggregation can be * pushed down and the push down path. @@ -473,13 +475,23 @@ private boolean checkStatistics(AggregatePushDownContext context, ColumnRefSet g List[] cards = new List[] {lower, medium, high}; - groupBys.getStream().map(factory::getColumnRef) + Set columnStatistics = groupBys.getStream() + .map(factory::getColumnRef) .map(s -> ExpressionStatisticCalculator.calculate(s, statistics)) - .forEach(s -> cards[groupByCardinality(s, statistics.getOutputRowCount())].add(s)); + .collect(Collectors.toSet()); + columnStatistics.forEach(s -> cards[groupByCardinality(s, statistics.getOutputRowCount())].add(s)); double lowerCartesian = lower.stream().map(ColumnStatistic::getDistinctValuesCount).reduce((a, b) -> a * b) .orElse(Double.MAX_VALUE); + // target is the immediate child of a small broadcast join + // and the ndv of all columns is less than SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT + if (pushDownMode == PUSH_DOWN_AGG_AUTO && context.immediateChildOfSmallBroadcastJoin) { + if (columnStatistics.stream().anyMatch(x -> x.getDistinctValuesCount() > SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT)) { + return false; + } + } + // pow(row_count/20, a half of lower column size) double lowerUpper = Math.max(statistics.getOutputRowCount() / 20, 1); lowerUpper = Math.pow(lowerUpper, Math.max(lower.size() / 2, 1)); @@ -516,15 +528,9 @@ private boolean checkStatistics(AggregatePushDownContext context, ColumnRefSet g } } - // 2. forbidden rules - // 2.1 target is the immediate child of a small broadcast join and the cardinality of the aggregation is not lower. - if (pushDownMode == PUSH_DOWN_AGG_AUTO && context.immediateChildOfSmallBroadcastJoin) { - return false; - } - - // 2.2 high cardinality >= 2 - // 2.3 medium cardinality > 2 - // 2.4 high cardinality = 1 and medium cardinality > 0 + // 2.1 high cardinality >= 2 + // 2.2 medium cardinality > 2 + // 2.3 high cardinality = 1 and medium cardinality > 0 if (high.size() >= 2 || medium.size() > 2 || (high.size() == 1 && !medium.isEmpty())) { return false; } @@ -553,9 +559,9 @@ private boolean checkStatistics(AggregatePushDownContext context, ColumnRefSet g return false; } - // high(2): cardinality/count > MEDIUM_AGGREGATE - // medium(1): cardinality/count <= MEDIUM_AGGREGATE and > LOW_AGGREGATE - // lower(0): cardinality/count < LOW_AGGREGATE + // high(2): row_count / cardinality < MEDIUM_AGGREGATE_EFFECT_COEFFICIENT + // medium(1): row_count / cardinality >= MEDIUM_AGGREGATE_EFFECT_COEFFICIENT and < LOW_AGGREGATE_EFFECT_COEFFICIENT + // lower(0): row_count / cardinality >= LOW_AGGREGATE_EFFECT_COEFFICIENT private int groupByCardinality(ColumnStatistic statistic, double rowCount) { if (statistic.isUnknown()) { return 2; @@ -586,7 +592,7 @@ private boolean isSmallBroadcastJoin(OptExpression optExpression) { } double rightRows = rightStatistics.getOutputRowCount(); return rightRows <= sessionVariable.getBroadcastRowCountLimit() && - rightRows <= StatisticsEstimateCoefficient.SMALL_BROADCAST_JOIN_ROW_COUNT_UPPER_BOUND; + rightRows <= sessionVariable.getCboPushDownAggregateOnBroadcastJoinRowCountLimit(); } /** diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java index b9e8b151a291c..a2716bccc48e9 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/statistics/StatisticsEstimateCoefficient.java @@ -47,7 +47,7 @@ public class StatisticsEstimateCoefficient { public static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000; public static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000; public static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100; - public static final int SMALL_BROADCAST_JOIN_ROW_COUNT_UPPER_BOUND = 4096; + public static final int SMALL_BROADCAST_JOIN_MAX_NDV_LIMIT = 100000; public static final double EXTREME_HIGH_AGGREGATE_EFFECT_COEFFICIENT = 3; // default selectivity for anti join diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDSPushAggTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDSPushAggTest.java index bede2783541ec..e16c673214dae 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDSPushAggTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/TPCDSPushAggTest.java @@ -148,24 +148,24 @@ private static Stream testPushDownProvider() { Arguments[] cases = new Arguments[] { Arguments.of("Q01", 4, 4, false, 6, true, 4, false, 6, true), Arguments.of("Q02", 2, 6, true, 6, true, 6, true, 6, true), - Arguments.of("Q03", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q03", 2, 2, false, 4, true, 4, true, 4, true), // Although the number of aggregators is the same, the aggregator was pushed down. // This is caused by the CTE. orig: CTE inline, auto~high: CTE Arguments.of("Q04", 12, 12, true, 12, true, 12, true, 12, true), Arguments.of("Q05", 8, 16, true, 16, true, 16, true, 16, true), Arguments.of("Q08", 4, 6, true, 6, true, 6, true, 6, true), Arguments.of("Q11", 8, 8, true, 8, true, 8, true, 8, true), - Arguments.of("Q12", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q12", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q15", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q19", 2, 2, false, 4, true, 2, false, 2, false), - Arguments.of("Q20", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q20", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q23_1", 10, 13, true, 13, true, 13, true, 13, true), Arguments.of("Q24_1", 6, 6, false, 7, true, 6, false, 6, false), Arguments.of("Q24_2", 6, 6, false, 7, true, 6, false, 6, false), Arguments.of("Q30", 4, 4, false, 6, true, 4, false, 4, false), Arguments.of("Q31", 4, 8, true, 8, true, 8, true, 8, true), Arguments.of("Q33", 8, 8, false, 14, true, 14, true, 14, true), - Arguments.of("Q37", 2, 4, true, 8, true, 6, true, 7, true), + Arguments.of("Q37", 2, 2, false, 8, true, 6, true, 7, true), Arguments.of("Q38", 8, 14, true, 20, true, 14, true, 17, true), Arguments.of("Q41", 4, 4, false, 6, true, 4, false, 4, false), Arguments.of("Q42", 2, 4, true, 4, true, 4, true, 4, true), @@ -173,11 +173,11 @@ private static Stream testPushDownProvider() { Arguments.of("Q45", 6, 6, false, 8, true, 6, false, 8, true), Arguments.of("Q46", 2, 2, false, 4, true, 2, false, 2, false), Arguments.of("Q47", 2, 2, true, 4, true, 4, true, 4, true), - Arguments.of("Q51", 4, 8, true, 8, true, 8, true, 8, true), - Arguments.of("Q52", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q51", 4, 4, false, 8, true, 8, true, 8, true), + Arguments.of("Q52", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q53", 2, 2, false, 4, true, 4, true, 4, true), - Arguments.of("Q54", 9, 11, true, 18, true, 11, true, 17, true), - Arguments.of("Q55", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q54", 9, 9, false, 18, true, 11, true, 17, true), + Arguments.of("Q55", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q56", 8, 8, false, 14, true, 14, true, 14, true), Arguments.of("Q57", 2, 2, true, 4, true, 4, true, 4, true), Arguments.of("Q58", 6, 12, true, 12, true, 12, true, 12, true), @@ -194,13 +194,13 @@ private static Stream testPushDownProvider() { Arguments.of("Q78", 6, 6, false, 9, true, 6, false, 6, false), Arguments.of("Q79", 2, 2, false, 4, true, 2, false, 2, false), Arguments.of("Q81", 4, 4, false, 6, true, 4, false, 4, false), - Arguments.of("Q82", 2, 4, true, 8, true, 6, true, 7, true), + Arguments.of("Q82", 2, 2, false, 8, true, 6, true, 7, true), Arguments.of("Q83", 6, 12, true, 12, true, 12, true, 12, true), Arguments.of("Q87", 8, 14, true, 20, true, 14, true, 17, true), Arguments.of("Q89", 2, 2, false, 4, true, 4, true, 4, true), Arguments.of("Q91", 2, 4, true, 4, true, 4, true, 4, true), Arguments.of("Q97", 6, 6, false, 12, true, 10, true, 12, true), - Arguments.of("Q98", 2, 4, true, 4, true, 4, true, 4, true), + Arguments.of("Q98", 2, 2, false, 4, true, 4, true, 4, true), }; return Arrays.stream(cases);