From 43e8e92110313a5e93b6ad0da8ed29410ecdeb20 Mon Sep 17 00:00:00 2001 From: dmitrybugakov Date: Fri, 3 May 2024 12:32:21 +0200 Subject: [PATCH] Optimized push down filter #10291 --- datafusion/optimizer/src/push_down_filter.rs | 144 +++++++++++-------- 1 file changed, 82 insertions(+), 62 deletions(-) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 8462cf86f154e..6c2cba20301dd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -17,8 +17,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; +use itertools::Itertools; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, @@ -29,6 +28,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::replace_col; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{ CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union, }; @@ -38,7 +38,8 @@ use datafusion_expr::{ ScalarFunctionDefinition, TableProviderFilterPushDown, }; -use itertools::Itertools; +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; /// Optimizer rule for pushing (moving) filter expressions down in a plan so /// they are applied as early as possible. @@ -407,7 +408,7 @@ fn push_down_all_join( right: &LogicalPlan, on_filter: Vec, is_inner_join: bool, -) -> Result { +) -> Result> { let on_filter_empty = on_filter.is_empty(); // Get pushable predicates from current optimizer state let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?; @@ -502,44 +503,45 @@ fn push_down_all_join( exprs.extend(join_conditions.into_iter().reduce(Expr::and)); let plan = join_plan.with_new_exprs(exprs, vec![left, right])?; - // wrap the join on the filter whose predicates must be kept match conjunction(keep_predicates) { Some(predicate) => { - Filter::try_new(predicate, Arc::new(plan)).map(LogicalPlan::Filter) + let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?; + Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan))) } - None => Ok(plan), + None => Ok(Transformed::no(plan)), } } fn push_down_join( - plan: &LogicalPlan, + plan: LogicalPlan, join: &Join, parent_predicate: Option<&Expr>, -) -> Result> { - let predicates = match parent_predicate { - Some(parent_predicate) => split_conjunction_owned(parent_predicate.clone()), - None => vec![], - }; +) -> Result> { + // Split the parent predicate into individual conjunctive parts. + let predicates = parent_predicate + .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone())); - // Convert JOIN ON predicate to Predicates + // Extract conjunctions from the JOIN's ON filter, if present. let on_filters = join .filter .as_ref() - .map(|e| split_conjunction_owned(e.clone())) - .unwrap_or_default(); + .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone())); let mut is_inner_join = false; let infer_predicates = if join.join_type == JoinType::Inner { is_inner_join = true; + // Only allow both side key is column. let join_col_keys = join .on .iter() - .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) { - (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)), - _ => None, + .filter_map(|(l, r)| { + let left_col = l.try_into_col().ok()?; + let right_col = r.try_into_col().ok()?; + Some((left_col, right_col)) }) .collect::>(); + // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down // For inner joins, duplicate filters for joined columns so filters can be pushed down // to both sides. Take the following query as an example: @@ -559,6 +561,7 @@ fn push_down_join( .chain(on_filters.iter()) .filter_map(|predicate| { let mut join_cols_to_replace = HashMap::new(); + let columns = match predicate.to_columns() { Ok(columns) => columns, Err(e) => return Some(Err(e)), @@ -596,20 +599,32 @@ fn push_down_join( }; if on_filters.is_empty() && predicates.is_empty() && infer_predicates.is_empty() { - return Ok(None); + return Ok(Transformed::no(plan.clone())); } - Ok(Some(push_down_all_join( + + match push_down_all_join( predicates, infer_predicates, - plan, + &plan, &join.left, &join.right, on_filters, is_inner_join, - )?)) + ) { + Ok(plan) => Ok(Transformed::yes(plan.data)), + Err(e) => Err(e), + } } impl OptimizerRule for PushDownFilter { + fn try_optimize( + &self, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + internal_err!("Should have called PushDownFilter::rewrite") + } + fn name(&self) -> &str { "push_down_filter" } @@ -618,21 +633,26 @@ impl OptimizerRule for PushDownFilter { Some(ApplyOrder::TopDown) } - fn try_optimize( + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { let filter = match plan { - LogicalPlan::Filter(filter) => filter, - // we also need to pushdown filter in Join. - LogicalPlan::Join(join) => return push_down_join(plan, join, None), - _ => return Ok(None), + LogicalPlan::Filter(ref filter) => filter, + LogicalPlan::Join(ref join) => { + return push_down_join(plan.clone(), join, None) + } + _ => return Ok(Transformed::no(plan)), }; - let child_plan = filter.input.as_ref(); + let child_plan = unwrap_arc(filter.clone().input); let new_plan = match child_plan { - LogicalPlan::Filter(child_filter) => { + LogicalPlan::Filter(ref child_filter) => { let parents_predicates = split_conjunction(&filter.predicate); let set: HashSet<&&Expr> = parents_predicates.iter().collect(); @@ -652,20 +672,18 @@ impl OptimizerRule for PushDownFilter { new_predicate, child_filter.input.clone(), )?); - self.try_optimize(&new_filter, _config)? - .unwrap_or(new_filter) + self.rewrite(new_filter, _config)?.data } LogicalPlan::Repartition(_) | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) => { - // commutable let new_filter = plan.with_new_exprs( plan.expressions(), vec![child_plan.inputs()[0].clone()], )?; child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])? } - LogicalPlan::SubqueryAlias(subquery_alias) => { + LogicalPlan::SubqueryAlias(ref subquery_alias) => { let mut replace_map = HashMap::new(); for (i, (qualifier, field)) in subquery_alias.input.schema().iter().enumerate() @@ -685,7 +703,7 @@ impl OptimizerRule for PushDownFilter { )?); child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])? } - LogicalPlan::Projection(projection) => { + LogicalPlan::Projection(ref projection) => { // A projection is filter-commutable if it do not contain volatile predicates or contain volatile // predicates that are not used in the filter. However, we should re-writes all predicate expressions. // collect projection. @@ -742,10 +760,10 @@ impl OptimizerRule for PushDownFilter { } } } - None => return Ok(None), + None => return Ok(Transformed::no(plan)), } } - LogicalPlan::Union(union) => { + LogicalPlan::Union(ref union) => { let mut inputs = Vec::with_capacity(union.inputs.len()); for input in &union.inputs { let mut replace_map = HashMap::new(); @@ -770,7 +788,7 @@ impl OptimizerRule for PushDownFilter { schema: plan.schema().clone(), }) } - LogicalPlan::Aggregate(agg) => { + LogicalPlan::Aggregate(ref agg) => { // We can push down Predicate which in groupby_expr. let group_expr_columns = agg .group_expr @@ -821,13 +839,11 @@ impl OptimizerRule for PushDownFilter { None => new_agg, } } - LogicalPlan::Join(join) => { - match push_down_join(&filter.input, join, Some(&filter.predicate))? { - Some(optimized_plan) => optimized_plan, - None => return Ok(None), - } + LogicalPlan::Join(ref join) => { + let unwrapped_plan = unwrap_arc(filter.clone().input); + push_down_join(unwrapped_plan, join, Some(&filter.predicate))?.data } - LogicalPlan::CrossJoin(cross_join) => { + LogicalPlan::CrossJoin(ref cross_join) => { let predicates = split_conjunction_owned(filter.predicate.clone()); let join = convert_cross_join_to_inner_join(cross_join.clone())?; let join_plan = LogicalPlan::Join(join); @@ -843,9 +859,9 @@ impl OptimizerRule for PushDownFilter { vec![], true, )?; - convert_to_cross_join_if_beneficial(plan)? + convert_to_cross_join_if_beneficial(plan.data)? } - LogicalPlan::TableScan(scan) => { + LogicalPlan::TableScan(ref scan) => { let filter_predicates = split_conjunction(&filter.predicate); let results = scan .source @@ -892,7 +908,7 @@ impl OptimizerRule for PushDownFilter { None => new_scan, } } - LogicalPlan::Extension(extension_plan) => { + LogicalPlan::Extension(ref extension_plan) => { let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); @@ -935,9 +951,10 @@ impl OptimizerRule for PushDownFilter { None => new_extension, } } - _ => return Ok(None), + _ => return Ok(Transformed::no(plan)), }; - Ok(Some(new_plan)) + + Ok(Transformed::yes(new_plan)) } } @@ -1024,16 +1041,12 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { #[cfg(test)] mod tests { - use super::*; use std::any::Any; use std::fmt::{Debug, Formatter}; - use crate::optimizer::Optimizer; - use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; - use crate::test::*; - use crate::OptimizerContext; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use async_trait::async_trait; + use datafusion_common::ScalarValue; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::table_scan; @@ -1043,7 +1056,13 @@ mod tests { Volatility, }; - use async_trait::async_trait; + use crate::optimizer::Optimizer; + use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; + use crate::test::*; + use crate::OptimizerContext; + + use super::*; + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { @@ -2298,9 +2317,9 @@ mod tests { table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?; let optimized_plan = PushDownFilter::new() - .try_optimize(&plan, &OptimizerContext::new()) + .rewrite(plan, &OptimizerContext::new()) .expect("failed to optimize plan") - .unwrap(); + .data; let expected = "\ Filter: a = Int64(1)\ @@ -2667,8 +2686,9 @@ Projection: a, b // Originally global state which can help to avoid duplicate Filters been generated and pushed down. // Now the global state is removed. Need to double confirm that avoid duplicate Filters. let optimized_plan = PushDownFilter::new() - .try_optimize(&plan, &OptimizerContext::new())? - .expect("failed to optimize plan"); + .rewrite(plan, &OptimizerContext::new()) + .expect("failed to optimize plan") + .data; assert_optimized_plan_eq(optimized_plan, expected) }