diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index bcf4fec071d4..cafd61d9ea9e 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -265,7 +265,7 @@ mod tests { use crate::error::Result; use crate::logical_plan::Operator; - use crate::physical_plan::aggregates::AggregateExec; + use crate::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; use crate::physical_plan::expressions::Count; @@ -407,7 +407,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], source, Arc::clone(&schema), @@ -415,7 +415,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), @@ -435,7 +435,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], source, Arc::clone(&schema), @@ -443,7 +443,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), @@ -462,7 +462,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], source, Arc::clone(&schema), @@ -473,7 +473,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], Arc::new(coalesce), Arc::clone(&schema), @@ -492,7 +492,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], source, Arc::clone(&schema), @@ -503,7 +503,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], Arc::new(coalesce), Arc::clone(&schema), @@ -533,7 +533,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], filter, Arc::clone(&schema), @@ -541,7 +541,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), @@ -576,7 +576,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], filter, Arc::clone(&schema), @@ -584,7 +584,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, - vec![], + PhysicalGroupBy::default(), vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), diff --git a/datafusion/core/src/physical_optimizer/repartition.rs b/datafusion/core/src/physical_optimizer/repartition.rs index b3b7ba9486a9..e9e14abf6394 100644 --- a/datafusion/core/src/physical_optimizer/repartition.rs +++ b/datafusion/core/src/physical_optimizer/repartition.rs @@ -242,7 +242,9 @@ mod tests { use super::*; use crate::datasource::listing::PartitionedFile; use crate::datasource::object_store::ObjectStoreUrl; - use crate::physical_plan::aggregates::{AggregateExec, AggregateMode}; + use crate::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; use crate::physical_plan::expressions::{col, PhysicalSortExpr}; use crate::physical_plan::file_format::{FileScanConfig, ParquetExec}; use crate::physical_plan::filter::FilterExec; @@ -305,12 +307,12 @@ mod tests { Arc::new( AggregateExec::try_new( AggregateMode::Final, - vec![], + PhysicalGroupBy::default(), vec![], Arc::new( AggregateExec::try_new( AggregateMode::Partial, - vec![], + PhysicalGroupBy::default(), vec![], input, schema.clone(), diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs b/datafusion/core/src/physical_plan/aggregates/hash.rs index 45719260ccf5..ddf9af18fd78 100644 --- a/datafusion/core/src/physical_plan/aggregates/hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/hash.rs @@ -29,7 +29,7 @@ use futures::{ use crate::error::Result; use crate::physical_plan::aggregates::{ - evaluate, evaluate_many, AccumulatorItem, AggregateMode, + evaluate_group_by, evaluate_many, AccumulatorItem, AggregateMode, PhysicalGroupBy, }; use crate::physical_plan::hash_utils::create_hashes; use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput}; @@ -81,7 +81,7 @@ pub(crate) struct GroupedHashAggregateStream { aggregate_expressions: Vec>>, aggr_expr: Vec>, - group_expr: Vec>, + group_by: PhysicalGroupBy, baseline_metrics: BaselineMetrics, random_state: RandomState, @@ -93,7 +93,7 @@ impl GroupedHashAggregateStream { pub fn new( mode: AggregateMode, schema: SchemaRef, - group_expr: Vec>, + group_by: PhysicalGroupBy, aggr_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, @@ -104,7 +104,7 @@ impl GroupedHashAggregateStream { // Assume create_schema() always put group columns in front of aggr columns, we set // col_idx_base to group expression count. let aggregate_expressions = - aggregates::aggregate_expressions(&aggr_expr, &mode, group_expr.len())?; + aggregates::aggregate_expressions(&aggr_expr, &mode, group_by.expr.len())?; timer.done(); @@ -113,7 +113,7 @@ impl GroupedHashAggregateStream { mode, input, aggr_expr, - group_expr, + group_by, baseline_metrics, aggregate_expressions, accumulators: Default::default(), @@ -144,7 +144,7 @@ impl Stream for GroupedHashAggregateStream { let result = group_aggregate_batch( &this.mode, &this.random_state, - &this.group_expr, + &this.group_by, &this.aggr_expr, batch, &mut this.accumulators, @@ -165,7 +165,7 @@ impl Stream for GroupedHashAggregateStream { let result = create_batch_from_map( &this.mode, &this.accumulators, - this.group_expr.len(), + this.group_by.expr.len(), &this.schema, ) .record_output(&this.baseline_metrics); @@ -191,152 +191,154 @@ impl RecordBatchStream for GroupedHashAggregateStream { fn group_aggregate_batch( mode: &AggregateMode, random_state: &RandomState, - group_expr: &[Arc], + group_by: &PhysicalGroupBy, aggr_expr: &[Arc], batch: RecordBatch, accumulators: &mut Accumulators, aggregate_expressions: &[Vec>], ) -> Result<()> { // evaluate the grouping expressions - let group_values = evaluate(group_expr, &batch)?; + let group_by_values = evaluate_group_by(group_by, &batch)?; // evaluate the aggregation expressions. // We could evaluate them after the `take`, but since we need to evaluate all // of them anyways, it is more performant to do it while they are together. let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?; - // 1.1 construct the key from the group values - // 1.2 construct the mapping key if it does not exist - // 1.3 add the row' index to `indices` - - // track which entries in `accumulators` have rows in this batch to aggregate - let mut groups_with_rows = vec![]; - - // 1.1 Calculate the group keys for the group values - let mut batch_hashes = vec![0; batch.num_rows()]; - create_hashes(&group_values, random_state, &mut batch_hashes)?; - - for (row, hash) in batch_hashes.into_iter().enumerate() { - let Accumulators { map, group_states } = accumulators; - - let entry = map.get_mut(hash, |(_hash, group_idx)| { - // verify that a group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) - let group_state = &group_states[*group_idx]; - group_values - .iter() - .zip(group_state.group_by_values.iter()) - .all(|(array, scalar)| scalar.eq_array(array, row)) - }); - - match entry { - // Existing entry for this group value - Some((_hash, group_idx)) => { - let group_state = &mut group_states[*group_idx]; - // 1.3 - if group_state.indices.is_empty() { - groups_with_rows.push(*group_idx); - }; - group_state.indices.push(row as u32); // remember this row - } - // 1.2 Need to create new entry - None => { - let accumulator_set = aggregates::create_accumulators(aggr_expr)?; + for grouping_set_values in group_by_values { + // 1.1 construct the key from the group values + // 1.2 construct the mapping key if it does not exist + // 1.3 add the row' index to `indices` + + // track which entries in `accumulators` have rows in this batch to aggregate + let mut groups_with_rows = vec![]; + + // 1.1 Calculate the group keys for the group values + let mut batch_hashes = vec![0; batch.num_rows()]; + create_hashes(&grouping_set_values, random_state, &mut batch_hashes)?; - // Copy group values out of arrays into `ScalarValue`s - let group_by_values = group_values + for (row, hash) in batch_hashes.into_iter().enumerate() { + let Accumulators { map, group_states } = accumulators; + + let entry = map.get_mut(hash, |(_hash, group_idx)| { + // verify that a group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + let group_state = &group_states[*group_idx]; + grouping_set_values .iter() - .map(|col| ScalarValue::try_from_array(col, row)) - .collect::>>()?; - - // Add new entry to group_states and save newly created index - let group_state = GroupState { - group_by_values: group_by_values.into_boxed_slice(), - accumulator_set, - indices: vec![row as u32], // 1.3 - }; - let group_idx = group_states.len(); - group_states.push(group_state); - groups_with_rows.push(group_idx); - - // for hasher function, use precomputed hash value - map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash); - } - }; - } + .zip(group_state.group_by_values.iter()) + .all(|(array, scalar)| scalar.eq_array(array, row)) + }); + + match entry { + // Existing entry for this group value + Some((_hash, group_idx)) => { + let group_state = &mut group_states[*group_idx]; + // 1.3 + if group_state.indices.is_empty() { + groups_with_rows.push(*group_idx); + }; + group_state.indices.push(row as u32); // remember this row + } + // 1.2 Need to create new entry + None => { + let accumulator_set = aggregates::create_accumulators(aggr_expr)?; + + // Copy group values out of arrays into `ScalarValue`s + let group_by_values = grouping_set_values + .iter() + .map(|col| ScalarValue::try_from_array(col, row)) + .collect::>>()?; + + // Add new entry to group_states and save newly created index + let group_state = GroupState { + group_by_values: group_by_values.into_boxed_slice(), + accumulator_set, + indices: vec![row as u32], // 1.3 + }; + let group_idx = group_states.len(); + group_states.push(group_state); + groups_with_rows.push(group_idx); + + // for hasher function, use precomputed hash value + map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash); + } + }; + } - // Collect all indices + offsets based on keys in this vec - let mut batch_indices: UInt32Builder = UInt32Builder::new(0); - let mut offsets = vec![0]; - let mut offset_so_far = 0; - for group_idx in groups_with_rows.iter() { - let indices = &accumulators.group_states[*group_idx].indices; - batch_indices.append_slice(indices)?; - offset_so_far += indices.len(); - offsets.push(offset_so_far); - } - let batch_indices = batch_indices.finish(); + // Collect all indices + offsets based on keys in this vec + let mut batch_indices: UInt32Builder = UInt32Builder::new(0); + let mut offsets = vec![0]; + let mut offset_so_far = 0; + for group_idx in groups_with_rows.iter() { + let indices = &accumulators.group_states[*group_idx].indices; + batch_indices.append_slice(indices)?; + offset_so_far += indices.len(); + offsets.push(offset_so_far); + } + let batch_indices = batch_indices.finish(); - // `Take` all values based on indices into Arrays - let values: Vec>> = aggr_input_values - .iter() - .map(|array| { - array - .iter() - .map(|array| { - compute::take( - array.as_ref(), - &batch_indices, - None, // None: no index check - ) - .unwrap() - }) - .collect() - // 2.3 - }) - .collect(); - - // 2.1 for each key in this batch - // 2.2 for each aggregation - // 2.3 `slice` from each of its arrays the keys' values - // 2.4 update / merge the accumulator with the values - // 2.5 clear indices - groups_with_rows - .iter() - .zip(offsets.windows(2)) - .try_for_each(|(group_idx, offsets)| { - let group_state = &mut accumulators.group_states[*group_idx]; - // 2.2 - group_state - .accumulator_set - .iter_mut() - .zip(values.iter()) - .map(|(accumulator, aggr_array)| { - ( - accumulator, - aggr_array - .iter() - .map(|array| { - // 2.3 - array.slice(offsets[0], offsets[1] - offsets[0]) - }) - .collect::>(), - ) - }) - .try_for_each(|(accumulator, values)| match mode { - AggregateMode::Partial => accumulator.update_batch(&values), - AggregateMode::FinalPartitioned | AggregateMode::Final => { - // note: the aggregation here is over states, not values, thus the merge - accumulator.merge_batch(&values) - } - }) - // 2.5 - .and({ - group_state.indices.clear(); - Ok(()) - }) - })?; + // `Take` all values based on indices into Arrays + let values: Vec>> = aggr_input_values + .iter() + .map(|array| { + array + .iter() + .map(|array| { + compute::take( + array.as_ref(), + &batch_indices, + None, // None: no index check + ) + .unwrap() + }) + .collect() + // 2.3 + }) + .collect(); + + // 2.1 for each key in this batch + // 2.2 for each aggregation + // 2.3 `slice` from each of its arrays the keys' values + // 2.4 update / merge the accumulator with the values + // 2.5 clear indices + groups_with_rows + .iter() + .zip(offsets.windows(2)) + .try_for_each(|(group_idx, offsets)| { + let group_state = &mut accumulators.group_states[*group_idx]; + // 2.2 + group_state + .accumulator_set + .iter_mut() + .zip(values.iter()) + .map(|(accumulator, aggr_array)| { + ( + accumulator, + aggr_array + .iter() + .map(|array| { + // 2.3 + array.slice(offsets[0], offsets[1] - offsets[0]) + }) + .collect::>(), + ) + }) + .try_for_each(|(accumulator, values)| match mode { + AggregateMode::Partial => accumulator.update_batch(&values), + AggregateMode::FinalPartitioned | AggregateMode::Final => { + // note: the aggregation here is over states, not values, thus the merge + accumulator.merge_batch(&values) + } + }) + // 2.5 + .and({ + group_state.indices.clear(); + Ok(()) + }) + })?; + } Ok(()) } diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index abe20cdcbc94..657b6281a559 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -37,6 +37,7 @@ use datafusion_physical_expr::{ expressions, AggregateExpr, PhysicalExpr, PhysicalSortExpr, }; use std::any::Any; + use std::sync::Arc; mod hash; @@ -65,13 +66,93 @@ pub enum AggregateMode { FinalPartitioned, } +/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET) +/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b] +/// and a single group [false, false]. +/// In the case of `GROUP BY GROUPING SET/CUBE/ROLLUP` the planner will expand the expression +/// into multiple groups, using null expressions to align each group. +/// For example, with a group by clause `GROUP BY GROUPING SET ((a,b),(a),(b))` the planner should +/// create a `PhysicalGroupBy` like +/// PhysicalGroupBy { +/// expr: [(col(a), a), (col(b), b)], +/// null_expr: [(NULL, a), (NULL, b)], +/// groups: [ +/// [false, false], // (a,b) +/// [false, true], // (a) <=> (a, NULL) +/// [true, false] // (b) <=> (NULL, b) +/// ] +/// } +#[derive(Clone, Debug, Default)] +pub struct PhysicalGroupBy { + /// Distinct (Physical Expr, Alias) in the grouping set + expr: Vec<(Arc, String)>, + /// Corresponding NULL expressions for expr + null_expr: Vec<(Arc, String)>, + /// Null mask for each group in this grouping set. Each group is + /// composed of either one of the group expressions in expr or a null + /// expression in null_expr. If groups[i][j] is true, then the the + /// j-th expression in the i-th group is NULL, otherwise it is expr[j]. + groups: Vec>, +} + +impl PhysicalGroupBy { + /// Create a new `PhysicalGroupBy` + pub fn new( + expr: Vec<(Arc, String)>, + null_expr: Vec<(Arc, String)>, + groups: Vec>, + ) -> Self { + Self { + expr, + null_expr, + groups, + } + } + + /// Create a GROUPING SET with only a single group. This is the "standard" + /// case when building a plan from an expression such as `GROUP BY a,b,c` + pub fn new_single(expr: Vec<(Arc, String)>) -> Self { + let num_exprs = expr.len(); + Self { + expr, + null_expr: vec![], + groups: vec![vec![false; num_exprs]], + } + } + + /// Returns true if this GROUP BY contains NULL expressions + pub fn contains_null(&self) -> bool { + self.groups.iter().flatten().any(|is_null| *is_null) + } + + /// Returns the group expressions + pub fn expr(&self) -> &[(Arc, String)] { + &self.expr + } + + /// Returns the null expressions + pub fn null_expr(&self) -> &[(Arc, String)] { + &self.null_expr + } + + /// Returns the group null masks + pub fn groups(&self) -> &[Vec] { + &self.groups + } + + /// Returns true if this `PhysicalGroupBy` has no group expressions + pub fn is_empty(&self) -> bool { + self.expr.is_empty() + } +} + /// Hash aggregate execution plan #[derive(Debug)] pub struct AggregateExec { /// Aggregation mode (full, partial) mode: AggregateMode, - /// Grouping expressions - group_expr: Vec<(Arc, String)>, + /// Group by expressions + group_by: PhysicalGroupBy, /// Aggregate expressions aggr_expr: Vec>, /// Input plan, could be a partial aggregate or the input to the aggregate @@ -90,18 +171,24 @@ impl AggregateExec { /// Create a new hash aggregate execution plan pub fn try_new( mode: AggregateMode, - group_expr: Vec<(Arc, String)>, + group_by: PhysicalGroupBy, aggr_expr: Vec>, input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema(&input.schema(), &group_expr, &aggr_expr, mode)?; + let schema = create_schema( + &input.schema(), + &group_by.expr, + &aggr_expr, + group_by.contains_null(), + mode, + )?; let schema = Arc::new(schema); Ok(AggregateExec { mode, - group_expr, + group_by, aggr_expr, input, schema, @@ -116,15 +203,16 @@ impl AggregateExec { } /// Grouping expressions - pub fn group_expr(&self) -> &[(Arc, String)] { - &self.group_expr + pub fn group_expr(&self) -> &PhysicalGroupBy { + &self.group_by } /// Grouping expressions as they occur in the output schema pub fn output_group_expr(&self) -> Vec> { // Update column indices. Since the group by columns come first in the output schema, their // indices are simply 0..self.group_expr(len). - self.group_expr + self.group_by + .expr() .iter() .enumerate() .map(|(index, (_col, name))| { @@ -149,7 +237,7 @@ impl AggregateExec { } fn row_aggregate_supported(&self) -> bool { - let group_schema = group_schema(&self.schema, self.group_expr.len()); + let group_schema = group_schema(&self.schema, self.group_by.expr.len()); row_supported(&group_schema, RowType::Compact) && accumulator_v2_supported(&self.aggr_expr) } @@ -178,7 +266,7 @@ impl ExecutionPlan for AggregateExec { match &self.mode { AggregateMode::Partial => Distribution::UnspecifiedDistribution, AggregateMode::FinalPartitioned => Distribution::HashPartitioned( - self.group_expr.iter().map(|x| x.0.clone()).collect(), + self.group_by.expr.iter().map(|x| x.0.clone()).collect(), ), AggregateMode::Final => Distribution::SinglePartition, } @@ -198,7 +286,7 @@ impl ExecutionPlan for AggregateExec { ) -> Result> { Ok(Arc::new(AggregateExec::try_new( self.mode, - self.group_expr.clone(), + self.group_by.clone(), self.aggr_expr.clone(), children[0].clone(), self.input_schema.clone(), @@ -211,11 +299,10 @@ impl ExecutionPlan for AggregateExec { context: Arc, ) -> Result { let input = self.input.execute(partition, context)?; - let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect(); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - if self.group_expr.is_empty() { + if self.group_by.expr.is_empty() { Ok(Box::pin(AggregateStream::new( self.mode, self.schema.clone(), @@ -227,7 +314,7 @@ impl ExecutionPlan for AggregateExec { Ok(Box::pin(GroupedHashAggregateStreamV2::new( self.mode, self.schema.clone(), - group_expr, + self.group_by.clone(), self.aggr_expr.clone(), input, baseline_metrics, @@ -236,7 +323,7 @@ impl ExecutionPlan for AggregateExec { Ok(Box::pin(GroupedHashAggregateStream::new( self.mode, self.schema.clone(), - group_expr, + self.group_by.clone(), self.aggr_expr.clone(), input, baseline_metrics, @@ -256,18 +343,53 @@ impl ExecutionPlan for AggregateExec { match t { DisplayFormatType::Default => { write!(f, "AggregateExec: mode={:?}", self.mode)?; - let g: Vec = self - .group_expr - .iter() - .map(|(e, alias)| { - let e = e.to_string(); - if &e != alias { - format!("{} as {}", e, alias) - } else { - e - } - }) - .collect(); + let g: Vec = if self.group_by.groups.len() == 1 { + self.group_by + .expr + .iter() + .map(|(e, alias)| { + let e = e.to_string(); + if &e != alias { + format!("{} as {}", e, alias) + } else { + e + } + }) + .collect() + } else { + self.group_by + .groups + .iter() + .map(|group| { + let terms = group + .iter() + .enumerate() + .map(|(idx, is_null)| { + if *is_null { + let (e, alias) = &self.group_by.null_expr[idx]; + let e = e.to_string(); + if &e != alias { + format!("{} as {}", e, alias) + } else { + e + } + } else { + let (e, alias) = &self.group_by.expr[idx]; + let e = e.to_string(); + if &e != alias { + format!("{} as {}", e, alias) + } else { + e + } + } + }) + .collect::>() + .join(", "); + format!("({})", terms) + }) + .collect() + }; + write!(f, ", gby=[{}]", g.join(", "))?; let a: Vec = self @@ -289,7 +411,7 @@ impl ExecutionPlan for AggregateExec { // - aggregations somtimes also preserve invariants such as min, max... match self.mode { AggregateMode::Final | AggregateMode::FinalPartitioned - if self.group_expr.is_empty() => + if self.group_by.expr.is_empty() => { Statistics { num_rows: Some(1), @@ -306,6 +428,7 @@ fn create_schema( input_schema: &Schema, group_expr: &[(Arc, String)], aggr_expr: &[Arc], + contains_null_expr: bool, mode: AggregateMode, ) -> datafusion_common::Result { let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); @@ -313,7 +436,10 @@ fn create_schema( fields.push(Field::new( name, expr.data_type(input_schema)?, - expr.nullable(input_schema)?, + // In cases where we have multiple grouping sets, we will use NULL expressions in + // order to align the grouping sets. So the field must be nullable even if the underlying + // schema field is not. + contains_null_expr || expr.nullable(input_schema)?, )) } @@ -469,11 +595,54 @@ fn evaluate_many( .collect::>>() } +fn evaluate_group_by( + group_by: &PhysicalGroupBy, + batch: &RecordBatch, +) -> Result>> { + let exprs: Vec = group_by + .expr + .iter() + .map(|(expr, _)| { + let value = expr.evaluate(batch)?; + Ok(value.into_array(batch.num_rows())) + }) + .collect::>>()?; + + let null_exprs: Vec = group_by + .null_expr + .iter() + .map(|(expr, _)| { + let value = expr.evaluate(batch)?; + Ok(value.into_array(batch.num_rows())) + }) + .collect::>>()?; + + Ok(group_by + .groups + .iter() + .map(|group| { + group + .iter() + .enumerate() + .map(|(idx, is_null)| { + if *is_null { + null_exprs[idx].clone() + } else { + exprs[idx].clone() + } + }) + .collect() + }) + .collect()) +} + #[cfg(test)] mod tests { use crate::execution::context::TaskContext; use crate::from_slice::FromSlice; - use crate::physical_plan::aggregates::{AggregateExec, AggregateMode}; + use crate::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; use crate::physical_plan::expressions::{col, Avg}; use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; @@ -482,7 +651,8 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; - use datafusion_common::{DataFusionError, Result}; + use datafusion_common::{DataFusionError, Result, ScalarValue}; + use datafusion_physical_expr::expressions::{lit, Count}; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; use futures::{FutureExt, Stream}; use std::any::Any; @@ -528,12 +698,129 @@ mod tests { ) } + async fn check_grouping_sets(input: Arc) -> Result<()> { + let input_schema = input.schema(); + + let grouping_set = PhysicalGroupBy { + expr: vec![ + (col("a", &input_schema)?, "a".to_string()), + (col("b", &input_schema)?, "b".to_string()), + ], + null_expr: vec![ + (lit(ScalarValue::UInt32(None)), "a".to_string()), + (lit(ScalarValue::Float64(None)), "b".to_string()), + ], + groups: vec![ + vec![false, true], // (a, NULL) + vec![true, false], // (NULL, b) + vec![false, false], // (a,b) + ], + }; + + let aggregates: Vec> = vec![Arc::new(Count::new( + lit(ScalarValue::Int8(Some(1))), + "COUNT(1)".to_string(), + DataType::Int64, + ))]; + + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + let partial_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + grouping_set.clone(), + aggregates.clone(), + input, + input_schema.clone(), + )?); + + let result = + common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?; + + let expected = vec![ + "+---+---+-----------------+", + "| a | b | COUNT(1)[count] |", + "+---+---+-----------------+", + "| | 1 | 2 |", + "| | 2 | 2 |", + "| | 3 | 2 |", + "| | 4 | 2 |", + "| 2 | | 2 |", + "| 2 | 1 | 2 |", + "| 3 | | 3 |", + "| 3 | 2 | 2 |", + "| 3 | 3 | 1 |", + "| 4 | | 3 |", + "| 4 | 3 | 1 |", + "| 4 | 4 | 2 |", + "+---+---+-----------------+", + ]; + assert_batches_sorted_eq!(expected, &result); + + let groups = partial_aggregate.group_expr().expr().to_vec(); + + let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); + + let final_group: Vec<(Arc, String)> = groups + .iter() + .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) + .collect::>()?; + + let final_grouping_set = PhysicalGroupBy::new_single(final_group); + + let merged_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + final_grouping_set, + aggregates, + merge, + input_schema, + )?); + + let result = + common::collect(merged_aggregate.execute(0, task_ctx.clone())?).await?; + assert_eq!(result.len(), 1); + + let batch = &result[0]; + assert_eq!(batch.num_columns(), 3); + assert_eq!(batch.num_rows(), 12); + + let expected = vec![ + "+---+---+----------+", + "| a | b | COUNT(1) |", + "+---+---+----------+", + "| | 1 | 2 |", + "| | 2 | 2 |", + "| | 3 | 2 |", + "| | 4 | 2 |", + "| 2 | | 2 |", + "| 2 | 1 | 2 |", + "| 3 | | 3 |", + "| 3 | 2 | 2 |", + "| 3 | 3 | 1 |", + "| 4 | | 3 |", + "| 4 | 3 | 1 |", + "| 4 | 4 | 2 |", + "+---+---+----------+", + ]; + + assert_batches_sorted_eq!(&expected, &result); + + let metrics = merged_aggregate.metrics().unwrap(); + let output_rows = metrics.output_rows().unwrap(); + assert_eq!(12, output_rows); + + Ok(()) + } + /// build the aggregates on the data from some_data() and check the results async fn check_aggregates(input: Arc) -> Result<()> { let input_schema = input.schema(); - let groups: Vec<(Arc, String)> = - vec![(col("a", &input_schema)?, "a".to_string())]; + let grouping_set = PhysicalGroupBy { + expr: vec![(col("a", &input_schema)?, "a".to_string())], + null_expr: vec![], + groups: vec![vec![false]], + }; let aggregates: Vec> = vec![Arc::new(Avg::new( col("b", &input_schema)?, @@ -546,7 +833,7 @@ mod tests { let partial_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Partial, - groups.clone(), + grouping_set.clone(), aggregates.clone(), input, input_schema.clone(), @@ -568,17 +855,17 @@ mod tests { let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); - let final_group: Vec> = (0..groups.len()) - .map(|i| col(&groups[i].1, &input_schema)) + let final_group: Vec<(Arc, String)> = grouping_set + .expr + .iter() + .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) .collect::>()?; + let final_grouping_set = PhysicalGroupBy::new_single(final_group); + let merged_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, - final_group - .iter() - .enumerate() - .map(|(i, expr)| (expr.clone(), groups[i].1.clone())) - .collect(), + final_grouping_set, aggregates, merge, input_schema, @@ -719,6 +1006,14 @@ mod tests { check_aggregates(input).await } + #[tokio::test] + async fn aggregate_grouping_sets_source_not_yielding() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: false }); + + check_grouping_sets(input).await + } + #[tokio::test] async fn aggregate_source_with_yielding() -> Result<()> { let input: Arc = @@ -727,6 +1022,14 @@ mod tests { check_aggregates(input).await } + #[tokio::test] + async fn aggregate_grouping_sets_with_yielding() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: true }); + + check_grouping_sets(input).await + } + #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { let session_ctx = SessionContext::new(); @@ -734,7 +1037,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); - let groups = vec![]; + let groups = PhysicalGroupBy::default(); let aggregates: Vec> = vec![Arc::new(Avg::new( col("a", &schema)?, @@ -771,8 +1074,8 @@ mod tests { Field::new("b", DataType::Float32, true), ])); - let groups: Vec<(Arc, String)> = - vec![(col("a", &schema)?, "a".to_string())]; + let groups = + PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); let aggregates: Vec> = vec![Arc::new(Avg::new( col("b", &schema)?, @@ -784,7 +1087,7 @@ mod tests { let refs = blocking_exec.refs(); let aggregate_exec = Arc::new(AggregateExec::try_new( AggregateMode::Partial, - groups.clone(), + groups, aggregates.clone(), blocking_exec, schema, diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index e364048e75fd..5353bc745c1b 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -29,7 +29,8 @@ use futures::{ use crate::error::Result; use crate::physical_plan::aggregates::{ - evaluate, evaluate_many, group_schema, AccumulatorItemV2, AggregateMode, + evaluate_group_by, evaluate_many, group_schema, AccumulatorItemV2, AggregateMode, + PhysicalGroupBy, }; use crate::physical_plan::hash_utils::create_row_hashes; use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput}; @@ -75,7 +76,7 @@ pub(crate) struct GroupedHashAggregateStreamV2 { aggr_state: AggregationState, aggregate_expressions: Vec>>, - group_expr: Vec>, + group_by: PhysicalGroupBy, accumulators: Vec, group_schema: SchemaRef, @@ -100,7 +101,7 @@ impl GroupedHashAggregateStreamV2 { pub fn new( mode: AggregateMode, schema: SchemaRef, - group_expr: Vec>, + group_by: PhysicalGroupBy, aggr_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, @@ -111,11 +112,11 @@ impl GroupedHashAggregateStreamV2 { // Assume create_schema() always put group columns in front of aggr columns, we set // col_idx_base to group expression count. let aggregate_expressions = - aggregates::aggregate_expressions(&aggr_expr, &mode, group_expr.len())?; + aggregates::aggregate_expressions(&aggr_expr, &mode, group_by.expr.len())?; let accumulators = aggregates::create_accumulators_v2(&aggr_expr)?; - let group_schema = group_schema(&schema, group_expr.len()); + let group_schema = group_schema(&schema, group_by.expr.len()); let aggr_schema = aggr_state_schema(&aggr_expr)?; let aggr_layout = Arc::new(RowLayout::new(&aggr_schema, RowType::WordAligned)); @@ -125,7 +126,7 @@ impl GroupedHashAggregateStreamV2 { schema, mode, input, - group_expr, + group_by, accumulators, group_schema, aggr_schema, @@ -160,7 +161,7 @@ impl Stream for GroupedHashAggregateStreamV2 { let result = group_aggregate_batch( &this.mode, &this.random_state, - &this.group_expr, + &this.group_by, &mut this.accumulators, &this.group_schema, this.aggr_layout.clone(), @@ -212,7 +213,7 @@ impl RecordBatchStream for GroupedHashAggregateStreamV2 { fn group_aggregate_batch( mode: &AggregateMode, random_state: &RandomState, - group_expr: &[Arc], + grouping_set: &PhysicalGroupBy, accumulators: &mut [AccumulatorItemV2], group_schema: &Schema, state_layout: Arc, @@ -221,142 +222,145 @@ fn group_aggregate_batch( aggregate_expressions: &[Vec>], ) -> Result<()> { // evaluate the grouping expressions - let group_values = evaluate(group_expr, &batch)?; - let group_rows: Vec> = create_group_rows(group_values, group_schema); - - // evaluate the aggregation expressions. - // We could evaluate them after the `take`, but since we need to evaluate all - // of them anyways, it is more performant to do it while they are together. - let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?; - - // 1.1 construct the key from the group values - // 1.2 construct the mapping key if it does not exist - // 1.3 add the row' index to `indices` - - // track which entries in `aggr_state` have rows in this batch to aggregate - let mut groups_with_rows = vec![]; - - // 1.1 Calculate the group keys for the group values - let mut batch_hashes = vec![0; batch.num_rows()]; - create_row_hashes(&group_rows, random_state, &mut batch_hashes)?; - - for (row, hash) in batch_hashes.into_iter().enumerate() { - let AggregationState { map, group_states } = aggr_state; - - let entry = map.get_mut(hash, |(_hash, group_idx)| { - // verify that a group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) - let group_state = &group_states[*group_idx]; - group_rows[row] == group_state.group_by_values - }); - - match entry { - // Existing entry for this group value - Some((_hash, group_idx)) => { - let group_state = &mut group_states[*group_idx]; - // 1.3 - if group_state.indices.is_empty() { - groups_with_rows.push(*group_idx); - }; - group_state.indices.push(row as u32); // remember this row - } - // 1.2 Need to create new entry - None => { - // Add new entry to group_states and save newly created index - let group_state = RowGroupState { - group_by_values: group_rows[row].clone(), - aggregation_buffer: vec![0; state_layout.fixed_part_width()], - indices: vec![row as u32], // 1.3 - }; - let group_idx = group_states.len(); - group_states.push(group_state); - groups_with_rows.push(group_idx); - - // for hasher function, use precomputed hash value - map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash); - } - }; - } - - // Collect all indices + offsets based on keys in this vec - let mut batch_indices: UInt32Builder = UInt32Builder::new(0); - let mut offsets = vec![0]; - let mut offset_so_far = 0; - for group_idx in groups_with_rows.iter() { - let indices = &aggr_state.group_states[*group_idx].indices; - batch_indices.append_slice(indices)?; - offset_so_far += indices.len(); - offsets.push(offset_so_far); - } - let batch_indices = batch_indices.finish(); + let grouping_by_values = evaluate_group_by(grouping_set, &batch)?; + + for group_values in grouping_by_values { + let group_rows: Vec> = create_group_rows(group_values, group_schema); + + // evaluate the aggregation expressions. + // We could evaluate them after the `take`, but since we need to evaluate all + // of them anyways, it is more performant to do it while they are together. + let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?; + + // 1.1 construct the key from the group values + // 1.2 construct the mapping key if it does not exist + // 1.3 add the row' index to `indices` + + // track which entries in `aggr_state` have rows in this batch to aggregate + let mut groups_with_rows = vec![]; + + // 1.1 Calculate the group keys for the group values + let mut batch_hashes = vec![0; batch.num_rows()]; + create_row_hashes(&group_rows, random_state, &mut batch_hashes)?; + + for (row, hash) in batch_hashes.into_iter().enumerate() { + let AggregationState { map, group_states } = aggr_state; + + let entry = map.get_mut(hash, |(_hash, group_idx)| { + // verify that a group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + let group_state = &group_states[*group_idx]; + group_rows[row] == group_state.group_by_values + }); + + match entry { + // Existing entry for this group value + Some((_hash, group_idx)) => { + let group_state = &mut group_states[*group_idx]; + // 1.3 + if group_state.indices.is_empty() { + groups_with_rows.push(*group_idx); + }; + group_state.indices.push(row as u32); // remember this row + } + // 1.2 Need to create new entry + None => { + // Add new entry to group_states and save newly created index + let group_state = RowGroupState { + group_by_values: group_rows[row].clone(), + aggregation_buffer: vec![0; state_layout.fixed_part_width()], + indices: vec![row as u32], // 1.3 + }; + let group_idx = group_states.len(); + group_states.push(group_state); + groups_with_rows.push(group_idx); + + // for hasher function, use precomputed hash value + map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash); + } + }; + } - // `Take` all values based on indices into Arrays - let values: Vec>> = aggr_input_values - .iter() - .map(|array| { - array - .iter() - .map(|array| { - compute::take( - array.as_ref(), - &batch_indices, - None, // None: no index check - ) - .unwrap() - }) - .collect() - // 2.3 - }) - .collect(); - - // 2.1 for each key in this batch - // 2.2 for each aggregation - // 2.3 `slice` from each of its arrays the keys' values - // 2.4 update / merge the accumulator with the values - // 2.5 clear indices - groups_with_rows - .iter() - .zip(offsets.windows(2)) - .try_for_each(|(group_idx, offsets)| { - let group_state = &mut aggr_state.group_states[*group_idx]; - // 2.2 - accumulators - .iter_mut() - .zip(values.iter()) - .map(|(accumulator, aggr_array)| { - ( - accumulator, - aggr_array - .iter() - .map(|array| { - // 2.3 - array.slice(offsets[0], offsets[1] - offsets[0]) - }) - .collect::>(), - ) - }) - .try_for_each(|(accumulator, values)| { - let mut state_accessor = - RowAccessor::new_from_layout(state_layout.clone()); - state_accessor - .point_to(0, group_state.aggregation_buffer.as_mut_slice()); - match mode { - AggregateMode::Partial => { - accumulator.update_batch(&values, &mut state_accessor) - } - AggregateMode::FinalPartitioned | AggregateMode::Final => { - // note: the aggregation here is over states, not values, thus the merge - accumulator.merge_batch(&values, &mut state_accessor) + // Collect all indices + offsets based on keys in this vec + let mut batch_indices: UInt32Builder = UInt32Builder::new(0); + let mut offsets = vec![0]; + let mut offset_so_far = 0; + for group_idx in groups_with_rows.iter() { + let indices = &aggr_state.group_states[*group_idx].indices; + batch_indices.append_slice(indices)?; + offset_so_far += indices.len(); + offsets.push(offset_so_far); + } + let batch_indices = batch_indices.finish(); + + // `Take` all values based on indices into Arrays + let values: Vec>> = aggr_input_values + .iter() + .map(|array| { + array + .iter() + .map(|array| { + compute::take( + array.as_ref(), + &batch_indices, + None, // None: no index check + ) + .unwrap() + }) + .collect() + // 2.3 + }) + .collect(); + + // 2.1 for each key in this batch + // 2.2 for each aggregation + // 2.3 `slice` from each of its arrays the keys' values + // 2.4 update / merge the accumulator with the values + // 2.5 clear indices + groups_with_rows + .iter() + .zip(offsets.windows(2)) + .try_for_each(|(group_idx, offsets)| { + let group_state = &mut aggr_state.group_states[*group_idx]; + // 2.2 + accumulators + .iter_mut() + .zip(values.iter()) + .map(|(accumulator, aggr_array)| { + ( + accumulator, + aggr_array + .iter() + .map(|array| { + // 2.3 + array.slice(offsets[0], offsets[1] - offsets[0]) + }) + .collect::>(), + ) + }) + .try_for_each(|(accumulator, values)| { + let mut state_accessor = + RowAccessor::new_from_layout(state_layout.clone()); + state_accessor + .point_to(0, group_state.aggregation_buffer.as_mut_slice()); + match mode { + AggregateMode::Partial => { + accumulator.update_batch(&values, &mut state_accessor) + } + AggregateMode::FinalPartitioned | AggregateMode::Final => { + // note: the aggregation here is over states, not values, thus the merge + accumulator.merge_batch(&values, &mut state_accessor) + } } - } - }) - // 2.5 - .and({ - group_state.indices.clear(); - Ok(()) - }) - })?; + }) + // 2.5 + .and({ + group_state.indices.clear(); + Ok(()) + }) + })?; + } Ok(()) } diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 60cc3b8de088..14cdee3016ac 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -37,7 +37,7 @@ use crate::logical_plan::{ use crate::logical_plan::{Limit, Values}; use crate::physical_expr::create_physical_expr; use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; -use crate::physical_plan::aggregates::{AggregateExec, AggregateMode}; +use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::cross_join::CrossJoinExec; use crate::physical_plan::explain::ExplainExec; use crate::physical_plan::expressions::{Column, PhysicalSortExpr}; @@ -58,10 +58,13 @@ use arrow::compute::SortOptions; use arrow::datatypes::DataType; use arrow::datatypes::{Schema, SchemaRef}; use async_trait::async_trait; +use datafusion_common::ScalarValue; use datafusion_expr::{expr::GroupingSet, utils::expr_to_columns}; +use datafusion_physical_expr::expressions::Literal; use datafusion_sql::utils::window_expr_common_partition_keys; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; +use itertools::Itertools; use log::{debug, trace}; use std::collections::{HashMap, HashSet}; use std::fmt::Write; @@ -535,20 +538,12 @@ impl DefaultPhysicalPlanner { let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); - let groups = group_expr - .iter() - .map(|e| { - tuple_err(( - self.create_physical_expr( - e, - logical_input_schema, - &physical_input_schema, - session_state, - ), - physical_name(e), - )) - }) - .collect::>>()?; + let groups = self.create_grouping_physical_expr( + group_expr, + logical_input_schema, + &physical_input_schema, + session_state)?; + let aggregates = aggr_expr .iter() .map(|e| { @@ -574,6 +569,7 @@ impl DefaultPhysicalPlanner { // TODO: dictionary type not yet supported in Hash Repartition let contains_dict = groups + .expr() .iter() .flat_map(|x| x.0.data_type(physical_input_schema.as_ref())) .any(|x| matches!(x, DataType::Dictionary(_, _))); @@ -603,13 +599,17 @@ impl DefaultPhysicalPlanner { (initial_aggr, AggregateMode::Final) }; - Ok(Arc::new(AggregateExec::try_new( - next_partition_mode, + let final_grouping_set = PhysicalGroupBy::new_single( final_group .iter() .enumerate() - .map(|(i, expr)| (expr.clone(), groups[i].1.clone())) - .collect(), + .map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone())) + .collect() + ); + + Ok(Arc::new(AggregateExec::try_new( + next_partition_mode, + final_grouping_set, aggregates, initial_aggr, physical_input_schema.clone(), @@ -1001,6 +1001,261 @@ impl DefaultPhysicalPlanner { exec_plan }.boxed() } + + fn create_grouping_physical_expr( + &self, + group_expr: &[Expr], + input_dfschema: &DFSchema, + input_schema: &Schema, + session_state: &SessionState, + ) -> Result { + if group_expr.len() == 1 { + match &group_expr[0] { + Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => { + merge_grouping_set_physical_expr( + grouping_sets, + input_dfschema, + input_schema, + session_state, + ) + } + Expr::GroupingSet(GroupingSet::Cube(exprs)) => create_cube_physical_expr( + exprs, + input_dfschema, + input_schema, + session_state, + ), + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + create_rollup_physical_expr( + exprs, + input_dfschema, + input_schema, + session_state, + ) + } + expr => Ok(PhysicalGroupBy::new_single(vec![tuple_err(( + self.create_physical_expr( + expr, + input_dfschema, + input_schema, + session_state, + ), + physical_name(expr), + ))?])), + } + } else { + Ok(PhysicalGroupBy::new_single( + group_expr + .iter() + .map(|e| { + tuple_err(( + self.create_physical_expr( + e, + input_dfschema, + input_schema, + session_state, + ), + physical_name(e), + )) + }) + .collect::>>()?, + )) + } + } +} + +/// Expand and align a GROUPING SET expression. +/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS) +/// +/// This will take a list of grouping sets and ensure that each group is +/// properly aligned for the physical execution plan. We do this by +/// identifying all unique expression in each group and conforming each +/// group to the same set of expression types and ordering. +/// For example, if we have something like `GROUPING SETS ((a,b,c),(a),(b),(b,c))` +/// we would expand this to `GROUPING SETS ((a,b,c),(a,NULL,NULL),(NULL,b,NULL),(NULL,b,c)) +/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS) +fn merge_grouping_set_physical_expr( + grouping_sets: &[Vec], + input_dfschema: &DFSchema, + input_schema: &Schema, + session_state: &SessionState, +) -> Result { + let num_groups = grouping_sets.len(); + let mut all_exprs: Vec = vec![]; + let mut grouping_set_expr: Vec<(Arc, String)> = vec![]; + let mut null_exprs: Vec<(Arc, String)> = vec![]; + + for expr in grouping_sets.iter().flatten() { + if !all_exprs.contains(expr) { + all_exprs.push(expr.clone()); + + grouping_set_expr.push(get_physical_expr_pair( + expr, + input_dfschema, + input_schema, + session_state, + )?); + + null_exprs.push(get_null_physical_expr_pair( + expr, + input_dfschema, + input_schema, + session_state, + )?); + } + } + + let mut merged_sets: Vec> = Vec::with_capacity(num_groups); + + for expr_group in grouping_sets.iter() { + let group: Vec = all_exprs + .iter() + .map(|expr| !expr_group.contains(expr)) + .collect(); + + merged_sets.push(group) + } + + Ok(PhysicalGroupBy::new( + grouping_set_expr, + null_exprs, + merged_sets, + )) +} + +/// Expand and align a CUBE expression. This is a special case of GROUPING SETS +/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS) +fn create_cube_physical_expr( + exprs: &[Expr], + input_dfschema: &DFSchema, + input_schema: &Schema, + session_state: &SessionState, +) -> Result { + let num_of_exprs = exprs.len(); + let num_groups = num_of_exprs * num_of_exprs; + + let mut null_exprs: Vec<(Arc, String)> = + Vec::with_capacity(num_of_exprs); + let mut all_exprs: Vec<(Arc, String)> = + Vec::with_capacity(num_of_exprs); + + for expr in exprs { + null_exprs.push(get_null_physical_expr_pair( + expr, + input_dfschema, + input_schema, + session_state, + )?); + + all_exprs.push(get_physical_expr_pair( + expr, + input_dfschema, + input_schema, + session_state, + )?) + } + + let mut groups: Vec> = Vec::with_capacity(num_groups); + + groups.push(vec![false; num_of_exprs]); + + for null_count in 1..=num_of_exprs { + for null_idx in (0..num_of_exprs).combinations(null_count) { + let mut next_group: Vec = vec![false; num_of_exprs]; + null_idx.into_iter().for_each(|i| next_group[i] = true); + groups.push(next_group); + } + } + + Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) +} + +/// Expand and align a ROLLUP expression. This is a special case of GROUPING SETS +/// (see https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS) +fn create_rollup_physical_expr( + exprs: &[Expr], + input_dfschema: &DFSchema, + input_schema: &Schema, + session_state: &SessionState, +) -> Result { + let num_of_exprs = exprs.len(); + + let mut null_exprs: Vec<(Arc, String)> = + Vec::with_capacity(num_of_exprs); + let mut all_exprs: Vec<(Arc, String)> = + Vec::with_capacity(num_of_exprs); + + let mut groups: Vec> = Vec::with_capacity(num_of_exprs + 1); + + for expr in exprs { + null_exprs.push(get_null_physical_expr_pair( + expr, + input_dfschema, + input_schema, + session_state, + )?); + + all_exprs.push(get_physical_expr_pair( + expr, + input_dfschema, + input_schema, + session_state, + )?) + } + + for total in 0..=num_of_exprs { + let mut group: Vec = Vec::with_capacity(num_of_exprs); + + for index in 0..num_of_exprs { + if index < total { + group.push(false); + } else { + group.push(true); + } + } + + groups.push(group) + } + + Ok(PhysicalGroupBy::new(all_exprs, null_exprs, groups)) +} + +/// For a given logical expr, get a properly typed NULL ScalarValue physical expression +fn get_null_physical_expr_pair( + expr: &Expr, + input_dfschema: &DFSchema, + input_schema: &Schema, + session_state: &SessionState, +) -> Result<(Arc, String)> { + let physical_expr = create_physical_expr( + expr, + input_dfschema, + input_schema, + &session_state.execution_props, + )?; + let physical_name = physical_name(&expr.clone())?; + + let data_type = physical_expr.data_type(input_schema)?; + let null_value: ScalarValue = (&data_type).try_into()?; + + let null_value = Literal::new(null_value); + Ok((Arc::new(null_value), physical_name)) +} + +fn get_physical_expr_pair( + expr: &Expr, + input_dfschema: &DFSchema, + input_schema: &Schema, + session_state: &SessionState, +) -> Result<(Arc, String)> { + let physical_expr = create_physical_expr( + expr, + input_dfschema, + input_schema, + &session_state.execution_props, + )?; + let physical_name = physical_name(expr)?; + Ok((physical_expr, physical_name)) } /// Create a window expression with a name from a logical expression @@ -1303,6 +1558,7 @@ mod tests { }; use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{DFField, DFSchema, DFSchemaRef}; + use datafusion_expr::expr::GroupingSet; use datafusion_expr::sum; use datafusion_expr::{col, lit}; use fmt::Debug; @@ -1346,6 +1602,60 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_create_cube_expr() -> Result<()> { + let logical_plan = test_csv_scan().await?.build()?; + + let plan = plan(&logical_plan).await?; + + let exprs = vec![col("c1"), col("c2"), col("c3")]; + + let physical_input_schema = plan.schema(); + let physical_input_schema = physical_input_schema.as_ref(); + let logical_input_schema = logical_plan.schema(); + let session_state = make_session_state(); + + let cube = create_cube_physical_expr( + &exprs, + logical_input_schema, + physical_input_schema, + &session_state, + ); + + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[false, false, false], [true, false, false], [false, true, false], [false, false, true], [true, true, false], [true, false, true], [false, true, true], [true, true, true]] })"#; + + assert_eq!(format!("{:?}", cube), expected); + + Ok(()) + } + + #[tokio::test] + async fn test_create_rollup_expr() -> Result<()> { + let logical_plan = test_csv_scan().await?.build()?; + + let plan = plan(&logical_plan).await?; + + let exprs = vec![col("c1"), col("c2"), col("c3")]; + + let physical_input_schema = plan.schema(); + let physical_input_schema = physical_input_schema.as_ref(); + let logical_input_schema = logical_plan.schema(); + let session_state = make_session_state(); + + let rollup = create_rollup_physical_expr( + &exprs, + logical_input_schema, + physical_input_schema, + &session_state, + ); + + let expected = r#"Ok(PhysicalGroupBy { expr: [(Column { name: "c1", index: 0 }, "c1"), (Column { name: "c2", index: 1 }, "c2"), (Column { name: "c3", index: 2 }, "c3")], null_expr: [(Literal { value: Utf8(NULL) }, "c1"), (Literal { value: Int64(NULL) }, "c2"), (Literal { value: Int64(NULL) }, "c3")], groups: [[true, true, true], [false, true, true], [false, false, true], [false, false, false]] })"#; + + assert_eq!(format!("{:?}", rollup), expected); + + Ok(()) + } + #[tokio::test] async fn test_create_not() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); @@ -1620,6 +1930,34 @@ mod tests { Ok(()) } + #[tokio::test] + async fn hash_agg_grouping_set_input_schema() -> Result<()> { + let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("c1")], + vec![col("c2")], + vec![col("c1"), col("c2")], + ])); + let logical_plan = test_csv_scan_with_name("aggregate_test_100") + .await? + .aggregate(vec![grouping_set_expr], vec![sum(col("c3"))])? + .build()?; + + let execution_plan = plan(&logical_plan).await?; + let final_hash_agg = execution_plan + .as_any() + .downcast_ref::() + .expect("hash aggregate"); + assert_eq!( + "SUM(aggregate_test_100.c3)", + final_hash_agg.schema().field(2).name() + ); + // we need access to the input to the partial aggregate so that other projects can + // implement serde + assert_eq!("c3", final_hash_agg.input_schema().field(2).name()); + + Ok(()) + } + #[tokio::test] async fn hash_agg_group_by_partitioned() -> Result<()> { let logical_plan = test_csv_scan() @@ -1637,6 +1975,28 @@ mod tests { Ok(()) } + #[tokio::test] + async fn hash_agg_grouping_set_by_partitioned() -> Result<()> { + let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("c1")], + vec![col("c2")], + vec![col("c1"), col("c2")], + ])); + let logical_plan = test_csv_scan() + .await? + .aggregate(vec![grouping_set_expr], vec![sum(col("c3"))])? + .build()?; + + let execution_plan = plan(&logical_plan).await?; + let formatted = format!("{:?}", execution_plan); + + // Make sure the plan contains a FinalPartitioned, which means it will not use the Final + // mode in Aggregate (which is slower) + assert!(formatted.contains("FinalPartitioned")); + + Ok(()) + } + #[tokio::test] async fn test_explain() { let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 38f54a2a78b8..b25e83cb7eba 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -24,11 +24,14 @@ use datafusion::from_slice::FromSlice; use std::sync::Arc; use datafusion::assert_batches_eq; +use datafusion::dataframe::DataFrame; use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::logical_plan::{col, Expr}; +use datafusion::prelude::CsvReadOptions; use datafusion::{datasource::MemTable, prelude::JoinType}; -use datafusion_expr::lit; +use datafusion_expr::expr::GroupingSet; +use datafusion_expr::{avg, count, lit, sum}; #[tokio::test] async fn join() -> Result<()> { @@ -207,3 +210,217 @@ async fn select_with_alias_overwrite() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_grouping_sets() -> Result<()> { + let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("a")], + vec![col("b")], + vec![col("a"), col("b")], + ])); + + let df = create_test_table()? + .aggregate(vec![grouping_set_expr], vec![count(col("a"))])? + .sort(vec![ + Expr::Sort { + expr: Box::new(col("a")), + asc: false, + nulls_first: true, + }, + Expr::Sort { + expr: Box::new(col("b")), + asc: false, + nulls_first: true, + }, + ])?; + + let results = df.collect().await?; + + let expected = vec![ + "+-----------+-----+---------------+", + "| a | b | COUNT(test.a) |", + "+-----------+-----+---------------+", + "| | 100 | 1 |", + "| | 10 | 2 |", + "| | 1 | 1 |", + "| abcDEF | | 1 |", + "| abcDEF | 1 | 1 |", + "| abc123 | | 1 |", + "| abc123 | 10 | 1 |", + "| CBAdef | | 1 |", + "| CBAdef | 10 | 1 |", + "| 123AbcDef | | 1 |", + "| 123AbcDef | 100 | 1 |", + "+-----------+-----+---------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn test_grouping_sets_count() -> Result<()> { + let ctx = SessionContext::new(); + + let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("c1")], + vec![col("c2")], + ])); + + let df = aggregates_table(&ctx) + .await? + .aggregate(vec![grouping_set_expr], vec![count(lit(1))])? + .sort(vec![ + Expr::Sort { + expr: Box::new(col("c1")), + asc: false, + nulls_first: true, + }, + Expr::Sort { + expr: Box::new(col("c2")), + asc: false, + nulls_first: true, + }, + ])?; + + let results = df.collect().await?; + + let expected = vec![ + "+----+----+-----------------+", + "| c1 | c2 | COUNT(Int32(1)) |", + "+----+----+-----------------+", + "| | 5 | 14 |", + "| | 4 | 23 |", + "| | 3 | 19 |", + "| | 2 | 22 |", + "| | 1 | 22 |", + "| e | | 21 |", + "| d | | 18 |", + "| c | | 21 |", + "| b | | 19 |", + "| a | | 21 |", + "+----+----+-----------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { + let ctx = SessionContext::new(); + + let grouping_set_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("c1")], + vec![col("c2")], + vec![col("c1"), col("c2")], + ])); + + let df = aggregates_table(&ctx) + .await? + .aggregate( + vec![grouping_set_expr], + vec![ + sum(col("c3")).alias("sum_c3"), + avg(col("c3")).alias("avg_c3"), + ], + )? + .sort(vec![ + Expr::Sort { + expr: Box::new(col("c1")), + asc: false, + nulls_first: true, + }, + Expr::Sort { + expr: Box::new(col("c2")), + asc: false, + nulls_first: true, + }, + ])?; + + let results = df.collect().await?; + + let expected = vec![ + "+----+----+--------+---------------------+", + "| c1 | c2 | sum_c3 | avg_c3 |", + "+----+----+--------+---------------------+", + "| | 5 | -194 | -13.857142857142858 |", + "| | 4 | 29 | 1.2608695652173914 |", + "| | 3 | 395 | 20.789473684210527 |", + "| | 2 | 184 | 8.363636363636363 |", + "| | 1 | 367 | 16.681818181818183 |", + "| e | | 847 | 40.333333333333336 |", + "| e | 5 | -22 | -11 |", + "| e | 4 | 261 | 37.285714285714285 |", + "| e | 3 | 192 | 48 |", + "| e | 2 | 189 | 37.8 |", + "| e | 1 | 227 | 75.66666666666667 |", + "| d | | 458 | 25.444444444444443 |", + "| d | 5 | -99 | -49.5 |", + "| d | 4 | 162 | 54 |", + "| d | 3 | 124 | 41.333333333333336 |", + "| d | 2 | 328 | 109.33333333333333 |", + "| d | 1 | -57 | -8.142857142857142 |", + "| c | | -28 | -1.3333333333333333 |", + "| c | 5 | 24 | 12 |", + "| c | 4 | -43 | -10.75 |", + "| c | 3 | 190 | 47.5 |", + "| c | 2 | -389 | -55.57142857142857 |", + "| c | 1 | 190 | 47.5 |", + "| b | | -111 | -5.842105263157895 |", + "| b | 5 | -1 | -0.2 |", + "| b | 4 | -223 | -44.6 |", + "| b | 3 | -84 | -42 |", + "| b | 2 | 102 | 25.5 |", + "| b | 1 | 95 | 31.666666666666668 |", + "| a | | -385 | -18.333333333333332 |", + "| a | 5 | -96 | -32 |", + "| a | 4 | -128 | -32 |", + "| a | 3 | -27 | -4.5 |", + "| a | 2 | -46 | -15.333333333333334 |", + "| a | 1 | -88 | -17.6 |", + "+----+----+--------+---------------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} + +fn create_test_table() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(StringArray::from_slice(&[ + "abcDEF", + "abc123", + "CBAdef", + "123AbcDef", + ])), + Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])), + ], + )?; + + let ctx = SessionContext::new(); + + let table = MemTable::try_new(schema, vec![vec![batch]])?; + + ctx.register_table("test", Arc::new(table))?; + + ctx.table("test") +} + +async fn aggregates_table(ctx: &SessionContext) -> Result> { + let testdata = datafusion::test_util::arrow_test_data(); + + ctx.read_csv( + format!("{}/csv/aggregate_test_100.csv", testdata), + CsvReadOptions::default(), + ) + .await +} diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 08ccbe453042..61b1a1afac64 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -476,6 +476,205 @@ async fn csv_query_approx_percentile_cont() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_cube_avg() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; + + let sql = "SELECT c1, c2, AVG(c3) FROM aggregate_test_100 GROUP BY CUBE (c1, c2) ORDER BY c1, c2"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+----------------------------+", + "| c1 | c2 | AVG(aggregate_test_100.c3) |", + "+----+----+----------------------------+", + "| a | 1 | -17.6 |", + "| a | 2 | -15.333333333333334 |", + "| a | 3 | -4.5 |", + "| a | 4 | -32 |", + "| a | 5 | -32 |", + "| a | | -18.333333333333332 |", + "| b | 1 | 31.666666666666668 |", + "| b | 2 | 25.5 |", + "| b | 3 | -42 |", + "| b | 4 | -44.6 |", + "| b | 5 | -0.2 |", + "| b | | -5.842105263157895 |", + "| c | 1 | 47.5 |", + "| c | 2 | -55.57142857142857 |", + "| c | 3 | 47.5 |", + "| c | 4 | -10.75 |", + "| c | 5 | 12 |", + "| c | | -1.3333333333333333 |", + "| d | 1 | -8.142857142857142 |", + "| d | 2 | 109.33333333333333 |", + "| d | 3 | 41.333333333333336 |", + "| d | 4 | 54 |", + "| d | 5 | -49.5 |", + "| d | | 25.444444444444443 |", + "| e | 1 | 75.66666666666667 |", + "| e | 2 | 37.8 |", + "| e | 3 | 48 |", + "| e | 4 | 37.285714285714285 |", + "| e | 5 | -11 |", + "| e | | 40.333333333333336 |", + "| | 1 | 16.681818181818183 |", + "| | 2 | 8.363636363636363 |", + "| | 3 | 20.789473684210527 |", + "| | 4 | 1.2608695652173914 |", + "| | 5 | -13.857142857142858 |", + "| | | 7.81 |", + "+----+----+----------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_rollup_avg() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; + + let sql = "SELECT c1, c2, c3, AVG(c4) FROM aggregate_test_100 GROUP BY ROLLUP (c1, c2, c3) ORDER BY c1, c2, c3"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+------+----------------------------+", + "| c1 | c2 | c3 | AVG(aggregate_test_100.c4) |", + "+----+----+------+----------------------------+", + "| a | 1 | -85 | -15154 |", + "| a | 1 | -56 | 8692 |", + "| a | 1 | -25 | 15295 |", + "| a | 1 | -5 | 12636 |", + "| a | 1 | 83 | -14704 |", + "| a | 1 | | 1353 |", + "| a | 2 | -48 | -18025 |", + "| a | 2 | -43 | 13080 |", + "| a | 2 | 45 | 15673 |", + "| a | 2 | | 3576 |", + "| a | 3 | -72 | -11122 |", + "| a | 3 | -12 | -9168 |", + "| a | 3 | 13 | 22338.5 |", + "| a | 3 | 14 | 28162 |", + "| a | 3 | 17 | -22796 |", + "| a | 3 | | 4958.833333333333 |", + "| a | 4 | -101 | 11640 |", + "| a | 4 | -54 | -2376 |", + "| a | 4 | -38 | 20744 |", + "| a | 4 | 65 | -28462 |", + "| a | 4 | | 386.5 |", + "| a | 5 | -101 | -12484 |", + "| a | 5 | -31 | -12907 |", + "| a | 5 | 36 | -16974 |", + "| a | 5 | | -14121.666666666666 |", + "| a | | | 306.04761904761904 |", + "| b | 1 | 12 | 7652 |", + "| b | 1 | 29 | -18218 |", + "| b | 1 | 54 | -18410 |", + "| b | 1 | | -9658.666666666666 |", + "| b | 2 | -60 | -21739 |", + "| b | 2 | 31 | 23127 |", + "| b | 2 | 63 | 21456 |", + "| b | 2 | 68 | 15874 |", + "| b | 2 | | 9679.5 |", + "| b | 3 | -101 | -13217 |", + "| b | 3 | 17 | 14457 |", + "| b | 3 | | 620 |", + "| b | 4 | -117 | 19316 |", + "| b | 4 | -111 | -1967 |", + "| b | 4 | -59 | 25286 |", + "| b | 4 | 17 | -28070 |", + "| b | 4 | 47 | 20690 |", + "| b | 4 | | 7051 |", + "| b | 5 | -82 | 22080 |", + "| b | 5 | -44 | 15788 |", + "| b | 5 | -5 | 24896 |", + "| b | 5 | 62 | 16337 |", + "| b | 5 | 68 | 21576 |", + "| b | 5 | | 20135.4 |", + "| b | | | 7732.315789473684 |", + "| c | 1 | -24 | -24085 |", + "| c | 1 | 41 | -4667 |", + "| c | 1 | 70 | 27752 |", + "| c | 1 | 103 | -22186 |", + "| c | 1 | | -5796.5 |", + "| c | 2 | -117 | -30187 |", + "| c | 2 | -107 | -2904 |", + "| c | 2 | -106 | -1114 |", + "| c | 2 | -60 | -16312 |", + "| c | 2 | -29 | 25305 |", + "| c | 2 | 1 | 18109 |", + "| c | 2 | 29 | -3855 |", + "| c | 2 | | -1565.4285714285713 |", + "| c | 3 | -2 | -18655 |", + "| c | 3 | 22 | 13741 |", + "| c | 3 | 73 | -9565 |", + "| c | 3 | 97 | 29106 |", + "| c | 3 | | 3656.75 |", + "| c | 4 | -90 | -2935 |", + "| c | 4 | -79 | 5281 |", + "| c | 4 | 3 | -30508 |", + "| c | 4 | 123 | 16620 |", + "| c | 4 | | -2885.5 |", + "| c | 5 | -94 | -15880 |", + "| c | 5 | 118 | 19208 |", + "| c | 5 | | 1664 |", + "| c | | | -1320.5238095238096 |", + "| d | 1 | -99 | 5613 |", + "| d | 1 | -98 | 13630 |", + "| d | 1 | -72 | 25590 |", + "| d | 1 | -8 | 27138 |", + "| d | 1 | 38 | 18384 |", + "| d | 1 | 57 | 28781 |", + "| d | 1 | 125 | 31106 |", + "| d | 1 | | 21463.14285714286 |", + "| d | 2 | 93 | -12642 |", + "| d | 2 | 113 | 3917 |", + "| d | 2 | 122 | 10130 |", + "| d | 2 | | 468.3333333333333 |", + "| d | 3 | -76 | 8809 |", + "| d | 3 | 77 | 15091 |", + "| d | 3 | 123 | 29533 |", + "| d | 3 | | 17811 |", + "| d | 4 | 5 | -7688 |", + "| d | 4 | 55 | -1471 |", + "| d | 4 | 102 | -24558 |", + "| d | 4 | | -11239 |", + "| d | 5 | -59 | 2045 |", + "| d | 5 | -40 | 22614 |", + "| d | 5 | | 12329.5 |", + "| d | | | 10890.111111111111 |", + "| e | 1 | 36 | -21481 |", + "| e | 1 | 71 | -5479 |", + "| e | 1 | 120 | 10837 |", + "| e | 1 | | -5374.333333333333 |", + "| e | 2 | -61 | -2888 |", + "| e | 2 | 49 | 24495 |", + "| e | 2 | 52 | 5666 |", + "| e | 2 | 97 | 18167 |", + "| e | 2 | | 10221.2 |", + "| e | 3 | -95 | 13611 |", + "| e | 3 | 71 | 194 |", + "| e | 3 | 104 | -25136 |", + "| e | 3 | 112 | -6823 |", + "| e | 3 | | -4538.5 |", + "| e | 4 | -56 | -31500 |", + "| e | 4 | -53 | 13788 |", + "| e | 4 | 30 | -16110 |", + "| e | 4 | 73 | -22501 |", + "| e | 4 | 74 | -12612 |", + "| e | 4 | 96 | -30336 |", + "| e | 4 | 97 | -13181 |", + "| e | 4 | | -16064.57142857143 |", + "| e | 5 | -86 | 32514 |", + "| e | 5 | 64 | -26526 |", + "| e | 5 | | 2994 |", + "| e | | | -4268.333333333333 |", + "| | | | 2319.97 |", + "+----+----+------+----------------------------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> { let ctx = SessionContext::new(); @@ -583,6 +782,200 @@ async fn csv_query_sum_crossjoin() { assert_batches_eq!(expected, &actual); } +#[tokio::test] +async fn csv_query_cube_sum_crossjoin() { + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; + let sql = "SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY CUBE (a.c1, b.c1) ORDER BY a.c1, b.c1"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+-----------+", + "| c1 | c1 | SUM(a.c2) |", + "+----+----+-----------+", + "| a | a | 1260 |", + "| a | b | 1140 |", + "| a | c | 1260 |", + "| a | d | 1080 |", + "| a | e | 1260 |", + "| a | | 6000 |", + "| b | a | 1302 |", + "| b | b | 1178 |", + "| b | c | 1302 |", + "| b | d | 1116 |", + "| b | e | 1302 |", + "| b | | 6200 |", + "| c | a | 1176 |", + "| c | b | 1064 |", + "| c | c | 1176 |", + "| c | d | 1008 |", + "| c | e | 1176 |", + "| c | | 5600 |", + "| d | a | 924 |", + "| d | b | 836 |", + "| d | c | 924 |", + "| d | d | 792 |", + "| d | e | 924 |", + "| d | | 4400 |", + "| e | a | 1323 |", + "| e | b | 1197 |", + "| e | c | 1323 |", + "| e | d | 1134 |", + "| e | e | 1323 |", + "| e | | 6300 |", + "| | a | 5985 |", + "| | b | 5415 |", + "| | c | 5985 |", + "| | d | 5130 |", + "| | e | 5985 |", + "| | | 28500 |", + "+----+----+-----------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_cube_distinct_count() { + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; + let sql = "SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY CUBE (c1,c2) ORDER BY c1,c2"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+---------------------------------------+", + "| c1 | c2 | COUNT(DISTINCT aggregate_test_100.c3) |", + "+----+----+---------------------------------------+", + "| a | 1 | 5 |", + "| a | 2 | 3 |", + "| a | 3 | 5 |", + "| a | 4 | 4 |", + "| a | 5 | 3 |", + "| a | | 19 |", + "| b | 1 | 3 |", + "| b | 2 | 4 |", + "| b | 3 | 2 |", + "| b | 4 | 5 |", + "| b | 5 | 5 |", + "| b | | 17 |", + "| c | 1 | 4 |", + "| c | 2 | 7 |", + "| c | 3 | 4 |", + "| c | 4 | 4 |", + "| c | 5 | 2 |", + "| c | | 21 |", + "| d | 1 | 7 |", + "| d | 2 | 3 |", + "| d | 3 | 3 |", + "| d | 4 | 3 |", + "| d | 5 | 2 |", + "| d | | 18 |", + "| e | 1 | 3 |", + "| e | 2 | 4 |", + "| e | 3 | 4 |", + "| e | 4 | 7 |", + "| e | 5 | 2 |", + "| e | | 18 |", + "| | 1 | 22 |", + "| | 2 | 20 |", + "| | 3 | 17 |", + "| | 4 | 23 |", + "| | 5 | 14 |", + "| | | 80 |", + "+----+----+---------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_rollup_distinct_count() { + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; + let sql = "SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY ROLLUP (c1,c2) ORDER BY c1,c2"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+---------------------------------------+", + "| c1 | c2 | COUNT(DISTINCT aggregate_test_100.c3) |", + "+----+----+---------------------------------------+", + "| a | 1 | 5 |", + "| a | 2 | 3 |", + "| a | 3 | 5 |", + "| a | 4 | 4 |", + "| a | 5 | 3 |", + "| a | | 19 |", + "| b | 1 | 3 |", + "| b | 2 | 4 |", + "| b | 3 | 2 |", + "| b | 4 | 5 |", + "| b | 5 | 5 |", + "| b | | 17 |", + "| c | 1 | 4 |", + "| c | 2 | 7 |", + "| c | 3 | 4 |", + "| c | 4 | 4 |", + "| c | 5 | 2 |", + "| c | | 21 |", + "| d | 1 | 7 |", + "| d | 2 | 3 |", + "| d | 3 | 3 |", + "| d | 4 | 3 |", + "| d | 5 | 2 |", + "| d | | 18 |", + "| e | 1 | 3 |", + "| e | 2 | 4 |", + "| e | 3 | 4 |", + "| e | 4 | 7 |", + "| e | 5 | 2 |", + "| e | | 18 |", + "| | | 80 |", + "+----+----+---------------------------------------+", + ]; + assert_batches_eq!(expected, &actual); +} + +#[tokio::test] +async fn csv_query_rollup_sum_crossjoin() { + let ctx = SessionContext::new(); + register_aggregate_csv_by_sql(&ctx).await; + let sql = "SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY ROLLUP (a.c1, b.c1) ORDER BY a.c1, b.c1"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+-----------+", + "| c1 | c1 | SUM(a.c2) |", + "+----+----+-----------+", + "| a | a | 1260 |", + "| a | b | 1140 |", + "| a | c | 1260 |", + "| a | d | 1080 |", + "| a | e | 1260 |", + "| a | | 6000 |", + "| b | a | 1302 |", + "| b | b | 1178 |", + "| b | c | 1302 |", + "| b | d | 1116 |", + "| b | e | 1302 |", + "| b | | 6200 |", + "| c | a | 1176 |", + "| c | b | 1064 |", + "| c | c | 1176 |", + "| c | d | 1008 |", + "| c | e | 1176 |", + "| c | | 5600 |", + "| d | a | 924 |", + "| d | b | 836 |", + "| d | c | 924 |", + "| d | d | 792 |", + "| d | e | 924 |", + "| d | | 4400 |", + "| e | a | 1323 |", + "| e | b | 1197 |", + "| e | c | 1323 |", + "| e | d | 1134 |", + "| e | e | 1323 |", + "| e | | 6300 |", + "| | | 28500 |", + "+----+----+-----------+", + ]; + assert_batches_eq!(expected, &actual); +} + #[tokio::test] async fn query_count_without_from() -> Result<()> { let ctx = SessionContext::new(); @@ -675,6 +1068,59 @@ async fn csv_query_array_agg_with_overflow() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_array_cube_agg_with_overflow() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = + "select c1, c2, sum(c3) sum_c3, avg(c3) avg_c3, max(c3) max_c3, min(c3) min_c3, count(c3) count_c3 from aggregate_test_100 group by CUBE (c1,c2) order by c1, c2"; + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+----+----+--------+---------------------+--------+--------+----------+", + "| c1 | c2 | sum_c3 | avg_c3 | max_c3 | min_c3 | count_c3 |", + "+----+----+--------+---------------------+--------+--------+----------+", + "| a | 1 | -88 | -17.6 | 83 | -85 | 5 |", + "| a | 2 | -46 | -15.333333333333334 | 45 | -48 | 3 |", + "| a | 3 | -27 | -4.5 | 17 | -72 | 6 |", + "| a | 4 | -128 | -32 | 65 | -101 | 4 |", + "| a | 5 | -96 | -32 | 36 | -101 | 3 |", + "| a | | -385 | -18.333333333333332 | 83 | -101 | 21 |", + "| b | 1 | 95 | 31.666666666666668 | 54 | 12 | 3 |", + "| b | 2 | 102 | 25.5 | 68 | -60 | 4 |", + "| b | 3 | -84 | -42 | 17 | -101 | 2 |", + "| b | 4 | -223 | -44.6 | 47 | -117 | 5 |", + "| b | 5 | -1 | -0.2 | 68 | -82 | 5 |", + "| b | | -111 | -5.842105263157895 | 68 | -117 | 19 |", + "| c | 1 | 190 | 47.5 | 103 | -24 | 4 |", + "| c | 2 | -389 | -55.57142857142857 | 29 | -117 | 7 |", + "| c | 3 | 190 | 47.5 | 97 | -2 | 4 |", + "| c | 4 | -43 | -10.75 | 123 | -90 | 4 |", + "| c | 5 | 24 | 12 | 118 | -94 | 2 |", + "| c | | -28 | -1.3333333333333333 | 123 | -117 | 21 |", + "| d | 1 | -57 | -8.142857142857142 | 125 | -99 | 7 |", + "| d | 2 | 328 | 109.33333333333333 | 122 | 93 | 3 |", + "| d | 3 | 124 | 41.333333333333336 | 123 | -76 | 3 |", + "| d | 4 | 162 | 54 | 102 | 5 | 3 |", + "| d | 5 | -99 | -49.5 | -40 | -59 | 2 |", + "| d | | 458 | 25.444444444444443 | 125 | -99 | 18 |", + "| e | 1 | 227 | 75.66666666666667 | 120 | 36 | 3 |", + "| e | 2 | 189 | 37.8 | 97 | -61 | 5 |", + "| e | 3 | 192 | 48 | 112 | -95 | 4 |", + "| e | 4 | 261 | 37.285714285714285 | 97 | -56 | 7 |", + "| e | 5 | -22 | -11 | 64 | -86 | 2 |", + "| e | | 847 | 40.333333333333336 | 120 | -95 | 21 |", + "| | 1 | 367 | 16.681818181818183 | 125 | -99 | 22 |", + "| | 2 | 184 | 8.363636363636363 | 122 | -117 | 22 |", + "| | 3 | 395 | 20.789473684210527 | 123 | -101 | 19 |", + "| | 4 | 29 | 1.2608695652173914 | 123 | -117 | 23 |", + "| | 5 | -194 | -13.857142857142858 | 118 | -101 | 14 |", + "| | | 781 | 7.81 | 125 | -117 | 100 |", + "+----+----+--------+---------------------+--------+--------+----------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { let ctx = SessionContext::new(); @@ -1223,6 +1669,79 @@ async fn count_aggregated() -> Result<()> { Ok(()) } +#[tokio::test] +async fn count_aggregated_cube() -> Result<()> { + let results = execute_with_partition( + "SELECT c1, c2, COUNT(c3) FROM test GROUP BY CUBE (c1, c2) ORDER BY c1, c2", + 4, + ) + .await?; + + let expected = vec![ + "+----+----+----------------+", + "| c1 | c2 | COUNT(test.c3) |", + "+----+----+----------------+", + "| | | 40 |", + "| | 1 | 4 |", + "| | 10 | 4 |", + "| | 2 | 4 |", + "| | 3 | 4 |", + "| | 4 | 4 |", + "| | 5 | 4 |", + "| | 6 | 4 |", + "| | 7 | 4 |", + "| | 8 | 4 |", + "| | 9 | 4 |", + "| 0 | | 10 |", + "| 0 | 1 | 1 |", + "| 0 | 10 | 1 |", + "| 0 | 2 | 1 |", + "| 0 | 3 | 1 |", + "| 0 | 4 | 1 |", + "| 0 | 5 | 1 |", + "| 0 | 6 | 1 |", + "| 0 | 7 | 1 |", + "| 0 | 8 | 1 |", + "| 0 | 9 | 1 |", + "| 1 | | 10 |", + "| 1 | 1 | 1 |", + "| 1 | 10 | 1 |", + "| 1 | 2 | 1 |", + "| 1 | 3 | 1 |", + "| 1 | 4 | 1 |", + "| 1 | 5 | 1 |", + "| 1 | 6 | 1 |", + "| 1 | 7 | 1 |", + "| 1 | 8 | 1 |", + "| 1 | 9 | 1 |", + "| 2 | | 10 |", + "| 2 | 1 | 1 |", + "| 2 | 10 | 1 |", + "| 2 | 2 | 1 |", + "| 2 | 3 | 1 |", + "| 2 | 4 | 1 |", + "| 2 | 5 | 1 |", + "| 2 | 6 | 1 |", + "| 2 | 7 | 1 |", + "| 2 | 8 | 1 |", + "| 2 | 9 | 1 |", + "| 3 | | 10 |", + "| 3 | 1 | 1 |", + "| 3 | 10 | 1 |", + "| 3 | 2 | 1 |", + "| 3 | 3 | 1 |", + "| 3 | 4 | 1 |", + "| 3 | 5 | 1 |", + "| 3 | 6 | 1 |", + "| 3 | 7 | 1 |", + "| 3 | 8 | 1 |", + "| 3 | 9 | 1 |", + "+----+----+----------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + Ok(()) +} + #[tokio::test] async fn simple_avg() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7cf161697247..202605b4d7a9 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -270,6 +270,27 @@ pub enum GroupingSet { GroupingSets(Vec>), } +impl GroupingSet { + /// Return all distinct exprs in the grouping set. For `CUBE` and `ROLLUP` this + /// is just the underlying list of exprs. For `GROUPING SET` we need to deduplicate + /// the exprs in the underlying sets. + pub fn distinct_expr(&self) -> Vec { + match self { + GroupingSet::Rollup(exprs) => exprs.clone(), + GroupingSet::Cube(exprs) => exprs.clone(), + GroupingSet::GroupingSets(groups) => { + let mut exprs: Vec = vec![]; + for exp in groups.iter().flatten() { + if !exprs.contains(exp) { + exprs.push(exp.clone()); + } + } + exprs + } + } + } +} + /// Fixed seed for the hashing so that Ords are consistent across runs const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0); diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 14eeb2c82551..76bd9a9753f4 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -17,6 +17,7 @@ //! Functions for creating logical expressions +use crate::expr::GroupingSet; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, lit, logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF, @@ -226,6 +227,21 @@ pub fn scalar_subquery(subquery: Arc) -> Expr { Expr::ScalarSubquery(Subquery { subquery }) } +/// Create a grouping set +pub fn grouping_set(exprs: Vec>) -> Expr { + Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) +} + +/// Create a grouping set for all combination of `exprs` +pub fn cube(exprs: Vec) -> Expr { + Expr::GroupingSet(GroupingSet::Cube(exprs)) +} + +/// Create a grouping set for rollup +pub fn rollup(exprs: Vec) -> Expr { + Expr::GroupingSet(GroupingSet::Rollup(exprs)) +} + // TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many // varying arity functions /// Create an convenience function representing a unary scalar function diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 8d58241eaa65..083b66c325ae 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -18,7 +18,9 @@ //! This module provides a builder for creating LogicalPlans use crate::expr_rewriter::{normalize_col, normalize_cols, rewrite_sort_cols_by_aggs}; -use crate::utils::{columnize_expr, exprlist_to_fields, from_plan}; +use crate::utils::{ + columnize_expr, exprlist_to_fields, from_plan, grouping_set_to_exprlist, +}; use crate::{and, binary_expr, Operator}; use crate::{ logical_plan::{ @@ -694,7 +696,10 @@ impl LogicalPlanBuilder { ) -> Result { let group_expr = normalize_cols(group_expr, &self.plan)?; let aggr_expr = normalize_cols(aggr_expr, &self.plan)?; - let all_expr = group_expr.iter().chain(aggr_expr.iter()); + + let grouping_expr: Vec = grouping_set_to_exprlist(group_expr.as_slice())?; + + let all_expr = grouping_expr.iter().chain(aggr_expr.iter()); validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?; let aggr_schema = DFSchema::new_with_metadata( exprlist_to_fields(all_expr, &self.plan)?, diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 2120acaed615..a85a817a89a8 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -45,6 +45,22 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result Ok(()) } +/// Find all distinct exprs in a list of group by expressions. If the +/// first element is a `GroupingSet` expression then it must be the only expr. +pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { + if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { + if group_expr.len() > 1 { + return Err(DataFusionError::Plan( + "Invalid group by expressions, GroupingSet must be the only expression" + .to_string(), + )); + } + Ok(grouping_set.distinct_expr()) + } else { + Ok(group_expr.to_vec()) + } +} + /// Recursively walk an expression tree, collecting the unique set of column names /// referenced in the expression struct ColumnNameVisitor<'a> { diff --git a/datafusion/optimizer/src/projection_push_down.rs b/datafusion/optimizer/src/projection_push_down.rs index c9aee1e03d3e..ae2cc4fce87d 100644 --- a/datafusion/optimizer/src/projection_push_down.rs +++ b/datafusion/optimizer/src/projection_push_down.rs @@ -24,6 +24,7 @@ use arrow::error::Result as ArrowResult; use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ToDFSchema, }; +use datafusion_expr::utils::grouping_set_to_exprlist; use datafusion_expr::{ logical_plan::{ builder::{build_join_schema, LogicalPlanBuilder}, @@ -314,7 +315,10 @@ fn optimize_plan( // * remove any aggregate expression that is not required // * construct the new set of required columns - exprlist_to_columns(group_expr, &mut new_required_columns)?; + // Find distinct group by exprs in the case where we have a grouping set + let all_group_expr: Vec = grouping_set_to_exprlist(group_expr)?; + + exprlist_to_columns(&all_group_expr, &mut new_required_columns)?; // Gather all columns needed for expressions in this Aggregate let mut new_aggr_expr = Vec::new(); diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index c508b9772c34..80214f302cb5 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -19,6 +19,7 @@ use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, Result}; +use datafusion_expr::utils::grouping_set_to_exprlist; use datafusion_expr::{ col, logical_plan::{Aggregate, LogicalPlan, Projection}, @@ -62,9 +63,11 @@ fn optimize(plan: &LogicalPlan) -> Result { schema, group_expr, }) => { - if is_single_distinct_agg(plan) { + if is_single_distinct_agg(plan) && !contains_grouping_set(group_expr) { let mut group_fields_set = HashSet::new(); - let mut all_group_args = group_expr.clone(); + let base_group_expr = grouping_set_to_exprlist(group_expr)?; + let mut all_group_args: Vec = group_expr.clone(); + // remove distinct and collection args let new_aggr_expr = aggr_expr .iter() @@ -87,7 +90,9 @@ fn optimize(plan: &LogicalPlan) -> Result { }) .collect::>(); - let all_field = all_group_args + let all_group_expr = grouping_set_to_exprlist(&all_group_args)?; + + let all_field = all_group_expr .iter() .map(|expr| expr.to_field(input.schema()).unwrap()) .collect::>(); @@ -106,7 +111,7 @@ fn optimize(plan: &LogicalPlan) -> Result { let grouped_agg = optimize_children(&grouped_agg); let final_agg_schema = Arc::new( DFSchema::new_with_metadata( - group_expr + base_group_expr .iter() .chain(new_aggr_expr.iter()) .map(|expr| expr.to_field(&grouped_schema).unwrap()) @@ -115,18 +120,12 @@ fn optimize(plan: &LogicalPlan) -> Result { ) .unwrap(), ); - let final_agg = LogicalPlan::Aggregate(Aggregate { - input: Arc::new(grouped_agg.unwrap()), - group_expr: group_expr.clone(), - aggr_expr: new_aggr_expr, - schema: final_agg_schema.clone(), - }); // so the aggregates are displayed in the same way even after the rewrite let mut alias_expr: Vec = Vec::new(); - final_agg - .expressions() + base_group_expr .iter() + .chain(new_aggr_expr.iter()) .enumerate() .for_each(|(i, field)| { alias_expr.push(columnize_expr( @@ -135,11 +134,18 @@ fn optimize(plan: &LogicalPlan) -> Result { )); }); + let final_agg = LogicalPlan::Aggregate(Aggregate { + input: Arc::new(grouped_agg.unwrap()), + group_expr: group_expr.clone(), + aggr_expr: new_aggr_expr, + schema: final_agg_schema, + }); + Ok(LogicalPlan::Projection(Projection { expr: alias_expr, input: Arc::new(final_agg), schema: schema.clone(), - alias: Option::None, + alias: None, })) } else { optimize_children(plan) @@ -185,6 +191,10 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> bool { } } +fn contains_grouping_set(expr: &[Expr]) -> bool { + matches!(expr.first(), Some(Expr::GroupingSet(_))) +} + impl OptimizerRule for SingleDistinctToGroupBy { fn optimize( &self, @@ -202,6 +212,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { mod tests { use super::*; use crate::test::*; + use datafusion_expr::expr::GroupingSet; use datafusion_expr::{ col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, AggregateFunction, @@ -212,6 +223,7 @@ mod tests { let optimized_plan = rule .optimize(plan, &OptimizerConfig::new()) .expect("failed to optimize plan"); + let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); assert_eq!(formatted_plan, expected); } @@ -250,6 +262,69 @@ mod tests { Ok(()) } + // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET + #[test] + fn single_distinct_and_grouping_set() -> Result<()> { + let table_scan = test_table_scan()?; + + let grouping_set = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("a")], + vec![col("b")], + ])); + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])? + .build()?; + + // Should not be optimized + let expected = "Aggregate: groupBy=[[GROUPING SETS ((#test.a), (#test.b))]], aggr=[[COUNT(DISTINCT #test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET + #[test] + fn single_distinct_and_cube() -> Result<()> { + let table_scan = test_table_scan()?; + + let grouping_set = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])? + .build()?; + + println!("{:?}", plan); + + // Should not be optimized + let expected = "Aggregate: groupBy=[[CUBE (#test.a, #test.b)]], aggr=[[COUNT(DISTINCT #test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + + // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET + #[test] + fn single_distinct_and_rollup() -> Result<()> { + let table_scan = test_table_scan()?; + + let grouping_set = + Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![grouping_set], vec![count_distinct(col("c"))])? + .build()?; + + // Should not be optimized + let expected = "Aggregate: groupBy=[[ROLLUP (#test.a, #test.b)]], aggr=[[COUNT(DISTINCT #test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ + \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } + #[test] fn single_distinct_expr() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 4522cd63ec2a..dffc8ec2f68f 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -300,9 +300,33 @@ message LogicalExprNode { ScalarUDFExprNode scalar_udf_expr = 20; GetIndexedField get_indexed_field = 21; + + GroupingSetNode grouping_set = 22; + + CubeNode cube = 23; + + RollupNode rollup = 24; } } +message LogicalExprList { + repeated LogicalExprNode expr = 1; +} + +message GroupingSetNode { + repeated LogicalExprList expr = 1; +} + +message CubeNode { + repeated LogicalExprNode expr = 1; +} + +message RollupNode { + repeated LogicalExprNode expr = 1; +} + + + message GetIndexedField { LogicalExprNode expr = 1; ScalarValue key = 2; diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index c684b785e6fc..279cb8e40203 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -20,12 +20,17 @@ use crate::protobuf::plan_type::PlanTypeEnum::{ FinalLogicalPlan, FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, OptimizedLogicalPlan, OptimizedPhysicalPlan, }; -use crate::protobuf::{OptimizedLogicalPlanType, OptimizedPhysicalPlanType}; +use crate::protobuf::{ + CubeNode, GroupingSetNode, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, + RollupNode, +}; use arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode}; use datafusion::logical_plan::FunctionRegistry; use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue, }; +use datafusion_expr::expr::GroupingSet; +use datafusion_expr::expr::GroupingSet::GroupingSets; use datafusion_expr::{ abs, acos, array, ascii, asin, atan, bit_length, btrim, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_part, date_trunc, digest, exp, @@ -1290,6 +1295,32 @@ pub fn parse_expr( .collect::, Error>>()?, }) } + + ExprType::GroupingSet(GroupingSetNode { expr }) => { + Ok(Expr::GroupingSet(GroupingSets( + expr.iter() + .map(|expr_list| { + expr_list + .expr + .iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, Error>>() + }) + .collect::, Error>>()?, + ))) + } + ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( + expr.iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, Error>>()?, + ))), + ExprType::Rollup(RollupNode { expr }) => { + Ok(Expr::GroupingSet(GroupingSet::Rollup( + expr.iter() + .map(|expr| parse_expr(expr, registry)) + .collect::, Error>>()?, + ))) + } } } diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 6fe1aac68624..f08a00b49374 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -62,6 +62,7 @@ mod roundtrip_tests { use datafusion::physical_plan::functions::make_scalar_function; use datafusion::prelude::{create_udf, CsvReadOptions, SessionContext}; use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; + use datafusion_expr::expr::GroupingSet; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode}; use datafusion_expr::{ col, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt, Expr, @@ -1001,4 +1002,32 @@ mod roundtrip_tests { roundtrip_expr_test!(test_expr, ctx); } + + #[test] + fn roundtrip_grouping_sets() { + let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("a")], + vec![col("b")], + vec![col("a"), col("b")], + ])); + + let ctx = SessionContext::new(); + roundtrip_expr_test!(test_expr, ctx); + } + + #[test] + fn roundtrip_rollup() { + let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); + + let ctx = SessionContext::new(); + roundtrip_expr_test!(test_expr, ctx); + } + + #[test] + fn roundtrip_cube() { + let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); + + let ctx = SessionContext::new(); + roundtrip_expr_test!(test_expr, ctx); + } } diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 8df8ff0dd91b..afe24ea892bb 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -25,12 +25,14 @@ use crate::protobuf::{ FinalLogicalPlan, FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, OptimizedLogicalPlan, OptimizedPhysicalPlan, }, - EmptyMessage, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, + CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, + OptimizedPhysicalPlanType, RollupNode, }; use arrow::datatypes::{ DataType, Field, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionMode, }; use datafusion_common::{Column, DFField, DFSchemaRef, ScalarValue}; +use datafusion_expr::expr::GroupingSet; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, BuiltInWindowFunction, BuiltinScalarFunction, Expr, WindowFrame, WindowFrameBound, @@ -718,9 +720,42 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { }, ))), }, - Expr::QualifiedWildcard { .. } - | Expr::TryCast { .. } - | Expr::GroupingSet(_) => unimplemented!(), + + Expr::GroupingSet(GroupingSet::Cube(exprs)) => Self { + expr_type: Some(ExprType::Cube(CubeNode { + expr: exprs.iter().map(|expr| expr.try_into()).collect::, + Self::Error, + >>( + )?, + })), + }, + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => Self { + expr_type: Some(ExprType::Rollup(RollupNode { + expr: exprs.iter().map(|expr| expr.try_into()).collect::, + Self::Error, + >>( + )?, + })), + }, + Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => Self { + expr_type: Some(ExprType::GroupingSet(GroupingSetNode { + expr: exprs + .iter() + .map(|expr_list| { + Ok(LogicalExprList { + expr: expr_list + .iter() + .map(|expr| expr.try_into()) + .collect::, Self::Error>>()?, + }) + }) + .collect::, Self::Error>>()?, + })), + }, + + Expr::QualifiedWildcard { .. } | Expr::TryCast { .. } => unimplemented!(), }; Ok(expr_node)