From c97048d178594b10b813c6bcd1543f157db4ba3f Mon Sep 17 00:00:00 2001 From: mingmwang Date: Tue, 11 Apr 2023 18:58:16 +0800 Subject: [PATCH] Improve avg/sum Aggregator performance for Decimal (#5866) * improve avg/sum Aggregator performance * check type before cast * fix Arithmetic Overflow bug * fix clippy --- datafusion/core/src/execution/context.rs | 7 +- .../core/src/physical_plan/aggregates/mod.rs | 44 ++++++++++- datafusion/core/tests/sql/udf.rs | 7 +- datafusion/expr/src/aggregate_function.rs | 13 ++++ .../expr/src/type_coercion/aggregates.rs | 15 ++++ .../optimizer/src/analyzer/type_coercion.rs | 15 +++- .../physical-expr/src/aggregate/average.rs | 72 +++++++++++++----- .../physical-expr/src/aggregate/build_in.rs | 76 ++++++++++--------- datafusion/physical-expr/src/aggregate/sum.rs | 27 ++++++- .../physical-expr/src/aggregate/utils.rs | 50 +++++++++++- datafusion/proto/src/physical_plan/mod.rs | 13 ++-- 11 files changed, 267 insertions(+), 72 deletions(-) diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 8a5ca3d023ad..c3adb4cc74dd 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -2008,7 +2008,12 @@ mod tests { DataType::Float64, Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))), + Arc::new(|_| { + Ok(Box::new(AvgAccumulator::try_new( + &DataType::Float64, + &DataType::Float64, + )?)) + }), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index c41cc438c898..ade0fa0066ce 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -31,13 +31,14 @@ use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; -use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::expressions::{Avg, CastExpr, Column, Sum}; use datafusion_physical_expr::{ expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr, }; use std::any::Any; use std::collections::HashMap; +use arrow::compute::DEFAULT_CAST_OPTIONS; use std::sync::Arc; mod no_grouping; @@ -554,9 +555,44 @@ fn aggregate_expressions( col_idx_base: usize, ) -> Result>>> { match mode { - AggregateMode::Partial => { - Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect()) - } + AggregateMode::Partial => Ok(aggr_expr + .iter() + .map(|agg| { + let pre_cast_type = if let Some(Sum { + data_type, + pre_cast_to_sum_type, + .. + }) = agg.as_any().downcast_ref::() + { + if *pre_cast_to_sum_type { + Some(data_type.clone()) + } else { + None + } + } else if let Some(Avg { + sum_data_type, + pre_cast_to_sum_type, + .. + }) = agg.as_any().downcast_ref::() + { + if *pre_cast_to_sum_type { + Some(sum_data_type.clone()) + } else { + None + } + } else { + None + }; + agg.expressions() + .into_iter() + .map(|expr| { + pre_cast_type.clone().map_or(expr.clone(), |cast_type| { + Arc::new(CastExpr::new(expr, cast_type, DEFAULT_CAST_OPTIONS)) + }) + }) + .collect::>() + }) + .collect()), // in this mode, we build the merge expressions of the aggregation AggregateMode::Final | AggregateMode::FinalPartitioned => { let mut col_idx_base = col_idx_base; diff --git a/datafusion/core/tests/sql/udf.rs b/datafusion/core/tests/sql/udf.rs index a1c48595605d..3f4402ec410a 100644 --- a/datafusion/core/tests/sql/udf.rs +++ b/datafusion/core/tests/sql/udf.rs @@ -204,7 +204,12 @@ async fn simple_udaf() -> Result<()> { DataType::Float64, Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))), + Arc::new(|_| { + Ok(Box::new(AvgAccumulator::try_new( + &DataType::Float64, + &DataType::Float64, + )?)) + }), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index f1d5ea0092db..968fd26ab0f5 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -161,6 +161,19 @@ pub fn return_type( } } +/// Returns the internal sum datatype of the avg aggregate function. +pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result { + // Note that this function *must* return the same type that the respective physical expression returns + // or the execution panics. + let fun = AggregateFunction::Avg; + let coerced_data_types = crate::type_coercion::aggregates::coerce_types( + &fun, + input_expr_types, + &signature(&fun), + )?; + avg_sum_type(&coerced_data_types[0]) +} + /// the signatures supported by the function `fun`. pub fn signature(fun: &AggregateFunction) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 3d4b9646dc4c..efcd503cf42f 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -381,6 +381,21 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { } } +/// internal sum type of an average +pub fn avg_sum_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Decimal128(precision, scale) => { + // in the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal128(new_precision, *scale)) + } + arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), + other => Err(DataFusionError::Plan(format!( + "AVG does not support {other:?}" + ))), + } +} + pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { matches!( arg_type, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 9a210601f606..6e11907cd1ad 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -866,7 +866,12 @@ mod test { DataType::Float64, Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))), + Arc::new(|_| { + Ok(Box::new(AvgAccumulator::try_new( + &DataType::Float64, + &DataType::Float64, + )?)) + }), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); let udaf = Expr::AggregateUDF { @@ -886,8 +891,12 @@ mod test { Arc::new(move |_| Ok(Arc::new(DataType::Float64))); let state_type: StateTypeFunction = Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64]))); - let accumulator: AccumulatorFunctionImplementation = - Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))); + let accumulator: AccumulatorFunctionImplementation = Arc::new(|_| { + Ok(Box::new(AvgAccumulator::try_new( + &DataType::Float64, + &DataType::Float64, + )?)) + }); let my_avg = AggregateUDF::new( "MY_AVG", &Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index de5f78f0a79f..7ad484ac4b92 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -26,6 +26,7 @@ use crate::aggregate::row_accumulator::{ }; use crate::aggregate::sum; use crate::aggregate::sum::sum_batch; +use crate::aggregate::utils::calculate_result_decimal_for_avg; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; use arrow::compute; @@ -34,6 +35,7 @@ use arrow::{ array::{ArrayRef, UInt64Array}, datatypes::Field, }; +use arrow_array::Array; use datafusion_common::{downcast_value, ScalarValue}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; @@ -44,7 +46,9 @@ use datafusion_row::accessor::RowAccessor; pub struct Avg { name: String, expr: Arc, - data_type: DataType, + pub sum_data_type: DataType, + rt_data_type: DataType, + pub pre_cast_to_sum_type: bool, } impl Avg { @@ -52,17 +56,34 @@ impl Avg { pub fn new( expr: Arc, name: impl Into, - data_type: DataType, + sum_data_type: DataType, ) -> Self { + Self::new_with_pre_cast(expr, name, sum_data_type.clone(), sum_data_type, false) + } + + pub fn new_with_pre_cast( + expr: Arc, + name: impl Into, + sum_data_type: DataType, + rt_data_type: DataType, + cast_to_sum_type: bool, + ) -> Self { + // the internal sum data type of avg just support FLOAT64 and Decimal data type. + assert!(matches!( + sum_data_type, + DataType::Float64 | DataType::Decimal128(_, _) + )); // the result of avg just support FLOAT64 and Decimal data type. assert!(matches!( - data_type, + rt_data_type, DataType::Float64 | DataType::Decimal128(_, _) )); Self { name: name.into(), expr, - data_type, + sum_data_type, + rt_data_type, + pre_cast_to_sum_type: cast_to_sum_type, } } } @@ -74,13 +95,14 @@ impl AggregateExpr for Avg { } fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) + Ok(Field::new(&self.name, self.rt_data_type.clone(), true)) } fn create_accumulator(&self) -> Result> { Ok(Box::new(AvgAccumulator::try_new( // avg is f64 or decimal - &self.data_type, + &self.sum_data_type, + &self.rt_data_type, )?)) } @@ -93,7 +115,7 @@ impl AggregateExpr for Avg { ), Field::new( format_state_name(&self.name, "sum"), - self.data_type.clone(), + self.sum_data_type.clone(), true, ), ]) @@ -108,7 +130,7 @@ impl AggregateExpr for Avg { } fn row_accumulator_supported(&self) -> bool { - is_row_accumulator_support_dtype(&self.data_type) + is_row_accumulator_support_dtype(&self.sum_data_type) } fn supports_bounded_execution(&self) -> bool { @@ -121,7 +143,7 @@ impl AggregateExpr for Avg { ) -> Result> { Ok(Box::new(AvgRowAccumulator::new( start_index, - self.data_type.clone(), + self.sum_data_type.clone(), ))) } @@ -130,7 +152,10 @@ impl AggregateExpr for Avg { } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(AvgAccumulator::try_new(&self.data_type)?)) + Ok(Box::new(AvgAccumulator::try_new( + &self.sum_data_type, + &self.rt_data_type, + )?)) } } @@ -139,14 +164,18 @@ impl AggregateExpr for Avg { pub struct AvgAccumulator { // sum is used for null sum: ScalarValue, + sum_data_type: DataType, + return_data_type: DataType, count: u64, } impl AvgAccumulator { /// Creates a new `AvgAccumulator` - pub fn try_new(datatype: &DataType) -> Result { + pub fn try_new(datatype: &DataType, return_data_type: &DataType) -> Result { Ok(Self { sum: ScalarValue::try_from(datatype)?, + sum_data_type: datatype.clone(), + return_data_type: return_data_type.clone(), count: 0, }) } @@ -163,14 +192,14 @@ impl Accumulator for AvgAccumulator { self.count += (values.len() - values.data().null_count()) as u64; self.sum = self .sum - .add(&sum::sum_batch(values, &self.sum.get_datatype())?)?; + .add(&sum::sum_batch(values, &self.sum_data_type)?)?; Ok(()) } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; self.count -= (values.len() - values.data().null_count()) as u64; - let delta = sum_batch(values, &self.sum.get_datatype())?; + let delta = sum_batch(values, &self.sum_data_type)?; self.sum = self.sum.sub(&delta)?; Ok(()) } @@ -183,7 +212,7 @@ impl Accumulator for AvgAccumulator { // sums are summed self.sum = self .sum - .add(&sum::sum_batch(&states[1], &self.sum.get_datatype())?)?; + .add(&sum::sum_batch(&states[1], &self.sum_data_type)?)?; Ok(()) } @@ -195,12 +224,15 @@ impl Accumulator for AvgAccumulator { ScalarValue::Decimal128(value, precision, scale) => { Ok(match value { None => ScalarValue::Decimal128(None, precision, scale), - // TODO add the checker for overflow the precision - Some(v) => ScalarValue::Decimal128( - Some(v / self.count as i128), - precision, - scale, - ), + Some(value) => { + // now the sum_type and return type is not the same, need to convert the sum type to return type + calculate_result_decimal_for_avg( + value, + self.count as i128, + scale, + &self.return_data_type, + )? + } }) } _ => Err(DataFusionError::Internal( diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index b1e03fb5d9b1..91415b06866d 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -29,7 +29,7 @@ use crate::{expressions, AggregateExpr, PhysicalExpr}; use arrow::datatypes::Schema; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::aggregate_function::return_type; +use datafusion_expr::aggregate_function::{return_type, sum_type_of_avg}; pub use datafusion_expr::AggregateFunction; use std::sync::Arc; @@ -48,17 +48,13 @@ pub fn create_aggregate_expr( .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; - let return_type = return_type(fun, &input_phy_types)?; + let rt_type = return_type(fun, &input_phy_types)?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::Count, false) => { - Arc::new(expressions::Count::new_with_multiple_exprs( - input_phy_exprs, - name, - return_type, - )) - } + (AggregateFunction::Count, false) => Arc::new( + expressions::Count::new_with_multiple_exprs(input_phy_exprs, name, rt_type), + ), (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( input_phy_types[0].clone(), input_phy_exprs[0].clone(), @@ -67,17 +63,21 @@ pub fn create_aggregate_expr( (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, - return_type, - )), - (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new( - input_phy_exprs[0].clone(), - name, - return_type, + rt_type, )), + (AggregateFunction::Sum, false) => { + let cast_to_sum_type = rt_type != input_phy_types[0]; + Arc::new(expressions::Sum::new_with_pre_cast( + input_phy_exprs[0].clone(), + name, + rt_type, + cast_to_sum_type, + )) + } (AggregateFunction::Sum, true) => Arc::new(expressions::DistinctSum::new( vec![input_phy_exprs[0].clone()], name, - return_type, + rt_type, )), (AggregateFunction::ApproxDistinct, _) => { Arc::new(expressions::ApproxDistinct::new( @@ -101,18 +101,24 @@ pub fn create_aggregate_expr( (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( input_phy_exprs[0].clone(), name, - return_type, + rt_type, )), (AggregateFunction::Max, _) => Arc::new(expressions::Max::new( input_phy_exprs[0].clone(), name, - return_type, - )), - (AggregateFunction::Avg, false) => Arc::new(expressions::Avg::new( - input_phy_exprs[0].clone(), - name, - return_type, + rt_type, )), + (AggregateFunction::Avg, false) => { + let sum_type = sum_type_of_avg(&input_phy_types)?; + let cast_to_sum_type = sum_type != input_phy_types[0]; + Arc::new(expressions::Avg::new_with_pre_cast( + input_phy_exprs[0].clone(), + name, + sum_type, + rt_type, + cast_to_sum_type, + )) + } (AggregateFunction::Avg, true) => { return Err(DataFusionError::NotImplemented( "AVG(DISTINCT) aggregations are not available".to_string(), @@ -121,7 +127,7 @@ pub fn create_aggregate_expr( (AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new( input_phy_exprs[0].clone(), name, - return_type, + rt_type, )), (AggregateFunction::Variance, true) => { return Err(DataFusionError::NotImplemented( @@ -129,7 +135,7 @@ pub fn create_aggregate_expr( )); } (AggregateFunction::VariancePop, false) => Arc::new( - expressions::VariancePop::new(input_phy_exprs[0].clone(), name, return_type), + expressions::VariancePop::new(input_phy_exprs[0].clone(), name, rt_type), ), (AggregateFunction::VariancePop, true) => { return Err(DataFusionError::NotImplemented( @@ -140,7 +146,7 @@ pub fn create_aggregate_expr( input_phy_exprs[0].clone(), input_phy_exprs[1].clone(), name, - return_type, + rt_type, )), (AggregateFunction::Covariance, true) => { return Err(DataFusionError::NotImplemented( @@ -152,7 +158,7 @@ pub fn create_aggregate_expr( input_phy_exprs[0].clone(), input_phy_exprs[1].clone(), name, - return_type, + rt_type, )) } (AggregateFunction::CovariancePop, true) => { @@ -163,7 +169,7 @@ pub fn create_aggregate_expr( (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( input_phy_exprs[0].clone(), name, - return_type, + rt_type, )), (AggregateFunction::Stddev, true) => { return Err(DataFusionError::NotImplemented( @@ -173,7 +179,7 @@ pub fn create_aggregate_expr( (AggregateFunction::StddevPop, false) => Arc::new(expressions::StddevPop::new( input_phy_exprs[0].clone(), name, - return_type, + rt_type, )), (AggregateFunction::StddevPop, true) => { return Err(DataFusionError::NotImplemented( @@ -185,7 +191,7 @@ pub fn create_aggregate_expr( input_phy_exprs[0].clone(), input_phy_exprs[1].clone(), name, - return_type, + rt_type, )) } (AggregateFunction::Correlation, true) => { @@ -199,14 +205,14 @@ pub fn create_aggregate_expr( // Pass in the desired percentile expr input_phy_exprs, name, - return_type, + rt_type, )?) } else { Arc::new(expressions::ApproxPercentileCont::new_with_max_size( // Pass in the desired percentile expr input_phy_exprs, name, - return_type, + rt_type, )?) } } @@ -221,7 +227,7 @@ pub fn create_aggregate_expr( // Pass in the desired percentile expr input_phy_exprs, name, - return_type, + rt_type, )?) } (AggregateFunction::ApproxPercentileContWithWeight, true) => { @@ -234,7 +240,7 @@ pub fn create_aggregate_expr( Arc::new(expressions::ApproxMedian::try_new( input_phy_exprs[0].clone(), name, - return_type, + rt_type, )?) } (AggregateFunction::ApproxMedian, true) => { @@ -245,7 +251,7 @@ pub fn create_aggregate_expr( (AggregateFunction::Median, false) => Arc::new(expressions::Median::new( input_phy_exprs[0].clone(), name, - return_type, + rt_type, )), (AggregateFunction::Median, true) => { return Err(DataFusionError::NotImplemented( diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index a815a33c8c7f..0e302f332f9f 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -47,9 +47,10 @@ use datafusion_row::accessor::RowAccessor; #[derive(Debug, Clone)] pub struct Sum { name: String, - data_type: DataType, + pub data_type: DataType, expr: Arc, nullable: bool, + pub pre_cast_to_sum_type: bool, } impl Sum { @@ -64,6 +65,22 @@ impl Sum { expr, data_type, nullable: true, + pre_cast_to_sum_type: false, + } + } + + pub fn new_with_pre_cast( + expr: Arc, + name: impl Into, + data_type: DataType, + pre_cast_to_sum_type: bool, + ) -> Self { + Self { + name: name.into(), + expr, + data_type, + nullable: true, + pre_cast_to_sum_type, } } } @@ -169,7 +186,13 @@ fn sum_decimal_batch(values: &ArrayRef, precision: u8, scale: i8) -> Result Result { - let values = &cast(values, sum_type)?; + // TODO refine the cast kernel in arrow-rs + let cast_values = if values.data_type() != sum_type { + Some(cast(values, sum_type)?) + } else { + None + }; + let values = cast_values.as_ref().unwrap_or(values); Ok(match values.data_type() { DataType::Decimal128(precision, scale) => { sum_decimal_batch(values, *precision, *scale)? diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index a63c5e208666..16fd8fd6c849 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -18,7 +18,9 @@ //! Utilities used in aggregates use arrow::array::ArrayRef; -use datafusion_common::Result; +use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION}; +use arrow_schema::DataType; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; /// Convert scalar values from an accumulator into arrays. @@ -31,3 +33,49 @@ pub fn get_accum_scalar_values_as_arrays( .map(|s| s.to_array_of_size(1)) .collect::>()) } + +pub fn calculate_result_decimal_for_avg( + lit_value: i128, + count: i128, + scale: i8, + target_type: &DataType, +) -> Result { + match target_type { + DataType::Decimal128(p, s) => { + // Different precision for decimal128 can store different range of value. + // For example, the precision is 3, the max of value is `999` and the min + // value is `-999` + let (target_mul, target_min, target_max) = ( + 10_i128.pow(*s as u32), + MIN_DECIMAL_FOR_EACH_PRECISION[*p as usize - 1], + MAX_DECIMAL_FOR_EACH_PRECISION[*p as usize - 1], + ); + let lit_scale_mul = 10_i128.pow(scale as u32); + if target_mul >= lit_scale_mul { + if let Some(value) = lit_value.checked_mul(target_mul / lit_scale_mul) { + let new_value = value / count; + if new_value >= target_min && new_value <= target_max { + Ok(ScalarValue::Decimal128(Some(new_value), *p, *s)) + } else { + Err(DataFusionError::Internal( + "Arithmetic Overflow in AvgAccumulator".to_string(), + )) + } + } else { + // can't convert the lit decimal to the returned data type + Err(DataFusionError::Internal( + "Arithmetic Overflow in AvgAccumulator".to_string(), + )) + } + } else { + // can't convert the lit decimal to the returned data type + Err(DataFusionError::Internal( + "Arithmetic Overflow in AvgAccumulator".to_string(), + )) + } + } + other => Err(DataFusionError::Internal(format!( + "Error returned data type in AvgAccumulator {other:?}" + ))), + } +} diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 5bf82e423c68..d45eeaed6016 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -1378,11 +1378,14 @@ mod roundtrip_tests { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec> = vec![Arc::new(Avg::new( - col("b", &schema)?, - "AVG(b)".to_string(), - DataType::Float64, - ))]; + let aggregates: Vec> = + vec![Arc::new(Avg::new_with_pre_cast( + col("b", &schema)?, + "AVG(b)".to_string(), + DataType::Float64, + DataType::Float64, + true, + ))]; roundtrip_test(Arc::new(AggregateExec::try_new( AggregateMode::Final,