From f346b759c32b12ae3df4f5a3b8edf18cae43d7b6 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 3 Jul 2024 17:11:09 +0200 Subject: [PATCH] Infer count() aggregation is not null `count([DISTINCT] [expr])` aggregate function never returns null. Infer non-nullness of such aggregate expression. This allows elimination of the HAVING filter for a query such as SELECT ... count(*) AS c FROM ... GROUP BY ... HAVING c IS NOT NULL --- datafusion/optimizer/src/infer_non_null.rs | 131 ++++++++++++++++++ datafusion/optimizer/src/lib.rs | 1 + datafusion/optimizer/src/optimizer.rs | 2 + .../sqllogictest/test_files/explain.slt | 2 + 4 files changed, 136 insertions(+) create mode 100644 datafusion/optimizer/src/infer_non_null.rs diff --git a/datafusion/optimizer/src/infer_non_null.rs b/datafusion/optimizer/src/infer_non_null.rs new file mode 100644 index 0000000000000..dc85fcbffdbc9 --- /dev/null +++ b/datafusion/optimizer/src/infer_non_null.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`InferNonNull`] infers which columns are non-nullable + +use std::collections::HashSet; +use std::ops::Deref; +use std::sync::Arc; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::{Aggregate, Expr, LogicalPlan}; + +#[derive(Default)] +pub struct InferNonNull {} + +impl InferNonNull { + pub fn new() -> Self { + Self::default() + } +} + +impl OptimizerRule for InferNonNull { + fn name(&self) -> &str { + "infer_non_null" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + if let LogicalPlan::Aggregate(aggregate) = plan { + let grouping_columns = aggregate.group_expr_len()?; + let new_non_null_fields: HashSet<_> = aggregate + .aggr_expr + .iter() + .enumerate() + .filter_map(|(i, expr)| { + let field_index = grouping_columns + i; + if !aggregate.schema.field(field_index).is_nullable() { + // Already not nullable. + return None; + } + simple_aggregate_function(expr) + .filter(|function| { + function.func_def.name() == "count" + && function.args.len() <= 1 + }) + .map(|_| field_index) + }) + .collect(); + + if !new_non_null_fields.is_empty() { + let new_schema = Arc::new(DFSchema::new_with_metadata( + aggregate + .schema + .iter() + .enumerate() + .map(|(i, field)| { + let mut field = (field.0.cloned(), field.1.clone()); + if new_non_null_fields.contains(&i) { + field = ( + field.0, + Arc::new( + field.1.deref().clone().with_nullable(false), + ), + ); + } + field + }) + .collect(), + aggregate.schema.metadata().clone(), + )?); + + return Ok(Transformed::yes(LogicalPlan::Aggregate( + Aggregate::try_new_with_schema( + aggregate.input, + aggregate.group_expr, + aggregate.aggr_expr, + new_schema, + )?, + ))); + } + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + if let LogicalPlan::Window(_) = plan { + // TODO similar to Aggregate + } + + if let LogicalPlan::Filter(_) = plan { + // TODO infer column being not null from filter predicates + } + + Ok(Transformed::no(plan)) + } +} + +fn simple_aggregate_function(expr: &Expr) -> Option<&AggregateFunction> { + match expr { + Expr::AggregateFunction(ref aggregate_function) => Some(aggregate_function), + Expr::Alias(ref alias) => simple_aggregate_function(alias.expr.as_ref()), + _ => None, + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index a6a9e5cf26eaf..2fb7777685ab4 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -43,6 +43,7 @@ pub mod eliminate_one_union; pub mod eliminate_outer_join; pub mod extract_equijoin_predicate; pub mod filter_null_join_keys; +pub mod infer_non_null; pub mod optimize_projections; pub mod optimizer; pub mod propagate_empty_relation; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 14e5ac141eeb6..4b6003fadafa0 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -44,6 +44,7 @@ use crate::eliminate_one_union::EliminateOneUnion; use crate::eliminate_outer_join::EliminateOuterJoin; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::filter_null_join_keys::FilterNullJoinKeys; +use crate::infer_non_null::InferNonNull; use crate::optimize_projections::OptimizeProjections; use crate::plan_signature::LogicalPlanSignature; use crate::propagate_empty_relation::PropagateEmptyRelation; @@ -245,6 +246,7 @@ impl Optimizer { Arc::new(EliminateNestedUnion::new()), Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), + Arc::new(InferNonNull::new()), Arc::new(ReplaceDistinctWithAggregate::new()), Arc::new(EliminateJoin::new()), Arc::new(DecorrelatePredicateSubquery::new()), diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index b850760b8734a..e0de8ce298057 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -187,6 +187,7 @@ analyzed_logical_plan SAME TEXT AS ABOVE logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE +logical_plan after infer_non_null SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE @@ -214,6 +215,7 @@ logical_plan after optimize_projections TableScan: simple_explain_test projectio logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE +logical_plan after infer_non_null SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE