diff --git a/src/avg.rs b/src/avg.rs deleted file mode 100644 index 816440ac9ade..000000000000 --- a/src/avg.rs +++ /dev/null @@ -1,341 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::compute::sum; -use arrow_array::{ - builder::PrimitiveBuilder, - cast::AsArray, - types::{Float64Type, Int64Type}, - Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray, -}; -use arrow_schema::{DataType, Field}; -use datafusion::logical_expr::{ - type_coercion::aggregates::avg_return_type, Accumulator, EmitTo, GroupsAccumulator, Signature, -}; -use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_physical_expr::expressions::format_state_name; -use std::{any::Any, sync::Arc}; - -use arrow_array::ArrowNativeTypeOp; -use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::Volatility::Immutable; -use datafusion_expr::{AggregateUDFImpl, ReversedUDAF}; -use DataType::*; - -/// AVG aggregate expression -#[derive(Debug, Clone)] -pub struct Avg { - name: String, - signature: Signature, - // expr: Arc, - input_data_type: DataType, - result_data_type: DataType, -} - -impl Avg { - /// Create a new AVG aggregate function - pub fn new(name: impl Into, data_type: DataType) -> Self { - let result_data_type = avg_return_type("avg", &data_type).unwrap(); - - Self { - name: name.into(), - signature: Signature::user_defined(Immutable), - input_data_type: data_type, - result_data_type, - } - } -} - -impl AggregateUDFImpl for Avg { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { - // instantiate specialized accumulator based for the type - match (&self.input_data_type, &self.result_data_type) { - (Float64, Float64) => Ok(Box::::default()), - _ => not_impl_err!( - "AvgAccumulator for ({} --> {})", - self.input_data_type, - self.result_data_type - ), - } - } - - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "sum"), - self.input_data_type.clone(), - true, - ), - Field::new( - format_state_name(&self.name, "count"), - DataType::Int64, - true, - ), - ]) - } - - fn name(&self) -> &str { - &self.name - } - - fn reverse_expr(&self) -> ReversedUDAF { - ReversedUDAF::Identical - } - - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { - true - } - - fn create_groups_accumulator( - &self, - _args: AccumulatorArgs, - ) -> Result> { - // instantiate specialized accumulator based for the type - match (&self.input_data_type, &self.result_data_type) { - (Float64, Float64) => Ok(Box::new(AvgGroupsAccumulator::::new( - &self.input_data_type, - |sum: f64, count: i64| Ok(sum / count as f64), - ))), - - _ => not_impl_err!( - "AvgGroupsAccumulator for ({} --> {})", - self.input_data_type, - self.result_data_type - ), - } - } - - fn default_value(&self, _data_type: &DataType) -> Result { - Ok(ScalarValue::Float64(None)) - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - avg_return_type(self.name(), &arg_types[0]) - } -} - -/// An accumulator to compute the average -#[derive(Debug, Default)] -pub struct AvgAccumulator { - sum: Option, - count: i64, -} - -impl Accumulator for AvgAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::Float64(self.sum), - ScalarValue::from(self.count), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); - self.count += (values.len() - values.null_count()) as i64; - let v = self.sum.get_or_insert(0.); - if let Some(x) = sum(values) { - *v += x; - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - // counts are summed - self.count += sum(states[1].as_primitive::()).unwrap_or_default(); - - // sums are summed - if let Some(x) = sum(states[0].as_primitive::()) { - let v = self.sum.get_or_insert(0.); - *v += x; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - if self.count == 0 { - // If all input are nulls, count will be 0 and we will get null after the division. - // This is consistent with Spark Average implementation. - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64( - self.sum.map(|f| f / self.count as f64), - )) - } - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -/// An accumulator to compute the average of `[PrimitiveArray]`. -/// Stores values as native types, and does overflow checking -/// -/// F: Function that calculates the average value from a sum of -/// T::Native and a total count -#[derive(Debug)] -struct AvgGroupsAccumulator -where - T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send, -{ - /// The type of the returned average - return_data_type: DataType, - - /// Count per group (use i64 to make Int64Array) - counts: Vec, - - /// Sums per group, stored as the native type - sums: Vec, - - /// Function that computes the final average (value / count) - avg_fn: F, -} - -impl AvgGroupsAccumulator -where - T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send, -{ - pub fn new(return_data_type: &DataType, avg_fn: F) -> Self { - Self { - return_data_type: return_data_type.clone(), - counts: vec![], - sums: vec![], - avg_fn, - } - } -} - -impl GroupsAccumulator for AvgGroupsAccumulator -where - T: ArrowNumericType + Send, - F: Fn(T::Native, i64) -> Result + Send, -{ - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - _opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = values[0].as_primitive::(); - let data = values.values(); - - // increment counts, update sums - self.counts.resize(total_num_groups, 0); - self.sums.resize(total_num_groups, T::default_value()); - - let iter = group_indices.iter().zip(data.iter()); - if values.null_count() == 0 { - for (&group_index, &value) in iter { - let sum = &mut self.sums[group_index]; - *sum = (*sum).add_wrapping(value); - self.counts[group_index] += 1; - } - } else { - for (idx, (&group_index, &value)) in iter.enumerate() { - if values.is_null(idx) { - continue; - } - let sum = &mut self.sums[group_index]; - *sum = (*sum).add_wrapping(value); - - self.counts[group_index] += 1; - } - } - - Ok(()) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - _opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 2, "two arguments to merge_batch"); - // first batch is partial sums, second is counts - let partial_sums = values[0].as_primitive::(); - let partial_counts = values[1].as_primitive::(); - // update counts with partial counts - self.counts.resize(total_num_groups, 0); - let iter1 = group_indices.iter().zip(partial_counts.values().iter()); - for (&group_index, &partial_count) in iter1 { - self.counts[group_index] += partial_count; - } - - // update sums - self.sums.resize(total_num_groups, T::default_value()); - let iter2 = group_indices.iter().zip(partial_sums.values().iter()); - for (&group_index, &new_value) in iter2 { - let sum = &mut self.sums[group_index]; - *sum = sum.add_wrapping(new_value); - } - - Ok(()) - } - - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let counts = emit_to.take_needed(&mut self.counts); - let sums = emit_to.take_needed(&mut self.sums); - let mut builder = PrimitiveBuilder::::with_capacity(sums.len()); - let iter = sums.into_iter().zip(counts); - - for (sum, count) in iter { - if count != 0 { - builder.append_value((self.avg_fn)(sum, count)?) - } else { - builder.append_null(); - } - } - let array: PrimitiveArray = builder.finish(); - - Ok(Arc::new(array)) - } - - // return arrays for sums and counts - fn state(&mut self, emit_to: EmitTo) -> Result> { - let counts = emit_to.take_needed(&mut self.counts); - let counts = Int64Array::new(counts.into(), None); - - let sums = emit_to.take_needed(&mut self.sums); - let sums = PrimitiveArray::::new(sums.into(), None) - .with_data_type(self.return_data_type.clone()); - - Ok(vec![ - Arc::new(sums) as ArrayRef, - Arc::new(counts) as ArrayRef, - ]) - } - - fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - + self.sums.capacity() * std::mem::size_of::() - } -} diff --git a/src/avg_decimal.rs b/src/avg_decimal.rs deleted file mode 100644 index 05fc28e58341..000000000000 --- a/src/avg_decimal.rs +++ /dev/null @@ -1,522 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer, compute::sum}; -use arrow_array::{ - builder::PrimitiveBuilder, - cast::AsArray, - types::{Decimal128Type, Int64Type}, - Array, ArrayRef, Decimal128Array, Int64Array, PrimitiveArray, -}; -use arrow_schema::{DataType, Field}; -use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator, Signature}; -use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_physical_expr::expressions::format_state_name; -use std::{any::Any, sync::Arc}; - -use crate::utils::is_valid_decimal_precision; -use arrow_array::ArrowNativeTypeOp; -use arrow_data::decimal::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION}; -use datafusion::logical_expr::Volatility::Immutable; -use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::avg_return_type; -use datafusion_expr::{AggregateUDFImpl, ReversedUDAF}; -use num::{integer::div_ceil, Integer}; -use DataType::*; - -/// AVG aggregate expression -#[derive(Debug, Clone)] -pub struct AvgDecimal { - signature: Signature, - sum_data_type: DataType, - result_data_type: DataType, -} - -impl AvgDecimal { - /// Create a new AVG aggregate function - pub fn new(result_type: DataType, sum_type: DataType) -> Self { - Self { - signature: Signature::user_defined(Immutable), - result_data_type: result_type, - sum_data_type: sum_type, - } - } -} - -impl AggregateUDFImpl for AvgDecimal { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { - match (&self.sum_data_type, &self.result_data_type) { - (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => { - Ok(Box::new(AvgDecimalAccumulator::new( - *sum_scale, - *sum_precision, - *target_precision, - *target_scale, - ))) - } - _ => not_impl_err!( - "AvgDecimalAccumulator for ({} --> {})", - self.sum_data_type, - self.result_data_type - ), - } - } - - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { - Ok(vec![ - Field::new( - format_state_name(self.name(), "sum"), - self.sum_data_type.clone(), - true, - ), - Field::new( - format_state_name(self.name(), "count"), - DataType::Int64, - true, - ), - ]) - } - - fn name(&self) -> &str { - "avg" - } - - fn reverse_expr(&self) -> ReversedUDAF { - ReversedUDAF::Identical - } - - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { - true - } - - fn create_groups_accumulator( - &self, - _args: AccumulatorArgs, - ) -> Result> { - // instantiate specialized accumulator based for the type - match (&self.sum_data_type, &self.result_data_type) { - (Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale)) => { - Ok(Box::new(AvgDecimalGroupsAccumulator::new( - &self.result_data_type, - &self.sum_data_type, - *target_precision, - *target_scale, - *sum_precision, - *sum_scale, - ))) - } - _ => not_impl_err!( - "AvgDecimalGroupsAccumulator for ({} --> {})", - self.sum_data_type, - self.result_data_type - ), - } - } - - fn default_value(&self, _data_type: &DataType) -> Result { - match &self.result_data_type { - Decimal128(target_precision, target_scale) => { - Ok(make_decimal128(None, *target_precision, *target_scale)) - } - _ => not_impl_err!( - "The result_data_type of AvgDecimal should be Decimal128 but got{}", - self.result_data_type - ), - } - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - avg_return_type(self.name(), &arg_types[0]) - } -} - -/// An accumulator to compute the average for decimals -#[derive(Debug)] -struct AvgDecimalAccumulator { - sum: Option, - count: i64, - is_empty: bool, - is_not_null: bool, - sum_scale: i8, - sum_precision: u8, - target_precision: u8, - target_scale: i8, -} - -impl AvgDecimalAccumulator { - pub fn new(sum_scale: i8, sum_precision: u8, target_precision: u8, target_scale: i8) -> Self { - Self { - sum: None, - count: 0, - is_empty: true, - is_not_null: true, - sum_scale, - sum_precision, - target_precision, - target_scale, - } - } - - fn update_single(&mut self, values: &Decimal128Array, idx: usize) { - let v = unsafe { values.value_unchecked(idx) }; - let (new_sum, is_overflow) = match self.sum { - Some(sum) => sum.overflowing_add(v), - None => (v, false), - }; - - if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) { - // Overflow: set buffer accumulator to null - self.is_not_null = false; - return; - } - - self.sum = Some(new_sum); - - if let Some(new_count) = self.count.checked_add(1) { - self.count = new_count; - } else { - self.is_not_null = false; - return; - } - - self.is_not_null = true; - } -} - -fn make_decimal128(value: Option, precision: u8, scale: i8) -> ScalarValue { - ScalarValue::Decimal128(value, precision, scale) -} - -impl Accumulator for AvgDecimalAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale), - ScalarValue::from(self.count), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if !self.is_empty && !self.is_not_null { - // This means there's a overflow in decimal, so we will just skip the rest - // of the computation - return Ok(()); - } - - let values = &values[0]; - let data = values.as_primitive::(); - - self.is_empty = self.is_empty && values.len() == values.null_count(); - - if values.null_count() == 0 { - for i in 0..data.len() { - self.update_single(data, i); - } - } else { - for i in 0..data.len() { - if data.is_null(i) { - continue; - } - self.update_single(data, i); - } - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - // counts are summed - self.count += sum(states[1].as_primitive::()).unwrap_or_default(); - - // sums are summed - if let Some(x) = sum(states[0].as_primitive::()) { - let v = self.sum.get_or_insert(0); - let (result, overflowed) = v.overflowing_add(x); - if overflowed { - // Set to None if overflow happens - self.sum = None; - } else { - *v = result; - } - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32); - let target_min = MIN_DECIMAL_FOR_EACH_PRECISION[self.target_precision as usize - 1]; - let target_max = MAX_DECIMAL_FOR_EACH_PRECISION[self.target_precision as usize - 1]; - - let result = self - .sum - .map(|v| avg(v, self.count as i128, target_min, target_max, scaler)); - - match result { - Some(value) => Ok(make_decimal128( - value, - self.target_precision, - self.target_scale, - )), - _ => Ok(make_decimal128( - None, - self.target_precision, - self.target_scale, - )), - } - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -#[derive(Debug)] -struct AvgDecimalGroupsAccumulator { - /// Tracks if the value is null - is_not_null: BooleanBufferBuilder, - - // Tracks if the value is empty - is_empty: BooleanBufferBuilder, - - /// The type of the avg return type - return_data_type: DataType, - target_precision: u8, - target_scale: i8, - - /// Count per group (use i64 to make Int64Array) - counts: Vec, - - /// Sums per group, stored as i128 - sums: Vec, - - /// The type of the sum - sum_data_type: DataType, - /// This is input_precision + 10 to be consistent with Spark - sum_precision: u8, - sum_scale: i8, -} - -impl AvgDecimalGroupsAccumulator { - pub fn new( - return_data_type: &DataType, - sum_data_type: &DataType, - target_precision: u8, - target_scale: i8, - sum_precision: u8, - sum_scale: i8, - ) -> Self { - Self { - is_not_null: BooleanBufferBuilder::new(0), - is_empty: BooleanBufferBuilder::new(0), - return_data_type: return_data_type.clone(), - target_precision, - target_scale, - sum_data_type: sum_data_type.clone(), - sum_precision, - sum_scale, - counts: vec![], - sums: vec![], - } - } - - fn is_overflow(&self, index: usize) -> bool { - !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index) - } - - fn update_single(&mut self, group_index: usize, value: i128) { - if self.is_overflow(group_index) { - // This means there's a overflow in decimal, so we will just skip the rest - // of the computation - return; - } - - self.is_empty.set_bit(group_index, false); - let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value); - self.counts[group_index] += 1; - - if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) { - // Overflow: set buffer accumulator to null - self.is_not_null.set_bit(group_index, false); - return; - } - - self.sums[group_index] = new_sum; - self.is_not_null.set_bit(group_index, true) - } -} - -fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) { - if builder.len() < capacity { - let additional = capacity - builder.len(); - builder.append_n(additional, true); - } -} - -impl GroupsAccumulator for AvgDecimalGroupsAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - _opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = values[0].as_primitive::(); - let data = values.values(); - - // increment counts, update sums - self.counts.resize(total_num_groups, 0); - self.sums.resize(total_num_groups, 0); - ensure_bit_capacity(&mut self.is_empty, total_num_groups); - ensure_bit_capacity(&mut self.is_not_null, total_num_groups); - - let iter = group_indices.iter().zip(data.iter()); - if values.null_count() == 0 { - for (&group_index, &value) in iter { - self.update_single(group_index, value); - } - } else { - for (idx, (&group_index, &value)) in iter.enumerate() { - if values.is_null(idx) { - continue; - } - self.update_single(group_index, value); - } - } - Ok(()) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - _opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 2, "two arguments to merge_batch"); - // first batch is partial sums, second is counts - let partial_sums = values[0].as_primitive::(); - let partial_counts = values[1].as_primitive::(); - // update counts with partial counts - self.counts.resize(total_num_groups, 0); - let iter1 = group_indices.iter().zip(partial_counts.values().iter()); - for (&group_index, &partial_count) in iter1 { - self.counts[group_index] += partial_count; - } - - // update sums - self.sums.resize(total_num_groups, 0); - let iter2 = group_indices.iter().zip(partial_sums.values().iter()); - for (&group_index, &new_value) in iter2 { - let sum = &mut self.sums[group_index]; - *sum = sum.add_wrapping(new_value); - } - - Ok(()) - } - - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let counts = emit_to.take_needed(&mut self.counts); - let sums = emit_to.take_needed(&mut self.sums); - - let mut builder = PrimitiveBuilder::::with_capacity(sums.len()) - .with_data_type(self.return_data_type.clone()); - let iter = sums.into_iter().zip(counts); - - let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32); - let target_min = MIN_DECIMAL_FOR_EACH_PRECISION[self.target_precision as usize - 1]; - let target_max = MAX_DECIMAL_FOR_EACH_PRECISION[self.target_precision as usize - 1]; - - for (sum, count) in iter { - if count != 0 { - match avg(sum, count as i128, target_min, target_max, scaler) { - Some(value) => { - builder.append_value(value); - } - _ => { - builder.append_null(); - } - } - } else { - builder.append_null(); - } - } - let array: PrimitiveArray = builder.finish(); - - Ok(Arc::new(array)) - } - - // return arrays for sums and counts - fn state(&mut self, emit_to: EmitTo) -> Result> { - let nulls = self.is_not_null.finish(); - let nulls = Some(NullBuffer::new(nulls)); - - let counts = emit_to.take_needed(&mut self.counts); - let counts = Int64Array::new(counts.into(), nulls.clone()); - - let sums = emit_to.take_needed(&mut self.sums); - let sums = - Decimal128Array::new(sums.into(), nulls).with_data_type(self.sum_data_type.clone()); - - Ok(vec![ - Arc::new(sums) as ArrayRef, - Arc::new(counts) as ArrayRef, - ]) - } - - fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - + self.sums.capacity() * std::mem::size_of::() - } -} - -/// Returns the `sum`/`count` as a i128 Decimal128 with -/// target_scale and target_precision and return None if overflows. -/// -/// * sum: The total sum value stored as Decimal128 with sum_scale -/// * count: total count, stored as a i128 (*NOT* a Decimal128 value) -/// * target_min: The minimum output value possible to represent with the target precision -/// * target_max: The maximum output value possible to represent with the target precision -/// * scaler: scale factor for avg -#[inline(always)] -fn avg(sum: i128, count: i128, target_min: i128, target_max: i128, scaler: i128) -> Option { - if let Some(value) = sum.checked_mul(scaler) { - // `sum / count` with ROUND_HALF_UP - let (div, rem) = value.div_rem(&count); - let half = div_ceil(count, 2); - let half_neg = half.neg_wrapping(); - let new_value = match value >= 0 { - true if rem >= half => div.add_wrapping(1), - false if rem <= half_neg => div.sub_wrapping(1), - _ => div, - }; - if new_value >= target_min && new_value <= target_max { - Some(new_value) - } else { - None - } - } else { - None - } -} diff --git a/src/covariance.rs b/src/covariance.rs deleted file mode 100644 index fa3563cdea55..000000000000 --- a/src/covariance.rs +++ /dev/null @@ -1,306 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::any::Any; - -use arrow::{ - array::{ArrayRef, Float64Array}, - compute::cast, - datatypes::{DataType, Field}, -}; -use datafusion::logical_expr::Accumulator; -use datafusion_common::{ - downcast_value, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, -}; -use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::type_coercion::aggregates::NUMERICS; -use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; -use datafusion_physical_expr::expressions::format_state_name; -use datafusion_physical_expr::expressions::StatsType; - -/// COVAR_SAMP and COVAR_POP aggregate expression -/// The implementation mostly is the same as the DataFusion's implementation. The reason -/// we have our own implementation is that DataFusion has UInt64 for state_field count, -/// while Spark has Double for count. -#[derive(Debug, Clone)] -pub struct Covariance { - name: String, - signature: Signature, - stats_type: StatsType, - null_on_divide_by_zero: bool, -} - -impl Covariance { - /// Create a new COVAR aggregate function - pub fn new( - name: impl Into, - data_type: DataType, - stats_type: StatsType, - null_on_divide_by_zero: bool, - ) -> Self { - // the result of covariance just support FLOAT64 data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), - stats_type, - null_on_divide_by_zero, - } - } -} - -impl AggregateUDFImpl for Covariance { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - &self.name - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Float64) - } - fn default_value(&self, _data_type: &DataType) -> Result { - Ok(ScalarValue::Float64(None)) - } - - fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(CovarianceAccumulator::try_new( - self.stats_type, - self.null_on_divide_by_zero, - )?)) - } - - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "mean1"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "mean2"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "algo_const"), - DataType::Float64, - true, - ), - ]) - } -} - -/// An accumulator to compute covariance -#[derive(Debug)] -pub struct CovarianceAccumulator { - algo_const: f64, - mean1: f64, - mean2: f64, - count: f64, - stats_type: StatsType, - null_on_divide_by_zero: bool, -} - -impl CovarianceAccumulator { - /// Creates a new `CovarianceAccumulator` - pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result { - Ok(Self { - algo_const: 0_f64, - mean1: 0_f64, - mean2: 0_f64, - count: 0_f64, - stats_type: s_type, - null_on_divide_by_zero, - }) - } - - pub fn get_count(&self) -> f64 { - self.count - } - - pub fn get_mean1(&self) -> f64 { - self.mean1 - } - - pub fn get_mean2(&self) -> f64 { - self.mean2 - } - - pub fn get_algo_const(&self) -> f64 { - self.algo_const - } -} - -impl Accumulator for CovarianceAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::from(self.mean1), - ScalarValue::from(self.mean2), - ScalarValue::from(self.algo_const), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); - - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None - }; - - if value1.is_none() || value2.is_none() { - continue; - } - - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); - let new_count = self.count + 1.0; - let delta1 = value1 - self.mean1; - let new_mean1 = delta1 / new_count + self.mean1; - let delta2 = value2 - self.mean2; - let new_mean2 = delta2 / new_count + self.mean2; - let new_c = delta1 * (value2 - new_mean2) + self.algo_const; - - self.count += 1.0; - self.mean1 = new_mean1; - self.mean2 = new_mean2; - self.algo_const = new_c; - } - - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); - - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None - }; - - if value1.is_none() || value2.is_none() { - continue; - } - - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); - - let new_count = self.count - 1.0; - let delta1 = self.mean1 - value1; - let new_mean1 = delta1 / new_count + self.mean1; - let delta2 = self.mean2 - value2; - let new_mean2 = delta2 / new_count + self.mean2; - let new_c = self.algo_const - delta1 * (new_mean2 - value2); - - self.count -= 1.0; - self.mean1 = new_mean1; - self.mean2 = new_mean2; - self.algo_const = new_c; - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], Float64Array); - let means1 = downcast_value!(states[1], Float64Array); - let means2 = downcast_value!(states[2], Float64Array); - let cs = downcast_value!(states[3], Float64Array); - - for i in 0..counts.len() { - let c = counts.value(i); - if c == 0.0 { - continue; - } - let new_count = self.count + c; - let new_mean1 = self.mean1 * self.count / new_count + means1.value(i) * c / new_count; - let new_mean2 = self.mean2 * self.count / new_count + means2.value(i) * c / new_count; - let delta1 = self.mean1 - means1.value(i); - let delta2 = self.mean2 - means2.value(i); - let new_c = - self.algo_const + cs.value(i) + delta1 * delta2 * self.count * c / new_count; - - self.count = new_count; - self.mean1 = new_mean1; - self.mean2 = new_mean2; - self.algo_const = new_c; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - if self.count == 0.0 { - return Ok(ScalarValue::Float64(None)); - } - - let count = match self.stats_type { - StatsType::Population => self.count, - StatsType::Sample if self.count > 1.0 => self.count - 1.0, - StatsType::Sample => { - // self.count == 1.0 - return if self.null_on_divide_by_zero { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(Some(f64::NAN))) - }; - } - }; - - Ok(ScalarValue::Float64(Some(self.algo_const / count))) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} diff --git a/src/strings.rs b/src/strings.rs deleted file mode 100644 index c2706b589652..000000000000 --- a/src/strings.rs +++ /dev/null @@ -1,290 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#![allow(deprecated)] - -use crate::kernels::strings::{string_space, substring}; -use arrow::{ - compute::{ - contains_dyn, contains_utf8_scalar_dyn, ends_with_dyn, ends_with_utf8_scalar_dyn, like_dyn, - like_utf8_scalar_dyn, starts_with_dyn, starts_with_utf8_scalar_dyn, - }, - record_batch::RecordBatch, -}; -use arrow_schema::{DataType, Schema}; -use datafusion::logical_expr::ColumnarValue; -use datafusion_common::{DataFusionError, ScalarValue::Utf8}; -use datafusion_physical_expr::PhysicalExpr; -use std::{ - any::Any, - fmt::{Display, Formatter}, - hash::Hash, - sync::Arc, -}; - -macro_rules! make_predicate_function { - ($name: ident, $kernel: ident, $str_scalar_kernel: ident) => { - #[derive(Debug, Eq)] - pub struct $name { - left: Arc, - right: Arc, - } - - impl $name { - pub fn new(left: Arc, right: Arc) -> Self { - Self { left, right } - } - } - - impl Display for $name { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "$name [left: {}, right: {}]", self.left, self.right) - } - } - - impl Hash for $name { - fn hash(&self, state: &mut H) { - self.left.hash(state); - self.right.hash(state); - } - } - - impl PartialEq for $name { - fn eq(&self, other: &Self) -> bool { - self.left.eq(&other.left) && self.right.eq(&other.right) - } - } - - impl PhysicalExpr for $name { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, _: &Schema) -> datafusion_common::Result { - Ok(DataType::Boolean) - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let left_arg = self.left.evaluate(batch)?; - let right_arg = self.right.evaluate(batch)?; - - let array = match (left_arg, right_arg) { - // array (op) scalar - (ColumnarValue::Array(array), ColumnarValue::Scalar(Utf8(Some(string)))) => { - $str_scalar_kernel(&array, string.as_str()) - } - (ColumnarValue::Array(_), ColumnarValue::Scalar(other)) => { - return Err(DataFusionError::Execution(format!( - "Should be String but got: {:?}", - other - ))) - } - // array (op) array - (ColumnarValue::Array(array1), ColumnarValue::Array(array2)) => { - $kernel(&array1, &array2) - } - // scalar (op) scalar should be folded at Spark optimizer - _ => { - return Err(DataFusionError::Execution( - "Predicate on two literals should be folded at Spark".to_string(), - )) - } - }?; - - Ok(ColumnarValue::Array(Arc::new(array))) - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.left, &self.right] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> datafusion_common::Result> { - Ok(Arc::new($name::new( - children[0].clone(), - children[1].clone(), - ))) - } - } - }; -} - -make_predicate_function!(Like, like_dyn, like_utf8_scalar_dyn); - -make_predicate_function!(StartsWith, starts_with_dyn, starts_with_utf8_scalar_dyn); - -make_predicate_function!(EndsWith, ends_with_dyn, ends_with_utf8_scalar_dyn); - -make_predicate_function!(Contains, contains_dyn, contains_utf8_scalar_dyn); - -#[derive(Debug, Eq)] -pub struct SubstringExpr { - pub child: Arc, - pub start: i64, - pub len: u64, -} - -impl Hash for SubstringExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.start.hash(state); - self.len.hash(state); - } -} - -impl PartialEq for SubstringExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.start.eq(&other.start) && self.len.eq(&other.len) - } -} -#[derive(Debug, Eq)] -pub struct StringSpaceExpr { - pub child: Arc, -} - -impl Hash for StringSpaceExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - } -} - -impl PartialEq for StringSpaceExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) - } -} - -impl SubstringExpr { - pub fn new(child: Arc, start: i64, len: u64) -> Self { - Self { child, start, len } - } -} - -impl StringSpaceExpr { - pub fn new(child: Arc) -> Self { - Self { child } - } -} - -impl Display for SubstringExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "StringSpace [start: {}, len: {}, child: {}]", - self.start, self.len, self.child - ) - } -} - -impl Display for StringSpaceExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "StringSpace [child: {}] ", self.child) - } -} - -impl PhysicalExpr for SubstringExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - self.child.data_type(input_schema) - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let result = substring(&array, self.start, self.len)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Substring(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> datafusion_common::Result> { - Ok(Arc::new(SubstringExpr::new( - Arc::clone(&children[0]), - self.start, - self.len, - ))) - } -} - -impl PhysicalExpr for StringSpaceExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema)? { - DataType::Dictionary(key_type, _) => { - Ok(DataType::Dictionary(key_type, Box::new(DataType::Utf8))) - } - _ => Ok(DataType::Utf8), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let result = string_space(&array)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "StringSpace(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> datafusion_common::Result> { - Ok(Arc::new(StringSpaceExpr::new(Arc::clone(&children[0])))) - } -} diff --git a/src/sum_decimal.rs b/src/sum_decimal.rs deleted file mode 100644 index f3f34d9bfa9d..000000000000 --- a/src/sum_decimal.rs +++ /dev/null @@ -1,555 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::utils::{is_valid_decimal_precision, unlikely}; -use arrow::{ - array::BooleanBufferBuilder, - buffer::{BooleanBuffer, NullBuffer}, -}; -use arrow_array::{ - cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array, -}; -use arrow_schema::{DataType, Field}; -use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; -use datafusion_common::{DataFusionError, Result as DFResult, ScalarValue}; -use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; -use datafusion_expr::Volatility::Immutable; -use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature}; -use std::{any::Any, ops::BitAnd, sync::Arc}; - -#[derive(Debug)] -pub struct SumDecimal { - /// Aggregate function signature - signature: Signature, - /// The data type of the SUM result. This will always be a decimal type - /// with the same precision and scale as specified in this struct - result_type: DataType, - /// Decimal precision - precision: u8, - /// Decimal scale - scale: i8, -} - -impl SumDecimal { - pub fn try_new(data_type: DataType) -> DFResult { - // The `data_type` is the SUM result type passed from Spark side - let (precision, scale) = match data_type { - DataType::Decimal128(p, s) => (p, s), - _ => { - return Err(DataFusionError::Internal( - "Invalid data type for SumDecimal".into(), - )) - } - }; - Ok(Self { - signature: Signature::user_defined(Immutable), - result_type: data_type, - precision, - scale, - }) - } -} - -impl AggregateUDFImpl for SumDecimal { - fn as_any(&self) -> &dyn Any { - self - } - - fn accumulator(&self, _args: AccumulatorArgs) -> DFResult> { - Ok(Box::new(SumDecimalAccumulator::new( - self.precision, - self.scale, - ))) - } - - fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { - let fields = vec![ - Field::new(self.name(), self.result_type.clone(), self.is_nullable()), - Field::new("is_empty", DataType::Boolean, false), - ]; - Ok(fields) - } - - fn name(&self) -> &str { - "sum" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> DFResult { - Ok(self.result_type.clone()) - } - - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { - true - } - - fn create_groups_accumulator( - &self, - _args: AccumulatorArgs, - ) -> DFResult> { - Ok(Box::new(SumDecimalGroupsAccumulator::new( - self.result_type.clone(), - self.precision, - ))) - } - - fn default_value(&self, _data_type: &DataType) -> DFResult { - ScalarValue::new_primitive::( - None, - &DataType::Decimal128(self.precision, self.scale), - ) - } - - fn reverse_expr(&self) -> ReversedUDAF { - ReversedUDAF::Identical - } - - fn is_nullable(&self) -> bool { - // SumDecimal is always nullable because overflows can cause null values - true - } -} - -#[derive(Debug)] -struct SumDecimalAccumulator { - sum: i128, - is_empty: bool, - is_not_null: bool, - - precision: u8, - scale: i8, -} - -impl SumDecimalAccumulator { - fn new(precision: u8, scale: i8) -> Self { - Self { - sum: 0, - is_empty: true, - is_not_null: true, - precision, - scale, - } - } - - fn update_single(&mut self, values: &Decimal128Array, idx: usize) { - let v = unsafe { values.value_unchecked(idx) }; - let (new_sum, is_overflow) = self.sum.overflowing_add(v); - - if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { - // Overflow: set buffer accumulator to null - self.is_not_null = false; - return; - } - - self.sum = new_sum; - self.is_not_null = true; - } -} - -impl Accumulator for SumDecimalAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { - assert_eq!( - values.len(), - 1, - "Expect only one element in 'values' but found {}", - values.len() - ); - - if !self.is_empty && !self.is_not_null { - // This means there's a overflow in decimal, so we will just skip the rest - // of the computation - return Ok(()); - } - - let values = &values[0]; - let data = values.as_primitive::(); - - self.is_empty = self.is_empty && values.len() == values.null_count(); - - if values.null_count() == 0 { - for i in 0..data.len() { - self.update_single(data, i); - } - } else { - for i in 0..data.len() { - if data.is_null(i) { - continue; - } - self.update_single(data, i); - } - } - - Ok(()) - } - - fn evaluate(&mut self) -> DFResult { - // For each group: - // 1. if `is_empty` is true, it means either there is no value or all values for the group - // are null, in this case we'll return null - // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In - // non-ANSI mode Spark returns null. - if self.is_empty || !self.is_not_null { - ScalarValue::new_primitive::( - None, - &DataType::Decimal128(self.precision, self.scale), - ) - } else { - ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale) - } - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } - - fn state(&mut self) -> DFResult> { - let sum = if self.is_not_null { - ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)? - } else { - ScalarValue::new_primitive::( - None, - &DataType::Decimal128(self.precision, self.scale), - )? - }; - Ok(vec![sum, ScalarValue::from(self.is_empty)]) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { - assert_eq!( - states.len(), - 2, - "Expect two element in 'states' but found {}", - states.len() - ); - assert_eq!(states[0].len(), 1); - assert_eq!(states[1].len(), 1); - - let that_sum = states[0].as_primitive::(); - let that_is_empty = states[1].as_any().downcast_ref::().unwrap(); - - let this_overflow = !self.is_empty && !self.is_not_null; - let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0); - - self.is_not_null = !this_overflow && !that_overflow; - self.is_empty = self.is_empty && that_is_empty.value(0); - - if self.is_not_null { - self.sum += that_sum.value(0); - } - - Ok(()) - } -} - -struct SumDecimalGroupsAccumulator { - // Whether aggregate buffer for a particular group is null. True indicates it is not null. - is_not_null: BooleanBufferBuilder, - is_empty: BooleanBufferBuilder, - sum: Vec, - result_type: DataType, - precision: u8, -} - -impl SumDecimalGroupsAccumulator { - fn new(result_type: DataType, precision: u8) -> Self { - Self { - is_not_null: BooleanBufferBuilder::new(0), - is_empty: BooleanBufferBuilder::new(0), - sum: Vec::new(), - result_type, - precision, - } - } - - fn is_overflow(&self, index: usize) -> bool { - !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index) - } - - fn update_single(&mut self, group_index: usize, value: i128) { - if unlikely(self.is_overflow(group_index)) { - // This means there's a overflow in decimal, so we will just skip the rest - // of the computation - return; - } - - self.is_empty.set_bit(group_index, false); - let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value); - - if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { - // Overflow: set buffer accumulator to null - self.is_not_null.set_bit(group_index, false); - return; - } - - self.sum[group_index] = new_sum; - self.is_not_null.set_bit(group_index, true) - } -} - -fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) { - if builder.len() < capacity { - let additional = capacity - builder.len(); - builder.append_n(additional, true); - } -} - -/// Build a boolean buffer from the state and reset the state, based on the emit_to -/// strategy. -fn build_bool_state(state: &mut BooleanBufferBuilder, emit_to: &EmitTo) -> BooleanBuffer { - let bool_state: BooleanBuffer = state.finish(); - - match emit_to { - EmitTo::All => bool_state, - EmitTo::First(n) => { - // split off the first N values in bool_state - let first_n_bools: BooleanBuffer = bool_state.iter().take(*n).collect(); - // reset the existing seen buffer - for seen in bool_state.iter().skip(*n) { - state.append(seen); - } - first_n_bools - } - } -} - -impl GroupsAccumulator for SumDecimalGroupsAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> DFResult<()> { - assert!(opt_filter.is_none(), "opt_filter is not supported yet"); - assert_eq!(values.len(), 1); - let values = values[0].as_primitive::(); - let data = values.values(); - - // Update size for the accumulate states - self.sum.resize(total_num_groups, 0); - ensure_bit_capacity(&mut self.is_empty, total_num_groups); - ensure_bit_capacity(&mut self.is_not_null, total_num_groups); - - let iter = group_indices.iter().zip(data.iter()); - if values.null_count() == 0 { - for (&group_index, &value) in iter { - self.update_single(group_index, value); - } - } else { - for (idx, (&group_index, &value)) in iter.enumerate() { - if values.is_null(idx) { - continue; - } - self.update_single(group_index, value); - } - } - - Ok(()) - } - - fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { - // For each group: - // 1. if `is_empty` is true, it means either there is no value or all values for the group - // are null, in this case we'll return null - // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In - // non-ANSI mode Spark returns null. - let nulls = build_bool_state(&mut self.is_not_null, &emit_to); - let is_empty = build_bool_state(&mut self.is_empty, &emit_to); - let x = (!&is_empty).bitand(&nulls); - - let result = emit_to.take_needed(&mut self.sum); - let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x))) - .with_data_type(self.result_type.clone()); - - Ok(Arc::new(result)) - } - - fn state(&mut self, emit_to: EmitTo) -> DFResult> { - let nulls = build_bool_state(&mut self.is_not_null, &emit_to); - let nulls = Some(NullBuffer::new(nulls)); - - let sum = emit_to.take_needed(&mut self.sum); - let sum = Decimal128Array::new(sum.into(), nulls.clone()) - .with_data_type(self.result_type.clone()); - - let is_empty = build_bool_state(&mut self.is_empty, &emit_to); - let is_empty = BooleanArray::new(is_empty, None); - - Ok(vec![ - Arc::new(sum) as ArrayRef, - Arc::new(is_empty) as ArrayRef, - ]) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&BooleanArray>, - total_num_groups: usize, - ) -> DFResult<()> { - assert_eq!( - values.len(), - 2, - "Expected two arrays: 'sum' and 'is_empty', but found {}", - values.len() - ); - assert!(opt_filter.is_none(), "opt_filter is not supported yet"); - - // Make sure we have enough capacity for the additional groups - self.sum.resize(total_num_groups, 0); - ensure_bit_capacity(&mut self.is_empty, total_num_groups); - ensure_bit_capacity(&mut self.is_not_null, total_num_groups); - - let that_sum = &values[0]; - let that_sum = that_sum.as_primitive::(); - let that_is_empty = &values[1]; - let that_is_empty = that_is_empty - .as_any() - .downcast_ref::() - .unwrap(); - - group_indices - .iter() - .enumerate() - .for_each(|(idx, &group_index)| unsafe { - let this_overflow = self.is_overflow(group_index); - let that_is_empty = that_is_empty.value_unchecked(idx); - let that_overflow = !that_is_empty && that_sum.is_null(idx); - let is_overflow = this_overflow || that_overflow; - - // This part follows the logic in Spark: - // `org.apache.spark.sql.catalyst.expressions.aggregate.Sum` - self.is_not_null.set_bit(group_index, !is_overflow); - self.is_empty.set_bit( - group_index, - self.is_empty.get_bit(group_index) && that_is_empty, - ); - if !is_overflow { - // .. otherwise, the sum value for this particular index must not be null, - // and thus we merge both values and update this sum. - self.sum[group_index] += that_sum.value_unchecked(idx); - } - }); - - Ok(()) - } - - fn size(&self) -> usize { - self.sum.capacity() * std::mem::size_of::() - + self.is_empty.capacity() / 8 - + self.is_not_null.capacity() / 8 - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::datatypes::*; - use arrow_array::builder::{Decimal128Builder, StringBuilder}; - use arrow_array::RecordBatch; - use datafusion::execution::TaskContext; - use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; - use datafusion::physical_plan::memory::MemoryExec; - use datafusion::physical_plan::ExecutionPlan; - use datafusion_common::Result; - use datafusion_expr::AggregateUDF; - use datafusion_physical_expr::aggregate::AggregateExprBuilder; - use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr::PhysicalExpr; - use futures::StreamExt; - - #[test] - fn invalid_data_type() { - assert!(SumDecimal::try_new(DataType::Int32).is_err()); - } - - #[tokio::test] - async fn sum_no_overflow() -> Result<()> { - let num_rows = 8192; - let batch = create_record_batch(num_rows); - let mut batches = Vec::new(); - for _ in 0..10 { - batches.push(batch.clone()); - } - let partitions = &[batches]; - let c0: Arc = Arc::new(Column::new("c0", 0)); - let c1: Arc = Arc::new(Column::new("c1", 1)); - - let data_type = DataType::Decimal128(8, 2); - let schema = Arc::clone(&partitions[0][0].schema()); - let scan: Arc = - Arc::new(MemoryExec::try_new(partitions, Arc::clone(&schema), None).unwrap()); - - let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new( - data_type.clone(), - )?)); - - let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) - .schema(Arc::clone(&schema)) - .alias("sum") - .with_ignore_nulls(false) - .with_distinct(false) - .build()?; - - let aggregate = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, - PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]), - vec![aggr_expr.into()], - vec![None], // no filter expressions - scan, - Arc::clone(&schema), - )?); - - let mut stream = aggregate - .execute(0, Arc::new(TaskContext::default())) - .unwrap(); - while let Some(batch) = stream.next().await { - let _batch = batch?; - } - - Ok(()) - } - - fn create_record_batch(num_rows: usize) -> RecordBatch { - let mut decimal_builder = Decimal128Builder::with_capacity(num_rows); - let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); - for i in 0..num_rows { - decimal_builder.append_value(i as i128); - string_builder.append_value(format!("this is string #{}", i % 1024)); - } - let decimal_array = Arc::new(decimal_builder.finish()); - let string_array = Arc::new(string_builder.finish()); - - let mut fields = vec![]; - let mut columns: Vec = vec![]; - - // string column - fields.push(Field::new("c0", DataType::Utf8, false)); - columns.push(string_array); - - // decimal column - fields.push(Field::new("c1", DataType::Decimal128(38, 10), false)); - columns.push(decimal_array); - - let schema = Schema::new(fields); - RecordBatch::try_new(Arc::new(schema), columns).unwrap() - } -} diff --git a/src/temporal.rs b/src/temporal.rs deleted file mode 100644 index fb549f9ce818..000000000000 --- a/src/temporal.rs +++ /dev/null @@ -1,510 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::utils::array_with_timezone; -use arrow::{ - compute::{date_part, DatePart}, - record_batch::RecordBatch, -}; -use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; -use datafusion::logical_expr::ColumnarValue; -use datafusion_common::{DataFusionError, ScalarValue::Utf8}; -use datafusion_physical_expr::PhysicalExpr; -use std::hash::Hash; -use std::{ - any::Any, - fmt::{Debug, Display, Formatter}, - sync::Arc, -}; - -use crate::kernels::temporal::{ - date_trunc_array_fmt_dyn, date_trunc_dyn, timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn, -}; - -#[derive(Debug, Eq)] -pub struct HourExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - timezone: String, -} - -impl Hash for HourExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for HourExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.timezone.eq(&other.timezone) - } -} - -impl HourExpr { - pub fn new(child: Arc, timezone: String) -> Self { - HourExpr { child, timezone } - } -} - -impl Display for HourExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Hour [timezone:{}, child: {}]", - self.timezone, self.child - ) - } -} - -impl PhysicalExpr for HourExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema).unwrap() { - DataType::Dictionary(key_type, _) => { - Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) - } - _ => Ok(DataType::Int32), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let array = array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp( - Microsecond, - Some(self.timezone.clone().into()), - )), - )?; - let result = date_part(&array, DatePart::Hour)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Hour(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(HourExpr::new( - Arc::clone(&children[0]), - self.timezone.clone(), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct MinuteExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - timezone: String, -} - -impl Hash for MinuteExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for MinuteExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.timezone.eq(&other.timezone) - } -} - -impl MinuteExpr { - pub fn new(child: Arc, timezone: String) -> Self { - MinuteExpr { child, timezone } - } -} - -impl Display for MinuteExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Minute [timezone:{}, child: {}]", - self.timezone, self.child - ) - } -} - -impl PhysicalExpr for MinuteExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema).unwrap() { - DataType::Dictionary(key_type, _) => { - Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) - } - _ => Ok(DataType::Int32), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let array = array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp( - Microsecond, - Some(self.timezone.clone().into()), - )), - )?; - let result = date_part(&array, DatePart::Minute)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Minute(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(MinuteExpr::new( - Arc::clone(&children[0]), - self.timezone.clone(), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct SecondExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - timezone: String, -} - -impl Hash for SecondExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for SecondExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.timezone.eq(&other.timezone) - } -} - -impl SecondExpr { - pub fn new(child: Arc, timezone: String) -> Self { - SecondExpr { child, timezone } - } -} - -impl Display for SecondExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Second (timezone:{}, child: {}]", - self.timezone, self.child - ) - } -} - -impl PhysicalExpr for SecondExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema).unwrap() { - DataType::Dictionary(key_type, _) => { - Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) - } - _ => Ok(DataType::Int32), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let array = array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp( - Microsecond, - Some(self.timezone.clone().into()), - )), - )?; - let result = date_part(&array, DatePart::Second)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Second(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(SecondExpr::new( - Arc::clone(&children[0]), - self.timezone.clone(), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct DateTruncExpr { - /// An array with DataType::Date32 - child: Arc, - /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc - format: Arc, -} - -impl Hash for DateTruncExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.format.hash(state); - } -} -impl PartialEq for DateTruncExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.format.eq(&other.format) - } -} - -impl DateTruncExpr { - pub fn new(child: Arc, format: Arc) -> Self { - DateTruncExpr { child, format } - } -} - -impl Display for DateTruncExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "DateTrunc [child:{}, format: {}]", - self.child, self.format - ) - } -} - -impl PhysicalExpr for DateTruncExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - self.child.data_type(input_schema) - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let date = self.child.evaluate(batch)?; - let format = self.format.evaluate(batch)?; - match (date, format) { - (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => { - let result = date_trunc_dyn(&date, format)?; - Ok(ColumnarValue::Array(result)) - } - (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => { - let result = date_trunc_array_fmt_dyn(&date, &formats)?; - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar) or \ - (PrimitiveArray, StringArray)".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(DateTruncExpr::new( - Arc::clone(&children[0]), - Arc::clone(&self.format), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct TimestampTruncExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc - format: Arc, - /// String containing a timezone name. The name must be found in the standard timezone - /// database (https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). The string is - /// later parsed into a chrono::TimeZone. - /// Timestamp arrays in this implementation are kept in arrays of UTC timestamps (in micros) - /// along with a single value for the associated TimeZone. The timezone offset is applied - /// just before any operations on the timestamp - timezone: String, -} - -impl Hash for TimestampTruncExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.format.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for TimestampTruncExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) - && self.format.eq(&other.format) - && self.timezone.eq(&other.timezone) - } -} - -impl TimestampTruncExpr { - pub fn new( - child: Arc, - format: Arc, - timezone: String, - ) -> Self { - TimestampTruncExpr { - child, - format, - timezone, - } - } -} - -impl Display for TimestampTruncExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "TimestampTrunc [child:{}, format:{}, timezone: {}]", - self.child, self.format, self.timezone - ) - } -} - -impl PhysicalExpr for TimestampTruncExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema)? { - DataType::Dictionary(key_type, _) => Ok(DataType::Dictionary( - key_type, - Box::new(DataType::Timestamp(Microsecond, None)), - )), - _ => Ok(DataType::Timestamp(Microsecond, None)), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let timestamp = self.child.evaluate(batch)?; - let format = self.format.evaluate(batch)?; - let tz = self.timezone.clone(); - match (timestamp, format) { - (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => { - let ts = array_with_timezone( - ts, - tz.clone(), - Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), - )?; - let result = timestamp_trunc_dyn(&ts, format)?; - Ok(ColumnarValue::Array(result)) - } - (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => { - let ts = array_with_timezone( - ts, - tz.clone(), - Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), - )?; - let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?; - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Invalid input to function TimestampTrunc. \ - Expected (PrimitiveArray, Scalar, String) or \ - (PrimitiveArray, StringArray, String)" - .to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(TimestampTruncExpr::new( - Arc::clone(&children[0]), - Arc::clone(&self.format), - self.timezone.clone(), - ))) - } -}