From 433445d113f8ca87672374488d25eaa5cac15a53 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 14 Jan 2025 17:15:55 +0000 Subject: [PATCH 1/3] fix case of f(dict_array, dict_array) invocation --- src/common.rs | 46 ++++++++++++++++++++++------ tests/main.rs | 84 ++++++++++++++++++++++++++++----------------------- 2 files changed, 83 insertions(+), 47 deletions(-) diff --git a/src/common.rs b/src/common.rs index fdba8d2..37056e2 100644 --- a/src/common.rs +++ b/src/common.rs @@ -3,11 +3,11 @@ use std::sync::Arc; use datafusion::arrow::array::{ Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, - StringArray, StringViewArray, UInt64Array, UnionArray, + StringArray, StringViewArray, UInt64Array, }; use datafusion::arrow::compute::take; use datafusion::arrow::datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Int64Type, UInt64Type, + ArrowDictionaryKeyType, ArrowNativeType, ArrowNativeTypeOp, DataType, Int64Type, UInt64Type, }; use datafusion::arrow::downcast_dictionary_array; use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; @@ -95,6 +95,7 @@ impl From for JsonPath<'_> { } } +#[derive(Debug)] enum JsonPathArgs<'a> { Array(&'a ArrayRef), Scalars(Vec>), @@ -152,7 +153,7 @@ pub fn invoke> + 'static, I>( let path = JsonPathArgs::extract_path(path_args)?; match (json_arg, path) { (ColumnarValue::Array(json_array), JsonPathArgs::Array(path_array)) => { - invoke_array_array(json_array, path_array, to_array, jiter_find, return_dict).map(ColumnarValue::Array) + invoke_array_array(json_array, path_array, to_array, jiter_find).map(ColumnarValue::Array) } (ColumnarValue::Array(json_array), JsonPathArgs::Scalars(path)) => { invoke_array_scalars(json_array, &path, to_array, jiter_find, return_dict).map(ColumnarValue::Array) @@ -171,13 +172,38 @@ fn invoke_array_array> + 'static, I>( path_array: &ArrayRef, to_array: impl Fn(C) -> DataFusionResult, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, - return_dict: bool, ) -> DataFusionResult { downcast_dictionary_array!( json_array => { - let values = invoke_array_array(json_array.values(), path_array, to_array, jiter_find, return_dict)?; - post_process_dict(json_array, values, return_dict) - } + fn wrap_as_dictionary(original: &DictionaryArray, new_values: ArrayRef) -> DictionaryArray { + assert_eq!(original.keys().len(), new_values.len()); + let mut key = K::Native::ZERO; + let key_range = std::iter::from_fn(move || { + let next = key; + key = key.add_checked(K::Native::ONE).expect("keys exhausted"); + Some(next) + }).take(new_values.len()); + let mut keys = PrimitiveArray::::from_iter_values(key_range); + if is_json_union(new_values.data_type()) { + // JSON union: post-process the array to set keys to null where the union member is null + let type_ids = new_values.as_union().type_ids(); + keys = mask_dictionary_keys(&keys, type_ids); + } + DictionaryArray::::new(keys, new_values) + } + + // TODO: in theory if path_array is _also_ a dictionary we could work out the unique key + // combinations and do less work, but this can be left as a future optimization + let output = match json_array.values().data_type() { + DataType::Utf8 => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), + DataType::LargeUtf8 => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), + DataType::Utf8View => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), + other => exec_err!("unexpected json array type {:?}", other), + }?; + + // ensure return is a dictionary to satisfy the declaration above in return_type_check + Ok(Arc::new(wrap_as_dictionary(json_array, output))) + }, DataType::Utf8 => zip_apply(json_array.as_string::().iter(), path_array, to_array, jiter_find), DataType::LargeUtf8 => zip_apply(json_array.as_string::().iter(), path_array, to_array, jiter_find), DataType::Utf8View => zip_apply(json_array.as_string_view().iter(), path_array, to_array, jiter_find), @@ -239,6 +265,7 @@ fn invoke_scalar_array> + 'static, I>( to_array, jiter_find, ) + // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary? .map(ColumnarValue::Array) } @@ -250,6 +277,7 @@ fn invoke_scalar_scalars( ) -> DataFusionResult { let s = extract_json_scalar(scalar)?; let v = jiter_find(s, path).ok(); + // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary? Ok(ColumnarValue::Scalar(to_scalar(v))) } @@ -321,7 +349,7 @@ fn post_process_dict( if return_dict { if is_json_union(result_values.data_type()) { // JSON union: post-process the array to set keys to null where the union member is null - let type_ids = result_values.as_any().downcast_ref::().unwrap().type_ids(); + let type_ids = result_values.as_union().type_ids(); Ok(Arc::new(DictionaryArray::new( mask_dictionary_keys(dict_array.keys(), type_ids), result_values, @@ -413,7 +441,7 @@ impl From for GetError { /// /// That said, doing this might also be an optimization for cases like null-checking without needing /// to check the value union array. -fn mask_dictionary_keys(keys: &PrimitiveArray, type_ids: &[i8]) -> PrimitiveArray { +fn mask_dictionary_keys(keys: &PrimitiveArray, type_ids: &[i8]) -> PrimitiveArray { let mut null_mask = vec![true; keys.len()]; for (i, k) in keys.iter().enumerate() { match k { diff --git a/tests/main.rs b/tests/main.rs index 54a3f9b..e2eb5e5 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1600,34 +1600,38 @@ async fn test_json_object_keys_nested() { #[tokio::test] async fn test_lookup_literal_column_matrix() { let sql = r#" -WITH attr_names AS ( - -- this is deliberately a different length to json_columns - SELECT unnest(['a', 'b', 'c']) as attr_name -), json_columns AS ( +WITH json_columns AS ( SELECT unnest(['{"a": 1}', '{"b": 2}']) as json_column +), attr_names AS ( + -- this is deliberately a different length to json_columns + SELECT + unnest(['a', 'b', 'c']) as attr_name, + arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name_dict ) SELECT attr_name, json_column, 'a' = attr_name, - json_get('{"a": 1}', attr_name), -- literal lookup with column - json_get('{"a": 1}', 'a'), -- literal lookup with literal - json_get(json_column, attr_name), -- column lookup with column - json_get(json_column, 'a') -- column lookup with literal -FROM attr_names, json_columns + json_get('{"a": 1}', attr_name), -- literal lookup with column + json_get('{"a": 1}', attr_name_dict), -- literal lookup with dict column + json_get('{"a": 1}', 'a'), -- literal lookup with literal + json_get(json_column, attr_name), -- column lookup with column + json_get(json_column, attr_name_dict), -- column lookup with dict column + json_get(json_column, 'a') -- column lookup with literal +FROM json_columns, attr_names "#; let expected = [ - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,Utf8(\"a\")) |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} |", - "| a | {\"b\": 2} | true | {int=1} | {int=1} | {null=} | {null=} |", - "| b | {\"a\": 1} | false | {null=} | {int=1} | {null=} | {int=1} |", - "| b | {\"b\": 2} | false | {null=} | {int=1} | {int=2} | {null=} |", - "| c | {\"a\": 1} | false | {null=} | {int=1} | {null=} | {int=1} |", - "| c | {\"b\": 2} | false | {null=} | {int=1} | {null=} | {null=} |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name_dict) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,attr_names.attr_name_dict) | json_get(json_columns.json_column,Utf8(\"a\")) |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} |", + "| a | {\"b\": 2} | true | {int=1} | {int=1} | {int=1} | {null=} | {null=} | {null=} |", + "| b | {\"a\": 1} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {int=1} |", + "| b | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {int=2} | {int=2} | {null=} |", + "| c | {\"a\": 1} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {int=1} |", + "| c | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {null=} |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", ]; let batches = run_query(sql).await.unwrap(); @@ -1637,36 +1641,40 @@ FROM attr_names, json_columns #[tokio::test] async fn test_lookup_literal_column_matrix_dictionaries() { let sql = r#" -WITH attr_names AS ( - -- this is deliberately a different length to json_columns - SELECT arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name -), json_columns AS ( +WITH json_columns AS ( SELECT arrow_cast(unnest(['{"a": 1}', '{"b": 2}']), 'Dictionary(Int32, Utf8)') as json_column +), attr_names AS ( + -- this is deliberately a different length to json_columns + SELECT + unnest(['a', 'b', 'c']) as attr_name, + arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name_dict ) SELECT attr_name, json_column, 'a' = attr_name, - json_get('{"a": 1}', attr_name), -- literal lookup with column - json_get('{"a": 1}', 'a'), -- literal lookup with literal - json_get(json_column, attr_name), -- column lookup with column - json_get(json_column, 'a') -- column lookup with literal -FROM attr_names, json_columns + json_get('{"a": 1}', attr_name), -- literal lookup with column + json_get('{"a": 1}', attr_name_dict), -- literal lookup with dict column + json_get('{"a": 1}', 'a'), -- literal lookup with literal + json_get(json_column, attr_name), -- column lookup with column + json_get(json_column, attr_name_dict), -- column lookup with dict column + json_get(json_column, 'a') -- column lookup with literal +FROM json_columns, attr_names "#; // NB as compared to the non-dictionary case, we null out the dictionary keys if the return // value is a dict, which is why we get true nulls instead of {null=} let expected = [ - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,Utf8(\"a\")) |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} |", - "| a | {\"b\": 2} | true | {int=1} | {int=1} | | |", - "| b | {\"a\": 1} | false | {null=} | {int=1} | | {int=1} |", - "| b | {\"b\": 2} | false | {null=} | {int=1} | {int=2} | |", - "| c | {\"a\": 1} | false | {null=} | {int=1} | | {int=1} |", - "| c | {\"b\": 2} | false | {null=} | {int=1} | | |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name_dict) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,attr_names.attr_name_dict) | json_get(json_columns.json_column,Utf8(\"a\")) |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} |", + "| b | {\"a\": 1} | false | {null=} | {null=} | {int=1} | | | {int=1} |", + "| c | {\"a\": 1} | false | {null=} | {null=} | {int=1} | | | {int=1} |", + "| a | {\"b\": 2} | true | {int=1} | {int=1} | {int=1} | | | |", + "| b | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {int=2} | {int=2} | |", + "| c | {\"b\": 2} | false | {null=} | {null=} | {int=1} | | | |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", ]; let batches = run_query(sql).await.unwrap(); From 122f83c91aecd0af10eec6f35ecb5a28aa52da00 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 14 Jan 2025 17:58:07 +0000 Subject: [PATCH 2/3] fix tests --- src/common.rs | 26 +++++++++++++++++++++----- src/common_union.rs | 6 +++++- tests/main.rs | 4 ++-- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/common.rs b/src/common.rs index 37056e2..0893e97 100644 --- a/src/common.rs +++ b/src/common.rs @@ -14,7 +14,9 @@ use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarV use datafusion::logical_expr::ColumnarValue; use jiter::{Jiter, JiterError, Peek}; -use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL}; +use crate::common_union::{ + is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL, +}; /// General implementation of `ScalarUDFImpl::return_type`. /// @@ -153,7 +155,7 @@ pub fn invoke> + 'static, I>( let path = JsonPathArgs::extract_path(path_args)?; match (json_arg, path) { (ColumnarValue::Array(json_array), JsonPathArgs::Array(path_array)) => { - invoke_array_array(json_array, path_array, to_array, jiter_find).map(ColumnarValue::Array) + invoke_array_array(json_array, path_array, to_array, jiter_find, return_dict).map(ColumnarValue::Array) } (ColumnarValue::Array(json_array), JsonPathArgs::Scalars(path)) => { invoke_array_scalars(json_array, &path, to_array, jiter_find, return_dict).map(ColumnarValue::Array) @@ -172,6 +174,7 @@ fn invoke_array_array> + 'static, I>( path_array: &ArrayRef, to_array: impl Fn(C) -> DataFusionResult, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, + return_dict: bool, ) -> DataFusionResult { downcast_dictionary_array!( json_array => { @@ -198,11 +201,24 @@ fn invoke_array_array> + 'static, I>( DataType::Utf8 => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), DataType::LargeUtf8 => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), DataType::Utf8View => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), - other => exec_err!("unexpected json array type {:?}", other), + other => if let Some(child_array) = nested_json_array_ref(json_array.values(), is_object_lookup_array(path_array.data_type())) { + // Horrible case: dict containing union as input with array for paths, figure + // out from the path type which union members we should access, repack the + // dictionary and then recurse. + // + // Use direct return because if return_dict applies, the recursion will handle it. + return invoke_array_array(&(Arc::new(json_array.with_values(child_array.clone())) as _), path_array, to_array, jiter_find, return_dict) + } else { + exec_err!("unexpected json array type {:?}", other) + } }?; - // ensure return is a dictionary to satisfy the declaration above in return_type_check - Ok(Arc::new(wrap_as_dictionary(json_array, output))) + if return_dict { + // ensure return is a dictionary to satisfy the declaration above in return_type_check + Ok(Arc::new(wrap_as_dictionary(json_array, output))) + } else { + Ok(output) + } }, DataType::Utf8 => zip_apply(json_array.as_string::().iter(), path_array, to_array, jiter_find), DataType::LargeUtf8 => zip_apply(json_array.as_string::().iter(), path_array, to_array, jiter_find), diff --git a/src/common_union.rs b/src/common_union.rs index 947086a..955f9af 100644 --- a/src/common_union.rs +++ b/src/common_union.rs @@ -22,9 +22,13 @@ pub fn is_json_union(data_type: &DataType) -> bool { /// * `object_lookup` - If `true`, extract from the "object" member of the union, /// otherwise extract from the "array" member pub(crate) fn nested_json_array(array: &ArrayRef, object_lookup: bool) -> Option<&StringArray> { + nested_json_array_ref(array, object_lookup).map(|a| a.as_string()) +} + +pub(crate) fn nested_json_array_ref(array: &ArrayRef, object_lookup: bool) -> Option<&ArrayRef> { let union_array: &UnionArray = array.as_any().downcast_ref::()?; let type_id = if object_lookup { TYPE_ID_OBJECT } else { TYPE_ID_ARRAY }; - union_array.child(type_id).as_any().downcast_ref() + Some(union_array.child(type_id)) } /// Extract a JSON string from a `JsonUnion` scalar diff --git a/tests/main.rs b/tests/main.rs index e2eb5e5..d71a910 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1626,10 +1626,10 @@ FROM json_columns, attr_names "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name_dict) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,attr_names.attr_name_dict) | json_get(json_columns.json_column,Utf8(\"a\")) |", "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} |", - "| a | {\"b\": 2} | true | {int=1} | {int=1} | {int=1} | {null=} | {null=} | {null=} |", "| b | {\"a\": 1} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {int=1} |", - "| b | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {int=2} | {int=2} | {null=} |", "| c | {\"a\": 1} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {int=1} |", + "| a | {\"b\": 2} | true | {int=1} | {int=1} | {int=1} | {null=} | {null=} | {null=} |", + "| b | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {int=2} | {int=2} | {null=} |", "| c | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {null=} |", "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", ]; From add403852c0e153502243e12ab95bcd430374602 Mon Sep 17 00:00:00 2001 From: David Hewitt Date: Tue, 14 Jan 2025 18:05:20 +0000 Subject: [PATCH 3/3] clippy --- src/common_union.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common_union.rs b/src/common_union.rs index 955f9af..52b13e3 100644 --- a/src/common_union.rs +++ b/src/common_union.rs @@ -22,7 +22,7 @@ pub fn is_json_union(data_type: &DataType) -> bool { /// * `object_lookup` - If `true`, extract from the "object" member of the union, /// otherwise extract from the "array" member pub(crate) fn nested_json_array(array: &ArrayRef, object_lookup: bool) -> Option<&StringArray> { - nested_json_array_ref(array, object_lookup).map(|a| a.as_string()) + nested_json_array_ref(array, object_lookup).map(AsArray::as_string) } pub(crate) fn nested_json_array_ref(array: &ArrayRef, object_lookup: bool) -> Option<&ArrayRef> {