From 83d194e84a112d6e1fca41fd8571a6038efbef57 Mon Sep 17 00:00:00 2001 From: blaginin Date: Tue, 29 Oct 2024 21:40:11 +0000 Subject: [PATCH] Remove macros in favour of `LegacyRewriter` --- datafusion/common/src/tree_node.rs | 283 +++++++++++++++++------------ 1 file changed, 166 insertions(+), 117 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 09a4ad619a48..250d25a8e98d 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -18,6 +18,7 @@ //! [`TreeNode`] for visiting and rewriting expression and plan trees use crate::Result; +use std::marker::PhantomData; use std::sync::Arc; /// These macros are used to determine continuation during transforming traversals. @@ -912,104 +913,6 @@ macro_rules! map_until_stop_and_collect { }} } -macro_rules! rewrite_recursive { - ($START:ident, $NAME:ident, $TRANSFORM_UP:expr, $TRANSFORM_DOWN:expr) => { - let mut queue = vec![ProcessingState::NotStarted($START)]; - - while let Some(item) = queue.pop() { - match item { - ProcessingState::NotStarted($NAME) => { - let node = $TRANSFORM_DOWN?; - - queue.push(match node.tnr { - TreeNodeRecursion::Continue => { - ProcessingState::ProcessingChildren { - non_processed_children: node - .data - .arc_children() - .into_iter() - .cloned() - .rev() - .collect(), - item: node, - processed_children: vec![], - } - } - TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren( - node.with_tnr(TreeNodeRecursion::Continue), - ), - TreeNodeRecursion::Stop => { - ProcessingState::ProcessedAllChildren(node) - } - }) - } - ProcessingState::ProcessingChildren { - mut item, - mut non_processed_children, - mut processed_children, - } => match item.tnr { - TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { - if let Some(non_processed_item) = non_processed_children.pop() { - queue.push(ProcessingState::ProcessingChildren { - item, - non_processed_children, - processed_children, - }); - queue.push(ProcessingState::NotStarted(non_processed_item)); - } else { - item.transformed |= - processed_children.iter().any(|item| item.transformed); - item.data = item.data.with_new_arc_children( - processed_children.into_iter().map(|c| c.data).collect(), - )?; - queue.push(ProcessingState::ProcessedAllChildren(item)) - } - } - TreeNodeRecursion::Stop => { - processed_children.extend( - non_processed_children - .into_iter() - .rev() - .map(Transformed::no), - ); - item.transformed |= - processed_children.iter().any(|item| item.transformed); - item.data = item.data.with_new_arc_children( - processed_children.into_iter().map(|c| c.data).collect(), - )?; - queue.push(ProcessingState::ProcessedAllChildren(item)); - } - }, - ProcessingState::ProcessedAllChildren(node) => { - let node = node.transform_parent(|$NAME| $TRANSFORM_UP)?; - - if let Some(ProcessingState::ProcessingChildren { - item: mut parent_node, - non_processed_children, - mut processed_children, - .. - }) = queue.pop() - { - parent_node.tnr = node.tnr; - processed_children.push(node); - - queue.push(ProcessingState::ProcessingChildren { - item: parent_node, - non_processed_children, - processed_children, - }) - } else { - debug_assert_eq!(queue.len(), 0); - return Ok(node); - } - } - } - } - - unreachable!(); - }; -} - /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. /// /// # Example @@ -1063,6 +966,59 @@ pub trait DynTreeNode { ) -> Result>; } +pub struct LegacyRewriter< + FD: FnMut(Node) -> Result>, + FU: FnMut(Node) -> Result>, + Node: TreeNode, +> { + f_down_func: FD, + f_up_func: FU, + _node: PhantomData, +} + +impl< + FD: FnMut(Node) -> Result>, + FU: FnMut(Node) -> Result>, + Node: TreeNode, + > LegacyRewriter +{ + pub fn new(f_down_func: FD, f_up_func: FU) -> Self { + Self { + f_down_func, + f_up_func, + _node: PhantomData, + } + } +} +impl< + FD: FnMut(Node) -> Result>, + FU: FnMut(Node) -> Result>, + Node: TreeNode, + > TreeNodeRewriter for LegacyRewriter +{ + type Node = Node; + + fn f_down(&mut self, node: Self::Node) -> Result> { + (self.f_down_func)(node) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + (self.f_up_func)(node) + } +} + +macro_rules! update_rec_node { + ($NAME:ident, $CHILDREN:ident) => {{ + $NAME.transformed |= $CHILDREN.iter().any(|item| item.transformed); + + $NAME.data = $NAME + .data + .with_new_arc_children($CHILDREN.into_iter().map(|c| c.data).collect())?; + + $NAME + }}; +} + /// Blanket implementation for any `Arc` where `T` implements [`DynTreeNode`] /// (such as [`Arc`]). impl TreeNode for Arc { @@ -1102,43 +1058,134 @@ impl TreeNode for Arc { FU: FnMut(Self) -> Result>, >( self, - mut f_down: FD, - mut f_up: FU, + f_down: FD, + f_up: FU, ) -> Result> { - rewrite_recursive!(self, node, f_up(node), f_down(node)); + self.rewrite(&mut LegacyRewriter::new(f_down, f_up)) } fn transform_down Result>>( self, f: F, ) -> Result> { - self.transform_down_up(f, |node| Ok(Transformed::no(node))) + self.rewrite(&mut LegacyRewriter::new(f, |node| { + Ok(Transformed::no(node)) + })) } fn transform_up Result>>( self, f: F, ) -> Result> { - self.transform_down_up(|node| Ok(Transformed::no(node)), f) + self.rewrite(&mut LegacyRewriter::new( + |node| Ok(Transformed::no(node)), + f, + )) } fn rewrite>( self, rewriter: &mut R, ) -> Result> { - rewrite_recursive!(self, node, rewriter.f_up(node), rewriter.f_down(node)); + let mut stack = vec![ProcessingState::NotStarted(self)]; + + while let Some(item) = stack.pop() { + match item { + ProcessingState::NotStarted(node) => { + let node = rewriter.f_down(node)?; + + stack.push(match node.tnr { + TreeNodeRecursion::Continue => { + ProcessingState::ProcessingChildren { + non_processed_children: node + .data + .arc_children() + .into_iter() + .cloned() + .rev() + .collect(), + item: node, + processed_children: vec![], + } + } + TreeNodeRecursion::Jump => ProcessingState::ProcessedAllChildren( + node.with_tnr(TreeNodeRecursion::Continue), + ), + TreeNodeRecursion::Stop => { + ProcessingState::ProcessedAllChildren(node) + } + }) + } + ProcessingState::ProcessingChildren { + mut item, + mut non_processed_children, + mut processed_children, + } => match item.tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + if let Some(non_processed_item) = non_processed_children.pop() { + stack.push(ProcessingState::ProcessingChildren { + item, + non_processed_children, + processed_children, + }); + stack.push(ProcessingState::NotStarted(non_processed_item)); + } else { + stack.push(ProcessingState::ProcessedAllChildren( + update_rec_node!(item, processed_children), + )) + } + } + TreeNodeRecursion::Stop => { + processed_children.extend( + non_processed_children + .into_iter() + .rev() + .map(Transformed::no), + ); + stack.push(ProcessingState::ProcessedAllChildren( + update_rec_node!(item, processed_children), + )); + } + }, + ProcessingState::ProcessedAllChildren(node) => { + let node = node.transform_parent(|n| rewriter.f_up(n))?; + + if let Some(ProcessingState::ProcessingChildren { + item: mut parent_node, + non_processed_children, + mut processed_children, + .. + }) = stack.pop() + { + parent_node.tnr = node.tnr; + processed_children.push(node); + + stack.push(ProcessingState::ProcessingChildren { + item: parent_node, + non_processed_children, + processed_children, + }) + } else { + debug_assert_eq!(stack.len(), 0); + return Ok(node); + } + } + } + } + + unreachable!(); } fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>( &'n self, visitor: &mut V, ) -> Result { - let mut queue = vec![VisitingState::NotStarted(self)]; + let mut stack = vec![VisitingState::NotStarted(self)]; - while let Some(item) = queue.pop() { + while let Some(item) = stack.pop() { match item { VisitingState::NotStarted(item) => { let tnr = visitor.f_down(item)?; - queue.push(match tnr { + stack.push(match tnr { TreeNodeRecursion::Continue => VisitingState::VisitingChildren { non_processed_children: item .arc_children() @@ -1165,14 +1212,14 @@ impl TreeNode for Arc { } => match tnr { TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { if let Some(non_processed_item) = non_processed_children.pop() { - queue.push(VisitingState::VisitingChildren { + stack.push(VisitingState::VisitingChildren { item, non_processed_children, tnr, }); - queue.push(VisitingState::NotStarted(non_processed_item)); + stack.push(VisitingState::NotStarted(non_processed_item)); } else { - queue.push(VisitingState::VisitedAllChildren { item, tnr }); + stack.push(VisitingState::VisitedAllChildren { item, tnr }); } } TreeNodeRecursion::Stop => { @@ -1186,15 +1233,15 @@ impl TreeNode for Arc { item, non_processed_children, .. - }) = queue.pop() + }) = stack.pop() { - queue.push(VisitingState::VisitingChildren { + stack.push(VisitingState::VisitingChildren { item, non_processed_children, tnr, }); } else { - debug_assert_eq!(queue.len(), 0); + debug_assert_eq!(stack.len(), 0); return Ok(tnr); } } @@ -1208,30 +1255,32 @@ impl TreeNode for Arc { #[derive(Debug)] enum ProcessingState { NotStarted(T), - // f_down is called + // ← at this point, f_down is called ProcessingChildren { item: Transformed, non_processed_children: Vec, processed_children: Vec>, }, + // ← at this point, all children are processed ProcessedAllChildren(Transformed), - // f_up is called + // ← at this point, f_up is called } #[derive(Debug)] enum VisitingState<'a, T> { NotStarted(&'a T), - // f_down is called + // ← at this point, f_down is called VisitingChildren { item: &'a T, non_processed_children: Vec<&'a T>, tnr: TreeNodeRecursion, }, + // ← at this point, all children are visited VisitedAllChildren { item: &'a T, tnr: TreeNodeRecursion, }, - // f_up is called + // ← at this point, f_up is called } /// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for