Skip to content

Commit

Permalink
Infer count() aggregation is not null
Browse files Browse the repository at this point in the history
`count([DISTINCT [expr]])` aggregate function never returns null. Infer
non-nullness of such aggregate expression. This allows elimination of
the HAVING filter for a query such as

    SELECT ... count(*) AS c
    FROM ...
    GROUP BY ...
    HAVING c IS NOT NULL
  • Loading branch information
findepi committed Jul 3, 2024
1 parent 58f79e1 commit 0352eda
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 0 deletions.
185 changes: 185 additions & 0 deletions datafusion/optimizer/src/infer_non_null.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`InferNonNull`] infers which columns are non-nullable
use std::collections::HashSet;
use std::ops::Deref;
use std::sync::Arc;

use datafusion_common::tree_node::Transformed;
use datafusion_common::{DFSchema, Result};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::{Aggregate, Expr, LogicalPlan};

use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};

#[derive(Default)]
pub struct InferNonNull {}

impl InferNonNull {
pub fn new() -> Self {
Self::default()
}
}

impl OptimizerRule for InferNonNull {
fn name(&self) -> &str {
"infer_non_null"
}

fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::BottomUp)
}

fn supports_rewrite(&self) -> bool {
true
}

fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
if let LogicalPlan::Aggregate(aggregate) = plan {
let grouping_columns = aggregate.group_expr_len()?;
let new_non_null_fields: HashSet<_> = aggregate
.aggr_expr
.iter()
.enumerate()
.filter_map(|(i, expr)| {
let field_index = grouping_columns + i;
if !aggregate.schema.field(field_index).is_nullable() {
// Already not nullable.
return None;
}
simple_aggregate_function(expr)
.filter(|function| {
// TODO infer non-null for min, max on non-null input
function.func_def.name() == "count"
&& function.args.len() <= 1
})
.map(|_| field_index)
})
.collect();

if !new_non_null_fields.is_empty() {
let new_schema = Arc::new(DFSchema::new_with_metadata(
aggregate
.schema
.iter()
.enumerate()
.map(|(i, field)| {
let mut field = (field.0.cloned(), field.1.clone());
if new_non_null_fields.contains(&i) {
field = (
field.0,
Arc::new(
field.1.deref().clone().with_nullable(false),
),
);
}
field
})
.collect(),
aggregate.schema.metadata().clone(),
)?);

return Ok(Transformed::yes(LogicalPlan::Aggregate(
Aggregate::try_new_with_schema(
aggregate.input,
aggregate.group_expr,
aggregate.aggr_expr,
new_schema,
)?,
)));
}
return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
}

if let LogicalPlan::Window(_) = plan {
// TODO similar to Aggregate
}

if let LogicalPlan::Filter(_) = plan {
// TODO infer column being not null from filter predicates
}

Ok(Transformed::no(plan))
}
}

fn simple_aggregate_function(expr: &Expr) -> Option<&AggregateFunction> {
match expr {
Expr::AggregateFunction(ref aggregate_function) => Some(aggregate_function),
Expr::Alias(ref alias) => simple_aggregate_function(alias.expr.as_ref()),
_ => None,
}
}

#[cfg(test)]
mod tests {
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::expr_fn::max;
use datafusion_expr::expr_fn::min;
use datafusion_expr::expr_fn::wildcard;
use datafusion_expr::{col, lit, LogicalPlanBuilder};
use datafusion_functions_aggregate::expr_fn::count;
use datafusion_functions_aggregate::expr_fn::sum;

use crate::infer_non_null::InferNonNull;
use crate::{OptimizerContext, OptimizerRule};

#[test]
fn test_aggregate() -> Result<()> {
let plan = LogicalPlanBuilder::values(vec![
vec![lit(42), lit(100)],
vec![lit(42), lit(ScalarValue::Int64(None))],
])?
.aggregate(
vec![col("column1")],
vec![
min(col("column2")).alias("min"),
max(col("column2")).alias("max"),
count(wildcard()).alias("count_all"),
count(col("column2")).alias("count_non_null"),
sum(col("column2")).alias("sum"),
],
)?
.build()?;

let transformed = InferNonNull::new().rewrite(plan, &OptimizerContext::new())?;
assert!(transformed.transformed);
let new_schema = transformed.data.schema();
assert!(new_schema
.field_with_unqualified_name("column1")?
.is_nullable());
assert!(new_schema.field_with_unqualified_name("min")?.is_nullable());
assert!(new_schema.field_with_unqualified_name("max")?.is_nullable());
assert!(!new_schema
.field_with_unqualified_name("count_all")?
.is_nullable());
assert!(!new_schema
.field_with_unqualified_name("count_non_null")?
.is_nullable());
assert!(new_schema.field_with_unqualified_name("sum")?.is_nullable());

Ok(())
}
}
1 change: 1 addition & 0 deletions datafusion/optimizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub mod eliminate_one_union;
pub mod eliminate_outer_join;
pub mod extract_equijoin_predicate;
pub mod filter_null_join_keys;
pub mod infer_non_null;
pub mod optimize_projections;
pub mod optimizer;
pub mod propagate_empty_relation;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use crate::eliminate_one_union::EliminateOneUnion;
use crate::eliminate_outer_join::EliminateOuterJoin;
use crate::extract_equijoin_predicate::ExtractEquijoinPredicate;
use crate::filter_null_join_keys::FilterNullJoinKeys;
use crate::infer_non_null::InferNonNull;
use crate::optimize_projections::OptimizeProjections;
use crate::plan_signature::LogicalPlanSignature;
use crate::propagate_empty_relation::PropagateEmptyRelation;
Expand Down Expand Up @@ -245,6 +246,7 @@ impl Optimizer {
Arc::new(EliminateNestedUnion::new()),
Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(InferNonNull::new()),
Arc::new(ReplaceDistinctWithAggregate::new()),
Arc::new(EliminateJoin::new()),
Arc::new(DecorrelatePredicateSubquery::new()),
Expand Down
15 changes: 15 additions & 0 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,21 @@ fn eliminate_nested_filters() {
assert_eq!(expected, format!("{plan:?}"));
}

#[test]
fn eliminate_redundant_null_check() {
let sql = "\
SELECT col_int32, count(*) c
FROM test
GROUP BY col_int32
HAVING c IS NOT NULL";
let plan = test_sql(sql).unwrap();
let expected = "\
Projection: test.col_int32, count(*) AS c\
\n Aggregate: groupBy=[[test.col_int32]], aggr=[[count(Int64(1)) AS count(*)]]\
\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{plan:?}"));
}

#[test]
fn test_propagate_empty_relation_inner_join_and_unions() {
let sql = "\
Expand Down
2 changes: 2 additions & 0 deletions datafusion/sqllogictest/test_files/explain.slt
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ analyzed_logical_plan SAME TEXT AS ABOVE
logical_plan after eliminate_nested_union SAME TEXT AS ABOVE
logical_plan after simplify_expressions SAME TEXT AS ABOVE
logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
logical_plan after infer_non_null SAME TEXT AS ABOVE
logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE
logical_plan after eliminate_join SAME TEXT AS ABOVE
logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE
Expand Down Expand Up @@ -214,6 +215,7 @@ logical_plan after optimize_projections TableScan: simple_explain_test projectio
logical_plan after eliminate_nested_union SAME TEXT AS ABOVE
logical_plan after simplify_expressions SAME TEXT AS ABOVE
logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE
logical_plan after infer_non_null SAME TEXT AS ABOVE
logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE
logical_plan after eliminate_join SAME TEXT AS ABOVE
logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE
Expand Down

0 comments on commit 0352eda

Please sign in to comment.