Skip to content

Commit

Permalink
Improve avg/sum Aggregator performance for Decimal (#5866)
Browse files Browse the repository at this point in the history
* improve avg/sum Aggregator performance

* check type before cast

* fix Arithmetic Overflow bug

* fix clippy
  • Loading branch information
mingmwang authored Apr 11, 2023
1 parent 9377105 commit c97048d
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 72 deletions.
7 changes: 6 additions & 1 deletion datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2008,7 +2008,12 @@ mod tests {
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
}),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down
44 changes: 40 additions & 4 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::expressions::{Avg, CastExpr, Column, Sum};
use datafusion_physical_expr::{
expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr,
};
use std::any::Any;
use std::collections::HashMap;

use arrow::compute::DEFAULT_CAST_OPTIONS;
use std::sync::Arc;

mod no_grouping;
Expand Down Expand Up @@ -554,9 +555,44 @@ fn aggregate_expressions(
col_idx_base: usize,
) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
match mode {
AggregateMode::Partial => {
Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect())
}
AggregateMode::Partial => Ok(aggr_expr
.iter()
.map(|agg| {
let pre_cast_type = if let Some(Sum {
data_type,
pre_cast_to_sum_type,
..
}) = agg.as_any().downcast_ref::<Sum>()
{
if *pre_cast_to_sum_type {
Some(data_type.clone())
} else {
None
}
} else if let Some(Avg {
sum_data_type,
pre_cast_to_sum_type,
..
}) = agg.as_any().downcast_ref::<Avg>()
{
if *pre_cast_to_sum_type {
Some(sum_data_type.clone())
} else {
None
}
} else {
None
};
agg.expressions()
.into_iter()
.map(|expr| {
pre_cast_type.clone().map_or(expr.clone(), |cast_type| {
Arc::new(CastExpr::new(expr, cast_type, DEFAULT_CAST_OPTIONS))
})
})
.collect::<Vec<_>>()
})
.collect()),
// in this mode, we build the merge expressions of the aggregation
AggregateMode::Final | AggregateMode::FinalPartitioned => {
let mut col_idx_base = col_idx_base;
Expand Down
7 changes: 6 additions & 1 deletion datafusion/core/tests/sql/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,12 @@ async fn simple_udaf() -> Result<()> {
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
}),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);

Expand Down
13 changes: 13 additions & 0 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,19 @@ pub fn return_type(
}
}

/// Returns the internal sum datatype of the avg aggregate function.
pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
let fun = AggregateFunction::Avg;
let coerced_data_types = crate::type_coercion::aggregates::coerce_types(
&fun,
input_expr_types,
&signature(&fun),
)?;
avg_sum_type(&coerced_data_types[0])
}

/// the signatures supported by the function `fun`.
pub fn signature(fun: &AggregateFunction) -> Signature {
// note: the physical expression must accept the type returned by this function or the execution panics.
Expand Down
15 changes: 15 additions & 0 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,21 @@ pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
}
}

/// internal sum type of an average
pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Decimal128(precision, scale) => {
// in the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal128(new_precision, *scale))
}
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
other => Err(DataFusionError::Plan(format!(
"AVG does not support {other:?}"
))),
}
}

pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
Expand Down
15 changes: 12 additions & 3 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,12 @@ mod test {
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?))),
Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
}),
Arc::new(vec![DataType::UInt64, DataType::Float64]),
);
let udaf = Expr::AggregateUDF {
Expand All @@ -886,8 +891,12 @@ mod test {
Arc::new(move |_| Ok(Arc::new(DataType::Float64)));
let state_type: StateTypeFunction =
Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64])));
let accumulator: AccumulatorFunctionImplementation =
Arc::new(|_| Ok(Box::new(AvgAccumulator::try_new(&DataType::Float64)?)));
let accumulator: AccumulatorFunctionImplementation = Arc::new(|_| {
Ok(Box::new(AvgAccumulator::try_new(
&DataType::Float64,
&DataType::Float64,
)?))
});
let my_avg = AggregateUDF::new(
"MY_AVG",
&Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable),
Expand Down
72 changes: 52 additions & 20 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use crate::aggregate::row_accumulator::{
};
use crate::aggregate::sum;
use crate::aggregate::sum::sum_batch;
use crate::aggregate::utils::calculate_result_decimal_for_avg;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use arrow::compute;
Expand All @@ -34,6 +35,7 @@ use arrow::{
array::{ArrayRef, UInt64Array},
datatypes::Field,
};
use arrow_array::Array;
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
Expand All @@ -44,25 +46,44 @@ use datafusion_row::accessor::RowAccessor;
pub struct Avg {
name: String,
expr: Arc<dyn PhysicalExpr>,
data_type: DataType,
pub sum_data_type: DataType,
rt_data_type: DataType,
pub pre_cast_to_sum_type: bool,
}

