Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve TreeNode and LogicalPlan APIs to accept owned closures, deprecate transform_down_mut() and transform_up_mut() #10126

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-examples/examples/function_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
impl ScalarFunctionWrapper {
// replaces placeholders such as $1 with actual arguments (args[0]
fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> {
let result = expr.clone().transform(&|e| {
let result = expr.clone().transform(|e| {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think this is a really nice improvement in usability -- to not have to put a & in front of the closure is 💯

let r = match e {
Expr::Placeholder(placeholder) => {
let placeholder_position =
Expand Down
6 changes: 3 additions & 3 deletions datafusion-examples/examples/rewrite_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule {

impl MyAnalyzerRule {
fn analyze_plan(plan: LogicalPlan) -> Result<LogicalPlan> {
plan.transform(&|plan| {
plan.transform(|plan| {
Ok(match plan {
LogicalPlan::Filter(filter) => {
let predicate = Self::analyze_expr(filter.predicate.clone())?;
Expand All @@ -107,7 +107,7 @@ impl MyAnalyzerRule {
}

fn analyze_expr(expr: Expr) -> Result<Expr> {
expr.transform(&|expr| {
expr.transform(|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Literal(ScalarValue::Int64(i)) => {
Expand Down Expand Up @@ -163,7 +163,7 @@ impl OptimizerRule for MyOptimizerRule {

/// use rewrite_expr to modify the expression tree.
fn my_rewrite(expr: Expr) -> Result<Expr> {
expr.transform(&|expr| {
expr.transform(|expr| {
// closure is invoked for all sub expressions
Ok(match expr {
Expr::Between(Between {
Expand Down
105 changes: 62 additions & 43 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,6 @@ macro_rules! handle_transform_recursion {
}};
}

macro_rules! handle_transform_recursion_down {
($F_DOWN:expr, $F_CHILD:expr) => {{
$F_DOWN?.transform_children(|n| n.map_children($F_CHILD))
}};
}

macro_rules! handle_transform_recursion_up {
($SELF:expr, $F_CHILD:expr, $F_UP:expr) => {{
$SELF.map_children($F_CHILD)?.transform_parent(|n| $F_UP(n))
}};
}

/// Defines a visitable and rewriteable tree node. This trait is implemented
/// for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as well as expression
/// trees ([`PhysicalExpr`], [`Expr`]) in DataFusion.
Expand Down Expand Up @@ -137,61 +125,85 @@ pub trait TreeNode: Sized {
/// or run a check on the tree.
fn apply<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: &mut F,
mut f: F,
) -> Result<TreeNodeRecursion> {
f(self)?.visit_children(|| self.apply_children(|c| c.apply(f)))
fn apply_impl<N: TreeNode, F: FnMut(&N) -> Result<TreeNodeRecursion>>(
node: &N,
f: &mut F,
) -> Result<TreeNodeRecursion> {
f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
}

apply_impl(self, &mut f)
}

/// Convenience utility for writing optimizer rules: Recursively apply the
/// given function `f` to the tree in a bottom-up (post-order) fashion. When
/// `f` does not apply to a given node, it is left unchanged.
fn transform<F: Fn(Self) -> Result<Transformed<Self>>>(
fn transform<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: &F,
f: F,
) -> Result<Transformed<Self>> {
self.transform_up(f)
}

/// Convenience utility for writing optimizer rules: Recursively apply the
/// given function `f` to a node and then to its children (pre-order traversal).
/// When `f` does not apply to a given node, it is left unchanged.
fn transform_down<F: Fn(Self) -> Result<Transformed<Self>>>(
fn transform_down<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: &F,
mut f: F,
) -> Result<Transformed<Self>> {
handle_transform_recursion_down!(f(self), |c| c.transform_down(f))
fn transform_down_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
node: N,
f: &mut F,
) -> Result<Transformed<N>> {
f(node)?.transform_children(|n| n.map_children(|c| transform_down_impl(c, f)))
}

transform_down_impl(self, &mut f)
}

/// Convenience utility for writing optimizer rules: Recursively apply the
/// given mutable function `f` to a node and then to its children (pre-order
/// traversal). When `f` does not apply to a given node, it is left unchanged.
#[deprecated(since = "38.0.0", note = "Use `transform_down` instead")]
fn transform_down_mut<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: &mut F,
) -> Result<Transformed<Self>> {
handle_transform_recursion_down!(f(self), |c| c.transform_down_mut(f))
self.transform_down(f)
}

/// Convenience utility for writing optimizer rules: Recursively apply the
/// given function `f` to all children of a node, and then to the node itself
/// (post-order traversal). When `f` does not apply to a given node, it is
/// left unchanged.
fn transform_up<F: Fn(Self) -> Result<Transformed<Self>>>(
fn transform_up<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: &F,
mut f: F,
) -> Result<Transformed<Self>> {
handle_transform_recursion_up!(self, |c| c.transform_up(f), f)
fn transform_up_impl<N: TreeNode, F: FnMut(N) -> Result<Transformed<N>>>(
node: N,
f: &mut F,
) -> Result<Transformed<N>> {
node.map_children(|c| transform_up_impl(c, f))?
.transform_parent(f)
}

transform_up_impl(self, &mut f)
}

/// Convenience utility for writing optimizer rules: Recursively apply the
/// given mutable function `f` to all children of a node, and then to the
/// node itself (post-order traversal). When `f` does not apply to a given
/// node, it is left unchanged.
#[deprecated(since = "38.0.0", note = "Use `transform_up` instead")]
fn transform_up_mut<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: &mut F,
) -> Result<Transformed<Self>> {
handle_transform_recursion_up!(self, |c| c.transform_up_mut(f), f)
self.transform_up(f)
}

/// Transforms the tree using `f_down` while traversing the tree top-down
Expand All @@ -200,8 +212,8 @@ pub trait TreeNode: Sized {
///
/// Use this method if you want to start the `f_up` process right where `f_down` jumps.
/// This can make the whole process faster by reducing the number of `f_up` steps.
/// If you don't need this, it's just like using `transform_down_mut` followed by
/// `transform_up_mut` on the same tree.
/// If you don't need this, it's just like using `transform_down` followed by
/// `transform_up` on the same tree.
///
/// Consider the following tree structure:
/// ```text
Expand Down Expand Up @@ -288,22 +300,34 @@ pub trait TreeNode: Sized {
FU: FnMut(Self) -> Result<Transformed<Self>>,
>(
self,
f_down: &mut FD,
f_up: &mut FU,
mut f_down: FD,
mut f_up: FU,
) -> Result<Transformed<Self>> {
handle_transform_recursion!(
f_down(self),
|c| c.transform_down_up(f_down, f_up),
f_up
)
fn transform_down_up_impl<
N: TreeNode,
FD: FnMut(N) -> Result<Transformed<N>>,
FU: FnMut(N) -> Result<Transformed<N>>,
>(
node: N,
f_down: &mut FD,
f_up: &mut FU,
) -> Result<Transformed<N>> {
handle_transform_recursion!(
f_down(node),
|c| transform_down_up_impl(c, f_down, f_up),
f_up
)
}

transform_down_up_impl(self, &mut f_down, &mut f_up)
}

/// Returns true if `f` returns true for node in the tree.
///
/// Stops recursion as soon as a matching node is found
fn exists<F: FnMut(&Self) -> bool>(&self, mut f: F) -> bool {
let mut found = false;
self.apply(&mut |n| {
self.apply(|n| {
Ok(if f(n) {
found = true;
TreeNodeRecursion::Stop
Expand Down Expand Up @@ -439,9 +463,7 @@ impl TreeNodeRecursion {
/// This struct is used by tree transformation APIs such as
/// - [`TreeNode::rewrite`],
/// - [`TreeNode::transform_down`],
/// - [`TreeNode::transform_down_mut`],
/// - [`TreeNode::transform_up`],
/// - [`TreeNode::transform_up_mut`],
/// - [`TreeNode::transform_down_up`]
///
/// to control the transformation and return the transformed result.
Expand Down Expand Up @@ -1362,7 +1384,7 @@ mod tests {
fn $NAME() -> Result<()> {
let tree = test_tree();
let mut visits = vec![];
tree.apply(&mut |node| {
tree.apply(|node| {
visits.push(format!("f_down({})", node.data));
$F(node)
})?;
Expand Down Expand Up @@ -1451,10 +1473,7 @@ mod tests {
#[test]
fn $NAME() -> Result<()> {
let tree = test_tree();
assert_eq!(
tree.transform_down_up(&mut $F_DOWN, &mut $F_UP,)?,
$EXPECTED_TREE
);
assert_eq!(tree.transform_down_up($F_DOWN, $F_UP,)?, $EXPECTED_TREE);

Ok(())
}
Expand All @@ -1466,7 +1485,7 @@ mod tests {
#[test]
fn $NAME() -> Result<()> {
let tree = test_tree();
assert_eq!(tree.transform_down_mut(&mut $F)?, $EXPECTED_TREE);
assert_eq!(tree.transform_down($F)?, $EXPECTED_TREE);

Ok(())
}
Expand All @@ -1478,7 +1497,7 @@ mod tests {
#[test]
fn $NAME() -> Result<()> {
let tree = test_tree();
assert_eq!(tree.transform_up_mut(&mut $F)?, $EXPECTED_TREE);
assert_eq!(tree.transform_up($F)?, $EXPECTED_TREE);

Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use object_store::{ObjectMeta, ObjectStore};
/// was performed
pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool {
let mut is_applicable = true;
expr.apply(&mut |expr| {
expr.apply(|expr| {
match expr {
Expr::Column(Column { ref name, .. }) => {
is_applicable &= col_names.contains(name);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/coalesce_batches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl PhysicalOptimizerRule for CoalesceBatches {
}

let target_batch_size = config.execution.batch_size;
plan.transform_up(&|plan| {
plan.transform_up(|plan| {
let plan_any = plan.as_any();
// The goal here is to detect operators that could produce small batches and only
// wrap those ones with a CoalesceBatchesExec operator. An alternate approach here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate {
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_down(&|plan| {
plan.transform_down(|plan| {
let transformed =
plan.as_any()
.downcast_ref::<AggregateExec>()
Expand Down Expand Up @@ -179,7 +179,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs {
fn discard_column_index(group_expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
group_expr
.clone()
.transform(&|expr| {
.transform(|expr| {
let normalized_form: Option<Arc<dyn PhysicalExpr>> =
match expr.as_any().downcast_ref::<Column>() {
Some(column) => Some(Arc::new(Column::new(column.name(), 0))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl PhysicalOptimizerRule for OptimizeAggregateOrder {
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_up(&get_common_requirement_of_aggregate_input)
plan.transform_up(get_common_requirement_of_aggregate_input)
.data()
}

Expand Down
12 changes: 6 additions & 6 deletions datafusion/core/src/physical_optimizer/enforce_distribution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,12 @@ impl PhysicalOptimizerRule for EnforceDistribution {
// Run a top-down process to adjust input key ordering recursively
let plan_requirements = PlanWithKeyRequirements::new_default(plan);
let adjusted = plan_requirements
.transform_down(&adjust_input_keys_ordering)
.transform_down(adjust_input_keys_ordering)
.data()?;
adjusted.plan
} else {
// Run a bottom-up process
plan.transform_up(&|plan| {
plan.transform_up(|plan| {
Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?))
})
.data()?
Expand All @@ -211,7 +211,7 @@ impl PhysicalOptimizerRule for EnforceDistribution {
let distribution_context = DistributionContext::new_default(adjusted);
// Distribution enforcement needs to be applied bottom-up.
let distribution_context = distribution_context
.transform_up(&|distribution_context| {
.transform_up(|distribution_context| {
ensure_distribution(distribution_context, config)
})
.data()?;
Expand Down Expand Up @@ -1768,22 +1768,22 @@ pub(crate) mod tests {
let plan_requirements =
PlanWithKeyRequirements::new_default($PLAN.clone());
let adjusted = plan_requirements
.transform_down(&adjust_input_keys_ordering)
.transform_down(adjust_input_keys_ordering)
.data()
.and_then(check_integrity)?;
// TODO: End state payloads will be checked here.
adjusted.plan
} else {
// Run reorder_join_keys_to_inputs rule
$PLAN.clone().transform_up(&|plan| {
$PLAN.clone().transform_up(|plan| {
Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?))
})
.data()?
};

// Then run ensure_distribution rule
DistributionContext::new_default(adjusted)
.transform_up(&|distribution_context| {
.transform_up(|distribution_context| {
ensure_distribution(distribution_context, &config)
})
.data()
Expand Down
Loading