diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 21375a3ba100..9a218fbb7d92 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -103,6 +103,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { | Expr::InList { .. } | Expr::InSubquery { .. } | Expr::GetIndexedField { .. } + | Expr::GroupingSet(_) | Expr::Case { .. } => Recursion::Continue(self), Expr::ScalarFunction { fun, .. } => self.visit_volatility(fun.volatility()), diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index c27c3b1e334f..a51ab37b04c8 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -568,7 +568,7 @@ impl LogicalPlanBuilder { expr.extend(missing_exprs); let new_schema = DFSchema::new_with_metadata( - exprlist_to_fields(&expr, input_schema)?, + exprlist_to_fields(&expr, &input)?, input_schema.metadata().clone(), )?; @@ -640,7 +640,7 @@ impl LogicalPlanBuilder { .map(|f| Expr::Column(f.qualified_column())) .collect(); let new_schema = DFSchema::new_with_metadata( - exprlist_to_fields(&new_expr, schema)?, + exprlist_to_fields(&new_expr, &self.plan)?, schema.metadata().clone(), )?; @@ -870,8 +870,7 @@ impl LogicalPlanBuilder { let window_expr = normalize_cols(window_expr, &self.plan)?; let all_expr = window_expr.iter(); validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?; - let mut window_fields: Vec = - exprlist_to_fields(all_expr, self.plan.schema())?; + let mut window_fields: Vec = exprlist_to_fields(all_expr, &self.plan)?; window_fields.extend_from_slice(self.plan.schema().fields()); Ok(Self::from(LogicalPlan::Window(Window { input: Arc::new(self.plan.clone()), @@ -903,7 +902,7 @@ impl LogicalPlanBuilder { let all_expr = group_expr.iter().chain(aggr_expr.iter()); validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?; let aggr_schema = DFSchema::new_with_metadata( - exprlist_to_fields(all_expr, self.plan.schema())?, + exprlist_to_fields(all_expr, &self.plan)?, self.plan.schema().metadata().clone(), )?; Ok(Self::from(LogicalPlan::Aggregate(Aggregate { @@ -1180,13 +1179,14 @@ pub fn project_with_alias( } validate_unique_names("Projections", projected_expr.iter(), input_schema)?; let input_schema = DFSchema::new_with_metadata( - exprlist_to_fields(&projected_expr, input_schema)?, + exprlist_to_fields(&projected_expr, &plan)?, plan.schema().metadata().clone(), )?; let schema = match alias { Some(ref alias) => input_schema.replace_qualifier(alias.as_str()), None => input_schema, }; + Ok(LogicalPlan::Projection(Projection { expr: projected_expr, input: Arc::new(plan.clone()), diff --git a/datafusion/core/src/logical_plan/expr.rs b/datafusion/core/src/logical_plan/expr.rs index 2071ca4ef59b..300cf8d6740e 100644 --- a/datafusion/core/src/logical_plan/expr.rs +++ b/datafusion/core/src/logical_plan/expr.rs @@ -21,7 +21,9 @@ pub use super::Operator; use crate::error::Result; use crate::logical_plan::ExprSchemable; +use crate::logical_plan::LogicalPlan; use crate::logical_plan::{DFField, DFSchema}; +use crate::sql::utils::find_columns_referenced_by_expr; use arrow::datatypes::DataType; use datafusion_common::DataFusionError; pub use datafusion_common::{Column, ExprSchema}; @@ -251,9 +253,33 @@ pub fn create_udaf( /// Create field meta-data from an expression, for use in a result set schema pub fn exprlist_to_fields<'a>( expr: impl IntoIterator, - input_schema: &DFSchema, + plan: &LogicalPlan, ) -> Result> { - expr.into_iter().map(|e| e.to_field(input_schema)).collect() + match plan { + LogicalPlan::Aggregate(agg) => { + let group_expr: Vec = agg + .group_expr + .iter() + .flat_map(find_columns_referenced_by_expr) + .collect(); + let exprs: Vec = expr.into_iter().cloned().collect(); + let mut fields = vec![]; + for expr in &exprs { + match expr { + Expr::Column(c) if group_expr.iter().any(|x| x == c) => { + // resolve against schema of input to aggregate + fields.push(expr.to_field(agg.input.schema())?); + } + _ => fields.push(expr.to_field(plan.schema())?), + } + } + Ok(fields) + } + _ => { + let input_schema = &plan.schema(); + expr.into_iter().map(|e| e.to_field(input_schema)).collect() + } + } } /// Calls a named built in function diff --git a/datafusion/core/src/logical_plan/expr_rewriter.rs b/datafusion/core/src/logical_plan/expr_rewriter.rs index 8e65962771b5..0d16d9674642 100644 --- a/datafusion/core/src/logical_plan/expr_rewriter.rs +++ b/datafusion/core/src/logical_plan/expr_rewriter.rs @@ -27,6 +27,7 @@ use crate::sql::utils::{ }; use datafusion_common::Column; use datafusion_common::Result; +use datafusion_expr::expr::GroupingSet; use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; @@ -256,6 +257,22 @@ impl ExprRewritable for Expr { fun, distinct, }, + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => { + Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?)) + } + GroupingSet::Cube(exprs) => { + Expr::GroupingSet(GroupingSet::Cube(rewrite_vec(exprs, rewriter)?)) + } + GroupingSet::GroupingSets(lists_of_exprs) => { + Expr::GroupingSet(GroupingSet::GroupingSets( + lists_of_exprs + .iter() + .map(|exprs| rewrite_vec(exprs.clone(), rewriter)) + .collect::>>()?, + )) + } + }, Expr::AggregateUDF { args, fun } => Expr::AggregateUDF { args: rewrite_vec(args, rewriter)?, fun, diff --git a/datafusion/core/src/logical_plan/expr_schema.rs b/datafusion/core/src/logical_plan/expr_schema.rs index 4f041f09a342..f7b4778adf7b 100644 --- a/datafusion/core/src/logical_plan/expr_schema.rs +++ b/datafusion/core/src/logical_plan/expr_schema.rs @@ -132,6 +132,10 @@ impl ExprSchemable for Expr { "QualifiedWildcard expressions are not valid in a logical query plan" .to_owned(), )), + Expr::GroupingSet(_) => { + // grouping sets do not really have a type and do not appear in projections + Ok(DataType::Null) + } Expr::GetIndexedField { ref expr, key } => { let data_type = expr.get_type(schema)?; @@ -212,6 +216,11 @@ impl ExprSchemable for Expr { let data_type = expr.get_type(input_schema)?; get_indexed_field(&data_type, key).map(|x| x.is_nullable()) } + Expr::GroupingSet(_) => { + // grouping sets do not really have the concept of nullable and do not appear + // in projections + Ok(true) + } } } diff --git a/datafusion/core/src/logical_plan/expr_visitor.rs b/datafusion/core/src/logical_plan/expr_visitor.rs index 9296848ea8ab..e0befe0ddbc2 100644 --- a/datafusion/core/src/logical_plan/expr_visitor.rs +++ b/datafusion/core/src/logical_plan/expr_visitor.rs @@ -19,6 +19,7 @@ use super::{Expr, Like}; use datafusion_common::Result; +use datafusion_expr::expr::GroupingSet; /// Controls how the visitor recursion should proceed. pub enum Recursion { @@ -106,6 +107,19 @@ impl ExprVisitable for Expr { let visitor = expr.accept(visitor)?; key.accept(visitor) } + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => exprs + .iter() + .fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))), + Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs + .iter() + .fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))), + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + lists_of_exprs.iter().fold(Ok(visitor), |v, exprs| { + v.and_then(|v| { + exprs.iter().fold(Ok(v), |v, e| v.and_then(|v| e.accept(v))) + }) + }) + } Expr::Column(_) | Expr::OuterColumn(_, _) | Expr::ScalarVariable(_, _) diff --git a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs index 9fa278a8b688..dec9c08efbf1 100644 --- a/datafusion/core/src/optimizer/common_subexpr_eliminate.rs +++ b/datafusion/core/src/optimizer/common_subexpr_eliminate.rs @@ -29,6 +29,7 @@ use crate::optimizer::optimizer::OptimizerConfig; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; use arrow::datatypes::DataType; +use datafusion_expr::expr::GroupingSet; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -523,6 +524,33 @@ impl ExprIdentifierVisitor<'_> { desc.push_str("GetIndexedField-"); desc.push_str(&key.to_string()); } + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => { + desc.push_str("Rollup"); + for expr in exprs { + desc.push('-'); + desc.push_str(&Self::desc_expr(expr)); + } + } + GroupingSet::Cube(exprs) => { + desc.push_str("Cube"); + for expr in exprs { + desc.push('-'); + desc.push_str(&Self::desc_expr(expr)); + } + } + GroupingSet::GroupingSets(lists_of_exprs) => { + desc.push_str("GroupingSets"); + for exprs in lists_of_exprs { + desc.push('('); + for expr in exprs { + desc.push('-'); + desc.push_str(&Self::desc_expr(expr)); + } + desc.push(')'); + } + } + }, } desc diff --git a/datafusion/core/src/optimizer/projection_push_down.rs b/datafusion/core/src/optimizer/projection_push_down.rs index 9c1cdb11bc9f..e7731ed84508 100644 --- a/datafusion/core/src/optimizer/projection_push_down.rs +++ b/datafusion/core/src/optimizer/projection_push_down.rs @@ -827,7 +827,7 @@ mod tests { // that the Column references are unqualified (e.g. their // relation is `None`). PlanBuilder resolves the expressions let expr = vec![col("a"), col("b")]; - let projected_fields = exprlist_to_fields(&expr, input_schema).unwrap(); + let projected_fields = exprlist_to_fields(&expr, &table_scan).unwrap(); let projected_schema = DFSchema::new_with_metadata( projected_fields, input_schema.metadata().clone(), diff --git a/datafusion/core/src/optimizer/simplify_expressions.rs b/datafusion/core/src/optimizer/simplify_expressions.rs index 32b0daabeb9b..ba285fc11b55 100644 --- a/datafusion/core/src/optimizer/simplify_expressions.rs +++ b/datafusion/core/src/optimizer/simplify_expressions.rs @@ -393,6 +393,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::WindowFunction { .. } | Expr::Sort { .. } | Expr::InSubquery { .. } + | Expr::GroupingSet(_) | Expr::Wildcard | Expr::QualifiedWildcard { .. } => false, Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()), diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index 0f0f8b3b6805..91ab7edfba17 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -36,6 +36,8 @@ use crate::{ error::{DataFusionError, Result}, logical_plan::ExpressionVisitor, }; +use datafusion_common::DFSchema; +use datafusion_expr::expr::GroupingSet; use std::{collections::HashSet, sync::Arc}; const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__"; @@ -91,6 +93,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> { | Expr::TableUDF { .. } | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } + | Expr::GroupingSet(_) | Expr::AggregateUDF { .. } | Expr::InList { .. } | Expr::InSubquery { .. } @@ -339,6 +342,13 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { | Expr::TableUDF { args, .. } | Expr::AggregateFunction { args, .. } | Expr::AggregateUDF { args, .. } => Ok(args.clone()), + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => Ok(exprs.clone()), + GroupingSet::Cube(exprs) => Ok(exprs.clone()), + GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan( + "GroupingSets are not supported yet".to_string(), + )), + }, Expr::WindowFunction { args, partition_by, @@ -517,6 +527,17 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { fun: fun.clone(), args: expressions.to_vec(), }), + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(_exprs) => { + Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec()))) + } + GroupingSet::Cube(_exprs) => { + Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec()))) + } + GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan( + "GroupingSets are not supported yet".to_string(), + )), + }, Expr::Case { .. } => { let mut base_expr: Option> = None; let mut when_then: Vec<(Box, Box)> = vec![]; diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 76d003622913..6cd1cf91dccf 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -68,6 +68,7 @@ use arrow::datatypes::{Schema, SchemaRef}; use arrow::{compute::can_cast_types, datatypes::DataType}; use async_trait::async_trait; use datafusion_common::OuterQueryCursor; +use datafusion_expr::expr::GroupingSet; use datafusion_expr::expr_fn::binary_expr; use datafusion_physical_expr::expressions::{any, OuterColumn}; use futures::future::BoxFuture; @@ -197,6 +198,37 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { } Ok(format!("{}({})", fun.name, names.join(","))) } + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => Ok(format!( + "ROLLUP ({})", + exprs + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()? + .join(", ") + )), + GroupingSet::Cube(exprs) => Ok(format!( + "CUBE ({})", + exprs + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()? + .join(", ") + )), + GroupingSet::GroupingSets(lists_of_exprs) => { + let mut strings = vec![]; + for exprs in lists_of_exprs { + let exprs_str = exprs + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()? + .join(", "); + strings.push(format!("({})", exprs_str)); + } + Ok(format!("GROUPING SETS ({})", strings.join(", "))) + } + }, + Expr::InList { expr, list, diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index f577f7713506..3ef2f8e6dd63 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -52,6 +52,12 @@ use datafusion_expr::{window_function::WindowFunction, BuiltinScalarFunction}; use hashbrown::HashMap; use log::warn; +<<<<<<< HEAD +======= +use datafusion_common::field_not_found; +use datafusion_expr::expr::GroupingSet; +use datafusion_expr::logical_plan::{Filter, Subquery}; +>>>>>>> 1fe038fbc (Add SQL planner support for `ROLLUP` and `CUBE` grouping set expressions (#2446)) use sqlparser::ast::{ ArrayAgg, BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Fetch, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, @@ -1262,11 +1268,44 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { group_by_exprs: Vec, aggr_exprs: Vec, ) -> Result<(LogicalPlan, Vec, Option)> { +<<<<<<< HEAD let aggr_projection_exprs = group_by_exprs .iter() .chain(aggr_exprs.iter()) .cloned() .collect::>(); +======= + // create the aggregate plan + let plan = LogicalPlanBuilder::from(input.clone()) + .aggregate(group_by_exprs.clone(), aggr_exprs.clone())? + .build()?; + + // in this next section of code we are re-writing the projection to refer to columns + // output by the aggregate plan. For example, if the projection contains the expression + // `SUM(a)` then we replace that with a reference to a column `#SUM(a)` produced by + // the aggregate plan. + + // combine the original grouping and aggregate expressions into one list (note that + // we do not add the "having" expression since that is not part of the projection) + let mut aggr_projection_exprs = vec![]; + for expr in &group_by_exprs { + match expr { + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + aggr_projection_exprs.extend_from_slice(exprs) + } + Expr::GroupingSet(GroupingSet::Cube(exprs)) => { + aggr_projection_exprs.extend_from_slice(exprs) + } + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + for exprs in lists_of_exprs { + aggr_projection_exprs.extend_from_slice(exprs) + } + } + _ => aggr_projection_exprs.push(expr.clone()), + } + } + aggr_projection_exprs.extend_from_slice(&aggr_exprs); +>>>>>>> 1fe038fbc (Add SQL planner support for `ROLLUP` and `CUBE` grouping set expressions (#2446)) let plan = LogicalPlanBuilder::from(input.clone()) .aggregate(group_by_exprs, aggr_exprs)? @@ -2217,10 +2256,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { function.name.0[0].clone().value.to_ascii_lowercase() }; - // first, scalar built-in - if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { + // first, check SQL reserved words + if name == "rollup" { let args = self.function_args_to_expr(function.args, schema)?; + return Ok(Expr::GroupingSet(GroupingSet::Rollup(args))); + } else if name == "cube" { + let args = self.function_args_to_expr(function.args, schema)?; + return Ok(Expr::GroupingSet(GroupingSet::Cube(args))); + } + // next, scalar built-in + if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { + let args = self.function_args_to_expr(function.args, schema)?; return Ok(Expr::ScalarFunction { fun, args }); }; @@ -5233,6 +5280,7 @@ mod tests { quick_test(sql, expected); } +<<<<<<< HEAD #[test] fn test_offset_after_limit() { let sql = "select id from person where person.id > 100 LIMIT 5 OFFSET 3;"; @@ -5291,5 +5339,45 @@ mod tests { \n Filter: #person.id < Int64(100)\ \n TableScan: person projection=None"; quick_test(sql, expected); +======= + #[tokio::test] + async fn aggregate_with_rollup() { + let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)"; + let expected = "Projection: #person.id, #person.state, #person.age, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#person.id, ROLLUP (#person.state, #person.age)]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[tokio::test] + async fn aggregate_with_cube() { + let sql = + "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, CUBE (state, age)"; + let expected = "Projection: #person.id, #person.state, #person.age, #COUNT(UInt8(1))\ + \n Aggregate: groupBy=[[#person.id, CUBE (#person.state, #person.age)]], aggr=[[COUNT(UInt8(1))]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + + #[ignore] // see https://github.com/apache/arrow-datafusion/issues/2469 + #[tokio::test] + async fn aggregate_with_grouping_sets() { + let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; + let expected = "TBD"; + quick_test(sql, expected); + } + + fn assert_field_not_found(err: DataFusionError, name: &str) { + match err { + DataFusionError::SchemaError { .. } => { + let msg = format!("{}", err); + let expected = format!("Schema error: No field named '{}'.", name); + if !msg.starts_with(&expected) { + panic!("error [{}] did not start with [{}]", msg, expected); + } + } + _ => panic!("assert_field_not_found wrong error type"), + } +>>>>>>> 1fe038fbc (Add SQL planner support for `ROLLUP` and `CUBE` grouping set expressions (#2446)) } } diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs index 833914ce4d9e..e0efb4ba1e86 100644 --- a/datafusion/core/src/sql/utils.rs +++ b/datafusion/core/src/sql/utils.rs @@ -28,6 +28,7 @@ use crate::{ error::{DataFusionError, Result}, logical_plan::{Column, ExpressionVisitor, Recursion}, }; +use datafusion_expr::expr::GroupingSet; use std::collections::HashMap; /// Collect all deeply nested `Expr::AggregateFunction` and @@ -90,6 +91,30 @@ where }) } +/// Recursively find all columns referenced by an expression +#[derive(Debug, Default)] +struct ColumnCollector { + exprs: Vec, +} + +impl ExpressionVisitor for ColumnCollector { + fn pre_visit(mut self, expr: &Expr) -> Result> { + if let Expr::Column(c) = expr { + self.exprs.push(c.clone()) + } + Ok(Recursion::Continue(self)) + } +} + +pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { + // As the `ExpressionVisitor` impl above always returns Ok, this + // "can't" error + let ColumnCollector { exprs } = e + .accept(ColumnCollector::default()) + .expect("Unexpected error"); + exprs +} + // Visitor that find expressions that match a particular predicate struct Finder<'a, F> where @@ -191,8 +216,55 @@ pub(crate) fn can_columns_satisfy_exprs( "Expr::Column are required".to_string(), )), })?; +<<<<<<< HEAD Ok(find_column_exprs(exprs).iter().all(|c| columns.contains(c))) +======= + let column_exprs = find_column_exprs(exprs); + for e in &column_exprs { + match e { + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + for e in exprs { + check_column_satisfies_expr(columns, e, message_prefix)?; + } + } + Expr::GroupingSet(GroupingSet::Cube(exprs)) => { + for e in exprs { + check_column_satisfies_expr(columns, e, message_prefix)?; + } + } + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + for exprs in lists_of_exprs { + for e in exprs { + check_column_satisfies_expr(columns, e, message_prefix)?; + } + } + } + _ => check_column_satisfies_expr(columns, e, message_prefix)?, + } + } + Ok(()) +} + +fn check_column_satisfies_expr( + columns: &[Expr], + expr: &Expr, + message_prefix: &str, +) -> Result<()> { + if !columns.contains(expr) { + return Err(DataFusionError::Plan(format!( + "{}: Expression {:?} could not be resolved from available columns: {}", + message_prefix, + expr, + columns + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ))); + } + Ok(()) +>>>>>>> 1fe038fbc (Add SQL planner support for `ROLLUP` and `CUBE` grouping set expressions (#2446)) } /// Returns a cloned `Expr`, but any of the `Expr`'s in the tree may be @@ -447,6 +519,34 @@ where expr: Box::new(clone_with_replacement(expr.as_ref(), replacement_fn)?), key: Box::new(clone_with_replacement(key.as_ref(), replacement_fn)?), }), + Expr::GroupingSet(set) => match set { + GroupingSet::Rollup(exprs) => Ok(Expr::GroupingSet(GroupingSet::Rollup( + exprs + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + ))), + GroupingSet::Cube(exprs) => Ok(Expr::GroupingSet(GroupingSet::Cube( + exprs + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + ))), + GroupingSet::GroupingSets(lists_of_exprs) => { + let mut new_lists_of_exprs = vec![]; + for exprs in lists_of_exprs { + new_lists_of_exprs.push( + exprs + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + ); + } + Ok(Expr::GroupingSet(GroupingSet::GroupingSets( + new_lists_of_exprs, + ))) + } + }, }, } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 5c6297da9195..5284e2557973 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -266,6 +266,24 @@ pub enum Expr { Wildcard, /// Represents a reference to all fields in a specific schema. QualifiedWildcard { qualifier: String }, + /// List of grouping set expressions. Only valid in the context of an aggregate + /// GROUP BY expression list + GroupingSet(GroupingSet), +} + +/// Grouping sets +/// See https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-GROUPING-SETS +/// for Postgres definition. +/// See https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-groupby.html +/// for Apache Spark definition. +#[derive(Clone, PartialEq, Hash)] +pub enum GroupingSet { + /// Rollup grouping sets + Rollup(Vec), + /// Cube grouping sets + Cube(Vec), + /// User-defined grouping sets + GroupingSets(Vec>), } /// LIKE expression @@ -679,6 +697,51 @@ impl fmt::Debug for Expr { Expr::GetIndexedField { ref expr, key } => { write!(f, "({:?})[{}]", expr, key) } + Expr::GroupingSet(grouping_sets) => match grouping_sets { + GroupingSet::Rollup(exprs) => { + // ROLLUP (c0, c1, c2) + write!( + f, + "ROLLUP ({})", + exprs + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } + GroupingSet::Cube(exprs) => { + // CUBE (c0, c1, c2) + write!( + f, + "CUBE ({})", + exprs + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } + GroupingSet::GroupingSets(lists_of_exprs) => { + // GROUPING SETS ((c0), (c1, c2), (c3, c4)) + write!( + f, + "GROUPING SETS ({})", + lists_of_exprs + .iter() + .map(|exprs| format!( + "({})", + exprs + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + )) + .collect::>() + .join(", ") + ) + } + }, } } } @@ -902,6 +965,26 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { } Ok(format!("{}({})", fun.name, names.join(","))) } + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => Ok(format!( + "ROLLUP ({})", + create_names(exprs.as_slice(), input_schema)? + )), + GroupingSet::Cube(exprs) => Ok(format!( + "CUBE ({})", + create_names(exprs.as_slice(), input_schema)? + )), + GroupingSet::GroupingSets(lists_of_exprs) => { + let mut list_of_names = vec![]; + for exprs in lists_of_exprs { + list_of_names.push(format!( + "({})", + create_names(exprs.as_slice(), input_schema)? + )); + } + Ok(format!("GROUPING SETS ({})", list_of_names.join(", "))) + } + }, Expr::InList { expr, list, @@ -952,6 +1035,15 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { } } +/// Create a comma separated list of names from a list of expressions +fn create_names(exprs: &[Expr], input_schema: &DFSchema) -> Result { + Ok(exprs + .iter() + .map(|e| create_name(e, input_schema)) + .collect::>>()? + .join(", ")) +} + #[cfg(test)] mod test { use crate::expr_fn::col;