Skip to content

Commit

Permalink
chore: Move remaining expressions to spark-expr crate + some minor re…
Browse files Browse the repository at this point in the history
…factoring (apache#1165)

* move CheckOverflow to spark-expr crate

* move NegativeExpr to spark-expr crate

* move UnboundColumn to spark-expr crate

* move ExpandExec from execution::datafusion::operators to execution::operators

* refactoring to remove datafusion subpackage

* update imports in benches

* fix

* fix
  • Loading branch information
andygrove authored Dec 12, 2024
1 parent 3859724 commit f69148d
Show file tree
Hide file tree
Showing 5 changed files with 562 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ edition = { workspace = true }
[dependencies]
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
arrow-data = { workspace = true }
arrow-schema = { workspace = true }
chrono = { workspace = true }
Expand Down
173 changes: 173 additions & 0 deletions src/checkoverflow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::{
any::Any,
fmt::{Display, Formatter},
hash::{Hash, Hasher},
sync::Arc,
};

use arrow::{
array::{as_primitive_array, Array, ArrayRef, Decimal128Array},
datatypes::{Decimal128Type, DecimalType},
record_batch::RecordBatch,
};
use arrow_schema::{DataType, Schema};
use datafusion::logical_expr::ColumnarValue;
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_physical_expr::PhysicalExpr;

/// This is from Spark `CheckOverflow` expression. Spark `CheckOverflow` expression rounds decimals
/// to given scale and check if the decimals can fit in given precision. As `cast` kernel rounds
/// decimals already, Comet `CheckOverflow` expression only checks if the decimals can fit in the
/// precision.
#[derive(Debug, Hash)]
pub struct CheckOverflow {
pub child: Arc<dyn PhysicalExpr>,
pub data_type: DataType,
pub fail_on_error: bool,
}

impl CheckOverflow {
pub fn new(child: Arc<dyn PhysicalExpr>, data_type: DataType, fail_on_error: bool) -> Self {
Self {
child,
data_type,
fail_on_error,
}
}
}

impl Display for CheckOverflow {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"CheckOverflow [datatype: {}, fail_on_error: {}, child: {}]",
self.data_type, self.fail_on_error, self.child
)
}
}

impl PartialEq<dyn Any> for CheckOverflow {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.child.eq(&x.child)
&& self.data_type.eq(&x.data_type)
&& self.fail_on_error.eq(&x.fail_on_error)
})
.unwrap_or(false)
}
}

impl PhysicalExpr for CheckOverflow {
fn as_any(&self) -> &dyn Any {
self
}

fn data_type(&self, _: &Schema) -> datafusion_common::Result<DataType> {
Ok(self.data_type.clone())
}

fn nullable(&self, _: &Schema) -> datafusion_common::Result<bool> {
Ok(true)
}

fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result<ColumnarValue> {
let arg = self.child.evaluate(batch)?;
match arg {
ColumnarValue::Array(array)
if matches!(array.data_type(), DataType::Decimal128(_, _)) =>
{
let (precision, scale) = match &self.data_type {
DataType::Decimal128(p, s) => (p, s),
dt => {
return Err(DataFusionError::Execution(format!(
"CheckOverflow expects only Decimal128, but got {:?}",
dt
)))
}
};

let decimal_array = as_primitive_array::<Decimal128Type>(&array);

let casted_array = if self.fail_on_error {
// Returning error if overflow
decimal_array.validate_decimal_precision(*precision)?;
decimal_array
} else {
// Overflowing gets null value
&decimal_array.null_if_overflow_precision(*precision)
};

let new_array = Decimal128Array::from(casted_array.into_data())
.with_precision_and_scale(*precision, *scale)
.map(|a| Arc::new(a) as ArrayRef)?;

Ok(ColumnarValue::Array(new_array))
}
ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => {
// `fail_on_error` is only true when ANSI is enabled, which we don't support yet
// (Java side will simply fallback to Spark when it is enabled)
assert!(
!self.fail_on_error,
"fail_on_error (ANSI mode) is not supported yet"
);

let new_v: Option<i128> = v.and_then(|v| {
Decimal128Type::validate_decimal_precision(v, precision)
.map(|_| v)
.ok()
});

Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
new_v, precision, scale,
)))
}
v => Err(DataFusionError::Execution(format!(
"CheckOverflow's child expression should be decimal array, but found {:?}",
v
))),
}
}

fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
vec![&self.child]
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(CheckOverflow::new(
Arc::clone(&children[0]),
self.data_type.clone(),
self.fail_on_error,
)))
}

fn dyn_hash(&self, state: &mut dyn Hasher) {
let mut s = state;
self.child.hash(&mut s);
self.data_type.hash(&mut s);
self.fail_on_error.hash(&mut s);
self.hash(&mut s);
}
}
12 changes: 12 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ mod bitwise_not;
pub use bitwise_not::{bitwise_not, BitwiseNotExpr};
mod avg_decimal;
pub use avg_decimal::AvgDecimal;
mod checkoverflow;
pub use checkoverflow::CheckOverflow;
mod correlation;
pub use correlation::Correlation;
mod covariance;
Expand All @@ -45,10 +47,14 @@ pub use stddev::Stddev;
mod structs;
mod sum_decimal;
pub use sum_decimal::SumDecimal;
mod negative;
pub use negative::{create_negate_expr, NegativeExpr};
mod normalize_nan;
mod temporal;
pub mod timezone;
mod to_json;
mod unbound;
pub use unbound::UnboundColumn;
pub mod utils;
pub use normalize_nan::NormalizeNaNAndZero;

Expand Down Expand Up @@ -83,3 +89,9 @@ pub enum EvalMode {
/// failing the entire query.
Try,
}

pub(crate) fn arithmetic_overflow_error(from_type: &str) -> SparkError {
SparkError::ArithmeticOverflow {
from_type: from_type.to_string(),
}
}
Loading

0 comments on commit f69148d

Please sign in to comment.