From 97c34a7ccab4e9ca74bc2da36ab7ea84910c08e2 Mon Sep 17 00:00:00 2001 From: lutengda <18346072982@163.com> Date: Fri, 29 Nov 2024 15:10:55 +0800 Subject: [PATCH] add "can_be_pushed_down" in AggregateFunction --- datafusion/core/src/physical_planner.rs | 1 + .../core/tests/provider_aggregation_pushdown.rs | 2 ++ datafusion/expr/src/expr.rs | 5 +++++ datafusion/expr/src/expr_fn.rs | 12 ++++++++++++ datafusion/expr/src/tree_node/expr.rs | 2 ++ .../optimizer/src/analyzer/count_wildcard_rule.rs | 2 ++ datafusion/optimizer/src/analyzer/type_coercion.rs | 6 +++++- datafusion/optimizer/src/push_down_projection.rs | 1 + .../optimizer/src/single_distinct_to_groupby.rs | 3 +++ datafusion/proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/mod.rs | 3 +++ datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sql/src/expr/function.rs | 2 +- datafusion/sql/src/expr/mod.rs | 4 +++- datafusion/sql/src/utils.rs | 2 ++ datafusion/substrait/src/logical_plan/consumer.rs | 1 + datafusion/substrait/src/logical_plan/producer.rs | 2 +- 17 files changed, 46 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 97e873d204c8..c0384ee47754 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1650,6 +1650,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( args, filter, order_by, + .. }) => { let args = args .iter() diff --git a/datafusion/core/tests/provider_aggregation_pushdown.rs b/datafusion/core/tests/provider_aggregation_pushdown.rs index a3725f849bf1..3098ef234766 100644 --- a/datafusion/core/tests/provider_aggregation_pushdown.rs +++ b/datafusion/core/tests/provider_aggregation_pushdown.rs @@ -249,6 +249,7 @@ impl TableProvider for CustomAggregationProvider { distinct, filter, order_by, + can_be_pushed_down, }) => { let support_agg_func = match fun { aggregate_function::AggregateFunction::Count => true, @@ -263,6 +264,7 @@ impl TableProvider for CustomAggregationProvider { && !distinct && filter.is_none() && order_by.is_none() + && *can_be_pushed_down } _ => false, } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 55bfea7375eb..6128dd942da4 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -425,6 +425,8 @@ pub struct AggregateFunction { pub filter: Option>, /// Optional ordering pub order_by: Option>, + /// Whether it can be pushed down + pub can_be_pushed_down: bool, } impl AggregateFunction { @@ -434,6 +436,7 @@ impl AggregateFunction { distinct: bool, filter: Option>, order_by: Option>, + can_be_pushed_down: bool, ) -> Self { Self { fun, @@ -441,6 +444,7 @@ impl AggregateFunction { distinct, filter, order_by, + can_be_pushed_down, } } } @@ -1364,6 +1368,7 @@ fn create_name(e: &Expr) -> Result { args, filter, order_by, + .. }) => { let mut name = create_function_name(&fun.to_string(), *distinct, args)?; if let Some(fe) = filter { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 5b2ec735c0ac..98db25fe3392 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -111,6 +111,7 @@ pub fn min(expr: Expr) -> Expr { false, None, None, + false, )) } @@ -122,6 +123,7 @@ pub fn max(expr: Expr) -> Expr { false, None, None, + false, )) } @@ -133,6 +135,7 @@ pub fn sum(expr: Expr) -> Expr { false, None, None, + false, )) } @@ -144,6 +147,7 @@ pub fn avg(expr: Expr) -> Expr { false, None, None, + false, )) } @@ -155,6 +159,7 @@ pub fn count(expr: Expr) -> Expr { false, None, None, + true, )) } @@ -211,6 +216,7 @@ pub fn count_distinct(expr: Expr) -> Expr { true, None, None, + false, )) } @@ -263,6 +269,7 @@ pub fn approx_distinct(expr: Expr) -> Expr { false, None, None, + false, )) } @@ -274,6 +281,7 @@ pub fn median(expr: Expr) -> Expr { false, None, None, + false, )) } @@ -285,6 +293,7 @@ pub fn approx_median(expr: Expr) -> Expr { false, None, None, + false, )) } @@ -296,6 +305,7 @@ pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { false, None, None, + false, )) } @@ -311,6 +321,7 @@ pub fn approx_percentile_cont_with_weight( false, None, None, + false, )) } @@ -381,6 +392,7 @@ pub fn stddev(expr: Expr) -> Expr { false, None, None, + false, )) } diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 98f489379386..59037337a535 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -299,12 +299,14 @@ impl TreeNode for Expr { distinct, filter, order_by, + can_be_pushed_down, }) => Expr::AggregateFunction(AggregateFunction::new( fun, transform_vec(args, &mut transform)?, distinct, transform_option_box(filter, &mut transform)?, transform_option_vec(order_by, &mut transform)?, + can_be_pushed_down, )), Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 51354cb66661..46ef6422f523 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -176,6 +176,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { distinct, filter, order_by, + can_be_pushed_down, }) if args.len() == 1 => match args[0] { Expr::Wildcard => Expr::AggregateFunction(AggregateFunction { fun: aggregate_function::AggregateFunction::Count, @@ -183,6 +184,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { distinct, filter, order_by, + can_be_pushed_down, }), _ => old_expr, }, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e8785efa0959..cccd0b3efd06 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -409,6 +409,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { distinct, filter, order_by, + can_be_pushed_down, }) => { let new_expr = coerce_agg_exprs_for_signature( &fun, @@ -417,7 +418,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &aggregate_function::signature(&fun), )?; let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, + fun, new_expr, distinct, filter, order_by, can_be_pushed_down, )); Ok(expr) } @@ -993,6 +994,7 @@ mod test { false, None, None, + false, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); let expected = "Projection: AVG(Int64(12))\n EmptyRelation"; @@ -1006,6 +1008,7 @@ mod test { false, None, None, + false, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); let expected = "Projection: AVG(a)\n EmptyRelation"; @@ -1023,6 +1026,7 @@ mod test { false, None, None, + false, )); let err = Projection::try_new(vec![agg_expr], empty).err().unwrap(); assert_eq!( diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index c2e65840ad5f..91649dd71886 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -1064,6 +1064,7 @@ mod tests { false, Some(Box::new(col("c").gt(lit(42)))), None, + false, )); let plan = LogicalPlanBuilder::from(table_scan) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index ba7e89094b0f..5d9ce7461c06 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -132,6 +132,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { args, filter, order_by, + can_be_pushed_down, .. }) => { // is_single_distinct_agg ensure args.len=1 @@ -146,6 +147,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { false, // intentional to remove distinct here filter.clone(), order_by.clone(), + can_be_pushed_down.clone(), ))) } _ => Ok(aggr_expr.clone()), @@ -402,6 +404,7 @@ mod tests { true, None, None, + false, )), ], )? diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 1d47ce3021b5..87399d2540d1 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1011,6 +1011,7 @@ pub fn parse_expr( expr.distinct, parse_optional_expr(expr.filter.as_deref(), registry)?.map(Box::new), parse_vec_expr(&expr.order_by, registry)?, + false, ))) } ExprType::Alias(alias) => Ok(Expr::Alias( diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 95e3fcb00c3a..8b27c0e9ccc7 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -2570,6 +2570,7 @@ mod roundtrip_tests { false, None, None, + false, )); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2583,6 +2584,7 @@ mod roundtrip_tests { true, None, None, + false, )); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -2596,6 +2598,7 @@ mod roundtrip_tests { false, None, None, + false, )); let ctx = SessionContext::new(); diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 7c8ccc67698b..32454b26bbf5 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -632,6 +632,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref distinct, ref filter, ref order_by, + .. }) => { let aggr_function = match fun { AggregateFunction::ApproxDistinct => { diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index e5c53caca5d7..801b919b0695 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -135,7 +135,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.function_args_to_expr(function.args, schema, planner_context)?; return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, args, distinct, None, order_by, + fun, args, distinct, None, order_by, true, ))); }; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 06102424f766..f697476172d2 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -365,7 +365,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // next, aggregate built-ins let fun = AggregateFunction::ArrayAgg; Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, args, distinct, None, order_by, + fun, args, distinct, None, order_by, false, ))) } @@ -500,6 +500,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { args, distinct, order_by, + can_be_pushed_down, .. }) => Ok(Expr::AggregateFunction(expr::AggregateFunction::new( fun, @@ -511,6 +512,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?)), order_by, + can_be_pushed_down, ))), _ => Err(DataFusionError::Plan( "AggregateExpressionWithFilter expression was not an AggregateFunction" diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index da11c20e7f00..3e71d5790f6c 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -167,6 +167,7 @@ where distinct, filter, order_by, + can_be_pushed_down, }) => Ok(Expr::AggregateFunction(AggregateFunction::new( fun.clone(), args.iter() @@ -175,6 +176,7 @@ where *distinct, filter.clone(), order_by.clone(), + can_be_pushed_down.clone(), ))), Expr::WindowFunction(WindowFunction { fun, diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b70191a9d30e..2336e4f9ad65 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -665,6 +665,7 @@ pub async fn from_substrait_agg_func( distinct, filter, order_by, + can_be_pushed_down: false, }))) } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index be236855774c..cf9caeed3dfb 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -472,7 +472,7 @@ pub fn to_substrait_agg_measure( ), ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by }) => { + Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by, .. }) => { let sorts = if let Some(order_by) = order_by { order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? } else {