Skip to content

Commit

Permalink
fix: add ArrayAndElementAndOptionalIndex for proper casting in `arr…
Browse files Browse the repository at this point in the history
…ay_position` (apache#9233)

* fix: use `array_and_element` for proper casting in array_position

* fix: fix typo

* feat: add ArrayAndElementAndOptionalIndex

* refactor: cleanup

* docs: add docs to enum variants

* doc: fix cargo doc formatting snafu

* test: add a couple of tests

* refactor: update names, early exit logic

* test: add null test for array_position
  • Loading branch information
tshauck authored Feb 19, 2024
1 parent 497cb9d commit 60ee91e
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 2 deletions.
2 changes: 1 addition & 1 deletion datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayDistinct => Signature::any(1, self.volatility()),
BuiltinScalarFunction::ArrayPosition => {
Signature::variadic_any(self.volatility())
Signature::array_and_element_and_optional_index(self.volatility())
}
BuiltinScalarFunction::ArrayPositions => {
Signature::array_and_element(self.volatility())
Expand Down
17 changes: 16 additions & 1 deletion datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub enum TypeSignature {
/// DataFusion attempts to coerce all argument types to match the first argument's type
///
/// # Examples
/// Given types in signature should be coericible to the same final type.
/// Given types in signature should be coercible to the same final type.
/// A function such as `make_array` is `VariadicEqual`.
///
/// `make_array(i32, i64) -> make_array(i64, i64)`
Expand Down Expand Up @@ -132,7 +132,10 @@ pub enum ArrayFunctionSignature {
/// The first argument should be non-list or list, and the second argument should be List/LargeList.
/// The first argument's list dimension should be one dimension less than the second argument's list dimension.
ElementAndArray,
/// Specialized Signature for Array functions of the form (List/LargeList, Index)
ArrayAndIndex,
/// Specialized Signature for Array functions of the form (List/LargeList, Element, Optional Index)
ArrayAndElementAndOptionalIndex,
}

impl std::fmt::Display for ArrayFunctionSignature {
Expand All @@ -141,6 +144,9 @@ impl std::fmt::Display for ArrayFunctionSignature {
ArrayFunctionSignature::ArrayAndElement => {
write!(f, "array, element")
}
ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => {
write!(f, "array, element, [index]")
}
ArrayFunctionSignature::ElementAndArray => {
write!(f, "element, array")
}
Expand Down Expand Up @@ -292,6 +298,15 @@ impl Signature {
volatility,
}
}
/// Specialized Signature for Array functions with an optional index
pub fn array_and_element_and_optional_index(volatility: Volatility) -> Self {
Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::ArrayAndElementAndOptionalIndex,
),
volatility,
}
}
/// Specialized Signature for ArrayPrepend and similar functions
pub fn element_and_array(volatility: Volatility) -> Self {
Signature {
Expand Down
32 changes: 32 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,35 @@ fn get_valid_types(
_ => Ok(vec![vec![]]),
}
}
fn array_element_and_optional_index(
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
// make sure there's 2 or 3 arguments
if !(current_types.len() == 2 || current_types.len() == 3) {
return Ok(vec![vec![]]);
}

let first_two_types = &current_types[0..2];
let mut valid_types = array_append_or_prepend_valid_types(first_two_types, true)?;

// Early return if there are only 2 arguments
if current_types.len() == 2 {
return Ok(valid_types);
}

let valid_types_with_index = valid_types
.iter()
.map(|t| {
let mut t = t.clone();
t.push(DataType::Int64);
t
})
.collect::<Vec<_>>();

valid_types.extend(valid_types_with_index);

Ok(valid_types)
}
fn array_and_index(current_types: &[DataType]) -> Result<Vec<Vec<DataType>>> {
if current_types.len() != 2 {
return Ok(vec![vec![]]);
Expand Down Expand Up @@ -184,6 +213,9 @@ fn get_valid_types(
ArrayFunctionSignature::ArrayAndElement => {
return array_append_or_prepend_valid_types(current_types, true)
}
ArrayFunctionSignature::ArrayAndElementAndOptionalIndex => {
return array_element_and_optional_index(current_types)
}
ArrayFunctionSignature::ArrayAndIndex => {
return array_and_index(current_types)
}
Expand Down
25 changes: 25 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2603,6 +2603,31 @@ select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4,
----
2 2

query I
SELECT array_position(arrow_cast([5, 2, 3, 4, 5], 'List(Int32)'), 5)
----
1

query I
SELECT array_position(arrow_cast([5, 2, 3, 4, 5], 'List(Int32)'), 5, 2)
----
5

query I
SELECT array_position(arrow_cast([1, 1, 100, 1, 1], 'LargeList(Int32)'), 100)
----
3

query I
SELECT array_position([1, 2, 3], 'foo')
----
NULL

query I
SELECT array_position([1, 2, 3], 'foo', 2)
----
NULL

# list_position scalar function #5 (function alias `array_position`)
query III
select list_position(['h', 'e', 'l', 'l', 'o'], 'l'), list_position([1, 2, 3, 4, 5], 5), list_position([1, 1, 1], 1);
Expand Down

0 comments on commit 60ee91e

Please sign in to comment.