Skip to content

Commit

Permalink
Add type coercion for scalar UDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Sep 6, 2022
1 parent 191d8b7 commit 141d2b0
Showing 1 changed file with 107 additions and 3 deletions.
110 changes: 107 additions & 3 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ use datafusion_expr::binary_rule::coerce_types;
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::logical_plan::builder::build_join_schema;
use datafusion_expr::logical_plan::JoinType;
use datafusion_expr::type_coercion::data_types;
use datafusion_expr::utils::from_plan;
use datafusion_expr::ExprSchemable;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_expr::{ExprSchemable, Signature};

#[derive(Default)]
pub struct TypeCoercion {}
Expand Down Expand Up @@ -116,18 +117,61 @@ impl ExprRewriter for TypeCoercionRewriter {
}
}
}
Expr::ScalarUDF { fun, args } => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
Ok(Expr::ScalarUDF {
fun: fun.clone(),
args: new_expr,
})
}
_ => Ok(expr),
}
}
}

/// Returns `expressions` coerced to types compatible with
/// `signature`, if possible.
///
/// See the module level documentation for more detail on coercion.
pub fn coerce_arguments_for_signature(
expressions: &[Expr],
schema: &DFSchema,
signature: &Signature,
) -> Result<Vec<Expr>> {
if expressions.is_empty() {
return Ok(vec![]);
}

let current_types = expressions
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;

let new_types = data_types(&current_types, signature)?;

expressions
.iter()
.enumerate()
.map(|(i, expr)| expr.clone().cast_to(&new_types[i], schema))
.collect::<Result<Vec<_>>>()
}

#[cfg(test)]
mod test {
use crate::type_coercion::TypeCoercion;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, Result};
use datafusion_expr::logical_plan::{EmptyRelation, Projection};
use datafusion_expr::{lit, LogicalPlan};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF,
Signature, Volatility,
};
use std::sync::Arc;

#[test]
Expand Down Expand Up @@ -167,4 +211,64 @@ mod test {
\n EmptyRelation", &format!("{:?}", plan));
Ok(())
}

#[test]
fn scalar_udf() -> Result<()> {
let empty = empty();
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!());
let udf = Expr::ScalarUDF {
fun: Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Float32], Volatility::Stable),
&return_type,
&fun,
)),
args: vec![lit(123_i32)],
};
let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

#[test]
fn scalar_udf_invalid_input() -> Result<()> {
let empty = empty();
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!());
let udf = Expr::ScalarUDF {
fun: Arc::new(ScalarUDF::new(
"TestScalarUDF",
&Signature::uniform(1, vec![DataType::Int32], Volatility::Stable),
&return_type,
&fun,
)),
args: vec![lit("Apple")],
};
let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config).err().unwrap();
assert_eq!(
"Plan(\"Coercion from [Utf8] to the signature Uniform(1, [Int32]) failed.\")",
&format!("{:?}", plan)
);
Ok(())
}

fn empty() -> Arc<LogicalPlan> {
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
}));
empty
}
}

0 comments on commit 141d2b0

Please sign in to comment.