diff --git a/datafusion/optimizer/src/infer_non_null.rs b/datafusion/optimizer/src/infer_non_null.rs new file mode 100644 index 0000000000000..e0336dccf1c3b --- /dev/null +++ b/datafusion/optimizer/src/infer_non_null.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`InferNonNull`] infers which columns are non-nullable + +use std::collections::HashSet; +use std::ops::Deref; +use std::sync::Arc; + +use datafusion_common::tree_node::Transformed; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::{Aggregate, Expr, LogicalPlan}; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +#[derive(Default)] +pub struct InferNonNull {} + +impl InferNonNull { + pub fn new() -> Self { + Self::default() + } +} + +impl OptimizerRule for InferNonNull { + fn name(&self) -> &str { + "infer_non_null" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + if let LogicalPlan::Aggregate(aggregate) = plan { + let grouping_columns = aggregate.group_expr_len()?; + let new_non_null_fields: HashSet<_> = aggregate + .aggr_expr + .iter() + .enumerate() + .filter_map(|(i, expr)| { + let field_index = grouping_columns + i; + if !aggregate.schema.field(field_index).is_nullable() { + // Already not nullable. + return None; + } + simple_aggregate_function(expr) + .filter(|function| { + // TODO infer non-null for min, max on non-null input + function.func_def.name() == "count" + && function.args.len() <= 1 + }) + .map(|_| field_index) + }) + .collect(); + + if !new_non_null_fields.is_empty() { + let new_schema = Arc::new(DFSchema::new_with_metadata( + aggregate + .schema + .iter() + .enumerate() + .map(|(i, field)| { + let mut field = (field.0.cloned(), field.1.clone()); + if new_non_null_fields.contains(&i) { + field = ( + field.0, + Arc::new( + field.1.deref().clone().with_nullable(false), + ), + ); + } + field + }) + .collect(), + aggregate.schema.metadata().clone(), + )?); + + return Ok(Transformed::yes(LogicalPlan::Aggregate( + Aggregate::try_new_with_schema( + aggregate.input, + aggregate.group_expr, + aggregate.aggr_expr, + new_schema, + )?, + ))); + } + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + if let LogicalPlan::Window(_) = plan { + // TODO similar to Aggregate + } + + if let LogicalPlan::Filter(_) = plan { + // TODO infer column being not null from filter predicates + } + + Ok(Transformed::no(plan)) + } +} + +fn simple_aggregate_function(expr: &Expr) -> Option<&AggregateFunction> { + match expr { + Expr::AggregateFunction(ref aggregate_function) => Some(aggregate_function), + Expr::Alias(ref alias) => simple_aggregate_function(alias.expr.as_ref()), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use datafusion_expr::expr_fn::max; + use datafusion_expr::expr_fn::min; + use datafusion_expr::expr_fn::wildcard; + use datafusion_expr::{col, lit, LogicalPlanBuilder}; + use datafusion_functions_aggregate::expr_fn::count; + use datafusion_functions_aggregate::expr_fn::sum; + + use crate::infer_non_null::InferNonNull; + use crate::{OptimizerContext, OptimizerRule}; + + #[test] + fn test_aggregate() -> Result<()> { + let plan = LogicalPlanBuilder::values(vec![ + vec![lit(42), lit(100)], + vec![lit(42), lit(ScalarValue::Int64(None))], + ])? + .aggregate( + vec![col("column1")], + vec![ + min(col("column2")).alias("min"), + max(col("column2")).alias("max"), + count(wildcard()).alias("count_all"), + count(col("column2")).alias("count_non_null"), + sum(col("column2")).alias("sum"), + ], + )? + .build()?; + + let transformed = InferNonNull::new().rewrite(plan, &OptimizerContext::new())?; + assert!(transformed.transformed); + let new_schema = transformed.data.schema(); + assert!(new_schema + .field_with_unqualified_name("column1")? + .is_nullable()); + assert!(new_schema.field_with_unqualified_name("min")?.is_nullable()); + assert!(new_schema.field_with_unqualified_name("max")?.is_nullable()); + assert!(!new_schema + .field_with_unqualified_name("count_all")? + .is_nullable()); + assert!(!new_schema + .field_with_unqualified_name("count_non_null")? + .is_nullable()); + assert!(new_schema.field_with_unqualified_name("sum")?.is_nullable()); + + Ok(()) + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index a6a9e5cf26eaf..2fb7777685ab4 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -43,6 +43,7 @@ pub mod eliminate_one_union; pub mod eliminate_outer_join; pub mod extract_equijoin_predicate; pub mod filter_null_join_keys; +pub mod infer_non_null; pub mod optimize_projections; pub mod optimizer; pub mod propagate_empty_relation; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 14e5ac141eeb6..4b6003fadafa0 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -44,6 +44,7 @@ use crate::eliminate_one_union::EliminateOneUnion; use crate::eliminate_outer_join::EliminateOuterJoin; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; use crate::filter_null_join_keys::FilterNullJoinKeys; +use crate::infer_non_null::InferNonNull; use crate::optimize_projections::OptimizeProjections; use crate::plan_signature::LogicalPlanSignature; use crate::propagate_empty_relation::PropagateEmptyRelation; @@ -245,6 +246,7 @@ impl Optimizer { Arc::new(EliminateNestedUnion::new()), Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), + Arc::new(InferNonNull::new()), Arc::new(ReplaceDistinctWithAggregate::new()), Arc::new(EliminateJoin::new()), Arc::new(DecorrelatePredicateSubquery::new()), diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index b850760b8734a..e0de8ce298057 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -187,6 +187,7 @@ analyzed_logical_plan SAME TEXT AS ABOVE logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE +logical_plan after infer_non_null SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE @@ -214,6 +215,7 @@ logical_plan after optimize_projections TableScan: simple_explain_test projectio logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE +logical_plan after infer_non_null SAME TEXT AS ABOVE logical_plan after replace_distinct_aggregate SAME TEXT AS ABOVE logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE