From ac7084c4ff19f81a72ee2a15f542c9ae1a99658d Mon Sep 17 00:00:00 2001 From: Mryange <59914473+Mryange@users.noreply.github.com> Date: Fri, 27 Sep 2024 11:52:20 +0800 Subject: [PATCH] [opt](function)Optimize the Percentile function. (#41206) ## Proposed changes Support batch functions and floating-point parameters. --- be/src/util/counts.h | 8 ++++++- .../aggregate_function_percentile.cpp | 2 +- .../aggregate_function_percentile.h | 23 ++++++++++++++++++- .../expressions/functions/agg/Percentile.java | 3 +++ .../test_aggregate_all_functions.out | 21 +++++++++++++++++ .../test_aggregate_all_functions.groovy | 3 +++ 6 files changed, 57 insertions(+), 3 deletions(-) diff --git a/be/src/util/counts.h b/be/src/util/counts.h index e479f04c6208ee..968dc00e2ae8b8 100644 --- a/be/src/util/counts.h +++ b/be/src/util/counts.h @@ -157,6 +157,12 @@ class Counts { } } + void increment(Ty key) { _nums.push_back(key); } + + void increment_batch(const vectorized::PaddedPODArray& keys) { + _nums.insert(keys.begin(), keys.end()); + } + void serialize(vectorized::BufferWritable& buf) { if (!_nums.empty()) { pdqsort(_nums.begin(), _nums.end()); @@ -234,7 +240,7 @@ class Counts { int array_index; int64_t element_index; - std::strong_ordering operator<=>(const Node& other) const { return value <=> other.value; } + auto operator<=>(const Node& other) const { return value <=> other.value; } }; void _convert_sorted_num_vec_to_nums() { diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp b/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp index b0da562bd73b6c..ac8e40d03124d6 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.cpp @@ -104,7 +104,7 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted( void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) { factory.register_function_both("percentile", - creator_with_integer_type::creator); + creator_with_numeric_type::creator); factory.register_function_both( "percentile_array", creator_with_integer_type::creator); diff --git a/be/src/vec/aggregate_functions/aggregate_function_percentile.h b/be/src/vec/aggregate_functions/aggregate_function_percentile.h index 1c8a12340d7096..0cec238846eba1 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_percentile.h +++ b/be/src/vec/aggregate_functions/aggregate_function_percentile.h @@ -631,10 +631,20 @@ struct PercentileState { } } for (int i = 0; i < arg_size; ++i) { - vec_counts[i].increment(source, 1); + vec_counts[i].increment(source); } } + void add_batch(const PaddedPODArray& source, const Float64& q) { + if (!inited_flag) { + inited_flag = true; + vec_counts.resize(1); + vec_quantile.resize(1); + vec_quantile[0] = q; + } + vec_counts[0].increment_batch(source); + } + void merge(const PercentileState& rhs) { if (!rhs.inited_flag) { return; @@ -692,6 +702,17 @@ class AggregateFunctionPercentile final quantile.get_data(), 1); } + void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, + Arena* arena) const override { + const auto& sources = + assert_cast(*columns[0]); + const auto& quantile = + assert_cast(*columns[1]); + DCHECK_EQ(sources.get_data().size(), batch_size); + AggregateFunctionPercentile::data(place).add_batch(sources.get_data(), + quantile.get_data()[0]); + } + void reset(AggregateDataPtr __restrict place) const override { AggregateFunctionPercentile::data(place).reset(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java index d8328baadf79c2..17f501b7f1333a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Percentile.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.LargeIntType; import org.apache.doris.nereids.types.SmallIntType; @@ -42,6 +43,8 @@ public class Percentile extends NullableAggregateFunction implements BinaryExpression, ExplicitlyCastableSignature { public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), + FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE, DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE), FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE), diff --git a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out index 5713247a09e6c7..75d9a18679f78a 100644 --- a/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out +++ b/regression-test/data/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.out @@ -129,6 +129,27 @@ beijing chengdu shanghai 5 29.0 6 101.0 +-- !select20_1 -- +1 10.1 +2 224.6 +3 10.1 +5 29.1 +6 101.1 + +-- !select21_1 -- +1 10.1 +2 246.25 +3 10.1 +5 29.1 +6 101.1 + +-- !select22_1 -- +1 10.1 +2 356.665 +3 10.1 +5 29.1 +6 101.1 + -- !select23 -- 1 10.0 2 224.5 diff --git a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy index 9cd9ff7e04d9a6..cdab9472e27dbd 100644 --- a/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy +++ b/regression-test/suites/query_p0/sql_functions/aggregate_functions/test_aggregate_all_functions.groovy @@ -283,6 +283,9 @@ suite("test_aggregate_all_functions", "arrow_flight_sql") { qt_select20 "select id,percentile(level,0.5) from ${tableName_13} group by id order by id" qt_select21 "select id,percentile(level,0.55) from ${tableName_13} group by id order by id" qt_select22 "select id,percentile(level,0.805) from ${tableName_13} group by id order by id" + qt_select20_1 "select id,percentile(level + 0.1,0.5) from ${tableName_13} group by id order by id" + qt_select21_1 "select id,percentile(level + 0.1,0.55) from ${tableName_13} group by id order by id" + qt_select22_1 "select id,percentile(level + 0.1,0.805) from ${tableName_13} group by id order by id" sql "DROP TABLE IF EXISTS ${tableName_13}"