diff --git a/sources/sql/src/lib.rs b/sources/sql/src/lib.rs index f668774..654d70b 100644 --- a/sources/sql/src/lib.rs +++ b/sources/sql/src/lib.rs @@ -22,7 +22,7 @@ use datafusion::{ Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, LogicalPlan, LogicalPlanBuilder, Projection, Subquery, TryCast, }, - optimizer::analyzer::{subquery, Analyzer, AnalyzerRule}, + optimizer::analyzer::{Analyzer, AnalyzerRule}, physical_expr::EquivalenceProperties, physical_plan::{ DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, @@ -265,7 +265,8 @@ fn rewrite_unnest_plan( let updated_unnest_inner_projection = Projection::try_new(new_expressions, Arc::clone(&projection.input))?; - let unnest_options = rewrite_unnest_options(&unnest.options, known_rewrites); + let unnest_options = + rewrite_unnest_options(&unnest.options, known_rewrites, subquery_table_scans); // reconstruct the unnest plan with updated projection and rewritten column names let new_plan = @@ -281,17 +282,22 @@ fn rewrite_unnest_plan( fn rewrite_unnest_options( options: &UnnestOptions, known_rewrites: &HashMap, + subquery_table_scans: &mut Option>, ) -> UnnestOptions { let mut new_options = options.clone(); new_options .recursions .iter_mut() .for_each(|x: &mut RecursionUnnestOption| { - if let Some(new_name) = rewrite_column_name(&x.input_column.name, known_rewrites) { + if let Some(new_name) = + rewrite_column_name(&x.input_column.name, known_rewrites, subquery_table_scans) + { x.input_column.name = new_name; } - if let Some(new_name) = rewrite_column_name(&x.output_column.name, known_rewrites) { + if let Some(new_name) = + rewrite_column_name(&x.output_column.name, known_rewrites, subquery_table_scans) + { x.output_column.name = new_name; } }); @@ -304,17 +310,25 @@ fn rewrite_unnest_options( fn rewrite_column_name( col_name: &str, known_rewrites: &HashMap, + subquery_table_scans: &mut Option>, ) -> Option { let (new_col_name, was_rewritten) = known_rewrites.iter().fold( (col_name.to_string(), false), - |(col_name, was_rewritten), (table_ref, rewrite)| match rewrite_column_name_in_expr( - &col_name, - &table_ref.to_string(), - &rewrite.to_string(), - 0, - ) { - Some(new_name) => (new_name, true), - None => (col_name, was_rewritten), + |(col_name, was_rewritten), (table_ref, rewrite)| { + if let Some(subquery_reference) = subquery_table_scans { + if subquery_reference.get(table_ref).is_some() { + return (col_name, was_rewritten); + } + } + match rewrite_column_name_in_expr( + &col_name, + &table_ref.to_string(), + &rewrite.to_string(), + 0, + ) { + Some(new_name) => (new_name, true), + None => (col_name, was_rewritten), + } }, ); @@ -484,7 +498,9 @@ fn rewrite_table_scans_in_expr( // Check if any of the rewrites match any substring in col.name, and replace that part of the string if so. // This will handles cases like "MAX(foo.df_table.a)" -> "MAX(remote_table.a)" - if let Some(new_name) = rewrite_column_name(&col.name, known_rewrites) { + if let Some(new_name) = + rewrite_column_name(&col.name, known_rewrites, subquery_table_scans) + { Ok(Expr::Column(Column::new(col.relation.take(), new_name))) } else { Ok(Expr::Column(col))