diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index b42f6e55901d..369eaecb1905 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -20,7 +20,8 @@ use arrow::datatypes::DataType; use datafusion_common::ExprSchema; use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result}; -use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, ScalarFunction}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; use datafusion_expr::AggregateUDF; use datafusion_expr::{ planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 66a325f21eaf..47014c7ce6fb 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -15,11 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Int64Array; use arrow::array::{ - make_array, Array, Capacities, MutableArrayData, Scalar, StringArray, + make_array, make_comparator, Array, BooleanArray, Capacities, Datum, + MutableArrayData, Scalar, StringArray, StructArray, }; +use arrow::array::{Int64Array, ListArray}; +use arrow::compute::SortOptions; use arrow::datatypes::DataType; +use arrow_buffer::NullBuffer; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result, @@ -187,13 +190,27 @@ impl ScalarUDFImpl for GetFieldFunc { fn process_map_array( array: Arc, - key_scalar: Scalar, + key_array: Arc, ) -> Result where K: Array + 'static, { let map_array = as_map_array(array.as_ref())?; - let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; + let keys = if key_array.data_type().is_nested() { + let comparator = make_comparator( + map_array.keys().as_ref(), + key_array.as_ref(), + SortOptions::default(), + )?; + let len = map_array.keys().len().min(key_array.len()); + let values = (0..len).map(|i| comparator(i, i).is_eq()).collect(); + let nulls = + NullBuffer::union(map_array.keys().nulls(), key_array.nulls()); + BooleanArray::new(values, nulls) + } else { + let be_compared = Scalar::new(key_array); + arrow::compute::kernels::cmp::eq(&be_compared, map_array.keys())? + }; let original_data = map_array.entries().column(1).to_data(); let capacity = Capacities::Array(original_data.len()); @@ -225,14 +242,28 @@ impl ScalarUDFImpl for GetFieldFunc { match (array.data_type(), name) { (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { - let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); - process_map_array::(array, key_scalar) + let key_array: Arc = Arc::new(StringArray::from(vec![k.clone()])); + process_map_array::(array, key_array) } (DataType::Map(_, _), ScalarValue::Int64(Some(k))) => { - let key_scalar = Scalar::new(Int64Array::from(vec![*k])); - process_map_array::(array, key_scalar) + let key_array: Arc = Arc::new(Int64Array::from(vec![*k])); + process_map_array::(array, key_array) + } + (DataType::Map(_, _), ScalarValue::List(arr)) => { + let key_array: Arc = Arc::new((**arr).clone()); + process_map_array::(array, key_array) + } + (DataType::Map(_, _), ScalarValue::Struct(arr)) => { + process_map_array::(array, Arc::new(arr.clone() as Arc)) + } + (DataType::Map(_, _), other) => { + let data_type = other.data_type(); + if data_type.is_nested() { + return exec_err!("unsupported type {:?} for map access", data_type); + } else { + process_map_array::>(array, other.to_array()?) + } } - (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { let as_struct_array = as_struct_array(&array)?; match as_struct_array.column_by_name(k) { diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 996d3f78adac..42a4ba621801 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -608,7 +608,26 @@ NULL NULL NULL +# test for negative scenario +query ? +SELECT column1[-1] FROM map_array_table_1; +---- +NULL +NULL +NULL +NULL +query ? +SELECT column1[1000] FROM map_array_table_1; +---- +NULL +NULL +NULL +NULL + + +query error DataFusion error: Arrow error: Invalid argument error +SELECT column1[NULL] FROM map_array_table_1; query ??? select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_1; @@ -740,3 +759,28 @@ drop table map_array_table_1; statement ok drop table map_array_table_2; + + +statement ok +create table tt as values(MAP{[1,2,3]:1}, MAP {{'a':1, 'b':2}:2}, MAP{true: 3}); + +# accessing using an array +query I +select column1[make_array(1, 2, 3)] from tt; +---- +1 + +# accessing using a struct +query I +select column2[{a:1, b: 2}] from tt; +---- +2 + +# accessing using Bool +query I +select column3[true] from tt; +---- +3 + +statement ok +drop table tt;