diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index f431e6264367..8c8628a4b4ff 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1746,7 +1746,7 @@ impl ScalarValue { } /// Converts `Vec` where each element has type corresponding to - /// `data_type`, to a [`ListArray`]. + /// `data_type`, to a single element [`ListArray`]. /// /// Example /// ``` @@ -4453,7 +4453,8 @@ mod tests { // The alignment requirements differ across architectures and // thus the size of the enum appears to as well - assert_eq!(std::mem::size_of::(), 48); + // The value can be changed depending on rust version + assert_eq!(std::mem::size_of::(), 64); } #[test] diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 59905d859dc8..8df16e7944d2 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -46,7 +46,7 @@ use tokio::task::JoinSet; /// same results #[tokio::test(flavor = "multi_thread")] async fn streaming_aggregate_test() { - let test_cases = vec![ + let test_cases = [ vec!["a"], vec!["b", "a"], vec!["c", "a"], diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 9cc7b4e855cb..a763a58379b7 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -26,9 +26,7 @@ use std::sync::{Arc, OnceLock}; use crate::signature::TIMEZONE_WILDCARD; use crate::type_coercion::binary::get_wider_type; use crate::type_coercion::functions::data_types; -use crate::{ - conditional_expressions, FuncMonotonicity, Signature, TypeSignature, Volatility, -}; +use crate::{FuncMonotonicity, Signature, TypeSignature, Volatility}; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, TimeUnit}; use datafusion_common::{exec_err, plan_err, DataFusionError, Result}; @@ -899,10 +897,9 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::ConcatWithSeparator => { Signature::variadic(vec![Utf8], self.volatility()) } - BuiltinScalarFunction::Coalesce => Signature::variadic( - conditional_expressions::SUPPORTED_COALESCE_TYPES.to_vec(), - self.volatility(), - ), + BuiltinScalarFunction::Coalesce => { + Signature::variadic_equal(self.volatility()) + } BuiltinScalarFunction::SHA224 | BuiltinScalarFunction::SHA256 | BuiltinScalarFunction::SHA384 @@ -1575,4 +1572,13 @@ mod tests { assert_eq!(func_from_str, *func_original); } } + + #[test] + fn test_coalesce_return_types() { + let coalesce = BuiltinScalarFunction::Coalesce; + let return_type = coalesce + .return_type(&[DataType::Date32, DataType::Date32]) + .unwrap(); + assert_eq!(return_type, DataType::Date32); + } } diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 1346825f054d..7a2bf4b6c44a 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -22,25 +22,6 @@ use arrow::datatypes::DataType; use datafusion_common::{plan_err, DFSchema, Result}; use std::collections::HashSet; -/// Currently supported types by the coalesce function. -/// The order of these types correspond to the order on which coercion applies -/// This should thus be from least informative to most informative -pub static SUPPORTED_COALESCE_TYPES: &[DataType] = &[ - DataType::Boolean, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - DataType::Utf8, - DataType::LargeUtf8, -]; - /// Helper struct for building [Expr::Case] pub struct CaseBuilder { expr: Option>, diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index cd9a8344dec4..3fb0485c002e 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -397,7 +397,7 @@ mod test { let expr = col("a") + col("b"); let schema_a = make_schema_with_empty_metadata(vec![make_field("\"tableA\"", "a")]); - let schemas = vec![schema_a]; + let schemas = [schema_a]; let schemas = schemas.iter().collect::>(); let error = diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index ca021c4bfc28..bbf0274ef4dc 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2415,7 +2415,7 @@ impl DistinctOn { /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() #[non_exhaustive] pub struct Aggregate { diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 118844e4b266..e41934354806 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -361,7 +361,7 @@ fn string_temporal_coercion( /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation /// where one both are numeric -fn comparison_binary_numeric_coercion( +pub(crate) fn comparison_binary_numeric_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index fb09be44c4c7..d4095a72fe3e 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -28,7 +28,7 @@ use arrow::{ use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; -use super::binary::comparison_coercion; +use super::binary::{comparison_binary_numeric_coercion, comparison_coercion}; /// Performs type coercion for function arguments. /// @@ -187,6 +187,10 @@ fn get_valid_types( let new_type = current_types.iter().skip(1).try_fold( current_types.first().unwrap().clone(), |acc, x| { + // The coerced types found by `comparison_coercion` are not guaranteed to be + // coercible for the arguments. `comparison_coercion` returns more loose + // types that can be coerced to both `acc` and `x` for comparison purpose. + // See `maybe_data_types` for the actual coercion. let coerced_type = comparison_coercion(&acc, x); if let Some(coerced_type) = coerced_type { Ok(coerced_type) @@ -276,9 +280,9 @@ fn maybe_data_types( if current_type == valid_type { new_type.push(current_type.clone()) } else { - // attempt to coerce - if let Some(valid_type) = coerced_from(valid_type, current_type) { - new_type.push(valid_type) + // attempt to coerce. + if let Some(coerced_type) = coerced_from(valid_type, current_type) { + new_type.push(coerced_type) } else { // not possible return None; @@ -427,8 +431,19 @@ fn coerced_from<'a>( Some(type_into.clone()) } - // cannot coerce - _ => None, + // More coerce rules. + // Note that not all rules in `comparison_coercion` can be reused here. + // For example, all numeric types can be coerced into Utf8 for comparison, + // but not for function arguments. + _ => comparison_binary_numeric_coercion(type_into, type_from).and_then( + |coerced_type| { + if *type_into == coerced_type { + Some(coerced_type) + } else { + None + } + }, + ), } } diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 153cd4efe2c8..9a98d9b76809 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -44,7 +44,7 @@ fn data(rng: &mut ThreadRng) -> StringArray { } fn regex(rng: &mut ThreadRng) -> StringArray { - let samples = vec![ + let samples = [ ".*([A-Z]{1}).*".to_string(), "^(A).*".to_string(), r#"[\p{Letter}-]+"#.to_string(), @@ -60,7 +60,7 @@ fn regex(rng: &mut ThreadRng) -> StringArray { } fn flags(rng: &mut ThreadRng) -> StringArray { - let samples = vec![Some("i".to_string()), Some("im".to_string()), None]; + let samples = [Some("i".to_string()), Some("im".to_string()), None]; let mut sb = StringBuilder::new(); for _ in 0..1000 { let sample = samples.choose(rng).unwrap(); diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 30c184a28e33..9f5376cad513 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -17,6 +17,7 @@ //! Eliminate common sub-expression. +use std::collections::hash_map::Entry; use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; @@ -29,8 +30,7 @@ use datafusion_common::tree_node::{ TreeNodeVisitor, }; use datafusion_common::{ - internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef, - DataFusionError, Result, + internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{ @@ -38,14 +38,80 @@ use datafusion_expr::logical_plan::{ }; use datafusion_expr::{col, Expr, ExprSchemable}; -/// A map from expression's identifier to tuple including -/// - the expression itself (cloned) -/// - counter -/// - DataType of this expression. -type ExprSet = HashMap; +/// Set of expressions generated by the [`ExprIdentifierVisitor`] +/// and consumed by the [`CommonSubexprRewriter`]. +#[derive(Default)] +struct ExprSet { + /// A map from expression's identifier (stringified expr) to tuple including: + /// - the expression itself (cloned) + /// - counter + /// - DataType of this expression. + /// - symbol used as the identifier in the alias. + map: HashMap, +} + +impl ExprSet { + fn expr_identifier(expr: &Expr) -> Identifier { + format!("{expr}") + } -/// Identifier type. Current implementation use describe of a expression (type String) as -/// Identifier. + fn get(&self, key: &Identifier) -> Option<&(Expr, usize, DataType, Identifier)> { + self.map.get(key) + } + + fn entry( + &mut self, + key: Identifier, + ) -> Entry<'_, Identifier, (Expr, usize, DataType, Identifier)> { + self.map.entry(key) + } + + fn populate_expr_set( + &mut self, + expr: &[Expr], + input_schema: DFSchemaRef, + expr_mask: ExprMask, + ) -> Result<()> { + expr.iter().try_for_each(|e| { + self.expr_to_identifier(e, Arc::clone(&input_schema), expr_mask)?; + + Ok(()) + }) + } + + /// Go through an expression tree and generate identifier for every node in this tree. + fn expr_to_identifier( + &mut self, + expr: &Expr, + input_schema: DFSchemaRef, + expr_mask: ExprMask, + ) -> Result<()> { + expr.visit(&mut ExprIdentifierVisitor { + expr_set: self, + input_schema, + visit_stack: vec![], + node_count: 0, + expr_mask, + })?; + + Ok(()) + } +} + +impl From> for ExprSet { + fn from(entries: Vec<(Identifier, (Expr, usize, DataType, Identifier))>) -> Self { + let mut expr_set = Self::default(); + entries.into_iter().for_each(|(k, v)| { + expr_set.map.insert(k, v); + }); + expr_set + } +} + +/// Identifier for each subexpression. +/// +/// Note that the current implementation uses the `Display` of an expression +/// (a `String`) as `Identifier`. /// /// A Identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no /// collision (as low as possible)" @@ -65,21 +131,16 @@ impl CommonSubexprEliminate { fn rewrite_exprs_list( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result>> { exprs_list .iter() - .zip(arrays_list.iter()) - .map(|(exprs, arrays)| { + .map(|exprs| { exprs .iter() .cloned() - .zip(arrays.iter()) - .map(|(expr, id_array)| { - replace_common_expr(expr, id_array, expr_set, affected_id) - }) + .map(|expr| replace_common_expr(expr, expr_set, affected_id)) .collect::>>() }) .collect::>>() @@ -88,7 +149,6 @@ impl CommonSubexprEliminate { fn rewrite_expr( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], input: &LogicalPlan, expr_set: &ExprSet, config: &dyn OptimizerConfig, @@ -96,7 +156,7 @@ impl CommonSubexprEliminate { let mut affected_id = BTreeSet::::new(); let rewrite_exprs = - self.rewrite_exprs_list(exprs_list, arrays_list, expr_set, &mut affected_id)?; + self.rewrite_exprs_list(exprs_list, expr_set, &mut affected_id)?; let mut new_input = self .try_optimize(input, config)? @@ -115,13 +175,13 @@ impl CommonSubexprEliminate { ) -> Result { let Projection { expr, input, .. } = projection; let input_schema = Arc::clone(input.schema()); - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); // Visit expr list and build expr identifier to occuring count map (`expr_set`). - let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(expr, input_schema, ExprMask::Normal)?; let (mut new_expr, new_input) = - self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?; + self.rewrite_expr(&[expr], input, &expr_set, config)?; // Since projection expr changes, schema changes also. Use try_new method. Projection::try_new(pop_expr(&mut new_expr)?, Arc::new(new_input)) @@ -133,25 +193,13 @@ impl CommonSubexprEliminate { filter: &Filter, config: &dyn OptimizerConfig, ) -> Result { - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); let predicate = &filter.predicate; let input_schema = Arc::clone(filter.input.schema()); - let mut id_array = vec![]; - expr_to_identifier( - predicate, - &mut expr_set, - &mut id_array, - input_schema, - ExprMask::Normal, - )?; + expr_set.expr_to_identifier(predicate, input_schema, ExprMask::Normal)?; - let (mut new_expr, new_input) = self.rewrite_expr( - &[&[predicate.clone()]], - &[&[id_array]], - &filter.input, - &expr_set, - config, - )?; + let (mut new_expr, new_input) = + self.rewrite_expr(&[&[predicate.clone()]], &filter.input, &expr_set, config)?; if let Some(predicate) = pop_expr(&mut new_expr)?.pop() { Ok(LogicalPlan::Filter(Filter::try_new( @@ -169,8 +217,7 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result { let mut window_exprs = vec![]; - let mut arrays_per_window = vec![]; - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); // Get all window expressions inside the consecutive window operators. // Consecutive window expressions may refer to same complex expression. @@ -189,30 +236,18 @@ impl CommonSubexprEliminate { plan = input.as_ref().clone(); let input_schema = Arc::clone(input.schema()); - let arrays = - to_arrays(&window_expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(&window_expr, input_schema, ExprMask::Normal)?; window_exprs.push(window_expr); - arrays_per_window.push(arrays); } let mut window_exprs = window_exprs .iter() .map(|expr| expr.as_slice()) .collect::>(); - let arrays_per_window = arrays_per_window - .iter() - .map(|arrays| arrays.as_slice()) - .collect::>(); - assert_eq!(window_exprs.len(), arrays_per_window.len()); - let (mut new_expr, new_input) = self.rewrite_expr( - &window_exprs, - &arrays_per_window, - &plan, - &expr_set, - config, - )?; + let (mut new_expr, new_input) = + self.rewrite_expr(&window_exprs, &plan, &expr_set, config)?; assert_eq!(window_exprs.len(), new_expr.len()); // Construct consecutive window operator, with their corresponding new window expressions. @@ -249,46 +284,36 @@ impl CommonSubexprEliminate { input, .. } = aggregate; - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); - // rewrite inputs + // build expr_set, with groupby and aggr let input_schema = Arc::clone(input.schema()); - let group_arrays = to_arrays( + expr_set.populate_expr_set( group_expr, Arc::clone(&input_schema), - &mut expr_set, ExprMask::Normal, )?; - let aggr_arrays = - to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(aggr_expr, input_schema, ExprMask::Normal)?; - let (mut new_expr, new_input) = self.rewrite_expr( - &[group_expr, aggr_expr], - &[&group_arrays, &aggr_arrays], - input, - &expr_set, - config, - )?; + // rewrite inputs + let (mut new_expr, new_input) = + self.rewrite_expr(&[group_expr, aggr_expr], input, &expr_set, config)?; // note the reversed pop order. let new_aggr_expr = pop_expr(&mut new_expr)?; let new_group_expr = pop_expr(&mut new_expr)?; // create potential projection on top - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); let new_input_schema = Arc::clone(new_input.schema()); - let aggr_arrays = to_arrays( + expr_set.populate_expr_set( &new_aggr_expr, new_input_schema.clone(), - &mut expr_set, ExprMask::NormalAndAggregates, )?; + let mut affected_id = BTreeSet::::new(); - let mut rewritten = self.rewrite_exprs_list( - &[&new_aggr_expr], - &[&aggr_arrays], - &expr_set, - &mut affected_id, - )?; + let mut rewritten = + self.rewrite_exprs_list(&[&new_aggr_expr], &expr_set, &mut affected_id)?; let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { @@ -308,9 +333,9 @@ impl CommonSubexprEliminate { for id in affected_id { match expr_set.get(&id) { - Some((expr, _, _)) => { + Some((expr, _, _, symbol)) => { // todo: check `nullable` - agg_exprs.push(expr.clone().alias(&id)); + agg_exprs.push(expr.clone().alias(symbol.as_str())); } _ => { return internal_err!("expr_set invalid state"); @@ -328,8 +353,7 @@ impl CommonSubexprEliminate { agg_exprs.push(expr.alias(&name)); proj_exprs.push(Expr::Column(Column::from_name(name))); } else { - let id = - ExprIdentifierVisitor::<'static>::desc_expr(&expr_rewritten); + let id = ExprSet::expr_identifier(&expr_rewritten); let out_name = expr_rewritten.to_field(&new_input_schema)?.qualified_name(); agg_exprs.push(expr_rewritten.alias(&id)); @@ -360,13 +384,14 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result { let Sort { expr, input, fetch } = sort; - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); let input_schema = Arc::clone(input.schema()); - let arrays = to_arrays(expr, input_schema, &mut expr_set, ExprMask::Normal)?; + // Visit expr list and build expr identifier to occuring count map (`expr_set`). + expr_set.populate_expr_set(expr, input_schema, ExprMask::Normal)?; let (mut new_expr, new_input) = - self.rewrite_expr(&[expr], &[&arrays], input, &expr_set, config)?; + self.rewrite_expr(&[&expr], input, &expr_set, config)?; Ok(LogicalPlan::Sort(Sort { expr: pop_expr(&mut new_expr)?, @@ -460,28 +485,6 @@ fn pop_expr(new_expr: &mut Vec>) -> Result> { .ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string())) } -fn to_arrays( - expr: &[Expr], - input_schema: DFSchemaRef, - expr_set: &mut ExprSet, - expr_mask: ExprMask, -) -> Result>> { - expr.iter() - .map(|e| { - let mut id_array = vec![]; - expr_to_identifier( - e, - expr_set, - &mut id_array, - Arc::clone(&input_schema), - expr_mask, - )?; - - Ok(id_array) - }) - .collect::>>() -} - /// Build the "intermediate" projection plan that evaluates the extracted common expressions. fn build_common_expr_project_plan( input: LogicalPlan, @@ -493,11 +496,11 @@ fn build_common_expr_project_plan( for id in affected_id { match expr_set.get(&id) { - Some((expr, _, data_type)) => { + Some((expr, _, data_type, symbol)) => { // todo: check `nullable` let field = DFField::new_unqualified(&id, data_type.clone(), true); fields_set.insert(field.name().to_owned()); - project_exprs.push(expr.clone().alias(&id)); + project_exprs.push(expr.clone().alias(symbol.as_str())); } _ => { return internal_err!("expr_set invalid state"); @@ -597,15 +600,15 @@ impl ExprMask { /// This visitor implementation use a stack `visit_stack` to track traversal, which /// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called /// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack. -/// And try to pop out a `EnterMark` on leaving a node (`post_visit()`). All `ExprItem` +/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem` /// before the first `EnterMark` is considered to be sub-tree of the leaving node. /// /// This visitor also records identifier in `id_array`. Makes the following traverse /// pass can get the identifier of a node without recalculate it. We assign each node /// in the expr tree a series number, start from 1, maintained by `series_number`. -/// Series number represents the order we left (`post_visit`) a node. Has the property +/// Series number represents the order we left (`f_up()`) a node. Has the property /// that child node's series number always smaller than parent's. While `id_array` is -/// organized in the order we enter (`pre_visit`) a node. `node_count` helps us to +/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to /// get the index of `id_array` for each node. /// /// `Expr` without sub-expr (column, literal etc.) will not have identifier @@ -613,17 +616,13 @@ impl ExprMask { struct ExprIdentifierVisitor<'a> { // param expr_set: &'a mut ExprSet, - /// series number (usize) and identifier. - id_array: &'a mut Vec<(usize, Identifier)>, /// input schema for the node that we're optimizing, so we can determine the correct datatype /// for each subexpression input_schema: DFSchemaRef, // inner states visit_stack: Vec, - /// increased in pre_visit, start from 0. + /// increased in fn_down, start from 0. node_count: usize, - /// increased in post_visit, start from 1. - series_number: usize, /// which expression should be skipped? expr_mask: ExprMask, } @@ -633,31 +632,29 @@ enum VisitRecord { /// `usize` is the monotone increasing series number assigned in pre_visit(). /// Starts from 0. Is used to index the identifier array `id_array` in post_visit(). EnterMark(usize), + /// the node's children were skipped => jump to f_up on same node + JumpMark(usize), /// Accumulated identifier of sub expression. ExprItem(Identifier), } impl ExprIdentifierVisitor<'_> { - fn desc_expr(expr: &Expr) -> String { - format!("{expr}") - } - /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> { + fn pop_enter_mark(&mut self) -> (usize, Identifier) { let mut desc = String::new(); while let Some(item) = self.visit_stack.pop() { match item { - VisitRecord::EnterMark(idx) => { - return Some((idx, desc)); + VisitRecord::EnterMark(idx) | VisitRecord::JumpMark(idx) => { + return (idx, desc); } - VisitRecord::ExprItem(s) => { - desc.push_str(&s); + VisitRecord::ExprItem(id) => { + desc.push_str(&id); } } } - None + unreachable!("Enter mark should paired with node number"); } } @@ -668,81 +665,51 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(TreeNodeRecursion::Jump); + self.visit_stack + .push(VisitRecord::JumpMark(self.node_count)); + return Ok(TreeNodeRecursion::Jump); // go to f_up } + self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; - // put placeholder - self.id_array.push((0, "".to_string())); + Ok(TreeNodeRecursion::Continue) } fn f_up(&mut self, expr: &Expr) -> Result { - self.series_number += 1; + let (_idx, sub_expr_identifier) = self.pop_enter_mark(); - let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else { - return Ok(TreeNodeRecursion::Continue); - }; // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { - self.id_array[idx].0 = self.series_number; - let desc = Self::desc_expr(expr); - self.visit_stack.push(VisitRecord::ExprItem(desc)); + let curr_expr_identifier = ExprSet::expr_identifier(expr); + self.visit_stack + .push(VisitRecord::ExprItem(curr_expr_identifier)); return Ok(TreeNodeRecursion::Continue); } - let mut desc = Self::desc_expr(expr); - desc.push_str(&sub_expr_desc); + let curr_expr_identifier = ExprSet::expr_identifier(expr); + let alias_symbol = format!("{curr_expr_identifier}{sub_expr_identifier}"); - self.id_array[idx] = (self.series_number, desc.clone()); - self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); + self.visit_stack + .push(VisitRecord::ExprItem(alias_symbol.clone())); let data_type = expr.get_type(&self.input_schema)?; self.expr_set - .entry(desc) - .or_insert_with(|| (expr.clone(), 0, data_type)) + .entry(curr_expr_identifier) + .or_insert_with(|| (expr.clone(), 0, data_type, alias_symbol)) .1 += 1; Ok(TreeNodeRecursion::Continue) } } -/// Go through an expression tree and generate identifier for every node in this tree. -fn expr_to_identifier( - expr: &Expr, - expr_set: &mut ExprSet, - id_array: &mut Vec<(usize, Identifier)>, - input_schema: DFSchemaRef, - expr_mask: ExprMask, -) -> Result<()> { - expr.visit(&mut ExprIdentifierVisitor { - expr_set, - id_array, - input_schema, - visit_stack: vec![], - node_count: 0, - series_number: 0, - expr_mask, - })?; - - Ok(()) -} - /// Rewrite expression by replacing detected common sub-expression with /// the corresponding temporary column name. That column contains the /// evaluate result of replaced expression. struct CommonSubexprRewriter<'a> { expr_set: &'a ExprSet, - id_array: &'a [(usize, Identifier)], /// Which identifier is replaced. affected_id: &'a mut BTreeSet, - - /// the max series number we have rewritten. Other expression nodes - /// with smaller series number is already replaced and shouldn't - /// do anything with them. - max_series_number: usize, - /// current node's information's index in `id_array`. - curr_index: usize, } impl TreeNodeRewriter for CommonSubexprRewriter<'_> { @@ -755,88 +722,42 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { if expr.short_circuits() || is_volatile_expression(&expr)? { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } - if self.curr_index >= self.id_array.len() - || self.max_series_number > self.id_array[self.curr_index].0 - { - return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); - } - let curr_id = &self.id_array[self.curr_index].1; - // skip `Expr`s without identifier (empty identifier). - if curr_id.is_empty() { - self.curr_index += 1; - return Ok(Transformed::no(expr)); - } + let curr_id = &ExprSet::expr_identifier(&expr); + + // lookup previously visited expression match self.expr_set.get(curr_id) { - Some((_, counter, _)) => { + Some((_, counter, _, symbol)) => { + // if has a commonly used (a.k.a. 1+ use) expr if *counter > 1 { self.affected_id.insert(curr_id.clone()); - // This expr tree is finished. - if self.curr_index >= self.id_array.len() { - return Ok(Transformed::new( - expr, - false, - TreeNodeRecursion::Jump, - )); - } - - let (series_number, id) = &self.id_array[self.curr_index]; - self.curr_index += 1; - // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. - let expr_set_item = self.expr_set.get(id).ok_or_else(|| { - internal_datafusion_err!("expr_set invalid state") - })?; - if *series_number < self.max_series_number - || id.is_empty() - || expr_set_item.1 <= 1 - { - return Ok(Transformed::new( - expr, - false, - TreeNodeRecursion::Jump, - )); - } - - self.max_series_number = *series_number; - // step index to skip all sub-node (which has smaller series number). - while self.curr_index < self.id_array.len() - && *series_number > self.id_array[self.curr_index].0 - { - self.curr_index += 1; - } - let expr_name = expr.display_name()?; // Alias this `Column` expr to it original "expr name", // `projection_push_down` optimizer use "expr name" to eliminate useless // projections. Ok(Transformed::new( - col(id).alias(expr_name), + col(symbol).alias(expr_name), true, TreeNodeRecursion::Jump, )) } else { - self.curr_index += 1; Ok(Transformed::no(expr)) } } - _ => internal_err!("expr_set invalid state"), + None => Ok(Transformed::no(expr)), } } } fn replace_common_expr( expr: Expr, - id_array: &[(usize, Identifier)], expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result { expr.rewrite(&mut CommonSubexprRewriter { expr_set, - id_array, affected_id, - max_series_number: 0, - curr_index: 0, }) .data() } @@ -872,73 +793,6 @@ mod test { assert_eq!(expected, formatted_plan); } - #[test] - fn id_array_visitor() -> Result<()> { - let expr = ((sum(col("a") + lit(1))) - avg(col("c"))) * lit(2); - - let schema = Arc::new(DFSchema::new_with_metadata( - vec![ - DFField::new_unqualified("a", DataType::Int64, false), - DFField::new_unqualified("c", DataType::Int64, false), - ], - Default::default(), - )?); - - // skip aggregates - let mut id_array = vec![]; - expr_to_identifier( - &expr, - &mut HashMap::new(), - &mut id_array, - Arc::clone(&schema), - ExprMask::Normal, - )?; - - let expected = vec![ - (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), - (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), - (4, ""), - (3, "a + Int32(1)Int32(1)a"), - (1, ""), - (2, ""), - (6, ""), - (5, ""), - (8, "") - ] - .into_iter() - .map(|(number, id)| (number, id.into())) - .collect::>(); - assert_eq!(expected, id_array); - - // include aggregates - let mut id_array = vec![]; - expr_to_identifier( - &expr, - &mut HashMap::new(), - &mut id_array, - Arc::clone(&schema), - ExprMask::NormalAndAggregates, - )?; - - let expected = vec![ - (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (3, "a + Int32(1)Int32(1)a"), - (1, ""), - (2, ""), - (6, "AVG(c)c"), - (5, ""), - (8, "") - ] - .into_iter() - .map(|(number, id)| (number, id.into())) - .collect::>(); - assert_eq!(expected, id_array); - - Ok(()) - } - #[test] fn tpch_q1_simplified() -> Result<()> { // SQL: @@ -1183,24 +1037,28 @@ mod test { let table_scan = test_table_scan().unwrap(); let affected_id: BTreeSet = ["c+a".to_string(), "b+a".to_string()].into_iter().collect(); - let expr_set_1 = [ + let expr_set_1 = vec![ ( "c+a".to_string(), - (col("c") + col("a"), 1, DataType::UInt32), + (col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()), ), ( "b+a".to_string(), - (col("b") + col("a"), 1, DataType::UInt32), + (col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()), ), ] - .into_iter() - .collect(); - let expr_set_2 = [ - ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)), - ("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)), + .into(); + let expr_set_2 = vec![ + ( + "c+a".to_string(), + (col("c+a"), 1, DataType::UInt32, "c+a".to_string()), + ), + ( + "b+a".to_string(), + (col("b+a"), 1, DataType::UInt32, "b+a".to_string()), + ), ] - .into_iter() - .collect(); + .into(); let project = build_common_expr_project_plan(table_scan, affected_id.clone(), &expr_set_1) .unwrap(); @@ -1226,30 +1084,48 @@ mod test { ["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()] .into_iter() .collect(); - let expr_set_1 = [ + let expr_set_1 = vec![ ( "test1.c+test1.a".to_string(), - (col("test1.c") + col("test1.a"), 1, DataType::UInt32), + ( + col("test1.c") + col("test1.a"), + 1, + DataType::UInt32, + "test1.c+test1.a".to_string(), + ), ), ( "test1.b+test1.a".to_string(), - (col("test1.b") + col("test1.a"), 1, DataType::UInt32), + ( + col("test1.b") + col("test1.a"), + 1, + DataType::UInt32, + "test1.b+test1.a".to_string(), + ), ), ] - .into_iter() - .collect(); - let expr_set_2 = [ + .into(); + let expr_set_2 = vec![ ( "test1.c+test1.a".to_string(), - (col("test1.c+test1.a"), 1, DataType::UInt32), + ( + col("test1.c+test1.a"), + 1, + DataType::UInt32, + "test1.c+test1.a".to_string(), + ), ), ( "test1.b+test1.a".to_string(), - (col("test1.b+test1.a"), 1, DataType::UInt32), + ( + col("test1.b+test1.a"), + 1, + DataType::UInt32, + "test1.b+test1.a".to_string(), + ), ), ] - .into_iter() - .collect(); + .into(); let project = build_common_expr_project_plan(join, affected_id.clone(), &expr_set_1) .unwrap(); diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index ef034a5ed711..ddc7d1256f75 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -404,11 +404,12 @@ struct ConstEvaluator<'a> { input_batch: RecordBatch, } +#[allow(dead_code)] /// The simplify result of ConstEvaluator enum ConstSimplifyResult { // Expr was simplifed and contains the new expression Simplified(ScalarValue), - // Evalaution encountered an error, contains the original expression + // Evaluation encountered an error, contains the original expression SimplifyRuntimeError(DataFusionError, Expr), } diff --git a/datafusion/physical-expr/benches/to_char.rs b/datafusion/physical-expr/benches/to_char.rs index 3d08a02bc231..5b9415ae1df7 100644 --- a/datafusion/physical-expr/benches/to_char.rs +++ b/datafusion/physical-expr/benches/to_char.rs @@ -64,7 +64,7 @@ fn data(rng: &mut ThreadRng) -> Date32Array { } fn patterns(rng: &mut ThreadRng) -> StringArray { - let samples = vec![ + let samples = [ "%Y:%m:%d".to_string(), "%d-%m-%Y".to_string(), "%d%m%Y".to_string(), diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs index 71782fcc5f9b..fb5e7710496c 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs @@ -47,7 +47,7 @@ use crate::binary_map::OutputType; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -/// Expression for a COUNT(DISTINCT) aggregation. +/// Expression for a `COUNT(DISTINCT)` aggregation. #[derive(Debug)] pub struct DistinctCount { /// Column name @@ -100,6 +100,7 @@ impl AggregateExpr for DistinctCount { use TimeUnit::*; Ok(match &self.state_data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new()), @@ -157,6 +158,7 @@ impl AggregateExpr for DistinctCount { OutputType::Binary, )), + // Use the generic accumulator based on `ScalarValue` for all other types _ => Box::new(DistinctCountAccumulator { values: HashSet::default(), state_data_type: self.state_data_type.clone(), @@ -183,7 +185,11 @@ impl PartialEq for DistinctCount { } /// General purpose distinct accumulator that works for any DataType by using -/// [`ScalarValue`]. Some types have specialized accumulators that are (much) +/// [`ScalarValue`]. +/// +/// It stores intermediate results as a `ListArray` +/// +/// Note that many types have specialized accumulators that are (much) /// more efficient such as [`PrimitiveDistinctCountAccumulator`] and /// [`BytesDistinctCountAccumulator`] #[derive(Debug)] @@ -193,8 +199,9 @@ struct DistinctCountAccumulator { } impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * number of batches - // This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types + // calculating the size for fixed length values, taking first batch size * + // number of batches This method is faster than .full_size(), however it is + // not suitable for variable length values like strings or complex types fn fixed_size(&self) -> usize { std::mem::size_of_val(self) + (std::mem::size_of::() * self.values.capacity()) @@ -207,7 +214,8 @@ impl DistinctCountAccumulator { + std::mem::size_of::() } - // calculates the size as accurate as possible, call to this method is expensive + // calculates the size as accurately as possible. Note that calling this + // method is expensive fn full_size(&self) -> usize { std::mem::size_of_val(self) + (std::mem::size_of::() * self.values.capacity()) @@ -221,6 +229,7 @@ impl DistinctCountAccumulator { } impl Accumulator for DistinctCountAccumulator { + /// Returns the distinct values seen so far as (one element) ListArray. fn state(&mut self) -> Result> { let scalars = self.values.iter().cloned().collect::>(); let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); @@ -246,6 +255,11 @@ impl Accumulator for DistinctCountAccumulator { }) } + /// Merges multiple sets of distinct values into the current set. + /// + /// The input to this function is a `ListArray` with **multiple** rows, + /// where each row contains the values from a partial aggregate's phase (e.g. + /// the result of calling `Self::state` on multiple accumulators). fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); @@ -253,8 +267,12 @@ impl Accumulator for DistinctCountAccumulator { assert_eq!(states.len(), 1, "array_agg states must be singleton!"); let array = &states[0]; let list_array = array.as_list::(); - let inner_array = list_array.value(0); - self.update_batch(&[inner_array]) + for inner_array in list_array.iter() { + let inner_array = inner_array + .expect("counts are always non null, so are intermediate results"); + self.update_batch(&[inner_array])?; + } + Ok(()) } fn evaluate(&mut self) -> Result { diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 280535f5e6be..58519c61cf1f 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -535,7 +535,7 @@ mod tests { #[test] fn test_remove_redundant_entries_eq_group() -> Result<()> { - let entries = vec![ + let entries = [ EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), // This group is meaningless should be removed EquivalenceClass::new(vec![lit(3), lit(3)]), @@ -543,11 +543,11 @@ mod tests { ]; // Given equivalences classes are not in succinct form. // Expected form is the most plain representation that is functionally same. - let expected = vec![ + let expected = [ EquivalenceClass::new(vec![lit(1), lit(2)]), EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), ]; - let mut eq_groups = EquivalenceGroup::new(entries); + let mut eq_groups = EquivalenceGroup::new(entries.to_vec()); eq_groups.remove_redundant_entries(); let eq_groups = eq_groups.classes; diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index c7cb9e5f530e..1364d3a8c028 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -746,7 +746,7 @@ mod tests { // Generate a data that satisfies properties given let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let col_exprs = vec![ + let col_exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, @@ -815,7 +815,7 @@ mod tests { Operator::Plus, col("b", &test_schema)?, )) as Arc; - let exprs = vec![ + let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 890d0b49687a..9f1998f70a7d 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -1793,7 +1793,7 @@ mod tests { Operator::Plus, col("b", &test_schema)?, )) as Arc; - let exprs = vec![ + let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index 98a05dff5386..a8c115ba3a82 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -25,6 +25,7 @@ use std::sync::Arc; use arrow::array::ArrayRef; use arrow::array::{BooleanArray, Float32Array, Float64Array, Int64Array}; use arrow::datatypes::DataType; +use arrow_array::Array; use rand::{thread_rng, Rng}; use datafusion_common::ScalarValue::{Float32, Int64}; @@ -92,8 +93,9 @@ macro_rules! downcast_arg { ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { DataFusionError::Internal(format!( - "could not cast {} to {}", + "could not cast {} from {} to {}", $NAME, + $ARG.data_type(), type_name::<$ARRAY_TYPE>() )) })? diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 500df6153fdb..2acb881246a4 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -578,7 +578,7 @@ mod tests { #[tokio::test] async fn test_partial_sort2() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); - let source_tables = vec![ + let source_tables = [ test::build_table_scan_i32( ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]), ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]), diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 7eaac74a5449..64322bd5f101 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -740,7 +740,7 @@ mod tests { let col_e = &col("e", &schema)?; let col_f = &col("f", &schema)?; let options = SortOptions::default(); - let test_cases = vec![ + let test_cases = [ //-----------TEST CASE 1----------// ( // First child orderings diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index 002aade2528e..af7bf5cb16e8 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -280,3 +280,70 @@ ORDER BY 2023-12-20T01:20:00 1000 f2 foo 2023-12-20T01:30:00 1000 f1 32.0 2023-12-20T01:30:00 1000 f2 foo + +# Cleanup +statement ok +drop view m1; + +statement ok +drop view m2; + +###### +# Create a table using UNION ALL to get 2 partitions (very important) +###### +statement ok +create table m3_source as + select * from (values('foo', 'bar', 1)) + UNION ALL + select * from (values('foo', 'baz', 1)); + +###### +# Now, create a table with the same data, but column2 has type `Dictionary(Int32)` to trigger the fallback code +###### +statement ok +create table m3 as + select + column1, + arrow_cast(column2, 'Dictionary(Int32, Utf8)') as "column2", + column3 +from m3_source; + +# there are two values in column2 +query T?I rowsort +SELECT * +FROM m3; +---- +foo bar 1 +foo baz 1 + +# There is 1 distinct value in column1 +query I +SELECT count(distinct column1) +FROM m3 +GROUP BY column3; +---- +1 + +# There are 2 distinct values in column2 +query I +SELECT count(distinct column2) +FROM m3 +GROUP BY column3; +---- +2 + +# Should still get the same results when querying in the same query +query II +SELECT count(distinct column1), count(distinct column2) +FROM m3 +GROUP BY column3; +---- +1 2 + + +# Cleanup +statement ok +drop table m3; + +statement ok +drop table m3_source; diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 9e4e3aa8185d..43c3144fb80e 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -1910,3 +1910,102 @@ false true false true NULL NULL NULL NULL false false true true false false true false + + +############# +## Common Subexpr Eliminate Tests +############# + +statement ok +CREATE TABLE doubles ( + f64 DOUBLE +) as VALUES + (10.1) +; + +# common subexpr with alias +query RRR rowsort +select f64, round(1.0 / f64) as i64_1, acos(round(1.0 / f64)) from doubles; +---- +10.1 0 1.570796326795 + +# common subexpr with coalesce (short-circuited) +query RRR rowsort +select f64, coalesce(1.0 / f64, 0.0), acos(coalesce(1.0 / f64, 0.0)) from doubles; +---- +10.1 0.09900990099 1.471623942989 + +# common subexpr with coalesce (short-circuited) and alias +query RRR rowsort +select f64, coalesce(1.0 / f64, 0.0) as f64_1, acos(coalesce(1.0 / f64, 0.0)) from doubles; +---- +10.1 0.09900990099 1.471623942989 + +# common subexpr with case (short-circuited) +query RRR rowsort +select f64, case when f64 > 0 then 1.0 / f64 else null end, acos(case when f64 > 0 then 1.0 / f64 else null end) from doubles; +---- +10.1 0.09900990099 1.471623942989 + + +statement ok +CREATE TABLE t1( + time TIMESTAMP, + load1 DOUBLE, + load2 DOUBLE, + host VARCHAR +) AS VALUES + (to_timestamp_nanos(1527018806000000000), 1.1, 101, 'host1'), + (to_timestamp_nanos(1527018806000000000), 2.2, 202, 'host2'), + (to_timestamp_nanos(1527018806000000000), 3.3, 303, 'host3'), + (to_timestamp_nanos(1527018806000000000), 1.1, 101, NULL) +; + +# struct scalar function with columns +query ? +select struct(time,load1,load2,host) from t1; +---- +{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: host1} +{c0: 2018-05-22T19:53:26, c1: 2.2, c2: 202.0, c3: host2} +{c0: 2018-05-22T19:53:26, c1: 3.3, c2: 303.0, c3: host3} +{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: } + +# can have an aggregate function with an inner coalesce +query TR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 +host2 2.2 +host3 3.3 + +# can have an aggregate function with an inner CASE WHEN +query TR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 101 +host2 202 +host3 303 + +# can have 2 projections with aggr(short_circuited), with different short-circuited expr +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 + +# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN) +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 + +# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce) +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 9619696679d2..6c6e242ff589 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -1782,7 +1782,7 @@ AS VALUES ('BB', 6, 1), ('BB', 6, 1); -query TIR +query TII select col1, col2, coalesce(sum_col3, 0) as sum_col3 from (select distinct col2 from tbl) AS q1 cross join (select distinct col1 from tbl) AS q2 diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 5ff253c1a34a..a64fcbbdbca2 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1841,6 +1841,51 @@ SELECT COALESCE(c1 * c2, 0) FROM test statement ok drop table test +# coalesce date32 + +statement ok +CREATE TABLE test( + d1_date DATE, + d2_date DATE, + d3_date DATE +) as VALUES + ('2022-12-12','2022-12-12','2022-12-12'), + (NULL,'2022-12-11','2022-12-12'), + ('2022-12-12','2022-12-10','2022-12-12'), + ('2022-12-12',NULL,'2022-12-12'), + ('2022-12-12','2022-12-8','2022-12-12'), + ('2022-12-12','2022-12-7',NULL), + ('2022-12-12',NULL,'2022-12-12'), + (NULL,'2022-12-5','2022-12-12') +; + +query D +SELECT COALESCE(d1_date, d2_date, d3_date) FROM test +---- +2022-12-12 +2022-12-11 +2022-12-12 +2022-12-12 +2022-12-12 +2022-12-12 +2022-12-12 +2022-12-05 + +query T +SELECT arrow_typeof(COALESCE(d1_date, d2_date, d3_date)) FROM test +---- +Date32 +Date32 +Date32 +Date32 +Date32 +Date32 +Date32 +Date32 + +statement ok +drop table test + statement ok CREATE TABLE test( i32 INT, diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index e8698253edb5..6b81e33dfc37 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -27,6 +27,7 @@ use substrait::proto::Plan; use std::fs::OpenOptions; use std::io::{Read, Write}; +#[allow(clippy::suspicious_open_options)] pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { let protobuf_out = serialize_bytes(sql, ctx).await; let mut file = OpenOptions::new().create(true).write(true).open(path)?;