From edbdefe0fd20285aeac4a5dee1e8c3e87aa62706 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 3 Feb 2025 09:50:52 -0500 Subject: [PATCH] Support `array_concat` for `Utf8View` (#14378) * Add tests for concatenating differnet string types * clean up code * fmt --- .../expr-common/src/type_coercion/binary.rs | 52 +------------------ datafusion/functions-nested/src/concat.rs | 46 +++++++++------- datafusion/sqllogictest/test_files/array.slt | 41 +++++++++++++++ 3 files changed, 69 insertions(+), 70 deletions(-) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 3195218ea28e..de0608426e72 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -30,8 +30,8 @@ use arrow::datatypes::{ }; use datafusion_common::types::NativeType; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err, - Diagnostic, Result, Span, Spans, + exec_err, internal_err, plan_datafusion_err, plan_err, Diagnostic, Result, Span, + Spans, }; use itertools::Itertools; @@ -928,54 +928,6 @@ fn get_wider_decimal_type( } } -/// Returns the wider type among arguments `lhs` and `rhs`. -/// The wider type is the type that can safely represent values from both types -/// without information loss. Returns an Error if types are incompatible. -pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { - use arrow::datatypes::DataType::*; - Ok(match (lhs, rhs) { - (lhs, rhs) if lhs == rhs => lhs.clone(), - // Right UInt is larger than left UInt. - (UInt8, UInt16 | UInt32 | UInt64) | (UInt16, UInt32 | UInt64) | (UInt32, UInt64) | - // Right Int is larger than left Int. - (Int8, Int16 | Int32 | Int64) | (Int16, Int32 | Int64) | (Int32, Int64) | - // Right Float is larger than left Float. - (Float16, Float32 | Float64) | (Float32, Float64) | - // Right String is larger than left String. - (Utf8, LargeUtf8) | - // Any right type is wider than a left hand side Null. - (Null, _) => rhs.clone(), - // Left UInt is larger than right UInt. - (UInt16 | UInt32 | UInt64, UInt8) | (UInt32 | UInt64, UInt16) | (UInt64, UInt32) | - // Left Int is larger than right Int. - (Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) | - // Left Float is larger than right Float. - (Float32 | Float64, Float16) | (Float64, Float32) | - // Left String is larger than right String. - (LargeUtf8, Utf8) | - // Any left type is wider than a right hand side Null. - (_, Null) => lhs.clone(), - (List(lhs_field), List(rhs_field)) => { - let field_type = - get_wider_type(lhs_field.data_type(), rhs_field.data_type())?; - if lhs_field.name() != rhs_field.name() { - return Err(exec_datafusion_err!( - "There is no wider type that can represent both {lhs} and {rhs}." - )); - } - assert_eq!(lhs_field.name(), rhs_field.name()); - let field_name = lhs_field.name(); - let nullable = lhs_field.is_nullable() | rhs_field.is_nullable(); - List(Arc::new(Field::new(field_name, field_type, nullable))) - } - (_, _) => { - return Err(exec_datafusion_err!( - "There is no wider type that can represent both {lhs} and {rhs}." - )); - } - }) -} - /// Convert the numeric data type to the decimal data type. /// We support signed and unsigned integer types and floating-point type. fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option { diff --git a/datafusion/functions-nested/src/concat.rs b/datafusion/functions-nested/src/concat.rs index a6557e36da37..93305faad56f 100644 --- a/datafusion/functions-nested/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -29,8 +29,7 @@ use datafusion_common::{ cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims, }; use datafusion_expr::{ - type_coercion::binary::get_wider_type, ColumnarValue, Documentation, ScalarUDFImpl, - Signature, Volatility, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -276,25 +275,32 @@ impl ScalarUDFImpl for ArrayConcat { let mut expr_type = DataType::Null; let mut max_dims = 0; for arg_type in arg_types { - match arg_type { - DataType::List(field) => { - if !field.data_type().equals_datatype(&DataType::Null) { - let dims = list_ndims(arg_type); - expr_type = match max_dims.cmp(&dims) { - Ordering::Greater => expr_type, - Ordering::Equal => get_wider_type(&expr_type, arg_type)?, - Ordering::Less => { - max_dims = dims; - arg_type.clone() - } - }; + let DataType::List(field) = arg_type else { + return plan_err!( + "The array_concat function can only accept list as the args." + ); + }; + if !field.data_type().equals_datatype(&DataType::Null) { + let dims = list_ndims(arg_type); + expr_type = match max_dims.cmp(&dims) { + Ordering::Greater => expr_type, + Ordering::Equal => { + if expr_type == DataType::Null { + arg_type.clone() + } else if !expr_type.equals_datatype(arg_type) { + return plan_err!( + "It is not possible to concatenate arrays of different types. Expected: {}, got: {}", expr_type, arg_type + ); + } else { + expr_type + } } - } - _ => { - return plan_err!( - "The array_concat function can only accept list as the args." - ) - } + + Ordering::Less => { + max_dims = dims; + arg_type.clone() + } + }; } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index ce66acd670c6..ff701b55407c 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2870,6 +2870,47 @@ select array_concat([]); ---- [] +# Concatenating strings arrays +query ? +select array_concat( + ['1', '2'], + ['3'] +); +---- +[1, 2, 3] + +# Concatenating string arrays +query ? +select array_concat( + [arrow_cast('1', 'LargeUtf8'), arrow_cast('2', 'LargeUtf8')], + [arrow_cast('3', 'LargeUtf8')] +); +---- +[1, 2, 3] + +# Concatenating stringview +query ? +select array_concat( + [arrow_cast('1', 'Utf8View'), arrow_cast('2', 'Utf8View')], + [arrow_cast('3', 'Utf8View')] +); +---- +[1, 2, 3] + +# Concatenating Mixed types (doesn't work) +query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: LargeUtf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +select array_concat( + [arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], + [arrow_cast('3', 'LargeUtf8')] +); + +# Concatenating Mixed types (doesn't work) +query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +select array_concat( + [arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')], + [arrow_cast('3', 'Utf8View')] +); + # array_concat error query error DataFusion error: Error during planning: The array_concat function can only accept list as the args\. select array_concat(1, 2);