From 18b657bab01d1c17468c61c93221a3ce5cf784cc Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 18 Nov 2023 20:45:34 +0800 Subject: [PATCH 1/4] squash for rebase Signed-off-by: jayzhan211 --- datafusion/common/src/scalar.rs | 31 ++++ datafusion/expr/src/built_in_function.rs | 41 +++-- datafusion/expr/src/expr_fn.rs | 7 + .../expr/src/type_coercion/aggregates.rs | 28 ++-- datafusion/physical-expr/src/aggregate/sum.rs | 147 +++++++++++++++++- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sql/src/expr/function.rs | 26 +++- datafusion/sqllogictest/test_files/array.slt | 46 +++++- docs/source/user-guide/expressions.md | 1 + .../source/user-guide/sql/scalar_functions.md | 21 +++ 14 files changed, 316 insertions(+), 41 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index e8dac2a7f486..1d56ddad01eb 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -682,6 +682,37 @@ impl ScalarValue { } } + /// Return a new `ScalarValue::List` given a `Vec` of primitive values + pub fn new_primitives( + values: Vec>, + d: &DataType, + ) -> Result { + if values.is_empty() { + return d.try_into(); + } + + let mut array = Vec::with_capacity(values.len()); + let mut nulls = Vec::with_capacity(values.len()); + + for a in values { + match a { + Some(v) => { + array.push(v); + nulls.push(true); + } + None => { + array.push(T::Native::default()); + nulls.push(false); + } + } + } + + let arr = PrimitiveArray::::new(array.into(), Some(NullBuffer::from(nulls))) + .with_data_type(d.clone()); + + Ok(ScalarValue::List(Arc::new(arr))) + } + /// Create a decimal Scalar from value/precision and scale. pub fn try_new_decimal128(value: i128, precision: u8, scale: i8) -> Result { // make sure the precision and scale is valid diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index e9030ebcc00f..f4a0cbfe398f 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -128,6 +128,8 @@ pub enum BuiltinScalarFunction { Cot, // array functions + /// array_aggregate + ArrayAggregate, /// array_append ArrayAppend, /// array_concat @@ -389,6 +391,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Tanh => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::ArrayAppend => Volatility::Immutable, + BuiltinScalarFunction::ArrayAggregate => Volatility::Immutable, BuiltinScalarFunction::ArrayConcat => Volatility::Immutable, BuiltinScalarFunction::ArrayEmpty => Volatility::Immutable, BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable, @@ -534,6 +537,7 @@ impl BuiltinScalarFunction { Ok(data_type) } BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayAggregate => unimplemented!("ArrayAggregate is based on Aggreation function, so no return value for it."), BuiltinScalarFunction::ArrayConcat => { let mut expr_type = Null; let mut max_dims = 0; @@ -893,23 +897,24 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()), BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()), - BuiltinScalarFunction::ArrayHasAll + + BuiltinScalarFunction::ArrayAggregate + | BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny - | BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayLength => { - Signature::variadic_any(self.volatility()) - } - BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()), - BuiltinScalarFunction::ArrayPosition => { - Signature::variadic_any(self.volatility()) + | BuiltinScalarFunction::ArrayHas + | BuiltinScalarFunction::ArrayPositions + | BuiltinScalarFunction::ArrayPrepend + | BuiltinScalarFunction::ArrayRepeat + | BuiltinScalarFunction::ArrayRemove + | BuiltinScalarFunction::ArrayRemoveAll => { + Signature::any(2, self.volatility()) } - BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), - BuiltinScalarFunction::ArrayRemoveAll => Signature::any(2, self.volatility()), - BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()), + + BuiltinScalarFunction::ArrayRemoveN + | BuiltinScalarFunction::ArrayReplace + | BuiltinScalarFunction::ArrayReplaceAll + | BuiltinScalarFunction::ArraySlice => Signature::any(3, self.volatility()), + BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()), BuiltinScalarFunction::ArrayReplaceAll => { Signature::any(3, self.volatility()) @@ -1509,6 +1514,12 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], // array functions + BuiltinScalarFunction::ArrayAggregate => &[ + "array_aggregate", + "list_aggregate", + "array_aggr", + "list_aggr", + ], BuiltinScalarFunction::ArrayAppend => &[ "array_append", "list_append", diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 674d2a34df38..1616842a0f1b 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -597,6 +597,13 @@ scalar_expr!( "returns the array without the first element." ); +scalar_expr!( + ArrayAggregate, + array_aggregate, + array name, + "allows the execution of arbitrary existing aggregate functions `name` on the elements of a list" +); + nary_scalar_expr!(ArrayConcat, array_concat, "concatenates arrays."); scalar_expr!( ArrayHas, diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 7128b575978a..3474069a3d9e 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -16,12 +16,12 @@ // under the License. use arrow::datatypes::{ - DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DataType, Field, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; use datafusion_common::{internal_err, plan_err, DataFusionError, Result}; -use std::ops::Deref; +use std::{ops::Deref, sync::Arc}; use crate::{AggregateFunction, Signature, TypeSignature}; @@ -118,6 +118,16 @@ pub fn coerce_types( Dictionary(_, v) => { return coerce_types(agg_fun, &[v.as_ref().clone()], signature) } + List(field) => { + let coerce_types = + coerce_types(agg_fun, &[field.data_type().clone()], signature)?; + let data_type = coerce_types[0].clone(); + List(Arc::new(Field::new( + field.name(), + data_type, + field.is_nullable(), + ))) + } _ => { return plan_err!( "The function {:?} does not support inputs of type {:?}.", @@ -411,6 +421,7 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal256(new_precision, *scale)) } + DataType::List(field) => sum_return_type(field.data_type()), other => plan_err!("SUM does not support type \"{other:?}\""), } } @@ -505,19 +516,6 @@ pub fn is_bool_and_or_support_arg_type(arg_type: &DataType) -> bool { matches!(arg_type, DataType::Boolean) } -pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { - match arg_type { - DataType::Dictionary(_, dict_value_type) => { - is_sum_support_arg_type(dict_value_type.as_ref()) - } - _ => matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) - ), - } -} - pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { match arg_type { DataType::Dictionary(_, dict_value_type) => { diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 03f666cc4e5d..674b24bd1feb 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -24,16 +24,21 @@ use super::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; +use arrow::array::ArrayRef; use arrow::compute::sum; use arrow::datatypes::DataType; -use arrow::{array::ArrayRef, datatypes::Field}; +use arrow::datatypes::Field; use arrow_array::cast::AsArray; use arrow_array::types::{ Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, }; use arrow_array::{Array, ArrowNativeTypeOp, ArrowNumericType}; use arrow_buffer::ArrowNativeType; -use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::cast::{as_list_array, as_primitive_array}; +use datafusion_common::utils::wrap_into_list_array; +use datafusion_common::{ + internal_err, not_impl_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::type_coercion::aggregates::sum_return_type; use datafusion_expr::Accumulator; @@ -86,6 +91,27 @@ macro_rules! downcast_sum { } pub(crate) use downcast_sum; +// TODO: Replace with `downcast_sum` after most of the AggregateExpr differentiate `return_data_type` and `data_type` +// The reason we have this is because using the name `return_data_type` makes more much sense to me, +// instead of changing `data_type` to `return_data_type` all the AggregateExpr that have `downcast_sum`, introduce v2 is better. +macro_rules! downcast_sum_v2 { + ($s:ident, $helper:ident) => { + match $s.return_data_type { + DataType::UInt64 => $helper!(UInt64Type, $s.return_data_type), + DataType::Int64 => $helper!(Int64Type, $s.return_data_type), + DataType::Float64 => $helper!(Float64Type, $s.return_data_type), + DataType::Decimal128(_, _) => $helper!(Decimal128Type, $s.return_data_type), + DataType::Decimal256(_, _) => $helper!(Decimal256Type, $s.return_data_type), + _ => not_impl_err!( + "Sum not supported for {}: {}", + $s.name, + $s.return_data_type + ), + } + }; +} +pub(crate) use downcast_sum_v2; + impl AggregateExpr for Sum { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -101,12 +127,48 @@ impl AggregateExpr for Sum { } fn create_accumulator(&self) -> Result> { + // TODO: Rewrite `downcast_sum` that accepts `DataType` instead of `self` to extend support to `List` + if let DataType::List(field) = &self.data_type { + match field.data_type() { + DataType::Int64 => { + return Ok(Box::new(ArraySumAccumulator::::new( + self.return_data_type.clone(), + ))) + } + DataType::UInt64 => { + return Ok(Box::new(ArraySumAccumulator::::new( + self.return_data_type.clone(), + ))) + } + DataType::Float64 => { + return Ok(Box::new(ArraySumAccumulator::::new( + self.return_data_type.clone(), + ))) + } + DataType::Decimal128(_, _) => { + return Ok(Box::new(ArraySumAccumulator::::new( + self.return_data_type.clone(), + ))) + } + DataType::Decimal256(_, _) => { + return Ok(Box::new(ArraySumAccumulator::::new( + self.return_data_type.clone(), + ))) + } + _ => unimplemented!( + "Sum not supported for {}: {}", + self.name, + self.data_type + ), + } + } + macro_rules! helper { ($t:ty, $dt:expr) => { Ok(Box::new(SumAccumulator::<$t>::new($dt.clone()))) }; } - downcast_sum!(self, helper) + downcast_sum_v2!(self, helper) } fn state_fields(&self) -> Result> { @@ -138,7 +200,7 @@ impl AggregateExpr for Sum { ))) }; } - downcast_sum!(self, helper) + downcast_sum_v2!(self, helper) } fn reverse_expr(&self) -> Option> { @@ -151,7 +213,7 @@ impl AggregateExpr for Sum { Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) }; } - downcast_sum!(self, helper) + downcast_sum_v2!(self, helper) } } @@ -217,6 +279,81 @@ impl Accumulator for SumAccumulator { } } +/// This accumulator is specialized for `array_sum` or `array_aggregate('sum')` +struct ArraySumAccumulator { + // Each element in `Vec` represents the partial sum of respective row + sum: Vec>, + data_type: DataType, +} + +impl std::fmt::Debug for ArraySumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ArraySumAccumulator({})", self.data_type) + } +} + +impl ArraySumAccumulator { + fn new(data_type: DataType) -> Self { + Self { + // Row number is unknown at the beginning, so we use an empty vector + sum: vec![], + data_type, + } + } +} + +impl Accumulator for ArraySumAccumulator { + fn state(&self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + // There are two kinds of input, PrimitiveArray and ListArray + // ListArray is for multiple-rows input, and PrimitiveArray is for single-row input + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // Wrap single-row input into multiple-rows input and use the same logic as multiple-rows input + let list_values = match as_list_array(&values[0]) { + Ok(arr) => arr.to_owned(), + Err(_) => wrap_into_list_array(values[0].clone()), + }; + + let row_number = list_values.len(); + + if self.sum.is_empty() { + self.sum.resize(row_number, None); + } else if self.sum.len() < row_number { + return internal_err!("ArraySumAccumulator::update_batch only support consistent row number, got {} and {}", self.sum.len(), row_number); + } + + for (i, values) in list_values.iter().enumerate() { + if let Some(values) = values { + let values = as_primitive_array::(&values)?; + if let Some(x) = sum(values) { + let v = self.sum[i].get_or_insert(T::Native::usize_as(0)); + *v = v.add_wrapping(x); + } + } else { + return internal_err!( + "ArraySumAccumulator::update_batch got null values" + ); + } + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&self) -> Result { + ScalarValue::new_primitives::(self.sum.clone(), &self.data_type) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + /// This accumulator incrementally computes sums over a sliding window /// /// This is separate from [`SumAccumulator`] as requires additional state diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9d508078c705..2455da27392b 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -641,6 +641,7 @@ enum ScalarFunction { ArrayExcept = 123; ArrayPopFront = 124; Levenshtein = 125; + ArrayAggregate = 126; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0a8f415e20c5..49bb9732aa5a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20849,6 +20849,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayExcept => "ArrayExcept", Self::ArrayPopFront => "ArrayPopFront", Self::Levenshtein => "Levenshtein", + Self::ArrayAggregate => "ArrayAggregate", }; serializer.serialize_str(variant) } @@ -20986,6 +20987,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayExcept", "ArrayPopFront", "Levenshtein", + "ArrayAggregate", ]; struct GeneratedVisitor; @@ -21152,6 +21154,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayExcept" => Ok(ScalarFunction::ArrayExcept), "ArrayPopFront" => Ok(ScalarFunction::ArrayPopFront), "Levenshtein" => Ok(ScalarFunction::Levenshtein), + "ArrayAggregate" => Ok(ScalarFunction::ArrayAggregate), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 84fb84b9487e..e0b84ac93555 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2572,6 +2572,7 @@ pub enum ScalarFunction { ArrayExcept = 123, ArrayPopFront = 124, Levenshtein = 125, + ArrayAggregate = 126, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2706,6 +2707,7 @@ impl ScalarFunction { ScalarFunction::ArrayExcept => "ArrayExcept", ScalarFunction::ArrayPopFront => "ArrayPopFront", ScalarFunction::Levenshtein => "Levenshtein", + ScalarFunction::ArrayAggregate => "ArrayAggregate", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2837,6 +2839,7 @@ impl ScalarFunction { "ArrayExcept" => Some(Self::ArrayExcept), "ArrayPopFront" => Some(Self::ArrayPopFront), "Levenshtein" => Some(Self::Levenshtein), + "ArrayAggregate" => Some(Self::ArrayAggregate), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 4ae45fa52162..12c5fa0af069 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -462,6 +462,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Ltrim => Self::Ltrim, ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::ToTimestamp => Self::ToTimestamp, + ScalarFunction::ArrayAggregate => Self::ArrayAggregate, ScalarFunction::ArrayAppend => Self::ArrayAppend, ScalarFunction::ArrayConcat => Self::ArrayConcat, ScalarFunction::ArrayEmpty => Self::ArrayEmpty, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index cf66e3ddd5b5..c8d78da8ff7e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1477,6 +1477,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Ltrim => Self::Ltrim, BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp, + BuiltinScalarFunction::ArrayAggregate => Self::ArrayAggregate, BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend, BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat, BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty, diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index c77ef64718bb..10212c50b5ab 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -17,7 +17,8 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ - not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, Result, + internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, DataFusionError, + Result, ScalarValue, }; use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; use datafusion_expr::function::suggest_valid_function; @@ -71,7 +72,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // next, scalar built-in if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { - let args = self.function_args_to_expr(args, schema, planner_context)?; + let args = + self.function_args_to_expr(args, schema, planner_context)?; + + // Translate array_aggregate to aggregate function with array argument. + if fun == BuiltinScalarFunction::ArrayAggregate { + let fun = match &args[1] { + Expr::Literal(ScalarValue::Utf8(Some(name))) => match name.as_str() { + "sum" => AggregateFunction::Sum, + _ => { + return not_impl_err!( + "Aggregate function {name} is not implemented" + ) + } + }, + _ => return internal_err!("Aggregate function name is not a string"), + }; + let args = vec![args[0].to_owned()]; + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( + fun, args, false, None, None, + ))); + } + return Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))); }; diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 61f190e7baf6..cdf3b688deff 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -234,10 +234,10 @@ AS VALUES statement ok CREATE TABLE arrays_values_without_nulls AS VALUES - (make_array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 1, 1, ',', [2,3]), - (make_array(11, 12, 13, 14, 15, 16, 17, 18, 19, 20), 12, 2, '.', [4,5]), - (make_array(21, 22, 23, 24, 25, 26, 27, 28, 29, 30), 23, 3, '-', [6,7]), - (make_array(31, 32, 33, 34, 35, 26, 37, 38, 39, 40), 34, 4, 'ok', [8,9]) + (make_array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), 1, 1, ',', [2,3], make_array(1.1, 2.2)), + (make_array(11, 12, 13, 14, 15, 16, 17, 18, 19, 20), 12, 2, '.', [4,5], make_array(3.1, -3.1)), + (make_array(21, 22, 23, 24, 25, 26, 27, 28, 29, 30), 23, 3, '-', [6,7], make_array(3.1, -1.1)), + (make_array(31, 32, 33, 34, 35, 26, 37, 38, 39, 40), 34, 4, 'ok', [8,9], make_array(-1.2, -1.3)) ; statement ok @@ -3077,6 +3077,44 @@ select string_to_list(e, 'm') from values; [adipiscing] NULL +# array aggregate function +## array aggregate + +### sum +statement ok +set datafusion.execution.target_partitions = 1; + +query IRI +select +array_aggregate([1, 3, 5, 7], 'sum'), +array_aggregate([1.1, 2.2, 3.3], 'sum'), +array_aggregate([1, -1, 0, 23], 'sum'); +---- +16 6.6 23 + +# TODO: Support nulls in array. +# query error DataFusion error: This feature is not implemented: Arrays with different types are not supported: \{Null, Int64\} +# select array_aggregate([1, null, 3, null], 'sum'); + +query ?? +select column1, column6 from arrays_values_without_nulls; +---- +[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] [1.1, 2.2] +[11, 12, 13, 14, 15, 16, 17, 18, 19, 20] [3.1, -3.1] +[21, 22, 23, 24, 25, 26, 27, 28, 29, 30] [3.1, -1.1] +[31, 32, 33, 34, 35, 26, 37, 38, 39, 40] [-1.2, -1.3] + +# I need to set target partition to 1, otherwise the we will get NullArray for some of partition executions. +query IR +select array_aggregate(column1, 'sum'), + array_aggregate(column6, 'sum') +from arrays_values_without_nulls; +---- +55 3.3 +155 0 +255 2 +345 -2.5 + ### Delete tables statement ok diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 257c50dfa497..835c9c946f9b 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -209,6 +209,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | Syntax | Description | | ------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| array_aggregate(array, name) | Allows the execution of arbitrary existing aggregate functions on the elements of a list. `array_aggregate([1, 2, 3], 'sum') -> 6` | | array_append(array, element) | Appends an element to the end of an array. `array_append([1, 2, 3], 4) -> [1, 2, 3, 4]` | | array_concat(array[, ..., array_n]) | Concatenates arrays. `array_concat([1, 2, 3], [4, 5, 6]) -> [1, 2, 3, 4, 5, 6]` | | array_has(array, element) | Returns true if the array contains the element `array_has([1,2,3], 1) -> true` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index eda46ef8a73b..beb21c1775cd 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1518,6 +1518,7 @@ from_unixtime(expression) ## Array Functions +- [array_aggregate](#array_aggregate) - [array_append](#array_append) - [array_cat](#array_cat) - [array_concat](#array_concat) @@ -1578,6 +1579,26 @@ from_unixtime(expression) - [trim_array](#trim_array) - [range](#range) +### `array_aggregate` + +Allows the execution of arbitrary existing aggregate function `name` on the elements of a list. + +``` +array_aggregate(array, name) +``` + +#### Arguments + +- **array**: Array expression. + Can be a constant, column, or function, and any combination of array operators. +- **name**: Aggregate function name. + +#### Aliases + +- list_aggregate +- array_aggr +- list_aggr + ### `array_append` Appends an element to the end of an array. From d3bd12ffe936d92bd2a1b600a430b782a338c3b5 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 18 Nov 2023 21:22:58 +0800 Subject: [PATCH 2/4] cleanup Signed-off-by: jayzhan211 --- datafusion/expr/src/built_in_function.rs | 34 +++++++-------- .../expr/src/type_coercion/aggregates.rs | 4 +- datafusion/physical-expr/src/aggregate/sum.rs | 41 +++++-------------- datafusion/physical-expr/src/functions.rs | 3 ++ .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/sql/src/expr/function.rs | 3 +- 6 files changed, 34 insertions(+), 52 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index f4a0cbfe398f..b04fb9d3a49a 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -886,7 +886,8 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { - BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayAggregate + | BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()), BuiltinScalarFunction::ArrayConcat => { @@ -897,24 +898,23 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayElement => Signature::any(2, self.volatility()), BuiltinScalarFunction::ArrayExcept => Signature::any(2, self.volatility()), BuiltinScalarFunction::Flatten => Signature::any(1, self.volatility()), - - BuiltinScalarFunction::ArrayAggregate - | BuiltinScalarFunction::ArrayHasAll + BuiltinScalarFunction::ArrayHasAll | BuiltinScalarFunction::ArrayHasAny - | BuiltinScalarFunction::ArrayHas - | BuiltinScalarFunction::ArrayPositions - | BuiltinScalarFunction::ArrayPrepend - | BuiltinScalarFunction::ArrayRepeat - | BuiltinScalarFunction::ArrayRemove - | BuiltinScalarFunction::ArrayRemoveAll => { - Signature::any(2, self.volatility()) + | BuiltinScalarFunction::ArrayHas => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayLength => { + Signature::variadic_any(self.volatility()) } - - BuiltinScalarFunction::ArrayRemoveN - | BuiltinScalarFunction::ArrayReplace - | BuiltinScalarFunction::ArrayReplaceAll - | BuiltinScalarFunction::ArraySlice => Signature::any(3, self.volatility()), - + BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()), + BuiltinScalarFunction::ArrayPosition => { + Signature::variadic_any(self.volatility()) + } + BuiltinScalarFunction::ArrayPositions => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayPrepend => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRepeat => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRemove => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayRemoveN => Signature::any(3, self.volatility()), + BuiltinScalarFunction::ArrayRemoveAll => Signature::any(2, self.volatility()), + BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()), BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()), BuiltinScalarFunction::ArrayReplaceAll => { Signature::any(3, self.volatility()) diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 3474069a3d9e..8b0120ac127f 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -119,9 +119,9 @@ pub fn coerce_types( return coerce_types(agg_fun, &[v.as_ref().clone()], signature) } List(field) => { - let coerce_types = + let coerced_types = coerce_types(agg_fun, &[field.data_type().clone()], signature)?; - let data_type = coerce_types[0].clone(); + let data_type = coerced_types[0].clone(); List(Arc::new(Field::new( field.name(), data_type, diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 674b24bd1feb..26cba9a29e16 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -35,7 +35,7 @@ use arrow_array::types::{ use arrow_array::{Array, ArrowNativeTypeOp, ArrowNumericType}; use arrow_buffer::ArrowNativeType; use datafusion_common::cast::{as_list_array, as_primitive_array}; -use datafusion_common::utils::wrap_into_list_array; +use datafusion_common::utils::array_into_list_array; use datafusion_common::{ internal_err, not_impl_err, DataFusionError, Result, ScalarValue, }; @@ -91,27 +91,6 @@ macro_rules! downcast_sum { } pub(crate) use downcast_sum; -// TODO: Replace with `downcast_sum` after most of the AggregateExpr differentiate `return_data_type` and `data_type` -// The reason we have this is because using the name `return_data_type` makes more much sense to me, -// instead of changing `data_type` to `return_data_type` all the AggregateExpr that have `downcast_sum`, introduce v2 is better. -macro_rules! downcast_sum_v2 { - ($s:ident, $helper:ident) => { - match $s.return_data_type { - DataType::UInt64 => $helper!(UInt64Type, $s.return_data_type), - DataType::Int64 => $helper!(Int64Type, $s.return_data_type), - DataType::Float64 => $helper!(Float64Type, $s.return_data_type), - DataType::Decimal128(_, _) => $helper!(Decimal128Type, $s.return_data_type), - DataType::Decimal256(_, _) => $helper!(Decimal256Type, $s.return_data_type), - _ => not_impl_err!( - "Sum not supported for {}: {}", - $s.name, - $s.return_data_type - ), - } - }; -} -pub(crate) use downcast_sum_v2; - impl AggregateExpr for Sum { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -132,27 +111,27 @@ impl AggregateExpr for Sum { match field.data_type() { DataType::Int64 => { return Ok(Box::new(ArraySumAccumulator::::new( - self.return_data_type.clone(), + self.return_type.clone(), ))) } DataType::UInt64 => { return Ok(Box::new(ArraySumAccumulator::::new( - self.return_data_type.clone(), + self.return_type.clone(), ))) } DataType::Float64 => { return Ok(Box::new(ArraySumAccumulator::::new( - self.return_data_type.clone(), + self.return_type.clone(), ))) } DataType::Decimal128(_, _) => { return Ok(Box::new(ArraySumAccumulator::::new( - self.return_data_type.clone(), + self.return_type.clone(), ))) } DataType::Decimal256(_, _) => { return Ok(Box::new(ArraySumAccumulator::::new( - self.return_data_type.clone(), + self.return_type.clone(), ))) } _ => unimplemented!( @@ -168,7 +147,7 @@ impl AggregateExpr for Sum { Ok(Box::new(SumAccumulator::<$t>::new($dt.clone()))) }; } - downcast_sum_v2!(self, helper) + downcast_sum!(self, helper) } fn state_fields(&self) -> Result> { @@ -200,7 +179,7 @@ impl AggregateExpr for Sum { ))) }; } - downcast_sum_v2!(self, helper) + downcast_sum!(self, helper) } fn reverse_expr(&self) -> Option> { @@ -213,7 +192,7 @@ impl AggregateExpr for Sum { Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) }; } - downcast_sum_v2!(self, helper) + downcast_sum!(self, helper) } } @@ -313,7 +292,7 @@ impl Accumulator for ArraySumAccumulator { // Wrap single-row input into multiple-rows input and use the same logic as multiple-rows input let list_values = match as_list_array(&values[0]) { Ok(arr) => arr.to_owned(), - Err(_) => wrap_into_list_array(values[0].clone()), + Err(_) => array_into_list_array(values[0].clone()), }; let row_number = list_values.len(); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5a1a68dd2127..24b953018812 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -326,6 +326,9 @@ pub fn create_physical_fun( } // array functions + BuiltinScalarFunction::ArrayAggregate => { + unimplemented!("ArrayAggregate reused the same function as AggregateExpr") + } BuiltinScalarFunction::ArrayAppend => { Arc::new(|args| make_scalar_function(array_expressions::array_append)(args)) } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 12c5fa0af069..6b590d0a8522 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1335,6 +1335,7 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry)) .collect::, _>>()?, )), + ScalarFunction::ArrayAggregate => unimplemented!("ArrayAggregate"), ScalarFunction::ArrayAppend => Ok(array_append( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 10212c50b5ab..9160652b4152 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -72,8 +72,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // next, scalar built-in if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { - let args = - self.function_args_to_expr(args, schema, planner_context)?; + let args = self.function_args_to_expr(args, schema, planner_context)?; // Translate array_aggregate to aggregate function with array argument. if fun == BuiltinScalarFunction::ArrayAggregate { From 24daa144b3490ccb039109784255166b29349070 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 19 Nov 2023 16:23:23 +0800 Subject: [PATCH 3/4] cleanup Signed-off-by: jayzhan211 --- datafusion/physical-expr/src/aggregate/sum.rs | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 26cba9a29e16..db32e8f6b227 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -34,7 +34,7 @@ use arrow_array::types::{ }; use arrow_array::{Array, ArrowNativeTypeOp, ArrowNumericType}; use arrow_buffer::ArrowNativeType; -use datafusion_common::cast::{as_list_array, as_primitive_array}; +use datafusion_common::cast::as_primitive_array; use datafusion_common::utils::array_into_list_array; use datafusion_common::{ internal_err, not_impl_err, DataFusionError, Result, ScalarValue, @@ -106,7 +106,6 @@ impl AggregateExpr for Sum { } fn create_accumulator(&self) -> Result> { - // TODO: Rewrite `downcast_sum` that accepts `DataType` instead of `self` to extend support to `List` if let DataType::List(field) = &self.data_type { match field.data_type() { DataType::Int64 => { @@ -134,11 +133,13 @@ impl AggregateExpr for Sum { self.return_type.clone(), ))) } - _ => unimplemented!( - "Sum not supported for {}: {}", - self.name, - self.data_type - ), + _ => { + return internal_err!( + "Sum not supported for {}: {}", + self.name, + self.data_type + ) + } } } @@ -290,9 +291,10 @@ impl Accumulator for ArraySumAccumulator { // ListArray is for multiple-rows input, and PrimitiveArray is for single-row input fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { // Wrap single-row input into multiple-rows input and use the same logic as multiple-rows input - let list_values = match as_list_array(&values[0]) { - Ok(arr) => arr.to_owned(), - Err(_) => array_into_list_array(values[0].clone()), + let list_values = if let Some(list_arr) = values[0].as_list_opt::() { + list_arr.to_owned() + } else { + array_into_list_array(values[0].clone()) }; let row_number = list_values.len(); From a7eeee55356f68f3ca4d9a8acc13ef60e8f1bd4b Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sun, 26 Nov 2023 20:13:58 +0800 Subject: [PATCH 4/4] bench Signed-off-by: jayzhan211 --- datafusion/core/Cargo.toml | 4 + datafusion/core/benches/aggregate_sum.rs | 126 +++++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 datafusion/core/benches/aggregate_sum.rs diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 0b7aa1509820..f561303cb53d 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -167,3 +167,7 @@ name = "sort" [[bench]] harness = false name = "topk_aggregate" + +[[bench]] +harness = false +name = "aggregate_sum" diff --git a/datafusion/core/benches/aggregate_sum.rs b/datafusion/core/benches/aggregate_sum.rs new file mode 100644 index 000000000000..6f3cbfbba276 --- /dev/null +++ b/datafusion/core/benches/aggregate_sum.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[macro_use] +extern crate criterion; +extern crate arrow; +extern crate datafusion; + +mod data_utils; +use crate::criterion::Criterion; +use arrow_array::{RecordBatch, ArrayRef, Float64Array, PrimitiveArray}; +use arrow_schema::{Schema, Field}; +use data_utils::create_table_provider; +use datafusion::error::Result; +use datafusion::execution::context::SessionContext; +use datafusion_common::ScalarValue; +use datafusion_expr::AggregateFunction; +use datafusion_expr::type_coercion::aggregates::coerce_types; +use datafusion_physical_expr::{expressions::{create_aggregate_expr, try_cast, col}, AggregateExpr}; +use parking_lot::Mutex; +use std::sync::Arc; +use tokio::runtime::Runtime; + +pub fn aggregate( + batch: &RecordBatch, + agg: Arc, +) -> Result { + let mut accum = agg.create_accumulator()?; + let expr = agg.expressions(); + let values = expr + .iter() + .map(|e| { + e.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) + }) + .collect::>>()?; + accum.update_batch(&values)?; + accum.evaluate() +} + +pub fn assert_aggregate( + array: ArrayRef, + function: AggregateFunction, + distinct: bool, + expected: ScalarValue, +) { + let data_type = array.data_type(); + let sig = function.signature(); + let coerced = coerce_types(&function, &[data_type.clone()], &sig).unwrap(); + + let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); + let batch = + RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); + + let input = try_cast( + col("a", &input_schema).unwrap(), + &input_schema, + coerced[0].clone(), + ) + .unwrap(); + + let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); + let agg = + create_aggregate_expr(&function, distinct, &[input], &[], &schema, "agg") + .unwrap(); + + let result = aggregate(&batch, agg).unwrap(); + assert_eq!(expected, result); +} + +fn query(ctx: Arc>, sql: &str) { + let rt = Runtime::new().unwrap(); + let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); + criterion::black_box(rt.block_on(df.collect()).unwrap()); +} + +fn create_context( + partitions_len: usize, + array_len: usize, + batch_size: usize, +) -> Result>> { + let ctx = SessionContext::new(); + let provider = create_table_provider(partitions_len, array_len, batch_size)?; + ctx.register_table("t", provider)?; + Ok(Arc::new(Mutex::new(ctx))) +} + +fn criterion_benchmark(c: &mut Criterion) { + // let partitions_len = 8; + // let array_len = 32768 * 2; // 2^16 + // let batch_size = 2048; // 2^11 + // let ctx = create_context(partitions_len, array_len, batch_size).unwrap(); + + let n = 1000000000; + let vec_of_f64: Vec = (0..=n as usize).map(|x| 1 as f64).collect(); + let a: ArrayRef = Arc::new(Float64Array::from(vec_of_f64)); + // c.bench_function("sum 1e9", |b| b.iter(|| assert_aggregate(a.clone(), AggregateFunction::Sum, false, criterion::black_box(ScalarValue::List(Arc::new(Float64Array::from(vec![1000000001_f64]))))))); + c.bench_function("sum 1e9", |b| b.iter(|| assert_aggregate(a.clone(), AggregateFunction::Sum, false, criterion::black_box(ScalarValue::from(1000000001_f64))))); + + // c.bench_function("aggregate_query_no_group_by 15 12", |b| { + // b.iter(|| { + // query( + // ctx.clone(), + // "SELECT MIN(f64), AVG(f64), COUNT(f64) \ + // FROM t", + // ) + // }) + // }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches);