From c510c7c023e73a2b1181a66d5db314342a1b9886 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 12 Apr 2024 15:28:28 -0400 Subject: [PATCH] Rewrite `CommonSubexprEliminate` to avoid copies using TreeNode --- .../optimizer/src/common_subexpr_eliminate.rs | 178 +++++++++--------- 1 file changed, 92 insertions(+), 86 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 2fabd5de9282..b5147725c7b3 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -22,7 +22,7 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; use crate::utils::is_volatile_expression; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{DataType, Field}; use datafusion_common::tree_node::{ @@ -192,9 +192,8 @@ impl CommonSubexprEliminate { let rewrite_exprs = self.rewrite_exprs_list(exprs_list, expr_set, &mut affected_id)?; - let mut new_input = self - .try_optimize(input, config)? - .unwrap_or_else(|| input.clone()); + let mut new_input = self.rewrite(input.clone(), config)?.data; + if !affected_id.is_empty() { new_input = build_common_expr_project_plan(new_input, affected_id, expr_set)?; } @@ -202,11 +201,11 @@ impl CommonSubexprEliminate { Ok((rewrite_exprs, new_input)) } - fn try_optimize_window( + fn optimize_window( &self, window: &Window, config: &dyn OptimizerConfig, - ) -> Result { + ) -> Result> { let mut window_exprs = vec![]; let mut expr_set = ExprSet::default(); @@ -261,14 +260,14 @@ impl CommonSubexprEliminate { plan = LogicalPlan::Window(Window::try_new(new_window_expr, Arc::new(plan))?); } - Ok(plan) + Ok(Transformed::yes(plan)) } - fn try_optimize_aggregate( + fn optimize_aggregate( &self, aggregate: &Aggregate, config: &dyn OptimizerConfig, - ) -> Result { + ) -> Result> { let Aggregate { group_expr, aggr_expr, @@ -317,8 +316,10 @@ impl CommonSubexprEliminate { }) .collect::>>()?; // Since group_epxr changes, schema changes also. Use try_new method. - Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) - .map(LogicalPlan::Aggregate) + let new_plan = + Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) + .map(LogicalPlan::Aggregate)?; + Ok(Transformed::yes(new_plan)) } else { let mut agg_exprs = vec![]; @@ -364,18 +365,17 @@ impl CommonSubexprEliminate { agg_exprs, )?); - Ok(LogicalPlan::Projection(Projection::try_new( - proj_exprs, - Arc::new(agg), - )?)) + let new_plan = + LogicalPlan::Projection(Projection::try_new(proj_exprs, Arc::new(agg))?); + Ok(Transformed::yes(new_plan)) } } - fn try_unary_plan( + fn optimize_unary_plan( &self, plan: &LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result { + ) -> Result> { let expr = plan.expressions(); let inputs = plan.inputs(); let input = inputs[0]; @@ -388,64 +388,73 @@ impl CommonSubexprEliminate { let (mut new_expr, new_input) = self.rewrite_expr(&[&expr], input, &expr_set, config)?; - plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input]) + let new_plan = plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input])?; + Ok(Transformed::yes(new_plan)) } } impl OptimizerRule for CommonSubexprEliminate { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { - let optimized_plan = match plan { - LogicalPlan::Projection(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Filter(_) => Some(self.try_unary_plan(plan, config)?), - LogicalPlan::Window(window) => { - Some(self.try_optimize_window(window, config)?) - } - LogicalPlan::Aggregate(aggregate) => { - Some(self.try_optimize_aggregate(aggregate, config)?) - } - LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Union(_) - | LogicalPlan::TableScan(_) - | LogicalPlan::Values(_) - | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::Statement(_) - | LogicalPlan::DescribeTable(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Extension(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Copy(_) - | LogicalPlan::Unnest(_) - | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Prepare(_) => { - // apply the optimization to all inputs of the plan - utils::optimize_children(self, plan, config)? - } - }; + internal_err!("Should call CommonSubexprEliminate::rewrite instead") + } - let original_schema = plan.schema().clone(); - match optimized_plan { - Some(optimized_plan) if optimized_plan.schema() != &original_schema => { - // add an additional projection if the output schema changed. - Ok(Some(build_recover_project_plan( - &original_schema, - optimized_plan, - )?)) - } - plan => Ok(plan), - } + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { + // Note it needs to be down as this rule introduces new plans + plan.transform_down_with_subqueries(&|plan| { + let original_schema = plan.schema().clone(); + + let transformed_plan = match plan { + LogicalPlan::Projection(_) + | LogicalPlan::Sort(_) + | LogicalPlan::Filter(_) => self.optimize_unary_plan(&plan, config), + LogicalPlan::Window(window) => self.optimize_window(&window, config), + LogicalPlan::Aggregate(aggregate) => { + self.optimize_aggregate(&aggregate, config) + } + LogicalPlan::Join(_) + | LogicalPlan::CrossJoin(_) + | LogicalPlan::Repartition(_) + | LogicalPlan::Union(_) + | LogicalPlan::TableScan(_) + | LogicalPlan::Values(_) + | LogicalPlan::EmptyRelation(_) + | LogicalPlan::Subquery(_) + | LogicalPlan::SubqueryAlias(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Ddl(_) + | LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::Statement(_) + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Distinct(_) + | LogicalPlan::Extension(_) + | LogicalPlan::Dml(_) + | LogicalPlan::Copy(_) + | LogicalPlan::Unnest(_) + | LogicalPlan::RecursiveQuery(_) + | LogicalPlan::Prepare(_) => Ok(Transformed::no(plan)), + }?; + // If schema has changed, add a projection to recover the original schema + transformed_plan.map_data(|transformed_plan| { + if transformed_plan.schema() != &original_schema { + build_recover_project_plan(&original_schema, transformed_plan) + } else { + Ok(transformed_plan) + } + }) + }) } fn name(&self) -> &str { @@ -786,12 +795,12 @@ mod test { use super::*; - fn assert_optimized_plan_eq(expected: &str, plan: &LogicalPlan) { + fn assert_optimized_plan_eq(expected: &str, plan: LogicalPlan) { let optimizer = CommonSubexprEliminate {}; let optimized_plan = optimizer - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); + .rewrite(plan, &OptimizerContext::new()) + .expect("failed to optimize plan") + .data; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(expected, formatted_plan); } @@ -822,7 +831,7 @@ mod test { \n Projection: test.a * (Int32(1) - test.b) AS test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -875,7 +884,7 @@ mod test { \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a, AVG(test.b) AS col3, AVG(test.c) AS AVG(test.c), my_agg(test.b) AS col6, my_agg(test.c) AS my_agg(test.c)]]\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: trafo after aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -894,7 +903,7 @@ mod test { \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a]]\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: transformation before aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -911,7 +920,7 @@ mod test { \n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: common between agg and group let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -928,7 +937,7 @@ mod test { \n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); // test: all mixed let plan = LogicalPlanBuilder::from(table_scan) @@ -950,7 +959,7 @@ mod test { \n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -977,7 +986,7 @@ mod test { \n Projection: UInt32(1) + table.test.col.a AS UInt32(1) + table.test.col.atable.test.col.aUInt32(1), table.test.col.a\ \n TableScan: table.test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -997,7 +1006,7 @@ mod test { \n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1013,7 +1022,7 @@ mod test { let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1031,7 +1040,7 @@ mod test { \n Projection: Int32(1) + test.a\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) } @@ -1162,10 +1171,7 @@ mod test { .build() .unwrap(); let rule = CommonSubexprEliminate {}; - let optimized_plan = rule - .try_optimize(&plan, &OptimizerContext::new()) - .unwrap() - .unwrap(); + let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap().data; let schema = optimized_plan.schema(); let fields_with_datatypes: Vec<_> = schema @@ -1204,7 +1210,7 @@ mod test { \n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan); Ok(()) }