From ac431090c4a5fa92e5441f75f512702d67febc03 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Wed, 3 Jul 2024 17:11:09 +0200 Subject: [PATCH] Infer count() aggregation is not null `count([DISTINCT] [expr])` aggregate function never returns null. Infer non-nullness of such aggregate expression. This allows elimination of the HAVING filter for a query such as SELECT ... count(*) AS c FROM ... GROUP BY ... HAVING c IS NOT NULL --- datafusion/optimizer/src/infer_non_null.rs | 118 +++++++++++++++++++++ datafusion/optimizer/src/lib.rs | 1 + datafusion/optimizer/src/optimizer.rs | 2 + 3 files changed, 121 insertions(+) create mode 100644 datafusion/optimizer/src/infer_non_null.rs diff --git a/datafusion/optimizer/src/infer_non_null.rs b/datafusion/optimizer/src/infer_non_null.rs new file mode 100644 index 0000000000000..8f7f63850be1a --- /dev/null +++ b/datafusion/optimizer/src/infer_non_null.rs @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`InferNonNull`] infers which columns are non-nullable + +use std::collections::HashSet; +use std::ops::Deref; +use std::sync::Arc; + +use datafusion_common::{DFSchema, Result}; +use datafusion_common::tree_node::Transformed; +use datafusion_expr::{Aggregate, Expr, LogicalPlan}; +use datafusion_expr::expr::AggregateFunction; +use crate::{OptimizerConfig, OptimizerRule}; +use crate::optimizer::ApplyOrder; + +#[derive(Default)] +pub struct InferNonNull {} + +impl InferNonNull { + pub fn new() -> Self { + Self::default() + } +} + +impl OptimizerRule for InferNonNull { + fn name(&self) -> &str { + "infer_non_null" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + if let LogicalPlan::Aggregate(ref aggregate) = plan { + let grouping_columns = aggregate.group_expr_len()?; + let new_non_null_fields: HashSet<_> = aggregate.aggr_expr.iter().enumerate().map(|(i, expr)| { + let field_index = grouping_columns + i; + if !plan.schema().field(field_index).is_nullable() { + // Already not nullable. + return None; + } + simple_aggregate_function(expr) + .filter(|function| { + function.func_def.name() == "count" && + function.args.len() <= 1 + }) + .map(|_| field_index) + }) + .flat_map(|x| x) + .collect(); + + if !new_non_null_fields.is_empty() { + let new_schema = Arc::new(DFSchema::new_with_metadata( + plan.schema().iter().enumerate().map(|(i, field)| { + let mut field = (field.0.cloned(), field.1.clone()); + if new_non_null_fields.contains(&i) { + field = (field.0, Arc::new(field.1.deref().clone().with_nullable(false))); + } + field + }) + .collect(), + plan.schema().metadata().clone(), + )?); + + return Ok(Transformed::yes(LogicalPlan::Aggregate( + Aggregate::try_new_with_schema( + aggregate.input.clone(), + aggregate.group_expr.clone(), + aggregate.aggr_expr.clone(), + new_schema, + )? + ))); + } + } + + if let LogicalPlan::Window(ref window) = plan { + // TODO similar to Aggregate + } + + if let LogicalPlan::Filter(ref filter) = plan { + // TODO infer column being not null from filter predicates + } + + Ok(Transformed::no(plan)) + } +} + +fn simple_aggregate_function(expr: &Expr) -> Option<&AggregateFunction> { + match expr { + Expr::AggregateFunction(ref aggregate_function) => Some(aggregate_function), + Expr::Alias(ref alias) => simple_aggregate_function(alias.expr.as_ref()), + _ => None, + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index a6a9e5cf26eaf..2fb7777685ab4 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -43,6 +43,7 @@ pub mod eliminate_one_union; pub mod eliminate_outer_join; pub mod extract_equijoin_predicate; pub mod filter_null_join_keys; +pub mod infer_non_null; pub mod optimize_projections; pub mod optimizer; pub mod propagate_empty_relation; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 14e5ac141eeb6..4b6003fadafa0 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -44,6 +44,7 @@ use crate::eliminate_one_union::EliminateOneUnion; use crate::eliminate_outer_join::EliminateOuterJoin; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::filter_null_join_keys::FilterNullJoinKeys; +use crate::infer_non_null::InferNonNull; use crate::optimize_projections::OptimizeProjections; use crate::plan_signature::LogicalPlanSignature; use crate::propagate_empty_relation::PropagateEmptyRelation; @@ -245,6 +246,7 @@ impl Optimizer { Arc::new(EliminateNestedUnion::new()), Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), + Arc::new(InferNonNull::new()), Arc::new(ReplaceDistinctWithAggregate::new()), Arc::new(EliminateJoin::new()), Arc::new(DecorrelatePredicateSubquery::new()),