Skip to content

Commit

Permalink
Support array_concat for Utf8View (#14378)
Browse files Browse the repository at this point in the history
* Add tests for concatenating differnet string types

* clean up code

* fmt
  • Loading branch information
alamb authored Feb 3, 2025
1 parent 67bc04c commit edbdefe
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 70 deletions.
52 changes: 2 additions & 50 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use arrow::datatypes::{
};
use datafusion_common::types::NativeType;
use datafusion_common::{
exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err,
Diagnostic, Result, Span, Spans,
exec_err, internal_err, plan_datafusion_err, plan_err, Diagnostic, Result, Span,
Spans,
};
use itertools::Itertools;

Expand Down Expand Up @@ -928,54 +928,6 @@ fn get_wider_decimal_type(
}
}

/// Returns the wider type among arguments `lhs` and `rhs`.
/// The wider type is the type that can safely represent values from both types
/// without information loss. Returns an Error if types are incompatible.
pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result<DataType> {
use arrow::datatypes::DataType::*;
Ok(match (lhs, rhs) {
(lhs, rhs) if lhs == rhs => lhs.clone(),
// Right UInt is larger than left UInt.
(UInt8, UInt16 | UInt32 | UInt64) | (UInt16, UInt32 | UInt64) | (UInt32, UInt64) |
// Right Int is larger than left Int.
(Int8, Int16 | Int32 | Int64) | (Int16, Int32 | Int64) | (Int32, Int64) |
// Right Float is larger than left Float.
(Float16, Float32 | Float64) | (Float32, Float64) |
// Right String is larger than left String.
(Utf8, LargeUtf8) |
// Any right type is wider than a left hand side Null.
(Null, _) => rhs.clone(),
// Left UInt is larger than right UInt.
(UInt16 | UInt32 | UInt64, UInt8) | (UInt32 | UInt64, UInt16) | (UInt64, UInt32) |
// Left Int is larger than right Int.
(Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) |
// Left Float is larger than right Float.
(Float32 | Float64, Float16) | (Float64, Float32) |
// Left String is larger than right String.
(LargeUtf8, Utf8) |
// Any left type is wider than a right hand side Null.
(_, Null) => lhs.clone(),
(List(lhs_field), List(rhs_field)) => {
let field_type =
get_wider_type(lhs_field.data_type(), rhs_field.data_type())?;
if lhs_field.name() != rhs_field.name() {
return Err(exec_datafusion_err!(
"There is no wider type that can represent both {lhs} and {rhs}."
));
}
assert_eq!(lhs_field.name(), rhs_field.name());
let field_name = lhs_field.name();
let nullable = lhs_field.is_nullable() | rhs_field.is_nullable();
List(Arc::new(Field::new(field_name, field_type, nullable)))
}
(_, _) => {
return Err(exec_datafusion_err!(
"There is no wider type that can represent both {lhs} and {rhs}."
));
}
})
}

/// Convert the numeric data type to the decimal data type.
/// We support signed and unsigned integer types and floating-point type.
fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
Expand Down
46 changes: 26 additions & 20 deletions datafusion/functions-nested/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ use datafusion_common::{
cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims,
};
use datafusion_expr::{
type_coercion::binary::get_wider_type, ColumnarValue, Documentation, ScalarUDFImpl,
Signature, Volatility,
ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_macros::user_doc;

Expand Down Expand Up @@ -276,25 +275,32 @@ impl ScalarUDFImpl for ArrayConcat {
let mut expr_type = DataType::Null;
let mut max_dims = 0;
for arg_type in arg_types {
match arg_type {
DataType::List(field) => {
if !field.data_type().equals_datatype(&DataType::Null) {
let dims = list_ndims(arg_type);
expr_type = match max_dims.cmp(&dims) {
Ordering::Greater => expr_type,
Ordering::Equal => get_wider_type(&expr_type, arg_type)?,
Ordering::Less => {
max_dims = dims;
arg_type.clone()
}
};
let DataType::List(field) = arg_type else {
return plan_err!(
"The array_concat function can only accept list as the args."
);
};
if !field.data_type().equals_datatype(&DataType::Null) {
let dims = list_ndims(arg_type);
expr_type = match max_dims.cmp(&dims) {
Ordering::Greater => expr_type,
Ordering::Equal => {
if expr_type == DataType::Null {
arg_type.clone()
} else if !expr_type.equals_datatype(arg_type) {
return plan_err!(
"It is not possible to concatenate arrays of different types. Expected: {}, got: {}", expr_type, arg_type
);
} else {
expr_type
}
}
}
_ => {
return plan_err!(
"The array_concat function can only accept list as the args."
)
}

Ordering::Less => {
max_dims = dims;
arg_type.clone()
}
};
}
}

Expand Down
41 changes: 41 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2870,6 +2870,47 @@ select array_concat([]);
----
[]

# Concatenating strings arrays
query ?
select array_concat(
['1', '2'],
['3']
);
----
[1, 2, 3]

# Concatenating string arrays
query ?
select array_concat(
[arrow_cast('1', 'LargeUtf8'), arrow_cast('2', 'LargeUtf8')],
[arrow_cast('3', 'LargeUtf8')]
);
----
[1, 2, 3]

# Concatenating stringview
query ?
select array_concat(
[arrow_cast('1', 'Utf8View'), arrow_cast('2', 'Utf8View')],
[arrow_cast('3', 'Utf8View')]
);
----
[1, 2, 3]

# Concatenating Mixed types (doesn't work)
query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: LargeUtf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)
select array_concat(
[arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')],
[arrow_cast('3', 'LargeUtf8')]
);

# Concatenating Mixed types (doesn't work)
query error DataFusion error: Error during planning: It is not possible to concatenate arrays of different types\. Expected: List\(Field \{ name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), got: List\(Field \{ name: "item", data_type: Utf8View, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)
select array_concat(
[arrow_cast('1', 'Utf8'), arrow_cast('2', 'Utf8')],
[arrow_cast('3', 'Utf8View')]
);

# array_concat error
query error DataFusion error: Error during planning: The array_concat function can only accept list as the args\.
select array_concat(1, 2);
Expand Down

0 comments on commit edbdefe

Please sign in to comment.