impl Avg {
/// Create a new AVG aggregate function
pub fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
sum_data_type: DataType,
) -> Self {
Self::new_with_pre_cast(expr, name, sum_data_type.clone(), sum_data_type, false)
}

pub fn new_with_pre_cast(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
sum_data_type: DataType,
rt_data_type: DataType,
cast_to_sum_type: bool,
) -> Self {
// the internal sum data type of avg just support FLOAT64 and Decimal data type.
assert!(matches!(
sum_data_type,
DataType::Float64 | DataType::Decimal128(_, _)
));
// the result of avg just support FLOAT64 and Decimal data type.
assert!(matches!(
data_type,
rt_data_type,
DataType::Float64 | DataType::Decimal128(_, _)
));
Self {
name: name.into(),
expr,
data_type,
sum_data_type,
rt_data_type,
pre_cast_to_sum_type: cast_to_sum_type,
}
}
}
Expand All @@ -74,13 +95,14 @@ impl AggregateExpr for Avg {
}

fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, self.data_type.clone(), true))
Ok(Field::new(&self.name, self.rt_data_type.clone(), true))
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(
// avg is f64 or decimal
&self.data_type,
&self.sum_data_type,
&self.rt_data_type,
)?))
}

Expand All @@ -93,7 +115,7 @@ impl AggregateExpr for Avg {
),
Field::new(
format_state_name(&self.name, "sum"),
self.data_type.clone(),
self.sum_data_type.clone(),
true,
),
])
Expand All @@ -108,7 +130,7 @@ impl AggregateExpr for Avg {
}

fn row_accumulator_supported(&self) -> bool {
is_row_accumulator_support_dtype(&self.data_type)
is_row_accumulator_support_dtype(&self.sum_data_type)
}

fn supports_bounded_execution(&self) -> bool {
Expand All @@ -121,7 +143,7 @@ impl AggregateExpr for Avg {
) -> Result<Box<dyn RowAccumulator>> {
Ok(Box::new(AvgRowAccumulator::new(
start_index,
self.data_type.clone(),
self.sum_data_type.clone(),
)))
}

Expand All @@ -130,7 +152,10 @@ impl AggregateExpr for Avg {
}

fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(&self.data_type)?))
Ok(Box::new(AvgAccumulator::try_new(
&self.sum_data_type,
&self.rt_data_type,
)?))
}
}

Expand All @@ -139,14 +164,18 @@ impl AggregateExpr for Avg {
pub struct AvgAccumulator {
// sum is used for null
sum: ScalarValue,
sum_data_type: DataType,
return_data_type: DataType,
count: u64,
}

impl AvgAccumulator {
/// Creates a new `AvgAccumulator`
pub fn try_new(datatype: &DataType) -> Result<Self> {
pub fn try_new(datatype: &DataType, return_data_type: &DataType) -> Result<Self> {
Ok(Self {
sum: ScalarValue::try_from(datatype)?,
sum_data_type: datatype.clone(),
return_data_type: return_data_type.clone(),
count: 0,
})
}
Expand All @@ -163,14 +192,14 @@ impl Accumulator for AvgAccumulator {
self.count += (values.len() - values.data().null_count()) as u64;
self.sum = self
.sum
.add(&sum::sum_batch(values, &self.sum.get_datatype())?)?;
.add(&sum::sum_batch(values, &self.sum_data_type)?)?;
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &values[0];
self.count -= (values.len() - values.data().null_count()) as u64;
let delta = sum_batch(values, &self.sum.get_datatype())?;
let delta = sum_batch(values, &self.sum_data_type)?;
self.sum = self.sum.sub(&delta)?;
Ok(())
}
Expand All @@ -183,7 +212,7 @@ impl Accumulator for AvgAccumulator {
// sums are summed
self.sum = self
.sum
.add(&sum::sum_batch(&states[1], &self.sum.get_datatype())?)?;
.add(&sum::sum_batch(&states[1], &self.sum_data_type)?)?;
Ok(())
}

Expand All @@ -195,12 +224,15 @@ impl Accumulator for AvgAccumulator {
ScalarValue::Decimal128(value, precision, scale) => {
Ok(match value {
None => ScalarValue::Decimal128(None, precision, scale),
// TODO add the checker for overflow the precision
Some(v) => ScalarValue::Decimal128(
Some(v / self.count as i128),
precision,
scale,
),
Some(value) => {
// now the sum_type and return type is not the same, need to convert the sum type to return type
calculate_result_decimal_for_avg(
value,
self.count as i128,
scale,
&self.return_data_type,
)?
}
})
}
_ => Err(DataFusionError::Internal(
Expand Down
Loading

0 comments on commit c97048d

Please sign in to comment.