From 6a3a19a702029cde64c270a48012281b85a67ee2 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Sun, 19 May 2024 18:30:54 +0300 Subject: [PATCH] feat: eliminate group by constant optimizer rule --- .../src/eliminate_group_by_constant.rs | 105 ++++++++++++++++++ datafusion/optimizer/src/lib.rs | 1 + datafusion/optimizer/src/optimizer.rs | 2 + .../optimizer_group_by_constant.slt | 96 ++++++++++++++++ 4 files changed, 204 insertions(+) create mode 100644 datafusion/optimizer/src/eliminate_group_by_constant.rs create mode 100644 datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs new file mode 100644 index 0000000000000..52dc025bb865d --- /dev/null +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` clause +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use datafusion_common::tree_node::Transformed; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder}; + +/// Optimizer rule that removes constant expressions from `GROUP BY` clause +#[derive(Default)] +pub struct EliminateGroupByConstant {} + +impl EliminateGroupByConstant { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for EliminateGroupByConstant { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Aggregate(aggregate) => { + let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate + .group_expr + .iter() + .partition(|expr| is_constant_expression(expr)); + + if const_group_expr.is_empty() { + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + aggregate.input, + nonconst_group_expr.into_iter().cloned().collect(), + aggregate.aggr_expr.clone(), + )?); + + let projection_expr = + aggregate.group_expr.into_iter().chain(aggregate.aggr_expr); + + let projection = LogicalPlanBuilder::from(simplified_aggregate) + .project(projection_expr)? + .build()?; + + Ok(Transformed::yes(projection)) + } + _ => Ok(Transformed::no(plan)), + } + } + + fn try_optimize( + &self, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + internal_err!("Should have called EliminateGroupByConstant::rewrite") + } + + fn name(&self) -> &str { + "eliminate_group_by_constant" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } +} + +/// Checks if expression is constant, and can be eliminated from group by. +/// +fn is_constant_expression(expr: &Expr) -> bool { + match expr { + Expr::Alias(e) => is_constant_expression(&e.expr), + Expr::BinaryExpr(e) => { + is_constant_expression(&e.left) && is_constant_expression(&e.right) + } + Expr::Literal(_) => true, + Expr::ScalarFunction(e) => e.args.iter().all(|arg| is_constant_expression(arg)), + _ => false, + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 793c87f8bc0c7..c172d59797569 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -35,6 +35,7 @@ pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; pub mod eliminate_duplicated_expr; pub mod eliminate_filter; +pub mod eliminate_group_by_constant; pub mod eliminate_join; pub mod eliminate_limit; pub mod eliminate_nested_union; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index e787f56587f7b..3d89255890a53 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -35,6 +35,7 @@ use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr; use crate::eliminate_filter::EliminateFilter; +use crate::eliminate_group_by_constant::EliminateGroupByConstant; use crate::eliminate_join::EliminateJoin; use crate::eliminate_limit::EliminateLimit; use crate::eliminate_nested_union::EliminateNestedUnion; @@ -262,6 +263,7 @@ impl Optimizer { Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), + Arc::new(EliminateGroupByConstant::new()), Arc::new(OptimizeProjections::new()), ]; diff --git a/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt b/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt new file mode 100644 index 0000000000000..c223cc837c5eb --- /dev/null +++ b/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE EXTERNAL TABLE test_table ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); + +statement ok +SET datafusion.execution.target_partitions = 1; + +statement ok +SET datafusion.explain.logical_plan_only = true; + +query TT +EXPLAIN +SELECT c1, 99999, c5 + c8, 'test', count(1) +FROM test_table t +GROUP BY 1, 2, 3, 4 +---- +logical_plan +01)Projection: t.c1, Int64(99999), t.c5 + t.c8, Utf8("test"), COUNT(Int64(1)) +02)--Aggregate: groupBy=[[t.c1, t.c5 + t.c8]], aggr=[[COUNT(Int64(1))]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[c1, c5, c8] + +query TT +EXPLAIN +SELECT 123, 456, 789, count(1), avg(c12) +FROM test_table t +group by 1, 2, 3 +---- +logical_plan +01)Projection: Int64(123), Int64(456), Int64(789), COUNT(Int64(1)), AVG(t.c12) +02)--Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)), AVG(t.c12)]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[c12] + +query TT +EXPLAIN +SELECT to_date('2023-05-04') as dt, extract(day from now()) < 1000 as today_filter, count(1) +FROM test_table t +GROUP BY 1, 2 +---- +logical_plan +01)Projection: Date32("19481") AS dt, Boolean(true) AS today_filter, COUNT(Int64(1)) +02)--Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[] + +query TT +EXPLAIN +SELECT + not ( + cast( + extract(month from now()) AS INT + ) + between 50 and 60 + ), count(1) +FROM test_table t +GROUP BY 1 +---- +logical_plan +01)Projection: Boolean(true) AS NOT date_part(Utf8("MONTH"),now()) BETWEEN Int64(50) AND Int64(60), COUNT(Int64(1)) +02)--Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[]