Skip to content

Commit

Permalink
[Enhancement] support push down agg distinct limit (#55455)
Browse files Browse the repository at this point in the history
Signed-off-by: stdpain <[email protected]>
(cherry picked from commit 4f45265)

# Conflicts:
#	be/src/exec/aggregate/agg_hash_set.h
#	be/src/exec/pipeline/aggregate/aggregate_distinct_streaming_sink_operator.cpp
#	fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java
  • Loading branch information
stdpain authored and mergify[bot] committed Jan 27, 2025
1 parent 67a367d commit 7929173
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 12 deletions.
50 changes: 45 additions & 5 deletions be/src/exec/aggregate/agg_hash_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
#include "column/column_helper.h"
#include "column/hash_set.h"
#include "column/type_traits.h"
<<<<<<< HEAD
=======
#include "column/vectorized_fwd.h"
#include "exec/aggregate/agg_profile.h"
>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455))
#include "gutil/casts.h"
#include "runtime/mem_pool.h"
#include "runtime/runtime_state.h"
#include "util/fixed_hash_map.h"
#include "util/hash_util.hpp"
#include "util/phmap/phmap.h"
#include "util/runtime_profile.h"

namespace starrocks {

Expand Down Expand Up @@ -70,9 +76,10 @@ using SliceAggTwoLevelHashSet =

template <typename HashSet, typename Impl>
struct AggHashSet {
AggHashSet() = default;
AggHashSet(size_t chunk_size, AggStatistics* agg_stat_) : agg_stat(agg_stat_) {}
using HHashSetType = HashSet;
HashSet hash_set;
AggStatistics* agg_stat;

////// Common Methods ////////
void build_hash_set(size_t chunk_size, const Columns& key_columns, MemPool* pool) {
Expand All @@ -88,14 +95,16 @@ struct AggHashSet {
// handle one number hash key
template <LogicalType logical_type, typename HashSet>
struct AggHashSetOfOneNumberKey : public AggHashSet<HashSet, AggHashSetOfOneNumberKey<logical_type, HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfOneNumberKey<logical_type, HashSet>>;
using KeyType = typename HashSet::key_type;
using Iterator = typename HashSet::iterator;
using ColumnType = RunTimeColumnType<logical_type>;
using ResultVector = typename ColumnType::Container;
using FieldType = RunTimeCppType<logical_type>;
static_assert(sizeof(FieldType) <= sizeof(KeyType), "hash set key size needs to be larger than the actual element");

AggHashSetOfOneNumberKey(int32_t chunk_size) {}
template <class... Args>
AggHashSetOfOneNumberKey(Args&&... args) : Base(std::forward<Args>(args)...) {}

// When compute_and_allocate=false:
// Elements queried in HashSet will be added to HashSet
Expand Down Expand Up @@ -132,6 +141,7 @@ struct AggHashSetOfOneNumberKey : public AggHashSet<HashSet, AggHashSetOfOneNumb
template <LogicalType logical_type, typename HashSet>
struct AggHashSetOfOneNullableNumberKey
: public AggHashSet<HashSet, AggHashSetOfOneNullableNumberKey<logical_type, HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfOneNullableNumberKey<logical_type, HashSet>>;
using KeyType = typename HashSet::key_type;
using Iterator = typename HashSet::iterator;
using ColumnType = RunTimeColumnType<logical_type>;
Expand All @@ -140,7 +150,8 @@ struct AggHashSetOfOneNullableNumberKey

static_assert(sizeof(FieldType) <= sizeof(KeyType), "hash set key size needs to be larger than the actual element");

AggHashSetOfOneNullableNumberKey(int32_t chunk_size) {}
template <class... Args>
AggHashSetOfOneNullableNumberKey(Args&&... args) : Base(std::forward<Args>(args)...) {}

// When compute_and_allocate=false:
// Elements queried in HashSet will be added to HashSet
Expand Down Expand Up @@ -200,11 +211,13 @@ struct AggHashSetOfOneNullableNumberKey

template <typename HashSet>
struct AggHashSetOfOneStringKey : public AggHashSet<HashSet, AggHashSetOfOneStringKey<HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfOneStringKey<HashSet>>;
using Iterator = typename HashSet::iterator;
using KeyType = typename HashSet::key_type;
using ResultVector = typename std::vector<Slice>;

AggHashSetOfOneStringKey(int32_t chunk_size) {}
template <class... Args>
AggHashSetOfOneStringKey(Args&&... args) : Base(std::forward<Args>(args)...) {}

// When compute_and_allocate=false:
// Elements queried in HashSet will be added to HashSet
Expand Down Expand Up @@ -248,11 +261,17 @@ struct AggHashSetOfOneStringKey : public AggHashSet<HashSet, AggHashSetOfOneStri

template <typename HashSet>
struct AggHashSetOfOneNullableStringKey : public AggHashSet<HashSet, AggHashSetOfOneNullableStringKey<HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfOneNullableStringKey<HashSet>>;
using Iterator = typename HashSet::iterator;
using KeyType = typename HashSet::key_type;
<<<<<<< HEAD
using ResultVector = typename std::vector<Slice>;
=======
using ResultVector = Buffer<Slice>;
>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455))

AggHashSetOfOneNullableStringKey(int32_t chunk_size) {}
template <class... Args>
AggHashSetOfOneNullableStringKey(Args&&... args) : Base(std::forward<Args>(args)...) {}

// When compute_and_allocate=false:
// Elements queried in HashSet will be added to HashSet
Expand Down Expand Up @@ -329,13 +348,25 @@ struct AggHashSetOfOneNullableStringKey : public AggHashSet<HashSet, AggHashSetO

template <typename HashSet>
struct AggHashSetOfSerializedKey : public AggHashSet<HashSet, AggHashSetOfSerializedKey<HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfSerializedKey<HashSet>>;
using Iterator = typename HashSet::iterator;
<<<<<<< HEAD
using ResultVector = typename std::vector<Slice>;
using KeyType = typename HashSet::key_type;

AggHashSetOfSerializedKey(int32_t chunk_size)
: _mem_pool(std::make_unique<MemPool>()),
_buffer(_mem_pool->allocate(max_one_row_size * chunk_size)),
=======
using ResultVector = Buffer<Slice>;
using KeyType = typename HashSet::key_type;

template <class... Args>
AggHashSetOfSerializedKey(int32_t chunk_size, Args&&... args)
: Base(chunk_size, std::forward<Args>(args)...),
_mem_pool(std::make_unique<MemPool>()),
_buffer(_mem_pool->allocate(max_one_row_size * chunk_size + SLICE_MEMEQUAL_OVERFLOW_PADDING)),
>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455))
_chunk_size(chunk_size) {}

