Skip to content

Commit

Permalink
feat: support UDAF in substrait producer/consumer (#8119)
Browse files Browse the repository at this point in the history
* feat: support UDAF in substrait producer/consumer

Signed-off-by: Ruihang Xia <[email protected]>

* Update datafusion/substrait/src/logical_plan/consumer.rs

Co-authored-by: Andrew Lamb <[email protected]>

* remove redundent to_lowercase

Signed-off-by: Ruihang Xia <[email protected]>

---------

Signed-off-by: Ruihang Xia <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
waynexia and alamb authored Nov 12, 2023
1 parent f67c20f commit 824bb66
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 25 deletions.
45 changes: 31 additions & 14 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use async_recursion::async_recursion;
use datafusion::arrow::datatypes::{DataType, Field, TimeUnit};
use datafusion::common::{not_impl_err, DFField, DFSchema, DFSchemaRef};

use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
aggregate_function, window_function::find_df_window_func, BinaryExpr,
BuiltinScalarFunction, Case, Expr, LogicalPlan, Operator,
Expand Down Expand Up @@ -365,6 +366,7 @@ pub async fn from_substrait_rel(
_ => false,
};
from_substrait_agg_func(
ctx,
f,
input.schema(),
extensions,
Expand Down Expand Up @@ -660,6 +662,7 @@ pub async fn from_substriat_func_args(

/// Convert Substrait AggregateFunction to DataFusion Expr
pub async fn from_substrait_agg_func(
ctx: &SessionContext,
f: &AggregateFunction,
input_schema: &DFSchema,
extensions: &HashMap<u32, &String>,
Expand All @@ -680,23 +683,37 @@ pub async fn from_substrait_agg_func(
args.push(arg_expr?.as_ref().clone());
}

let fun = match extensions.get(&f.function_reference) {
Some(function_name) => {
aggregate_function::AggregateFunction::from_str(function_name)
}
None => not_impl_err!(
"Aggregated function not found: function anchor = {:?}",
let Some(function_name) = extensions.get(&f.function_reference) else {
return plan_err!(
"Aggregate function not registered: function anchor = {:?}",
f.function_reference
),
);
};

Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
fun: fun.unwrap(),
args,
distinct,
filter,
order_by,
})))
// try udaf first, then built-in aggr fn.
if let Ok(fun) = ctx.udaf(function_name) {
Ok(Arc::new(Expr::AggregateUDF(expr::AggregateUDF {
fun,
args,
filter,
order_by,
})))
} else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name)
{
Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction {
fun,
args,
distinct,
filter,
order_by,
})))
} else {
not_impl_err!(
"Aggregated function {} is not supported: function anchor = {:?}",
function_name,
f.function_reference
)
}
}

/// Convert Substrait Rex to DataFusion Expr
Expand Down
41 changes: 33 additions & 8 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,7 @@ pub fn to_substrait_agg_measure(
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
}
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
let function_anchor = _register_function(fun.to_string(), extension_info);
Ok(Measure {
measure: Some(AggregateFunction {
function_reference: function_anchor,
Expand All @@ -610,6 +609,34 @@ pub fn to_substrait_agg_measure(
}
})
}
Expr::AggregateUDF(expr::AggregateUDF{ fun, args, filter, order_by }) =>{
let sorts = if let Some(order_by) = order_by {
order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::<Result<Vec<_>>>()?
} else {
vec![]
};
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) });
}
let function_anchor = _register_function(fun.name.clone(), extension_info);
Ok(Measure {
measure: Some(AggregateFunction {
function_reference: function_anchor,
arguments,
sorts,
output_type: None,
invocation: AggregationInvocation::All as i32,
phase: AggregationPhase::Unspecified as i32,
args: vec![],
options: vec![],
}),
filter: match filter {
Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?),
None => None
}
})
},
Expr::Alias(Alias{expr,..})=> {
to_substrait_agg_measure(expr, schema, extension_info)
}
Expand Down Expand Up @@ -703,8 +730,8 @@ pub fn make_binary_op_scalar_func(
HashMap<String, u32>,
),
) -> Expression {
let function_name = operator_to_name(op).to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
let function_anchor =
_register_function(operator_to_name(op).to_string(), extension_info);
Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -807,8 +834,7 @@ pub fn to_substrait_rex(
)?)),
});
}
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
let function_anchor = _register_function(fun.to_string(), extension_info);
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
function_reference: function_anchor,
Expand Down Expand Up @@ -973,8 +999,7 @@ pub fn to_substrait_rex(
window_frame,
}) => {
// function reference
let function_name = fun.to_string().to_lowercase();
let function_anchor = _register_function(function_name, extension_info);
let function_anchor = _register_function(fun.to_string(), extension_info);
// arguments
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
Expand Down
64 changes: 61 additions & 3 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.

use datafusion::arrow::array::ArrayRef;
use datafusion::physical_plan::Accumulator;
use datafusion::scalar::ScalarValue;
use datafusion_substrait::logical_plan::{
consumer::from_substrait_plan, producer::to_substrait_plan,
};
Expand All @@ -28,7 +31,9 @@ use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::SessionState;
use datafusion::execution::registry::SerializerRegistry;
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::logical_expr::{Extension, LogicalPlan, UserDefinedLogicalNode};
use datafusion::logical_expr::{
Extension, LogicalPlan, UserDefinedLogicalNode, Volatility,
};
use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST;
use datafusion::prelude::*;

Expand Down Expand Up @@ -636,6 +641,56 @@ async fn extension_logical_plan() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn roundtrip_aggregate_udf() -> Result<()> {
#[derive(Debug)]
struct Dummy {}

impl Accumulator for Dummy {
fn state(&self) -> datafusion::error::Result<Vec<ScalarValue>> {
Ok(vec![])
}

fn update_batch(
&mut self,
_values: &[ArrayRef],
) -> datafusion::error::Result<()> {
Ok(())
}

fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> {
Ok(())
}

fn evaluate(&self) -> datafusion::error::Result<ScalarValue> {
Ok(ScalarValue::Float64(None))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}
}

let dummy_agg = create_udaf(
// the name; used to represent it in plan descriptions and in the registry, to use in SQL.
"dummy_agg",
// the input type; DataFusion guarantees that the first entry of `values` in `update` has this type.
vec![DataType::Int64],
// the return type; DataFusion expects this to match the type returned by `evaluate`.
Arc::new(DataType::Int64),
Volatility::Immutable,
// This is the accumulator factory; DataFusion uses it to create new accumulators.
Arc::new(|_| Ok(Box::new(Dummy {}))),
// This is the description of the state. `state()` must match the types here.
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);

let ctx = create_context().await?;
ctx.register_udaf(dummy_agg);

roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await
}

fn check_post_join_filters(rel: &Rel) -> Result<()> {
// search for target_rel and field value in proto
match &rel.rel_type {
Expand Down Expand Up @@ -772,8 +827,7 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> {
Ok(())
}

async fn roundtrip(sql: &str) -> Result<()> {
let ctx = create_context().await?;
async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> {
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
let proto = to_substrait_plan(&plan, &ctx)?;
Expand All @@ -789,6 +843,10 @@ async fn roundtrip(sql: &str) -> Result<()> {
Ok(())
}

async fn roundtrip(sql: &str) -> Result<()> {
roundtrip_with_ctx(sql, create_context().await?).await
}

async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> {
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
Expand Down

0 comments on commit 824bb66

Please sign in to comment.