Skip to content

Commit

Permalink
Rewrite CommonSubexprEliminate to avoid copies using TreeNode
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Apr 12, 2024
1 parent 60305ed commit c510c7c
Showing 1 changed file with 92 additions and 86 deletions.
178 changes: 92 additions & 86 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;

use crate::utils::is_volatile_expression;
use crate::{utils, OptimizerConfig, OptimizerRule};
use crate::{OptimizerConfig, OptimizerRule};

use arrow::datatypes::{DataType, Field};
use datafusion_common::tree_node::{
Expand Down Expand Up @@ -192,21 +192,20 @@ impl CommonSubexprEliminate {
let rewrite_exprs =
self.rewrite_exprs_list(exprs_list, expr_set, &mut affected_id)?;

let mut new_input = self
.try_optimize(input, config)?
.unwrap_or_else(|| input.clone());
let mut new_input = self.rewrite(input.clone(), config)?.data;

if !affected_id.is_empty() {
new_input = build_common_expr_project_plan(new_input, affected_id, expr_set)?;
}

Ok((rewrite_exprs, new_input))
}

fn try_optimize_window(
fn optimize_window(
&self,
window: &Window,
config: &dyn OptimizerConfig,
) -> Result<LogicalPlan> {
) -> Result<Transformed<LogicalPlan>> {
let mut window_exprs = vec![];
let mut expr_set = ExprSet::default();

Expand Down Expand Up @@ -261,14 +260,14 @@ impl CommonSubexprEliminate {
plan = LogicalPlan::Window(Window::try_new(new_window_expr, Arc::new(plan))?);
}

Ok(plan)
Ok(Transformed::yes(plan))
}

fn try_optimize_aggregate(
fn optimize_aggregate(
&self,
aggregate: &Aggregate,
config: &dyn OptimizerConfig,
) -> Result<LogicalPlan> {
) -> Result<Transformed<LogicalPlan>> {
let Aggregate {
group_expr,
aggr_expr,
Expand Down Expand Up @@ -317,8 +316,10 @@ impl CommonSubexprEliminate {
})
.collect::<Result<Vec<Expr>>>()?;
// Since group_epxr changes, schema changes also. Use try_new method.
Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr)
.map(LogicalPlan::Aggregate)
let new_plan =
Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr)
.map(LogicalPlan::Aggregate)?;
Ok(Transformed::yes(new_plan))
} else {
let mut agg_exprs = vec![];

Expand Down Expand Up @@ -364,18 +365,17 @@ impl CommonSubexprEliminate {
agg_exprs,
)?);

Ok(LogicalPlan::Projection(Projection::try_new(
proj_exprs,
Arc::new(agg),
)?))
let new_plan =
LogicalPlan::Projection(Projection::try_new(proj_exprs, Arc::new(agg))?);
Ok(Transformed::yes(new_plan))
}
}

fn try_unary_plan(
fn optimize_unary_plan(
&self,
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<LogicalPlan> {
) -> Result<Transformed<LogicalPlan>> {
let expr = plan.expressions();
let inputs = plan.inputs();
let input = inputs[0];
Expand All @@ -388,64 +388,73 @@ impl CommonSubexprEliminate {
let (mut new_expr, new_input) =
self.rewrite_expr(&[&expr], input, &expr_set, config)?;

plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input])
let new_plan = plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input])?;
Ok(Transformed::yes(new_plan))
}
}

