Skip to content

Commit

Permalink
Fix ApproxPercentileCont signature (#8825)
Browse files Browse the repository at this point in the history
* Fix ApproxPercentileCont signature

The number of centroids must be an integer in `coerce_types`.
Reflect that in the type signature.

* Add a unit test for percentile signature error message
  • Loading branch information
joroKr21 authored Jan 14, 2024
1 parent acf0f78 commit af3d190
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 28 deletions.
27 changes: 16 additions & 11 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,18 +386,23 @@ impl AggregateFunction {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ApproxPercentileCont => {
let mut variants =
Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
// Accept any numeric value paired with a float64 percentile
let with_tdigest_size = NUMERICS.iter().map(|t| {
TypeSignature::Exact(vec![t.clone(), DataType::Float64, t.clone()])
});
Signature::one_of(
NUMERICS
.iter()
.map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
.chain(with_tdigest_size)
.collect(),
Volatility::Immutable,
)
for num in NUMERICS {
variants
.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
// Additionally accept an integer number of centroids for T-Digest
for int in INTEGERS {
variants.push(TypeSignature::Exact(vec![
num.clone(),
DataType::Float64,
int.clone(),
]))
}
}

Signature::one_of(variants, Volatility::Immutable)
}
AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of(
// Accept any numeric value paired with a float64 percentile
Expand Down
22 changes: 6 additions & 16 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ pub fn coerce_types(
| AggregateFunction::RegrSXX
| AggregateFunction::RegrSYY
| AggregateFunction::RegrSXY => {
let valid_types = [NUMERICS.to_vec(), vec![DataType::Null]].concat();
let valid_types = [NUMERICS.to_vec(), vec![Null]].concat();
let input_types_valid = // number of input already checked before
valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]);
if !input_types_valid {
Expand All @@ -243,15 +243,15 @@ pub fn coerce_types(
input_types[0]
);
}
if input_types.len() == 3 && !is_integer_arg_type(&input_types[2]) {
if input_types.len() == 3 && !input_types[2].is_integer() {
return plan_err!(
"The percentile sample points count for {:?} must be integer, not {:?}.",
agg_fun, input_types[2]
);
}
let mut result = input_types.to_vec();
if can_coerce_from(&DataType::Float64, &input_types[1]) {
result[1] = DataType::Float64;
if can_coerce_from(&Float64, &input_types[1]) {
result[1] = Float64;
} else {
return plan_err!(
"Could not coerce the percent argument for {:?} to Float64. Was {:?}.",
Expand All @@ -275,7 +275,7 @@ pub fn coerce_types(
input_types[1]
);
}
if !matches!(input_types[2], DataType::Float64) {
if !matches!(input_types[2], Float64) {
return plan_err!(
"The percentile argument for {:?} must be Float64, not {:?}.",
agg_fun,
Expand Down Expand Up @@ -560,17 +560,7 @@ pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
}

pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
)
arg_type.is_integer()
}

/// Return `true` if `arg_type` is of a [`DataType`] that the
Expand Down
26 changes: 25 additions & 1 deletion datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ mod test {
}

#[test]
fn agg_function_invalid_input() -> Result<()> {
fn agg_function_invalid_input_avg() -> Result<()> {
let empty = empty();
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
Expand All @@ -984,6 +984,30 @@ mod test {
Ok(())
}

#[test]
fn agg_function_invalid_input_percentile() {
let empty = empty();
let fun: AggregateFunction = AggregateFunction::ApproxPercentileCont;
let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
vec![lit(0.95), lit(42.0), lit(100.0)],
false,
None,
None,
));

let err = Projection::try_new(vec![agg_expr], empty)
.err()
.unwrap()
.strip_backtrace();

let prefix = "Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT(Float64, Float64, Float64)'. You might need to add explicit type casts.\n\tCandidate functions:";
assert!(!err
.strip_prefix(prefix)
.unwrap()
.contains("APPROX_PERCENTILE_CONT(Float64, Float64, Float64)"));
}

#[test]
fn binary_op_date32_op_interval() -> Result<()> {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
Expand Down
3 changes: 3 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100
statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Float64\)'\. You might need to add explicit type casts\.
SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100

statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Float64, Float64, Float64\)'\. You might need to add explicit type casts\.
SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100

# array agg can use order by
query ?
SELECT array_agg(c13 ORDER BY c13)
Expand Down

0 comments on commit af3d190

Please sign in to comment.