// When compute_and_allocate=false:
Expand Down Expand Up @@ -422,6 +453,7 @@ struct AggHashSetOfSerializedKey : public AggHashSet<HashSet, AggHashSetOfSerial

template <typename HashSet>
struct AggHashSetOfSerializedKeyFixedSize : public AggHashSet<HashSet, AggHashSetOfSerializedKeyFixedSize<HashSet>> {
using Base = AggHashSet<HashSet, AggHashSetOfSerializedKeyFixedSize<HashSet>>;
using Iterator = typename HashSet::iterator;
using KeyType = typename HashSet::key_type;
using FixedSizeSliceKey = typename HashSet::key_type;
Expand All @@ -431,9 +463,17 @@ struct AggHashSetOfSerializedKeyFixedSize : public AggHashSet<HashSet, AggHashSe
int fixed_byte_size = -1; // unset state
static constexpr size_t max_fixed_size = sizeof(FixedSizeSliceKey);

<<<<<<< HEAD
AggHashSetOfSerializedKeyFixedSize(int32_t chunk_size)
: _mem_pool(std::make_unique<MemPool>()),
buffer(_mem_pool->allocate(max_fixed_size * chunk_size)),
=======
template <class... Args>
AggHashSetOfSerializedKeyFixedSize(int32_t chunk_size, Args&&... args)
: Base(chunk_size, std::forward<Args>(args)...),
_mem_pool(std::make_unique<MemPool>()),
buffer(_mem_pool->allocate(max_fixed_size * chunk_size + SLICE_MEMEQUAL_OVERFLOW_PADDING)),
>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455))
_chunk_size(chunk_size) {
memset(buffer, 0x0, max_fixed_size * _chunk_size);
}
Expand Down
5 changes: 3 additions & 2 deletions be/src/exec/aggregate/agg_hash_variant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,12 @@ size_t AggHashMapVariant::allocated_memory_usage(const MemPool* pool) const {

void AggHashSetVariant::init(RuntimeState* state, Type type, AggStatistics* agg_stat) {
_type = type;
_agg_stat = agg_stat;
switch (_type) {
#define M(NAME) \
case Type::NAME: \
hash_set_with_key = std::make_unique<detail::AggHashSetVariantTypeTraits<Type::NAME>::HashSetWithKeyType>( \
state->chunk_size()); \
state->chunk_size(), _agg_stat); \
break;
APPLY_FOR_AGG_VARIANT_ALL(M)
#undef M
Expand All @@ -255,7 +256,7 @@ void AggHashSetVariant::init(RuntimeState* state, Type type, AggStatistics* agg_
#define CONVERT_TO_TWO_LEVEL_SET(DST, SRC) \
if (_type == AggHashSetVariant::Type::SRC) { \
auto dst = std::make_unique<detail::AggHashSetVariantTypeTraits<Type::DST>::HashSetWithKeyType>( \
state->chunk_size()); \
state->chunk_size(), _agg_stat); \
std::visit( \
[&](auto& hash_set_with_key) { \
if constexpr (std::is_same_v<typename decltype(hash_set_with_key->hash_set)::key_type, \
Expand Down
1 change: 1 addition & 0 deletions be/src/exec/aggregate/agg_hash_variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ struct AggHashSetVariant {

private:
Type _type = Type::phase1_slice;
AggStatistics* _agg_stat = nullptr;
};

} // namespace starrocks
7 changes: 6 additions & 1 deletion be/src/exec/aggregate/aggregate_blocking_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,12 @@ pipeline::OpFactories AggregateBlockingNode::decompose_to_pipeline(pipeline::Pip
_decompose_to_pipeline<StreamingAggregatorFactory, SortedAggregateStreamingSourceOperatorFactory,
SortedAggregateStreamingSinkOperatorFactory>(ops_with_sink, context, false);
} else {
if (runtime_state()->enable_spill() && runtime_state()->enable_agg_spill() && has_group_by_keys) {
// disable spill when group by with a small limit
bool enable_agg_spill = runtime_state()->enable_spill() && runtime_state()->enable_agg_spill();
if (limit() != -1 && limit() < runtime_state()->chunk_size()) {
enable_agg_spill = false;
}
if (enable_agg_spill && has_group_by_keys) {
ops_with_source = _decompose_to_pipeline<AggregatorFactory, SpillableAggregateBlockingSourceOperatorFactory,
SpillableAggregateBlockingSinkOperatorFactory>(ops_with_sink,
context, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ class AggregateBlockingSinkOperator : public Operator {
// - reffed at constructor() of both sink and source operator,
// - unreffed at close() of both sink and source operator.
AggregatorPtr _aggregator = nullptr;
bool _agg_group_by_with_limit = false;

private:
// Whether prev operator has no output
std::atomic_bool _is_finished = false;
// whether enable aggregate group by limit optimize
bool _agg_group_by_with_limit = false;
std::atomic<int64_t>& _shared_limit_countdown;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ Status AggregateDistinctStreamingSinkOperator::prepare(RuntimeState* state) {
if (_aggregator->streaming_preaggregation_mode() == TStreamingPreaggregationMode::LIMITED_MEM) {
_limited_mem_state.limited_memory_size = config::streaming_agg_limited_memory_size;
}
<<<<<<< HEAD
=======
_aggregator->streaming_preaggregation_mode() = TStreamingPreaggregationMode::FORCE_PREAGGREGATION;
_aggregator->attach_sink_observer(state, this->_observer);
>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455))
return _aggregator->open(state);
}

Expand All @@ -37,7 +42,17 @@ void AggregateDistinctStreamingSinkOperator::close(RuntimeState* state) {
}

Status AggregateDistinctStreamingSinkOperator::set_finishing(RuntimeState* state) {
<<<<<<< HEAD
_is_finished = true;
=======
if (_is_finished) return Status::OK();
ONCE_DETECT(_set_finishing_once);
auto notify = _aggregator->defer_notify_source();
auto defer = DeferOp([this]() {
_aggregator->sink_complete();
_is_finished = true;
});
>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455))

// skip processing if cancelled
if (state->is_cancelled()) {
Expand All @@ -48,7 +63,6 @@ Status AggregateDistinctStreamingSinkOperator::set_finishing(RuntimeState* state
_aggregator->set_ht_eos();
}

_aggregator->sink_complete();
return Status::OK();
}

Expand All @@ -68,7 +82,14 @@ Status AggregateDistinctStreamingSinkOperator::push_chunk(RuntimeState* state, c

_aggregator->update_num_input_rows(chunk_size);
COUNTER_SET(_aggregator->input_row_count(), _aggregator->num_input_rows());

bool limit_with_no_agg = _aggregator->limit() != -1;
if (limit_with_no_agg) {
auto size = _aggregator->hash_set_variant().size();
if (size >= _aggregator->limit()) {
(void)set_finishing(state);
return Status::OK();
}
}
RETURN_IF_ERROR(_aggregator->evaluate_groupby_exprs(chunk.get()));

if (_aggregator->streaming_preaggregation_mode() == TStreamingPreaggregationMode::FORCE_STREAMING) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class AggregateDistinctStreamingSinkOperator : public Operator {
// Whether prev operator has no output
bool _is_finished = false;
LimitedMemAggState _limited_mem_state;
DECLARE_ONCE_DETECTOR(_set_finishing_once);
};

class AggregateDistinctStreamingSinkOperatorFactory final : public OperatorFactory {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ Status SpillableAggregateBlockingSinkOperator::prepare(RuntimeState* state) {
_peak_revocable_mem_bytes = _unique_metrics->AddHighWaterMarkCounter(
"PeakRevocableMemoryBytes", TUnit::BYTES, RuntimeProfile::Counter::create_strategy(TUnit::BYTES));
_hash_table_spill_times = ADD_COUNTER(_unique_metrics.get(), "HashTableSpillTimes", TUnit::UNIT);
_agg_group_by_with_limit = false;
_aggregator->params()->enable_pipeline_share_limit = false;

return Status::OK();
}
Expand Down
14 changes: 14 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,11 @@ public static MaterializedViewRewriteMode parse(String str) {

public static final String CBO_PUSHDOWN_TOPN_LIMIT = "cbo_push_down_topn_limit";

<<<<<<< HEAD
public static final String ENABLE_CONSTANT_UNION_TO_VALUES = "enable_constant_union_to_values";
=======
public static final String CBO_PUSHDOWN_DISTINCT_LIMIT = "cbo_push_down_distinct_limit";
>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455))

public static final String ENABLE_AGGREGATION_PIPELINE_SHARE_LIMIT = "enable_aggregation_pipeline_share_limit";

Expand Down Expand Up @@ -1344,8 +1348,13 @@ public static MaterializedViewRewriteMode parse(String str) {
@VarAttr(name = CBO_PUSHDOWN_TOPN_LIMIT)
private long cboPushDownTopNLimit = 1000;

<<<<<<< HEAD
@VarAttr(name = ENABLE_CONSTANT_UNION_TO_VALUES, flag = VariableMgr.INVISIBLE)
private boolean enableConstantUnionToValues = true;
=======
@VarAttr(name = CBO_PUSHDOWN_DISTINCT_LIMIT)
private long cboPushDownDistinctLimit = 4096;
>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455))

@VarAttr(name = ENABLE_AGGREGATION_PIPELINE_SHARE_LIMIT, flag = VariableMgr.INVISIBLE)
private boolean enableAggregationPipelineShareLimit = true;
Expand All @@ -1371,8 +1380,13 @@ public long getCboPushDownTopNLimit() {
return cboPushDownTopNLimit;
}

<<<<<<< HEAD
public boolean isEnableConstantUnionToValues() {
return enableConstantUnionToValues;
=======
public long cboPushDownDistinctLimit() {
return cboPushDownDistinctLimit;
>>>>>>> 4f452658be ([Enhancement] support push down agg distinct limit (#55455))
}

public void setCboPushDownTopNLimit(long cboPushDownTopNLimit) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,18 @@ public List<OptExpression> transform(OptExpression input, OptimizerContext conte
}
}

long localAggLimit = Operator.DEFAULT_LIMIT;
boolean isOnlyGroupBy = aggOp.getAggregations().isEmpty();
if (isOnlyGroupBy && aggOp.getLimit() < context.getSessionVariable().cboPushDownDistinctLimit()) {
localAggLimit = aggOp.getLimit();
}

LogicalAggregationOperator local = new LogicalAggregationOperator.Builder().withOperator(aggOp)
.setType(AggType.LOCAL)
.setAggregations(createNormalAgg(AggType.LOCAL, newAggMap))
.setSplit()
.setPredicate(null)
.setLimit(Operator.DEFAULT_LIMIT)
.setLimit(localAggLimit)
.setProjection(null)
.build();
OptExpression localOptExpression = OptExpression.create(local, input.getInputs());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1893,6 +1893,7 @@ public PlanFragment visitPhysicalHashAggregate(OptExpression optExpr, ExecPlan c
hasColocateOlapScanChildInFragment(aggregationNode)) {
aggregationNode.setColocate(!node.isWithoutColocateRequirement());
}
aggregationNode.setLimit(node.getLimit());
} else if (node.getType().isGlobal() || (node.getType().isLocal() && !node.isSplit())) {
// Local && un-split aggregate meanings only execute local pre-aggregation, we need promise
// output type match other node, so must use `update finalized` phase
Expand Down
12 changes: 12 additions & 0 deletions fe/fe-core/src/test/java/com/starrocks/sql/plan/AggregateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2839,6 +2839,18 @@ public void testAvgDecimalScale() throws Exception {
" | cardinality: 1");
}

@Test
public void testOnlyGroupByLimit() throws Exception {
FeConstants.runningUnitTest = true;
String sql = "select distinct v1 + v2 as vx from t0 limit 10";
String plan = getFragmentPlan(sql);
assertContains(plan, " 2:AGGREGATE (update serialize)\n" +
" | STREAMING\n" +
" | group by: 4: expr\n" +
" | limit: 10");
FeConstants.runningUnitTest = false;
}

@Test
public void testHavingAggregate() throws Exception {
String sql = "select * from (" +
Expand Down

0 comments on commit 7929173

Please sign in to comment.