diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs index e798751b33532..b9c2a3ff90929 100644 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ b/datafusion-examples/examples/dataframe_subquery.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use datafusion::error::Result; use datafusion::functions_aggregate::average::avg; +use datafusion::logical_expr::test::function_stub::max; use datafusion::prelude::*; use datafusion::test_util::arrow_test_data; use datafusion_common::ScalarValue; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index fb28b5c1ab470..01b0fb98faee4 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -53,9 +53,11 @@ use datafusion_common::{ }; use datafusion_expr::{case, is_null, lit}; use datafusion_expr::{ - max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, + utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, +}; +use datafusion_functions_aggregate::expr_fn::{ + avg, count, max, median, min, stddev, sum, }; -use datafusion_functions_aggregate::expr_fn::{avg, count, median, stddev, sum}; use async_trait::async_trait; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index e412d814239d1..bf68756dfe3da 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -283,7 +283,7 @@ fn is_min(agg_expr: &dyn AggregateExpr) -> bool { } if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { - if agg_expr.fun().name() == "min" { + if agg_expr.fun().name().to_lowercase() == "min" { return true; } } @@ -299,7 +299,7 @@ fn is_max(agg_expr: &dyn AggregateExpr) -> bool { } if let Some(agg_expr) = agg_expr.as_any().downcast_ref::() { - if agg_expr.fun().name() == "max" { + if agg_expr.fun().name().to_lowercase() == "max" { return true; } } diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index d68b80691917c..3a5db471bd280 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -54,11 +54,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder, + cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, sum}; +use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, max, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 5bd19850cacc8..ddfa940975d8a 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -35,10 +35,11 @@ use datafusion_common_runtime::SpawnedTask; use datafusion_expr::type_coercion::aggregates::coerce_types; use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; use datafusion_expr::{ - AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunctionDefinition, + BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -360,14 +361,14 @@ fn get_random_function( window_fn_map.insert( "min", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![arg.clone()], ), ); window_fn_map.insert( "max", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![arg.clone()], ), ); diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 39b3b4ed3b5a4..ac5ef8aced74f 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -33,10 +33,6 @@ use strum_macros::EnumIter; // https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { - /// Minimum - Min, - /// Maximum - Max, /// Aggregation into an array ArrayAgg, } @@ -45,8 +41,6 @@ impl AggregateFunction { pub fn name(&self) -> &str { use AggregateFunction::*; match self { - Min => "MIN", - Max => "MAX", ArrayAgg => "ARRAY_AGG", } } @@ -62,9 +56,6 @@ impl FromStr for AggregateFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { Ok(match name { - // general - "max" => AggregateFunction::Max, - "min" => AggregateFunction::Min, "array_agg" => AggregateFunction::ArrayAgg, _ => { return plan_err!("There is no built-in function named {name}"); @@ -100,11 +91,6 @@ impl AggregateFunction { })?; match self { - AggregateFunction::Max | AggregateFunction::Min => { - // For min and max agg function, the returned type is same as input type. - // The coerced_data_types is same with input_types. - Ok(coerced_data_types[0].clone()) - } AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( "item", coerced_data_types[0].clone(), @@ -117,7 +103,6 @@ impl AggregateFunction { /// nullability pub fn nullable(&self) -> Result { match self { - AggregateFunction::Max | AggregateFunction::Min => Ok(true), AggregateFunction::ArrayAgg => Ok(true), } } @@ -129,18 +114,6 @@ impl AggregateFunction { // note: the physical expression must accept the type returned by this function or the execution panics. match self { AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable), - AggregateFunction::Min | AggregateFunction::Max => { - let valid = STRINGS - .iter() - .chain(NUMERICS.iter()) - .chain(TIMESTAMPS.iter()) - .chain(DATES.iter()) - .chain(TIMES.iter()) - .chain(BINARYS.iter()) - .cloned() - .collect::>(); - Signature::uniform(1, valid, Volatility::Immutable) - } } } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index e3620501d9a8f..cf50df999aee6 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2524,8 +2524,6 @@ mod test { "first_value", "last_value", "nth_value", - "min", - "max", ]; for name in names { let fun = find_df_window_func(name).unwrap(); @@ -2542,18 +2540,6 @@ mod test { #[test] fn test_find_df_window_function() { - assert_eq!( - find_df_window_func("max"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Max - )) - ); - assert_eq!( - find_df_window_func("min"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Min - )) - ); assert_eq!( find_df_window_func("cume_dist"), Some(WindowFunctionDefinition::BuiltInWindowFunction( diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 9187e83522052..57b781d6732e1 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -18,17 +18,17 @@ //! Functions for creating logical expressions use crate::expr::{ - AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - Placeholder, TryCast, Unnest, + BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, Placeholder, TryCast, + Unnest, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, StateFieldsArgs, }; use crate::{ - aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, - AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + conditional_expressions::CaseBuilder, logical_plan::Subquery, AggregateUDF, Expr, + LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, Signature, + Volatility, }; use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::compute::kernels::cast_utils::{ @@ -147,30 +147,6 @@ pub fn not(expr: Expr) -> Expr { expr.not() } -/// Create an expression to represent the min() aggregate function -pub fn min(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Min, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Create an expression to represent the max() aggregate function -pub fn max(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Max, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 4b56ca3d1c2e0..2efdcae1a790c 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -156,11 +156,13 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use crate::{ - cast, col, lit, logical_plan::builder::LogicalTableSource, min, - test::function_stub::avg, try_cast, LogicalPlanBuilder, + cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast, + LogicalPlanBuilder, }; use super::*; + use crate::test::function_stub::avg; + use crate::test::function_stub::min; #[test] fn rewrite_sort_cols_by_agg() { diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index 14a6522ebe91e..19822c92d6908 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -291,6 +291,174 @@ impl AggregateUDFImpl for Count { } } +create_func!(Min, min_udaf); + +pub fn min(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + min_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Testing stub implementation of Min aggregate +pub struct Min { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Min { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Min") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Min { + fn default() -> Self { + Self::new() + } +} + +impl Min { + pub fn new() -> Self { + Self { + aliases: vec!["min".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Min { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MIN" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +create_func!(Max, max_udaf); + +pub fn max(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + max_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Testing stub implementation of MAX aggregate +pub struct Max { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Max { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Min") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Max { + fn default() -> Self { + Self::new() + } +} + +impl Max { + pub fn new() -> Self { + Self { + aliases: vec!["max".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Max { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MIN" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + /// Testing stub implementation of avg aggregate #[derive(Debug)] pub struct Avg { diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index fbec6e2f8024d..adad003d98f88 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::ops::Deref; +use crate::{AggregateFunction, Signature, TypeSignature}; use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, @@ -24,8 +24,6 @@ use arrow::datatypes::{ use datafusion_common::{internal_err, plan_err, Result}; -use crate::{AggregateFunction, Signature, TypeSignature}; - pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; pub static SIGNED_INTEGERS: &[DataType] = &[ @@ -93,14 +91,8 @@ pub fn coerce_types( ) -> Result> { // Validate input_types matches (at least one of) the func signature. check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; - match agg_fun { AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), - AggregateFunction::Min | AggregateFunction::Max => { - // min and max support the dictionary data type - // unpack the dictionary to get the value - get_min_max_result_type(input_types) - } } } @@ -164,22 +156,6 @@ pub fn check_arg_count( Ok(()) } -fn get_min_max_result_type(input_types: &[DataType]) -> Result> { - // make sure that the input types only has one element. - assert_eq!(input_types.len(), 1); - // min and max support the dictionary data type - // unpack the dictionary to get the value - match &input_types[0] { - DataType::Dictionary(_, dict_value_type) => { - // TODO add checker, if the value type is complex data type - Ok(vec![dict_value_type.deref().clone()]) - } - // TODO add checker for datatype which min and max supported - // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function - _ => Ok(input_types.to_vec()), - } -} - /// function return type of a sum pub fn sum_return_type(arg_type: &DataType) -> Result { match arg_type { @@ -351,20 +327,9 @@ mod tests { use super::*; #[test] fn test_aggregate_coerce_types() { - // test input args with error number input types - let fun = AggregateFunction::Min; - let input_types = vec![DataType::Int64, DataType::Int32]; - let signature = fun.signature(); - let result = coerce_types(&fun, &input_types, &signature); - assert_eq!("Error during planning: The function MIN expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace()); - - // test count, array_agg, approx_distinct, min, max. + // test count, array_agg, approx_distinct. // the coerced types is same with input types - let funs = vec![ - AggregateFunction::ArrayAgg, - AggregateFunction::Min, - AggregateFunction::Max, - ]; + let funs = vec![AggregateFunction::ArrayAgg]; let input_types = vec![ vec![DataType::Int32], vec![DataType::Decimal128(10, 2)], diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 1657e034fbe2b..7f074fe375a85 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -537,6 +537,21 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { self.signature().hash(hasher); hasher.finish() } + + /// Returns a flag beneficial for the aggregate statistics and optimizer + fn is_non_distinct_count(&self) -> bool { + false + } + + /// Returns a flag beneficial for the aggregate statistics and optimizer + fn is_min(&self) -> bool { + false + } + + /// Returns a flag beneficial for the aggregate statistics and optimizer + fn is_max(&self) -> bool { + false + } } pub enum ReversedUDAF { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 889aa0952e51e..fcb31ccbde70d 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1253,8 +1253,8 @@ mod tests { use super::*; use crate::{ col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, - test::function_stub::sum_udaf, AggregateFunction, Cast, WindowFrame, - WindowFunctionDefinition, + test::function_stub::max_udaf, test::function_stub::min_udaf, + test::function_stub::sum_udaf, Cast, WindowFrame, WindowFunctionDefinition, }; #[test] @@ -1268,7 +1268,7 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], vec![], vec![], @@ -1276,7 +1276,7 @@ mod tests { None, )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], vec![], vec![], @@ -1284,7 +1284,7 @@ mod tests { None, )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], vec![], vec![], @@ -1315,7 +1315,7 @@ mod tests { let created_at_desc = Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], @@ -1323,7 +1323,7 @@ mod tests { None, )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], vec![], vec![], @@ -1331,7 +1331,7 @@ mod tests { None, )); let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], vec![], vec![age_asc.clone(), name_desc.clone()], @@ -1371,7 +1371,7 @@ mod tests { fn test_find_sort_exprs() -> Result<()> { let exprs = &[ Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], vec![], vec![ diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 26630a0352d58..43ddd37cfb6ff 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -48,3 +48,6 @@ datafusion-physical-expr-common = { workspace = true } log = { workspace = true } paste = "1.0.14" sqlparser = { workspace = true } + +[dev-dependencies] +rand = { workspace = true } diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index b39b1955bb07b..e64be2012aa49 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -65,6 +65,7 @@ pub mod covariance; pub mod first_last; pub mod hyperloglog; pub mod median; +pub mod min_max; pub mod regr; pub mod stddev; pub mod sum; @@ -110,7 +111,8 @@ pub mod expr_fn { pub use super::first_last::last_value; pub use super::grouping::grouping; pub use super::median::median; - pub use super::nth_value::nth_value; + pub use super::min_max::max; + pub use super::min_max::min; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; pub use super::regr::regr_count; @@ -137,6 +139,8 @@ pub fn all_default_aggregate_functions() -> Vec> { covariance::covar_pop_udaf(), correlation::corr_udaf(), sum::sum_udaf(), + min_max::max_udaf(), + min_max::min_udaf(), median::median_udaf(), count::count_udaf(), regr::regr_slope_udaf(), @@ -192,11 +196,11 @@ mod tests { #[test] fn test_no_duplicate_name() -> Result<()> { let mut names = HashSet::new(); + let migrated_functions = vec!["array_agg", "count", "max", "min"]; for func in all_default_aggregate_functions() { // TODO: remove this - // These functions are in intermidiate migration state, skip them - let name_lower_case = func.name().to_lowercase(); - if name_lower_case == "count" || name_lower_case == "array_agg" { + // These functions are in intermediate migration state, skip them + if migrated_functions.contains(&func.name().to_lowercase().as_str()) { continue; } assert!( diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs new file mode 100644 index 0000000000000..4a03cc3203739 --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -0,0 +1,1426 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function + +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines `MAX` aggregate accumulators + +use arrow::array::{ + ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, + Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, LargeBinaryArray, LargeStringArray, StringArray, Time32MillisecondArray, + Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::compute; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use datafusion_common::{downcast_value, internal_err, DataFusionError, Result}; +use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use std::fmt::Debug; + +use arrow::datatypes::{ + Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; + +use arrow::datatypes::i256; + +use datafusion_common::ScalarValue; +use datafusion_expr::GroupsAccumulator; +use datafusion_expr::{ + function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, +}; + +macro_rules! typed_min_max_float { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => match a.total_cmp(b) { + choose_min_max!($OP) => Some(*b), + _ => Some(*a), + }, + }) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for binay types. +macro_rules! typed_min_max_batch_binary { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + let value = value.and_then(|e| Some(e.to_vec())); + ScalarValue::$SCALAR(value) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for non-string types. +macro_rules! typed_min_max_batch { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*) + }}; +} +// min/max of two scalar string values. +macro_rules! typed_min_max_string { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some((a).$OP(b).clone()), + }) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for string types. +macro_rules! typed_min_max_batch_string { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + let value = value.and_then(|e| Some(e.to_string())); + ScalarValue::$SCALAR(value) + }}; +} + +macro_rules! min_max { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + Ok(match ($VALUE, $DELTA) { + ( + lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + ( + lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { + typed_min_max!(lhs, rhs, Boolean, $OP) + } + (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { + typed_min_max_float!(lhs, rhs, Float64, $OP) + } + (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { + typed_min_max_float!(lhs, rhs, Float32, $OP) + } + (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { + typed_min_max!(lhs, rhs, UInt64, $OP) + } + (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { + typed_min_max!(lhs, rhs, UInt32, $OP) + } + (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { + typed_min_max!(lhs, rhs, UInt16, $OP) + } + (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { + typed_min_max!(lhs, rhs, UInt8, $OP) + } + (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { + typed_min_max!(lhs, rhs, Int64, $OP) + } + (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { + typed_min_max!(lhs, rhs, Int32, $OP) + } + (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { + typed_min_max!(lhs, rhs, Int16, $OP) + } + (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { + typed_min_max!(lhs, rhs, Int8, $OP) + } + (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { + typed_min_max_string!(lhs, rhs, Utf8, $OP) + } + (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) + } + (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { + typed_min_max_string!(lhs, rhs, Binary, $OP) + } + (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeBinary, $OP) + } + (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { + typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMillisecond(lhs, l_tz), + ScalarValue::TimestampMillisecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMicrosecond(lhs, l_tz), + ScalarValue::TimestampMicrosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) + } + ( + ScalarValue::TimestampNanosecond(lhs, l_tz), + ScalarValue::TimestampNanosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) + } + ( + ScalarValue::Date32(lhs), + ScalarValue::Date32(rhs), + ) => { + typed_min_max!(lhs, rhs, Date32, $OP) + } + ( + ScalarValue::Date64(lhs), + ScalarValue::Date64(rhs), + ) => { + typed_min_max!(lhs, rhs, Date64, $OP) + } + ( + ScalarValue::Time32Second(lhs), + ScalarValue::Time32Second(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Second, $OP) + } + ( + ScalarValue::Time32Millisecond(lhs), + ScalarValue::Time32Millisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Millisecond, $OP) + } + ( + ScalarValue::Time64Microsecond(lhs), + ScalarValue::Time64Microsecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Microsecond, $OP) + } + ( + ScalarValue::Time64Nanosecond(lhs), + ScalarValue::Time64Nanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) + } + ( + ScalarValue::IntervalYearMonth(lhs), + ScalarValue::IntervalYearMonth(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) + } + ( + ScalarValue::IntervalMonthDayNano(lhs), + ScalarValue::IntervalMonthDayNano(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) + } + ( + ScalarValue::IntervalDayTime(lhs), + ScalarValue::IntervalDayTime(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalDayTime, $OP) + } + ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalMonthDayNano(_), + ) | ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalMonthDayNano(_), + ) => { + interval_min_max!($OP, $VALUE, $DELTA) + } + ( + ScalarValue::DurationSecond(lhs), + ScalarValue::DurationSecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationSecond, $OP) + } + ( + ScalarValue::DurationMillisecond(lhs), + ScalarValue::DurationMillisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMillisecond, $OP) + } + ( + ScalarValue::DurationMicrosecond(lhs), + ScalarValue::DurationMicrosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) + } + ( + ScalarValue::DurationNanosecond(lhs), + ScalarValue::DurationNanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationNanosecond, $OP) + } + e => { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + e + ) + } + }) + }}; +} + +macro_rules! choose_min_max { + (min) => { + std::cmp::Ordering::Greater + }; + (max) => { + std::cmp::Ordering::Less + }; +} + +macro_rules! interval_min_max { + ($OP:tt, $LHS:expr, $RHS:expr) => {{ + match $LHS.partial_cmp(&$RHS) { + Some(choose_min_max!($OP)) => $RHS.clone(), + Some(_) => $LHS.clone(), + None => { + return internal_err!("Comparison error while computing interval min/max") + } + } + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for non-string types. +// this is a macro to support both operations (min and max). +macro_rules! min_max_batch { + ($VALUES:expr, $OP:ident) => {{ + match $VALUES.data_type() { + DataType::Decimal128(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal128Array, + Decimal128, + $OP, + precision, + scale + ) + } + DataType::Decimal256(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal256Array, + Decimal256, + $OP, + precision, + scale + ) + } + // all types that have a natural order + DataType::Float64 => { + typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) + } + DataType::Float32 => { + typed_min_max_batch!($VALUES, Float32Array, Float32, $OP) + } + DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP), + DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP), + DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP), + DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP), + DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP), + DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP), + DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP), + DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP), + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + typed_min_max_batch!( + $VALUES, + TimestampSecondArray, + TimestampSecond, + $OP, + tz_opt + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampMillisecondArray, + TimestampMillisecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampMicrosecondArray, + TimestampMicrosecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampNanosecondArray, + TimestampNanosecond, + $OP, + tz_opt + ), + DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP), + DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP), + DataType::Time32(TimeUnit::Second) => { + typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP) + } + DataType::Time32(TimeUnit::Millisecond) => { + typed_min_max_batch!( + $VALUES, + Time32MillisecondArray, + Time32Millisecond, + $OP + ) + } + DataType::Time64(TimeUnit::Microsecond) => { + typed_min_max_batch!( + $VALUES, + Time64MicrosecondArray, + Time64Microsecond, + $OP + ) + } + DataType::Time64(TimeUnit::Nanosecond) => { + typed_min_max_batch!( + $VALUES, + Time64NanosecondArray, + Time64Nanosecond, + $OP + ) + } + other => { + // This should have been handled before + return internal_err!( + "Min/Max accumulator not implemented for type {:?}", + other + ); + } + } + }}; +} + +/// dynamically-typed max(array) -> ScalarValue +fn max_batch(values: &ArrayRef) -> Result { + Ok(match values.data_type() { + DataType::Utf8 => { + typed_min_max_batch_string!(values, StringArray, Utf8, max_string) + } + DataType::LargeUtf8 => { + typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) + } + DataType::Boolean => { + typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) + } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + max_binary + ) + } + _ => min_max_batch!(values, max), + }) +} + +fn min_batch(values: &ArrayRef) -> Result { + Ok(match values.data_type() { + DataType::Utf8 => { + typed_min_max_batch_string!(values, StringArray, Utf8, min_string) + } + DataType::LargeUtf8 => { + typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) + } + DataType::Boolean => { + typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) + } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + min_binary + ) + } + _ => min_max_batch!(values, min), + }) +} +// min/max of two non-string scalar values. +macro_rules! typed_min_max { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + ScalarValue::$SCALAR( + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => Some((*a).$OP(*b)), + }, + $($EXTRA_ARGS.clone()),* + ) + }}; +} + +// The implementation is taken from https://github.com/spebern/moving_min_max/blob/master/src/lib.rs. + +// Keep track of the minimum or maximum value in a sliding window. +// +// `moving min max` provides one data structure for keeping track of the +// minimum value and one for keeping track of the maximum value in a sliding +// window. +// +// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, +// push to this stack all elements popped from first stack while updating their current min/max. Now pop from +// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, +// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. +// +// The complexity of the operations are +// - O(1) for getting the minimum/maximum +// - O(1) for push +// - amortized O(1) for pop + +/// ``` +/// # use datafusion_physical_expr::aggregate::moving_min_max::MovingMin; +/// let mut moving_min = MovingMin::::new(); +/// moving_min.push(2); +/// moving_min.push(1); +/// moving_min.push(3); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(2)); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(1)); +/// +/// assert_eq!(moving_min.min(), Some(&3)); +/// assert_eq!(moving_min.pop(), Some(3)); +/// +/// assert_eq!(moving_min.min(), None); +/// assert_eq!(moving_min.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMin { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMin { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMin { + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window with `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the minimum of the sliding window or `None` if the window is + /// empty. + #[inline] + pub fn min(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, min)), None) => Some(min), + (None, Some((_, min))) => Some(min), + (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, min)) => { + if val > *min { + (val, min.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let min = if last.1 < val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), min); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} +/// ``` +/// # use datafusion_physical_expr::aggregate::moving_min_max::MovingMax; +/// let mut moving_max = MovingMax::::new(); +/// moving_max.push(2); +/// moving_max.push(3); +/// moving_max.push(1); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(2)); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(3)); +/// +/// assert_eq!(moving_max.max(), Some(&1)); +/// assert_eq!(moving_max.pop(), Some(1)); +/// +/// assert_eq!(moving_max.max(), None); +/// assert_eq!(moving_max.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMax { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMax { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMax { + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with + /// `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the maximum of the sliding window or `None` if the window is empty. + #[inline] + pub fn max(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, max)), None) => Some(max), + (None, Some((_, max))) => Some(max), + (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, max)) => { + if val < *max { + (val, max.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let max = if last.1 > val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), max); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +make_udaf_expr_and_func!( + Max, + max, + expression, + "Returns the maximum of a group of values.", + max_udaf +); + +make_udaf_expr_and_func!( + Min, + min, + expression, + "Returns the minimum of a group of values.", + min_udaf +); + +fn min_max_aggregate_data_type(input_type: DataType) -> DataType { + if let DataType::Dictionary(_, value_type) = input_type { + *value_type + } else { + input_type + } +} + +#[derive(Debug)] +pub struct Max { + signature: Signature, + aliases: Vec, +} + +impl Max { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + aliases: vec!["max".to_owned()], + } + } +} + +impl Default for Max { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Max { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MAX" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(min_max_aggregate_data_type(arg_types[0].clone())) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(MaxAccumulator::try_new(acc_args.data_type)?)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + matches!( + _args.data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(TimeUnit::Second) + | DataType::Time32(TimeUnit::Millisecond) + | DataType::Time64(TimeUnit::Microsecond) + | DataType::Time64(TimeUnit::Nanosecond) + | DataType::Timestamp(TimeUnit::Second, _) + | DataType::Timestamp(TimeUnit::Millisecond, _) + | DataType::Timestamp(TimeUnit::Microsecond, _) + | DataType::Timestamp(TimeUnit::Nanosecond, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + use DataType::*; + use TimeUnit::*; + let data_type = args.data_type; + macro_rules! helper { + ($NATIVE:ident, $PRIMTYPE:ident) => {{ + Ok(Box::new( + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( + data_type, + |cur, new| { + if *cur < new { + *cur = new + } + }, + ) + .with_starting_value($NATIVE::MIN), + )) + }}; + } + + match args.data_type { + Int8 => helper!(i8, Int8Type), + Int16 => helper!(i16, Int16Type), + Int32 => helper!(i32, Int32Type), + Int64 => helper!(i64, Int64Type), + UInt8 => helper!(u8, UInt8Type), + UInt16 => helper!(u16, UInt16Type), + UInt32 => helper!(u32, UInt32Type), + UInt64 => helper!(u64, UInt64Type), + Float32 => { + helper!(f32, Float32Type) + } + Float64 => { + helper!(f64, Float64Type) + } + Date32 => helper!(i32, Date32Type), + Date64 => helper!(i64, Date64Type), + Time32(Second) => { + helper!(i32, Time32SecondType) + } + Time32(Millisecond) => { + helper!(i32, Time32MillisecondType) + } + Time64(Microsecond) => { + helper!(i64, Time64MicrosecondType) + } + Time64(Nanosecond) => { + helper!(i64, Time64NanosecondType) + } + Timestamp(Second, _) => { + helper!(i64, TimestampSecondType) + } + Timestamp(Millisecond, _) => { + helper!(i64, TimestampMillisecondType) + } + Timestamp(Microsecond, _) => { + helper!(i64, TimestampMicrosecondType) + } + Timestamp(Nanosecond, _) => { + helper!(i64, TimestampNanosecondType) + } + Decimal128(_, _) => { + helper!(i128, Decimal128Type) + } + Decimal256(_, _) => { + helper!(i256, Decimal256Type) + } + + // It would be nice to have a fast implementation for Strings as well + // https://github.com/apache/datafusion/issues/6906 + + // This is only reached if groups_accumulator_supported is out of sync + _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), + } + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(SlidingMaxAccumulator::try_new(args.data_type)?)) + } +} + +/// An accumulator to compute the maximum value +#[derive(Debug)] +pub struct MaxAccumulator { + max: ScalarValue, +} + +impl MaxAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + max: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MaxAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &max_batch(values)?; + let new_max: Result = + min_max!(&self.max, delta, max); + self.max = new_max?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + Ok(self.max.clone()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() + } +} + +#[derive(Debug)] +pub struct Min { + signature: Signature, + aliases: Vec, +} + +impl Min { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + aliases: vec!["min".to_owned()], + } + } +} +impl Default for Min { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Min { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "MIN" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(min_max_aggregate_data_type(arg_types[0].clone())) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(MinAccumulator::try_new(acc_args.data_type)?)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + matches!( + _args.data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(TimeUnit::Second) + | DataType::Time32(TimeUnit::Millisecond) + | DataType::Time64(TimeUnit::Microsecond) + | DataType::Time64(TimeUnit::Nanosecond) + | DataType::Timestamp(TimeUnit::Second, _) + | DataType::Timestamp(TimeUnit::Millisecond, _) + | DataType::Timestamp(TimeUnit::Microsecond, _) + | DataType::Timestamp(TimeUnit::Nanosecond, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + use DataType::*; + use TimeUnit::*; + let data_type = args.data_type; + macro_rules! helper { + ($NATIVE:ident, $PRIMTYPE:ident) => {{ + Ok(Box::new( + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( + data_type, + |cur, new| { + if *cur > new { + *cur = new + } + }, + ) + .with_starting_value($NATIVE::MAX), + )) + }}; + } + + match args.data_type { + Int8 => helper!(i8, Int8Type), + Int16 => helper!(i16, Int16Type), + Int32 => helper!(i32, Int32Type), + Int64 => helper!(i64, Int64Type), + UInt8 => helper!(u8, UInt8Type), + UInt16 => helper!(u16, UInt16Type), + UInt32 => helper!(u32, UInt32Type), + UInt64 => helper!(u64, UInt64Type), + Float32 => { + helper!(f32, Float32Type) + } + Float64 => { + helper!(f64, Float64Type) + } + Date32 => helper!(i32, Date32Type), + Date64 => helper!(i64, Date64Type), + Time32(Second) => { + helper!(i32, Time32SecondType) + } + Time32(Millisecond) => { + helper!(i32, Time32MillisecondType) + } + Time64(Microsecond) => { + helper!(i64, Time64MicrosecondType) + } + Time64(Nanosecond) => { + helper!(i64, Time64NanosecondType) + } + Timestamp(Second, _) => { + helper!(i64, TimestampSecondType) + } + Timestamp(Millisecond, _) => { + helper!(i64, TimestampMillisecondType) + } + Timestamp(Microsecond, _) => { + helper!(i64, TimestampMicrosecondType) + } + Timestamp(Nanosecond, _) => { + helper!(i64, TimestampNanosecondType) + } + Decimal128(_, _) => { + helper!(i128, Decimal128Type) + } + Decimal256(_, _) => { + helper!(i256, Decimal256Type) + } + + // It would be nice to have a fast implementation for Strings as well + // https://github.com/apache/datafusion/issues/6906 + + // This is only reached if groups_accumulator_supported is out of sync + _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), + } + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(SlidingMinAccumulator::try_new(args.data_type)?)) + } +} +/// An accumulator to compute the minimum value +#[derive(Debug)] +pub struct MinAccumulator { + min: ScalarValue, +} + +impl MinAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + min: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MinAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &min_batch(values)?; + let new_min: Result = + min_max!(&self.min, delta, min); + self.min = new_min?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + Ok(self.min.clone()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() + } +} + +#[derive(Debug)] +pub struct SlidingMinAccumulator { + min: ScalarValue, + moving_min: MovingMin, +} + +impl SlidingMinAccumulator { + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + min: ScalarValue::try_from(datatype)?, + moving_min: MovingMin::::new(), + }) + } +} + +impl Accumulator for SlidingMinAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.min.clone()]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + if !val.is_null() { + self.moving_min.push(val); + } + } + if let Some(res) = self.moving_min.min() { + self.min = res.clone(); + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + if !val.is_null() { + (self.moving_min).pop(); + } + } + if let Some(res) = self.moving_min.min() { + self.min = res.clone(); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + Ok(self.min.clone()) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() + } +} + +#[derive(Debug)] +pub struct SlidingMaxAccumulator { + max: ScalarValue, + moving_max: MovingMax, +} + +impl SlidingMaxAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + max: ScalarValue::try_from(datatype)?, + moving_max: MovingMax::::new(), + }) + } +} + +impl Accumulator for SlidingMaxAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + self.moving_max.push(val); + } + if let Some(res) = self.moving_max.max() { + self.max = res.clone(); + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for _idx in 0..values[0].len() { + (self.moving_max).pop(); + } + if let Some(res) = self.moving_max.max() { + self.max = res.clone(); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.max.clone()]) + } + + fn evaluate(&mut self) -> Result { + Ok(self.max.clone()) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + #[test] + fn float_min_max_with_nans() { + let pos_nan = f32::NAN; + let zero = 0_f32; + let neg_inf = f32::NEG_INFINITY; + + let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| { + for batch in values.iter() { + let batch = + Arc::new(Float32Array::from_iter_values(batch.iter().copied())); + acc.update_batch(&[batch]).unwrap(); + } + let result = acc.evaluate().unwrap(); + assert_eq!(result, ScalarValue::Float32(Some(expected))); + }; + + // This test checks both comparison between batches (which uses the min_max macro + // defined above) and within a batch (which uses the arrow min/max compute function + // and verifies both respect the total order comparison for floats) + + let min = || MinAccumulator::try_new(&DataType::Float32).unwrap(); + let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap(); + + check(&mut min(), &[&[zero], &[pos_nan]], zero); + check(&mut min(), &[&[zero, pos_nan]], zero); + check(&mut min(), &[&[zero], &[neg_inf]], neg_inf); + check(&mut min(), &[&[zero, neg_inf]], neg_inf); + check(&mut max(), &[&[zero], &[pos_nan]], pos_nan); + check(&mut max(), &[&[zero, pos_nan]], pos_nan); + check(&mut max(), &[&[zero], &[neg_inf]], zero); + check(&mut max(), &[&[zero, neg_inf]], zero); + } + + use datafusion_common::Result; + use rand::Rng; + + fn get_random_vec_i32(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut input = Vec::with_capacity(len); + for _i in 0..len { + input.push(rng.gen_range(0..100)); + } + input + } + + fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_min = MovingMin::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().min().unwrap()); + + moving_min.push(data[i]); + if i > n_sliding_window { + moving_min.pop(); + } + res.push(*moving_min.min().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_max = MovingMax::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().max().unwrap()); + + moving_max.push(data[i]); + if i > n_sliding_window { + moving_max.pop(); + } + res.push(*moving_max.max().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + #[test] + fn moving_min_tests() -> Result<()> { + moving_min_i32(100, 10)?; + moving_min_i32(100, 20)?; + moving_min_i32(100, 50)?; + moving_min_i32(100, 100)?; + Ok(()) + } + + #[test] + fn moving_max_tests() -> Result<()> { + moving_max_i32(100, 10)?; + moving_max_i32(100, 20)?; + moving_max_i32(100, 50)?; + moving_max_i32(100, 100)?; + Ok(()) + } +} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index fa8aeb86ed31e..1262f8f347500 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -102,11 +102,11 @@ mod tests { use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; use datafusion_expr::{ - col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, - WindowFrameUnits, + col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col, + scalar_subquery, wildcard, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::max; use std::sync::Arc; use datafusion_functions_aggregate::expr_fn::{count, sum}; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 58c1ae297b02e..414cedc4a1ac4 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -814,13 +814,13 @@ mod tests { expr::{self, Cast}, lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, - max, min, not, try_cast, when, AggregateFunction, BinaryExpr, Expr, Extension, - Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFrame, - WindowFunctionDefinition, + not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, + Projection, UserDefinedLogicalNodeCore, WindowFrame, WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_functions_aggregate::expr_fn::count; + use datafusion_functions_aggregate::expr_fn::{count, max, min}; + use datafusion_functions_aggregate::min_max::max_udaf; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) @@ -1917,7 +1917,7 @@ mod tests { let table_scan = test_table_scan()?; let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], vec![col("test.b")], vec![], @@ -1926,7 +1926,7 @@ mod tests { )); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], vec![], vec![], diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index cd2e0b6f5ba2e..15c25ad0116fb 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -330,8 +330,8 @@ mod test { use super::*; use crate::test::*; - - use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder, max}; + use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_functions_aggregate::expr_fn::max; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 35691847fb8e9..51a34b4861b04 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -393,8 +393,10 @@ mod tests { use crate::test::*; use arrow::datatypes::DataType; + // TODO: stubs or real functions use datafusion_expr::test::function_stub::sum; - use datafusion_expr::{col, lit, max, min, out_ref_col, scalar_subquery, Between}; + use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between}; + use datafusion_functions_aggregate::expr_fn::{max, min}; /// Test multiple correlated subqueries #[test] diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index e650d4c09c23f..e44f60d1df220 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -160,6 +160,7 @@ mod tests { ExprSchemable, JoinType, }; use datafusion_expr::{or, BinaryExpr, Cast, Operator}; + use datafusion_functions_aggregate::expr_fn::{max, min}; use crate::test::{assert_fields_eq, test_table_scan_with_name}; use crate::OptimizerContext; @@ -395,10 +396,7 @@ mod tests { .project(vec![col("a"), col("c"), col("b")])? .aggregate( vec![col("a"), col("c")], - vec![ - datafusion_expr::max(col("b").eq(lit(true))), - datafusion_expr::min(col("b")), - ], + vec![max(col("b").eq(lit(true))), min(col("b"))], )? .build()?; diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index f2b4abdd6cbd5..d651397278040 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -28,7 +28,6 @@ use datafusion_common::{ use datafusion_expr::builder::project; use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::{ - aggregate_function::AggregateFunction::{Max, Min}, col, expr::AggregateFunction, logical_plan::{Aggregate, LogicalPlan}, @@ -71,7 +70,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(fun), + func_def: AggregateFunctionDefinition::BuiltIn(_fun), distinct, args, filter, @@ -87,7 +86,7 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } - } else if !matches!(fun, Min | Max) { + } else { return Ok(false); } } else if let Expr::AggregateFunction(AggregateFunction { @@ -107,7 +106,10 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for e in args { fields_set.insert(e); } - } else if fun.name() != "sum" && fun.name() != "MIN" && fun.name() != "MAX" { + } else if fun.name() != "sum" + && fun.name().to_lowercase() != "min" + && fun.name().to_lowercase() != "max" + { return Ok(false); } } else { @@ -173,6 +175,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { // // First aggregate(from bottom) refers to `test.a` column. // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate. + // If we were to write plan above as below without alias // // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\ @@ -355,11 +358,9 @@ mod tests { use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; use datafusion_expr::AggregateExt; - use datafusion_expr::{ - lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, - }; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_functions_aggregate::expr_fn::{count, count_distinct, sum}; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum}; use datafusion_functions_aggregate::sum::sum_udaf; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { @@ -520,17 +521,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], - vec![ - count_distinct(col("b")), - Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Max, - vec![col("b")], - true, - None, - None, - None, - )), - ], + vec![count_distinct(col("b")), max(col("b"))], )? .build()?; // Should work @@ -584,18 +575,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], - vec![ - sum(col("c")), - count_distinct(col("b")), - Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Max, - vec![col("b")], - true, - None, - None, - None, - )), - ], + vec![sum(col("c")), count_distinct(col("b")), max(col("b"))], )? .build()?; // Should work diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 9c270561f37d2..0de34dcf423ee 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -36,6 +36,8 @@ use datafusion_expr::AggregateFunction; use crate::expressions::{self}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; +/// Create a physical aggregation expression. +/// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. pub fn create_aggregate_expr( @@ -77,154 +79,5 @@ pub fn create_aggregate_expr( )) } } - (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( - Arc::clone(&input_phy_exprs[0]), - name, - data_type, - )), - (AggregateFunction::Max, _) => Arc::new(expressions::Max::new( - Arc::clone(&input_phy_exprs[0]), - name, - data_type, - )), }) } - -#[cfg(test)] -mod tests { - use arrow::datatypes::{DataType, Field}; - - use datafusion_common::plan_err; - use datafusion_expr::{type_coercion, Signature}; - - use crate::expressions::{try_cast, Max, Min}; - - use super::*; - - #[test] - fn test_min_max_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::Min => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::Max => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; - } - } - Ok(()) - } - - #[test] - fn test_min_max() -> Result<()> { - let observed = AggregateFunction::Min.return_type(&[DataType::Utf8], &[true])?; - assert_eq!(DataType::Utf8, observed); - - let observed = AggregateFunction::Max.return_type(&[DataType::Int32], &[true])?; - assert_eq!(DataType::Int32, observed); - - // test decimal for min - let observed = AggregateFunction::Min - .return_type(&[DataType::Decimal128(10, 6)], &[true])?; - assert_eq!(DataType::Decimal128(10, 6), observed); - - // test decimal for max - let observed = AggregateFunction::Max - .return_type(&[DataType::Decimal128(28, 13)], &[true])?; - assert_eq!(DataType::Decimal128(28, 13), observed); - - Ok(()) - } - - // Helper function - // Create aggregate expr with type coercion - fn create_physical_agg_expr_for_test( - fun: &AggregateFunction, - distinct: bool, - input_phy_exprs: &[Arc], - input_schema: &Schema, - name: impl Into, - ) -> Result> { - let name = name.into(); - let coerced_phy_exprs = - coerce_exprs_for_test(fun, input_phy_exprs, input_schema, &fun.signature())?; - if coerced_phy_exprs.is_empty() { - return plan_err!( - "Invalid or wrong number of arguments passed to aggregate: '{name}'" - ); - } - create_aggregate_expr( - fun, - distinct, - &coerced_phy_exprs, - &[], - input_schema, - name, - false, - ) - } - - // Returns the coerced exprs for each `input_exprs`. - // Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the - // data type of `input_exprs` need to be coerced. - fn coerce_exprs_for_test( - agg_fun: &AggregateFunction, - input_exprs: &[Arc], - schema: &Schema, - signature: &Signature, - ) -> Result>> { - if input_exprs.is_empty() { - return Ok(vec![]); - } - let input_types = input_exprs - .iter() - .map(|e| e.data_type(schema)) - .collect::>>()?; - - // get the coerced data types - let coerced_types = - type_coercion::aggregates::coerce_types(agg_fun, &input_types, signature)?; - - // try cast if need - input_exprs - .iter() - .zip(coerced_types) - .map(|(expr, coerced_type)| try_cast(Arc::clone(expr), schema, coerced_type)) - .collect::>>() - } -} diff --git a/datafusion/proto/gen/src/main.rs b/datafusion/proto/gen/src/main.rs index d38a41a01ac23..3ede12aea2078 100644 --- a/datafusion/proto/gen/src/main.rs +++ b/datafusion/proto/gen/src/main.rs @@ -33,6 +33,7 @@ fn main() -> Result<(), String> { .file_descriptor_set_path(&descriptor_path) .out_dir(out_dir) .compile_well_known_types() + .protoc_arg("--experimental_allow_proto3_optional") .extern_path(".google.protobuf", "::pbjson_types") .compile_protos(&[proto_path], &["proto"]) .map_err(|e| format!("protobuf compilation failed: {e}"))?; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index dc551778c5fb2..eeed5d40a473b 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -466,8 +466,8 @@ message InListNode { } enum AggregateFunction { - MIN = 0; - MAX = 1; + UNUSED = 0; + // MAX = 1; // SUM = 2; // AVG = 3; // COUNT = 4; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 8f77c24bd9117..98559eb8059b8 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -532,8 +532,7 @@ impl serde::Serialize for AggregateFunction { S: serde::Serializer, { let variant = match self { - Self::Min => "MIN", - Self::Max => "MAX", + Self::Unused => "UNUSED", Self::ArrayAgg => "ARRAY_AGG", }; serializer.serialize_str(variant) @@ -546,8 +545,8 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "MIN", - "MAX", + "UNUSED", + "AVG", "ARRAY_AGG", ]; @@ -589,8 +588,7 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { E: serde::de::Error, { match value { - "MIN" => Ok(AggregateFunction::Min), - "MAX" => Ok(AggregateFunction::Max), + "UNUSED" => Ok(AggregateFunction::Unused), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 605c56fa946a3..30687d1b372eb 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1938,8 +1938,8 @@ pub struct PartitionStats { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum AggregateFunction { - Min = 0, - Max = 1, + Unused = 0, + /// MAX = 1; /// SUM = 2; /// AVG = 3; /// COUNT = 4; @@ -1982,16 +1982,14 @@ impl AggregateFunction { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - AggregateFunction::Min => "MIN", - AggregateFunction::Max => "MAX", + AggregateFunction::Unused => "UNUSED", AggregateFunction::ArrayAgg => "ARRAY_AGG", } } /// Creates an enum from field names used in the ProtoBuf definition. pub fn from_str_name(value: &str) -> ::core::option::Option { match value { - "MIN" => Some(Self::Min), - "MAX" => Some(Self::Max), + "UNUSED" => Some(Self::Unused), "ARRAY_AGG" => Some(Self::ArrayAgg), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b6b556a8ed6b2..71d363e1e67f8 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -142,9 +142,8 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { impl From for AggregateFunction { fn from(agg_fun: protobuf::AggregateFunction) -> Self { match agg_fun { - protobuf::AggregateFunction::Min => Self::Min, - protobuf::AggregateFunction::Max => Self::Max, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, + protobuf::AggregateFunction::Unused => panic!("This should never happen, we are retiring this but protobuf doesn't support enum with no 0 values"), } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9607b918eb895..c2d61e5236f6e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -114,8 +114,6 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { impl From<&AggregateFunction> for protobuf::AggregateFunction { fn from(value: &AggregateFunction) -> Self { match value { - AggregateFunction::Min => Self::Min, - AggregateFunction::Max => Self::Max, AggregateFunction::ArrayAgg => Self::ArrayAgg, } } @@ -387,8 +385,6 @@ pub fn serialize_expr( AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, }; let aggregate_expr = protobuf::AggregateExprNode { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e9a90fce2663f..2428f59978ccd 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -24,7 +24,7 @@ use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, CumeDist, InListExpr, IsNotNullExpr, - IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, Ntile, + IsNullExpr, Literal, NegativeExpr, NotExpr, NthValue, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; @@ -263,10 +263,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { // TODO: remove OrderSensitiveArrayAgg let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Min - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Max } else { return not_impl_err!("Aggregate function not supported: {expr:?}"); }; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 11945f39589a7..25e7e555d8e57 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -41,9 +41,10 @@ use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, - count_distinct, covar_pop, covar_samp, first_value, grouping, median, stddev, - stddev_pop, sum, var_pop, var_sample, + count_distinct, covar_pop, covar_samp, first_value, grouping, max, median, min, + stddev, stddev_pop, sum, var_pop, var_sample, }; +use datafusion::functions_aggregate::min_max::max_udaf; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::TableOptions; @@ -59,7 +60,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateExt, AggregateFunction, AggregateUDF, ColumnarValue, + Accumulator, AggregateExt, AggregateUDF, ColumnarValue, ExprSchemable, Literal, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, @@ -687,7 +688,9 @@ async fn roundtrip_expr_api() -> Result<()> { covar_pop(lit(1.5), lit(2.2)), corr(lit(1.5), lit(2.2)), sum(lit(1)), + max(lit(1)), median(lit(2)), + min(lit(2)), var_sample(lit(2.2)), var_pop(lit(2.2)), stddev(lit(2.2)), @@ -2087,7 +2090,7 @@ fn roundtrip_window() { ); let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], vec![col("col1")], vec![col("col2")],