Skip to content

Commit

Permalink
Adding type coercion
Browse files Browse the repository at this point in the history
  • Loading branch information
edmondop committed Jul 27, 2024
1 parent 87e945f commit 07910bb
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 5,538 deletions.
27 changes: 26 additions & 1 deletion datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
// under the License.

use crate::TypeSignature;

use arrow::datatypes::{
DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};
use std::ops::Deref;

use datafusion_common::{internal_err, plan_err, Result};

Expand Down Expand Up @@ -142,6 +142,22 @@ pub fn check_arg_count(
Ok(())
}

pub fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
// make sure that the input types only has one element.
assert_eq!(input_types.len(), 1);
// min and max support the dictionary data type
// unpack the dictionary to get the value
match &input_types[0] {
DataType::Dictionary(_, dict_value_type) => {
// TODO add checker, if the value type is complex data type
Ok(vec![dict_value_type.deref().clone()])
}
// TODO add checker for datatype which min and max supported
// For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
_ => Ok(input_types.to_vec()),
}
}

/// function return type of a sum
pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
Expand Down Expand Up @@ -311,6 +327,15 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<Da
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> {
let data_type =
DataType::Dictionary(Box::new(DataType::Utf8), Box::new((DataType::Int32)));
let result = get_min_max_result_type(&[data_type])?;
assert_eq!(result, vec![DataType::Int32]);
Ok(())
}

#[test]
fn test_variance_return_data_type() -> Result<()> {
let data_type = DataType::Float64;
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ fn get_valid_types_with_aggregate_udf(
_ => get_valid_types(signature, current_types)?,
};

println!("current types {:?}, valid_types: {:?}", current_types, valid_types);
Ok(valid_types)
}

Expand Down
177 changes: 160 additions & 17 deletions datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use arrow::datatypes::{
DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type,
Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_schema::IntervalUnit;
use datafusion_common::{downcast_value, internal_err, DataFusionError, Result};
use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
use std::fmt::Debug;
Expand All @@ -62,31 +63,111 @@ use datafusion_common::ScalarValue;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_expr::{Expr, GroupsAccumulator};

use datafusion_expr::{type_coercion, Expr, GroupsAccumulator};



pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8];

pub static SIGNED_INTEGERS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
];

pub static UNSIGNED_INTEGERS: &[DataType] = &[
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
];

pub static INTEGERS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
];

pub static NUMERICS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Float32,
DataType::Float64,
];

pub static TIMESTAMPS: &[DataType] = &[
DataType::Timestamp(TimeUnit::Second, None),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Microsecond, None),
DataType::Timestamp(TimeUnit::Nanosecond, None),
];

pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64];

pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary];

pub static TIMES: &[DataType] = &[
DataType::Time32(TimeUnit::Second),
DataType::Time32(TimeUnit::Millisecond),
DataType::Time64(TimeUnit::Microsecond),
DataType::Time64(TimeUnit::Nanosecond),
];

pub static TIMES_INTERVALS: &[DataType] = &[
DataType::Interval(IntervalUnit::DayTime),
DataType::Interval(IntervalUnit::YearMonth),
DataType::Interval(IntervalUnit::MonthDayNano),
];
// Min/max aggregation can take Dictionary encode input but always produces unpacked
// (aka non Dictionary) output. We need to adjust the output data type to reflect this.
// The reason min/max aggregate produces unpacked output because there is only one
// min/max value per group; there is no needs to keep them Dictionary encode
fn min_max_aggregate_data_type(input_type: DataType) -> DataType {
fn min_max_aggregate_data_type<'a>(input_type: &'a DataType) -> &'a DataType {
if let DataType::Dictionary(_, value_type) = input_type {
*value_type
value_type
} else {
input_type
}
}

fn min_max_signature() -> Signature {
let valid = STRINGS
.iter()
.chain(NUMERICS.iter())
.chain(TIMESTAMPS.iter())
.chain(DATES.iter())
.chain(TIMES.iter())
.chain(BINARYS.iter())
.chain(TIMES_INTERVALS.iter())
.cloned()
.collect::<Vec<_>>();
Signature::uniform(1, valid, Volatility::Immutable)
}

// MAX aggregate UDF
#[derive(Debug)]
pub struct Max {
signature: Signature,
aliases: Vec<String>,
signature: Signature
}

