Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite CommonSubexprEliminate to avoid copies using TreeNode #10067

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 92 additions & 86 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;

use crate::utils::is_volatile_expression;
use crate::{utils, OptimizerConfig, OptimizerRule};
use crate::{OptimizerConfig, OptimizerRule};

use arrow::datatypes::{DataType, Field};
use datafusion_common::tree_node::{
Expand Down Expand Up @@ -192,21 +192,20 @@ impl CommonSubexprEliminate {
let rewrite_exprs =
self.rewrite_exprs_list(exprs_list, expr_set, &mut affected_id)?;

let mut new_input = self
.try_optimize(input, config)?
.unwrap_or_else(|| input.clone());
let mut new_input = self.rewrite(input.clone(), config)?.data;

if !affected_id.is_empty() {
new_input = build_common_expr_project_plan(new_input, affected_id, expr_set)?;
}

Ok((rewrite_exprs, new_input))
}

fn try_optimize_window(
fn optimize_window(
&self,
window: &Window,
config: &dyn OptimizerConfig,
) -> Result<LogicalPlan> {
) -> Result<Transformed<LogicalPlan>> {
let mut window_exprs = vec![];
let mut expr_set = ExprSet::default();

Expand Down Expand Up @@ -261,14 +260,14 @@ impl CommonSubexprEliminate {
plan = LogicalPlan::Window(Window::try_new(new_window_expr, Arc::new(plan))?);
}

Ok(plan)
Ok(Transformed::yes(plan))
}

fn try_optimize_aggregate(
fn optimize_aggregate(
&self,
aggregate: &Aggregate,
config: &dyn OptimizerConfig,
) -> Result<LogicalPlan> {
) -> Result<Transformed<LogicalPlan>> {
let Aggregate {
group_expr,
aggr_expr,
Expand Down Expand Up @@ -317,8 +316,10 @@ impl CommonSubexprEliminate {
})
.collect::<Result<Vec<Expr>>>()?;
// Since group_epxr changes, schema changes also. Use try_new method.
Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr)
.map(LogicalPlan::Aggregate)
let new_plan =
Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr)
.map(LogicalPlan::Aggregate)?;
Ok(Transformed::yes(new_plan))
} else {
let mut agg_exprs = vec![];

Expand Down Expand Up @@ -364,18 +365,17 @@ impl CommonSubexprEliminate {
agg_exprs,
)?);

Ok(LogicalPlan::Projection(Projection::try_new(
proj_exprs,
Arc::new(agg),
)?))
let new_plan =
LogicalPlan::Projection(Projection::try_new(proj_exprs, Arc::new(agg))?);
Ok(Transformed::yes(new_plan))
}
}

fn try_unary_plan(
fn optimize_unary_plan(
&self,
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<LogicalPlan> {
) -> Result<Transformed<LogicalPlan>> {
let expr = plan.expressions();
let inputs = plan.inputs();
let input = inputs[0];
Expand All @@ -388,64 +388,73 @@ impl CommonSubexprEliminate {
let (mut new_expr, new_input) =
self.rewrite_expr(&[&expr], input, &expr_set, config)?;

plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input])
let new_plan = plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input])?;
Ok(Transformed::yes(new_plan))
}
}

