Skip to content

Commit

Permalink
Make make_scalar_function private
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jan 15, 2024
1 parent e966a10 commit 256fa22
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 56 deletions.
53 changes: 34 additions & 19 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ use arrow::compute::kernels::numeric::add;
use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
use datafusion::{
execution::registry::FunctionRegistry,
physical_plan::functions::make_scalar_function, test_util,
};
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
use datafusion_expr::{
Expand Down Expand Up @@ -87,12 +84,18 @@ async fn scalar_udf() -> Result<()> {

ctx.register_batch("t", batch)?;

let myfunc = |args: &[ArrayRef]| {
let l = as_int32_array(&args[0])?;
let r = as_int32_array(&args[1])?;
Ok(Arc::new(add(l, r)?) as ArrayRef)
};
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(l) = &args[0] else {
panic!()
};
let ColumnarValue::Array(r) = &args[0] else {
panic!()
};

let l = as_int32_array(l)?;
let r = as_int32_array(r)?;
Ok(ColumnarValue::Array(Arc::new(add(l, r)?) as ArrayRef))
});

ctx.register_udf(create_udf(
"my_add",
Expand Down Expand Up @@ -163,11 +166,15 @@ async fn scalar_udf_zero_params() -> Result<()> {

ctx.register_batch("t", batch)?;
// create function just returns 100 regardless of inp
let myfunc = |args: &[ArrayRef]| {
let num_rows = args[0].len();
Ok(Arc::new((0..num_rows).map(|_| 100).collect::<Int32Array>()) as ArrayRef)
};
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
let num_rows = array.len();
Ok(ColumnarValue::Array(Arc::new(
(0..num_rows).map(|_| 100).collect::<Int32Array>(),
) as ArrayRef))
});

ctx.register_udf(create_udf(
"get_100",
Expand Down Expand Up @@ -307,8 +314,12 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
Ok(ColumnarValue::Array(Arc::clone(array)))
});

ctx.register_udf(create_udf(
"MY_FUNC",
Expand Down Expand Up @@ -348,8 +359,12 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
Ok(ColumnarValue::Array(Arc::clone(array)))
});

let udf = create_udf(
"dummy",
Expand Down
22 changes: 13 additions & 9 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1321,9 +1321,7 @@ mod tests {
assert_contains, cast::as_int32_array, plan_datafusion_err, DFField, ToDFSchema,
};
use datafusion_expr::{interval_arithmetic::Interval, *};
use datafusion_physical_expr::{
execution_props::ExecutionProps, functions::make_scalar_function,
};
use datafusion_physical_expr::execution_props::ExecutionProps;

use chrono::{DateTime, TimeZone, Utc};

Expand Down Expand Up @@ -1438,9 +1436,16 @@ mod tests {
let input_types = vec![DataType::Int32, DataType::Int32];
let return_type = Arc::new(DataType::Int32);

let fun = |args: &[ArrayRef]| {
let arg0 = as_int32_array(&args[0])?;
let arg1 = as_int32_array(&args[1])?;
let fun = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(arg0) = &args[0] else {
panic!()
};
let ColumnarValue::Array(arg1) = &args[1] else {
panic!()
};

let arg0 = as_int32_array(arg0)?;
let arg1 = as_int32_array(&arg1)?;

// 2. perform the computation
let array = arg0
Expand All @@ -1456,10 +1461,9 @@ mod tests {
})
.collect::<Int32Array>();

Ok(Arc::new(array) as ArrayRef)
};
Ok(ColumnarValue::Array(Arc::new(array) as ArrayRef))
});

let fun = make_scalar_function(fun);
Arc::new(create_udf(
"udf_add",
input_types,
Expand Down
6 changes: 4 additions & 2 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,11 @@ pub(crate) enum Hint {
AcceptsSingular,
}

/// decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function
/// Decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function
/// and vice-versa after evaluation.
pub fn make_scalar_function<F>(inner: F) -> ScalarFunctionImplementation
/// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar.
/// That's said its output will be same for all input rows in a batch.
pub(crate) fn make_scalar_function<F>(inner: F) -> ScalarFunctionImplementation
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef> + Sync + Send + 'static,
{
Expand Down
3 changes: 1 addition & 2 deletions datafusion/proto/src/bytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use crate::physical_plan::{
AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
};
use crate::protobuf;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion_common::{plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::{
create_udaf, create_udf, create_udwf, AggregateUDF, Expr, LogicalPlan, Volatility,
Expand Down Expand Up @@ -117,7 +116,7 @@ impl Serializeable for Expr {
vec![],
Arc::new(arrow::datatypes::DataType::Null),
Volatility::Immutable,
make_scalar_function(|_| unimplemented!()),
Arc::new(|_| unimplemented!()),
)))
}

Expand Down
16 changes: 9 additions & 7 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion::parquet::file::properties::{WriterProperties, WriterVersion};
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext};
use datafusion::test_util::{TestTableFactory, TestTableProvider};
use datafusion_common::file_options::csv_writer::CsvWriterOptions;
Expand All @@ -53,9 +52,9 @@ use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore};
use datafusion_expr::{
col, create_udaf, lit, Accumulator, AggregateFunction,
BuiltinScalarFunction::{Sqrt, Substr},
Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility,
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF,
WindowUDFImpl,
ColumnarValue, Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast,
Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition, WindowUDF, WindowUDFImpl,
};
use datafusion_proto::bytes::{
logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec,
Expand Down Expand Up @@ -1592,9 +1591,12 @@ fn roundtrip_aggregate_udf() {

#[test]
fn roundtrip_scalar_udf() {
let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef);

let scalar_fn = make_scalar_function(fn_impl);
let scalar_fn = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef))
});

