Skip to content

Commit

Permalink
feat: Add Exponentiate operator
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Mar 20, 2024
1 parent 4ede2d0 commit 9ba8150
Show file tree
Hide file tree
Showing 15 changed files with 65 additions and 19 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion datafusion-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ repository = "https://github.com/apache/arrow-datafusion"
rust-version = "1.59"

[dependencies]
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "70e71d1829e6302333e965fde9d0ee56d6415ef4" }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "d9c12d71b655d356c5a287226a763638417972e9" }
clap = { version = "3", features = ["derive", "cargo"] }
datafusion = { path = "../datafusion/core", version = "7.0.0" }
dirs = "4.0.0"
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ path = "examples/avro_sql.rs"
required-features = ["datafusion/avro"]

[dev-dependencies]
arrow-flight = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "70e71d1829e6302333e965fde9d0ee56d6415ef4" }
arrow-flight = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "d9c12d71b655d356c5a287226a763638417972e9" }
async-trait = "0.1.41"
datafusion = { path = "../datafusion/core" }
futures = "0.3"
Expand Down
4 changes: 2 additions & 2 deletions datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ jit = ["cranelift-module"]
pyarrow = ["pyo3"]

[dependencies]
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "70e71d1829e6302333e965fde9d0ee56d6415ef4", features = ["prettyprint"] }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "d9c12d71b655d356c5a287226a763638417972e9", features = ["prettyprint"] }
avro-rs = { version = "0.13", features = ["snappy"], optional = true }
cranelift-module = { version = "0.82.0", optional = true }
ordered-float = "2.10"
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "70e71d1829e6302333e965fde9d0ee56d6415ef4", features = ["arrow"], optional = true }
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "d9c12d71b655d356c5a287226a763638417972e9", features = ["arrow"], optional = true }
pyo3 = { version = "0.16", optional = true }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
4 changes: 2 additions & 2 deletions datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ unicode_expressions = ["datafusion-physical-expr/regex_expressions"]

[dependencies]
ahash = { version = "0.7", default-features = false }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "70e71d1829e6302333e965fde9d0ee56d6415ef4", features = ["prettyprint"] }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "d9c12d71b655d356c5a287226a763638417972e9", features = ["prettyprint"] }
async-trait = "0.1.41"
avro-rs = { version = "0.13", features = ["snappy"], optional = true }
chrono = { version = "0.4", default-features = false }
Expand All @@ -73,7 +73,7 @@ num-traits = { version = "0.2", optional = true }
num_cpus = "1.13.0"
ordered-float = "2.10"
parking_lot = "0.12"
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "70e71d1829e6302333e965fde9d0ee56d6415ef4", features = ["arrow"] }
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "d9c12d71b655d356c5a287226a763638417972e9", features = ["arrow"] }
paste = "^1.0"
pin-project-lite= "^0.2.7"
pyo3 = { version = "0.16", optional = true }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/fuzz-utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "70e71d1829e6302333e965fde9d0ee56d6415ef4", features = ["prettyprint"] }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "d9c12d71b655d356c5a287226a763638417972e9", features = ["prettyprint"] }
env_logger = "0.9.0"
rand = "0.8"
3 changes: 3 additions & 0 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight),
BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft),
BinaryOperator::StringConcat => Ok(Operator::StringConcat),
// TODO: PGExponentiation needs to be introduced, but DF doesn't pass dialect
// so using BitwiseXor is safe for now since it's not implemented anyway
BinaryOperator::BitwiseXor => Ok(Operator::Exponentiate),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported SQL binary operator {:?}",
op
Expand Down
2 changes: 1 addition & 1 deletion datafusion/cube_ext/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ name = "cube_ext"
path = "src/lib.rs"

[dependencies]
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "70e71d1829e6302333e965fde9d0ee56d6415ef4", features = ["prettyprint"] }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "d9c12d71b655d356c5a287226a763638417972e9", features = ["prettyprint"] }
chrono = { version = "0.4.16", package = "chrono", default-features = false, features = ["clock"] }
datafusion-common = { path = "../common", version = "7.0.0" }
datafusion-expr = { path = "../expr", version = "7.0.0" }
2 changes: 1 addition & 1 deletion datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ path = "src/lib.rs"

