diff --git a/datafusion/src/physical_optimizer/aggregate_statistics.rs b/datafusion/src/physical_optimizer/aggregate_statistics.rs index 1b361dd54936..2732777de7da 100644 --- a/datafusion/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/src/physical_optimizer/aggregate_statistics.rs @@ -57,7 +57,13 @@ impl PhysicalOptimizerRule for AggregateStatistics { let stats = partial_agg_exec.input().statistics(); let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { - if let Some((num_rows, name)) = take_optimizable_count(&**expr, &stats) { + if let Some((non_null_rows, name)) = + take_optimizable_column_count(&**expr, &stats) + { + projections.push((expressions::lit(non_null_rows), name.to_owned())); + } else if let Some((num_rows, name)) = + take_optimizable_table_count(&**expr, &stats) + { projections.push((expressions::lit(num_rows), name.to_owned())); } else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) { projections.push((expressions::lit(min), name.to_owned())); @@ -127,7 +133,7 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> } /// If this agg_expr is a count that is defined in the statistics, return it -fn take_optimizable_count( +fn take_optimizable_table_count( agg_expr: &dyn AggregateExpr, stats: &Statistics, ) -> Option<(ScalarValue, &'static str)> { @@ -144,7 +150,40 @@ fn take_optimizable_count( if lit_expr.value() == &ScalarValue::UInt8(Some(1)) { return Some(( ScalarValue::UInt64(Some(num_rows as u64)), - "COUNT(Uint8(1))", + "COUNT(UInt8(1))", + )); + } + } + } + } + None +} + +/// If this agg_expr is a count that can be derived from the statistics, return it +fn take_optimizable_column_count( + agg_expr: &dyn AggregateExpr, + stats: &Statistics, +) -> Option<(ScalarValue, String)> { + if let (Some(num_rows), Some(col_stats), Some(casted_expr)) = ( + stats.num_rows, + &stats.column_statistics, + agg_expr.as_any().downcast_ref::(), + ) { + if casted_expr.expressions().len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = casted_expr.expressions()[0] + .as_any() + .downcast_ref::() + { + if let ColumnStatistics { + null_count: Some(val), + .. + } = &col_stats[col_expr.index()] + { + let expr = format!("COUNT({})", col_expr.name()); + return Some(( + ScalarValue::UInt64(Some((num_rows - val) as u64)), + expr, )); } } @@ -237,8 +276,8 @@ mod tests { let batch = RecordBatch::try_new( Arc::clone(&schema), vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(Int32Array::from(vec![4, 5, 6])), + Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), + Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])), ], )?; @@ -250,20 +289,22 @@ mod tests { } /// Checks that the count optimization was applied and we still get the right result - async fn assert_count_optim_success(plan: HashAggregateExec) -> Result<()> { + async fn assert_count_optim_success( + plan: HashAggregateExec, + nulls: bool, + ) -> Result<()> { let conf = ExecutionConfig::new(); let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?; + let (col, count) = match nulls { + false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false), 3), + true => (Field::new("COUNT(a)", DataType::UInt64, false), 2), + }; + + // A ProjectionExec is a sign that the count optimization was applied assert!(optimized.as_any().is::()); let result = common::collect(optimized.execute(0).await?).await?; - assert_eq!( - result[0].schema(), - Arc::new(Schema::new(vec![Field::new( - "COUNT(Uint8(1))", - DataType::UInt64, - false - )])) - ); + assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col]))); assert_eq!( result[0] .column(0) @@ -271,17 +312,18 @@ mod tests { .downcast_ref::() .unwrap() .values(), - &[3] + &[count] ); Ok(()) } - fn count_expr() -> Arc { - Arc::new(Count::new( - expressions::lit(ScalarValue::UInt8(Some(1))), - "my_count_alias", - DataType::UInt64, - )) + fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc { + // Return appropriate expr depending if COUNT is for col or table + let expr = match schema { + None => expressions::lit(ScalarValue::UInt8(Some(1))), + Some(s) => expressions::col(col.unwrap(), s).unwrap(), + }; + Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64)) } #[tokio::test] @@ -293,7 +335,7 @@ mod tests { let partial_agg = HashAggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr()], + vec![count_expr(None, None)], source, Arc::clone(&schema), )?; @@ -301,12 +343,39 @@ mod tests { let final_agg = HashAggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr()], + vec![count_expr(None, None)], Arc::new(partial_agg), Arc::clone(&schema), )?; - assert_count_optim_success(final_agg).await?; + assert_count_optim_success(final_agg, false).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_with_nulls_direct_child() -> Result<()> { + // basic test case with the aggregation applied on a source with exact statistics + let source = mock_data()?; + let schema = source.schema(); + + let partial_agg = HashAggregateExec::try_new( + AggregateMode::Partial, + vec![], + vec![count_expr(Some(&schema), Some("a"))], + source, + Arc::clone(&schema), + )?; + + let final_agg = HashAggregateExec::try_new( + AggregateMode::Final, + vec![], + vec![count_expr(Some(&schema), Some("a"))], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, true).await?; Ok(()) } @@ -319,7 +388,36 @@ mod tests { let partial_agg = HashAggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr()], + vec![count_expr(None, None)], + source, + Arc::clone(&schema), + )?; + + // We introduce an intermediate optimization step between the partial and final aggregtator + let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); + + let final_agg = HashAggregateExec::try_new( + AggregateMode::Final, + vec![], + vec![count_expr(None, None)], + Arc::new(coalesce), + Arc::clone(&schema), + )?; + + assert_count_optim_success(final_agg, false).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + let partial_agg = HashAggregateExec::try_new( + AggregateMode::Partial, + vec![], + vec![count_expr(Some(&schema), Some("a"))], source, Arc::clone(&schema), )?; @@ -330,12 +428,12 @@ mod tests { let final_agg = HashAggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr()], + vec![count_expr(Some(&schema), Some("a"))], Arc::new(coalesce), Arc::clone(&schema), )?; - assert_count_optim_success(final_agg).await?; + assert_count_optim_success(final_agg, true).await?; Ok(()) } @@ -359,7 +457,49 @@ mod tests { let partial_agg = HashAggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr()], + vec![count_expr(None, None)], + filter, + Arc::clone(&schema), + )?; + + let final_agg = HashAggregateExec::try_new( + AggregateMode::Final, + vec![], + vec![count_expr(None, None)], + Arc::new(partial_agg), + Arc::clone(&schema), + )?; + + let conf = ExecutionConfig::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; + + // check that the original ExecutionPlan was not replaced + assert!(optimized.as_any().is::()); + + Ok(()) + } + + #[tokio::test] + async fn test_count_with_nulls_inexact_stat() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // adding a filter makes the statistics inexact + let filter = Arc::new(FilterExec::try_new( + expressions::binary( + expressions::col("a", &schema)?, + Operator::Gt, + expressions::lit(ScalarValue::from(1u32)), + &schema, + )?, + source, + )?); + + let partial_agg = HashAggregateExec::try_new( + AggregateMode::Partial, + vec![], + vec![count_expr(Some(&schema), Some("a"))], filter, Arc::clone(&schema), )?; @@ -367,7 +507,7 @@ mod tests { let final_agg = HashAggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr()], + vec![count_expr(Some(&schema), Some("a"))], Arc::new(partial_agg), Arc::clone(&schema), )?;