Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce user-defined signature #10439

Merged
merged 16 commits into from
May 11, 2024
7 changes: 4 additions & 3 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::expr::{
};
use crate::field_util::GetFieldAccessSchema;
use crate::type_coercion::binary::get_result_type;
use crate::type_coercion::functions::data_types;
use crate::type_coercion::functions::data_types_with_scalar_udf;
use crate::{utils, LogicalPlan, Projection, Subquery};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
Expand Down Expand Up @@ -139,9 +139,10 @@ impl ExprSchemable for Expr {
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
// verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
data_types(&arg_data_types, func.signature()).map_err(|_| {
data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| {
plan_datafusion_err!(
"{}",
"{} and {}",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep error for debugging

err,
utils::generate_signature_error_msg(
func.name(),
func.signature().clone(),
Expand Down
23 changes: 10 additions & 13 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,12 @@ pub enum TypeSignature {
/// # Examples
/// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])`
Variadic(Vec<DataType>),
/// One or more arguments of an arbitrary but equal type.
/// DataFusion attempts to coerce all argument types to match the first argument's type
/// The acceptable signature and coercions rules to coerce arguments to this
/// signature are special for this function. If this signature is specified,
/// Datafusion will call [`ScalarUDFImpl::coerce_types`] to prepare argument types.
///
/// # Examples
/// Given types in signature should be coercible to the same final type.
/// A function such as `make_array` is `VariadicEqual`.
///
/// `make_array(i32, i64) -> make_array(i64, i64)`
VariadicEqual,
/// [`ScalarUDFImpl::coerce_types`]: crate::udf::ScalarUDFImpl::coerce_types
UserDefined,
/// One or more arguments with arbitrary types
VariadicAny,
/// Fixed number of arguments of an arbitrary but equal type out of a list of valid types.
Expand Down Expand Up @@ -190,8 +187,8 @@ impl TypeSignature {
.collect::<Vec<&str>>()
.join(", ")]
}
TypeSignature::VariadicEqual => {
vec!["CoercibleT, .., CoercibleT".to_string()]
TypeSignature::UserDefined => {
vec!["UserDefined".to_string()]
}
TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()],
TypeSignature::OneOf(sigs) => {
Expand Down Expand Up @@ -255,10 +252,10 @@ impl Signature {
volatility,
}
}
/// An arbitrary number of arguments of the same type.
pub fn variadic_equal(volatility: Volatility) -> Self {
/// User-defined coercion rules for the function.
pub fn user_defined(volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::VariadicEqual,
type_signature: TypeSignature::UserDefined,
volatility,
}
}
Expand Down
179 changes: 153 additions & 26 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,124 @@ use std::sync::Arc;
use crate::signature::{
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD,
};
use crate::{Signature, TypeSignature};
use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature};
use arrow::{
compute::can_cast_types,
datatypes::{DataType, TimeUnit},
};
use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims};
use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err, Result,
};

use super::binary::{comparison_binary_numeric_coercion, comparison_coercion};

/// Performs type coercion for scalar function arguments.
///
/// Returns the data types to which each argument must be coerced to
/// match `signature`.
///
/// For more details on coercion in general, please see the
/// [`type_coercion`](crate::type_coercion) module.
pub fn data_types_with_scalar_udf(
current_types: &[DataType],
func: &ScalarUDF,
) -> Result<Vec<DataType>> {
let signature = func.signature();

if current_types.is_empty() {
if signature.type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!(
"[data_types_with_scalar_udf] signature {:?} does not support zero arguments.",
&signature.type_signature
);
}
}

let valid_types =
get_valid_types_with_scalar_udf(&signature.type_signature, current_types, func)?;

if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

// Try and coerce the argument types to match the signature, returning the
// coerced types from the first matching signature.
for valid_types in valid_types {
if let Some(types) = maybe_data_types(&valid_types, current_types) {
return Ok(types);
}
}

// none possible -> Error
plan_err!(
"[data_types_with_scalar_udf] Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
)
}

pub fn data_types_with_aggregate_udf(
current_types: &[DataType],
func: &AggregateUDF,
) -> Result<Vec<DataType>> {
let signature = func.signature();

if current_types.is_empty() {
if signature.type_signature.supports_zero_argument() {
return Ok(vec![]);
} else {
return plan_err!(
"[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
);
}
}

let valid_types = get_valid_types_with_aggregate_udf(
&signature.type_signature,
current_types,
func,
)?;
if valid_types
.iter()
.any(|data_type| data_type == current_types)
{
return Ok(current_types.to_vec());
}

// Try and coerce the argument types to match the signature, returning the
// coerced types from the first matching signature.
for valid_types in valid_types {
if let Some(types) = maybe_data_types(&valid_types, current_types) {
return Ok(types);
}
}

// none possible -> Error
plan_err!(
"[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
)
}

/// Performs type coercion for function arguments.
///
/// Returns the data types to which each argument must be coerced to
/// match `signature`.
///
/// For more details on coercion in general, please see the
/// [`type_coercion`](crate::type_coercion) module.
///
/// This function will be replaced with [data_types_with_scalar_udf],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure we want to replace this function over time -- I think having the basic simple Signatures that handle most common coercions makes sense to have in DataFusion core (even if it could be done purely in a udf) as it will make creating UDFs easier for users

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with you

/// [data_types_with_aggregate_udf], and data_types_with_window_udf gradually.
pub fn data_types(
current_types: &[DataType],
signature: &Signature,
Expand All @@ -46,7 +147,7 @@ pub fn data_types(
return Ok(vec![]);
} else {
return plan_err!(
"Coercion from {:?} to the signature {:?} failed.",
"[data_types] Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
);
Expand All @@ -72,12 +173,56 @@ pub fn data_types(

// none possible -> Error
plan_err!(
"Coercion from {:?} to the signature {:?} failed.",
"[data_types] Coercion from {:?} to the signature {:?} failed.",
current_types,
&signature.type_signature
)
}

fn get_valid_types_with_scalar_udf(
signature: &TypeSignature,
current_types: &[DataType],
func: &ScalarUDF,
) -> Result<Vec<Vec<DataType>>> {
let valid_types = match signature {
TypeSignature::UserDefined => match func.coerce_types(current_types) {
Ok(coerced_types) => vec![coerced_types],
Err(e) => return exec_err!("User-defined coercion failed with {:?}", e),
},
TypeSignature::OneOf(signatures) => signatures
.iter()
.filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok())
.flatten()
.collect::<Vec<_>>(),
_ => get_valid_types(signature, current_types)?,
};

Ok(valid_types)
}

fn get_valid_types_with_aggregate_udf(
signature: &TypeSignature,
current_types: &[DataType],
func: &AggregateUDF,
) -> Result<Vec<Vec<DataType>>> {
let valid_types = match signature {
TypeSignature::UserDefined => match func.coerce_types(current_types) {
Ok(coerced_types) => vec![coerced_types],
Err(e) => return exec_err!("User-defined coercion failed with {:?}", e),
},
TypeSignature::OneOf(signatures) => signatures
.iter()
.filter_map(|t| {
get_valid_types_with_aggregate_udf(t, current_types, func).ok()
})
.flatten()
.collect::<Vec<_>>(),
_ => get_valid_types(signature, current_types)?,
};

Ok(valid_types)
}

/// Returns a Vec of all possible valid argument types for the given signature.
fn get_valid_types(
signature: &TypeSignature,
Expand Down Expand Up @@ -184,32 +329,14 @@ fn get_valid_types(
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::VariadicEqual => {
let new_type = current_types.iter().skip(1).try_fold(
current_types.first().unwrap().clone(),
|acc, x| {
// The coerced types found by `comparison_coercion` are not guaranteed to be
// coercible for the arguments. `comparison_coercion` returns more loose
// types that can be coerced to both `acc` and `x` for comparison purpose.
// See `maybe_data_types` for the actual coercion.
let coerced_type = comparison_coercion(&acc, x);
if let Some(coerced_type) = coerced_type {
Ok(coerced_type)
} else {
internal_err!("Coercion from {acc:?} to {x:?} failed.")
}
},
);

match new_type {
Ok(new_type) => vec![vec![new_type; current_types.len()]],
Err(e) => return Err(e),
}
TypeSignature::UserDefined => {
return internal_err!(
"User-defined signature should be handled by function-specific coerce_types."
)
}
TypeSignature::VariadicAny => {
vec![current_types.to_vec()]
}

TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::ArraySignature(ref function_signature) => match function_signature
{
Expand Down
4 changes: 4 additions & 0 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ impl AggregateUDF {
pub fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
self.inner.create_groups_accumulator()
}

pub fn coerce_types(&self, _args: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("coerce_types not implemented for {:?} yet", self.name())
}
}

impl<F> From<F> for AggregateUDF
Expand Down
29 changes: 29 additions & 0 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ impl ScalarUDF {
pub fn short_circuits(&self) -> bool {
self.inner.short_circuits()
}

/// See [`ScalarUDFImpl::coerce_types`] for more details.
pub fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.inner.coerce_types(arg_types)
}
}

impl<F> From<F> for ScalarUDF
Expand Down Expand Up @@ -420,6 +425,29 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
fn short_circuits(&self) -> bool {
false
}

/// Coerce arguments of a function call to types that the function can evaluate.
///
/// This function is only called if [`ScalarUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most
/// UDFs should return one of the other variants of `TypeSignature` which handle common
/// cases
///
/// See the [type coercion module](crate::type_coercion)
/// documentation for more details on type coercion
///
/// For example, if your function requires a floating point arguments, but the user calls
/// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]`
/// to ensure the argument was cast to `1::double`
///
/// # Parameters
/// * `arg_types`: The argument types of the arguments this function with
///
/// # Return value
/// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call
/// arguments to these specific types.
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}
}

/// ScalarUDF that adds an alias to the underlying function. It is better to
Expand All @@ -446,6 +474,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
self.inner.name()
}
Expand Down
Loading