From 6903259068f11917f8b445002593a094bf3662f8 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 13 Dec 2024 14:21:02 +0100 Subject: [PATCH 1/4] Fix get_type for higher-order array functions --- .../expr/src/type_coercion/functions.rs | 10 +-- datafusion/functions-nested/src/extract.rs | 81 +++++++++++++++++++ 2 files changed, 85 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 9d15d9693992..9fbd46e37f7c 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -24,7 +24,7 @@ use arrow::{ use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, types::{LogicalType, NativeType}, - utils::{coerced_fixed_size_list_to_list, list_ndims}, + utils::list_ndims, Result, }; use datafusion_expr_common::{ @@ -416,11 +416,9 @@ fn get_valid_types( } fn array(array_type: &DataType) -> Option { match array_type { - DataType::List(_) - | DataType::LargeList(_) - | DataType::FixedSizeList(_, _) => { - let array_type = coerced_fixed_size_list_to_list(array_type); - Some(array_type) + DataType::List(_) => Some(array_type.clone()), + DataType::LargeList(field) | DataType::FixedSizeList(field, _) => { + Some(DataType::List(Arc::clone(field))) } _ => None, } diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index fc35f0076330..4fdac6d9730a 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -993,3 +993,84 @@ where let data = mutable.freeze(); Ok(arrow::array::make_array(data)) } + +#[cfg(test)] +mod tests { + use super::array_element_udf; + use arrow_schema::{DataType, Field}; + use datafusion_common::{Column, DFSchema, ScalarValue}; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{cast, Expr, ExprSchemable}; + use std::collections::HashMap; + + #[test] + fn test_array_element_return_type() { + let complex_type = DataType::FixedSizeList( + Field::new("some_arbitrary_test_field", DataType::Int32, false).into(), + 13, + ); + let array_type = + DataType::List(Field::new_list_field(complex_type.clone(), true).into()); + let index_type = DataType::Int64; + + let schema = DFSchema::from_unqualified_fields( + vec![ + Field::new("my_array", array_type.clone(), false), + Field::new("my_index", index_type.clone(), false), + ] + .into(), + HashMap::default(), + ) + .unwrap(); + + let udf = array_element_udf(); + + // ScalarUDFImpl::return_type + assert_eq!( + udf.return_type(&[array_type.clone(), index_type.clone()]) + .unwrap(), + complex_type + ); + + // ScalarUDFImpl::return_type_from_exprs with typed exprs + assert_eq!( + udf.return_type_from_exprs( + &[ + cast(Expr::Literal(ScalarValue::Null), array_type.clone()), + cast(Expr::Literal(ScalarValue::Null), index_type.clone()), + ], + &schema, + &[array_type.clone(), index_type.clone()] + ) + .unwrap(), + complex_type + ); + + // ScalarUDFImpl::return_type_from_exprs with exprs not carrying type + assert_eq!( + udf.return_type_from_exprs( + &[ + Expr::Column(Column::new_unqualified("my_array")), + Expr::Column(Column::new_unqualified("my_index")), + ], + &schema, + &[array_type.clone(), index_type.clone()] + ) + .unwrap(), + complex_type + ); + + // Via ExprSchemable::get_type (e.g. SimplifyInfo) + let udf_expr = Expr::ScalarFunction(ScalarFunction { + func: array_element_udf(), + args: vec![ + Expr::Column(Column::new_unqualified("my_array")), + Expr::Column(Column::new_unqualified("my_index")), + ], + }); + assert_eq!( + ExprSchemable::get_type(&udf_expr, &schema).unwrap(), + complex_type + ); + } +} From 6d81418cc4804fb9a9ebfffd6354fac38181ac7b Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 13 Dec 2024 14:52:54 +0100 Subject: [PATCH 2/4] Fix recursive flatten The fix is covered by recursive flatten test case in array.slt --- datafusion/expr-common/src/signature.rs | 6 ++++++ .../expr/src/type_coercion/functions.rs | 21 +++++++++++++++++++ datafusion/functions-nested/src/flatten.rs | 11 ++++++++-- 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 32cbb6d0aecb..69960acc57e1 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -175,6 +175,9 @@ pub enum ArrayFunctionSignature { /// The function takes a single argument that must be a List/LargeList/FixedSizeList /// or something that can be coerced to one of those types. Array, + /// A function takes a single argument that must be a List/LargeList/FixedSizeList + /// which gets coerced to List, with element type recursively coerced to List too if it is list-like. + RecursiveArray, /// Specialized Signature for MapArray /// The function takes a single argument that must be a MapArray MapArray, @@ -198,6 +201,9 @@ impl std::fmt::Display for ArrayFunctionSignature { ArrayFunctionSignature::Array => { write!(f, "array") } + ArrayFunctionSignature::RecursiveArray => { + write!(f, "recursive_array") + } ArrayFunctionSignature::MapArray => { write!(f, "map_array") } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 9fbd46e37f7c..cbc1ce428c2f 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,6 +21,7 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; +use datafusion_common::utils::coerced_fixed_size_list_to_list; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, plan_err, types::{LogicalType, NativeType}, @@ -414,6 +415,7 @@ fn get_valid_types( _ => Ok(vec![vec![]]), } } + fn array(array_type: &DataType) -> Option { match array_type { DataType::List(_) => Some(array_type.clone()), @@ -424,6 +426,18 @@ fn get_valid_types( } } + fn recursive_array(array_type: &DataType) -> Option { + match array_type { + DataType::List(_) + | DataType::LargeList(_) + | DataType::FixedSizeList(_, _) => { + let array_type = coerced_fixed_size_list_to_list(array_type); + Some(array_type) + } + _ => None, + } + } + fn function_length_check(length: usize, expected_length: usize) -> Result<()> { if length < 1 { return plan_err!( @@ -651,6 +665,13 @@ fn get_valid_types( array(¤t_types[0]) .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) } + ArrayFunctionSignature::RecursiveArray => { + if current_types.len() != 1 { + return Ok(vec![vec![]]); + } + recursive_array(¤t_types[0]) + .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) + } ArrayFunctionSignature::MapArray => { if current_types.len() != 1 { return Ok(vec![vec![]]); diff --git a/datafusion/functions-nested/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs index 9d2cb8a3f667..7cb52ae4c5c9 100644 --- a/datafusion/functions-nested/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -28,7 +28,8 @@ use datafusion_common::cast::{ use datafusion_common::{exec_err, Result}; use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, }; use std::any::Any; use std::sync::{Arc, OnceLock}; @@ -56,7 +57,13 @@ impl Default for Flatten { impl Flatten { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), + signature: Signature { + // TODO (https://github.com/apache/datafusion/issues/13757) flatten should be single-step, not recursive + type_signature: TypeSignature::ArraySignature( + ArrayFunctionSignature::RecursiveArray, + ), + volatility: Volatility::Immutable, + }, aliases: vec![], } } From 038a01571457e7932e1cb7192cefc654c9259dc3 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Fri, 13 Dec 2024 22:41:43 +0100 Subject: [PATCH 3/4] Restore "keep LargeList" in Array signature --- datafusion/expr/src/type_coercion/functions.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index cbc1ce428c2f..199f649c37e8 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -418,10 +418,8 @@ fn get_valid_types( fn array(array_type: &DataType) -> Option { match array_type { - DataType::List(_) => Some(array_type.clone()), - DataType::LargeList(field) | DataType::FixedSizeList(field, _) => { - Some(DataType::List(Arc::clone(field))) - } + DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()), + DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))), _ => None, } } From 69fcf249a673831dcbcd9f42b837e488ac1716e4 Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Tue, 17 Dec 2024 21:59:45 +0100 Subject: [PATCH 4/4] clarify naming in the test --- datafusion/functions-nested/src/extract.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/datafusion/functions-nested/src/extract.rs b/datafusion/functions-nested/src/extract.rs index 4fdac6d9730a..f972597bbf84 100644 --- a/datafusion/functions-nested/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -1003,14 +1003,16 @@ mod tests { use datafusion_expr::{cast, Expr, ExprSchemable}; use std::collections::HashMap; + // Regression test for https://github.com/apache/datafusion/issues/13755 #[test] - fn test_array_element_return_type() { - let complex_type = DataType::FixedSizeList( + fn test_array_element_return_type_fixed_size_list() { + let fixed_size_list_type = DataType::FixedSizeList( Field::new("some_arbitrary_test_field", DataType::Int32, false).into(), 13, ); - let array_type = - DataType::List(Field::new_list_field(complex_type.clone(), true).into()); + let array_type = DataType::List( + Field::new_list_field(fixed_size_list_type.clone(), true).into(), + ); let index_type = DataType::Int64; let schema = DFSchema::from_unqualified_fields( @@ -1029,7 +1031,7 @@ mod tests { assert_eq!( udf.return_type(&[array_type.clone(), index_type.clone()]) .unwrap(), - complex_type + fixed_size_list_type ); // ScalarUDFImpl::return_type_from_exprs with typed exprs @@ -1043,7 +1045,7 @@ mod tests { &[array_type.clone(), index_type.clone()] ) .unwrap(), - complex_type + fixed_size_list_type ); // ScalarUDFImpl::return_type_from_exprs with exprs not carrying type @@ -1057,7 +1059,7 @@ mod tests { &[array_type.clone(), index_type.clone()] ) .unwrap(), - complex_type + fixed_size_list_type ); // Via ExprSchemable::get_type (e.g. SimplifyInfo) @@ -1070,7 +1072,7 @@ mod tests { }); assert_eq!( ExprSchemable::get_type(&udf_expr, &schema).unwrap(), - complex_type + fixed_size_list_type ); } }