[dependencies]
ahash = { version = "0.7", default-features = false }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "70e71d1829e6302333e965fde9d0ee56d6415ef4", features = ["prettyprint"] }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "d9c12d71b655d356c5a287226a763638417972e9", features = ["prettyprint"] }
datafusion-common = { path = "../common", version = "7.0.0" }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
10 changes: 9 additions & 1 deletion datafusion/expr/src/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ pub fn binary_operator_data_type(
| Operator::Minus
| Operator::Divide
| Operator::Multiply
| Operator::Modulo => Ok(result_type),
| Operator::Modulo
| Operator::Exponentiate => Ok(result_type),
// string operations return the same values as the common coerced type
Operator::StringConcat => Ok(result_type),
}
Expand Down Expand Up @@ -117,6 +118,8 @@ pub fn coerce_types(
Operator::Modulo | Operator::Divide | Operator::Multiply => {
mathematics_numerical_coercion(op, lhs_type, rhs_type)
}
// Exponentiate is fixed type, handled inside function
Operator::Exponentiate => mathematics_numerical_coercion(op, lhs_type, rhs_type),
Operator::RegexMatch
| Operator::RegexIMatch
| Operator::RegexNotMatch
Expand Down Expand Up @@ -318,6 +321,11 @@ fn mathematics_numerical_coercion(
return None;
};

// exponentiation is always Float64
if mathematics_op == &Operator::Exponentiate {
return Some(Float64);
}

// same type => all good
if lhs_type == rhs_type {
return Some(lhs_type.clone());
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ pub enum Operator {
Divide,
/// Remainder operator, like `%`
Modulo,
/// Exponentiation operator, like `^`
Exponentiate,
/// Logical AND, like `&&`
And,
/// Logical OR, like `||`
Expand Down Expand Up @@ -97,6 +99,7 @@ impl fmt::Display for Operator {
Operator::Multiply => "*",
Operator::Divide => "/",
Operator::Modulo => "%",
Operator::Exponentiate => "^",
Operator::And => "AND",
Operator::Or => "OR",
Operator::Like => "LIKE",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ path = "src/lib.rs"
jit = []

[dependencies]
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "70e71d1829e6302333e965fde9d0ee56d6415ef4" }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "d9c12d71b655d356c5a287226a763638417972e9" }
cranelift = "0.82.0"
cranelift-jit = "0.82.0"
cranelift-module = "0.82.0"
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ unicode_expressions = ["unicode-segmentation"]

[dependencies]
ahash = { version = "0.7", default-features = false }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "70e71d1829e6302333e965fde9d0ee56d6415ef4", features = ["prettyprint"] }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "d9c12d71b655d356c5a287226a763638417972e9", features = ["prettyprint"] }
blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
chrono = { version = "0.4.20", default-features = false }
Expand Down
36 changes: 34 additions & 2 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use arrow::array::TimestampMillisecondArray;
use arrow::array::*;
use arrow::compute::kernels::arithmetic::{
add, add_scalar, divide, divide_scalar, modulus, modulus_scalar, multiply,
multiply_scalar, subtract, subtract_scalar,
multiply_scalar, powf, subtract, subtract_scalar,
};
use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene};
use arrow::compute::kernels::comparison::{
Expand Down Expand Up @@ -1533,6 +1533,7 @@ impl BinaryExpr {
Operator::Multiply => binary_primitive_array_op!(left, right, multiply),
Operator::Divide => binary_primitive_array_op!(left, right, divide),
Operator::Modulo => binary_primitive_array_op!(left, right, modulus),
Operator::Exponentiate => compute_op!(left, right, powf, Float64Array),
Operator::And => {
if left_data_type == &DataType::Boolean {
boolean_op!(left, right, and_kleene)
Expand Down Expand Up @@ -1694,7 +1695,7 @@ pub fn binary(
mod tests {
use super::*;
use crate::expressions::{col, lit};
use arrow::datatypes::{ArrowNumericType, Field, Int32Type, SchemaRef};
use arrow::datatypes::{ArrowNumericType, Field, Float64Type, Int32Type, SchemaRef};
use arrow::util::display::array_value_to_string;
use datafusion_common::Result;

Expand Down Expand Up @@ -1884,6 +1885,18 @@ mod tests {
DataType::Float32,
vec![2f32]
);
test_coercion!(
Int64Array,
DataType::Int64,
vec![2i64],
Int64Array,
DataType::Int64,
vec![3i64],
Operator::Exponentiate,
Float64Array,
DataType::Float64,
vec![8f64]
);
test_coercion!(
StringArray,
DataType::Utf8,
Expand Down Expand Up @@ -2269,6 +2282,25 @@ mod tests {
Ok(())
}

#[test]
fn exponentiate_op() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
]);
let a = Int32Array::from(vec![2, 3, 4, -5, 16]);
let b = Float64Array::from(vec![2.0, 5.0, -0.5, 2.0, 1.5]);

apply_arithmetic::<Float64Type>(
Arc::new(schema),
vec![Arc::new(a), Arc::new(b)],
Operator::Exponentiate,
Float64Array::from(vec![4.0, 243.0, 0.5, 25.0, 64.0]),
)?;

Ok(())
}

fn apply_arithmetic<T: ArrowNumericType>(
schema: SchemaRef,
data: Vec<ArrayRef>,
Expand Down

0 comments on commit 9ba8150

Please sign in to comment.