From d8697414cd59b4d6fce64a9899d7bcc6b26f27c9 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Mon, 13 Jan 2025 16:38:20 +0900 Subject: [PATCH] feat: fix OuterReferenceColumns not being rewritten correctly (take 2) (#39) * feat: fix OuterReferenceColumns not being rewritten correctly (#38) * Fix OuterReferenceColumns not being rewritten correctly * wip * Handle collecting rewrites from expressions * Add more unit tests for collect rewrite * Fix comments --- sources/sql/src/rewrite/plan.rs | 401 +++++++++++++++++++++++++++++++- 1 file changed, 391 insertions(+), 10 deletions(-) diff --git a/sources/sql/src/rewrite/plan.rs b/sources/sql/src/rewrite/plan.rs index d0f9c4f..5c799a9 100644 --- a/sources/sql/src/rewrite/plan.rs +++ b/sources/sql/src/rewrite/plan.rs @@ -21,12 +21,192 @@ use datafusion_federation::{get_table_source, table_reference::MultiPartTableRef use crate::SQLTableSource; +fn collect_known_rewrites_from_plan( + plan: &LogicalPlan, + known_rewrites: &mut HashMap, +) -> Result<()> { + if let LogicalPlan::TableScan(table_scan) = plan { + let original_table_name = table_scan.table_name.clone(); + + if let Some(federated_source) = get_table_source(&table_scan.source)? { + if let Some(sql_table_source) = + federated_source.as_any().downcast_ref::() + { + let remote_table_name = sql_table_source.table_name(); + known_rewrites.insert(original_table_name, remote_table_name.clone()); + } + } + } + + // Recursively collect from all inputs + for input in plan.inputs() { + collect_known_rewrites_from_plan(input, known_rewrites)?; + } + + for expr in plan.expressions() { + collect_known_rewrites_from_expr(expr, known_rewrites)?; + } + + Ok(()) +} + +fn collect_known_rewrites_from_expr( + expr: Expr, + known_rewrites: &mut HashMap, +) -> Result<()> { + match expr { + Expr::Column(_) => Ok(()), // Column references don't have any table scans + Expr::ScalarSubquery(subquery) => { + collect_known_rewrites_from_plan(&subquery.subquery, known_rewrites) + } + Expr::BinaryExpr(binary_expr) => { + collect_known_rewrites_from_expr(*binary_expr.left, known_rewrites)?; + collect_known_rewrites_from_expr(*binary_expr.right, known_rewrites)?; + Ok(()) + } + Expr::Alias(alias) => collect_known_rewrites_from_expr(*alias.expr, known_rewrites), + Expr::Like(like) => { + collect_known_rewrites_from_expr(*like.expr, known_rewrites)?; + collect_known_rewrites_from_expr(*like.pattern, known_rewrites)?; + Ok(()) + } + Expr::SimilarTo(similar_to) => { + collect_known_rewrites_from_expr(*similar_to.expr, known_rewrites)?; + collect_known_rewrites_from_expr(*similar_to.pattern, known_rewrites)?; + Ok(()) + } + Expr::Not(e) => collect_known_rewrites_from_expr(*e, known_rewrites), + Expr::IsNotNull(e) => collect_known_rewrites_from_expr(*e, known_rewrites), + Expr::IsNull(e) => collect_known_rewrites_from_expr(*e, known_rewrites), + Expr::IsTrue(e) => collect_known_rewrites_from_expr(*e, known_rewrites), + Expr::IsFalse(e) => collect_known_rewrites_from_expr(*e, known_rewrites), + Expr::IsUnknown(e) => collect_known_rewrites_from_expr(*e, known_rewrites), + Expr::IsNotTrue(e) => collect_known_rewrites_from_expr(*e, known_rewrites), + Expr::IsNotFalse(e) => collect_known_rewrites_from_expr(*e, known_rewrites), + Expr::IsNotUnknown(e) => collect_known_rewrites_from_expr(*e, known_rewrites), + Expr::Negative(e) => collect_known_rewrites_from_expr(*e, known_rewrites), + Expr::Between(between) => { + collect_known_rewrites_from_expr(*between.expr, known_rewrites)?; + collect_known_rewrites_from_expr(*between.low, known_rewrites)?; + collect_known_rewrites_from_expr(*between.high, known_rewrites)?; + Ok(()) + } + Expr::Case(case) => { + if let Some(expr) = case.expr { + collect_known_rewrites_from_expr(*expr, known_rewrites)?; + } + if let Some(else_expr) = case.else_expr { + collect_known_rewrites_from_expr(*else_expr, known_rewrites)?; + } + for (when, then) in case.when_then_expr { + collect_known_rewrites_from_expr(*when, known_rewrites)?; + collect_known_rewrites_from_expr(*then, known_rewrites)?; + } + Ok(()) + } + Expr::Cast(cast) => collect_known_rewrites_from_expr(*cast.expr, known_rewrites), + Expr::TryCast(try_cast) => collect_known_rewrites_from_expr(*try_cast.expr, known_rewrites), + Expr::ScalarFunction(sf) => { + for arg in sf.args { + collect_known_rewrites_from_expr(arg, known_rewrites)?; + } + Ok(()) + } + Expr::AggregateFunction(af) => { + for arg in af.args { + collect_known_rewrites_from_expr(arg, known_rewrites)?; + } + if let Some(filter) = af.filter { + collect_known_rewrites_from_expr(*filter, known_rewrites)?; + } + if let Some(order_by) = af.order_by { + for sort in order_by { + collect_known_rewrites_from_expr(sort.expr, known_rewrites)?; + } + } + Ok(()) + } + Expr::WindowFunction(wf) => { + for arg in wf.args { + collect_known_rewrites_from_expr(arg, known_rewrites)?; + } + for expr in wf.partition_by { + collect_known_rewrites_from_expr(expr, known_rewrites)?; + } + for sort in wf.order_by { + collect_known_rewrites_from_expr(sort.expr, known_rewrites)?; + } + Ok(()) + } + Expr::InList(il) => { + collect_known_rewrites_from_expr(*il.expr, known_rewrites)?; + for expr in il.list { + collect_known_rewrites_from_expr(expr, known_rewrites)?; + } + Ok(()) + } + Expr::Exists(exists) => { + collect_known_rewrites_from_plan(&exists.subquery.subquery, known_rewrites)?; + for expr in exists.subquery.outer_ref_columns { + collect_known_rewrites_from_expr(expr, known_rewrites)?; + } + Ok(()) + } + Expr::InSubquery(is) => { + collect_known_rewrites_from_expr(*is.expr, known_rewrites)?; + collect_known_rewrites_from_plan(&is.subquery.subquery, known_rewrites)?; + for expr in is.subquery.outer_ref_columns { + collect_known_rewrites_from_expr(expr, known_rewrites)?; + } + Ok(()) + } + Expr::Wildcard { .. } => Ok(()), // Wildcard expressions don't have any table scans + Expr::GroupingSet(gs) => match gs { + GroupingSet::Rollup(exprs) | GroupingSet::Cube(exprs) => { + for expr in exprs { + collect_known_rewrites_from_expr(expr, known_rewrites)?; + } + Ok(()) + } + GroupingSet::GroupingSets(vec_exprs) => { + for exprs in vec_exprs { + for expr in exprs { + collect_known_rewrites_from_expr(expr, known_rewrites)?; + } + } + Ok(()) + } + }, + Expr::OuterReferenceColumn(_, _) => Ok(()), // Outer reference columns don't have any table scans + Expr::Unnest(unnest) => collect_known_rewrites_from_expr(*unnest.expr, known_rewrites), + Expr::ScalarVariable(_, _) | Expr::Literal(_) | Expr::Placeholder(_) => Ok(()), + } +} + /// Rewrite table scans to use the original federated table name. pub(crate) fn rewrite_table_scans( plan: &LogicalPlan, known_rewrites: &mut HashMap, subquery_uses_partial_path: bool, subquery_table_scans: &mut Option>, +) -> Result { + // First pass: collect all known rewrites + collect_known_rewrites_from_plan(plan, known_rewrites)?; + + // Second pass: do the actual rewriting with complete known_rewrites + rewrite_plan_with_known_rewrites( + plan, + known_rewrites, + subquery_uses_partial_path, + subquery_table_scans, + ) +} + +fn rewrite_plan_with_known_rewrites( + plan: &LogicalPlan, + known_rewrites: &HashMap, + subquery_uses_partial_path: bool, + subquery_table_scans: &mut Option>, ) -> Result { if plan.inputs().is_empty() { if let LogicalPlan::TableScan(table_scan) = plan { @@ -41,7 +221,6 @@ pub(crate) fn rewrite_table_scans( match federated_source.as_any().downcast_ref::() { Some(sql_table_source) => { let remote_table_name = sql_table_source.table_name(); - known_rewrites.insert(original_table_name.clone(), remote_table_name.clone()); // If the remote table name is a MultiPartTableReference, we will not rewrite it here, but rewrite it after the final unparsing on the AST directly. let MultiPartTableReference::TableReference(remote_table_name) = @@ -91,7 +270,7 @@ pub(crate) fn rewrite_table_scans( .inputs() .into_iter() .map(|i| { - rewrite_table_scans( + rewrite_plan_with_known_rewrites( i, known_rewrites, subquery_uses_partial_path, @@ -172,7 +351,7 @@ pub(crate) fn rewrite_table_scans( fn rewrite_unnest_plan( unnest: &logical_expr::Unnest, mut rewritten_inputs: Vec, - known_rewrites: &mut HashMap, + known_rewrites: &HashMap, subquery_uses_partial_path: bool, subquery_table_scans: &mut Option>, ) -> Result { @@ -391,14 +570,14 @@ fn rewrite_column_name_in_expr( fn rewrite_table_scans_in_expr( expr: Expr, - known_rewrites: &mut HashMap, + known_rewrites: &HashMap, subquery_uses_partial_path: bool, subquery_table_scans: &mut Option>, ) -> Result { match expr { Expr::ScalarSubquery(subquery) => { let new_subquery = if subquery_table_scans.is_some() || !subquery_uses_partial_path { - rewrite_table_scans( + rewrite_plan_with_known_rewrites( &subquery.subquery, known_rewrites, subquery_uses_partial_path, @@ -406,7 +585,7 @@ fn rewrite_table_scans_in_expr( )? } else { let mut scans = Some(HashSet::new()); - rewrite_table_scans( + rewrite_plan_with_known_rewrites( &subquery.subquery, known_rewrites, subquery_uses_partial_path, @@ -888,7 +1067,7 @@ fn rewrite_table_scans_in_expr( } Expr::Exists(exists) => { let subquery_plan = if subquery_table_scans.is_some() || !subquery_uses_partial_path { - rewrite_table_scans( + rewrite_plan_with_known_rewrites( &exists.subquery.subquery, known_rewrites, subquery_uses_partial_path, @@ -896,7 +1075,7 @@ fn rewrite_table_scans_in_expr( )? } else { let mut scans = Some(HashSet::new()); - rewrite_table_scans( + rewrite_plan_with_known_rewrites( &exists.subquery.subquery, known_rewrites, subquery_uses_partial_path, @@ -930,7 +1109,7 @@ fn rewrite_table_scans_in_expr( subquery_table_scans, )?; let subquery_plan = if subquery_table_scans.is_some() || !subquery_uses_partial_path { - rewrite_table_scans( + rewrite_plan_with_known_rewrites( &is.subquery.subquery, known_rewrites, subquery_uses_partial_path, @@ -938,7 +1117,7 @@ fn rewrite_table_scans_in_expr( )? } else { let mut scans = Some(HashSet::new()); - rewrite_table_scans( + rewrite_plan_with_known_rewrites( &is.subquery.subquery, known_rewrites, subquery_uses_partial_path, @@ -1380,6 +1559,22 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_rewrite_outer_ref_columns() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + let tests = vec![( + "SELECT foo.df_table.a FROM bar JOIN foo.df_table ON foo.df_table.a = (SELECT bar.a FROM bar WHERE bar.a > foo.df_table.a)", + r#"SELECT remote_table.a FROM remote_db.remote_schema.remote_table JOIN remote_table ON (remote_table.a = (SELECT a FROM remote_db.remote_schema.remote_table WHERE (remote_table.a > remote_table.a)))"#, + true, + )]; + for test in tests { + test_sql(&ctx, test.0, test.1, test.2).await?; + } + + Ok(()) + } + #[tokio::test] async fn test_rewrite_column_name_in_expr() -> Result<()> { init_tracing(); @@ -1479,3 +1674,189 @@ mod tests { Ok(()) } } + +#[cfg(test)] +mod collect_rewrites_tests { + use crate::{SQLExecutor, SQLFederationProvider}; + + use super::*; + use async_trait::async_trait; + use datafusion::{ + arrow::datatypes::{DataType, Field, Schema, SchemaRef}, + common::DFSchema, + datasource::DefaultTableSource, + execution::SendableRecordBatchStream, + sql::unparser::dialect::{DefaultDialect, Dialect}, + }; + use datafusion_federation::FederatedTableProviderAdaptor; + + struct TestSQLExecutor {} + + #[async_trait] + impl SQLExecutor for TestSQLExecutor { + fn name(&self) -> &str { + "test_sql_table_source" + } + + fn compute_context(&self) -> Option { + None + } + + fn dialect(&self) -> Arc { + Arc::new(DefaultDialect {}) + } + + fn execute(&self, _query: &str, _schema: SchemaRef) -> Result { + Err(DataFusionError::NotImplemented( + "execute not implemented".to_string(), + )) + } + + async fn table_names(&self) -> Result> { + Err(DataFusionError::NotImplemented( + "table inference not implemented".to_string(), + )) + } + + async fn get_table_schema(&self, _table_name: &str) -> Result { + Err(DataFusionError::NotImplemented( + "table inference not implemented".to_string(), + )) + } + } + + fn create_test_table_scan() -> LogicalPlan { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, false), + ])); + + let sql_federation_provider = + Arc::new(SQLFederationProvider::new(Arc::new(TestSQLExecutor {}))); + let table_source = Arc::new( + SQLTableSource::new_with_schema( + sql_federation_provider, + "remote_table".to_string(), + schema.clone(), + ) + .expect("to have a valid SQLTableSource"), + ); + let source = Arc::new(DefaultTableSource::new(Arc::new( + FederatedTableProviderAdaptor::new(table_source), + ))); + + let df_schema = + DFSchema::try_from(schema.as_ref().clone()).expect("to have a valid DFSchema"); + + LogicalPlan::TableScan(logical_expr::TableScan { + table_name: TableReference::from("foo.df_table"), + source, + projection: None, + projected_schema: df_schema.into(), + filters: vec![], + fetch: None, + }) + } + + #[test] + fn test_collect_from_table_scan() -> Result<()> { + let plan = create_test_table_scan(); + let mut known_rewrites = HashMap::new(); + + collect_known_rewrites_from_plan(&plan, &mut known_rewrites)?; + + assert_eq!(known_rewrites.len(), 1); + assert_eq!( + known_rewrites.get(&TableReference::from("foo.df_table")), + Some(&MultiPartTableReference::TableReference( + TableReference::from("remote_table") + )) + ); + Ok(()) + } + + #[test] + fn test_collect_from_scalar_subquery() -> Result<()> { + let table_scan = create_test_table_scan(); + let subquery = Expr::ScalarSubquery(Subquery { + subquery: Arc::new(table_scan), + outer_ref_columns: vec![], + }); + + let mut known_rewrites = HashMap::new(); + collect_known_rewrites_from_expr(subquery, &mut known_rewrites)?; + + assert_eq!(known_rewrites.len(), 1); + assert_eq!( + known_rewrites.get(&TableReference::from("foo.df_table")), + Some(&MultiPartTableReference::TableReference( + TableReference::from("remote_table") + )) + ); + Ok(()) + } + + #[test] + fn test_collect_from_binary_expr() -> Result<()> { + let left = Expr::Column(Column::from_qualified_name("foo.df_table.a")); + let right = Expr::Column(Column::from_qualified_name("foo.df_table.b")); + let binary = Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + datafusion::logical_expr::Operator::Eq, + Box::new(right), + )); + + let mut known_rewrites = HashMap::new(); + collect_known_rewrites_from_expr(binary, &mut known_rewrites)?; + + // Column expressions don't generate rewrites on their own + assert_eq!(known_rewrites.len(), 0); + Ok(()) + } + + #[test] + fn test_collect_from_case_expression() -> Result<()> { + let col = Expr::Column(Column::from_qualified_name("foo.df_table.a")); + let case = Expr::Case(Case::new( + Some(Box::new(col.clone())), + vec![( + Box::new(Expr::Literal(datafusion::scalar::ScalarValue::Int64(Some( + 1, + )))), + Box::new(col.clone()), + )], + Some(Box::new(col)), + )); + + let mut known_rewrites = HashMap::new(); + collect_known_rewrites_from_expr(case, &mut known_rewrites)?; + + // Column expressions don't generate rewrites on their own + assert_eq!(known_rewrites.len(), 0); + Ok(()) + } + + #[test] + fn test_collect_from_exists_subquery() -> Result<()> { + let table_scan = create_test_table_scan(); + let exists = Expr::Exists(Exists::new( + Subquery { + subquery: Arc::new(table_scan), + outer_ref_columns: vec![], + }, + false, + )); + + let mut known_rewrites = HashMap::new(); + collect_known_rewrites_from_expr(exists, &mut known_rewrites)?; + + assert_eq!(known_rewrites.len(), 1); + assert_eq!( + known_rewrites.get(&TableReference::from("foo.df_table")), + Some(&MultiPartTableReference::TableReference( + TableReference::from("remote_table") + )) + ); + Ok(()) + } +}