Skip to content

Commit

Permalink
[opt](function)Optimize the Percentile function. (apache#41206)
Browse files Browse the repository at this point in the history
## Proposed changes
Support batch functions and floating-point parameters.

<!--Describe your changes.-->
  • Loading branch information
Mryange authored Sep 27, 2024
1 parent ba6102b commit ac7084c
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 3 deletions.
8 changes: 7 additions & 1 deletion be/src/util/counts.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ class Counts {
}
}

void increment(Ty key) { _nums.push_back(key); }

void increment_batch(const vectorized::PaddedPODArray<Ty>& keys) {
_nums.insert(keys.begin(), keys.end());
}

void serialize(vectorized::BufferWritable& buf) {
if (!_nums.empty()) {
pdqsort(_nums.begin(), _nums.end());
Expand Down Expand Up @@ -234,7 +240,7 @@ class Counts {
int array_index;
int64_t element_index;

std::strong_ordering operator<=>(const Node& other) const { return value <=> other.value; }
auto operator<=>(const Node& other) const { return value <=> other.value; }
};

void _convert_sorted_num_vec_to_nums() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ AggregateFunctionPtr create_aggregate_function_percentile_approx_weighted(

void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("percentile",
creator_with_integer_type::creator<AggregateFunctionPercentile>);
creator_with_numeric_type::creator<AggregateFunctionPercentile>);
factory.register_function_both(
"percentile_array",
creator_with_integer_type::creator<AggregateFunctionPercentileArray>);
Expand Down
23 changes: 22 additions & 1 deletion be/src/vec/aggregate_functions/aggregate_function_percentile.h
Original file line number Diff line number Diff line change
Expand Up @@ -631,10 +631,20 @@ struct PercentileState {
}
}
for (int i = 0; i < arg_size; ++i) {
vec_counts[i].increment(source, 1);
vec_counts[i].increment(source);
}
}

void add_batch(const PaddedPODArray<T>& source, const Float64& q) {
if (!inited_flag) {
inited_flag = true;
vec_counts.resize(1);
vec_quantile.resize(1);
vec_quantile[0] = q;
}
vec_counts[0].increment_batch(source);
}

void merge(const PercentileState& rhs) {
if (!rhs.inited_flag) {
return;
Expand Down Expand Up @@ -692,6 +702,17 @@ class AggregateFunctionPercentile final
quantile.get_data(), 1);
}

void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena* arena) const override {
const auto& sources =
assert_cast<const ColVecType&, TypeCheckOnRelease::DISABLE>(*columns[0]);
const auto& quantile =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1]);
DCHECK_EQ(sources.get_data().size(), batch_size);
AggregateFunctionPercentile::data(place).add_batch(sources.get_data(),
quantile.get_data()[0]);
}

void reset(AggregateDataPtr __restrict place) const override {
AggregateFunctionPercentile::data(place).reset();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.SmallIntType;
Expand All @@ -42,6 +43,8 @@ public class Percentile extends NullableAggregateFunction
implements BinaryExpression, ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(LargeIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, DoubleType.INSTANCE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,27 @@ beijing chengdu shanghai
5 29.0
6 101.0

-- !select20_1 --
1 10.1
2 224.6
3 10.1
5 29.1
6 101.1

-- !select21_1 --
1 10.1
2 246.25
3 10.1
5 29.1
6 101.1

-- !select22_1 --
1 10.1
2 356.665
3 10.1
5 29.1
6 101.1

-- !select23 --
1 10.0
2 224.5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ suite("test_aggregate_all_functions", "arrow_flight_sql") {
qt_select20 "select id,percentile(level,0.5) from ${tableName_13} group by id order by id"
qt_select21 "select id,percentile(level,0.55) from ${tableName_13} group by id order by id"
qt_select22 "select id,percentile(level,0.805) from ${tableName_13} group by id order by id"
qt_select20_1 "select id,percentile(level + 0.1,0.5) from ${tableName_13} group by id order by id"
qt_select21_1 "select id,percentile(level + 0.1,0.55) from ${tableName_13} group by id order by id"
qt_select22_1 "select id,percentile(level + 0.1,0.805) from ${tableName_13} group by id order by id"

sql "DROP TABLE IF EXISTS ${tableName_13}"

Expand Down

0 comments on commit ac7084c

Please sign in to comment.