From 7929173fe2c6703f2580baec5c6ac0d5858ff322 Mon Sep 17 00:00:00 2001 From: stdpain <34912776+stdpain@users.noreply.github.com> Date: Mon, 27 Jan 2025 10:00:21 +0800 Subject: [PATCH] [Enhancement] support push down agg distinct limit (#55455) Signed-off-by: stdpain (cherry picked from commit 4f452658bedf7b71fa2259045d737fa85779de68) # Conflicts: # be/src/exec/aggregate/agg_hash_set.h # be/src/exec/pipeline/aggregate/aggregate_distinct_streaming_sink_operator.cpp # fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java --- be/src/exec/aggregate/agg_hash_set.h | 50 +++++++++++++++++-- be/src/exec/aggregate/agg_hash_variant.cpp | 5 +- be/src/exec/aggregate/agg_hash_variant.h | 1 + .../aggregate/aggregate_blocking_node.cpp | 7 ++- .../aggregate_blocking_sink_operator.h | 2 +- ...egate_distinct_streaming_sink_operator.cpp | 25 +++++++++- ...gregate_distinct_streaming_sink_operator.h | 1 + ...lable_aggregate_blocking_sink_operator.cpp | 2 + .../com/starrocks/qe/SessionVariable.java | 14 ++++++ .../transformation/SplitTwoPhaseAggRule.java | 8 ++- .../sql/plan/PlanFragmentBuilder.java | 1 + .../com/starrocks/sql/plan/AggregateTest.java | 12 +++++ 12 files changed, 116 insertions(+), 12 deletions(-) diff --git a/be/src/exec/aggregate/agg_hash_set.h b/be/src/exec/aggregate/agg_hash_set.h index 58d6589d36784..691178935ec54 100644 --- a/be/src/exec/aggregate/agg_hash_set.h +++ b/be/src/exec/aggregate/agg_hash_set.h @@ -18,12 +18,18 @@ #include "column/column_helper.h" #include "column/hash_set.h" #include "column/type_traits.h" +<<<<<<< HEAD +======= +#include "column/vectorized_fwd.h" +#include "exec/aggregate/agg_profile.h" +>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455)) #include "gutil/casts.h" #include "runtime/mem_pool.h" #include "runtime/runtime_state.h" #include "util/fixed_hash_map.h" #include "util/hash_util.hpp" #include "util/phmap/phmap.h" +#include "util/runtime_profile.h" namespace starrocks { @@ -70,9 +76,10 @@ using SliceAggTwoLevelHashSet = template struct AggHashSet { - AggHashSet() = default; + AggHashSet(size_t chunk_size, AggStatistics* agg_stat_) : agg_stat(agg_stat_) {} using HHashSetType = HashSet; HashSet hash_set; + AggStatistics* agg_stat; ////// Common Methods //////// void build_hash_set(size_t chunk_size, const Columns& key_columns, MemPool* pool) { @@ -88,6 +95,7 @@ struct AggHashSet { // handle one number hash key template struct AggHashSetOfOneNumberKey : public AggHashSet> { + using Base = AggHashSet>; using KeyType = typename HashSet::key_type; using Iterator = typename HashSet::iterator; using ColumnType = RunTimeColumnType; @@ -95,7 +103,8 @@ struct AggHashSetOfOneNumberKey : public AggHashSet; static_assert(sizeof(FieldType) <= sizeof(KeyType), "hash set key size needs to be larger than the actual element"); - AggHashSetOfOneNumberKey(int32_t chunk_size) {} + template + AggHashSetOfOneNumberKey(Args&&... args) : Base(std::forward(args)...) {} // When compute_and_allocate=false: // Elements queried in HashSet will be added to HashSet @@ -132,6 +141,7 @@ struct AggHashSetOfOneNumberKey : public AggHashSet struct AggHashSetOfOneNullableNumberKey : public AggHashSet> { + using Base = AggHashSet>; using KeyType = typename HashSet::key_type; using Iterator = typename HashSet::iterator; using ColumnType = RunTimeColumnType; @@ -140,7 +150,8 @@ struct AggHashSetOfOneNullableNumberKey static_assert(sizeof(FieldType) <= sizeof(KeyType), "hash set key size needs to be larger than the actual element"); - AggHashSetOfOneNullableNumberKey(int32_t chunk_size) {} + template + AggHashSetOfOneNullableNumberKey(Args&&... args) : Base(std::forward(args)...) {} // When compute_and_allocate=false: // Elements queried in HashSet will be added to HashSet @@ -200,11 +211,13 @@ struct AggHashSetOfOneNullableNumberKey template struct AggHashSetOfOneStringKey : public AggHashSet> { + using Base = AggHashSet>; using Iterator = typename HashSet::iterator; using KeyType = typename HashSet::key_type; using ResultVector = typename std::vector; - AggHashSetOfOneStringKey(int32_t chunk_size) {} + template + AggHashSetOfOneStringKey(Args&&... args) : Base(std::forward(args)...) {} // When compute_and_allocate=false: // Elements queried in HashSet will be added to HashSet @@ -248,11 +261,17 @@ struct AggHashSetOfOneStringKey : public AggHashSet struct AggHashSetOfOneNullableStringKey : public AggHashSet> { + using Base = AggHashSet>; using Iterator = typename HashSet::iterator; using KeyType = typename HashSet::key_type; +<<<<<<< HEAD using ResultVector = typename std::vector; +======= + using ResultVector = Buffer; +>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455)) - AggHashSetOfOneNullableStringKey(int32_t chunk_size) {} + template + AggHashSetOfOneNullableStringKey(Args&&... args) : Base(std::forward(args)...) {} // When compute_and_allocate=false: // Elements queried in HashSet will be added to HashSet @@ -329,13 +348,25 @@ struct AggHashSetOfOneNullableStringKey : public AggHashSet struct AggHashSetOfSerializedKey : public AggHashSet> { + using Base = AggHashSet>; using Iterator = typename HashSet::iterator; +<<<<<<< HEAD using ResultVector = typename std::vector; using KeyType = typename HashSet::key_type; AggHashSetOfSerializedKey(int32_t chunk_size) : _mem_pool(std::make_unique()), _buffer(_mem_pool->allocate(max_one_row_size * chunk_size)), +======= + using ResultVector = Buffer; + using KeyType = typename HashSet::key_type; + + template + AggHashSetOfSerializedKey(int32_t chunk_size, Args&&... args) + : Base(chunk_size, std::forward(args)...), + _mem_pool(std::make_unique()), + _buffer(_mem_pool->allocate(max_one_row_size * chunk_size + SLICE_MEMEQUAL_OVERFLOW_PADDING)), +>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455)) _chunk_size(chunk_size) {} // When compute_and_allocate=false: @@ -422,6 +453,7 @@ struct AggHashSetOfSerializedKey : public AggHashSet struct AggHashSetOfSerializedKeyFixedSize : public AggHashSet> { + using Base = AggHashSet>; using Iterator = typename HashSet::iterator; using KeyType = typename HashSet::key_type; using FixedSizeSliceKey = typename HashSet::key_type; @@ -431,9 +463,17 @@ struct AggHashSetOfSerializedKeyFixedSize : public AggHashSet()), buffer(_mem_pool->allocate(max_fixed_size * chunk_size)), +======= + template + AggHashSetOfSerializedKeyFixedSize(int32_t chunk_size, Args&&... args) + : Base(chunk_size, std::forward(args)...), + _mem_pool(std::make_unique()), + buffer(_mem_pool->allocate(max_fixed_size * chunk_size + SLICE_MEMEQUAL_OVERFLOW_PADDING)), +>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455)) _chunk_size(chunk_size) { memset(buffer, 0x0, max_fixed_size * _chunk_size); } diff --git a/be/src/exec/aggregate/agg_hash_variant.cpp b/be/src/exec/aggregate/agg_hash_variant.cpp index f873536ef1769..6bd15efb08ec6 100644 --- a/be/src/exec/aggregate/agg_hash_variant.cpp +++ b/be/src/exec/aggregate/agg_hash_variant.cpp @@ -241,11 +241,12 @@ size_t AggHashMapVariant::allocated_memory_usage(const MemPool* pool) const { void AggHashSetVariant::init(RuntimeState* state, Type type, AggStatistics* agg_stat) { _type = type; + _agg_stat = agg_stat; switch (_type) { #define M(NAME) \ case Type::NAME: \ hash_set_with_key = std::make_unique::HashSetWithKeyType>( \ - state->chunk_size()); \ + state->chunk_size(), _agg_stat); \ break; APPLY_FOR_AGG_VARIANT_ALL(M) #undef M @@ -255,7 +256,7 @@ void AggHashSetVariant::init(RuntimeState* state, Type type, AggStatistics* agg_ #define CONVERT_TO_TWO_LEVEL_SET(DST, SRC) \ if (_type == AggHashSetVariant::Type::SRC) { \ auto dst = std::make_unique::HashSetWithKeyType>( \ - state->chunk_size()); \ + state->chunk_size(), _agg_stat); \ std::visit( \ [&](auto& hash_set_with_key) { \ if constexpr (std::is_same_vhash_set)::key_type, \ diff --git a/be/src/exec/aggregate/agg_hash_variant.h b/be/src/exec/aggregate/agg_hash_variant.h index 377fe96536937..65701eaba08f9 100644 --- a/be/src/exec/aggregate/agg_hash_variant.h +++ b/be/src/exec/aggregate/agg_hash_variant.h @@ -592,6 +592,7 @@ struct AggHashSetVariant { private: Type _type = Type::phase1_slice; + AggStatistics* _agg_stat = nullptr; }; } // namespace starrocks diff --git a/be/src/exec/aggregate/aggregate_blocking_node.cpp b/be/src/exec/aggregate/aggregate_blocking_node.cpp index 93597c3cbfbbb..8ba934223889b 100644 --- a/be/src/exec/aggregate/aggregate_blocking_node.cpp +++ b/be/src/exec/aggregate/aggregate_blocking_node.cpp @@ -291,7 +291,12 @@ pipeline::OpFactories AggregateBlockingNode::decompose_to_pipeline(pipeline::Pip _decompose_to_pipeline(ops_with_sink, context, false); } else { - if (runtime_state()->enable_spill() && runtime_state()->enable_agg_spill() && has_group_by_keys) { + // disable spill when group by with a small limit + bool enable_agg_spill = runtime_state()->enable_spill() && runtime_state()->enable_agg_spill(); + if (limit() != -1 && limit() < runtime_state()->chunk_size()) { + enable_agg_spill = false; + } + if (enable_agg_spill && has_group_by_keys) { ops_with_source = _decompose_to_pipeline(ops_with_sink, context, false); diff --git a/be/src/exec/pipeline/aggregate/aggregate_blocking_sink_operator.h b/be/src/exec/pipeline/aggregate/aggregate_blocking_sink_operator.h index e15c1ad74defa..e30dea14aab14 100644 --- a/be/src/exec/pipeline/aggregate/aggregate_blocking_sink_operator.h +++ b/be/src/exec/pipeline/aggregate/aggregate_blocking_sink_operator.h @@ -57,12 +57,12 @@ class AggregateBlockingSinkOperator : public Operator { // - reffed at constructor() of both sink and source operator, // - unreffed at close() of both sink and source operator. AggregatorPtr _aggregator = nullptr; + bool _agg_group_by_with_limit = false; private: // Whether prev operator has no output std::atomic_bool _is_finished = false; // whether enable aggregate group by limit optimize - bool _agg_group_by_with_limit = false; std::atomic& _shared_limit_countdown; }; diff --git a/be/src/exec/pipeline/aggregate/aggregate_distinct_streaming_sink_operator.cpp b/be/src/exec/pipeline/aggregate/aggregate_distinct_streaming_sink_operator.cpp index ec3eeec18219d..4204b846b60fa 100644 --- a/be/src/exec/pipeline/aggregate/aggregate_distinct_streaming_sink_operator.cpp +++ b/be/src/exec/pipeline/aggregate/aggregate_distinct_streaming_sink_operator.cpp @@ -26,6 +26,11 @@ Status AggregateDistinctStreamingSinkOperator::prepare(RuntimeState* state) { if (_aggregator->streaming_preaggregation_mode() == TStreamingPreaggregationMode::LIMITED_MEM) { _limited_mem_state.limited_memory_size = config::streaming_agg_limited_memory_size; } +<<<<<<< HEAD +======= + _aggregator->streaming_preaggregation_mode() = TStreamingPreaggregationMode::FORCE_PREAGGREGATION; + _aggregator->attach_sink_observer(state, this->_observer); +>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455)) return _aggregator->open(state); } @@ -37,7 +42,17 @@ void AggregateDistinctStreamingSinkOperator::close(RuntimeState* state) { } Status AggregateDistinctStreamingSinkOperator::set_finishing(RuntimeState* state) { +<<<<<<< HEAD _is_finished = true; +======= + if (_is_finished) return Status::OK(); + ONCE_DETECT(_set_finishing_once); + auto notify = _aggregator->defer_notify_source(); + auto defer = DeferOp([this]() { + _aggregator->sink_complete(); + _is_finished = true; + }); +>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455)) // skip processing if cancelled if (state->is_cancelled()) { @@ -48,7 +63,6 @@ Status AggregateDistinctStreamingSinkOperator::set_finishing(RuntimeState* state _aggregator->set_ht_eos(); } - _aggregator->sink_complete(); return Status::OK(); } @@ -68,7 +82,14 @@ Status AggregateDistinctStreamingSinkOperator::push_chunk(RuntimeState* state, c _aggregator->update_num_input_rows(chunk_size); COUNTER_SET(_aggregator->input_row_count(), _aggregator->num_input_rows()); - + bool limit_with_no_agg = _aggregator->limit() != -1; + if (limit_with_no_agg) { + auto size = _aggregator->hash_set_variant().size(); + if (size >= _aggregator->limit()) { + (void)set_finishing(state); + return Status::OK(); + } + } RETURN_IF_ERROR(_aggregator->evaluate_groupby_exprs(chunk.get())); if (_aggregator->streaming_preaggregation_mode() == TStreamingPreaggregationMode::FORCE_STREAMING) { diff --git a/be/src/exec/pipeline/aggregate/aggregate_distinct_streaming_sink_operator.h b/be/src/exec/pipeline/aggregate/aggregate_distinct_streaming_sink_operator.h index ee9f4026d0a34..a48c5d1e5997d 100644 --- a/be/src/exec/pipeline/aggregate/aggregate_distinct_streaming_sink_operator.h +++ b/be/src/exec/pipeline/aggregate/aggregate_distinct_streaming_sink_operator.h @@ -72,6 +72,7 @@ class AggregateDistinctStreamingSinkOperator : public Operator { // Whether prev operator has no output bool _is_finished = false; LimitedMemAggState _limited_mem_state; + DECLARE_ONCE_DETECTOR(_set_finishing_once); }; class AggregateDistinctStreamingSinkOperatorFactory final : public OperatorFactory { diff --git a/be/src/exec/pipeline/aggregate/spillable_aggregate_blocking_sink_operator.cpp b/be/src/exec/pipeline/aggregate/spillable_aggregate_blocking_sink_operator.cpp index 1667507ca0a95..3ba413592e2fb 100644 --- a/be/src/exec/pipeline/aggregate/spillable_aggregate_blocking_sink_operator.cpp +++ b/be/src/exec/pipeline/aggregate/spillable_aggregate_blocking_sink_operator.cpp @@ -114,6 +114,8 @@ Status SpillableAggregateBlockingSinkOperator::prepare(RuntimeState* state) { _peak_revocable_mem_bytes = _unique_metrics->AddHighWaterMarkCounter( "PeakRevocableMemoryBytes", TUnit::BYTES, RuntimeProfile::Counter::create_strategy(TUnit::BYTES)); _hash_table_spill_times = ADD_COUNTER(_unique_metrics.get(), "HashTableSpillTimes", TUnit::UNIT); + _agg_group_by_with_limit = false; + _aggregator->params()->enable_pipeline_share_limit = false; return Status::OK(); } diff --git a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java index 22dff61ea397d..cecfeeec26c5d 100644 --- a/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java @@ -651,7 +651,11 @@ public static MaterializedViewRewriteMode parse(String str) { public static final String CBO_PUSHDOWN_TOPN_LIMIT = "cbo_push_down_topn_limit"; +<<<<<<< HEAD public static final String ENABLE_CONSTANT_UNION_TO_VALUES = "enable_constant_union_to_values"; +======= + public static final String CBO_PUSHDOWN_DISTINCT_LIMIT = "cbo_push_down_distinct_limit"; +>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455)) public static final String ENABLE_AGGREGATION_PIPELINE_SHARE_LIMIT = "enable_aggregation_pipeline_share_limit"; @@ -1344,8 +1348,13 @@ public static MaterializedViewRewriteMode parse(String str) { @VarAttr(name = CBO_PUSHDOWN_TOPN_LIMIT) private long cboPushDownTopNLimit = 1000; +<<<<<<< HEAD @VarAttr(name = ENABLE_CONSTANT_UNION_TO_VALUES, flag = VariableMgr.INVISIBLE) private boolean enableConstantUnionToValues = true; +======= + @VarAttr(name = CBO_PUSHDOWN_DISTINCT_LIMIT) + private long cboPushDownDistinctLimit = 4096; +>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455)) @VarAttr(name = ENABLE_AGGREGATION_PIPELINE_SHARE_LIMIT, flag = VariableMgr.INVISIBLE) private boolean enableAggregationPipelineShareLimit = true; @@ -1371,8 +1380,13 @@ public long getCboPushDownTopNLimit() { return cboPushDownTopNLimit; } +<<<<<<< HEAD public boolean isEnableConstantUnionToValues() { return enableConstantUnionToValues; +======= + public long cboPushDownDistinctLimit() { + return cboPushDownDistinctLimit; +>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455)) } public void setCboPushDownTopNLimit(long cboPushDownTopNLimit) { diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/SplitTwoPhaseAggRule.java b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/SplitTwoPhaseAggRule.java index 9c2a18dd2b11e..56634f0d3276b 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/SplitTwoPhaseAggRule.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/optimizer/rule/transformation/SplitTwoPhaseAggRule.java @@ -92,12 +92,18 @@ public List transform(OptExpression input, OptimizerContext conte } } + long localAggLimit = Operator.DEFAULT_LIMIT; + boolean isOnlyGroupBy = aggOp.getAggregations().isEmpty(); + if (isOnlyGroupBy && aggOp.getLimit() < context.getSessionVariable().cboPushDownDistinctLimit()) { + localAggLimit = aggOp.getLimit(); + } + LogicalAggregationOperator local = new LogicalAggregationOperator.Builder().withOperator(aggOp) .setType(AggType.LOCAL) .setAggregations(createNormalAgg(AggType.LOCAL, newAggMap)) .setSplit() .setPredicate(null) - .setLimit(Operator.DEFAULT_LIMIT) + .setLimit(localAggLimit) .setProjection(null) .build(); OptExpression localOptExpression = OptExpression.create(local, input.getInputs()); diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java index d01820e0808c1..097a86d6d8ea0 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java @@ -1893,6 +1893,7 @@ public PlanFragment visitPhysicalHashAggregate(OptExpression optExpr, ExecPlan c hasColocateOlapScanChildInFragment(aggregationNode)) { aggregationNode.setColocate(!node.isWithoutColocateRequirement()); } + aggregationNode.setLimit(node.getLimit()); } else if (node.getType().isGlobal() || (node.getType().isLocal() && !node.isSplit())) { // Local && un-split aggregate meanings only execute local pre-aggregation, we need promise // output type match other node, so must use `update finalized` phase diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java index 22624e4b47d66..87f6a0e2bc141 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java @@ -2839,6 +2839,18 @@ public void testAvgDecimalScale() throws Exception { " | cardinality: 1"); } + @Test + public void testOnlyGroupByLimit() throws Exception { + FeConstants.runningUnitTest = true; + String sql = "select distinct v1 + v2 as vx from t0 limit 10"; + String plan = getFragmentPlan(sql); + assertContains(plan, " 2:AGGREGATE (update serialize)\n" + + " | STREAMING\n" + + " | group by: 4: expr\n" + + " | limit: 10"); + FeConstants.runningUnitTest = false; + } + @Test public void testHavingAggregate() throws Exception { String sql = "select * from (" +