Skip to content

Commit

Permalink
Refactor TreeNode and LogicalPlan apply, transform, transform_up,…
Browse files Browse the repository at this point in the history
… transform_down and transform_down_up APIs to accept owned closures
  • Loading branch information
peter-toth committed Apr 19, 2024
1 parent b2f6309 commit d781ca7
Show file tree
Hide file tree
Showing 49 changed files with 229 additions and 179 deletions.
2 changes: 1 addition & 1 deletion datafusion-examples/examples/function_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
impl ScalarFunctionWrapper {
// replaces placeholders such as $1 with actual arguments (args[0]
fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> {
let result = expr.clone().transform(&mut |e| {
let result = expr.clone().transform(|e| {
let r = match e {
Expr::Placeholder(placeholder) => {
let placeholder_position =
Expand Down
6 changes: 3 additions & 3 deletions datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule {

impl MyAnalyzerRule {
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
plan.transform(&mut |plan| {
plan.transform(|plan| {
Ok(match plan {
LogicalPlan::Filter(filter) => {
let predicate = Self::analyze_expr(filter.predicate.clone())?;
Expand All @@ -107,7 +107,7 @@ impl MyAnalyzerRule {
}

fn analyze_expr(expr: Expr) -> Result<Expr> {
expr.transform(&mut |expr| {
expr.transform(|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Literal(ScalarValue::Int64(i)) => {
Expand Down Expand Up @@ -163,7 +163,7 @@ impl OptimizerRule for MyOptimizerRule {

/// use rewrite_expr to modify the expression tree.
fn my_rewrite(expr: Expr) -> Result<Expr> {
expr.transform(&mut |expr| {
expr.transform(|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Between(Between {
Expand Down
87 changes: 53 additions & 34 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,6 @@ macro_rules! handle_transform_recursion {
}};
}

macro_rules! handle_transform_recursion_down {
($F_DOWN:expr, $F_CHILD:expr) => {{
$F_DOWN?.transform_children(|n| n.map_children($F_CHILD))
}};
}

macro_rules! handle_transform_recursion_up {
($SELF:expr, $F_CHILD:expr, $F_UP:expr) => {{
$SELF.map_children($F_CHILD)?.transform_parent(|n| $F_UP(n))
}};
}

/// Defines a visitable and rewriteable tree node. This trait is implemented
/// for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as well as expression
/// trees ([`PhysicalExpr`], [`Expr`]) in DataFusion.
Expand Down Expand Up @@ -137,17 +125,24 @@ pub trait TreeNode: Sized {
/// or run a check on the tree.
fn apply<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: &mut F,
mut f: F,
) -> Result<TreeNodeRecursion> {
f(self)?.visit_children(|| self.apply_children(|c| c.apply(f)))
fn apply_impl<N: TreeNode, F: FnMut(&N) -> Result<TreeNodeRecursion>>(
node: &N,
f: &mut F,
) -> Result<TreeNodeRecursion> {
f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
}

apply_impl(self, &mut f)
}

/// Convenience utility for writing optimizer rules: Recursively apply the
/// given function `f` to the tree in a bottom-up (post-order) fashion. When
/// `f` does not apply to a given node, it is left unchanged.
fn transform<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: &mut F,
f: F,
) -> Result<Transformed<Self>> {
self.transform_up(f)
}
Expand All @@ -157,9 +152,16 @@ pub trait TreeNode: Sized {
/// When `f` does not apply to a given node, it is left unchanged.
fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: &mut F,
mut f: F,
) -> Result<Transformed<Self>> {
handle_transform_recursion_down!(f(self), |c| c.transform_down(f))
fn transform_down_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
node: N,
f: &mut F,
) -> Result<Transformed<N>> {
f(node)?.transform_children(|n| n.map_children(|c| transform_down_impl(c, f)))
}

transform_down_impl(self, &mut f)
}

/// Convenience utility for writing optimizer rules: Recursively apply the
Expand All @@ -179,9 +181,17 @@ pub trait TreeNode: Sized {
/// left unchanged.
fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: &mut F,
mut f: F,
) -> Result<Transformed<Self>> {
handle_transform_recursion_up!(self, |c| c.transform_up(f), f)
fn transform_up_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
node: N,
f: &mut F,
) -> Result<Transformed<N>> {
node.map_children(|c| transform_up_impl(c, f))?
.transform_parent(f)
}

transform_up_impl(self, &mut f)
}

/// Convenience utility for writing optimizer rules: Recursively apply the
Expand Down Expand Up @@ -290,22 +300,34 @@ pub trait TreeNode: Sized {
FU: FnMut(Self) -> Result<Transformed<Self>>,
>(
self,
f_down: &mut FD,
f_up: &mut FU,
mut f_down: FD,
mut f_up: FU,
) -> Result<Transformed<Self>> {
handle_transform_recursion!(
f_down(self),
|c| c.transform_down_up(f_down, f_up),
f_up
)
fn transform_down_up_impl<
N: TreeNode,
FD: FnMut(N) -> Result<Transformed<N>>,
FU: FnMut(N) -> Result<Transformed<N>>,
>(
node: N,
f_down: &mut FD,
f_up: &mut FU,
) -> Result<Transformed<N>> {
handle_transform_recursion!(
f_down(node),
|c| transform_down_up_impl(c, f_down, f_up),
f_up
)
}

transform_down_up_impl(self, &mut f_down, &mut f_up)
}

/// Returns true if `f` returns true for node in the tree.
///
/// Stops recursion as soon as a matching node is found
fn exists<F: FnMut(&Self) -> bool>(&self, mut f: F) -> bool {
let mut found = false;
self.apply(&mut |n| {
self.apply(|n| {
Ok(if f(n) {
found = true;
TreeNodeRecursion::Stop
Expand Down Expand Up @@ -1362,7 +1384,7 @@ mod tests {
fn $NAME() -> Result<()> {
let tree = test_tree();
let mut visits = vec![];
tree.apply(&mut |node| {
tree.apply(|node| {
visits.push(format!("f_down({})", node.data));
$F(node)
})?;
Expand Down Expand Up @@ -1451,10 +1473,7 @@ mod tests {
#[test]
fn $NAME() -> Result<()> {
let tree = test_tree();
assert_eq!(
tree.transform_down_up(&mut $F_DOWN, &mut $F_UP,)?,
$EXPECTED_TREE
);
assert_eq!(tree.transform_down_up($F_DOWN, $F_UP,)?, $EXPECTED_TREE);

Ok(())
}
Expand All @@ -1466,7 +1485,7 @@ mod tests {
#[test]
fn $NAME() -> Result<()> {
let tree = test_tree();
assert_eq!(tree.transform_down(&mut $F)?, $EXPECTED_TREE);
assert_eq!(tree.transform_down($F)?, $EXPECTED_TREE);

Ok(())
}
Expand All @@ -1478,7 +1497,7 @@ mod tests {
#[test]
fn $NAME() -> Result<()> {
let tree = test_tree();
assert_eq!(tree.transform_up(&mut $F)?, $EXPECTED_TREE);
assert_eq!(tree.transform_up($F)?, $EXPECTED_TREE);

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use object_store::{ObjectMeta, ObjectStore};
/// was performed
pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
let mut is_applicable = true;
expr.apply(&mut |expr| {
expr.apply(|expr| {
match expr {
Expr::Column(Column { ref name, .. }) => {
is_applicable &= col_names.contains(name);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/coalesce_batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl PhysicalOptimizerRule for CoalesceBatches {
}

let target_batch_size = config.execution.batch_size;
plan.transform_up(&mut |plan| {
plan.transform_up(|plan| {
let plan_any = plan.as_any();
// The goal here is to detect operators that could produce small batches and only
// wrap those ones with a CoalesceBatchesExec operator. An alternate approach here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_down(&mut |plan| {
plan.transform_down(|plan| {
let transformed =
plan.as_any()
.downcast_ref::<AggregateExec>()
Expand Down Expand Up @@ -179,7 +179,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs {
fn discard_column_index(group_expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
group_expr
.clone()
.transform(&mut |expr| {
.transform(|expr| {
let normalized_form: Option<Arc<dyn PhysicalExpr>> =
match expr.as_any().downcast_ref::<Column>() {
Some(column) => Some(Arc::new(Column::new(column.name(), 0))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder {
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_up(&mut get_common_requirement_of_aggregate_input)
plan.transform_up(get_common_requirement_of_aggregate_input)
.data()
}

Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/src/physical_optimizer/enforce_distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,12 @@ impl PhysicalOptimizerRule for EnforceDistribution {
// Run a top-down process to adjust input key ordering recursively
let plan_requirements = PlanWithKeyRequirements::new_default(plan);
let adjusted = plan_requirements
.transform_down(&mut adjust_input_keys_ordering)
.transform_down(adjust_input_keys_ordering)
.data()?;
adjusted.plan
} else {
// Run a bottom-up process
plan.transform_up(&mut |plan| {
plan.transform_up(|plan| {
Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?))
})
.data()?
Expand All @@ -211,7 +211,7 @@ impl PhysicalOptimizerRule for EnforceDistribution {
let distribution_context = DistributionContext::new_default(adjusted);
// Distribution enforcement needs to be applied bottom-up.
let distribution_context = distribution_context
.transform_up(&mut |distribution_context| {
.transform_up(|distribution_context| {
ensure_distribution(distribution_context, config)
})
.data()?;
Expand Down Expand Up @@ -1768,22 +1768,22 @@ pub(crate) mod tests {
let plan_requirements =
PlanWithKeyRequirements::new_default($PLAN.clone());
let adjusted = plan_requirements
.transform_down(&mut adjust_input_keys_ordering)
.transform_down(adjust_input_keys_ordering)
.data()
.and_then(check_integrity)?;
// TODO: End state payloads will be checked here.
adjusted.plan
} else {
// Run reorder_join_keys_to_inputs rule
$PLAN.clone().transform_up(&mut |plan| {
$PLAN.clone().transform_up(|plan| {
Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?))
})
.data()?
};

// Then run ensure_distribution rule
DistributionContext::new_default(adjusted)
.transform_up(&mut |distribution_context| {
.transform_up(|distribution_context| {
ensure_distribution(distribution_context, &config)
})
.data()
Expand Down
20 changes: 9 additions & 11 deletions datafusion/core/src/physical_optimizer/enforce_sorting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,12 @@ impl PhysicalOptimizerRule for EnforceSorting {
let plan_requirements = PlanWithCorrespondingSort::new_default(plan);
// Execute a bottom-up traversal to enforce sorting requirements,
// remove unnecessary sorts, and optimize sort-sensitive operators:
let adjusted = plan_requirements.transform_up(&mut ensure_sorting)?.data;
let adjusted = plan_requirements.transform_up(ensure_sorting)?.data;
let new_plan = if config.optimizer.repartition_sorts {
let plan_with_coalesce_partitions =
PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan);
let parallel = plan_with_coalesce_partitions
.transform_up(&mut parallelize_sorts)
.transform_up(parallelize_sorts)
.data()?;
parallel.plan
} else {
Expand All @@ -174,7 +174,7 @@ impl PhysicalOptimizerRule for EnforceSorting {

let plan_with_pipeline_fixer = OrderPreservationContext::new_default(new_plan);
let updated_plan = plan_with_pipeline_fixer
.transform_up(&mut |plan_with_pipeline_fixer| {
.transform_up(|plan_with_pipeline_fixer| {
replace_with_order_preserving_variants(
plan_with_pipeline_fixer,
false,
Expand All @@ -188,13 +188,11 @@ impl PhysicalOptimizerRule for EnforceSorting {
// missed by the bottom-up traversal:
let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan);
assign_initial_requirements(&mut sort_pushdown);
let adjusted = sort_pushdown.transform_down(&mut pushdown_sorts)?.data;
let adjusted = sort_pushdown.transform_down(pushdown_sorts)?.data;

adjusted
.plan
.transform_up(&mut |plan| {
Ok(Transformed::yes(replace_with_partial_sort(plan)?))
})
.transform_up(|plan| Ok(Transformed::yes(replace_with_partial_sort(plan)?)))
.data()
}

Expand Down Expand Up @@ -683,7 +681,7 @@ mod tests {
{
let plan_requirements = PlanWithCorrespondingSort::new_default($PLAN.clone());
let adjusted = plan_requirements
.transform_up(&mut ensure_sorting)
.transform_up(ensure_sorting)
.data()
.and_then(check_integrity)?;
// TODO: End state payloads will be checked here.
Expand All @@ -692,7 +690,7 @@ mod tests {
let plan_with_coalesce_partitions =
PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan);
let parallel = plan_with_coalesce_partitions
.transform_up(&mut parallelize_sorts)
.transform_up(parallelize_sorts)
.data()
.and_then(check_integrity)?;
// TODO: End state payloads will be checked here.
Expand All @@ -703,7 +701,7 @@ mod tests {

let plan_with_pipeline_fixer = OrderPreservationContext::new_default(new_plan);
let updated_plan = plan_with_pipeline_fixer
.transform_up(&mut |plan_with_pipeline_fixer| {
.transform_up(|plan_with_pipeline_fixer| {
replace_with_order_preserving_variants(
plan_with_pipeline_fixer,
false,
Expand All @@ -718,7 +716,7 @@ mod tests {
let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan);
assign_initial_requirements(&mut sort_pushdown);
sort_pushdown
.transform_down(&mut pushdown_sorts)
.transform_down(pushdown_sorts)
.data()
.and_then(check_integrity)?;
// TODO: End state payloads will be checked here.
Expand Down
Loading

0 comments on commit d781ca7

Please sign in to comment.