-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce user-defined signature #10439
Changes from 15 commits
ff441ea
8b6016b
36abe3e
0988efe
a64b813
5bbd2a0
6f0a90b
5cc047b
a515aad
17e6ec1
5eaacc8
a606581
40d5444
ca7f942
5cb0d0f
68fdf52
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,23 +20,124 @@ use std::sync::Arc; | |
use crate::signature::{ | ||
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, | ||
}; | ||
use crate::{Signature, TypeSignature}; | ||
use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature}; | ||
use arrow::{ | ||
compute::can_cast_types, | ||
datatypes::{DataType, TimeUnit}, | ||
}; | ||
use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims}; | ||
use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; | ||
use datafusion_common::{ | ||
exec_err, internal_datafusion_err, internal_err, plan_err, Result, | ||
}; | ||
|
||
use super::binary::{comparison_binary_numeric_coercion, comparison_coercion}; | ||
|
||
/// Performs type coercion for scalar function arguments. | ||
/// | ||
/// Returns the data types to which each argument must be coerced to | ||
/// match `signature`. | ||
/// | ||
/// For more details on coercion in general, please see the | ||
/// [`type_coercion`](crate::type_coercion) module. | ||
pub fn data_types_with_scalar_udf( | ||
current_types: &[DataType], | ||
func: &ScalarUDF, | ||
) -> Result<Vec<DataType>> { | ||
let signature = func.signature(); | ||
|
||
if current_types.is_empty() { | ||
if signature.type_signature.supports_zero_argument() { | ||
return Ok(vec![]); | ||
} else { | ||
return plan_err!( | ||
"[data_types_with_scalar_udf] signature {:?} does not support zero arguments.", | ||
&signature.type_signature | ||
); | ||
} | ||
} | ||
|
||
let valid_types = | ||
get_valid_types_with_scalar_udf(&signature.type_signature, current_types, func)?; | ||
|
||
if valid_types | ||
.iter() | ||
.any(|data_type| data_type == current_types) | ||
{ | ||
return Ok(current_types.to_vec()); | ||
} | ||
|
||
// Try and coerce the argument types to match the signature, returning the | ||
// coerced types from the first matching signature. | ||
for valid_types in valid_types { | ||
if let Some(types) = maybe_data_types(&valid_types, current_types) { | ||
return Ok(types); | ||
} | ||
} | ||
|
||
// none possible -> Error | ||
plan_err!( | ||
"[data_types_with_scalar_udf] Coercion from {:?} to the signature {:?} failed.", | ||
current_types, | ||
&signature.type_signature | ||
) | ||
} | ||
|
||
pub fn data_types_with_aggregate_udf( | ||
current_types: &[DataType], | ||
func: &AggregateUDF, | ||
) -> Result<Vec<DataType>> { | ||
let signature = func.signature(); | ||
|
||
if current_types.is_empty() { | ||
if signature.type_signature.supports_zero_argument() { | ||
return Ok(vec![]); | ||
} else { | ||
return plan_err!( | ||
"[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.", | ||
current_types, | ||
&signature.type_signature | ||
); | ||
} | ||
} | ||
|
||
let valid_types = get_valid_types_with_aggregate_udf( | ||
&signature.type_signature, | ||
current_types, | ||
func, | ||
)?; | ||
if valid_types | ||
.iter() | ||
.any(|data_type| data_type == current_types) | ||
{ | ||
return Ok(current_types.to_vec()); | ||
} | ||
|
||
// Try and coerce the argument types to match the signature, returning the | ||
// coerced types from the first matching signature. | ||
for valid_types in valid_types { | ||
if let Some(types) = maybe_data_types(&valid_types, current_types) { | ||
return Ok(types); | ||
} | ||
} | ||
|
||
// none possible -> Error | ||
plan_err!( | ||
"[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.", | ||
current_types, | ||
&signature.type_signature | ||
) | ||
} | ||
|
||
/// Performs type coercion for function arguments. | ||
/// | ||
/// Returns the data types to which each argument must be coerced to | ||
/// match `signature`. | ||
/// | ||
/// For more details on coercion in general, please see the | ||
/// [`type_coercion`](crate::type_coercion) module. | ||
/// | ||
/// This function will be replaced with [data_types_with_scalar_udf], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure we want to replace this function over time -- I think having the basic simple Signatures that handle most common coercions makes sense to have in DataFusion core (even if it could be done purely in a udf) as it will make creating UDFs easier for users There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with you |
||
/// [data_types_with_aggregate_udf], and data_types_with_window_udf gradually. | ||
pub fn data_types( | ||
current_types: &[DataType], | ||
signature: &Signature, | ||
|
@@ -46,7 +147,7 @@ pub fn data_types( | |
return Ok(vec![]); | ||
} else { | ||
return plan_err!( | ||
"Coercion from {:?} to the signature {:?} failed.", | ||
"[data_types] Coercion from {:?} to the signature {:?} failed.", | ||
current_types, | ||
&signature.type_signature | ||
); | ||
|
@@ -72,12 +173,56 @@ pub fn data_types( | |
|
||
// none possible -> Error | ||
plan_err!( | ||
"Coercion from {:?} to the signature {:?} failed.", | ||
"[data_types] Coercion from {:?} to the signature {:?} failed.", | ||
current_types, | ||
&signature.type_signature | ||
) | ||
} | ||
|
||
fn get_valid_types_with_scalar_udf( | ||
signature: &TypeSignature, | ||
current_types: &[DataType], | ||
func: &ScalarUDF, | ||
) -> Result<Vec<Vec<DataType>>> { | ||
let valid_types = match signature { | ||
TypeSignature::UserDefined => match func.coerce_types(current_types) { | ||
Ok(coerced_types) => vec![coerced_types], | ||
Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), | ||
}, | ||
TypeSignature::OneOf(signatures) => signatures | ||
.iter() | ||
.filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok()) | ||
.flatten() | ||
.collect::<Vec<_>>(), | ||
_ => get_valid_types(signature, current_types)?, | ||
}; | ||
|
||
Ok(valid_types) | ||
} | ||
|
||
fn get_valid_types_with_aggregate_udf( | ||
signature: &TypeSignature, | ||
current_types: &[DataType], | ||
func: &AggregateUDF, | ||
) -> Result<Vec<Vec<DataType>>> { | ||
let valid_types = match signature { | ||
TypeSignature::UserDefined => match func.coerce_types(current_types) { | ||
Ok(coerced_types) => vec![coerced_types], | ||
Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), | ||
}, | ||
TypeSignature::OneOf(signatures) => signatures | ||
.iter() | ||
.filter_map(|t| { | ||
get_valid_types_with_aggregate_udf(t, current_types, func).ok() | ||
}) | ||
.flatten() | ||
.collect::<Vec<_>>(), | ||
_ => get_valid_types(signature, current_types)?, | ||
}; | ||
|
||
Ok(valid_types) | ||
} | ||
|
||
/// Returns a Vec of all possible valid argument types for the given signature. | ||
fn get_valid_types( | ||
signature: &TypeSignature, | ||
|
@@ -184,32 +329,14 @@ fn get_valid_types( | |
.iter() | ||
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) | ||
.collect(), | ||
TypeSignature::VariadicEqual => { | ||
let new_type = current_types.iter().skip(1).try_fold( | ||
current_types.first().unwrap().clone(), | ||
|acc, x| { | ||
// The coerced types found by `comparison_coercion` are not guaranteed to be | ||
// coercible for the arguments. `comparison_coercion` returns more loose | ||
// types that can be coerced to both `acc` and `x` for comparison purpose. | ||
// See `maybe_data_types` for the actual coercion. | ||
let coerced_type = comparison_coercion(&acc, x); | ||
if let Some(coerced_type) = coerced_type { | ||
Ok(coerced_type) | ||
} else { | ||
internal_err!("Coercion from {acc:?} to {x:?} failed.") | ||
} | ||
}, | ||
); | ||
|
||
match new_type { | ||
Ok(new_type) => vec![vec![new_type; current_types.len()]], | ||
Err(e) => return Err(e), | ||
} | ||
TypeSignature::UserDefined => { | ||
return internal_err!( | ||
"User-defined signature should be handled by function-specific coerce_types." | ||
) | ||
} | ||
TypeSignature::VariadicAny => { | ||
vec![current_types.to_vec()] | ||
} | ||
|
||
TypeSignature::Exact(valid_types) => vec![valid_types.clone()], | ||
TypeSignature::ArraySignature(ref function_signature) => match function_signature | ||
{ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep error for debugging