-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize count agg expr with null column statistics #1063
Changes from 15 commits
bf0be0a
ca4e1c8
472c80a
ed7c838
ad5311c
3e28a0f
597338b
fb683c6
6dd40ea
df032a2
c73d6bc
62e0eeb
a47e812
42c01d1
1c65b6a
5f8c9fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -57,7 +57,13 @@ impl PhysicalOptimizerRule for AggregateStatistics { | |||||
let stats = partial_agg_exec.input().statistics(); | ||||||
let mut projections = vec![]; | ||||||
for expr in partial_agg_exec.aggr_expr() { | ||||||
if let Some((num_rows, name)) = take_optimizable_count(&**expr, &stats) { | ||||||
if let Some((non_null_rows, name)) = | ||||||
take_optimizable_column_count(&**expr, &stats) | ||||||
{ | ||||||
projections.push((expressions::lit(non_null_rows), name.to_owned())); | ||||||
} else if let Some((num_rows, name)) = | ||||||
take_optimizable_table_count(&**expr, &stats) | ||||||
{ | ||||||
projections.push((expressions::lit(num_rows), name.to_owned())); | ||||||
} else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) { | ||||||
projections.push((expressions::lit(min), name.to_owned())); | ||||||
|
@@ -127,7 +133,7 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>> | |||||
} | ||||||
|
||||||
/// If this agg_expr is a count that is defined in the statistics, return it | ||||||
fn take_optimizable_count( | ||||||
fn take_optimizable_table_count( | ||||||
agg_expr: &dyn AggregateExpr, | ||||||
stats: &Statistics, | ||||||
) -> Option<(ScalarValue, &'static str)> { | ||||||
|
@@ -153,6 +159,39 @@ fn take_optimizable_count( | |||||
None | ||||||
} | ||||||
|
||||||
/// If this agg_expr is a count that can be derived from the statistics, return it | ||||||
fn take_optimizable_column_count( | ||||||
agg_expr: &dyn AggregateExpr, | ||||||
stats: &Statistics, | ||||||
) -> Option<(ScalarValue, String)> { | ||||||
if let (Some(num_rows), Some(col_stats), Some(casted_expr)) = ( | ||||||
stats.num_rows, | ||||||
&stats.column_statistics, | ||||||
agg_expr.as_any().downcast_ref::<expressions::Count>(), | ||||||
) { | ||||||
if casted_expr.expressions().len() == 1 { | ||||||
// TODO optimize with exprs other than Column | ||||||
if let Some(col_expr) = casted_expr.expressions()[0] | ||||||
.as_any() | ||||||
.downcast_ref::<expressions::Column>() | ||||||
{ | ||||||
if let ColumnStatistics { | ||||||
null_count: Some(val), | ||||||
.. | ||||||
} = &col_stats[col_expr.index()] | ||||||
{ | ||||||
let expr = format!("COUNT({})", col_expr.name()); | ||||||
return Some(( | ||||||
ScalarValue::UInt64(Some((num_rows - val) as u64)), | ||||||
expr, | ||||||
)); | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
None | ||||||
} | ||||||
|
||||||
/// If this agg_expr is a min that is defined in the statistics, return it | ||||||
fn take_optimizable_min( | ||||||
agg_expr: &dyn AggregateExpr, | ||||||
|
@@ -237,8 +276,8 @@ mod tests { | |||||
let batch = RecordBatch::try_new( | ||||||
Arc::clone(&schema), | ||||||
vec![ | ||||||
Arc::new(Int32Array::from(vec![1, 2, 3])), | ||||||
Arc::new(Int32Array::from(vec![4, 5, 6])), | ||||||
Arc::new(Int32Array::from(vec![Some(1), Some(2), None])), | ||||||
Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])), | ||||||
], | ||||||
)?; | ||||||
|
||||||
|
@@ -250,38 +289,41 @@ mod tests { | |||||
} | ||||||
|
||||||
/// Checks that the count optimization was applied and we still get the right result | ||||||
async fn assert_count_optim_success(plan: HashAggregateExec) -> Result<()> { | ||||||
async fn assert_count_optim_success( | ||||||
plan: HashAggregateExec, | ||||||
nulls: bool, | ||||||
) -> Result<()> { | ||||||
let conf = ExecutionConfig::new(); | ||||||
let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?; | ||||||
|
||||||
let (col, count) = match nulls { | ||||||
false => (Field::new("COUNT(Uint8(1))", DataType::UInt64, false), 3), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why the column name output is different for columns with NULLs (and columns that don't have nulls) I think the difference is if the aggregate is This can be seen in the datafusion-cli on master
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this is just a case of a poorly named variable. it should really be something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
true => (Field::new("COUNT(a)", DataType::UInt64, false), 2), | ||||||
}; | ||||||
|
||||||
// A ProjectionExec is a sign that the count optimization was applied | ||||||
assert!(optimized.as_any().is::<ProjectionExec>()); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a comment here that the added There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure - added it. |
||||||
let result = common::collect(optimized.execute(0).await?).await?; | ||||||
assert_eq!( | ||||||
result[0].schema(), | ||||||
Arc::new(Schema::new(vec![Field::new( | ||||||
"COUNT(Uint8(1))", | ||||||
DataType::UInt64, | ||||||
false | ||||||
)])) | ||||||
); | ||||||
assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col]))); | ||||||
assert_eq!( | ||||||
result[0] | ||||||
.column(0) | ||||||
.as_any() | ||||||
.downcast_ref::<UInt64Array>() | ||||||
.unwrap() | ||||||
.values(), | ||||||
&[3] | ||||||
&[count] | ||||||
); | ||||||
Ok(()) | ||||||
} | ||||||
|
||||||
fn count_expr() -> Arc<dyn AggregateExpr> { | ||||||
Arc::new(Count::new( | ||||||
expressions::lit(ScalarValue::UInt8(Some(1))), | ||||||
"my_count_alias", | ||||||
DataType::UInt64, | ||||||
)) | ||||||
fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc<dyn AggregateExpr> { | ||||||
// Return appropriate expr depending if COUNT is for col or table | ||||||
let expr = match schema { | ||||||
None => expressions::lit(ScalarValue::UInt8(Some(1))), | ||||||
Some(s) => expressions::col(col.unwrap(), s).unwrap(), | ||||||
}; | ||||||
Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64)) | ||||||
} | ||||||
|
||||||
#[tokio::test] | ||||||
|
@@ -293,20 +335,47 @@ mod tests { | |||||
let partial_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Partial, | ||||||
vec![], | ||||||
vec![count_expr()], | ||||||
vec![count_expr(None, None)], | ||||||
source, | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
||||||
let final_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Final, | ||||||
vec![], | ||||||
vec![count_expr()], | ||||||
vec![count_expr(None, None)], | ||||||
Arc::new(partial_agg), | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
||||||
assert_count_optim_success(final_agg).await?; | ||||||
assert_count_optim_success(final_agg, false).await?; | ||||||
|
||||||
Ok(()) | ||||||
} | ||||||
|
||||||
#[tokio::test] | ||||||
async fn test_count_partial_with_nulls_direct_child() -> Result<()> { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not testing the code that you have added, it tests that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, thx for picking that up. Looking into it. |
||||||
// basic test case with the aggregation applied on a source with exact statistics | ||||||
let source = mock_data()?; | ||||||
let schema = source.schema(); | ||||||
|
||||||
let partial_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Partial, | ||||||
vec![], | ||||||
vec![count_expr(Some(&schema), Some("a"))], | ||||||
source, | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
||||||
let final_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Final, | ||||||
vec![], | ||||||
vec![count_expr(Some(&schema), Some("a"))], | ||||||
Arc::new(partial_agg), | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
||||||
assert_count_optim_success(final_agg, true).await?; | ||||||
|
||||||
Ok(()) | ||||||
} | ||||||
|
@@ -319,7 +388,36 @@ mod tests { | |||||
let partial_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Partial, | ||||||
vec![], | ||||||
vec![count_expr()], | ||||||
vec![count_expr(None, None)], | ||||||
source, | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
||||||
// We introduce an intermediate optimization step between the partial and final aggregtator | ||||||
let coalesce = CoalescePartitionsExec::new(Arc::new(partial_agg)); | ||||||
|
||||||
let final_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Final, | ||||||
vec![], | ||||||
vec![count_expr(None, None)], | ||||||
Arc::new(coalesce), | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
||||||
assert_count_optim_success(final_agg, false).await?; | ||||||
|
||||||
Ok(()) | ||||||
} | ||||||
|
||||||
#[tokio::test] | ||||||
async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { | ||||||
let source = mock_data()?; | ||||||
let schema = source.schema(); | ||||||
|
||||||
let partial_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Partial, | ||||||
vec![], | ||||||
vec![count_expr(Some(&schema), Some("a"))], | ||||||
source, | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
@@ -330,12 +428,12 @@ mod tests { | |||||
let final_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Final, | ||||||
vec![], | ||||||
vec![count_expr()], | ||||||
vec![count_expr(Some(&schema), Some("a"))], | ||||||
Arc::new(coalesce), | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
||||||
assert_count_optim_success(final_agg).await?; | ||||||
assert_count_optim_success(final_agg, true).await?; | ||||||
|
||||||
Ok(()) | ||||||
} | ||||||
|
@@ -359,15 +457,57 @@ mod tests { | |||||
let partial_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Partial, | ||||||
vec![], | ||||||
vec![count_expr()], | ||||||
vec![count_expr(None, None)], | ||||||
filter, | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
||||||
let final_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Final, | ||||||
vec![], | ||||||
vec![count_expr(None, None)], | ||||||
Arc::new(partial_agg), | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
||||||
let conf = ExecutionConfig::new(); | ||||||
let optimized = | ||||||
AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; | ||||||
|
||||||
// check that the original ExecutionPlan was not replaced | ||||||
assert!(optimized.as_any().is::<HashAggregateExec>()); | ||||||
|
||||||
Ok(()) | ||||||
} | ||||||
|
||||||
#[tokio::test] | ||||||
async fn test_count_with_nulls_inexact_stat() -> Result<()> { | ||||||
let source = mock_data()?; | ||||||
let schema = source.schema(); | ||||||
|
||||||
// adding a filter makes the statistics inexact | ||||||
let filter = Arc::new(FilterExec::try_new( | ||||||
expressions::binary( | ||||||
expressions::col("a", &schema)?, | ||||||
Operator::Gt, | ||||||
expressions::lit(ScalarValue::from(1u32)), | ||||||
&schema, | ||||||
)?, | ||||||
source, | ||||||
)?); | ||||||
|
||||||
let partial_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Partial, | ||||||
vec![], | ||||||
vec![count_expr(Some(&schema), Some("a"))], | ||||||
filter, | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
||||||
let final_agg = HashAggregateExec::try_new( | ||||||
AggregateMode::Final, | ||||||
vec![], | ||||||
vec![count_expr()], | ||||||
vec![count_expr(Some(&schema), Some("a"))], | ||||||
Arc::new(partial_agg), | ||||||
Arc::clone(&schema), | ||||||
)?; | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like this code handles
count(col)
whereas the code above only handlescount(*)
-- that seems strange -- perhaps we should update it so both can handlecount(col)
andcount(*)
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that
COUNT(*)
doesnt need to have a separate handler for nulls - assuming we expect same behavior as psql. For example in psql when i do the following:Does it make sense to reframe these optimizations as the following:
take_optimizable_table_count
(currenttake_optimizable_count
)=> comes fromCOUNT(*)
and returnsnum_rows
take_optimizable_column_count
(currenttake_optimizable_count_with_nulls
) => comes fromCOUNT(col)
and returnnum_rows - null_count
forcol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think those names make more sense to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok - ive updated. Let me know if anything else needed.