From f69148dfded14ec4a957512d4ebaab7a7a3a568f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 12 Dec 2024 14:51:19 -0700 Subject: [PATCH] chore: Move remaining expressions to spark-expr crate + some minor refactoring (#1165) * move CheckOverflow to spark-expr crate * move NegativeExpr to spark-expr crate * move UnboundColumn to spark-expr crate * move ExpandExec from execution::datafusion::operators to execution::operators * refactoring to remove datafusion subpackage * update imports in benches * fix * fix --- Cargo.toml | 1 + src/checkoverflow.rs | 173 ++++++++++++++++++++++++++++ src/lib.rs | 12 ++ src/negative.rs | 266 +++++++++++++++++++++++++++++++++++++++++++ src/unbound.rs | 110 ++++++++++++++++++ 5 files changed, 562 insertions(+) create mode 100644 src/checkoverflow.rs create mode 100644 src/negative.rs create mode 100644 src/unbound.rs diff --git a/Cargo.toml b/Cargo.toml index 65517431d2d9..d0bc2fd9dd53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ edition = { workspace = true } [dependencies] arrow = { workspace = true } arrow-array = { workspace = true } +arrow-buffer = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } chrono = { workspace = true } diff --git a/src/checkoverflow.rs b/src/checkoverflow.rs new file mode 100644 index 000000000000..e922171bd2b5 --- /dev/null +++ b/src/checkoverflow.rs @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use arrow::{ + array::{as_primitive_array, Array, ArrayRef, Decimal128Array}, + datatypes::{Decimal128Type, DecimalType}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_physical_expr::PhysicalExpr; + +/// This is from Spark `CheckOverflow` expression. Spark `CheckOverflow` expression rounds decimals +/// to given scale and check if the decimals can fit in given precision. As `cast` kernel rounds +/// decimals already, Comet `CheckOverflow` expression only checks if the decimals can fit in the +/// precision. +#[derive(Debug, Hash)] +pub struct CheckOverflow { + pub child: Arc, + pub data_type: DataType, + pub fail_on_error: bool, +} + +impl CheckOverflow { + pub fn new(child: Arc, data_type: DataType, fail_on_error: bool) -> Self { + Self { + child, + data_type, + fail_on_error, + } + } +} + +impl Display for CheckOverflow { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "CheckOverflow [datatype: {}, fail_on_error: {}, child: {}]", + self.data_type, self.fail_on_error, self.child + ) + } +} + +impl PartialEq for CheckOverflow { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.child.eq(&x.child) + && self.data_type.eq(&x.data_type) + && self.fail_on_error.eq(&x.fail_on_error) + }) + .unwrap_or(false) + } +} + +impl PhysicalExpr for CheckOverflow { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _: &Schema) -> datafusion_common::Result { + Ok(self.data_type.clone()) + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) + if matches!(array.data_type(), DataType::Decimal128(_, _)) => + { + let (precision, scale) = match &self.data_type { + DataType::Decimal128(p, s) => (p, s), + dt => { + return Err(DataFusionError::Execution(format!( + "CheckOverflow expects only Decimal128, but got {:?}", + dt + ))) + } + }; + + let decimal_array = as_primitive_array::(&array); + + let casted_array = if self.fail_on_error { + // Returning error if overflow + decimal_array.validate_decimal_precision(*precision)?; + decimal_array + } else { + // Overflowing gets null value + &decimal_array.null_if_overflow_precision(*precision) + }; + + let new_array = Decimal128Array::from(casted_array.into_data()) + .with_precision_and_scale(*precision, *scale) + .map(|a| Arc::new(a) as ArrayRef)?; + + Ok(ColumnarValue::Array(new_array)) + } + ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => { + // `fail_on_error` is only true when ANSI is enabled, which we don't support yet + // (Java side will simply fallback to Spark when it is enabled) + assert!( + !self.fail_on_error, + "fail_on_error (ANSI mode) is not supported yet" + ); + + let new_v: Option = v.and_then(|v| { + Decimal128Type::validate_decimal_precision(v, precision) + .map(|_| v) + .ok() + }); + + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + new_v, precision, scale, + ))) + } + v => Err(DataFusionError::Execution(format!( + "CheckOverflow's child expression should be decimal array, but found {:?}", + v + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(CheckOverflow::new( + Arc::clone(&children[0]), + self.data_type.clone(), + self.fail_on_error, + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.child.hash(&mut s); + self.data_type.hash(&mut s); + self.fail_on_error.hash(&mut s); + self.hash(&mut s); + } +} diff --git a/src/lib.rs b/src/lib.rs index 5dff6e0b8f4e..8a5748058769 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,8 @@ mod bitwise_not; pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; mod avg_decimal; pub use avg_decimal::AvgDecimal; +mod checkoverflow; +pub use checkoverflow::CheckOverflow; mod correlation; pub use correlation::Correlation; mod covariance; @@ -45,10 +47,14 @@ pub use stddev::Stddev; mod structs; mod sum_decimal; pub use sum_decimal::SumDecimal; +mod negative; +pub use negative::{create_negate_expr, NegativeExpr}; mod normalize_nan; mod temporal; pub mod timezone; mod to_json; +mod unbound; +pub use unbound::UnboundColumn; pub mod utils; pub use normalize_nan::NormalizeNaNAndZero; @@ -83,3 +89,9 @@ pub enum EvalMode { /// failing the entire query. Try, } + +pub(crate) fn arithmetic_overflow_error(from_type: &str) -> SparkError { + SparkError::ArithmeticOverflow { + from_type: from_type.to_string(), + } +} diff --git a/src/negative.rs b/src/negative.rs new file mode 100644 index 000000000000..3d9063e7835f --- /dev/null +++ b/src/negative.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::arithmetic_overflow_error; +use crate::SparkError; +use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeType}; +use arrow_array::RecordBatch; +use arrow_buffer::IntervalDayTime; +use arrow_schema::{DataType, Schema}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; +use datafusion::{ + logical_expr::{interval_arithmetic::Interval, ColumnarValue}, + physical_expr::PhysicalExpr, +}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::sort_properties::ExprProperties; +use std::{ + any::Any, + hash::{Hash, Hasher}, + sync::Arc, +}; + +pub fn create_negate_expr( + expr: Arc, + fail_on_error: bool, +) -> Result, DataFusionError> { + Ok(Arc::new(NegativeExpr::new(expr, fail_on_error))) +} + +/// Negative expression +#[derive(Debug, Hash)] +pub struct NegativeExpr { + /// Input expression + arg: Arc, + fail_on_error: bool, +} + +macro_rules! check_overflow { + ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{ + let typed_array = $array + .as_any() + .downcast_ref::<$array_type>() + .expect(concat!(stringify!($array_type), " expected")); + for i in 0..typed_array.len() { + if typed_array.value(i) == $min_val { + if $type_name == "byte" || $type_name == "short" { + let value = format!("{:?} caused", typed_array.value(i)); + return Err(arithmetic_overflow_error(value.as_str()).into()); + } + return Err(arithmetic_overflow_error($type_name).into()); + } + } + }}; +} + +impl NegativeExpr { + /// Create new not expression + pub fn new(arg: Arc, fail_on_error: bool) -> Self { + Self { arg, fail_on_error } + } + + /// Get the input expression + pub fn arg(&self) -> &Arc { + &self.arg + } +} + +impl std::fmt::Display for NegativeExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "(- {})", self.arg) + } +} + +impl PhysicalExpr for NegativeExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + self.arg.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.arg.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arg = self.arg.evaluate(batch)?; + + // overflow checks only apply in ANSI mode + // datatypes supported are byte, short, integer, long, float, interval + match arg { + ColumnarValue::Array(array) => { + if self.fail_on_error { + match array.data_type() { + DataType::Int8 => { + check_overflow!(array, arrow::array::Int8Array, i8::MIN, "byte") + } + DataType::Int16 => { + check_overflow!(array, arrow::array::Int16Array, i16::MIN, "short") + } + DataType::Int32 => { + check_overflow!(array, arrow::array::Int32Array, i32::MIN, "integer") + } + DataType::Int64 => { + check_overflow!(array, arrow::array::Int64Array, i64::MIN, "long") + } + DataType::Interval(value) => match value { + arrow::datatypes::IntervalUnit::YearMonth => check_overflow!( + array, + arrow::array::IntervalYearMonthArray, + i32::MIN, + "interval" + ), + arrow::datatypes::IntervalUnit::DayTime => check_overflow!( + array, + arrow::array::IntervalDayTimeArray, + IntervalDayTime::MIN, + "interval" + ), + arrow::datatypes::IntervalUnit::MonthDayNano => { + // Overflow checks are not supported + } + }, + _ => { + // Overflow checks are not supported for other datatypes + } + } + } + let result = neg_wrapping(array.as_ref())?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar) => { + if self.fail_on_error { + match scalar { + ScalarValue::Int8(value) => { + if value == Some(i8::MIN) { + return Err(arithmetic_overflow_error(" caused").into()); + } + } + ScalarValue::Int16(value) => { + if value == Some(i16::MIN) { + return Err(arithmetic_overflow_error(" caused").into()); + } + } + ScalarValue::Int32(value) => { + if value == Some(i32::MIN) { + return Err(arithmetic_overflow_error("integer").into()); + } + } + ScalarValue::Int64(value) => { + if value == Some(i64::MIN) { + return Err(arithmetic_overflow_error("long").into()); + } + } + ScalarValue::IntervalDayTime(value) => { + let (days, ms) = + IntervalDayTimeType::to_parts(value.unwrap_or_default()); + if days == i32::MIN || ms == i32::MIN { + return Err(arithmetic_overflow_error("interval").into()); + } + } + ScalarValue::IntervalYearMonth(value) => { + if value == Some(i32::MIN) { + return Err(arithmetic_overflow_error("interval").into()); + } + } + _ => { + // Overflow checks are not supported for other datatypes + } + } + } + Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?)) + } + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.arg] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(NegativeExpr::new( + Arc::clone(&children[0]), + self.fail_on_error, + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } + + /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval. + /// It replaces the upper and lower bounds after multiplying them with -1. + /// Ex: `(a, b]` => `[-b, -a)` + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + Interval::try_new( + children[0].upper().arithmetic_negate()?, + children[0].lower().arithmetic_negate()?, + ) + } + + /// Returns a new [`Interval`] of a NegativeExpr that has the existing `interval` given that + /// given the input interval is known to be `children`. + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + let child_interval = children[0]; + + if child_interval.lower() == &ScalarValue::Int32(Some(i32::MIN)) + || child_interval.upper() == &ScalarValue::Int32(Some(i32::MIN)) + || child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN)) + || child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN)) + { + return Err(SparkError::ArithmeticOverflow { + from_type: "long".to_string(), + } + .into()); + } + + let negated_interval = Interval::try_new( + interval.upper().arithmetic_negate()?, + interval.lower().arithmetic_negate()?, + )?; + + Ok(child_interval + .intersect(negated_interval)? + .map(|result| vec![result])) + } + + /// The ordering of a [`NegativeExpr`] is simply the reverse of its child. + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let properties = children[0].clone().with_order(children[0].sort_properties); + Ok(properties) + } +} + +impl PartialEq for NegativeExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.arg.eq(&x.arg)) + .unwrap_or(false) + } +} diff --git a/src/unbound.rs b/src/unbound.rs new file mode 100644 index 000000000000..a6babd0f7ef1 --- /dev/null +++ b/src/unbound.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::RecordBatch; +use arrow_schema::{DataType, Schema}; +use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{internal_err, Result}; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + hash::{Hash, Hasher}, + sync::Arc, +}; + +/// This is similar to `UnKnownColumn` in DataFusion, but it has data type. +/// This is only used when the column is not bound to a schema, for example, the +/// inputs to aggregation functions in final aggregation. In the case, we cannot +/// bind the aggregation functions to the input schema which is grouping columns +/// and aggregate buffer attributes in Spark (DataFusion has different design). +/// But when creating certain aggregation functions, we need to know its input +/// data types. As `UnKnownColumn` doesn't have data type, we implement this +/// `UnboundColumn` to carry the data type. +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct UnboundColumn { + name: String, + datatype: DataType, +} + +impl UnboundColumn { + /// Create a new unbound column expression + pub fn new(name: &str, datatype: DataType) -> Self { + Self { + name: name.to_owned(), + datatype, + } + } + + /// Get the column name + pub fn name(&self) -> &str { + &self.name + } +} + +impl std::fmt::Display for UnboundColumn { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}, datatype: {}", self.name, self.datatype) + } +} + +impl PhysicalExpr for UnboundColumn { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn std::any::Any { + self + } + + /// Get the data type of this expression, given the schema of the input + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.datatype.clone()) + } + + /// Decide whether this expression is nullable, given the schema of the input + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + /// Evaluate the expression + fn evaluate(&self, _batch: &RecordBatch) -> Result { + internal_err!("UnboundColumn::evaluate() should not be called") + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } +} + +impl PartialEq for UnboundColumn { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self == x) + .unwrap_or(false) + } +}