Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SQL planner support for ROLLUP and CUBE grouping set expressions #2446

Merged
merged 10 commits into from
May 9, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> {
| Expr::InSubquery { .. }
| Expr::ScalarSubquery(_)
| Expr::GetIndexedField { .. }
| Expr::GroupingSet(_)
| Expr::Case { .. } => Recursion::Continue(self),

Expr::ScalarFunction { fun, .. } => self.visit_volatility(fun.volatility()),
Expand Down
15 changes: 8 additions & 7 deletions datafusion/core/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ use std::{
sync::Arc,
};

use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
use super::{Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
use crate::logical_plan::expr::exprlist_to_fields;
use crate::logical_plan::{
columnize_expr, normalize_col, normalize_cols, provider_as_source,
rewrite_sort_cols_by_aggs, Column, CrossJoin, DFField, DFSchema, DFSchemaRef, Limit,
Expand Down Expand Up @@ -557,7 +558,7 @@ impl LogicalPlanBuilder {
expr.extend(missing_exprs);

let new_schema = DFSchema::new_with_metadata(
exprlist_to_fields(&expr, input_schema)?,
exprlist_to_fields(&expr, &input)?,
input_schema.metadata().clone(),
)?;

Expand Down Expand Up @@ -629,7 +630,7 @@ impl LogicalPlanBuilder {
.map(|f| Expr::Column(f.qualified_column()))
.collect();
let new_schema = DFSchema::new_with_metadata(
exprlist_to_fields(&new_expr, schema)?,
exprlist_to_fields(&new_expr, &self.plan)?,
schema.metadata().clone(),
)?;

Expand Down Expand Up @@ -843,8 +844,7 @@ impl LogicalPlanBuilder {
let window_expr = normalize_cols(window_expr, &self.plan)?;
let all_expr = window_expr.iter();
validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?;
let mut window_fields: Vec<DFField> =
exprlist_to_fields(all_expr, self.plan.schema())?;
let mut window_fields: Vec<DFField> = exprlist_to_fields(all_expr, &self.plan)?;
window_fields.extend_from_slice(self.plan.schema().fields());
Ok(Self::from(LogicalPlan::Window(Window {
input: Arc::new(self.plan.clone()),
Expand All @@ -869,7 +869,7 @@ impl LogicalPlanBuilder {
let all_expr = group_expr.iter().chain(aggr_expr.iter());
validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?;
let aggr_schema = DFSchema::new_with_metadata(
exprlist_to_fields(all_expr, self.plan.schema())?,
exprlist_to_fields(all_expr, &self.plan)?,
self.plan.schema().metadata().clone(),
)?;
Ok(Self::from(LogicalPlan::Aggregate(Aggregate {
Expand Down Expand Up @@ -1126,13 +1126,14 @@ pub fn project_with_alias(
}
validate_unique_names("Projections", projected_expr.iter(), input_schema)?;
let input_schema = DFSchema::new_with_metadata(
exprlist_to_fields(&projected_expr, input_schema)?,
exprlist_to_fields(&projected_expr, &plan)?,
plan.schema().metadata().clone(),
)?;
let schema = match alias {
Some(ref alias) => input_schema.replace_qualifier(alias.as_str()),
None => input_schema,
};

Ok(LogicalPlan::Projection(Projection {
expr: projected_expr,
input: Arc::new(plan.clone()),
Expand Down
31 changes: 28 additions & 3 deletions datafusion/core/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ pub use super::Operator;
use crate::error::Result;
use crate::logical_plan::ExprSchemable;
use crate::logical_plan::{DFField, DFSchema};
use crate::sql::utils::find_columns_referenced_by_expr;
use arrow::datatypes::DataType;
pub use datafusion_common::{Column, ExprSchema};
pub use datafusion_expr::expr_fn::*;
use datafusion_expr::AccumulatorFunctionImplementation;
use datafusion_expr::BuiltinScalarFunction;
pub use datafusion_expr::Expr;
use datafusion_expr::StateTypeFunction;
pub use datafusion_expr::{lit, lit_timestamp_nano, Literal};
use datafusion_expr::{AccumulatorFunctionImplementation, LogicalPlan};
use datafusion_expr::{AggregateUDF, ScalarUDF};
use datafusion_expr::{
ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility,
Expand Down Expand Up @@ -138,9 +139,33 @@ pub fn create_udaf(
/// Create field meta-data from an expression, for use in a result set schema
pub fn exprlist_to_fields<'a>(
expr: impl IntoIterator<Item = &'a Expr>,
input_schema: &DFSchema,
plan: &LogicalPlan,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that we pass in a plan now instead of a schema

) -> Result<Vec<DFField>> {
expr.into_iter().map(|e| e.to_field(input_schema)).collect()
match plan {
LogicalPlan::Aggregate(agg) => {
let group_expr: Vec<Column> = agg
.group_expr
.iter()
.flat_map(find_columns_referenced_by_expr)
.collect();
let exprs: Vec<Expr> = expr.into_iter().cloned().collect();
let mut fields = vec![];
for expr in &exprs {
match expr {
Expr::Column(c) if group_expr.iter().any(|x| x == c) => {
// resolve against schema of input to aggregate
fields.push(expr.to_field(agg.input.schema())?);
}
_ => fields.push(expr.to_field(plan.schema())?),
}
}
Ok(fields)
}
_ => {
let input_schema = &plan.schema();
expr.into_iter().map(|e| e.to_field(input_schema)).collect()
}
}
}

/// Calls a named built in function
Expand Down
17 changes: 17 additions & 0 deletions datafusion/core/src/logical_plan/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::logical_plan::ExprSchemable;
use crate::logical_plan::LogicalPlan;
use datafusion_common::Column;
use datafusion_common::Result;
use datafusion_expr::expr::GroupingSet;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
Expand Down Expand Up @@ -215,6 +216,22 @@ impl ExprRewritable for Expr {
fun,
distinct,
},
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?))
}
GroupingSet::Cube(exprs) => {
Expr::GroupingSet(GroupingSet::Cube(rewrite_vec(exprs, rewriter)?))
}
GroupingSet::GroupingSets(lists_of_exprs) => {
Expr::GroupingSet(GroupingSet::GroupingSets(
lists_of_exprs
.iter()
.map(|exprs| rewrite_vec(exprs.clone(), rewriter))
.collect::<Result<Vec<_>>>()?,
))
}
},
Expr::AggregateUDF { args, fun } => Expr::AggregateUDF {
args: rewrite_vec(args, rewriter)?,
fun,
Expand Down
14 changes: 14 additions & 0 deletions datafusion/core/src/logical_plan/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use super::Expr;
use datafusion_common::Result;
use datafusion_expr::expr::GroupingSet;

/// Controls how the visitor recursion should proceed.
pub enum Recursion<V: ExpressionVisitor> {
Expand Down Expand Up @@ -103,6 +104,19 @@ impl ExprVisitable for Expr {
| Expr::TryCast { expr, .. }
| Expr::Sort { expr, .. }
| Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
Expr::GroupingSet(GroupingSet::Rollup(exprs)) => exprs
.iter()
.fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs
.iter()
.fold(Ok(visitor), |v, e| v.and_then(|v| e.accept(v))),
Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
lists_of_exprs.iter().fold(Ok(visitor), |v, exprs| {
v.and_then(|v| {
exprs.iter().fold(Ok(v), |v, e| v.and_then(|v| e.accept(v)))
})
})
}
Expr::Column(_)
| Expr::ScalarVariable(_, _)
| Expr::Literal(_)
Expand Down
28 changes: 28 additions & 0 deletions datafusion/core/src/optimizer/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use crate::logical_plan::{
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use arrow::datatypes::DataType;
use datafusion_expr::expr::GroupingSet;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

Expand Down Expand Up @@ -482,6 +483,33 @@ impl ExprIdentifierVisitor<'_> {
desc.push_str("GetIndexedField-");
desc.push_str(&key.to_string());
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
desc.push_str("Rollup");
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
}
GroupingSet::Cube(exprs) => {
desc.push_str("Cube");
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
}
GroupingSet::GroupingSets(lists_of_exprs) => {
desc.push_str("GroupingSets");
for exprs in lists_of_exprs {
desc.push('(');
for expr in exprs {
desc.push('-');
desc.push_str(&Self::desc_expr(expr));
}
desc.push(')');
}
}
},
}

desc
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/optimizer/projection_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ mod tests {
// that the Column references are unqualified (e.g. their
// relation is `None`). PlanBuilder resolves the expressions
let expr = vec![col("a"), col("b")];
let projected_fields = exprlist_to_fields(&expr, input_schema).unwrap();
let projected_fields = exprlist_to_fields(&expr, &table_scan).unwrap();
let projected_schema = DFSchema::new_with_metadata(
projected_fields,
input_schema.metadata().clone(),
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/optimizer/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ impl<'a> ConstEvaluator<'a> {
| Expr::ScalarSubquery(_)
| Expr::WindowFunction { .. }
| Expr::Sort { .. }
| Expr::GroupingSet(_)
| Expr::Wildcard
| Expr::QualifiedWildcard { .. } => false,
Expr::ScalarFunction { fun, .. } => Self::volatility_ok(fun.volatility()),
Expand Down
20 changes: 20 additions & 0 deletions datafusion/core/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use crate::{
logical_plan::ExpressionVisitor,
};
use datafusion_common::DFSchema;
use datafusion_expr::expr::GroupingSet;
use std::{collections::HashSet, sync::Arc};

const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__";
Expand Down Expand Up @@ -83,6 +84,7 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
| Expr::ScalarUDF { .. }
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::GroupingSet(_)
| Expr::AggregateUDF { .. }
| Expr::InList { .. }
| Expr::Exists { .. }
Expand Down Expand Up @@ -323,6 +325,13 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
| Expr::ScalarUDF { args, .. }
| Expr::AggregateFunction { args, .. }
| Expr::AggregateUDF { args, .. } => Ok(args.clone()),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(exprs.clone()),
GroupingSet::Cube(exprs) => Ok(exprs.clone()),
GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan(
"GroupingSets are not supported yet".to_string(),
)),
},
Expr::WindowFunction {
args,
partition_by,
Expand Down Expand Up @@ -458,6 +467,17 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
fun: fun.clone(),
args: expressions.to_vec(),
}),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(_exprs) => {
Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec())))
}
GroupingSet::Cube(_exprs) => {
Ok(Expr::GroupingSet(GroupingSet::Rollup(expressions.to_vec())))
}
GroupingSet::GroupingSets(_) => Err(DataFusionError::Plan(
"GroupingSets are not supported yet".to_string(),
)),
},
Expr::Case { .. } => {
let mut base_expr: Option<Box<Expr>> = None;
let mut when_then: Vec<(Box<Expr>, Box<Expr>)> = vec![];
Expand Down
32 changes: 32 additions & 0 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use arrow::compute::SortOptions;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::{compute::can_cast_types, datatypes::DataType};
use async_trait::async_trait;
use datafusion_expr::expr::GroupingSet;
use datafusion_physical_expr::expressions::DateIntervalExpr;
use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt};
Expand Down Expand Up @@ -174,6 +175,37 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
}
Ok(format!("{}({})", fun.name, names.join(",")))
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(format!(
"ROLLUP ({})",
exprs
.iter()
.map(|e| create_physical_name(e, false))
.collect::<Result<Vec<_>>>()?
.join(", ")
)),
GroupingSet::Cube(exprs) => Ok(format!(
"CUBE ({})",
exprs
.iter()
.map(|e| create_physical_name(e, false))
.collect::<Result<Vec<_>>>()?
.join(", ")
)),
GroupingSet::GroupingSets(lists_of_exprs) => {
let mut strings = vec![];
for exprs in lists_of_exprs {
let exprs_str = exprs
.iter()
.map(|e| create_physical_name(e, false))
.collect::<Result<Vec<_>>>()?
.join(", ");
strings.push(format!("({})", exprs_str));
}
Ok(format!("GROUPING SETS ({})", strings.join(", ")))
}
},

Expr::InList {
expr,
list,
Expand Down
Loading