From 0dc5ad074ee8c2d0c7b466ebfbf64a11fc6b62fc Mon Sep 17 00:00:00 2001 From: sgrebnov Date: Wed, 8 Jan 2025 23:55:03 +0300 Subject: [PATCH] Correlated subquery support in Join filter --- .../src/decorrelate_predicate_subquery.rs | 250 +++++++++++++++--- 1 file changed, 210 insertions(+), 40 deletions(-) diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 7fdad5ba4b6e..697f0354e352 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -34,7 +34,7 @@ use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::{conjunction, split_conjunction_owned}; use datafusion_expr::{ exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, - LogicalPlan, LogicalPlanBuilder, Operator, + Join, LogicalPlan, LogicalPlanBuilder, Operator, }; use log::debug; @@ -66,54 +66,185 @@ impl OptimizerRule for DecorrelatePredicateSubquery { })? .data; - let LogicalPlan::Filter(filter) = plan else { - return Ok(Transformed::no(plan)); - }; + match plan { + LogicalPlan::Filter(filter) => { + if !has_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter.predicate) + .into_iter() + .partition(has_subquery); + + if with_subqueries.is_empty() { + return internal_err!( + "can not find expected subqueries in DecorrelatePredicateSubquery" + ); + } - if !has_subquery(&filter.predicate) { - return Ok(Transformed::no(LogicalPlan::Filter(filter))); - } + // iterate through all exists clauses in predicate, turning each into a join + let mut cur_input = Arc::unwrap_or_clone(filter.input); + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + // The subquery expression is at the top level of the filter + SubqueryPredicate::Top(subquery) => { + match build_join_top( + &subquery, + &cur_input, + config.alias_generator(), + )? { + Some(plan) => cur_input = plan, + // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter + None => other_exprs.push(subquery.expr()), + } + } + // The subquery expression is embedded within another expression + SubqueryPredicate::Embedded(expr) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = plan; + other_exprs.push(expr_without_subqueries); + } + } + } + + let expr = conjunction(other_exprs); + if let Some(expr) = expr { + let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; + cur_input = LogicalPlan::Filter(new_filter); + } + Ok(Transformed::yes(cur_input)) + } - let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = - split_conjunction_owned(filter.predicate) - .into_iter() - .partition(has_subquery); + LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }) => { + let Some(filter) = filter else { + return Ok(Transformed::no(LogicalPlan::Join(Join { + left, + right, + on, + filter, + join_type, + join_constraint, + schema, + null_equals_null, + }))); + }; + + if !has_subquery(&filter) { + return Ok(Transformed::no(LogicalPlan::Join(Join { + left, + right, + on, + filter: Some(filter), + join_type, + join_constraint, + schema, + null_equals_null, + }))); + } - if with_subqueries.is_empty() { - return internal_err!( - "can not find expected subqueries in DecorrelatePredicateSubquery" - ); - } + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter) + .into_iter() + .partition(has_subquery); - // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = Arc::unwrap_or_clone(filter.input); - for subquery_expr in with_subqueries { - match extract_subquery_info(subquery_expr) { - // The subquery expression is at the top level of the filter - SubqueryPredicate::Top(subquery) => { - match build_join_top(&subquery, &cur_input, config.alias_generator())? - { - Some(plan) => cur_input = plan, - // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter - None => other_exprs.push(subquery.expr()), - } + if with_subqueries.is_empty() { + return internal_err!( + "can not find expected subqueries in DecorrelatePredicateSubquery" + ); } - // The subquery expression is embedded within another expression - SubqueryPredicate::Embedded(expr) => { - let (plan, expr_without_subqueries) = - rewrite_inner_subqueries(cur_input, expr, config)?; - cur_input = plan; - other_exprs.push(expr_without_subqueries); + + let mut left: LogicalPlan = Arc::unwrap_or_clone(left); + let mut right: LogicalPlan = Arc::unwrap_or_clone(right); + + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + SubqueryPredicate::Top(subquery) => { + let outer_ref_columns = &subquery.query.outer_ref_columns; + match ( + has_reference_to_plan(&left, outer_ref_columns), + has_reference_to_plan(&right, outer_ref_columns), + ) { + (true, false) => { + match build_join_top( + &subquery, + &left, + config.alias_generator(), + )? { + Some(plan) => left = plan, + None => other_exprs.push(subquery.expr()), + } + } + (false, true) => { + match build_join_top( + &subquery, + &right, + config.alias_generator(), + )? { + Some(plan) => right = plan, + None => other_exprs.push(subquery.expr()), + } + } + _ => { + return internal_err!( + "Unsupported subquery expressions as part of Join filter expressions" + ); + } + }; + } + SubqueryPredicate::Embedded(expr) => { + match ( + // TODO - get rid of clone + has_reference_to_plan(&left, &vec![expr.clone()]), + has_reference_to_plan(&right, &vec![expr.clone()]), + ) { + (true, false) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(left, expr, config)?; + left = plan; + other_exprs.push(expr_without_subqueries); + } + (false, true) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(right, expr, config)?; + right = plan; + other_exprs.push(expr_without_subqueries); + } + _ => { + return internal_err!( + "Unsupported subquery expressions as part of Join filter expressions" + ); + } + }; + } + } } + + let expr = conjunction(other_exprs); + + return Ok(Transformed::yes(LogicalPlan::Join(Join { + left: Arc::new(left), + right: Arc::new(right), + on, + filter: expr, + join_type, + join_constraint, + schema, + null_equals_null, + }))); } - } - let expr = conjunction(other_exprs); - if let Some(expr) = expr { - let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; - cur_input = LogicalPlan::Filter(new_filter); + _ => return Ok(Transformed::no(plan)), } - Ok(Transformed::yes(cur_input)) } fn name(&self) -> &str { @@ -125,6 +256,45 @@ impl OptimizerRule for DecorrelatePredicateSubquery { } } +/// Checks if any of the `outer_ref_columns` reference the given `LogicalPlan`. +/// +/// # Returns +/// +/// `Ok(true)` if any column in `outer_ref_columns` refers to the schema of the `LogicalPlan`, +/// `Ok(false)` otherwise. +fn has_reference_to_plan(plan: &LogicalPlan, outer_ref_columns: &Vec) -> bool { + let schema = plan.schema(); + for col_expr in outer_ref_columns { + if let Expr::OuterReferenceColumn(_, column) = col_expr { + if schema.has_column(column) { + return true; + } + } else { + // complex expr, find correlated subquery expressions recursivly + if col_expr + .exists(|expr| match expr { + Expr::Exists(subquery) => { + return Ok(has_reference_to_plan( + plan, + &subquery.subquery.outer_ref_columns, + )) + } + Expr::InSubquery(subquery) => { + return Ok(has_reference_to_plan( + plan, + &subquery.subquery.outer_ref_columns, + )) + } + _ => Ok(false), + }) + .unwrap_or(false) + { + return true; + } + } + } + false +} fn rewrite_inner_subqueries( outer: LogicalPlan, expr: Expr,