Skip to content

Commit

Permalink
Use modulus dyn kernels for arithmetic expressions (#5634)
Browse files Browse the repository at this point in the history
* Use modulus dyn kernels

* Update

* Update

* Add tests
  • Loading branch information
viirya authored Mar 20, 2023
1 parent 3ccf1ae commit d77ccc2
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 114 deletions.
295 changes: 203 additions & 92 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ mod adapter;
mod kernels;
mod kernels_arrow;

use std::convert::TryInto;
use std::{any::Any, sync::Arc};

use arrow::array::*;
use arrow::compute::kernels::arithmetic::{
add_dyn, add_scalar_dyn as add_dyn_scalar, divide_dyn_opt,
divide_scalar_dyn as divide_dyn_scalar, modulus, modulus_scalar, multiply_dyn,
divide_scalar_dyn as divide_dyn_scalar, modulus_dyn,
modulus_scalar_dyn as modulus_dyn_scalar, multiply_dyn,
multiply_scalar_dyn as multiply_dyn_scalar, subtract_dyn,
subtract_scalar_dyn as subtract_dyn_scalar,
};
Expand Down Expand Up @@ -64,7 +64,7 @@ use kernels_arrow::{
is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from,
is_not_distinct_from_bool, is_not_distinct_from_decimal, is_not_distinct_from_f32,
is_not_distinct_from_f64, is_not_distinct_from_null, is_not_distinct_from_utf8,
modulus_decimal, modulus_decimal_scalar, multiply_decimal_dyn_scalar,
modulus_decimal_dyn_scalar, modulus_dyn_decimal, multiply_decimal_dyn_scalar,
multiply_dyn_decimal, subtract_decimal_dyn_scalar, subtract_dyn_decimal,
};

Expand All @@ -76,7 +76,7 @@ use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison};
use crate::intervals::{apply_operator, Interval};
use crate::physical_expr::down_cast_any_ref;
use crate::{analysis_expect, AnalysisContext, ExprBoundaries, PhysicalExpr};
use datafusion_common::cast::{as_boolean_array, as_decimal128_array};
use datafusion_common::cast::as_boolean_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::type_coercion::binary::binary_operator_data_type;
Expand Down Expand Up @@ -160,21 +160,6 @@ macro_rules! compute_decimal_op_dyn_scalar {
}};
}

macro_rules! compute_decimal_op_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = as_decimal128_array($LEFT).unwrap();
if let ScalarValue::Decimal128(Some(_), _, _) = $RIGHT {
Ok(Arc::new(paste::expr! {[<$OP _decimal_scalar>]}(
ll,
$RIGHT.try_into()?,
)?))
} else {
// when the $RIGHT is a NULL, generate a NULL array of LEFT's datatype
Ok(Arc::new(new_null_array($LEFT.data_type(), $LEFT.len())))
}
}};
}

macro_rules! compute_decimal_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT.as_any().downcast_ref::<$DT>().unwrap();
Expand Down Expand Up @@ -335,25 +320,6 @@ macro_rules! compute_bool_op {
}};
}

/// Invoke a compute kernel on a data array and a scalar value
/// LEFT is array, RIGHT is scalar value
macro_rules! compute_op_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
if $RIGHT.is_null() {
Ok(Arc::new(new_null_array($LEFT.data_type(), $LEFT.len())))
} else {
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast left side array");
Ok(Arc::new(paste::expr! {[<$OP _scalar>]}(
&ll,
$RIGHT.try_into()?,
)?))
}
}};
}

/// Invoke a dyn compute kernel on a data array and a scalar value
/// LEFT is Primitive or Dictionary array of numeric values, RIGHT is scalar value
/// OP_TYPE is the return type of scalar function
Expand Down Expand Up @@ -448,31 +414,6 @@ macro_rules! binary_string_array_op {
}};
}

/// Invoke a compute kernel on a pair of arrays
/// The binary_primitive_array_op macro only evaluates for primitive types
/// like integers and floats.
macro_rules! binary_primitive_array_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
match $LEFT.data_type() {
DataType::Decimal128(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, Decimal128Array),
DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array),
DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array),
DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array),
DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
other => Err(DataFusionError::Internal(format!(
"Data type {:?} not supported for binary operation '{}' on primitive arrays",
other, stringify!($OP)
))),
}
}};
}

/// Invoke a compute kernel on a pair of arrays
/// The binary_primitive_array_op macro only evaluates for primitive types
/// like integers and floats.
Expand Down Expand Up @@ -525,32 +466,6 @@ macro_rules! binary_primitive_array_op_dyn_scalar {
}}
}

/// Invoke a compute kernel on an array and a scalar
/// The binary_primitive_array_op_scalar macro only evaluates for primitive
/// types like integers and floats.
macro_rules! binary_primitive_array_op_scalar {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
let result: Result<Arc<dyn Array>> = match $LEFT.data_type() {
DataType::Decimal128(_,_) => compute_decimal_op_scalar!($LEFT, $RIGHT, $OP, Decimal128Array),
DataType::Int8 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int8Array),
DataType::Int16 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int16Array),
DataType::Int32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int32Array),
DataType::Int64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Int64Array),
DataType::UInt8 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt8Array),
DataType::UInt16 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt16Array),
DataType::UInt32 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt32Array),
DataType::UInt64 => compute_op_scalar!($LEFT, $RIGHT, $OP, UInt64Array),
DataType::Float32 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float32Array),
DataType::Float64 => compute_op_scalar!($LEFT, $RIGHT, $OP, Float64Array),
other => Err(DataFusionError::Internal(format!(
"Data type {:?} not supported for scalar operation '{}' on primitive array",
other, stringify!($OP)
))),
};
Some(result)
}};
}