impl OptimizerRule for CommonSubexprEliminate {
fn try_optimize(
&self,
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
_plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
let optimized_plan = match plan {
LogicalPlan::Projection(_)
| LogicalPlan::Sort(_)
| LogicalPlan::Filter(_) => Some(self.try_unary_plan(plan, config)?),
LogicalPlan::Window(window) => {
Some(self.try_optimize_window(window, config)?)
}
LogicalPlan::Aggregate(aggregate) => {
Some(self.try_optimize_aggregate(aggregate, config)?)
}
LogicalPlan::Join(_)
| LogicalPlan::CrossJoin(_)
| LogicalPlan::Repartition(_)
| LogicalPlan::Union(_)
| LogicalPlan::TableScan(_)
| LogicalPlan::Values(_)
| LogicalPlan::EmptyRelation(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Statement(_)
| LogicalPlan::DescribeTable(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Extension(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
| LogicalPlan::Unnest(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Prepare(_) => {
// apply the optimization to all inputs of the plan
utils::optimize_children(self, plan, config)?
}
};
internal_err!("Should call CommonSubexprEliminate::rewrite instead")
}

let original_schema = plan.schema().clone();
match optimized_plan {
Some(optimized_plan) if optimized_plan.schema() != &original_schema => {
// add an additional projection if the output schema changed.
Ok(Some(build_recover_project_plan(
&original_schema,
optimized_plan,
)?))
}
plan => Ok(plan),
}
fn supports_rewrite(&self) -> bool {
true
}

fn rewrite(
&self,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
// Note it needs to be down as this rule introduces new plans
plan.transform_down_with_subqueries(&|plan| {
let original_schema = plan.schema().clone();

let transformed_plan = match plan {
LogicalPlan::Projection(_)
| LogicalPlan::Sort(_)
| LogicalPlan::Filter(_) => self.optimize_unary_plan(&plan, config),
LogicalPlan::Window(window) => self.optimize_window(&window, config),
LogicalPlan::Aggregate(aggregate) => {
self.optimize_aggregate(&aggregate, config)
}
LogicalPlan::Join(_)
| LogicalPlan::CrossJoin(_)
| LogicalPlan::Repartition(_)
| LogicalPlan::Union(_)
| LogicalPlan::TableScan(_)
| LogicalPlan::Values(_)
| LogicalPlan::EmptyRelation(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Statement(_)
| LogicalPlan::DescribeTable(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Extension(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
| LogicalPlan::Unnest(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Prepare(_) => Ok(Transformed::no(plan)),
}?;
// If schema has changed, add a projection to recover the original schema
transformed_plan.map_data(|transformed_plan| {
if transformed_plan.schema() != &original_schema {
build_recover_project_plan(&original_schema, transformed_plan)
} else {
Ok(transformed_plan)
}
})
})
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -786,12 +795,12 @@ mod test {

use super::*;

fn assert_optimized_plan_eq(expected: &str, plan: &LogicalPlan) {
fn assert_optimized_plan_eq(expected: &str, plan: LogicalPlan) {
let optimizer = CommonSubexprEliminate {};
let optimized_plan = optimizer
.try_optimize(plan, &OptimizerContext::new())
.unwrap()
.expect("failed to optimize plan");
.rewrite(plan, &OptimizerContext::new())
.expect("failed to optimize plan")
.data;
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(expected, formatted_plan);
}
Expand Down Expand Up @@ -822,7 +831,7 @@ mod test {
\n Projection: test.a * (Int32(1) - test.b) AS test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand Down Expand Up @@ -875,7 +884,7 @@ mod test {
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a, AVG(test.b) AS col3, AVG(test.c) AS AVG(test.c), my_agg(test.b) AS col6, my_agg(test.c) AS my_agg(test.c)]]\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

// test: trafo after aggregate
let plan = LogicalPlanBuilder::from(table_scan.clone())
Expand All @@ -894,7 +903,7 @@ mod test {
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a]]\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

// test: transformation before aggregate
let plan = LogicalPlanBuilder::from(table_scan.clone())
Expand All @@ -911,7 +920,7 @@ mod test {
\n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

// test: common between agg and group
let plan = LogicalPlanBuilder::from(table_scan.clone())
Expand All @@ -928,7 +937,7 @@ mod test {
\n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

// test: all mixed
let plan = LogicalPlanBuilder::from(table_scan)
Expand All @@ -950,7 +959,7 @@ mod test {
\n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand All @@ -977,7 +986,7 @@ mod test {
\n Projection: UInt32(1) + table.test.col.a AS UInt32(1) + table.test.col.atable.test.col.aUInt32(1), table.test.col.a\
\n TableScan: table.test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand All @@ -997,7 +1006,7 @@ mod test {
\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand All @@ -1013,7 +1022,7 @@ mod test {
let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand All @@ -1031,7 +1040,7 @@ mod test {
\n Projection: Int32(1) + test.a\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);
Ok(())
}

Expand Down Expand Up @@ -1162,10 +1171,7 @@ mod test {
.build()
.unwrap();
let rule = CommonSubexprEliminate {};
let optimized_plan = rule
.try_optimize(&plan, &OptimizerContext::new())
.unwrap()
.unwrap();
let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap().data;

let schema = optimized_plan.schema();
let fields_with_datatypes: Vec<_> = schema
Expand Down Expand Up @@ -1204,7 +1210,7 @@ mod test {
\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand Down

0 comments on commit c510c7c

Please sign in to comment.