Skip to content

Commit

Permalink
Rebasing on main after several other aggregate functions were removed
Browse files Browse the repository at this point in the history
  • Loading branch information
edmondop committed Jul 14, 2024
1 parent c6eb03a commit d42fcc7
Show file tree
Hide file tree
Showing 16 changed files with 21 additions and 374 deletions.
2 changes: 1 addition & 1 deletion datafusion-examples/examples/dataframe_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use arrow_schema::DataType;
use std::sync::Arc;

use datafusion::error::Result;
use datafusion::logical_expr::test::function_stub::max;
use datafusion::functions_aggregate::average::avg;
use datafusion::logical_expr::test::function_stub::max;
use datafusion::prelude::*;
use datafusion::test_util::arrow_test_data;
use datafusion_common::ScalarValue;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ use datafusion_common::config::{CsvOptions, JsonOptions};
use datafusion_common::{
plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions,
};
use datafusion_expr::{case, is_null, lit};
use datafusion_expr::{
avg, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null};
use datafusion_expr::{case, is_null, lit};
use datafusion_functions_aggregate::expr_fn::{count, max, median, min, stddev, sum};

use async_trait::async_trait;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ fn take_optimizable_max(
match *num_rows {
0 => {
// MIN/MAX with 0 rows is always null
if let Some(casted_expr) = unwrap_max(agg_expr){
if let Some(casted_expr) = unwrap_max(agg_expr) {
if let Ok(max_data_type) =
ScalarValue::try_from(agg_expr.field().unwrap().data_type())
{
Expand All @@ -252,7 +252,7 @@ fn take_optimizable_max(
}
value if value > 0 => {
let col_stats = &stats.column_statistics;
if let Some(casted_expr) = unwrap_max(agg_expr){
if let Some(casted_expr) = unwrap_max(agg_expr) {
if casted_expr.expressions().len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) =
Expand Down
24 changes: 1 addition & 23 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ use strum_macros::EnumIter;
// https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum AggregateFunction {
/// Average
Avg,
/// Aggregation into an array
ArrayAgg,
}
Expand All @@ -43,7 +41,6 @@ impl AggregateFunction {
pub fn name(&self) -> &str {
use AggregateFunction::*;
match self {
Avg => "AVG",
ArrayAgg => "ARRAY_AGG",
}
}
Expand All @@ -59,11 +56,6 @@ impl FromStr for AggregateFunction {
type Err = DataFusionError;
fn from_str(name: &str) -> Result<AggregateFunction> {
Ok(match name {
// general
"avg" => AggregateFunction::Avg,
"bool_and" => AggregateFunction::BoolAnd,
"bool_or" => AggregateFunction::BoolOr,
"mean" => AggregateFunction::Avg,
"array_agg" => AggregateFunction::ArrayAgg,
_ => {
return plan_err!("There is no built-in function named {name}");
Expand Down Expand Up @@ -99,10 +91,6 @@ impl AggregateFunction {
})?;

match self {
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
coerced_data_types[0].clone(),
Expand All @@ -115,7 +103,6 @@ impl AggregateFunction {
/// nullability
pub fn nullable(&self) -> Result<bool> {
match self {
AggregateFunction::Max | AggregateFunction::Min => Ok(true),
AggregateFunction::ArrayAgg => Ok(true),
}
}
Expand All @@ -126,16 +113,7 @@ impl AggregateFunction {
pub fn signature(&self) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
match self {
AggregateFunction::Grouping | AggregateFunction::ArrayAgg => {
Signature::any(1, Volatility::Immutable)
}
AggregateFunction::Avg => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable),
}
}
}
Expand Down
6 changes: 0 additions & 6 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2540,12 +2540,6 @@ mod test {

#[test]
fn test_find_df_window_function() {
assert_eq!(
find_df_window_func("avg"),
Some(WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Avg
))
);
assert_eq!(
find_df_window_func("cume_dist"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/test/function_stub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ impl Default for Min {
impl Min {
pub fn new() -> Self {
Self {
aliases: vec!["count".to_string()],
aliases: vec!["min".to_string()],
signature: Signature::variadic_any(Volatility::Immutable),
}
}
Expand Down Expand Up @@ -412,7 +412,7 @@ impl Default for Max {
impl Max {
pub fn new() -> Self {
Self {
aliases: vec!["count".to_string()],
aliases: vec!["max".to_string()],
signature: Signature::variadic_any(Volatility::Immutable),
}
}
Expand Down
57 changes: 0 additions & 57 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ use arrow::datatypes::{

use datafusion_common::{internal_err, plan_err, Result};

use crate::{AggregateFunction, Signature, TypeSignature};

pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8];

pub static SIGNED_INTEGERS: &[DataType] = &[
Expand Down Expand Up @@ -93,53 +91,8 @@ pub fn coerce_types(
) -> Result<Vec<DataType>> {
// Validate input_types matches (at least one of) the func signature.
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
AggregateFunction::Avg => {
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval
let v = match &input_types[0] {
Decimal128(p, s) => Decimal128(*p, *s),
Decimal256(p, s) => Decimal256(*p, *s),
d if d.is_numeric() => Float64,
Dictionary(_, v) => {
return coerce_types(agg_fun, &[v.as_ref().clone()], signature)
}
_ => {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
)
}
};
Ok(vec![v])
}
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval.
if !is_bool_and_or_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(input_types.to_vec())
}
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(vec![Float64, Float64])
}
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
}
}

Expand Down Expand Up @@ -374,16 +327,6 @@ mod tests {
use super::*;
#[test]
fn test_aggregate_coerce_types() {
let fun = AggregateFunction::Avg;
// test input args is invalid data type for avg
let input_types = vec![DataType::Utf8];
let signature = fun.signature();
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!(
"Error during planning: The function Avg does not support inputs of type Utf8.",
result.unwrap_err().strip_backtrace()
);

// test count, array_agg, approx_distinct.
// the coerced types is same with input types
let funs = vec![AggregateFunction::ArrayAgg];
Expand Down
3 changes: 0 additions & 3 deletions datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,3 @@ datafusion-physical-expr-common = { workspace = true }
log = { workspace = true }
paste = "1.0.14"
sqlparser = { workspace = true }

[dev-dependencies]
rand = { workspace = true }
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ mod tests {
let migrated_functions = vec!["count", "max", "min"];
for func in all_default_aggregate_functions() {
// TODO: remove this
// These functions are in intermidiate migration state, skip them
// These functions are in interemdiate migration state, skip them
if migrated_functions.contains(&func.name().to_lowercase().as_str()) {
continue;
}
Expand Down
19 changes: 8 additions & 11 deletions datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,6 @@ impl<T: Clone + PartialOrd> MovingMax<T> {
}
}


make_udaf_expr_and_func!(
Max,
max,
Expand Down Expand Up @@ -961,10 +960,12 @@ impl AggregateUDFImpl for Max {
}
}

fn create_sliding_accumulator(&self, args:AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
fn create_sliding_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SlidingMaxAccumulator::try_new(args.data_type)?))
}

}

/// An accumulator to compute the maximum value
Expand Down Expand Up @@ -1161,11 +1162,12 @@ impl AggregateUDFImpl for Min {
}
}


fn create_sliding_accumulator(&self, args:AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
fn create_sliding_accumulator(
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SlidingMinAccumulator::try_new(args.data_type)?))
}

}
/// An accumulator to compute the minimum value
#[derive(Debug)]
Expand Down Expand Up @@ -1209,16 +1211,13 @@ impl Accumulator for MinAccumulator {
}
}



#[derive(Debug)]
pub struct SlidingMinAccumulator {
min: ScalarValue,
moving_min: MovingMin<ScalarValue>,
}

impl SlidingMinAccumulator {

pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
min: ScalarValue::try_from(datatype)?,
Expand Down Expand Up @@ -1372,7 +1371,6 @@ mod tests {
check(&mut max(), &[&[zero, neg_inf]], zero);
}


use datafusion_common::Result;
use rand::Rng;

Expand Down Expand Up @@ -1440,5 +1438,4 @@ mod tests {
moving_max_i32(100, 100)?;
Ok(())
}

}
Loading

0 comments on commit d42fcc7

Please sign in to comment.