diff --git a/src/common.rs b/src/common.rs index fdba8d2..0893e97 100644 --- a/src/common.rs +++ b/src/common.rs @@ -3,18 +3,20 @@ use std::sync::Arc; use datafusion::arrow::array::{ Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, - StringArray, StringViewArray, UInt64Array, UnionArray, + StringArray, StringViewArray, UInt64Array, }; use datafusion::arrow::compute::take; use datafusion::arrow::datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Int64Type, UInt64Type, + ArrowDictionaryKeyType, ArrowNativeType, ArrowNativeTypeOp, DataType, Int64Type, UInt64Type, }; use datafusion::arrow::downcast_dictionary_array; use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; use datafusion::logical_expr::ColumnarValue; use jiter::{Jiter, JiterError, Peek}; -use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL}; +use crate::common_union::{ + is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL, +}; /// General implementation of `ScalarUDFImpl::return_type`. /// @@ -95,6 +97,7 @@ impl From for JsonPath<'_> { } } +#[derive(Debug)] enum JsonPathArgs<'a> { Array(&'a ArrayRef), Scalars(Vec>), @@ -175,9 +178,48 @@ fn invoke_array_array> + 'static, I>( ) -> DataFusionResult { downcast_dictionary_array!( json_array => { - let values = invoke_array_array(json_array.values(), path_array, to_array, jiter_find, return_dict)?; - post_process_dict(json_array, values, return_dict) - } + fn wrap_as_dictionary(original: &DictionaryArray, new_values: ArrayRef) -> DictionaryArray { + assert_eq!(original.keys().len(), new_values.len()); + let mut key = K::Native::ZERO; + let key_range = std::iter::from_fn(move || { + let next = key; + key = key.add_checked(K::Native::ONE).expect("keys exhausted"); + Some(next) + }).take(new_values.len()); + let mut keys = PrimitiveArray::::from_iter_values(key_range); + if is_json_union(new_values.data_type()) { + // JSON union: post-process the array to set keys to null where the union member is null + let type_ids = new_values.as_union().type_ids(); + keys = mask_dictionary_keys(&keys, type_ids); + } + DictionaryArray::::new(keys, new_values) + } + + // TODO: in theory if path_array is _also_ a dictionary we could work out the unique key + // combinations and do less work, but this can be left as a future optimization + let output = match json_array.values().data_type() { + DataType::Utf8 => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), + DataType::LargeUtf8 => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), + DataType::Utf8View => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), + other => if let Some(child_array) = nested_json_array_ref(json_array.values(), is_object_lookup_array(path_array.data_type())) { + // Horrible case: dict containing union as input with array for paths, figure + // out from the path type which union members we should access, repack the + // dictionary and then recurse. + // + // Use direct return because if return_dict applies, the recursion will handle it. + return invoke_array_array(&(Arc::new(json_array.with_values(child_array.clone())) as _), path_array, to_array, jiter_find, return_dict) + } else { + exec_err!("unexpected json array type {:?}", other) + } + }?; + + if return_dict { + // ensure return is a dictionary to satisfy the declaration above in return_type_check + Ok(Arc::new(wrap_as_dictionary(json_array, output))) + } else { + Ok(output) + } + }, DataType::Utf8 => zip_apply(json_array.as_string::().iter(), path_array, to_array, jiter_find), DataType::LargeUtf8 => zip_apply(json_array.as_string::().iter(), path_array, to_array, jiter_find), DataType::Utf8View => zip_apply(json_array.as_string_view().iter(), path_array, to_array, jiter_find), @@ -239,6 +281,7 @@ fn invoke_scalar_array> + 'static, I>( to_array, jiter_find, ) + // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary? .map(ColumnarValue::Array) } @@ -250,6 +293,7 @@ fn invoke_scalar_scalars( ) -> DataFusionResult { let s = extract_json_scalar(scalar)?; let v = jiter_find(s, path).ok(); + // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary? Ok(ColumnarValue::Scalar(to_scalar(v))) } @@ -321,7 +365,7 @@ fn post_process_dict( if return_dict { if is_json_union(result_values.data_type()) { // JSON union: post-process the array to set keys to null where the union member is null - let type_ids = result_values.as_any().downcast_ref::().unwrap().type_ids(); + let type_ids = result_values.as_union().type_ids(); Ok(Arc::new(DictionaryArray::new( mask_dictionary_keys(dict_array.keys(), type_ids), result_values, @@ -413,7 +457,7 @@ impl From for GetError { /// /// That said, doing this might also be an optimization for cases like null-checking without needing /// to check the value union array. -fn mask_dictionary_keys(keys: &PrimitiveArray, type_ids: &[i8]) -> PrimitiveArray { +fn mask_dictionary_keys(keys: &PrimitiveArray, type_ids: &[i8]) -> PrimitiveArray { let mut null_mask = vec![true; keys.len()]; for (i, k) in keys.iter().enumerate() { match k { diff --git a/src/common_union.rs b/src/common_union.rs index 947086a..52b13e3 100644 --- a/src/common_union.rs +++ b/src/common_union.rs @@ -22,9 +22,13 @@ pub fn is_json_union(data_type: &DataType) -> bool { /// * `object_lookup` - If `true`, extract from the "object" member of the union, /// otherwise extract from the "array" member pub(crate) fn nested_json_array(array: &ArrayRef, object_lookup: bool) -> Option<&StringArray> { + nested_json_array_ref(array, object_lookup).map(AsArray::as_string) +} + +pub(crate) fn nested_json_array_ref(array: &ArrayRef, object_lookup: bool) -> Option<&ArrayRef> { let union_array: &UnionArray = array.as_any().downcast_ref::()?; let type_id = if object_lookup { TYPE_ID_OBJECT } else { TYPE_ID_ARRAY }; - union_array.child(type_id).as_any().downcast_ref() + Some(union_array.child(type_id)) } /// Extract a JSON string from a `JsonUnion` scalar diff --git a/tests/main.rs b/tests/main.rs index 54a3f9b..d71a910 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1600,34 +1600,38 @@ async fn test_json_object_keys_nested() { #[tokio::test] async fn test_lookup_literal_column_matrix() { let sql = r#" -WITH attr_names AS ( - -- this is deliberately a different length to json_columns - SELECT unnest(['a', 'b', 'c']) as attr_name -), json_columns AS ( +WITH json_columns AS ( SELECT unnest(['{"a": 1}', '{"b": 2}']) as json_column +), attr_names AS ( + -- this is deliberately a different length to json_columns + SELECT + unnest(['a', 'b', 'c']) as attr_name, + arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name_dict ) SELECT attr_name, json_column, 'a' = attr_name, - json_get('{"a": 1}', attr_name), -- literal lookup with column - json_get('{"a": 1}', 'a'), -- literal lookup with literal - json_get(json_column, attr_name), -- column lookup with column - json_get(json_column, 'a') -- column lookup with literal -FROM attr_names, json_columns + json_get('{"a": 1}', attr_name), -- literal lookup with column + json_get('{"a": 1}', attr_name_dict), -- literal lookup with dict column + json_get('{"a": 1}', 'a'), -- literal lookup with literal + json_get(json_column, attr_name), -- column lookup with column + json_get(json_column, attr_name_dict), -- column lookup with dict column + json_get(json_column, 'a') -- column lookup with literal +FROM json_columns, attr_names "#; let expected = [ - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,Utf8(\"a\")) |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} |", - "| a | {\"b\": 2} | true | {int=1} | {int=1} | {null=} | {null=} |", - "| b | {\"a\": 1} | false | {null=} | {int=1} | {null=} | {int=1} |", - "| b | {\"b\": 2} | false | {null=} | {int=1} | {int=2} | {null=} |", - "| c | {\"a\": 1} | false | {null=} | {int=1} | {null=} | {int=1} |", - "| c | {\"b\": 2} | false | {null=} | {int=1} | {null=} | {null=} |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name_dict) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,attr_names.attr_name_dict) | json_get(json_columns.json_column,Utf8(\"a\")) |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} |", + "| b | {\"a\": 1} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {int=1} |", + "| c | {\"a\": 1} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {int=1} |", + "| a | {\"b\": 2} | true | {int=1} | {int=1} | {int=1} | {null=} | {null=} | {null=} |", + "| b | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {int=2} | {int=2} | {null=} |", + "| c | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {null=} |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", ]; let batches = run_query(sql).await.unwrap(); @@ -1637,36 +1641,40 @@ FROM attr_names, json_columns #[tokio::test] async fn test_lookup_literal_column_matrix_dictionaries() { let sql = r#" -WITH attr_names AS ( - -- this is deliberately a different length to json_columns - SELECT arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name -), json_columns AS ( +WITH json_columns AS ( SELECT arrow_cast(unnest(['{"a": 1}', '{"b": 2}']), 'Dictionary(Int32, Utf8)') as json_column +), attr_names AS ( + -- this is deliberately a different length to json_columns + SELECT + unnest(['a', 'b', 'c']) as attr_name, + arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name_dict ) SELECT attr_name, json_column, 'a' = attr_name, - json_get('{"a": 1}', attr_name), -- literal lookup with column - json_get('{"a": 1}', 'a'), -- literal lookup with literal - json_get(json_column, attr_name), -- column lookup with column - json_get(json_column, 'a') -- column lookup with literal -FROM attr_names, json_columns + json_get('{"a": 1}', attr_name), -- literal lookup with column + json_get('{"a": 1}', attr_name_dict), -- literal lookup with dict column + json_get('{"a": 1}', 'a'), -- literal lookup with literal + json_get(json_column, attr_name), -- column lookup with column + json_get(json_column, attr_name_dict), -- column lookup with dict column + json_get(json_column, 'a') -- column lookup with literal +FROM json_columns, attr_names "#; // NB as compared to the non-dictionary case, we null out the dictionary keys if the return // value is a dict, which is why we get true nulls instead of {null=} let expected = [ - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,Utf8(\"a\")) |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} |", - "| a | {\"b\": 2} | true | {int=1} | {int=1} | | |", - "| b | {\"a\": 1} | false | {null=} | {int=1} | | {int=1} |", - "| b | {\"b\": 2} | false | {null=} | {int=1} | {int=2} | |", - "| c | {\"a\": 1} | false | {null=} | {int=1} | | {int=1} |", - "| c | {\"b\": 2} | false | {null=} | {int=1} | | |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name_dict) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,attr_names.attr_name_dict) | json_get(json_columns.json_column,Utf8(\"a\")) |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} |", + "| b | {\"a\": 1} | false | {null=} | {null=} | {int=1} | | | {int=1} |", + "| c | {\"a\": 1} | false | {null=} | {null=} | {int=1} | | | {int=1} |", + "| a | {\"b\": 2} | true | {int=1} | {int=1} | {int=1} | | | |", + "| b | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {int=2} | {int=2} | |", + "| c | {\"b\": 2} | false | {null=} | {null=} | {int=1} | | | |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", ]; let batches = run_query(sql).await.unwrap();