diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs new file mode 100644 index 0000000000000..72a2c2afbd645 --- /dev/null +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -0,0 +1,316 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` clause +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use datafusion_common::tree_node::Transformed; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility}; + +/// Optimizer rule that removes constant expressions from `GROUP BY` clause +#[derive(Default)] +pub struct EliminateGroupByConstant {} + +impl EliminateGroupByConstant { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for EliminateGroupByConstant { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Aggregate(aggregate) => { + let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate + .group_expr + .iter() + .partition(|expr| is_constant_expression(expr)); + + // If no constant expressions found (nothing to optimize) or + // constant expression is the only expression in aggregate, + // optimization is skipped + if const_group_expr.is_empty() + || (!const_group_expr.is_empty() + && nonconst_group_expr.is_empty() + && aggregate.aggr_expr.is_empty()) + { + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + aggregate.input, + nonconst_group_expr.into_iter().cloned().collect(), + aggregate.aggr_expr.clone(), + )?); + + let projection_expr = + aggregate.group_expr.into_iter().chain(aggregate.aggr_expr); + + let projection = LogicalPlanBuilder::from(simplified_aggregate) + .project(projection_expr)? + .build()?; + + Ok(Transformed::yes(projection)) + } + _ => Ok(Transformed::no(plan)), + } + } + + fn try_optimize( + &self, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + internal_err!("Should have called EliminateGroupByConstant::rewrite") + } + + fn name(&self) -> &str { + "eliminate_group_by_constant" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } +} + +/// Checks if expression is constant, and can be eliminated from group by. +/// +/// Intended to be used only within this rule, helper function, which heavily +/// reiles on `SimplifyExpressions` result. +fn is_constant_expression(expr: &Expr) -> bool { + match expr { + Expr::Alias(e) => is_constant_expression(&e.expr), + Expr::BinaryExpr(e) => { + is_constant_expression(&e.left) && is_constant_expression(&e.right) + } + Expr::Literal(_) => true, + Expr::ScalarFunction(e) => { + matches!( + e.func.signature().volatility, + Volatility::Immutable | Volatility::Stable + ) && e.args.iter().all(is_constant_expression) + } + _ => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + + use arrow::datatypes::DataType; + use datafusion_common::Result; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{ + col, count, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, + Signature, TypeSignature, + }; + + use std::sync::Arc; + + #[derive(Debug)] + struct ScalarUDFMock { + signature: Signature, + } + + impl ScalarUDFMock { + fn new_with_volatility(volatility: Volatility) -> Self { + Self { + signature: Signature::new(TypeSignature::Any(1), volatility), + } + } + } + + impl ScalarUDFImpl for ScalarUDFMock { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &str { + "scalar_fn_mock" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Int32) + } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } + + #[test] + fn test_eliminate_gby_literal() -> Result<()> { + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])? + .build()?; + + let expected = "\ + Projection: test.a, UInt32(1), COUNT(test.c)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_eliminate_constant() -> Result<()> { + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])? + .build()?; + + let expected = "\ + Projection: Utf8(\"test\"), UInt32(123), COUNT(test.c)\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_no_op_no_constants() -> Result<()> { + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![col("a"), col("b")], vec![count(col("c"))])? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[test.a, test.b]], aggr=[[COUNT(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_no_op_only_constant() -> Result<()> { + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![lit(123u32)], Vec::::new())? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_eliminate_constant_with_alias() -> Result<()> { + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate( + vec![lit(123u32).alias("const"), col("a")], + vec![count(col("c"))], + )? + .build()?; + + let expected = "\ + Projection: UInt32(123) AS const, test.a, COUNT(test.c)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_eliminate_scalar_fn_with_constant_arg() -> Result<()> { + let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility( + Volatility::Immutable, + )); + let udf_expr = + Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? + .build()?; + + let expected = "\ + Projection: scalar_fn_mock(UInt32(123)), test.a, COUNT(test.c)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> { + let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility( + Volatility::Volatile, + )); + let udf_expr = + Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[COUNT(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 793c87f8bc0c7..c172d59797569 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -35,6 +35,7 @@ pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; pub mod eliminate_duplicated_expr; pub mod eliminate_filter; +pub mod eliminate_group_by_constant; pub mod eliminate_join; pub mod eliminate_limit; pub mod eliminate_nested_union; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index e787f56587f7b..3d89255890a53 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -35,6 +35,7 @@ use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr; use crate::eliminate_filter::EliminateFilter; +use crate::eliminate_group_by_constant::EliminateGroupByConstant; use crate::eliminate_join::EliminateJoin; use crate::eliminate_limit::EliminateLimit; use crate::eliminate_nested_union::EliminateNestedUnion; @@ -262,6 +263,7 @@ impl Optimizer { Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), + Arc::new(EliminateGroupByConstant::new()), Arc::new(OptimizeProjections::new()), ]; diff --git a/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt b/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt new file mode 100644 index 0000000000000..d54e82de07e43 --- /dev/null +++ b/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE EXTERNAL TABLE test_table ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); + +statement ok +SET datafusion.execution.target_partitions = 1; + +statement ok +SET datafusion.explain.logical_plan_only = true; + +query TT +EXPLAIN +SELECT c1, 99999, c5 + c8, 'test', count(1) +FROM test_table t +GROUP BY 1, 2, 3, 4 +---- +logical_plan +01)Projection: t.c1, Int64(99999), t.c5 + t.c8, Utf8("test"), COUNT(Int64(1)) +02)--Aggregate: groupBy=[[t.c1, t.c5 + t.c8]], aggr=[[COUNT(Int64(1))]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[c1, c5, c8] + +query TT +EXPLAIN +SELECT 123, 456, 789, count(1), avg(c12) +FROM test_table t +group by 1, 2, 3 +---- +logical_plan +01)Projection: Int64(123), Int64(456), Int64(789), COUNT(Int64(1)), AVG(t.c12) +02)--Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)), AVG(t.c12)]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[c12] + +query TT +EXPLAIN +SELECT to_date('2023-05-04') as dt, extract(day from now()) < 1000 as today_filter, count(1) +FROM test_table t +GROUP BY 1, 2 +---- +logical_plan +01)Projection: Date32("19481") AS dt, Boolean(true) AS today_filter, COUNT(Int64(1)) +02)--Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[] + +query TT +EXPLAIN +SELECT + not ( + cast( + extract(month from now()) AS INT + ) + between 50 and 60 + ), count(1) +FROM test_table t +GROUP BY 1 +---- +logical_plan +01)Projection: Boolean(true) AS NOT date_part(Utf8("MONTH"),now()) BETWEEN Int64(50) AND Int64(60), COUNT(Int64(1)) +02)--Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[] + +query TT +EXPLAIN +SELECT 123 +FROM test_table t +GROUP BY 1 +---- +logical_plan +01)Aggregate: groupBy=[[Int64(123)]], aggr=[[]] +02)--SubqueryAlias: t +03)----TableScan: test_table projection=[] + +query TT +EXPLAIN +SELECT random() +FROM test_table t +GROUP BY 1 +---- +logical_plan +01)Aggregate: groupBy=[[random()]], aggr=[[]] +02)--SubqueryAlias: t +03)----TableScan: test_table projection=[]