Skip to content

Commit

Permalink
perf: Make Sum use PrimitiveGroupsAccumulator
Browse files Browse the repository at this point in the history
  • Loading branch information
srh committed Nov 26, 2024
1 parent 07cdc38 commit 64ae03e
Show file tree
Hide file tree
Showing 10 changed files with 1,549 additions and 133 deletions.
5 changes: 3 additions & 2 deletions datafusion/src/cube_ext/joinagg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ impl ExecutionPlan for CrossJoinAggExec {
&AggregateMode::Full,
self.group_expr.len(),
)?;
let mut accumulators = create_accumulation_state(&self.agg_expr)?;
let mut accumulators: hash_aggregate::AccumulationState =
create_accumulation_state(&self.agg_expr)?;
for partition in 0..self.join.right.output_partitioning().partition_count() {
let mut batches = self.join.right.execute(partition).await?;
while let Some(right) = batches.next().await {
Expand Down Expand Up @@ -273,7 +274,7 @@ impl ExecutionPlan for CrossJoinAggExec {
let out_schema = self.schema.clone();
let r = hash_aggregate::create_batch_from_map(
&AggregateMode::Full,
&accumulators,
accumulators,
self.group_expr.len(),
&out_schema,
)?;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ pub fn create_aggregate_expr(
))
}
(AggregateFunction::Sum, false) => {
Arc::new(expressions::Sum::new(arg, name, return_type))
Arc::new(expressions::Sum::new(arg, name, return_type, &arg_types[0]))
}
(AggregateFunction::Sum, true) => {
return Err(DataFusionError::NotImplemented(
Expand Down
180 changes: 168 additions & 12 deletions datafusion/src/physical_plan/expressions/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::sync::Arc;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::groups_accumulator::GroupsAccumulator;
use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter;
use crate::physical_plan::groups_accumulator_prim_op::PrimitiveGroupsAccumulator;
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
use crate::scalar::ScalarValue;
use arrow::compute;
Expand All @@ -49,6 +50,7 @@ use smallvec::SmallVec;
pub struct Sum {
name: String,
data_type: DataType,
input_data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
nullable: bool,
}
Expand Down Expand Up @@ -80,11 +82,16 @@ impl Sum {
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
input_data_type: &DataType,
) -> Self {
// Note: data_type = sum_return_type(input_data_type) in the actual caller, so we don't
// really need two params. But, we keep the four params to break symmetry with other
// accumulators and any code that might use 3 params, such as the generic_test_op macro.
Self {
name: name.into(),
expr,
data_type,
input_data_type: input_data_type.clone(),
nullable: true,
}
}
Expand Down Expand Up @@ -127,12 +134,147 @@ impl AggregateExpr for Sum {
fn create_groups_accumulator(
&self,
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
let data_type = self.data_type.clone();
Ok(Some(Box::new(
GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(move || {
SumAccumulator::try_new(&data_type)
}),
)))
use arrow::datatypes::ArrowPrimitiveType;

macro_rules! make_accumulator {
($T:ty, $U:ty) => {
Box::new(PrimitiveGroupsAccumulator::<$T, $U, _, _>::new(
&<$T as ArrowPrimitiveType>::DATA_TYPE,
|x: &mut <$T as ArrowPrimitiveType>::Native,
y: <$U as ArrowPrimitiveType>::Native| {
*x = *x + (y as <$T as ArrowPrimitiveType>::Native);
},
|x: &mut <$T as ArrowPrimitiveType>::Native,
y: <$T as ArrowPrimitiveType>::Native| {
*x = *x + y;
},
))
};
}

// Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic
// the current datafusion Sum accumulator implementation using native +. (That native +
// specifically is the one in the expressions *x = *x + ... above.)
Ok(Some(match (&self.data_type, &self.input_data_type) {
(DataType::Int64, DataType::Int64) => make_accumulator!(
arrow::datatypes::Int64Type,
arrow::datatypes::Int64Type
),
(DataType::Int64, DataType::Int32) => make_accumulator!(
arrow::datatypes::Int64Type,
arrow::datatypes::Int32Type
),
(DataType::Int64, DataType::Int16) => make_accumulator!(
arrow::datatypes::Int64Type,
arrow::datatypes::Int16Type
),
(DataType::Int64, DataType::Int8) => {
make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int8Type)
}

(DataType::Int96, DataType::Int96) => make_accumulator!(
arrow::datatypes::Int96Type,
arrow::datatypes::Int96Type
),

(DataType::Int64Decimal(0), DataType::Int64Decimal(0)) => make_accumulator!(
arrow::datatypes::Int64Decimal0Type,
arrow::datatypes::Int64Decimal0Type
),
(DataType::Int64Decimal(1), DataType::Int64Decimal(1)) => make_accumulator!(
arrow::datatypes::Int64Decimal1Type,
arrow::datatypes::Int64Decimal1Type
),
(DataType::Int64Decimal(2), DataType::Int64Decimal(2)) => make_accumulator!(
arrow::datatypes::Int64Decimal2Type,
arrow::datatypes::Int64Decimal2Type
),
(DataType::Int64Decimal(3), DataType::Int64Decimal(3)) => make_accumulator!(
arrow::datatypes::Int64Decimal3Type,
arrow::datatypes::Int64Decimal3Type
),
(DataType::Int64Decimal(4), DataType::Int64Decimal(4)) => make_accumulator!(
arrow::datatypes::Int64Decimal4Type,
arrow::datatypes::Int64Decimal4Type
),
(DataType::Int64Decimal(5), DataType::Int64Decimal(5)) => make_accumulator!(
arrow::datatypes::Int64Decimal5Type,
arrow::datatypes::Int64Decimal5Type
),
(DataType::Int64Decimal(10), DataType::Int64Decimal(10)) => {
make_accumulator!(
arrow::datatypes::Int64Decimal10Type,
arrow::datatypes::Int64Decimal10Type
)
}

(DataType::Int96Decimal(0), DataType::Int96Decimal(0)) => make_accumulator!(
arrow::datatypes::Int96Decimal0Type,
arrow::datatypes::Int96Decimal0Type
),
(DataType::Int96Decimal(1), DataType::Int96Decimal(1)) => make_accumulator!(
arrow::datatypes::Int96Decimal1Type,
arrow::datatypes::Int96Decimal1Type
),
(DataType::Int96Decimal(2), DataType::Int96Decimal(2)) => make_accumulator!(
arrow::datatypes::Int96Decimal2Type,
arrow::datatypes::Int96Decimal2Type
),
(DataType::Int96Decimal(3), DataType::Int96Decimal(3)) => make_accumulator!(
arrow::datatypes::Int96Decimal3Type,
arrow::datatypes::Int96Decimal3Type
),
(DataType::Int96Decimal(4), DataType::Int96Decimal(4)) => make_accumulator!(
arrow::datatypes::Int96Decimal4Type,
arrow::datatypes::Int96Decimal4Type
),
(DataType::Int96Decimal(5), DataType::Int96Decimal(5)) => make_accumulator!(
arrow::datatypes::Int96Decimal5Type,
arrow::datatypes::Int96Decimal5Type
),
(DataType::Int96Decimal(10), DataType::Int96Decimal(10)) => {
make_accumulator!(
arrow::datatypes::Int96Decimal10Type,
arrow::datatypes::Int96Decimal10Type
)
}

(DataType::UInt64, DataType::UInt64) => make_accumulator!(
arrow::datatypes::UInt64Type,
arrow::datatypes::UInt64Type
),
(DataType::UInt64, DataType::UInt32) => make_accumulator!(
arrow::datatypes::UInt64Type,
arrow::datatypes::UInt32Type
),
(DataType::UInt64, DataType::UInt16) => make_accumulator!(
arrow::datatypes::UInt64Type,
arrow::datatypes::UInt16Type
),
(DataType::UInt64, DataType::UInt8) => make_accumulator!(
arrow::datatypes::UInt64Type,
arrow::datatypes::UInt8Type
),

(DataType::Float32, DataType::Float32) => make_accumulator!(
arrow::datatypes::Float32Type,
arrow::datatypes::Float32Type
),
(DataType::Float64, DataType::Float64) => make_accumulator!(
arrow::datatypes::Float64Type,
arrow::datatypes::Float64Type
),

_ => {
// This case should never be reached because we've handled all sum_return_type
// arg_type values. Nonetheless:
let data_type = self.data_type.clone();

Box::new(GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(
move || SumAccumulator::try_new(&data_type),
))
}
}))
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -416,13 +558,27 @@ mod tests {
use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;

// A wrapper to make Sum::new, which now has an input_type argument, work with
// generic_test_op!.
struct SumTestStandin;
impl SumTestStandin {
fn new(
expr: Arc<dyn PhysicalExpr>,
name: impl Into<String>,
data_type: DataType,
) -> Sum {
Sum::new(expr, name, data_type.clone(), &data_type)
}
}

#[test]
fn sum_i32() -> Result<()> {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));

generic_test_op!(
a,
DataType::Int32,
Sum,
SumTestStandin,
ScalarValue::from(15i64),
DataType::Int64
)
Expand All @@ -440,7 +596,7 @@ mod tests {
generic_test_op!(
a,
DataType::Int32,
Sum,
SumTestStandin,
ScalarValue::from(13i64),
DataType::Int64
)
Expand All @@ -452,7 +608,7 @@ mod tests {
generic_test_op!(
a,
DataType::Int32,
Sum,
SumTestStandin,
ScalarValue::Int64(None),
DataType::Int64
)
Expand All @@ -465,7 +621,7 @@ mod tests {
generic_test_op!(
a,
DataType::UInt32,
Sum,
SumTestStandin,
ScalarValue::from(15u64),
DataType::UInt64
)
Expand All @@ -478,7 +634,7 @@ mod tests {
generic_test_op!(
a,
DataType::Float32,
Sum,
SumTestStandin,
ScalarValue::from(15_f32),
DataType::Float32
)
Expand All @@ -491,7 +647,7 @@ mod tests {
generic_test_op!(
a,
DataType::Float64,
Sum,
SumTestStandin,
ScalarValue::from(15_f64),
DataType::Float64
)
Expand Down
10 changes: 0 additions & 10 deletions datafusion/src/physical_plan/groups_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
//! Vectorized [`GroupsAccumulator`]
use crate::error::{DataFusionError, Result};
use crate::scalar::ScalarValue;
use arrow::array::{ArrayRef, BooleanArray};
use smallvec::SmallVec;

/// From upstream: This replaces a datafusion_common::{not_impl_err} import.
macro_rules! not_impl_err {
Expand Down Expand Up @@ -194,10 +192,6 @@ pub trait GroupsAccumulator: Send {
/// `n`. See [`EmitTo::First`] for more details.
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef>;

// TODO: Remove this?
/// evaluate for a particular group index.
fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue>;

/// Returns the intermediate aggregate state for this accumulator,
/// used for multi-phase grouping, resetting its internal state.
///
Expand All @@ -216,10 +210,6 @@ pub trait GroupsAccumulator: Send {
/// [`Accumulator::state`]: crate::accumulator::Accumulator::state
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;

// TODO: Remove this?
/// Looks at the state for a particular group index.
fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>>;

/// Merges intermediate state (the output from [`Self::state`])
/// into this accumulator's current state.
///
Expand Down
9 changes: 0 additions & 9 deletions datafusion/src/physical_plan/groups_accumulator_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ use arrow::{
compute,
datatypes::UInt32Type,
};
use smallvec::SmallVec;

/// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`]
///
Expand Down Expand Up @@ -345,10 +344,6 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
result
}

fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue> {
self.states[group_index].accumulator.evaluate()
}

// filtered_null_mask(opt_filter, &values);
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let vec_size_pre = self.states.allocated_size();
Expand Down Expand Up @@ -385,10 +380,6 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
Ok(arrays)
}

fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>> {
self.states[group_index].accumulator.state()
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,6 @@ impl<AccumulatorType: Accumulator> GroupsAccumulator
result
}

fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue> {
self.accumulators[group_index].evaluate()
}

// filtered_null_mask(opt_filter, &values);
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
let vec_size_pre = self.accumulators.allocated_size();
Expand Down Expand Up @@ -428,10 +424,6 @@ impl<AccumulatorType: Accumulator> GroupsAccumulator
Ok(arrays)
}

fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>> {
self.accumulators[group_index].state()
}

fn merge_batch(
&mut self,
values: &[ArrayRef],
Expand Down
Loading

0 comments on commit 64ae03e

Please sign in to comment.