let udf = create_udf(
"dummy",
Expand Down
14 changes: 8 additions & 6 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ use datafusion::physical_plan::expressions::{
GetFieldAccessExpr, GetIndexedFieldExpr, NotExpr, NthValue, PhysicalSortExpr, Sum,
};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::physical_plan::insert::FileSinkExec;
use datafusion::physical_plan::joins::{
HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode,
Expand All @@ -73,8 +72,8 @@ use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::stats::Precision;
use datafusion_common::{FileTypeWriterOptions, Result};
use datafusion_expr::{
Accumulator, AccumulatorFactoryFunction, AggregateUDF, Signature, SimpleAggregateUDF,
WindowFrame, WindowFrameBound,
Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, Signature,
SimpleAggregateUDF, WindowFrame, WindowFrameBound,
};
use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec};
use datafusion_proto::protobuf;
Expand Down Expand Up @@ -568,9 +567,12 @@ fn roundtrip_scalar_udf() -> Result<()> {

let input = Arc::new(EmptyExec::new(schema.clone()));

let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef);

let scalar_fn = make_scalar_function(fn_impl);
let scalar_fn = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef))
});

let udf = create_udf(
"dummy",
Expand Down
12 changes: 7 additions & 5 deletions datafusion/proto/tests/cases/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ use arrow::array::ArrayRef;
use arrow::datatypes::DataType;

use datafusion::execution::FunctionRegistry;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion::prelude::SessionContext;
use datafusion_expr::{col, create_udf, lit};
use datafusion_expr::{col, create_udf, lit, ColumnarValue};
use datafusion_expr::{Expr, Volatility};
use datafusion_proto::bytes::Serializeable;

Expand Down Expand Up @@ -226,9 +225,12 @@ fn roundtrip_deeply_nested() {

/// return a `SessionContext` with a `dummy` function registered as a UDF
fn context_with_udf() -> SessionContext {
let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef);

let scalar_fn = make_scalar_function(fn_impl);
let scalar_fn = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(array) = &args[0] else {
panic!()
};
Ok(ColumnarValue::Array(Arc::new(array.clone()) as ArrayRef))
});

let udf = create_udf(
"dummy",
Expand Down
18 changes: 12 additions & 6 deletions datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ use arrow::array::{
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use arrow::record_batch::RecordBatch;
use datafusion::execution::context::SessionState;
use datafusion::logical_expr::{create_udf, Expr, ScalarUDF, Volatility};
use datafusion::physical_expr::functions::make_scalar_function;
use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionConfig;
use datafusion::{
Expand Down Expand Up @@ -356,9 +355,16 @@ pub async fn register_metadata_tables(ctx: &SessionContext) {
/// Create a UDF function named "example". See the `sample_udf.rs` example
/// file for an explanation of the API.
fn create_example_udf() -> ScalarUDF {
let adder = make_scalar_function(|args: &[ArrayRef]| {
let lhs = as_float64_array(&args[0]).expect("cast failed");
let rhs = as_float64_array(&args[1]).expect("cast failed");
let adder = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Array(lhs) = &args[0] else {
panic!()
};
let ColumnarValue::Array(rhs) = &args[1] else {
panic!()
};

let lhs = as_float64_array(lhs).expect("cast failed");
let rhs = as_float64_array(rhs).expect("cast failed");
let array = lhs
.iter()
.zip(rhs.iter())
Expand All @@ -367,7 +373,7 @@ fn create_example_udf() -> ScalarUDF {
_ => None,
})
.collect::<Float64Array>();
Ok(Arc::new(array) as ArrayRef)
Ok(ColumnarValue::Array(Arc::new(array) as ArrayRef))
});
create_udf(
"example",
Expand Down

0 comments on commit 256fa22

Please sign in to comment.