From 9487ca057353370aa75895453c92bb40b9f33ac6 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 1 Apr 2024 19:19:17 +0800 Subject: [PATCH 1/5] use alias (#9894) Signed-off-by: jayzhan211 --- datafusion/sqllogictest/test_files/expr.slt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 2e0cbf50cab9..60ab4777883e 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -2288,7 +2288,7 @@ select struct(time,load1,load2,host) from t1; # can have an aggregate function with an inner coalesce query TR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select t2.info['c3'] as host, sum(coalesce(t2.info)['c1']) from (select struct(time,load1,load2,host) as info from t1) t2 where t2.info['c3'] IS NOT NULL group by t2.info['c3'] order by host; ---- host1 1.1 host2 2.2 @@ -2296,7 +2296,7 @@ host3 3.3 # can have an aggregate function with an inner CASE WHEN query TR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select t2.info['c3'] as host, sum((case when t2.info['c3'] is not null then t2.info end)['c2']) from (select struct(time,load1,load2,host) as info from t1) t2 where t2.info['c3'] IS NOT NULL group by t2.info['c3'] order by host; ---- host1 101 host2 202 @@ -2304,7 +2304,7 @@ host3 303 # can have 2 projections with aggr(short_circuited), with different short-circuited expr query TRR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select t2.info['c3'] as host, sum(coalesce(t2.info)['c1']), sum((case when t2.info['c3'] is not null then t2.info end)['c2']) from (select struct(time,load1,load2,host) as info from t1) t2 where t2.info['c3'] IS NOT NULL group by t2.info['c3'] order by host; ---- host1 1.1 101 host2 2.2 202 @@ -2312,7 +2312,7 @@ host3 3.3 303 # can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN) query TRR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select t2.info['c3'] as host, sum((case when t2.info['c3'] is not null then t2.info end)['c1']), sum((case when t2.info['c3'] is not null then t2.info end)['c2']) from (select struct(time,load1,load2,host) as info from t1) t2 where t2.info['c3'] IS NOT NULL group by t2.info['c3'] order by host; ---- host1 1.1 101 host2 2.2 202 @@ -2320,7 +2320,7 @@ host3 3.3 303 # can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce) query TRR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select t2.info['c3'] as host, sum(coalesce(t2.info)['c1']), sum(coalesce(t2.info)['c2']) from (select struct(time,load1,load2,host) as info from t1) t2 where t2.info['c3'] IS NOT NULL group by t2.info['c3'] order by host; ---- host1 1.1 101 host2 2.2 202 From f300168791b261e1162ac7fab47b329c9e5467f3 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Mon, 1 Apr 2024 23:36:14 +0800 Subject: [PATCH 2/5] fix: detect non-recursive CTEs in the recursive `WITH` clause (#9836) * move cte related logic to its own mod * fix check cte self reference * add tests * fix test * move test to slt --- datafusion/sql/src/cte.rs | 212 +++++++++++++++++++++ datafusion/sql/src/lib.rs | 1 + datafusion/sql/src/planner.rs | 5 + datafusion/sql/src/query.rs | 144 +------------- datafusion/sql/src/set_expr.rs | 81 ++++---- datafusion/sql/tests/sql_integration.rs | 10 - datafusion/sqllogictest/test_files/cte.slt | 88 +++++++++ 7 files changed, 356 insertions(+), 185 deletions(-) create mode 100644 datafusion/sql/src/cte.rs diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs new file mode 100644 index 000000000000..5b1f81e820a2 --- /dev/null +++ b/datafusion/sql/src/cte.rs @@ -0,0 +1,212 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; + +use arrow::datatypes::Schema; +use datafusion_common::{ + not_impl_err, plan_err, + tree_node::{TreeNode, TreeNodeRecursion}, + Result, +}; +use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource}; +use sqlparser::ast::{Query, SetExpr, SetOperator, With}; + +impl<'a, S: ContextProvider> SqlToRel<'a, S> { + pub(super) fn plan_with_clause( + &self, + with: With, + planner_context: &mut PlannerContext, + ) -> Result<()> { + let is_recursive = with.recursive; + // Process CTEs from top to bottom + for cte in with.cte_tables { + // A `WITH` block can't use the same name more than once + let cte_name = self.normalizer.normalize(cte.alias.name.clone()); + if planner_context.contains_cte(&cte_name) { + return plan_err!( + "WITH query name {cte_name:?} specified more than once" + ); + } + + // Create a logical plan for the CTE + let cte_plan = if is_recursive { + self.recursive_cte(cte_name.clone(), *cte.query, planner_context)? + } else { + self.non_recursive_cte(*cte.query, planner_context)? + }; + + // Each `WITH` block can change the column names in the last + // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). + let final_plan = self.apply_table_alias(cte_plan, cte.alias)?; + // Export the CTE to the outer query + planner_context.insert_cte(cte_name, final_plan); + } + Ok(()) + } + + fn non_recursive_cte( + &self, + cte_query: Query, + planner_context: &mut PlannerContext, + ) -> Result { + // CTE expr don't need extend outer_query_schema, + // so we clone a new planner_context here. + let mut cte_planner_context = planner_context.clone(); + self.query_to_plan(cte_query, &mut cte_planner_context) + } + + fn recursive_cte( + &self, + cte_name: String, + mut cte_query: Query, + planner_context: &mut PlannerContext, + ) -> Result { + if !self + .context_provider + .options() + .execution + .enable_recursive_ctes + { + return not_impl_err!("Recursive CTEs are not enabled"); + } + + let (left_expr, right_expr, set_quantifier) = match *cte_query.body { + SetExpr::SetOperation { + op: SetOperator::Union, + left, + right, + set_quantifier, + } => (left, right, set_quantifier), + other => { + // If the query is not a UNION, then it is not a recursive CTE + cte_query.body = Box::new(other); + return self.non_recursive_cte(cte_query, planner_context); + } + }; + + // Each recursive CTE consists from two parts in the logical plan: + // 1. A static term (the left hand side on the SQL, where the + // referencing to the same CTE is not allowed) + // + // 2. A recursive term (the right hand side, and the recursive + // part) + + // Since static term does not have any specific properties, it can + // be compiled as if it was a regular expression. This will + // allow us to infer the schema to be used in the recursive term. + + // ---------- Step 1: Compile the static term ------------------ + let static_plan = + self.set_expr_to_plan(*left_expr, &mut planner_context.clone())?; + + // Since the recursive CTEs include a component that references a + // table with its name, like the example below: + // + // WITH RECURSIVE values(n) AS ( + // SELECT 1 as n -- static term + // UNION ALL + // SELECT n + 1 + // FROM values -- self reference + // WHERE n < 100 + // ) + // + // We need a temporary 'relation' to be referenced and used. PostgreSQL + // calls this a 'working table', but it is entirely an implementation + // detail and a 'real' table with that name might not even exist (as + // in the case of DataFusion). + // + // Since we can't simply register a table during planning stage (it is + // an execution problem), we'll use a relation object that preserves the + // schema of the input perfectly and also knows which recursive CTE it is + // bound to. + + // ---------- Step 2: Create a temporary relation ------------------ + // Step 2.1: Create a table source for the temporary relation + let work_table_source = self.context_provider.create_cte_work_table( + &cte_name, + Arc::new(Schema::from(static_plan.schema().as_ref())), + )?; + + // Step 2.2: Create a temporary relation logical plan that will be used + // as the input to the recursive term + let work_table_plan = LogicalPlanBuilder::scan( + cte_name.to_string(), + work_table_source.clone(), + None, + )? + .build()?; + + let name = cte_name.clone(); + + // Step 2.3: Register the temporary relation in the planning context + // For all the self references in the variadic term, we'll replace it + // with the temporary relation we created above by temporarily registering + // it as a CTE. This temporary relation in the planning context will be + // replaced by the actual CTE plan once we're done with the planning. + planner_context.insert_cte(cte_name.clone(), work_table_plan); + + // ---------- Step 3: Compile the recursive term ------------------ + // this uses the named_relation we inserted above to resolve the + // relation. This ensures that the recursive term uses the named relation logical plan + // and thus the 'continuance' physical plan as its input and source + let recursive_plan = + self.set_expr_to_plan(*right_expr, &mut planner_context.clone())?; + + // Check if the recursive term references the CTE itself, + // if not, it is a non-recursive CTE + if !has_work_table_reference(&recursive_plan, &work_table_source) { + // Remove the work table plan from the context + planner_context.remove_cte(&cte_name); + // Compile it as a non-recursive CTE + return self.set_operation_to_plan( + SetOperator::Union, + static_plan, + recursive_plan, + set_quantifier, + ); + } + + // ---------- Step 4: Create the final plan ------------------ + // Step 4.1: Compile the final plan + let distinct = !Self::is_union_all(set_quantifier)?; + LogicalPlanBuilder::from(static_plan) + .to_recursive_query(name, recursive_plan, distinct)? + .build() + } +} + +fn has_work_table_reference( + plan: &LogicalPlan, + work_table_source: &Arc, +) -> bool { + let mut has_reference = false; + plan.apply(&mut |node| { + if let LogicalPlan::TableScan(scan) = node { + if Arc::ptr_eq(&scan.source, work_table_source) { + has_reference = true; + return Ok(TreeNodeRecursion::Stop); + } + } + Ok(TreeNodeRecursion::Continue) + }) + // Closure always return Ok + .unwrap(); + has_reference +} diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 12d6a4669634..1040cc61c702 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -28,6 +28,7 @@ //! [`SqlToRel`]: planner::SqlToRel //! [`LogicalPlan`]: datafusion_expr::logical_plan::LogicalPlan +mod cte; mod expr; pub mod parser; pub mod planner; diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index f94c6ec4e8c9..d2182962b98e 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -213,6 +213,11 @@ impl PlannerContext { pub fn get_cte(&self, cte_name: &str) -> Option<&LogicalPlan> { self.ctes.get(cte_name).map(|cte| cte.as_ref()) } + + /// Remove the plan of CTE / Subquery for the specified name + pub(super) fn remove_cte(&mut self, cte_name: &str) { + self.ctes.remove(cte_name); + } } /// SQL query planner diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index eda8398c432b..ba876d052f5e 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -19,21 +19,15 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow::datatypes::Schema; -use datafusion_common::{ - not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{plan_err, Constraints, Result, ScalarValue}; use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, Operator, }; use sqlparser::ast::{ - Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, SetOperator, - SetQuantifier, Value, + Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, }; -use sqlparser::parser::ParserError::ParserError; - impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Generate a logical plan from an SQL query pub(crate) fn query_to_plan( @@ -54,139 +48,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { let set_expr = query.body; if let Some(with) = query.with { - // Process CTEs from top to bottom - let is_recursive = with.recursive; - - for cte in with.cte_tables { - // A `WITH` block can't use the same name more than once - let cte_name = self.normalizer.normalize(cte.alias.name.clone()); - if planner_context.contains_cte(&cte_name) { - return sql_err!(ParserError(format!( - "WITH query name {cte_name:?} specified more than once" - ))); - } - - if is_recursive { - if !self - .context_provider - .options() - .execution - .enable_recursive_ctes - { - return not_impl_err!("Recursive CTEs are not enabled"); - } - - match *cte.query.body { - SetExpr::SetOperation { - op: SetOperator::Union, - left, - right, - set_quantifier, - } => { - let distinct = set_quantifier != SetQuantifier::All; - - // Each recursive CTE consists from two parts in the logical plan: - // 1. A static term (the left hand side on the SQL, where the - // referencing to the same CTE is not allowed) - // - // 2. A recursive term (the right hand side, and the recursive - // part) - - // Since static term does not have any specific properties, it can - // be compiled as if it was a regular expression. This will - // allow us to infer the schema to be used in the recursive term. - - // ---------- Step 1: Compile the static term ------------------ - let static_plan = self - .set_expr_to_plan(*left, &mut planner_context.clone())?; - - // Since the recursive CTEs include a component that references a - // table with its name, like the example below: - // - // WITH RECURSIVE values(n) AS ( - // SELECT 1 as n -- static term - // UNION ALL - // SELECT n + 1 - // FROM values -- self reference - // WHERE n < 100 - // ) - // - // We need a temporary 'relation' to be referenced and used. PostgreSQL - // calls this a 'working table', but it is entirely an implementation - // detail and a 'real' table with that name might not even exist (as - // in the case of DataFusion). - // - // Since we can't simply register a table during planning stage (it is - // an execution problem), we'll use a relation object that preserves the - // schema of the input perfectly and also knows which recursive CTE it is - // bound to. - - // ---------- Step 2: Create a temporary relation ------------------ - // Step 2.1: Create a table source for the temporary relation - let work_table_source = - self.context_provider.create_cte_work_table( - &cte_name, - Arc::new(Schema::from(static_plan.schema().as_ref())), - )?; - - // Step 2.2: Create a temporary relation logical plan that will be used - // as the input to the recursive term - let work_table_plan = LogicalPlanBuilder::scan( - cte_name.to_string(), - work_table_source, - None, - )? - .build()?; - - let name = cte_name.clone(); - - // Step 2.3: Register the temporary relation in the planning context - // For all the self references in the variadic term, we'll replace it - // with the temporary relation we created above by temporarily registering - // it as a CTE. This temporary relation in the planning context will be - // replaced by the actual CTE plan once we're done with the planning. - planner_context.insert_cte(cte_name.clone(), work_table_plan); - - // ---------- Step 3: Compile the recursive term ------------------ - // this uses the named_relation we inserted above to resolve the - // relation. This ensures that the recursive term uses the named relation logical plan - // and thus the 'continuance' physical plan as its input and source - let recursive_plan = self - .set_expr_to_plan(*right, &mut planner_context.clone())?; - - // ---------- Step 4: Create the final plan ------------------ - // Step 4.1: Compile the final plan - let logical_plan = LogicalPlanBuilder::from(static_plan) - .to_recursive_query(name, recursive_plan, distinct)? - .build()?; - - let final_plan = - self.apply_table_alias(logical_plan, cte.alias)?; - - // Step 4.2: Remove the temporary relation from the planning context and replace it - // with the final plan. - planner_context.insert_cte(cte_name.clone(), final_plan); - } - _ => { - return Err(DataFusionError::SQL( - ParserError(format!("Unsupported CTE: {cte}")), - None, - )); - } - }; - } else { - // create logical plan & pass backreferencing CTEs - // CTE expr don't need extend outer_query_schema - let logical_plan = - self.query_to_plan(*cte.query, &mut planner_context.clone())?; - - // Each `WITH` block can change the column names in the last - // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). - let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; - - planner_context.insert_cte(cte_name, logical_plan); - } - } + self.plan_with_clause(with, planner_context)?; } let plan = self.set_expr_to_plan(*(set_expr.clone()), planner_context)?; let plan = self.order_by(plan, query.order_by, planner_context)?; diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index 2cbb68368f72..cbe41c33c729 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -35,45 +35,58 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { right, set_quantifier, } => { - let all = match set_quantifier { - SetQuantifier::All => true, - SetQuantifier::Distinct | SetQuantifier::None => false, - SetQuantifier::ByName => { - return not_impl_err!("UNION BY NAME not implemented"); - } - SetQuantifier::AllByName => { - return not_impl_err!("UNION ALL BY NAME not implemented") - } - SetQuantifier::DistinctByName => { - return not_impl_err!("UNION DISTINCT BY NAME not implemented") - } - }; - let left_plan = self.set_expr_to_plan(*left, planner_context)?; let right_plan = self.set_expr_to_plan(*right, planner_context)?; - match (op, all) { - (SetOperator::Union, true) => LogicalPlanBuilder::from(left_plan) - .union(right_plan)? - .build(), - (SetOperator::Union, false) => LogicalPlanBuilder::from(left_plan) - .union_distinct(right_plan)? - .build(), - (SetOperator::Intersect, true) => { - LogicalPlanBuilder::intersect(left_plan, right_plan, true) - } - (SetOperator::Intersect, false) => { - LogicalPlanBuilder::intersect(left_plan, right_plan, false) - } - (SetOperator::Except, true) => { - LogicalPlanBuilder::except(left_plan, right_plan, true) - } - (SetOperator::Except, false) => { - LogicalPlanBuilder::except(left_plan, right_plan, false) - } - } + self.set_operation_to_plan(op, left_plan, right_plan, set_quantifier) } SetExpr::Query(q) => self.query_to_plan(*q, planner_context), _ => not_impl_err!("Query {set_expr} not implemented yet"), } } + + pub(super) fn is_union_all(set_quantifier: SetQuantifier) -> Result { + match set_quantifier { + SetQuantifier::All => Ok(true), + SetQuantifier::Distinct | SetQuantifier::None => Ok(false), + SetQuantifier::ByName => { + not_impl_err!("UNION BY NAME not implemented") + } + SetQuantifier::AllByName => { + not_impl_err!("UNION ALL BY NAME not implemented") + } + SetQuantifier::DistinctByName => { + not_impl_err!("UNION DISTINCT BY NAME not implemented") + } + } + } + + pub(super) fn set_operation_to_plan( + &self, + op: SetOperator, + left_plan: LogicalPlan, + right_plan: LogicalPlan, + set_quantifier: SetQuantifier, + ) -> Result { + let all = Self::is_union_all(set_quantifier)?; + match (op, all) { + (SetOperator::Union, true) => LogicalPlanBuilder::from(left_plan) + .union(right_plan)? + .build(), + (SetOperator::Union, false) => LogicalPlanBuilder::from(left_plan) + .union_distinct(right_plan)? + .build(), + (SetOperator::Intersect, true) => { + LogicalPlanBuilder::intersect(left_plan, right_plan, true) + } + (SetOperator::Intersect, false) => { + LogicalPlanBuilder::intersect(left_plan, right_plan, false) + } + (SetOperator::Except, true) => { + LogicalPlanBuilder::except(left_plan, right_plan, true) + } + (SetOperator::Except, false) => { + LogicalPlanBuilder::except(left_plan, right_plan, false) + } + } + } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 101c31039c7e..a34f8f07fe92 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2994,16 +2994,6 @@ fn join_with_aliases() { quick_test(sql, expected); } -#[test] -fn cte_use_same_name_multiple_times() { - let sql = - "with a as (select * from person), a as (select * from orders) select * from a;"; - let expected = - "SQL error: ParserError(\"WITH query name \\\"a\\\" specified more than once\")"; - let result = logical_plan(sql).err().unwrap(); - assert_eq!(result.strip_backtrace(), expected); -} - #[test] fn negative_interval_plus_interval_in_projection() { let sql = "select -interval '2 days' + interval '5 days';"; diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index e33dfabaf2ca..eec7eb0e3399 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -39,6 +39,37 @@ physical_plan ProjectionExec: expr=[1 as a, 2 as b, 3 as c] --PlaceholderRowExec +# cte_use_same_name_multiple_times +statement error DataFusion error: Error during planning: WITH query name "a" specified more than once +WITH a AS (SELECT 1), a AS (SELECT 2) SELECT * FROM a; + +# Test disabling recursive CTE +statement ok +set datafusion.execution.enable_recursive_ctes = false; + +query error DataFusion error: This feature is not implemented: Recursive CTEs are not enabled +WITH RECURSIVE nodes AS ( + SELECT 1 as id + UNION ALL + SELECT id + 1 as id + FROM nodes + WHERE id < 3 +) SELECT * FROM nodes + +statement ok +set datafusion.execution.enable_recursive_ctes = true; + + +# DISTINCT UNION is not supported +query error DataFusion error: This feature is not implemented: Recursive queries with a distinct 'UNION' \(in which the previous iteration's results will be de\-duplicated\) is not supported +WITH RECURSIVE nodes AS ( + SELECT 1 as id + UNION + SELECT id + 1 as id + FROM nodes + WHERE id < 3 +) SELECT * FROM nodes + # trivial recursive CTE works query I rowsort @@ -744,3 +775,60 @@ WITH RECURSIVE my_cte AS ( UNION ALL SELECT 'abc' FROM my_cte WHERE CAST(a AS text) !='abc' ) SELECT * FROM my_cte; + +# Define a non-recursive CTE in the recursive WITH clause. +# Test issue: https://github.com/apache/arrow-datafusion/issues/9804 +query I +WITH RECURSIVE cte AS ( + SELECT a FROM (VALUES(1)) AS t(a) WHERE a > 2 + UNION ALL + SELECT 2 +) SELECT * FROM cte; +---- +2 + +# Define a non-recursive CTE in the recursive WITH clause. +# UNION ALL +query I rowsort +WITH RECURSIVE cte AS ( + SELECT 1 + UNION ALL + SELECT 2 +) SELECT * FROM cte; +---- +1 +2 + +# Define a non-recursive CTE in the recursive WITH clause. +# DISTINCT UNION +query I +WITH RECURSIVE cte AS ( + SELECT 2 + UNION + SELECT 2 +) SELECT * FROM cte; +---- +2 + +# Define a non-recursive CTE in the recursive WITH clause. +# UNION is not present. +query I +WITH RECURSIVE cte AS ( + SELECT 1 +) SELECT * FROM cte; +---- +1 + +# Define a recursive CTE and a non-recursive CTE at the same time. +query II rowsort +WITH RECURSIVE +non_recursive_cte AS ( + SELECT 1 +), +recursive_cte AS ( + SELECT 1 AS a UNION ALL SELECT a+2 FROM recursive_cte WHERE a < 3 +) +SELECT * FROM non_recursive_cte, recursive_cte; +---- +1 1 +1 3 From c8584557cdfa7c138ab9039ceac31323f48a44d3 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 1 Apr 2024 11:36:52 -0400 Subject: [PATCH 3/5] Minor: Add SIGMOD paper reference to architecture guide (#9886) --- datafusion/core/src/lib.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 5dc3e1ce7d3f..f6e2171d6b5f 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -167,6 +167,11 @@ //! overview of how DataFusion is organized and then link to other //! sections of the docs with more details --> //! +//! You can find a formal description of DataFusion's architecture in our +//! [SIGMOD 2024 Paper]. +//! +//! [SIGMOD 2024 Paper]: https://github.com/apache/arrow-datafusion/files/14789704/DataFusion_Query_Engine___SIGMOD_2024-FINAL.pdf +//! //! ## Overview Presentations //! //! The following presentations offer high level overviews of the From b698e5ffc43ebb0585339ef9899496beccc0a707 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Mon, 1 Apr 2024 23:38:56 +0800 Subject: [PATCH 4/5] refactor: add macro for the binary math function in `datafusion-function` (#9889) * refactor: macro for the binary math function in datafusion-function * Update datafusion/functions/src/macros.rs --------- Co-authored-by: Andrew Lamb --- datafusion/functions/src/macros.rs | 107 ++++++++++++++++++- datafusion/functions/src/math/atan2.rs | 140 ------------------------- datafusion/functions/src/math/mod.rs | 3 +- 3 files changed, 107 insertions(+), 143 deletions(-) delete mode 100644 datafusion/functions/src/math/atan2.rs diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 4907d74fe941..c92cb27ef5bb 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -251,7 +251,112 @@ macro_rules! make_math_unary_udf { }; } -#[macro_export] +/// Macro to create a binary math UDF. +/// +/// A binary math function takes two arguments of types Float32 or Float64, +/// applies a binary floating function to the argument, and returns a value of the same type. +/// +/// $UDF: the name of the UDF struct that implements `ScalarUDFImpl` +/// $GNAME: a singleton instance of the UDF +/// $NAME: the name of the function +/// $BINARY_FUNC: the binary function to apply to the argument +/// $MONOTONIC_FUNC: the monotonicity of the function +macro_rules! make_math_binary_udf { + ($UDF:ident, $GNAME:ident, $NAME:ident, $BINARY_FUNC:ident, $MONOTONICITY:expr) => { + make_udf_function!($NAME::$UDF, $GNAME, $NAME); + + mod $NAME { + use arrow::array::{ArrayRef, Float32Array, Float64Array}; + use arrow::datatypes::DataType; + use datafusion_common::{exec_err, DataFusionError, Result}; + use datafusion_expr::TypeSignature::*; + use datafusion_expr::{ + ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, + }; + use std::any::Any; + use std::sync::Arc; + + #[derive(Debug)] + pub struct $UDF { + signature: Signature, + } + + impl $UDF { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Float32, Float32]), + Exact(vec![Float64, Float64]), + ], + Volatility::Immutable, + ), + } + } + } + + impl ScalarUDFImpl for $UDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + stringify!($NAME) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float32 => Ok(DataType::Float32), + // For other types (possible values float64/null/int), use Float64 + _ => Ok(DataType::Float64), + } + } + + fn monotonicity(&self) -> Result> { + Ok($MONOTONICITY) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float64Array, + { f64::$BINARY_FUNC } + )), + + DataType::Float32 => Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float32Array, + { f32::$BINARY_FUNC } + )), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ) + } + }; + Ok(ColumnarValue::Array(arr)) + } + } + } + }; +} + macro_rules! make_function_inputs2 { ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); diff --git a/datafusion/functions/src/math/atan2.rs b/datafusion/functions/src/math/atan2.rs deleted file mode 100644 index b090c6c454fd..000000000000 --- a/datafusion/functions/src/math/atan2.rs +++ /dev/null @@ -1,140 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Math function: `atan2()`. - -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; -use datafusion_common::DataFusionError; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; - -use crate::make_function_inputs2; -use crate::utils::make_scalar_function; - -#[derive(Debug)] -pub(super) struct Atan2 { - signature: Signature, -} - -impl Atan2 { - pub fn new() -> Self { - use DataType::*; - Self { - signature: Signature::one_of( - vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], - Volatility::Immutable, - ), - } - } -} - -impl ScalarUDFImpl for Atan2 { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "atan2" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use self::DataType::*; - match &arg_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - } - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - make_scalar_function(atan2, vec![])(args) - } -} - -/// Atan2 SQL function -pub fn atan2(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::atan2 } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float32Array, - { f32::atan2 } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function atan2"), - } -} - -#[cfg(test)] -mod test { - use super::*; - use datafusion_common::cast::{as_float32_array, as_float64_array}; - - #[test] - fn test_atan2_f64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y - Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x - ]; - - let result = atan2(&args).expect("failed to initialize function atan2"); - let floats = - as_float64_array(&result).expect("failed to initialize function atan2"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), (2.0_f64).atan2(1.0)); - assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0)); - assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0)); - assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0)); - } - - #[test] - fn test_atan2_f32() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y - Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x - ]; - - let result = atan2(&args).expect("failed to initialize function atan2"); - let floats = - as_float32_array(&result).expect("failed to initialize function atan2"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), (2.0_f32).atan2(1.0)); - assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0)); - assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0)); - assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0)); - } -} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 2ee1fffa1625..ee53fcf96a8b 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -18,13 +18,11 @@ //! "math" DataFusion functions mod abs; -mod atan2; mod nans; // Create UDFs make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); -make_udf_function!(atan2::Atan2, ATAN2, atan2); make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); @@ -39,6 +37,7 @@ make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, Some(vec![Some(true)])); make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)])); make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)])); +make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, Some(vec![Some(true)])); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( From d8d521ac8b90002fa0ba1f91456051a9775ae193 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Mon, 1 Apr 2024 11:40:23 -0400 Subject: [PATCH 5/5] Add benchmark for substr_index (#9878) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. * Fixed missing trim() function. * Add benchmark for substr_index * Add missing required-features * Update datafusion/functions/benches/substr_index.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- datafusion/functions/Cargo.toml | 5 + datafusion/functions/benches/substr_index.rs | 103 +++++++++++++++++++ 2 files changed, 108 insertions(+) create mode 100644 datafusion/functions/benches/substr_index.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 425ac207c33e..ef7d2c9b1892 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -107,3 +107,8 @@ required-features = ["datetime_expressions"] harness = false name = "to_char" required-features = ["datetime_expressions"] + +[[bench]] +harness = false +name = "substr_index" +required-features = ["unicode_expressions"] diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs new file mode 100644 index 000000000000..bb9a5b809eee --- /dev/null +++ b/datafusion/functions/benches/substr_index.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int64Array, StringArray}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::distributions::{Alphanumeric, Uniform}; +use rand::prelude::Distribution; +use rand::Rng; + +use datafusion_expr::ColumnarValue; +use datafusion_functions::unicode::substr_index; + +struct Filter { + dist: Dist, + test: Test, +} + +impl Distribution for Filter +where + Dist: Distribution, + Test: Fn(&T) -> bool, +{ + fn sample(&self, rng: &mut R) -> T { + loop { + let x = self.dist.sample(rng); + if (self.test)(&x) { + return x; + } + } + } +} + +fn data() -> (StringArray, StringArray, Int64Array) { + let dist = Filter { + dist: Uniform::new(-4, 5), + test: |x: &i64| x != &0, + }; + let mut rng = rand::thread_rng(); + let mut strings: Vec = vec![]; + let mut delimiters: Vec = vec![]; + let mut counts: Vec = vec![]; + + for _ in 0..1000 { + let length = rng.gen_range(20..50); + let text: String = (&mut rng) + .sample_iter(&Alphanumeric) + .take(length) + .map(char::from) + .collect(); + let char = rng.gen_range(0..text.len()); + let delimiter = &text.chars().nth(char).unwrap(); + let count = rng.sample(&dist); + + strings.push(text); + delimiters.push(delimiter.to_string()); + counts.push(count); + } + + ( + StringArray::from(strings), + StringArray::from(delimiters), + Int64Array::from(counts), + ) +} + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("substr_index_array_array_1000", |b| { + let (strings, delimiters, counts) = data(); + let strings = ColumnarValue::Array(Arc::new(strings) as ArrayRef); + let delimiters = ColumnarValue::Array(Arc::new(delimiters) as ArrayRef); + let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); + + let args = [strings, delimiters, counts]; + b.iter(|| { + black_box( + substr_index() + .invoke(&args) + .expect("substr_index should work on valid values"), + ) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches);