From 9ba3285a092cc62f35eebc0da1e5c0a2985dd46b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 21 Mar 2023 19:22:53 -0700 Subject: [PATCH] Fix --- .../core/src/datasource/listing/helpers.rs | 1 + datafusion/core/src/physical_plan/planner.rs | 14 +- datafusion/expr/src/expr.rs | 44 +++- datafusion/expr/src/expr_fn.rs | 15 ++ datafusion/expr/src/expr_rewriter.rs | 23 ++- datafusion/expr/src/expr_schema.rs | 48 ++++- datafusion/expr/src/expr_visitor.rs | 3 +- datafusion/expr/src/type_coercion/binary.rs | 76 ++++++- datafusion/expr/src/utils.rs | 3 +- datafusion/jit/src/ast.rs | 1 + .../optimizer/src/eliminate_cross_join.rs | 9 +- .../optimizer/src/eliminate_outer_join.rs | 4 +- .../src/extract_equijoin_predicate.rs | 1 + datafusion/optimizer/src/push_down_filter.rs | 3 + .../src/rewrite_disjunctive_predicate.rs | 4 +- .../optimizer/src/scalar_subquery_to_join.rs | 4 +- .../simplify_expressions/expr_simplifier.rs | 79 +++++++ .../src/simplify_expressions/regex.rs | 2 +- .../src/simplify_expressions/utils.rs | 20 +- datafusion/optimizer/src/type_coercion.rs | 30 ++- .../src/unwrap_cast_in_comparison.rs | 22 +- datafusion/optimizer/src/utils.rs | 22 +- .../physical-expr/src/expressions/binary.rs | 151 +++++++++++--- .../src/expressions/binary/kernels_arrow.rs | 193 +++++++++++++----- .../physical-expr/src/expressions/mod.rs | 4 +- .../src/expressions/promote_precision.rs | 90 ++++++++ datafusion/physical-expr/src/planner.rs | 20 +- datafusion/proto/src/logical_plan/to_proto.rs | 7 +- datafusion/sql/src/expr/mod.rs | 5 +- datafusion/sql/src/utils.rs | 25 ++- .../substrait/src/logical_plan/consumer.rs | 29 +-- .../substrait/src/logical_plan/producer.rs | 4 +- 32 files changed, 797 insertions(+), 159 deletions(-) create mode 100644 datafusion/physical-expr/src/expressions/promote_precision.rs diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 7ed6326a907aa..40ab7315f3de2 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -94,6 +94,7 @@ impl ExpressionVisitor for ApplicabilityVisitor<'_> { | Expr::IsNotUnknown(_) | Expr::Negative(_) | Expr::Cast { .. } + | Expr::PromotePrecision { .. } | Expr::TryCast { .. } | Expr::BinaryExpr { .. } | Expr::Between { .. } diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 51653450a6996..5fca65cf668dd 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -61,7 +61,7 @@ use async_trait::async_trait; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::{ self, AggregateFunction, Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, - Like, TryCast, WindowFunction, + Like, PromotePrecision, TryCast, WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -111,7 +111,9 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Expr::Alias(_, name) => Ok(name.clone()), Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")), Expr::Literal(value) => Ok(format!("{value:?}")), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) => { let left = create_physical_name(left, false)?; let right = create_physical_name(right, false)?; Ok(format!("{left} {op} {right}")) @@ -134,6 +136,10 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { // CAST does not change the expression name create_physical_name(expr, false) } + Expr::PromotePrecision(PromotePrecision { expr }) => { + // PromotePrecision does not change the expression name + create_physical_name(expr, false) + } Expr::TryCast(TryCast { expr, .. }) => { // CAST does not change the expression name create_physical_name(expr, false) @@ -1924,7 +1930,7 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) } }"; + let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, data_type: None }"; assert!(format!("{exec_plan:?}").contains(expected)); Ok(()) } @@ -2170,7 +2176,7 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } } }"; + let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, data_type: None }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, data_type: None }, data_type: None }"; let actual = format!("{execution_plan:?}"); assert!(actual.contains(expected), "{}", actual); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 2806683ab87bf..1ef92df78165f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -146,6 +146,8 @@ pub enum Expr { /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. Cast(Cast), + /// Wraps the child expression when promoting the precision of DecimalType to avoid promote multiple times. + PromotePrecision(PromotePrecision), /// Casts the expression to a given type and will return a null value if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. TryCast(TryCast), @@ -234,12 +236,33 @@ pub struct BinaryExpr { pub op: Operator, /// Right-hand side of the expression pub right: Box, + /// The data type of the expression, if known + pub data_type: Option, } impl BinaryExpr { /// Create a new binary expression pub fn new(left: Box, op: Operator, right: Box) -> Self { - Self { left, op, right } + Self { + left, + op, + right, + data_type: None, + } + } + + pub fn new_with_data_type( + left: Box, + op: Operator, + right: Box, + data_type: Option, + ) -> Self { + Self { + left, + op, + right, + data_type, + } } } @@ -385,6 +408,20 @@ impl Cast { } } +/// Cast expression +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct PromotePrecision { + /// The expression being promoted + pub expr: Box, +} + +impl PromotePrecision { + /// Create a new PromotePrecision expression + pub fn new(expr: Box) -> Self { + Self { expr } + } +} + /// TryCast Expression #[derive(Clone, PartialEq, Eq, Hash)] pub struct TryCast { @@ -569,6 +606,7 @@ impl Expr { Expr::BinaryExpr { .. } => "BinaryExpr", Expr::Case { .. } => "Case", Expr::Cast { .. } => "Cast", + Expr::PromotePrecision { .. } => "PromotePrecision", Expr::Column(..) => "Column", Expr::OuterReferenceColumn(_, _) => "Outer", Expr::Exists { .. } => "Exists", @@ -858,6 +896,9 @@ impl fmt::Debug for Expr { Expr::Cast(Cast { expr, data_type }) => { write!(f, "CAST({expr:?} AS {data_type:?})") } + Expr::PromotePrecision(PromotePrecision { expr }) => { + write!(f, "PROMOTE_PRECISION({expr:?})") + } Expr::TryCast(TryCast { expr, data_type }) => { write!(f, "TRY_CAST({expr:?} AS {data_type:?})") } @@ -1211,6 +1252,7 @@ fn create_name(e: &Expr) -> Result { // CAST does not change the expression name create_name(expr) } + Expr::PromotePrecision(PromotePrecision { expr }) => create_name(expr), Expr::TryCast(TryCast { expr, .. }) => { // CAST does not change the expression name create_name(expr) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index b20629946b01d..b4edf5f7cc3c2 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -74,6 +74,21 @@ pub fn binary_expr(left: Expr, op: Operator, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new(Box::new(left), op, Box::new(right))) } +/// Return a new expression `left right` +pub fn binary_expr_with_data_type( + left: Expr, + op: Operator, + right: Expr, + data_type: Option, +) -> Expr { + Expr::BinaryExpr(BinaryExpr::new_with_data_type( + Box::new(left), + op, + Box::new(right), + data_type, + )) +} + /// Return a new expression with a logical AND pub fn and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( diff --git a/datafusion/expr/src/expr_rewriter.rs b/datafusion/expr/src/expr_rewriter.rs index b4e82be5781fd..73f86cacdfcfc 100644 --- a/datafusion/expr/src/expr_rewriter.rs +++ b/datafusion/expr/src/expr_rewriter.rs @@ -19,7 +19,7 @@ use crate::expr::{ AggregateFunction, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet, - Like, Sort, TryCast, WindowFunction, + Like, PromotePrecision, Sort, TryCast, WindowFunction, }; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; @@ -135,13 +135,17 @@ impl ExprRewritable for Expr { Expr::ScalarSubquery(_) => self.clone(), Expr::ScalarVariable(ty, names) => Expr::ScalarVariable(ty, names), Expr::Literal(value) => Expr::Literal(value), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - Expr::BinaryExpr(BinaryExpr::new( - rewrite_boxed(left, rewriter)?, - op, - rewrite_boxed(right, rewriter)?, - )) - } + Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + data_type, + }) => Expr::BinaryExpr(BinaryExpr::new_with_data_type( + rewrite_boxed(left, rewriter)?, + op, + rewrite_boxed(right, rewriter)?, + data_type, + )), Expr::Like(Like { negated, expr, @@ -218,6 +222,9 @@ impl ExprRewritable for Expr { Expr::Cast(Cast { expr, data_type }) => { Expr::Cast(Cast::new(rewrite_boxed(expr, rewriter)?, data_type)) } + Expr::PromotePrecision(PromotePrecision { expr }) => Expr::PromotePrecision( + PromotePrecision::new(rewrite_boxed(expr, rewriter)?), + ), Expr::TryCast(TryCast { expr, data_type }) => { Expr::TryCast(TryCast::new(rewrite_boxed(expr, rewriter)?, data_type)) } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index fafda79a6f61d..2f9c35526417e 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,7 +17,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, BinaryExpr, Cast, GetIndexedField, Sort, TryCast, WindowFunction, + AggregateFunction, BinaryExpr, Cast, GetIndexedField, PromotePrecision, Sort, + TryCast, WindowFunction, }; use crate::field_util::get_indexed_field; use crate::type_coercion::binary::binary_operator_data_type; @@ -39,6 +40,13 @@ pub trait ExprSchemable { /// cast to a type with respect to a schema fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result; + + /// promote to a type with respect to a schema + fn promote_to( + self, + promote_to_type: &DataType, + schema: &S, + ) -> Result; } impl ExprSchemable for Expr { @@ -71,6 +79,7 @@ impl ExprSchemable for Expr { Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), + Expr::PromotePrecision(PromotePrecision { expr }) => expr.get_type(schema), Expr::ScalarUDF { fun, args } => { let data_types = args .iter() @@ -126,11 +135,18 @@ impl ExprSchemable for Expr { ref left, ref right, ref op, - }) => binary_operator_data_type( - &left.get_type(schema)?, - op, - &right.get_type(schema)?, - ), + ref data_type, + }) => { + if let Some(dt) = data_type { + Ok(dt.clone()) + } else { + binary_operator_data_type( + &left.get_type(schema)?, + op, + &right.get_type(schema)?, + ) + } + } Expr::Like { .. } | Expr::ILike { .. } | Expr::SimilarTo { .. } => { Ok(DataType::Boolean) } @@ -195,6 +211,9 @@ impl ExprSchemable for Expr { } } Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), + Expr::PromotePrecision(PromotePrecision { expr }) => { + expr.nullable(input_schema) + } Expr::ScalarVariable(_, _) | Expr::TryCast { .. } | Expr::ScalarFunction { .. } @@ -284,6 +303,23 @@ impl ExprSchemable for Expr { ))) } } + + /// Wraps this expression in a promote precision to a target [arrow::datatypes::DataType]. + /// + /// # Errors + /// + /// This function errors when it is impossible to cast the + /// expression to the target [arrow::datatypes::DataType]. + fn promote_to( + self, + promote_to_type: &DataType, + schema: &S, + ) -> Result { + let casted = self.cast_to(promote_to_type, schema)?; + Ok(Expr::PromotePrecision(PromotePrecision::new(Box::new( + casted, + )))) + } } #[cfg(test)] diff --git a/datafusion/expr/src/expr_visitor.rs b/datafusion/expr/src/expr_visitor.rs index 84ca6f7ed9dfb..eabf262a936b6 100644 --- a/datafusion/expr/src/expr_visitor.rs +++ b/datafusion/expr/src/expr_visitor.rs @@ -17,7 +17,7 @@ //! Expression visitor -use crate::expr::{AggregateFunction, Cast, Sort, WindowFunction}; +use crate::expr::{AggregateFunction, Cast, PromotePrecision, Sort, WindowFunction}; use crate::{ expr::{BinaryExpr, GroupingSet, TryCast}, Between, Expr, GetIndexedField, Like, @@ -116,6 +116,7 @@ impl ExprVisitable for Expr { | Expr::IsNull(expr) | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) + | Expr::PromotePrecision(PromotePrecision { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery { expr, .. } => expr.accept(visitor), diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr/src/type_coercion/binary.rs index 5ee66837ec166..e40cc7574840a 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr/src/type_coercion/binary.rs @@ -35,9 +35,52 @@ pub fn binary_operator_data_type( op: &Operator, rhs_type: &DataType, ) -> Result { + let coerced_type = coerce_types(lhs_type, op, rhs_type)?; // validate that it is possible to perform the operation on incoming types. // (or the return datatype cannot be inferred) - let result_type = coerce_types(lhs_type, op, rhs_type)?; + let result_type = if !matches!(coerced_type, DataType::Decimal128(_, _)) { + coerced_type + } else { + let lhs_type = match lhs_type { + DataType::Decimal128(_, _) | DataType::Null => lhs_type.clone(), + DataType::Dictionary(_, value_type) + if matches!(**value_type, DataType::Decimal128(_, _)) => + { + lhs_type.clone() + } + _ => coerce_numeric_type_to_decimal(lhs_type).ok_or_else(|| { + DataFusionError::Internal(format!( + "Could not coerce numeric type to decimal: {:?}", + lhs_type + )) + })?, + }; + + let rhs_type = match rhs_type { + DataType::Decimal128(_, _) | DataType::Null => rhs_type.clone(), + DataType::Dictionary(_, value_type) + if matches!(**value_type, DataType::Decimal128(_, _)) => + { + rhs_type.clone() + } + _ => coerce_numeric_type_to_decimal(rhs_type).ok_or_else(|| { + DataFusionError::Internal(format!( + "Could not coerce numeric type to decimal: {:?}", + rhs_type + )) + })?, + }; + + match op { + // For Plus and Minus, the result type is the same as the input type which is already promoted + Operator::Plus | Operator::Minus => coerced_type, + Operator::Divide | Operator::Multiply | Operator::Modulo => { + decimal_op_mathematics_type(op, &lhs_type, &rhs_type) + .unwrap_or(coerced_type) + } + _ => coerced_type, + } + }; match op { // operators that return a boolean @@ -388,6 +431,8 @@ fn mathematics_numerical_coercion( if lhs_type == rhs_type && !(matches!(lhs_type, DataType::Dictionary(_, _)) || matches!(rhs_type, DataType::Dictionary(_, _))) + // For decimal, we always need to coerce/promote the decimal types. + && !matches!(lhs_type, DataType::Decimal128(_, _)) { return Some(lhs_type.clone()); } @@ -458,10 +503,39 @@ fn create_decimal_type(precision: u8, scale: i8) -> DataType { ) } +/// Returns the promotion type of applying mathematics operations on decimal types. +/// Two sides of the mathematics operation will be promoted to the same type. fn coercion_decimal_mathematics_type( mathematics_op: &Operator, left_decimal_type: &DataType, right_decimal_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + match (left_decimal_type, right_decimal_type) { + // The promotion rule from spark + // https://github.com/apache/spark/blob/c20af535803a7250fef047c2bf0fe30be242369d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala#L35 + (Decimal128(_, _), Decimal128(_, _)) => match mathematics_op { + Operator::Plus | Operator::Minus => decimal_op_mathematics_type( + mathematics_op, + left_decimal_type, + right_decimal_type, + ), + Operator::Multiply | Operator::Divide | Operator::Modulo => { + get_wider_decimal_type(left_decimal_type, right_decimal_type) + } + _ => None, + }, + _ => None, + } +} + +/// Returns the output type of applying mathematics operations on decimal types. +/// The rule is from spark. Note that this is different to the promoted type applied +/// to two sides of the arithmetic operation. +fn decimal_op_mathematics_type( + mathematics_op: &Operator, + left_decimal_type: &DataType, + right_decimal_type: &DataType, ) -> Option { use arrow::datatypes::DataType::*; match (left_decimal_type, right_decimal_type) { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index c8bc8518078de..c3cd87cc1aa6a 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -295,6 +295,7 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Between { .. } | Expr::Case { .. } | Expr::Cast { .. } + | Expr::PromotePrecision { .. } | Expr::TryCast { .. } | Expr::Sort { .. } | Expr::ScalarFunction { .. } @@ -792,7 +793,7 @@ pub fn from_plan( let new_on:Vec<(Expr,Expr)> = expr.iter().take(equi_expr_count).map(|equi_expr| { // SimplifyExpression rule may add alias to the equi_expr. let unalias_expr = equi_expr.clone().unalias(); - if let Expr::BinaryExpr(BinaryExpr { left, op:Operator::Eq, right }) = unalias_expr { + if let Expr::BinaryExpr(BinaryExpr { left, op:Operator::Eq, right , .. }) = unalias_expr { Ok((*left, *right)) } else { Err(DataFusionError::Internal(format!( diff --git a/datafusion/jit/src/ast.rs b/datafusion/jit/src/ast.rs index 36741432ec257..e3f9a0f384806 100644 --- a/datafusion/jit/src/ast.rs +++ b/datafusion/jit/src/ast.rs @@ -156,6 +156,7 @@ impl TryFrom<(datafusion_expr::Expr, DFSchemaRef)> for Expr { left, op, right, + .. }) => { let op = match op { datafusion_expr::Operator::Eq => BinaryExpr::Eq, diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 533566a0bf695..d97de49198c01 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -260,7 +260,10 @@ fn intersect( /// Extract join keys from a WHERE clause fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Expr, Expr)>) -> Result<()> { - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr { + if let Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) = expr + { match op { Operator::Eq => { // Ensure that we don't add the same Join keys multiple times @@ -298,7 +301,9 @@ fn remove_join_expressions( join_keys: &HashSet<(Expr, Expr)>, ) -> Result> { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) => match op { Operator::Eq => { if join_keys.contains(&(*left.clone(), *right.clone())) || join_keys.contains(&(*right.clone(), *left.clone())) diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 8dfdfae035a12..7fde5c767bef6 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -178,7 +178,9 @@ fn extract_non_nullable_columns( non_nullable_cols.push(col.clone()); Ok(()) } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) => match op { // If one of the inputs are null for these operators, the results should be false. Operator::Eq | Operator::NotEq diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 2f7a20d6e230d..6a1b1be998ee2 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -114,6 +114,7 @@ fn split_eq_and_noneq_join_predicate( left, op: Operator::Eq, right, + .. }) => { let left = left.as_ref(); let right = right.as_ref(); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 55c77e51e2d3d..a00a5e253a77e 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -209,6 +209,7 @@ fn extract_or_clauses_for_join( left, op: Operator::Or, right, + .. }) = expr { let left_expr = extract_or_clause(left.as_ref(), &schema_columns); @@ -244,6 +245,7 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option { let l_expr = extract_or_clause(l_expr, schema_columns); let r_expr = extract_or_clause(r_expr, schema_columns); @@ -256,6 +258,7 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option { let l_expr = extract_or_clause(l_expr, schema_columns); let r_expr = extract_or_clause(r_expr, schema_columns); diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index 57513fa4fff41..4ab841fde15eb 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -161,7 +161,9 @@ enum Predicate { fn predicate(expr: &Expr) -> Result { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) => match op { Operator::And => { let args = vec![predicate(left)?, predicate(right)?]; Ok(Predicate::And { args }) diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index df0b9245faec8..79b23ef7d4268 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -59,7 +59,9 @@ impl ScalarSubqueryToJoin { let mut others = vec![]; for it in filters.iter() { match it { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) => { let l_query = Subquery::try_from_expr(left); let r_query = Subquery::try_from_expr(right); if l_query.is_err() && r_query.is_err() { diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index bfac8da643dba..051716b78668f 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -284,6 +284,7 @@ impl<'a> ConstEvaluator<'a> { | Expr::SimilarTo { .. } | Expr::Case(_) | Expr::Cast { .. } + | Expr::PromotePrecision { .. } | Expr::TryCast { .. } | Expr::InList { .. } | Expr::GetIndexedField { .. } => true, @@ -360,6 +361,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: Eq, right, + .. }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { match as_bool_lit(*left)? { Some(true) => *right, @@ -374,6 +376,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: Eq, right, + .. }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { match as_bool_lit(*right)? { Some(true) => *left, @@ -450,6 +453,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: NotEq, right, + .. }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { match as_bool_lit(*left)? { Some(true) => Expr::Not(right), @@ -464,6 +468,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: NotEq, right, + .. }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { match as_bool_lit(*right)? { Some(true) => Expr::Not(left), @@ -481,30 +486,35 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: Or, right: _, + .. }) if is_true(&left) => *left, // false OR A --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, + .. }) if is_false(&left) => *right, // A OR true --> true (even if A is null) Expr::BinaryExpr(BinaryExpr { left: _, op: Or, right, + .. }) if is_true(&right) => *right, // A OR false --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, + .. }) if is_false(&right) => *left, // A OR !A ---> true (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, + .. }) if is_not_of(&right, &left) && !info.nullable(&left)? => { Expr::Literal(ScalarValue::Boolean(Some(true))) } @@ -513,6 +523,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: Or, right, + .. }) if is_not_of(&left, &right) && !info.nullable(&right)? => { Expr::Literal(ScalarValue::Boolean(Some(true))) } @@ -521,24 +532,28 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: Or, right, + .. }) if expr_contains(&left, &right, Or) => *left, // A OR (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, + .. }) if expr_contains(&right, &left, Or) => *right, // A OR (A AND B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, + .. }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => *left, // (A AND B) OR A --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, + .. }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => *right, // @@ -550,30 +565,35 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: And, right, + .. }) if is_true(&left) => *right, // false AND A --> false (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: And, right: _, + .. }) if is_false(&left) => *left, // A AND true --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, + .. }) if is_true(&right) => *left, // A AND false --> false (even if A is null) Expr::BinaryExpr(BinaryExpr { left: _, op: And, right, + .. }) if is_false(&right) => *right, // A AND !A ---> false (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: And, right, + .. }) if is_not_of(&right, &left) && !info.nullable(&left)? => { Expr::Literal(ScalarValue::Boolean(Some(false))) } @@ -582,6 +602,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: And, right, + .. }) if is_not_of(&left, &right) && !info.nullable(&right)? => { Expr::Literal(ScalarValue::Boolean(Some(false))) } @@ -590,24 +611,28 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: And, right, + .. }) if expr_contains(&left, &right, And) => *left, // A AND (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: And, right, + .. }) if expr_contains(&right, &left, And) => *right, // A AND (A OR B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: And, right, + .. }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => *left, // (A OR B) AND A --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: And, right, + .. }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => *right, // @@ -619,24 +644,28 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: Multiply, right, + .. }) if is_one(&right) => *left, // 1 * A --> A Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right, + .. }) if is_one(&left) => *right, // A * null --> null Expr::BinaryExpr(BinaryExpr { left: _, op: Multiply, right, + .. }) if is_null(&right) => *right, // null * A --> null Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right: _, + .. }) if is_null(&left) => *left, // A * 0 --> 0 (if A is not null) @@ -644,12 +673,14 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: Multiply, right, + .. }) if !info.nullable(&left)? && is_zero(&right) => *right, // 0 * A --> 0 (if A is not null) Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right, + .. }) if !info.nullable(&right)? && is_zero(&left) => *left, // @@ -661,24 +692,28 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: Divide, right, + .. }) if is_one(&right) => *left, // null / A --> null Expr::BinaryExpr(BinaryExpr { left, op: Divide, right: _, + .. }) if is_null(&left) => *left, // A / null --> null Expr::BinaryExpr(BinaryExpr { left: _, op: Divide, right, + .. }) if is_null(&right) => *right, // 0 / 0 -> null Expr::BinaryExpr(BinaryExpr { left, op: Divide, right, + .. }) if is_zero(&left) && is_zero(&right) => { Expr::Literal(ScalarValue::Int32(None)) } @@ -687,6 +722,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: Divide, right, + .. }) if !info.nullable(&left)? && is_zero(&right) => { return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); } @@ -700,24 +736,28 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left: _, op: Modulo, right, + .. }) if is_null(&right) => *right, // null % A --> null Expr::BinaryExpr(BinaryExpr { left, op: Modulo, right: _, + .. }) if is_null(&left) => *left, // A % 1 --> 0 Expr::BinaryExpr(BinaryExpr { left, op: Modulo, right, + .. }) if !info.nullable(&left)? && is_one(&right) => lit(0), // A % 0 --> DivideByZero Error Expr::BinaryExpr(BinaryExpr { left, op: Modulo, right, + .. }) if !info.nullable(&left)? && is_zero(&right) => { return Err(DataFusionError::ArrowError(ArrowError::DivideByZero)); } @@ -731,6 +771,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left: _, op: BitwiseAnd, right, + .. }) if is_null(&right) => *right, // null & A -> null @@ -738,6 +779,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right: _, + .. }) if is_null(&left) => *left, // A & 0 -> 0 (if A not nullable) @@ -745,6 +787,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, + .. }) if !info.nullable(&left)? && is_zero(&right) => *right, // 0 & A -> 0 (if A not nullable) @@ -752,6 +795,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, + .. }) if !info.nullable(&right)? && is_zero(&left) => *left, // !A & A -> 0 (if A not nullable) @@ -759,6 +803,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, + .. }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) } @@ -768,6 +813,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, + .. }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) } @@ -777,6 +823,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, + .. }) if expr_contains(&left, &right, BitwiseAnd) => *left, // A & (..A..) --> (..A..) @@ -784,6 +831,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, + .. }) if expr_contains(&right, &left, BitwiseAnd) => *right, // A & (A | B) --> A (if B not null) @@ -791,6 +839,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, + .. }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => { *left } @@ -800,6 +849,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, + .. }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => { *right } @@ -813,6 +863,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left: _, op: BitwiseOr, right, + .. }) if is_null(&right) => *right, // null | A -> null @@ -820,6 +871,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right: _, + .. }) if is_null(&left) => *left, // A | 0 -> A (even if A is null) @@ -827,6 +879,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, + .. }) if is_zero(&right) => *left, // 0 | A -> A (even if A is null) @@ -834,6 +887,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, + .. }) if is_zero(&left) => *right, // !A | A -> -1 (if A not nullable) @@ -841,6 +895,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, + .. }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) } @@ -850,6 +905,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, + .. }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) } @@ -859,6 +915,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, + .. }) if expr_contains(&left, &right, BitwiseOr) => *left, // A | (..A..) --> (..A..) @@ -866,6 +923,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, + .. }) if expr_contains(&right, &left, BitwiseOr) => *right, // A | (A & B) --> A (if B not null) @@ -873,6 +931,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, + .. }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => { *left } @@ -882,6 +941,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, + .. }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => { *right } @@ -895,6 +955,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left: _, op: BitwiseXor, right, + .. }) if is_null(&right) => *right, // null ^ A -> null @@ -902,6 +963,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseXor, right: _, + .. }) if is_null(&left) => *left, // A ^ 0 -> A (if A not nullable) @@ -909,6 +971,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseXor, right, + .. }) if !info.nullable(&left)? && is_zero(&right) => *left, // 0 ^ A -> A (if A not nullable) @@ -916,6 +979,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseXor, right, + .. }) if !info.nullable(&right)? && is_zero(&left) => *right, // !A ^ A -> -1 (if A not nullable) @@ -923,6 +987,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseXor, right, + .. }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) } @@ -932,6 +997,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseXor, right, + .. }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) } @@ -941,6 +1007,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseXor, right, + .. }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); if expr == *right { @@ -955,6 +1022,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseXor, right, + .. }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); if expr == *left { @@ -973,6 +1041,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left: _, op: BitwiseShiftRight, right, + .. }) if is_null(&right) => *right, // null >> A -> null @@ -980,6 +1049,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseShiftRight, right: _, + .. }) if is_null(&left) => *left, // A >> 0 -> A (even if A is null) @@ -987,6 +1057,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseShiftRight, right, + .. }) if is_zero(&right) => *left, // @@ -998,6 +1069,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left: _, op: BitwiseShiftLeft, right, + .. }) if is_null(&right) => *right, // null << A -> null @@ -1005,6 +1077,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseShiftLeft, right: _, + .. }) if is_null(&left) => *left, // A << 0 -> A (even if A is null) @@ -1012,6 +1085,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: BitwiseShiftLeft, right, + .. }) if is_zero(&right) => *left, // @@ -1113,6 +1187,7 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { left, op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), right, + .. }) => simplify_regex_expr(left, op, right)?, // no additional rewrites possible @@ -2380,6 +2455,7 @@ mod tests { left: Box::new(left), op: Operator::RegexMatch, right: Box::new(right), + data_type: None, }) } @@ -2388,6 +2464,7 @@ mod tests { left: Box::new(left), op: Operator::RegexNotMatch, right: Box::new(right), + data_type: None, }) } @@ -2396,6 +2473,7 @@ mod tests { left: Box::new(left), op: Operator::RegexIMatch, right: Box::new(right), + data_type: None, }) } @@ -2404,6 +2482,7 @@ mod tests { left: Box::new(left), op: Operator::RegexNotIMatch, right: Box::new(right), + data_type: None, }) } diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index 13d170fd886f8..ae36c2bb9ab6d 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -55,7 +55,7 @@ pub fn simplify_regex_expr( } // leave untouched if optimization didn't work - Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })) + Ok(Expr::BinaryExpr(BinaryExpr::new(left, op, right))) } struct OperatorMode { diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 8b3f437dc233e..1bf975a18f665 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -69,7 +69,9 @@ pub static POWS_OF_TEN: [i128; 38] = [ /// expressions. Such as: (A AND B) AND C pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => { + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) if *op == search_op => { expr_contains(left, needle, search_op) || expr_contains(right, needle, search_op) } @@ -87,9 +89,9 @@ pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> xor_counter: &mut i32, ) -> Expr { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if *op == Operator::BitwiseXor => - { + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) if *op == Operator::BitwiseXor => { let left_expr = recursive_delete_xor_in_expr(left, needle, xor_counter); let right_expr = recursive_delete_xor_in_expr(right, needle, xor_counter); if left_expr == *needle { @@ -206,7 +208,7 @@ pub fn is_false(expr: &Expr) -> bool { /// returns true if `haystack` looks like (needle OP X) or (X OP needle) pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { - matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref())) + matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right, .. }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref())) } /// returns true if `not_expr` is !`expr` (not) @@ -246,7 +248,9 @@ pub fn as_bool_lit(expr: Expr) -> Result> { /// For others, use Not clause pub fn negate_clause(expr: Expr) -> Expr { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) => { if let Some(negated_op) = op.negate() { return Expr::BinaryExpr(BinaryExpr::new(left, negated_op, right)); } @@ -321,7 +325,9 @@ pub fn negate_clause(expr: Expr) -> Expr { /// For others, use Negative clause pub fn distribute_negation(expr: Expr) -> Expr { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) => { match op { // ~(A & B) ===> ~A | ~B Operator::BitwiseAnd => { diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 437d9cd47d0af..36cfb1f532d07 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -28,7 +28,7 @@ use datafusion_expr::expr::{self, Between, BinaryExpr, Case, Like, WindowFunctio use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ - coerce_types, comparison_coercion, like_coercion, + binary_operator_data_type, coerce_types, comparison_coercion, like_coercion, }; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::type_coercion::other::{ @@ -235,6 +235,7 @@ impl ExprRewriter for TypeCoercionRewriter { ref left, op, ref right, + .. }) => { let left_type = left.get_type(&self.schema)?; let right_type = right.get_type(&self.schema)?; @@ -246,6 +247,33 @@ impl ExprRewriter for TypeCoercionRewriter { // this is a workaround for https://github.com/apache/arrow-datafusion/issues/3419 Ok(expr.clone()) } + (DataType::Decimal128(_, _), _) | (_, DataType::Decimal128(_, _)) => { + if !matches!(left.as_ref(), &Expr::PromotePrecision(_)) + && !matches!(left.as_ref(), &Expr::PromotePrecision(_)) + { + // Promote decimal types if they are not already promoted + let coerced_type = + coerce_types(&left_type, &op, &right_type)?; + let result_type = + binary_operator_data_type(&left_type, &op, &right_type)?; + let expr = Expr::BinaryExpr(BinaryExpr::new_with_data_type( + Box::new( + left.clone() + .promote_to(&coerced_type, &self.schema)?, + ), + op, + Box::new( + right + .clone() + .promote_to(&coerced_type, &self.schema)?, + ), + Some(result_type), + )); + Ok(expr) + } else { + Ok(expr.clone()) + } + } _ => { let coerced_type = coerce_types(&left_type, &op, &right_type)?; let expr = Expr::BinaryExpr(BinaryExpr::new( diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 4c2a24f055152..b9d65b07db0eb 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -30,7 +30,7 @@ use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion}; use datafusion_expr::utils::from_plan; use datafusion_expr::{ - binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, + binary_expr_with_data_type, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; use std::cmp::Ordering; use std::sync::Arc; @@ -132,7 +132,12 @@ impl ExprRewriter for UnwrapCastExprRewriter { // For case: // try_cast/cast(expr as data_type) op literal // literal op try_cast/cast(expr as data_type) - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + data_type, + }) => { let left = left.as_ref().clone(); let right = right.as_ref().clone(); let left_type = left.get_type(&self.schema)?; @@ -155,10 +160,11 @@ impl ExprRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(left_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the right expr - return Ok(binary_expr( + return Ok(binary_expr_with_data_type( lit(value), *op, expr.as_ref().clone(), + data_type.clone(), )); } } @@ -174,10 +180,11 @@ impl ExprRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(right_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the left expr - return Ok(binary_expr( + return Ok(binary_expr_with_data_type( expr.as_ref().clone(), *op, lit(value), + data_type.clone(), )); } } @@ -187,7 +194,12 @@ impl ExprRewriter for UnwrapCastExprRewriter { }; } // return the new binary op - Ok(binary_expr(left, *op, right)) + Ok(binary_expr_with_data_type( + left, + *op, + right, + data_type.clone(), + )) } // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index ae38a216088df..77d935f16bbe1 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -75,6 +75,7 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& right, op: Operator::And, left, + .. }) => { let exprs = split_conjunction_impl(left, exprs); split_conjunction_impl(right, exprs) @@ -144,7 +145,9 @@ fn split_binary_owned_impl( mut exprs: Vec, ) -> Vec { match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => { + Expr::BinaryExpr(BinaryExpr { + right, op, left, .. + }) if op == operator => { let exprs = split_binary_owned_impl(*left, operator, exprs); split_binary_owned_impl(*right, operator, exprs) } @@ -169,7 +172,9 @@ fn split_binary_impl<'a>( mut exprs: Vec<&'a Expr>, ) -> Vec<&'a Expr> { match expr { - Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => { + Expr::BinaryExpr(BinaryExpr { + right, op, left, .. + }) if *op == operator => { let exprs = split_binary_impl(left, operator, exprs); split_binary_impl(right, operator, exprs) } @@ -242,6 +247,7 @@ pub fn verify_not_disjunction(predicates: &[&Expr]) -> Result<()> { left: _, op: Operator::Or, right: _, + .. }) => { plan_err!("Optimizing disjunctions not supported!") } @@ -299,9 +305,9 @@ pub fn find_join_exprs( } else { // TODO remove the logic let (left, op, right) = match filter { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - (*left.clone(), *op, *right.clone()) - } + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) => (*left.clone(), *op, *right.clone()), _ => { others.push((*filter).clone()); continue; @@ -370,9 +376,9 @@ pub fn exprs_to_join_cols( let mut others: Vec = vec![]; for filter in exprs.iter() { let (left, op, right) = match filter { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - (*left.clone(), *op, *right.clone()) - } + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) => (*left.clone(), *op, *right.clone()), _ => plan_err!("Invalid correlation expression!")?, }; match op { diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 2a19f62104e0e..97e3bd809fe3e 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -89,6 +89,7 @@ pub struct BinaryExpr { left: Arc, op: Operator, right: Arc, + data_type: Option, } impl BinaryExpr { @@ -98,7 +99,27 @@ impl BinaryExpr { op: Operator, right: Arc, ) -> Self { - Self { left, op, right } + Self { + left, + op, + right, + data_type: None, + } + } + + /// Create new binary expression + pub fn new_with_data_type( + left: Arc, + op: Operator, + right: Arc, + data_type: Option, + ) -> Self { + Self { + left, + op, + right, + data_type, + } } /// Get the left side of the binary expression @@ -366,12 +387,14 @@ macro_rules! compute_primitive_op_dyn_scalar { /// LEFT is Decimal or Dictionary array of decimal values, RIGHT is scalar value /// OP_TYPE is the return type of scalar function macro_rules! compute_primitive_decimal_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr) => {{ + ($LEFT:expr, $RIGHT:expr, $OP:ident, $OP_TYPE:expr, $RET_TYPE:expr) => {{ // generate the scalar function name, such as add_decimal_dyn_scalar, // from the $OP parameter (which could have a value of add) and the // suffix _decimal_dyn_scalar if let Some(value) = $RIGHT { - Ok(paste::expr! {[<$OP _decimal_dyn_scalar>]}($LEFT, value)?) + Ok(paste::expr! {[<$OP _decimal_dyn_scalar>]}( + $LEFT, value, $RET_TYPE, + )?) } else { // when the $RIGHT is a NULL, generate a NULL array of $OP_TYPE Ok(Arc::new(new_null_array($OP_TYPE, $LEFT.len()))) @@ -419,15 +442,15 @@ macro_rules! binary_string_array_op { /// The binary_primitive_array_op macro only evaluates for primitive types /// like integers and floats. macro_rules! binary_primitive_array_op_dyn { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + ($LEFT:expr, $RIGHT:expr, $OP:ident, $RET_TYPE:expr) => {{ match $LEFT.data_type() { DataType::Decimal128(_, _) => { - Ok(paste::expr! {[<$OP _decimal>]}(&$LEFT, &$RIGHT)?) + Ok(paste::expr! {[<$OP _decimal>]}(&$LEFT, &$RIGHT, $RET_TYPE)?) } DataType::Dictionary(_, value_type) if matches!(value_type.as_ref(), &DataType::Decimal128(_, _)) => { - Ok(paste::expr! {[<$OP _decimal>]}(&$LEFT, &$RIGHT)?) + Ok(paste::expr! {[<$OP _decimal>]}(&$LEFT, &$RIGHT, $RET_TYPE)?) } _ => Ok(Arc::new( $OP(&$LEFT, &$RIGHT).map_err(|err| DataFusionError::ArrowError(err))?, @@ -440,13 +463,13 @@ macro_rules! binary_primitive_array_op_dyn { /// The binary_primitive_array_op_dyn_scalar macro only evaluates for primitive /// types like integers and floats. macro_rules! binary_primitive_array_op_dyn_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + ($LEFT:expr, $RIGHT:expr, $OP:ident, $RET_TYPE:expr) => {{ // unwrap underlying (non dictionary) value let right = unwrap_dict_value($RIGHT); let op_type = $LEFT.data_type(); let result: Result> = match right { - ScalarValue::Decimal128(v, _, _) => compute_primitive_decimal_op_dyn_scalar!($LEFT, v, $OP, op_type), + ScalarValue::Decimal128(v, _, _) => compute_primitive_decimal_op_dyn_scalar!($LEFT, v, $OP, op_type, $RET_TYPE), ScalarValue::Int8(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, Int8Type), ScalarValue::Int16(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, Int16Type), ScalarValue::Int32(v) => compute_primitive_op_dyn_scalar!($LEFT, v, $OP, op_type, Int32Type), @@ -626,11 +649,15 @@ impl PhysicalExpr for BinaryExpr { } fn data_type(&self, input_schema: &Schema) -> Result { - binary_operator_data_type( - &self.left.data_type(input_schema)?, - &self.op, - &self.right.data_type(input_schema)?, - ) + if self.data_type.is_some() { + Ok(self.data_type.as_ref().unwrap().clone()) + } else { + binary_operator_data_type( + &self.left.data_type(input_schema)?, + &self.op, + &self.right.data_type(input_schema)?, + ) + } } fn nullable(&self, input_schema: &Schema) -> Result { @@ -1012,6 +1039,7 @@ impl BinaryExpr { scalar: ScalarValue, ) -> Result>> { let bool_type = &DataType::Boolean; + let result_type = &self.data_type; let scalar_result = match &self.op { Operator::Lt => { binary_array_op_dyn_scalar!(array, scalar, lt, bool_type) @@ -1032,19 +1060,29 @@ impl BinaryExpr { binary_array_op_dyn_scalar!(array, scalar, neq, bool_type) } Operator::Plus => { - binary_primitive_array_op_dyn_scalar!(array, scalar, add) + binary_primitive_array_op_dyn_scalar!(array, scalar, add, result_type) } Operator::Minus => { - binary_primitive_array_op_dyn_scalar!(array, scalar, subtract) + binary_primitive_array_op_dyn_scalar!( + array, + scalar, + subtract, + result_type + ) } Operator::Multiply => { - binary_primitive_array_op_dyn_scalar!(array, scalar, multiply) + binary_primitive_array_op_dyn_scalar!( + array, + scalar, + multiply, + result_type + ) } Operator::Divide => { - binary_primitive_array_op_dyn_scalar!(array, scalar, divide) + binary_primitive_array_op_dyn_scalar!(array, scalar, divide, result_type) } Operator::Modulo => { - binary_primitive_array_op_dyn_scalar!(array, scalar, modulus) + binary_primitive_array_op_dyn_scalar!(array, scalar, modulus, result_type) } Operator::RegexMatch => binary_string_array_flag_op_scalar!( array, @@ -1126,6 +1164,7 @@ impl BinaryExpr { right: Arc, right_data_type: &DataType, ) -> Result { + let result_type = &self.data_type; match &self.op { Operator::Lt => lt_dyn(&left, &right), Operator::LtEq => lt_eq_dyn(&left, &right), @@ -1146,16 +1185,20 @@ impl BinaryExpr { Operator::IsNotDistinctFrom => { binary_array_op!(left, right, is_not_distinct_from) } - Operator::Plus => binary_primitive_array_op_dyn!(left, right, add_dyn), - Operator::Minus => binary_primitive_array_op_dyn!(left, right, subtract_dyn), + Operator::Plus => { + binary_primitive_array_op_dyn!(left, right, add_dyn, result_type) + } + Operator::Minus => { + binary_primitive_array_op_dyn!(left, right, subtract_dyn, result_type) + } Operator::Multiply => { - binary_primitive_array_op_dyn!(left, right, multiply_dyn) + binary_primitive_array_op_dyn!(left, right, multiply_dyn, result_type) } Operator::Divide => { - binary_primitive_array_op_dyn!(left, right, divide_dyn_opt) + binary_primitive_array_op_dyn!(left, right, divide_dyn_opt, result_type) } Operator::Modulo => { - binary_primitive_array_op_dyn!(left, right, modulus_dyn) + binary_primitive_array_op_dyn!(left, right, modulus_dyn, result_type) } Operator::And => { if left_data_type == &DataType::Boolean { @@ -1229,6 +1272,28 @@ pub fn binary( Ok(Arc::new(BinaryExpr::new(lhs, op, rhs))) } +/// Create a binary expression whose arguments are correctly coerced. +/// This function errors if it is not possible to coerce the arguments +/// to computational types supported by the operator. +pub fn binary_with_data_type( + lhs: Arc, + op: Operator, + rhs: Arc, + input_schema: &Schema, + data_type: Option, +) -> Result> { + let lhs_type = &lhs.data_type(input_schema)?; + let rhs_type = &rhs.data_type(input_schema)?; + if !lhs_type.eq(rhs_type) { + return Err(DataFusionError::Internal(format!( + "The type of {lhs_type} {op:?} {rhs_type} of binary physical should be same" + ))); + } + Ok(Arc::new(BinaryExpr::new_with_data_type( + lhs, op, rhs, data_type, + ))) +} + #[cfg(test)] mod tests { use super::*; @@ -1247,8 +1312,9 @@ mod tests { op: Operator, r: Arc, input_schema: &Schema, + x: &DataType, ) -> Arc { - binary(l, op, r, input_schema).unwrap() + binary_with_data_type(l, op, r, input_schema, Some(x.clone())).unwrap() } #[test] @@ -1266,6 +1332,7 @@ mod tests { Operator::Lt, col("b", &schema)?, &schema, + &DataType::Boolean, ); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; @@ -1299,6 +1366,7 @@ mod tests { Operator::Lt, col("b", &schema)?, &schema, + &DataType::Boolean, ), Operator::Or, binary_simple( @@ -1306,8 +1374,10 @@ mod tests { Operator::Eq, col("b", &schema)?, &schema, + &DataType::Boolean, ), &schema, + &DataType::Boolean, ); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?; @@ -2763,8 +2833,13 @@ mod tests { op: Operator, expected: PrimitiveArray, ) -> Result<()> { - let arithmetic_op = - binary_simple(col("a", &schema)?, op, col("b", &schema)?, &schema); + let arithmetic_op = binary_simple( + col("a", &schema)?, + op, + col("b", &schema)?, + &schema, + expected.data_type(), + ); let batch = RecordBatch::try_new(schema, data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); @@ -2780,7 +2855,8 @@ mod tests { expected: ArrayRef, ) -> Result<()> { let lit = Arc::new(Literal::new(literal)); - let arithmetic_op = binary_simple(col("a", &schema)?, op, lit, &schema); + let arithmetic_op = + binary_simple(col("a", &schema)?, op, lit, &schema, expected.data_type()); let batch = RecordBatch::try_new(schema, data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); @@ -2801,7 +2877,8 @@ mod tests { let left_expr = try_cast(col("a", schema)?, schema, result_type.clone())?; let right_expr = try_cast(col("b", schema)?, schema, result_type)?; - let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); + let arithmetic_op = + binary_simple(left_expr, op, right_expr, schema, &DataType::Boolean); let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); @@ -2831,7 +2908,8 @@ mod tests { try_cast(col("a", schema)?, schema, op_type)? }; - let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); + let arithmetic_op = + binary_simple(left_expr, op, right_expr, schema, &DataType::Boolean); let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); assert_eq!(result.as_ref(), expected); @@ -2860,7 +2938,8 @@ mod tests { try_cast(col("a", schema)?, schema, op_type)? }; - let arithmetic_op = binary_simple(left_expr, op, right_expr, schema); + let arithmetic_op = + binary_simple(left_expr, op, right_expr, schema, &DataType::Boolean); let batch = RecordBatch::try_new(Arc::clone(schema), vec![Arc::clone(arr)])?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); assert_eq!(result.as_ref(), expected); @@ -3428,7 +3507,7 @@ mod tests { let tree_depth: i32 = 100; let expr = (0..tree_depth) .map(|_| col("a", schema.as_ref()).unwrap()) - .reduce(|l, r| binary_simple(l, Operator::Plus, r, &schema)) + .reduce(|l, r| binary_simple(l, Operator::Plus, r, &schema, &DataType::Int32)) .unwrap(); let result = expr @@ -3935,7 +4014,13 @@ mod tests { schema.field(1).is_nullable(), ), ]); - let arithmetic_op = binary_simple(left_expr, op, right_expr, &coerced_schema); + let arithmetic_op = binary_simple( + left_expr, + op, + right_expr, + &coerced_schema, + expected.data_type(), + ); let data: Vec = vec![left.clone(), right.clone()]; let batch = RecordBatch::try_new(schema.clone(), data)?; let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); @@ -4597,6 +4682,7 @@ mod tests { Operator::GtEq, lit(ScalarValue::from(25)), &schema, + &DataType::Boolean, ); let context = AnalysisContext::from_statistics(&schema, &statistics); @@ -4626,6 +4712,7 @@ mod tests { Operator::GtEq, a.clone(), &schema, + &DataType::Boolean, ); let context = AnalysisContext::from_statistics(&schema, &statistics); diff --git a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs index 772aa0b397b89..2d305b7da97fd 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs @@ -25,10 +25,11 @@ use arrow::compute::{ }; use arrow::datatypes::Decimal128Type; use arrow::{array::*, datatypes::ArrowNumericType, downcast_dictionary_array}; -use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE}; +use arrow_schema::DataType; use datafusion_common::cast::as_decimal128_array; use datafusion_common::{DataFusionError, Result}; -use std::cmp::min; +use datafusion_expr::type_coercion::binary::binary_operator_data_type; +use datafusion_expr::Operator; use std::sync::Arc; // Simple (low performance) kernels until optimized kernels are added to arrow @@ -258,14 +259,22 @@ pub(crate) fn is_not_distinct_from_decimal( .collect()) } -pub(crate) fn add_dyn_decimal(left: &dyn Array, right: &dyn Array) -> Result { - let (precision, scale) = get_precision_scale(left)?; +pub(crate) fn add_dyn_decimal( + left: &dyn Array, + right: &dyn Array, + result_type: &Option, +) -> Result { + let (precision, scale) = get_precision_scale(&result_type.clone().unwrap())?; let array = add_dyn(left, right)?; decimal_array_with_precision_scale(array, precision, scale) } -pub(crate) fn add_decimal_dyn_scalar(left: &dyn Array, right: i128) -> Result { - let (precision, scale) = get_precision_scale(left)?; +pub(crate) fn add_decimal_dyn_scalar( + left: &dyn Array, + right: i128, + result_type: &Option, +) -> Result { + let (precision, scale) = get_precision_scale(&result_type.clone().unwrap())?; let array = add_scalar_dyn::(left, right)?; decimal_array_with_precision_scale(array, precision, scale) @@ -274,25 +283,28 @@ pub(crate) fn add_decimal_dyn_scalar(left: &dyn Array, right: i128) -> Result, ) -> Result { - let (precision, scale) = get_precision_scale(left)?; + let (precision, scale) = get_precision_scale(&result_type.clone().unwrap())?; let array = subtract_scalar_dyn::(left, right)?; decimal_array_with_precision_scale(array, precision, scale) } -fn get_precision_scale(left: &dyn Array) -> Result<(u8, i8)> { - match left.data_type() { +fn get_precision_scale(data_type: &DataType) -> Result<(u8, i8)> { + match data_type { DataType::Decimal128(precision, scale) => Ok((*precision, *scale)), DataType::Dictionary(_, value_type) => match value_type.as_ref() { DataType::Decimal128(precision, scale) => Ok((*precision, *scale)), - _ => Err(DataFusionError::Internal( - "Unexpected data type".to_string(), - )), + _ => Err(DataFusionError::Internal(format!( + "Unexpected data type: {}", + data_type + ))), }, - _ => Err(DataFusionError::Internal( - "Unexpected data type".to_string(), - )), + _ => Err(DataFusionError::Internal(format!( + "Unexpected data type: {}", + data_type + ))), } } @@ -334,23 +346,21 @@ fn decimal_array_with_precision_scale( pub(crate) fn multiply_decimal_dyn_scalar( left: &dyn Array, right: i128, + result_type: &Option, ) -> Result { - let (precision, scale) = get_precision_scale(left)?; - + let (precision, scale) = get_precision_scale(&result_type.clone().unwrap())?; let array = multiply_scalar_dyn::(left, right)?; - let divide = 10_i128.pow(scale as u32); let array = divide_scalar_dyn::(&array, divide)?; - decimal_array_with_precision_scale(array, precision, scale) } pub(crate) fn divide_decimal_dyn_scalar( left: &dyn Array, right: i128, + result_type: &Option, ) -> Result { - let (precision, scale) = get_precision_scale(left)?; - + let (precision, scale) = get_precision_scale(&result_type.clone().unwrap())?; let mul = 10_i128.pow(scale as u32); let array = multiply_scalar_dyn::(left, mul)?; @@ -361,8 +371,9 @@ pub(crate) fn divide_decimal_dyn_scalar( pub(crate) fn subtract_dyn_decimal( left: &dyn Array, right: &dyn Array, + result_type: &Option, ) -> Result { - let (precision, scale) = get_precision_scale(left)?; + let (precision, scale) = get_precision_scale(&result_type.clone().unwrap())?; let array = subtract_dyn(left, right)?; decimal_array_with_precision_scale(array, precision, scale) } @@ -370,36 +381,50 @@ pub(crate) fn subtract_dyn_decimal( pub(crate) fn multiply_dyn_decimal( left: &dyn Array, right: &dyn Array, + result_type: &Option, ) -> Result { - let (left_precision, left_scale) = get_precision_scale(left)?; - let (right_precision, right_scale) = get_precision_scale(right)?; - let product_precision = min( - left_precision + right_precision + 1, - DECIMAL128_MAX_PRECISION, - ); - let product_scale = min(left_scale + right_scale, DECIMAL128_MAX_SCALE); + let (precision, scale) = get_precision_scale(&result_type.clone().unwrap())?; + + let op_type = binary_operator_data_type( + left.data_type(), + &Operator::Multiply, + right.data_type(), + )?; + let (_, op_scale) = get_precision_scale(&op_type)?; + let array = multiply_dyn(left, right)?; - decimal_array_with_precision_scale(array, product_precision, product_scale) + + if op_scale > scale { + let div = 10_i128.pow((op_scale - scale) as u32); + let array = divide_scalar_dyn::(&array, div)?; + decimal_array_with_precision_scale(array, precision, scale) + } else { + decimal_array_with_precision_scale(array, precision, scale) + } } pub(crate) fn divide_dyn_opt_decimal( left: &dyn Array, right: &dyn Array, + result_type: &Option, ) -> Result { - let (precision, scale) = get_precision_scale(left)?; + let (precision, scale) = get_precision_scale(&result_type.clone().unwrap())?; let mul = 10_i128.pow(scale as u32); + let array = multiply_scalar_dyn::(left, mul)?; let array = decimal_array_with_precision_scale(array, precision, scale)?; let array = divide_dyn_opt(&array, right)?; + decimal_array_with_precision_scale(array, precision, scale) } pub(crate) fn modulus_dyn_decimal( left: &dyn Array, right: &dyn Array, + result_type: &Option, ) -> Result { - let (precision, scale) = get_precision_scale(left)?; + let (precision, scale) = get_precision_scale(&result_type.clone().unwrap())?; let array = modulus_dyn(left, right)?; decimal_array_with_precision_scale(array, precision, scale) } @@ -407,9 +432,9 @@ pub(crate) fn modulus_dyn_decimal( pub(crate) fn modulus_decimal_dyn_scalar( left: &dyn Array, right: i128, + result_type: &Option, ) -> Result { - let (precision, scale) = get_precision_scale(left)?; - + let (precision, scale) = get_precision_scale(&result_type.clone().unwrap())?; let array = modulus_scalar_dyn::(left, right)?; decimal_array_with_precision_scale(array, precision, scale) } @@ -507,33 +532,66 @@ mod tests { 3, ); // add - let result = add_dyn_decimal(&left_decimal_array, &right_decimal_array)?; + let result_type = Some( + binary_operator_data_type( + left_decimal_array.data_type(), + &Operator::Plus, + right_decimal_array.data_type(), + ) + .unwrap(), + ); + let result = + add_dyn_decimal(&left_decimal_array, &right_decimal_array, &result_type)?; let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(246), None, Some(245), Some(247)], 25, 3); assert_eq!(&expect, result); - let result = add_decimal_dyn_scalar(&left_decimal_array, 10)?; + let result = add_decimal_dyn_scalar(&left_decimal_array, 10, &result_type)?; let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(133), None, Some(132), Some(134)], 25, 3); assert_eq!(&expect, result); // subtract - let result = subtract_dyn_decimal(&left_decimal_array, &right_decimal_array)?; + let result_type = Some( + binary_operator_data_type( + left_decimal_array.data_type(), + &Operator::Minus, + right_decimal_array.data_type(), + ) + .unwrap(), + ); + let result = subtract_dyn_decimal( + &left_decimal_array, + &right_decimal_array, + &result_type, + )?; let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(0), None, Some(-1), Some(1)], 25, 3); assert_eq!(&expect, result); - let result = subtract_decimal_dyn_scalar(&left_decimal_array, 10)?; + let result = subtract_decimal_dyn_scalar(&left_decimal_array, 10, &result_type)?; let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(113), None, Some(112), Some(114)], 25, 3); assert_eq!(&expect, result); // multiply - let result = multiply_dyn_decimal(&left_decimal_array, &right_decimal_array)?; + let result_type = Some( + binary_operator_data_type( + left_decimal_array.data_type(), + &Operator::Multiply, + right_decimal_array.data_type(), + ) + .unwrap(), + ); + let result = multiply_dyn_decimal( + &left_decimal_array, + &right_decimal_array, + &result_type, + )?; let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(15129), None, Some(15006), Some(15252)], 38, 6); assert_eq!(&expect, result); - let result = multiply_decimal_dyn_scalar(&left_decimal_array, 10)?; + let result = multiply_decimal_dyn_scalar(&left_decimal_array, 10, &result_type)?; let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(1), None, Some(1), Some(1)], 25, 3); assert_eq!(&expect, result); @@ -554,7 +612,19 @@ mod tests { 25, 3, ); - let result = divide_dyn_opt_decimal(&left_decimal_array, &right_decimal_array)?; + let result_type = Some( + binary_operator_data_type( + left_decimal_array.data_type(), + &Operator::Divide, + right_decimal_array.data_type(), + ) + .unwrap(), + ); + let result = divide_dyn_opt_decimal( + &left_decimal_array, + &right_decimal_array, + &result_type, + )?; let result = as_decimal128_array(&result)?; let expect = create_decimal_array( &[Some(123456700), None, Some(22446672), Some(-10037130), None], @@ -562,7 +632,7 @@ mod tests { 3, ); assert_eq!(&expect, result); - let result = divide_decimal_dyn_scalar(&left_decimal_array, 10)?; + let result = divide_decimal_dyn_scalar(&left_decimal_array, 10, &result_type)?; let result = as_decimal128_array(&result)?; let expect = create_decimal_array( &[ @@ -576,12 +646,22 @@ mod tests { 3, ); assert_eq!(&expect, result); - let result = modulus_dyn_decimal(&left_decimal_array, &right_decimal_array)?; + // modulus + let result_type = Some( + binary_operator_data_type( + left_decimal_array.data_type(), + &Operator::Modulo, + right_decimal_array.data_type(), + ) + .unwrap(), + ); + let result = + modulus_dyn_decimal(&left_decimal_array, &right_decimal_array, &result_type)?; let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(7), None, Some(37), Some(16), None], 25, 3); assert_eq!(&expect, result); - let result = modulus_decimal_dyn_scalar(&left_decimal_array, 10)?; + let result = modulus_decimal_dyn_scalar(&left_decimal_array, 10, &result_type)?; let result = as_decimal128_array(&result)?; let expect = create_decimal_array(&[Some(7), None, Some(7), Some(7), Some(7)], 25, 3); @@ -595,12 +675,31 @@ mod tests { let left_decimal_array = create_decimal_array(&[Some(101)], 10, 1); let right_decimal_array = create_decimal_array(&[Some(0)], 1, 1); - let err = divide_decimal_dyn_scalar(&left_decimal_array, 0).unwrap_err(); + let result_type = Some( + binary_operator_data_type( + left_decimal_array.data_type(), + &Operator::Divide, + right_decimal_array.data_type(), + ) + .unwrap(), + ); + let err = + divide_decimal_dyn_scalar(&left_decimal_array, 0, &result_type).unwrap_err(); assert_eq!("Arrow error: Divide by zero error", err.to_string()); + let result_type = Some( + binary_operator_data_type( + left_decimal_array.data_type(), + &Operator::Modulo, + right_decimal_array.data_type(), + ) + .unwrap(), + ); let err = - modulus_dyn_decimal(&left_decimal_array, &right_decimal_array).unwrap_err(); + modulus_dyn_decimal(&left_decimal_array, &right_decimal_array, &result_type) + .unwrap_err(); assert_eq!("Arrow error: Divide by zero error", err.to_string()); - let err = modulus_decimal_dyn_scalar(&left_decimal_array, 0).unwrap_err(); + let err = + modulus_decimal_dyn_scalar(&left_decimal_array, 0, &result_type).unwrap_err(); assert_eq!("Arrow error: Divide by zero error", err.to_string()); } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 63fb7b7d37ad5..3f94d4fff6275 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -33,6 +33,7 @@ mod negative; mod no_op; mod not; mod nullif; +mod promote_precision; mod try_cast; /// Module with some convenient methods used in expression building @@ -72,7 +73,7 @@ pub use crate::window::rank::{dense_rank, percent_rank, rank}; pub use crate::window::rank::{Rank, RankType}; pub use crate::window::row_number::RowNumber; -pub use binary::{binary, BinaryExpr}; +pub use binary::{binary, binary_with_data_type, BinaryExpr}; pub use case::{case, CaseExpr}; pub use cast::{ cast, cast_column, cast_with_options, CastExpr, DEFAULT_DATAFUSION_CAST_OPTIONS, @@ -89,6 +90,7 @@ pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; pub use nullif::nullif_func; +pub use promote_precision::promote_precision; pub use try_cast::{try_cast, TryCastExpr}; /// returns the name of the state diff --git a/datafusion/physical-expr/src/expressions/promote_precision.rs b/datafusion/physical-expr/src/expressions/promote_precision.rs new file mode 100644 index 0000000000000..0d9a7144e5f4c --- /dev/null +++ b/datafusion/physical-expr/src/expressions/promote_precision.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::physical_expr::down_cast_any_ref; +use crate::PhysicalExpr; +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, Schema}; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +/// PromotePrecision expression wraps an expression which was promoted to a specific data type +#[derive(Debug)] +pub struct PromotePrecisionExpr { + /// The expression to be promoted + expr: Arc, +} + +impl PromotePrecisionExpr { + /// Create a new PromotePrecisionExpr + pub fn new(expr: Arc) -> Self { + Self { expr } + } +} + +impl fmt::Display for PromotePrecisionExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "PROMOTE_PRECISION({})", self.expr) + } +} + +impl PhysicalExpr for PromotePrecisionExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + self.expr.data_type(_input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.expr.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + self.expr.evaluate(batch) + } + + fn children(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(PromotePrecisionExpr::new(children[0].clone()))) + } +} + +impl PartialEq for PromotePrecisionExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.expr.eq(&x.expr)) + .unwrap_or(false) + } +} + +/// Creates a unary expression PromotePrecisionExpr +pub fn promote_precision(arg: Arc) -> Result> { + Ok(Arc::new(PromotePrecisionExpr::new(arg))) +} diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 1fbd73b3ba01c..35656780ff3c0 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::expressions::binary_with_data_type; use crate::var_provider::is_system_variables; use crate::{ execution_props::ExecutionProps, @@ -27,7 +28,7 @@ use crate::{ }; use arrow::datatypes::{DataType, Schema}; use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::Cast; +use datafusion_expr::expr::{Cast, PromotePrecision}; use datafusion_expr::{ binary_expr, Between, BinaryExpr, Expr, GetIndexedField, Like, Operator, TryCast, }; @@ -169,7 +170,12 @@ pub fn create_physical_expr( execution_props, ) } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + data_type, + }) => { let lhs = create_physical_expr( left, input_dfschema, @@ -215,7 +221,7 @@ pub fn create_physical_expr( // // There should be no coercion during physical // planning. - binary(lhs, *op, rhs, input_schema) + binary_with_data_type(lhs, *op, rhs, input_schema, data_type.clone()) } } } @@ -340,6 +346,14 @@ pub fn create_physical_expr( input_schema, data_type.clone(), ), + Expr::PromotePrecision(PromotePrecision { expr }) => { + expressions::promote_precision(create_physical_expr( + expr, + input_dfschema, + input_schema, + execution_props, + )?) + } Expr::TryCast(TryCast { expr, data_type }) => expressions::try_cast( create_physical_expr(expr, input_dfschema, input_schema, execution_props)?, input_schema, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index e8570cf3c7e44..cb21d018b83f7 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -464,7 +464,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { expr_type: Some(ExprType::Literal(pb_value)), } } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { left, op, right, .. }) => { // Try to linerize a nested binary expression tree of the same operator // into a flat vector of expressions. let mut exprs = vec![right.as_ref()]; @@ -473,6 +473,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { left, op: current_op, right, + .. }) = current_expr { if current_op == op { @@ -945,8 +946,8 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } } - Expr::QualifiedWildcard { .. } => return Err(Error::General( - "Proto serialization error: Expr::QualifiedWildcard { .. } not supported" + Expr::QualifiedWildcard { .. } | Expr::PromotePrecision { .. } => return Err(Error::General( + "Proto serialization error: Expr::QualifiedWildcard { .. } | Expr::PromotePrecision { .. } not supported" .to_string(), )), }; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 441c29775c77d..848509d2ab824 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -523,7 +523,10 @@ fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Resu fn infer_placeholder_types(expr: Expr, schema: &DFSchema) -> Result { rewrite_expr(expr, |mut expr| { // Default to assuming the arguments are the same type - if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { + if let Expr::BinaryExpr(BinaryExpr { + left, op: _, right, .. + }) = &mut expr + { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; rewrite_placeholder(right.as_mut(), left.as_ref(), schema)?; }; diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 91cef6d4712e7..61017d1f6d3ec 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -23,7 +23,7 @@ use sqlparser::ast::Ident; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::expr::{ AggregateFunction, Between, BinaryExpr, Case, GetIndexedField, GroupingSet, Like, - WindowFunction, + PromotePrecision, WindowFunction, }; use datafusion_expr::expr::{Cast, Sort}; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; @@ -232,13 +232,17 @@ where .collect::>>()?, negated: *negated, }), - Expr::BinaryExpr(BinaryExpr { left, right, op }) => { - Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(clone_with_replacement(left, replacement_fn)?), - *op, - Box::new(clone_with_replacement(right, replacement_fn)?), - ))) - } + Expr::BinaryExpr(BinaryExpr { + left, + right, + op, + data_type, + }) => Ok(Expr::BinaryExpr(BinaryExpr::new_with_data_type( + Box::new(clone_with_replacement(left, replacement_fn)?), + *op, + Box::new(clone_with_replacement(right, replacement_fn)?), + data_type.clone(), + ))), Expr::Like(Like { negated, expr, @@ -344,6 +348,11 @@ where Box::new(clone_with_replacement(expr, replacement_fn)?), data_type.clone(), ))), + Expr::PromotePrecision(PromotePrecision { expr }) => { + Ok(Expr::PromotePrecision(PromotePrecision::new(Box::new( + clone_with_replacement(expr, replacement_fn)?, + )))) + } Expr::TryCast(TryCast { expr: nested_expr, data_type, diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 767c4a39375a5..86965fa4aa003 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -310,22 +310,22 @@ pub async fn from_substrait_rel( let join_exprs: Vec<(Column, Column, bool)> = predicates .iter() .map(|p| match p { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - match (left.as_ref(), right.as_ref()) { - (Expr::Column(l), Expr::Column(r)) => match op { - Operator::Eq => Ok((l.clone(), r.clone(), false)), - Operator::IsNotDistinctFrom => { - Ok((l.clone(), r.clone(), true)) - } - _ => Err(DataFusionError::Internal( - "invalid join condition op".to_string(), - )), - }, + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) => match (left.as_ref(), right.as_ref()) { + (Expr::Column(l), Expr::Column(r)) => match op { + Operator::Eq => Ok((l.clone(), r.clone(), false)), + Operator::IsNotDistinctFrom => { + Ok((l.clone(), r.clone(), true)) + } _ => Err(DataFusionError::Internal( - "invalid join condition expresssion".to_string(), + "invalid join condition op".to_string(), )), - } - } + }, + _ => Err(DataFusionError::Internal( + "invalid join condition expresssion".to_string(), + )), + }, _ => Err(DataFusionError::Internal( "Non-binary expression is not supported in join condition" .to_string(), @@ -674,6 +674,7 @@ pub async fn from_substrait_rex( .as_ref() .clone(), ), + data_type: None, }))) } (l, r) => Err(DataFusionError::NotImplemented(format!( diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index ecb322edb70e8..c7c1780a2f6f4 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -581,7 +581,9 @@ pub fn to_substrait_rex( let index = schema.index_of_column(col)?; substrait_field_ref(index) } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + Expr::BinaryExpr(BinaryExpr { + left, op, right, .. + }) => { let l = to_substrait_rex(left, schema, extension_info)?; let r = to_substrait_rex(right, schema, extension_info)?;