/// The binary_array_op macro includes types that extend beyond the primitive,
/// such as Utf8 strings.
#[macro_export]
Expand Down Expand Up @@ -1128,8 +1043,7 @@ impl BinaryExpr {
binary_primitive_array_op_dyn_scalar!(array, scalar, divide)
}
Operator::Modulo => {
// todo: change to binary_primitive_array_op_dyn_scalar! once modulo is implemented
binary_primitive_array_op_scalar!(array, scalar, modulus)
binary_primitive_array_op_dyn_scalar!(array, scalar, modulus)
}
Operator::RegexMatch => binary_string_array_flag_op_scalar!(
array,
Expand Down Expand Up @@ -1239,7 +1153,9 @@ impl BinaryExpr {
Operator::Divide => {
binary_primitive_array_op_dyn!(left, right, divide_dyn_opt)
}
Operator::Modulo => binary_primitive_array_op!(left, right, modulus),
Operator::Modulo => {
binary_primitive_array_op_dyn!(left, right, modulus_dyn)
}
Operator::And => {
if left_data_type == &DataType::Boolean {
boolean_op!(&left, &right, and_kleene)
Expand Down Expand Up @@ -2638,6 +2554,201 @@ mod tests {
Ok(())
}

#[test]
#[cfg(feature = "dictionary_expressions")]
fn modulus_op_dict() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
"a",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
true,
),
Field::new(
"b",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
true,
),
]);

let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();

dict_builder.append(1)?;
dict_builder.append_null();
dict_builder.append(2)?;
dict_builder.append(5)?;
dict_builder.append(0)?;

let a = dict_builder.finish();

let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
let keys = Int8Array::from(vec![0, 1, 1, 2, 1]);
let b = DictionaryArray::try_new(&keys, &b)?;

apply_arithmetic::<Int32Type>(
Arc::new(schema),
vec![Arc::new(a), Arc::new(b)],
Operator::Modulo,
Int32Array::from(vec![Some(0), None, Some(0), Some(1), Some(0)]),
)?;

Ok(())
}

#[test]
#[cfg(feature = "dictionary_expressions")]
fn modulus_op_dict_decimal() -> Result<()> {
let schema = Schema::new(vec![
Field::new(
"a",
DataType::Dictionary(
Box::new(DataType::Int8),
Box::new(DataType::Decimal128(10, 0)),
),
true,
),
Field::new(
"b",
DataType::Dictionary(
Box::new(DataType::Int8),
Box::new(DataType::Decimal128(10, 0)),
),
true,
),
]);

let value = 123;
let decimal_array = Arc::new(create_decimal_array(
&[
Some(value),
Some(value + 2),
Some(value - 1),
Some(value + 1),
],
10,
0,
)) as ArrayRef;

let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]);
let a = DictionaryArray::try_new(&keys, &decimal_array)?;

let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]);
let decimal_array = create_decimal_array(
&[
Some(value + 1),
Some(value + 3),
Some(value),
Some(value + 2),
],
10,
0,
);
let b = DictionaryArray::try_new(&keys, &decimal_array)?;

apply_arithmetic(
Arc::new(schema),
vec![Arc::new(a), Arc::new(b)],
Operator::Modulo,
create_decimal_array(&[Some(123), None, None, Some(1), Some(0)], 10, 0),
)?;

Ok(())
}

#[test]
fn modulus_op_scalar() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);

apply_arithmetic_scalar(
Arc::new(schema),
vec![Arc::new(a)],
Operator::Modulo,
ScalarValue::Int32(Some(2)),
Arc::new(Int32Array::from(vec![1, 0, 1, 0, 1])),
)?;

Ok(())
}

#[test]
fn modules_op_dict_scalar() -> Result<()> {
let schema = Schema::new(vec![Field::new(
"a",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)),
true,
)]);

let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();

dict_builder.append(1)?;
dict_builder.append_null();
dict_builder.append(2)?;
dict_builder.append(5)?;

let a = dict_builder.finish();

let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();

dict_builder.append(1)?;
dict_builder.append_null();
dict_builder.append(0)?;
dict_builder.append(1)?;
let expected = dict_builder.finish();

apply_arithmetic_scalar(
Arc::new(schema),
vec![Arc::new(a)],
Operator::Modulo,
ScalarValue::Dictionary(
Box::new(DataType::Int8),
Box::new(ScalarValue::Int32(Some(2))),
),
Arc::new(expected),
)?;

Ok(())
}

#[test]
fn modulus_op_dict_scalar_decimal() -> Result<()> {
let schema = Schema::new(vec![Field::new(
"a",
DataType::Dictionary(
Box::new(DataType::Int8),
Box::new(DataType::Decimal128(10, 0)),
),
true,
)]);

let value = 123;
let decimal_array = Arc::new(create_decimal_array(
&[Some(value), None, Some(value - 1), Some(value + 1)],
10,
0,
)) as ArrayRef;

let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let a = DictionaryArray::try_new(&keys, &decimal_array)?;

let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let decimal_array =
create_decimal_array(&[Some(1), None, Some(0), Some(0)], 10, 0);
let expected = DictionaryArray::try_new(&keys, &decimal_array)?;

apply_arithmetic_scalar(
Arc::new(schema),
vec![Arc::new(a)],
Operator::Modulo,
ScalarValue::Dictionary(
Box::new(DataType::Int8),
Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
),
Arc::new(expected),
)?;

Ok(())
}

fn apply_arithmetic<T: ArrowNumericType>(
schema: SchemaRef,
data: Vec<ArrayRef>,
Expand Down
Loading

0 comments on commit d77ccc2

Please sign in to comment.