From 141d2b0523c8cb385c24dfdddab3f651c1e9b652 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 6 Sep 2022 08:08:51 -0600 Subject: [PATCH 1/4] Add type coercion for scalar UDFs --- datafusion/optimizer/src/type_coercion.rs | 110 +++++++++++++++++++++- 1 file changed, 107 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index d9f16159926f..0a0cecec60f5 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -24,9 +24,10 @@ use datafusion_expr::binary_rule::coerce_types; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; use datafusion_expr::logical_plan::builder::build_join_schema; use datafusion_expr::logical_plan::JoinType; +use datafusion_expr::type_coercion::data_types; use datafusion_expr::utils::from_plan; -use datafusion_expr::ExprSchemable; use datafusion_expr::{Expr, LogicalPlan}; +use datafusion_expr::{ExprSchemable, Signature}; #[derive(Default)] pub struct TypeCoercion {} @@ -116,18 +117,61 @@ impl ExprRewriter for TypeCoercionRewriter { } } } + Expr::ScalarUDF { fun, args } => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + &fun.signature, + )?; + Ok(Expr::ScalarUDF { + fun: fun.clone(), + args: new_expr, + }) + } _ => Ok(expr), } } } +/// Returns `expressions` coerced to types compatible with +/// `signature`, if possible. +/// +/// See the module level documentation for more detail on coercion. +pub fn coerce_arguments_for_signature( + expressions: &[Expr], + schema: &DFSchema, + signature: &Signature, +) -> Result> { + if expressions.is_empty() { + return Ok(vec![]); + } + + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + let new_types = data_types(¤t_types, signature)?; + + expressions + .iter() + .enumerate() + .map(|(i, expr)| expr.clone().cast_to(&new_types[i], schema)) + .collect::>>() +} + #[cfg(test)] mod test { use crate::type_coercion::TypeCoercion; use crate::{OptimizerConfig, OptimizerRule}; + use arrow::datatypes::DataType; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::logical_plan::{EmptyRelation, Projection}; - use datafusion_expr::{lit, LogicalPlan}; + use datafusion_expr::{ + lit, + logical_plan::{EmptyRelation, Projection}, + Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, + Signature, Volatility, + }; use std::sync::Arc; #[test] @@ -167,4 +211,64 @@ mod test { \n EmptyRelation", &format!("{:?}", plan)); Ok(()) } + + #[test] + fn scalar_udf() -> Result<()> { + let empty = empty(); + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); + let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); + let udf = Expr::ScalarUDF { + fun: Arc::new(ScalarUDF::new( + "TestScalarUDF", + &Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), + &return_type, + &fun, + )), + args: vec![lit(123_i32)], + }; + let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty, None)?); + let rule = TypeCoercion::new(); + let mut config = OptimizerConfig::default(); + let plan = rule.optimize(&plan, &mut config)?; + assert_eq!( + "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation", + &format!("{:?}", plan) + ); + Ok(()) + } + + #[test] + fn scalar_udf_invalid_input() -> Result<()> { + let empty = empty(); + let return_type: ReturnTypeFunction = + Arc::new(move |_| Ok(Arc::new(DataType::Utf8))); + let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!()); + let udf = Expr::ScalarUDF { + fun: Arc::new(ScalarUDF::new( + "TestScalarUDF", + &Signature::uniform(1, vec![DataType::Int32], Volatility::Stable), + &return_type, + &fun, + )), + args: vec![lit("Apple")], + }; + let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty, None)?); + let rule = TypeCoercion::new(); + let mut config = OptimizerConfig::default(); + let plan = rule.optimize(&plan, &mut config).err().unwrap(); + assert_eq!( + "Plan(\"Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.\")", + &format!("{:?}", plan) + ); + Ok(()) + } + + fn empty() -> Arc { + let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + })); + empty + } } From e9bc400b67f2ab83916c762e085d85ae30f61ff1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 6 Sep 2022 09:01:19 -0600 Subject: [PATCH 2/4] clippy --- datafusion/optimizer/src/type_coercion.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 0a0cecec60f5..3411453dd49b 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -265,10 +265,9 @@ mod test { } fn empty() -> Arc { - let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::new(DFSchema::empty()), - })); - empty + })) } } From 6b37258f8e86c0bf2ae98af2b1e2fa6f502cfe22 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 6 Sep 2022 12:07:56 -0600 Subject: [PATCH 3/4] clippy --- datafusion/optimizer/src/type_coercion.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 71612b60680f..c7f5905bc523 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -104,7 +104,7 @@ impl ExprRewriter for TypeCoercionRewriter { &fun.signature, )?; Ok(Expr::ScalarUDF { - fun: fun.clone(), + fun, args: new_expr, }) } From 55c8d1325b494d0a5d182b72bfccd6d1419f1f6a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 7 Sep 2022 07:58:51 -0600 Subject: [PATCH 4/4] make coerce_arguments_for_signature private --- datafusion/optimizer/src/type_coercion.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index c7f5905bc523..42c081af3dfe 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -117,7 +117,7 @@ impl ExprRewriter for TypeCoercionRewriter { /// `signature`, if possible. /// /// See the module level documentation for more detail on coercion. -pub fn coerce_arguments_for_signature( +fn coerce_arguments_for_signature( expressions: &[Expr], schema: &DFSchema, signature: &Signature,