impl Max {
pub fn new() -> Self {
Self {
signature: Signature::numeric(1, Volatility::Immutable),
aliases: vec!["max".to_owned()],
signature: min_max_signature(),
}
}
}
Expand Down Expand Up @@ -147,20 +228,31 @@ impl AggregateUDFImpl for Max {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(min_max_aggregate_data_type(arg_types[0].clone()))
type_coercion::aggregates::get_min_max_result_type(arg_types)?
.into_iter()
.next()
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Expected at one input type for MAX aggregate function"
))
})
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(MaxAccumulator::try_new(acc_args.input_type)?))
// let data_type = &min_max_aggregate_data_type(acc_args.data_type);
let data_type = acc_args.input_type;
Ok(Box::new(MaxAccumulator::try_new(data_type)?))
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
// let data_type = min_max_aggregate_data_type(args.data_type);
let data_type = args.input_type;
matches!(
args.input_type,
data_type,
DataType::Int8
| DataType::Int16
| DataType::Int32
Expand All @@ -187,6 +279,7 @@ impl AggregateUDFImpl for Max {
) -> Result<Box<dyn GroupsAccumulator>> {
use DataType::*;
use TimeUnit::*;
// let data_type = min_max_aggregate_data_type(args.data_type);
let data_type = args.input_type;
match data_type {
Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type),
Expand Down Expand Up @@ -248,12 +341,20 @@ impl AggregateUDFImpl for Max {
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SlidingMaxAccumulator::try_new(args.input_type)?))
let data_type = min_max_aggregate_data_type(args.data_type);
Ok(Box::new(SlidingMaxAccumulator::try_new(data_type)?))
}

fn get_minmax_desc(&self) -> Option<bool> {
Some(true)
}
fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
type_coercion::aggregates::get_min_max_result_type(arg_types)
}
}

// Statically-typed version of min/max(array) -> ScalarValue for string types
Expand Down Expand Up @@ -824,7 +925,7 @@ pub struct Min {
impl Min {
pub fn new() -> Self {
Self {
signature: Signature::numeric(1, Volatility::Immutable),
signature: min_max_signature(),
aliases: vec!["min".to_owned()],
}
}
Expand All @@ -850,20 +951,33 @@ impl AggregateUDFImpl for Min {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(min_max_aggregate_data_type(arg_types[0].clone()))
let return_type = type_coercion::aggregates::get_min_max_result_type(arg_types)?
.into_iter()
.next()
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Expected at one input type for MAX aggregate function"
))
});

println!("Return type for min {:?}", return_type);
return_type
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(MinAccumulator::try_new(acc_args.input_type)?))
Ok(Box::new(MinAccumulator::try_new(
&min_max_aggregate_data_type(acc_args.data_type),
)?))
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
let data_type = min_max_aggregate_data_type(args.data_type);
matches!(
args.input_type,
data_type,
DataType::Int8
| DataType::Int16
| DataType::Int32
Expand All @@ -890,7 +1004,7 @@ impl AggregateUDFImpl for Min {
) -> Result<Box<dyn GroupsAccumulator>> {
use DataType::*;
use TimeUnit::*;
let data_type = args.input_type;
let data_type = min_max_aggregate_data_type(args.data_type);
match data_type {
Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type),
Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type),
Expand Down Expand Up @@ -951,12 +1065,21 @@ impl AggregateUDFImpl for Min {
&self,
args: AccumulatorArgs,
) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SlidingMinAccumulator::try_new(args.input_type)?))
let data_type = min_max_aggregate_data_type(args.data_type);
Ok(Box::new(SlidingMinAccumulator::try_new(data_type)?))
}

fn get_minmax_desc(&self) -> Option<bool> {
Some(false)
}

fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
type_coercion::aggregates::get_min_max_result_type(arg_types)
}
}
/// An accumulator to compute the minimum value
#[derive(Debug)]
Expand All @@ -965,7 +1088,7 @@ pub struct MinAccumulator {
}

impl MinAccumulator {
/// new max accumulator
/// new min accumulator
pub fn try_new(datatype: &DataType) -> Result<Self> {
Ok(Self {
min: ScalarValue::try_from(datatype)?,
Expand Down Expand Up @@ -1062,6 +1185,7 @@ impl Accumulator for SlidingMinAccumulator {
std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size()
}
}

//
// Moving min and moving max
// The implementation is taken from https://github.com/spebern/moving_min_max/blob/master/src/lib.rs.
Expand Down Expand Up @@ -1454,4 +1578,23 @@ mod tests {
moving_max_i32(100, 100)?;
Ok(())
}

#[test]
fn test_min_max_coerce_types() {
// the coerced types is same with input types
let funs: Vec<Box<dyn AggregateUDFImpl>> =
vec![Box::new(Min::new()), Box::new(Max::new())];
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Decimal128(10, 2)],
vec![DataType::Decimal256(1, 1)],
vec![DataType::Utf8],
];
for fun in funs {
for input_type in &input_types {
let result = fun.coerce_types(input_type);
assert_eq!(*input_type, result.unwrap());
}
}
}
}
Loading

0 comments on commit 07910bb

Please sign in to comment.