diff --git a/sources/sql/src/rewrite/plan.rs b/sources/sql/src/rewrite/plan.rs index 0b945b0..4e18644 100644 --- a/sources/sql/src/rewrite/plan.rs +++ b/sources/sql/src/rewrite/plan.rs @@ -322,7 +322,7 @@ fn rewrite_column_name_in_expr( // Table name same as column name // Shouldn't rewrite in this case - if idx == 0 && start_pos == 0 { + if idx == 0 && table_ref_str.len() == col_name.len() { return None; } @@ -1367,14 +1367,21 @@ mod tests { } #[tokio::test] - async fn test_rewrite_same_column_table_name() -> Result<()> { + async fn test_rewrite_column_name_in_expr() -> Result<()> { init_tracing(); let ctx = get_test_df_context(); - let tests = vec![( - "SELECT app_table FROM (SELECT a app_table from app_table limit 100);", - r#"SELECT app_table FROM (SELECT remote_table.a AS app_table FROM remote_table LIMIT 100)"#, - )]; + let tests = vec![ + ( + // Column alias name same as table name + "SELECT app_table FROM (SELECT a app_table from app_table limit 100);", + r#"SELECT app_table FROM (SELECT remote_table.a AS app_table FROM remote_table LIMIT 100)"#, + ), + ( + "SELECT a - 1, COUNT(*) AS c FROM app_table GROUP BY a - 1;", + r#"SELECT (remote_table.a - 1), count(*) AS c FROM remote_table GROUP BY (remote_table.a - 1)"#, + ), + ]; for test in tests { test_sql(&ctx, test.0, test.1, false).await?; @@ -1391,8 +1398,6 @@ mod tests { ) -> Result<(), datafusion::error::DataFusionError> { let data_frame = ctx.sql(sql_query).await?; - // println!("before optimization: \n{:#?}", data_frame.logical_plan()); - let mut known_rewrites = HashMap::new(); let rewritten_plan = rewrite_table_scans( data_frame.logical_plan(), @@ -1401,12 +1406,8 @@ mod tests { &mut None, )?; - // println!("rewritten_plan: \n{:#?}", rewritten_plan); - let unparsed_sql = plan_to_sql(&rewritten_plan)?; - println!("unparsed_sql: \n{unparsed_sql}"); - assert_eq!( format!("{unparsed_sql}"), expected_sql,