impl OptimizerRule for CommonSubexprEliminate {
fn try_optimize(
&self,
plan: &LogicalPlan,
config: &dyn OptimizerConfig,
_plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
let optimized_plan = match plan {
LogicalPlan::Projection(_)
| LogicalPlan::Sort(_)
| LogicalPlan::Filter(_) => Some(self.try_unary_plan(plan, config)?),
LogicalPlan::Window(window) => {
Some(self.try_optimize_window(window, config)?)
}
LogicalPlan::Aggregate(aggregate) => {
Some(self.try_optimize_aggregate(aggregate, config)?)
}
LogicalPlan::Join(_)
| LogicalPlan::CrossJoin(_)
| LogicalPlan::Repartition(_)
| LogicalPlan::Union(_)
| LogicalPlan::TableScan(_)
| LogicalPlan::Values(_)
| LogicalPlan::EmptyRelation(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Statement(_)
| LogicalPlan::DescribeTable(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Extension(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
| LogicalPlan::Unnest(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Prepare(_) => {
// apply the optimization to all inputs of the plan
utils::optimize_children(self, plan, config)?
}
};
internal_err!("Should call CommonSubexprEliminate::rewrite instead")
}

let original_schema = plan.schema().clone();
match optimized_plan {
Some(optimized_plan) if optimized_plan.schema() != &original_schema => {
// add an additional projection if the output schema changed.
Ok(Some(build_recover_project_plan(
&original_schema,
optimized_plan,
)?))
}
plan => Ok(plan),
}
fn supports_rewrite(&self) -> bool {
true
}

fn rewrite(
&self,
plan: LogicalPlan,
config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>, DataFusionError> {
// Note it needs to be down as this rule introduces new plans
plan.transform_down_with_subqueries(&|plan| {
let original_schema = plan.schema().clone();

let transformed_plan = match plan {
LogicalPlan::Projection(_)
| LogicalPlan::Sort(_)
| LogicalPlan::Filter(_) => self.optimize_unary_plan(&plan, config),
LogicalPlan::Window(window) => self.optimize_window(&window, config),
LogicalPlan::Aggregate(aggregate) => {
self.optimize_aggregate(&aggregate, config)
}
LogicalPlan::Join(_)
| LogicalPlan::CrossJoin(_)
| LogicalPlan::Repartition(_)
| LogicalPlan::Union(_)
| LogicalPlan::TableScan(_)
| LogicalPlan::Values(_)
| LogicalPlan::EmptyRelation(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Limit(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Statement(_)
| LogicalPlan::DescribeTable(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Extension(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
| LogicalPlan::Unnest(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Prepare(_) => Ok(Transformed::no(plan)),
}?;
// If schema has changed, add a projection to recover the original schema
transformed_plan.map_data(|transformed_plan| {
if transformed_plan.schema() != &original_schema {
build_recover_project_plan(&original_schema, transformed_plan)
} else {
Ok(transformed_plan)
}
})
})
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -786,12 +795,12 @@ mod test {

use super::*;

fn assert_optimized_plan_eq(expected: &str, plan: &LogicalPlan) {
fn assert_optimized_plan_eq(expected: &str, plan: LogicalPlan) {
let optimizer = CommonSubexprEliminate {};
let optimized_plan = optimizer
.try_optimize(plan, &OptimizerContext::new())
.unwrap()
.expect("failed to optimize plan");
.rewrite(plan, &OptimizerContext::new())
.expect("failed to optimize plan")
.data;
let formatted_plan = format!("{optimized_plan:?}");
assert_eq!(expected, formatted_plan);
}
Expand Down Expand Up @@ -822,7 +831,7 @@ mod test {
\n Projection: test.a * (Int32(1) - test.b) AS test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand Down Expand Up @@ -875,7 +884,7 @@ mod test {
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a, AVG(test.b) AS col3, AVG(test.c) AS AVG(test.c), my_agg(test.b) AS col6, my_agg(test.c) AS my_agg(test.c)]]\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

// test: trafo after aggregate
let plan = LogicalPlanBuilder::from(table_scan.clone())
Expand All @@ -894,7 +903,7 @@ mod test {
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a]]\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

// test: transformation before aggregate
let plan = LogicalPlanBuilder::from(table_scan.clone())
Expand All @@ -911,7 +920,7 @@ mod test {
\n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

// test: common between agg and group
let plan = LogicalPlanBuilder::from(table_scan.clone())
Expand All @@ -928,7 +937,7 @@ mod test {
\n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

// test: all mixed
let plan = LogicalPlanBuilder::from(table_scan)
Expand All @@ -950,7 +959,7 @@ mod test {
\n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand All @@ -977,7 +986,7 @@ mod test {
\n Projection: UInt32(1) + table.test.col.a AS UInt32(1) + table.test.col.atable.test.col.aUInt32(1), table.test.col.a\
\n TableScan: table.test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand All @@ -997,7 +1006,7 @@ mod test {
\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand All @@ -1013,7 +1022,7 @@ mod test {
let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand All @@ -1031,7 +1040,7 @@ mod test {
\n Projection: Int32(1) + test.a\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);
Ok(())
}

Expand Down Expand Up @@ -1162,10 +1171,7 @@ mod test {
.build()
.unwrap();
let rule = CommonSubexprEliminate {};
let optimized_plan = rule
.try_optimize(&plan, &OptimizerContext::new())
.unwrap()
.unwrap();
let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap().data;

let schema = optimized_plan.schema();
let fields_with_datatypes: Vec<_> = schema
Expand Down Expand Up @@ -1204,7 +1210,7 @@ mod test {
\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, &plan);
assert_optimized_plan_eq(expected, plan);

Ok(())
}
Expand Down
Loading