From 26b44f4dade6aea129bf9825a67268ec40872349 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 26 May 2024 18:02:32 +0800 Subject: [PATCH] Move Median to `functions-aggregate` and Introduce Numeric signature (#10644) * introduce median udaf Signed-off-by: jayzhan211 * rm agg median Signed-off-by: jayzhan211 * rm old median Signed-off-by: jayzhan211 * introduce numeric signature Signed-off-by: jayzhan211 * address comment Signed-off-by: jayzhan211 * fix doc Signed-off-by: jayzhan211 * add proto roundtrip Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 1 + datafusion/core/src/dataframe/mod.rs | 3 +- datafusion/expr/src/aggregate_function.rs | 9 +- datafusion/expr/src/expr_fn.rs | 12 -- datafusion/expr/src/function.rs | 3 + datafusion/expr/src/signature.rs | 14 ++ datafusion/expr/src/tree_node.rs | 2 +- .../expr/src/type_coercion/aggregates.rs | 10 +- .../expr/src/type_coercion/functions.rs | 32 ++++ datafusion/expr/src/udaf.rs | 3 +- datafusion/functions-aggregate/Cargo.toml | 1 + datafusion/functions-aggregate/src/lib.rs | 3 + .../src}/median.rs | 164 +++++++++--------- .../optimizer/src/analyzer/type_coercion.rs | 2 +- .../src/single_distinct_to_groupby.rs | 50 ++++++ .../physical-expr-common/src/aggregate/mod.rs | 2 + .../physical-expr/src/aggregate/build_in.rs | 6 - datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/expressions/mod.rs | 1 - .../physical-plan/src/aggregates/mod.rs | 27 ++- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 2 - .../proto/src/physical_plan/to_proto.rs | 4 +- .../tests/cases/roundtrip_logical_plan.rs | 2 + .../sqllogictest/test_files/aggregate.slt | 27 +++ 28 files changed, 258 insertions(+), 133 deletions(-) rename datafusion/{physical-expr/src/aggregate => functions-aggregate/src}/median.rs (78%) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 99af80bf9df2..e659e62d7ac7 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1287,6 +1287,7 @@ name = "datafusion-functions-aggregate" version = "38.0.0" dependencies = [ "arrow", + "arrow-schema", "datafusion-common", "datafusion-execution", "datafusion-expr", diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index d4626134acbf..e1656a22b1a4 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -50,10 +50,11 @@ use datafusion_common::{ }; use datafusion_expr::lit; use datafusion_expr::{ - avg, count, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, + avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_expr::{case, is_null, sum}; +use datafusion_functions_aggregate::expr_fn::median; use async_trait::async_trait; diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 0a7607498c61..f251969ca618 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -43,8 +43,6 @@ pub enum AggregateFunction { Max, /// Average Avg, - /// Median - Median, /// Approximate distinct function ApproxDistinct, /// Aggregation into an array @@ -114,7 +112,6 @@ impl AggregateFunction { Min => "MIN", Max => "MAX", Avg => "AVG", - Median => "MEDIAN", ApproxDistinct => "APPROX_DISTINCT", ArrayAgg => "ARRAY_AGG", FirstValue => "FIRST_VALUE", @@ -168,7 +165,6 @@ impl FromStr for AggregateFunction { "count" => AggregateFunction::Count, "max" => AggregateFunction::Max, "mean" => AggregateFunction::Avg, - "median" => AggregateFunction::Median, "min" => AggregateFunction::Min, "sum" => AggregateFunction::Sum, "array_agg" => AggregateFunction::ArrayAgg, @@ -275,9 +271,7 @@ impl AggregateFunction { AggregateFunction::ApproxPercentileContWithWeight => { Ok(coerced_data_types[0].clone()) } - AggregateFunction::ApproxMedian | AggregateFunction::Median => { - Ok(coerced_data_types[0].clone()) - } + AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()), AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::FirstValue | AggregateFunction::LastValue @@ -335,7 +329,6 @@ impl AggregateFunction { | AggregateFunction::VariancePop | AggregateFunction::Stddev | AggregateFunction::StddevPop - | AggregateFunction::Median | AggregateFunction::ApproxMedian | AggregateFunction::FirstValue | AggregateFunction::LastValue => { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 2a2bb75f1884..8c9d3c7885b0 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -296,18 +296,6 @@ pub fn approx_distinct(expr: Expr) -> Expr { )) } -/// Calculate the median for `expr`. -pub fn median(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Median, - vec![expr], - false, - None, - None, - None, - )) -} - /// Calculate an approximation of the median for `expr`. pub fn approx_median(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 714cfa1af671..eb748ed2711a 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -82,6 +82,9 @@ pub struct AccumulatorArgs<'a> { /// The number of arguments the aggregate function takes. pub args_num: usize, + + /// The name of the expression + pub name: &'a str, } /// [`StateFieldsArgs`] contains information about the fields that an diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index 63b030f0b748..33f643eb2dc2 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -119,6 +119,9 @@ pub enum TypeSignature { OneOf(Vec), /// Specifies Signatures for array functions ArraySignature(ArrayFunctionSignature), + /// Fixed number of arguments of numeric types. + /// See to know which type is considered numeric + Numeric(usize), } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -178,6 +181,9 @@ impl TypeSignature { .collect::>() .join(", ")] } + TypeSignature::Numeric(num) => { + vec![format!("Numeric({})", num)] + } TypeSignature::Exact(types) => { vec![Self::join_types(types, ", ")] } @@ -259,6 +265,14 @@ impl Signature { volatility, } } + + pub fn numeric(num: usize, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::Numeric(num), + volatility, + } + } + /// An arbitrary number of arguments of any type. pub fn variadic_any(volatility: Volatility) -> Self { Self { diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 31ca4c40942b..c5f1694c1138 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -332,7 +332,7 @@ impl TreeNode for Expr { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( fun, new_args, - false, + distinct, new_filter, new_order_by, null_treatment, diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 57c0b6f4edc5..ce4a2a709842 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -283,9 +283,9 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } - AggregateFunction::Median - | AggregateFunction::FirstValue - | AggregateFunction::LastValue => Ok(input_types.to_vec()), + AggregateFunction::FirstValue | AggregateFunction::LastValue => { + Ok(input_types.to_vec()) + } AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), AggregateFunction::StringAgg => { @@ -355,6 +355,10 @@ pub fn check_arg_count( ); } } + TypeSignature::UserDefined | TypeSignature::Numeric(_) => { + // User-defined signature is validated in `coerce_types` + // Numreic signature is validated in `get_valid_types` + } _ => { return internal_err!( "Aggregate functions do not support this {signature:?}" diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 583d75e1ccfc..b41ec109103d 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -322,6 +322,38 @@ fn get_valid_types( .iter() .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) .collect(), + TypeSignature::Numeric(number) => { + if *number < 1 { + return plan_err!( + "The signature expected at least one argument but received {}", + current_types.len() + ); + } + if *number != current_types.len() { + return plan_err!( + "The signature expected {} arguments but received {}", + number, + current_types.len() + ); + } + + let mut valid_type = current_types.first().unwrap().clone(); + for t in current_types.iter().skip(1) { + if let Some(coerced_type) = + comparison_binary_numeric_coercion(&valid_type, t) + { + valid_type = coerced_type; + } else { + return plan_err!( + "{} and {} are not coercible to a common numeric type", + valid_type, + t + ); + } + } + + vec![vec![valid_type; *number]] + } TypeSignature::Uniform(number, valid_types) => valid_types .iter() .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 4fd8d51679f0..b620a897bcc9 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -193,9 +193,10 @@ impl AggregateUDF { self.inner.create_groups_accumulator() } - pub fn coerce_types(&self, _args: &[DataType]) -> Result> { + pub fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("coerce_types not implemented for {:?} yet", self.name()) } + /// Do the function rewrite /// /// See [`AggregateUDFImpl::simplify`] for more details. diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index f97647565364..696bbaece9e6 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -39,6 +39,7 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } +arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index e76a43e39899..3e80174eec33 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -57,6 +57,7 @@ pub mod macros; pub mod covariance; pub mod first_last; +pub mod median; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; @@ -68,6 +69,7 @@ use std::sync::Arc; pub mod expr_fn { pub use super::covariance::covar_samp; pub use super::first_last::first_value; + pub use super::median::median; } /// Registers all enabled packages with a [`FunctionRegistry`] @@ -76,6 +78,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { first_last::first_value_udaf(), covariance::covar_samp_udaf(), covariance::covar_pop_udaf(), + median::median_udaf(), ]; functions.into_iter().try_for_each(|udf| { diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/functions-aggregate/src/median.rs similarity index 78% rename from datafusion/physical-expr/src/aggregate/median.rs rename to datafusion/functions-aggregate/src/median.rs index ee0fce3fabe7..b3fb05d7fcf7 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/functions-aggregate/src/median.rs @@ -15,22 +15,38 @@ // specific language governing permissions and limitations // under the License. -//! # Median - -use crate::aggregate::utils::{down_cast_any_ref, Hashable}; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; -use arrow_array::{downcast_integer, ArrowNativeTypeOp, ArrowNumericType}; -use arrow_buffer::ArrowNativeType; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::Accumulator; -use std::any::Any; use std::collections::HashSet; use std::fmt::Formatter; -use std::sync::Arc; +use std::{fmt::Debug, sync::Arc}; + +use arrow::array::{downcast_integer, ArrowNumericType}; +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::{ + DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type, + Float64Type, + }, +}; + +use arrow::array::Array; +use arrow::array::ArrowNativeTypeOp; +use arrow::datatypes::ArrowNativeType; + +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::{ + function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, + Signature, Volatility, +}; +use datafusion_physical_expr_common::aggregate::utils::Hashable; + +make_udaf_expr_and_func!( + Median, + median, + expression, + "Computes the median of a set of numbers", + median_udaf +); /// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a /// lot of memory because all values need to be stored in memory before a result can be @@ -40,46 +56,72 @@ use std::sync::Arc; /// If using the distinct variation, the memory usage will be similarly high if the /// cardinality is high as it stores all distinct values in memory before computing the /// result, but if cardinality is low then memory usage will also be lower. -#[derive(Debug)] pub struct Median { - name: String, - expr: Arc, - data_type: DataType, - distinct: bool, + signature: Signature, + aliases: Vec, +} + +impl Debug for Median { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Median") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Median { + fn default() -> Self { + Self::new() + } } impl Median { - /// Create a new MEDIAN aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - distinct: bool, - ) -> Self { + pub fn new() -> Self { Self { - name: name.into(), - expr, - data_type, - distinct, + aliases: vec!["median".to_string()], + signature: Signature::numeric(1, Volatility::Immutable), } } } -impl AggregateExpr for Median { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { +impl AggregateUDFImpl for Median { + fn as_any(&self) -> &dyn std::any::Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + fn name(&self) -> &str { + "MEDIAN" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) } - fn create_accumulator(&self) -> Result> { - use arrow_array::types::*; + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + //Intermediate state is a list of the elements we have collected so far + let field = Field::new("item", args.input_type.clone(), true); + let state_name = if args.is_distinct { + "distinct_median" + } else { + "median" + }; + + Ok(vec![Field::new( + format_state_name(args.name, state_name), + DataType::List(Arc::new(field)), + true, + )]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { macro_rules! helper { ($t:ty, $dt:expr) => { - if self.distinct { + if acc_args.is_distinct { Ok(Box::new(DistinctMedianAccumulator::<$t> { data_type: $dt.clone(), distinct_values: HashSet::new(), @@ -92,7 +134,8 @@ impl AggregateExpr for Median { } }; } - let dt = &self.data_type; + + let dt = acc_args.input_type; downcast_integer! { dt => (helper, dt), DataType::Float16 => helper!(Float16Type, dt), @@ -102,49 +145,14 @@ impl AggregateExpr for Median { DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), _ => Err(DataFusionError::NotImplemented(format!( "MedianAccumulator not supported for {} with {}", - self.name(), - self.data_type + acc_args.name, + dt, ))), } } - fn state_fields(&self) -> Result> { - //Intermediate state is a list of the elements we have collected so far - let field = Field::new("item", self.data_type.clone(), true); - let data_type = DataType::List(Arc::new(field)); - let state_name = if self.distinct { - "distinct_median" - } else { - "median" - }; - - Ok(vec![Field::new( - format_state_name(&self.name, state_name), - data_type, - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Median { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.expr.eq(&x.expr) - && self.distinct == x.distinct - }) - .unwrap_or(false) + fn aliases(&self) -> &[String] { + &self.aliases } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 3d08bd6c7e42..69be344cb753 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -402,7 +402,7 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { expr::AggregateFunction::new_udf( fun, new_expr, - false, + distinct, filter, order_by, null_treatment, diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 4b1f9a0d1401..27449c8dd5c4 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -257,6 +257,56 @@ impl OptimizerRule for SingleDistinctToGroupBy { ))) } } + Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::UDF(udf), + args, + distinct, + .. + }) => { + if distinct { + if args.len() != 1 { + return internal_err!("DISTINCT aggregate should have exactly one argument"); + } + let mut args = args; + let arg = args.swap_remove(0); + + if group_fields_set.insert(arg.display_name()?) { + inner_group_exprs + .push(arg.alias(SINGLE_DISTINCT_ALIAS)); + } + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + udf, + vec![col(SINGLE_DISTINCT_ALIAS)], + false, // intentional to remove distinct here + None, + None, + None, + ))) + // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation + } else { + index += 1; + let alias_str = format!("alias{}", index); + inner_aggr_exprs.push( + Expr::AggregateFunction(AggregateFunction::new_udf( + udf.clone(), + args, + false, + None, + None, + None, + )) + .alias(&alias_str), + ); + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + udf.clone(), + vec![col(&alias_str)], + false, + None, + None, + None, + ))) + } + } _ => Ok(aggr_expr), }) .collect::>>()?; diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 4ef0d58046f8..4e9414bc5a11 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -219,6 +219,7 @@ impl AggregateExpr for AggregateFunctionExpr { is_distinct: self.is_distinct, input_type: &self.input_type, args_num: self.args.len(), + name: &self.name, }; self.fun.accumulator(acc_args) @@ -292,6 +293,7 @@ impl AggregateExpr for AggregateFunctionExpr { is_distinct: self.is_distinct, input_type: &self.input_type, args_num: self.args.len(), + name: &self.name, }; self.fun.groups_accumulator_supported(args) } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 145e7feadf8c..18252ea370eb 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -332,12 +332,6 @@ pub fn create_aggregate_expr( "APPROX_MEDIAN(DISTINCT) aggregations are not available" ); } - (AggregateFunction::Median, distinct) => Arc::new(expressions::Median::new( - input_phy_exprs[0].clone(), - name, - data_type, - distinct, - )), (AggregateFunction::FirstValue, _) => Arc::new( expressions::FirstValue::new( input_phy_exprs[0].clone(), diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 93ecf0655e51..039c8814e987 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -39,7 +39,6 @@ pub(crate) mod count; pub(crate) mod count_distinct; pub(crate) mod covariance; pub(crate) mod grouping; -pub(crate) mod median; pub(crate) mod nth_value; pub(crate) mod string_agg; #[macro_use] diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 980297b8b433..a7921800fccd 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -53,7 +53,6 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::grouping::Grouping; -pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 21608db40d56..cf31c2990b7d 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1203,8 +1203,9 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion_functions_aggregate::median::median_udaf; use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Count, FirstValue, LastValue, Median, OrderSensitiveArrayAgg, + lit, ApproxDistinct, Count, FirstValue, LastValue, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::{reverse_order_bys, PhysicalSortExpr}; @@ -1773,6 +1774,22 @@ mod tests { check_grouping_sets(input, true).await } + // Median(a) + fn test_median_agg_expr(schema: &Schema) -> Result> { + let args = vec![col("a", schema)?]; + let fun = median_udaf(); + datafusion_physical_expr_common::aggregate::create_aggregate_expr( + &fun, + &args, + &[], + &[], + schema, + "MEDIAN(a)", + false, + false, + ) + } + #[tokio::test] async fn test_oom() -> Result<()> { let input: Arc = Arc::new(TestYieldingExec::new(true)); @@ -1792,12 +1809,8 @@ mod tests { }; // something that allocates within the aggregator - let aggregates_v0: Vec> = vec![Arc::new(Median::new( - col("a", &input_schema)?, - "MEDIAN(a)".to_string(), - DataType::UInt32, - false, - ))]; + let aggregates_v0: Vec> = + vec![test_median_agg_expr(&input_schema)?]; // use slow-path in `hash.rs` let aggregates_v1: Vec> = diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 73e751c616ac..434ec9f81f15 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -551,7 +551,7 @@ enum AggregateFunction { APPROX_MEDIAN = 15; APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; GROUPING = 17; - MEDIAN = 18; + // MEDIAN = 18; BIT_AND = 19; BIT_OR = 20; BIT_XOR = 21; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 77ba0808fb77..86a5975c8bb8 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -437,7 +437,6 @@ impl serde::Serialize for AggregateFunction { Self::ApproxMedian => "APPROX_MEDIAN", Self::ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Self::Grouping => "GROUPING", - Self::Median => "MEDIAN", Self::BitAnd => "BIT_AND", Self::BitOr => "BIT_OR", Self::BitXor => "BIT_XOR", @@ -483,7 +482,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "APPROX_MEDIAN", "APPROX_PERCENTILE_CONT_WITH_WEIGHT", "GROUPING", - "MEDIAN", "BIT_AND", "BIT_OR", "BIT_XOR", @@ -558,7 +556,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "APPROX_MEDIAN" => Ok(AggregateFunction::ApproxMedian), "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => Ok(AggregateFunction::ApproxPercentileContWithWeight), "GROUPING" => Ok(AggregateFunction::Grouping), - "MEDIAN" => Ok(AggregateFunction::Median), "BIT_AND" => Ok(AggregateFunction::BitAnd), "BIT_OR" => Ok(AggregateFunction::BitOr), "BIT_XOR" => Ok(AggregateFunction::BitXor), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index a175987f1994..cb2de710075a 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2848,7 +2848,7 @@ pub enum AggregateFunction { ApproxMedian = 15, ApproxPercentileContWithWeight = 16, Grouping = 17, - Median = 18, + /// MEDIAN = 18; BitAnd = 19, BitOr = 20, BitXor = 21, @@ -2895,7 +2895,6 @@ impl AggregateFunction { "APPROX_PERCENTILE_CONT_WITH_WEIGHT" } AggregateFunction::Grouping => "GROUPING", - AggregateFunction::Median => "MEDIAN", AggregateFunction::BitAnd => "BIT_AND", AggregateFunction::BitOr => "BIT_OR", AggregateFunction::BitXor => "BIT_XOR", @@ -2937,7 +2936,6 @@ impl AggregateFunction { Some(Self::ApproxPercentileContWithWeight) } "GROUPING" => Some(Self::Grouping), - "MEDIAN" => Some(Self::Median), "BIT_AND" => Some(Self::BitAnd), "BIT_OR" => Some(Self::BitOr), "BIT_XOR" => Some(Self::BitXor), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b6f72f6773a2..00c62fc32b98 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -450,7 +450,6 @@ impl From for AggregateFunction { } protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, protobuf::AggregateFunction::Grouping => Self::Grouping, - protobuf::AggregateFunction::Median => Self::Median, protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, protobuf::AggregateFunction::LastValueAgg => Self::LastValue, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 91f7411e911a..f2ee679ac129 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -386,7 +386,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { } AggregateFunction::ApproxMedian => Self::ApproxMedian, AggregateFunction::Grouping => Self::Grouping, - AggregateFunction::Median => Self::Median, AggregateFunction::FirstValue => Self::FirstValueAgg, AggregateFunction::LastValue => Self::LastValueAgg, AggregateFunction::NthValue => Self::NthValueAgg, @@ -697,7 +696,6 @@ pub fn serialize_expr( protobuf::AggregateFunction::ApproxMedian } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, AggregateFunction::FirstValue => { protobuf::AggregateFunction::FirstValueAgg } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index c6b94a934f23..d3badee3efff 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -27,7 +27,7 @@ use datafusion::physical_plan::expressions::{ ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, Count, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, InListExpr, IsNotNullExpr, - IsNullExpr, LastValue, Literal, Max, Median, Min, NegativeExpr, NotExpr, NthValue, + IsNullExpr, LastValue, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, Variance, VariancePop, WindowShift, @@ -318,8 +318,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::ApproxPercentileContWithWeight } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ApproxMedian - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Median } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::FirstValueAgg } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 6e819ef5bf46..d83d6cd1c297 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -33,6 +33,7 @@ use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; use datafusion::functions_aggregate::expr_fn::first_value; +use datafusion::functions_aggregate::median::median; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; @@ -624,6 +625,7 @@ async fn roundtrip_expr_api() -> Result<()> { first_value(vec![lit(1)], false, None, None, None), covar_samp(lit(1.5), lit(2.2)), covar_pop(lit(1.5), lit(2.2)), + median(lit(2)), ]; // ensure expressions created with the expr api can be round tripped diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index c2478e543735..2a220ea0a89d 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -871,6 +871,33 @@ select median(distinct c) from t; statement ok drop table t; +# optimize distinct median to group by +statement ok +create table t(c int) as values (1), (1), (1), (1), (2), (2), (3), (3); + +query TT +explain select median(distinct c) from t; +---- +logical_plan +01)Projection: MEDIAN(alias1) AS MEDIAN(DISTINCT t.c) +02)--Aggregate: groupBy=[[]], aggr=[[MEDIAN(alias1)]] +03)----Aggregate: groupBy=[[t.c AS alias1]], aggr=[[]] +04)------TableScan: t projection=[c] +physical_plan +01)ProjectionExec: expr=[MEDIAN(alias1)@0 as MEDIAN(DISTINCT t.c)] +02)--AggregateExec: mode=Final, gby=[], aggr=[MEDIAN(alias1)] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[MEDIAN(alias1)] +05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------RepartitionExec: partitioning=Hash([alias1@0], 4), input_partitions=4 +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------AggregateExec: mode=Partial, gby=[c@0 as alias1], aggr=[] +10)------------------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +drop table t; + # median_multi # test case for https://github.com/apache/datafusion/issues/3105 # has an intermediate grouping