diff --git a/src/common.rs b/src/common.rs index fdba8d2..37056e2 100644 --- a/src/common.rs +++ b/src/common.rs @@ -3,11 +3,11 @@ use std::sync::Arc; use datafusion::arrow::array::{ Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray, - StringArray, StringViewArray, UInt64Array, UnionArray, + StringArray, StringViewArray, UInt64Array, }; use datafusion::arrow::compute::take; use datafusion::arrow::datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Int64Type, UInt64Type, + ArrowDictionaryKeyType, ArrowNativeType, ArrowNativeTypeOp, DataType, Int64Type, UInt64Type, }; use datafusion::arrow::downcast_dictionary_array; use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; @@ -95,6 +95,7 @@ impl From for JsonPath<'_> { } } +#[derive(Debug)] enum JsonPathArgs<'a> { Array(&'a ArrayRef), Scalars(Vec>), @@ -152,7 +153,7 @@ pub fn invoke> + 'static, I>( let path = JsonPathArgs::extract_path(path_args)?; match (json_arg, path) { (ColumnarValue::Array(json_array), JsonPathArgs::Array(path_array)) => { - invoke_array_array(json_array, path_array, to_array, jiter_find, return_dict).map(ColumnarValue::Array) + invoke_array_array(json_array, path_array, to_array, jiter_find).map(ColumnarValue::Array) } (ColumnarValue::Array(json_array), JsonPathArgs::Scalars(path)) => { invoke_array_scalars(json_array, &path, to_array, jiter_find, return_dict).map(ColumnarValue::Array) @@ -171,13 +172,38 @@ fn invoke_array_array> + 'static, I>( path_array: &ArrayRef, to_array: impl Fn(C) -> DataFusionResult, jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result, - return_dict: bool, ) -> DataFusionResult { downcast_dictionary_array!( json_array => { - let values = invoke_array_array(json_array.values(), path_array, to_array, jiter_find, return_dict)?; - post_process_dict(json_array, values, return_dict) - } + fn wrap_as_dictionary(original: &DictionaryArray, new_values: ArrayRef) -> DictionaryArray { + assert_eq!(original.keys().len(), new_values.len()); + let mut key = K::Native::ZERO; + let key_range = std::iter::from_fn(move || { + let next = key; + key = key.add_checked(K::Native::ONE).expect("keys exhausted"); + Some(next) + }).take(new_values.len()); + let mut keys = PrimitiveArray::::from_iter_values(key_range); + if is_json_union(new_values.data_type()) { + // JSON union: post-process the array to set keys to null where the union member is null + let type_ids = new_values.as_union().type_ids(); + keys = mask_dictionary_keys(&keys, type_ids); + } + DictionaryArray::::new(keys, new_values) + } + + // TODO: in theory if path_array is _also_ a dictionary we could work out the unique key + // combinations and do less work, but this can be left as a future optimization + let output = match json_array.values().data_type() { + DataType::Utf8 => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), + DataType::LargeUtf8 => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), + DataType::Utf8View => zip_apply(json_array.downcast_dict::().unwrap(), path_array, to_array, jiter_find), + other => exec_err!("unexpected json array type {:?}", other), + }?; + + // ensure return is a dictionary to satisfy the declaration above in return_type_check + Ok(Arc::new(wrap_as_dictionary(json_array, output))) + }, DataType::Utf8 => zip_apply(json_array.as_string::().iter(), path_array, to_array, jiter_find), DataType::LargeUtf8 => zip_apply(json_array.as_string::().iter(), path_array, to_array, jiter_find), DataType::Utf8View => zip_apply(json_array.as_string_view().iter(), path_array, to_array, jiter_find), @@ -239,6 +265,7 @@ fn invoke_scalar_array> + 'static, I>( to_array, jiter_find, ) + // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary? .map(ColumnarValue::Array) } @@ -250,6 +277,7 @@ fn invoke_scalar_scalars( ) -> DataFusionResult { let s = extract_json_scalar(scalar)?; let v = jiter_find(s, path).ok(); + // FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary? Ok(ColumnarValue::Scalar(to_scalar(v))) } @@ -321,7 +349,7 @@ fn post_process_dict( if return_dict { if is_json_union(result_values.data_type()) { // JSON union: post-process the array to set keys to null where the union member is null - let type_ids = result_values.as_any().downcast_ref::().unwrap().type_ids(); + let type_ids = result_values.as_union().type_ids(); Ok(Arc::new(DictionaryArray::new( mask_dictionary_keys(dict_array.keys(), type_ids), result_values, @@ -413,7 +441,7 @@ impl From for GetError { /// /// That said, doing this might also be an optimization for cases like null-checking without needing /// to check the value union array. -fn mask_dictionary_keys(keys: &PrimitiveArray, type_ids: &[i8]) -> PrimitiveArray { +fn mask_dictionary_keys(keys: &PrimitiveArray, type_ids: &[i8]) -> PrimitiveArray { let mut null_mask = vec![true; keys.len()]; for (i, k) in keys.iter().enumerate() { match k { diff --git a/tests/main.rs b/tests/main.rs index 54a3f9b..e2eb5e5 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1600,34 +1600,38 @@ async fn test_json_object_keys_nested() { #[tokio::test] async fn test_lookup_literal_column_matrix() { let sql = r#" -WITH attr_names AS ( - -- this is deliberately a different length to json_columns - SELECT unnest(['a', 'b', 'c']) as attr_name -), json_columns AS ( +WITH json_columns AS ( SELECT unnest(['{"a": 1}', '{"b": 2}']) as json_column +), attr_names AS ( + -- this is deliberately a different length to json_columns + SELECT + unnest(['a', 'b', 'c']) as attr_name, + arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name_dict ) SELECT attr_name, json_column, 'a' = attr_name, - json_get('{"a": 1}', attr_name), -- literal lookup with column - json_get('{"a": 1}', 'a'), -- literal lookup with literal - json_get(json_column, attr_name), -- column lookup with column - json_get(json_column, 'a') -- column lookup with literal -FROM attr_names, json_columns + json_get('{"a": 1}', attr_name), -- literal lookup with column + json_get('{"a": 1}', attr_name_dict), -- literal lookup with dict column + json_get('{"a": 1}', 'a'), -- literal lookup with literal + json_get(json_column, attr_name), -- column lookup with column + json_get(json_column, attr_name_dict), -- column lookup with dict column + json_get(json_column, 'a') -- column lookup with literal +FROM json_columns, attr_names "#; let expected = [ - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,Utf8(\"a\")) |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} |", - "| a | {\"b\": 2} | true | {int=1} | {int=1} | {null=} | {null=} |", - "| b | {\"a\": 1} | false | {null=} | {int=1} | {null=} | {int=1} |", - "| b | {\"b\": 2} | false | {null=} | {int=1} | {int=2} | {null=} |", - "| c | {\"a\": 1} | false | {null=} | {int=1} | {null=} | {int=1} |", - "| c | {\"b\": 2} | false | {null=} | {int=1} | {null=} | {null=} |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name_dict) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,attr_names.attr_name_dict) | json_get(json_columns.json_column,Utf8(\"a\")) |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} |", + "| a | {\"b\": 2} | true | {int=1} | {int=1} | {int=1} | {null=} | {null=} | {null=} |", + "| b | {\"a\": 1} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {int=1} |", + "| b | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {int=2} | {int=2} | {null=} |", + "| c | {\"a\": 1} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {int=1} |", + "| c | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {null=} | {null=} | {null=} |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", ]; let batches = run_query(sql).await.unwrap(); @@ -1637,36 +1641,40 @@ FROM attr_names, json_columns #[tokio::test] async fn test_lookup_literal_column_matrix_dictionaries() { let sql = r#" -WITH attr_names AS ( - -- this is deliberately a different length to json_columns - SELECT arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name -), json_columns AS ( +WITH json_columns AS ( SELECT arrow_cast(unnest(['{"a": 1}', '{"b": 2}']), 'Dictionary(Int32, Utf8)') as json_column +), attr_names AS ( + -- this is deliberately a different length to json_columns + SELECT + unnest(['a', 'b', 'c']) as attr_name, + arrow_cast(unnest(['a', 'b', 'c']), 'Dictionary(Int32, Utf8)') as attr_name_dict ) SELECT attr_name, json_column, 'a' = attr_name, - json_get('{"a": 1}', attr_name), -- literal lookup with column - json_get('{"a": 1}', 'a'), -- literal lookup with literal - json_get(json_column, attr_name), -- column lookup with column - json_get(json_column, 'a') -- column lookup with literal -FROM attr_names, json_columns + json_get('{"a": 1}', attr_name), -- literal lookup with column + json_get('{"a": 1}', attr_name_dict), -- literal lookup with dict column + json_get('{"a": 1}', 'a'), -- literal lookup with literal + json_get(json_column, attr_name), -- column lookup with column + json_get(json_column, attr_name_dict), -- column lookup with dict column + json_get(json_column, 'a') -- column lookup with literal +FROM json_columns, attr_names "#; // NB as compared to the non-dictionary case, we null out the dictionary keys if the return // value is a dict, which is why we get true nulls instead of {null=} let expected = [ - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,Utf8(\"a\")) |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", - "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} |", - "| a | {\"b\": 2} | true | {int=1} | {int=1} | | |", - "| b | {\"a\": 1} | false | {null=} | {int=1} | | {int=1} |", - "| b | {\"b\": 2} | false | {null=} | {int=1} | {int=2} | |", - "| c | {\"a\": 1} | false | {null=} | {int=1} | | {int=1} |", - "| c | {\"b\": 2} | false | {null=} | {int=1} | | |", - "+-----------+-------------+----------------------------------+-------------------------------------------------+--------------------------------------+---------------------------------------------------------+----------------------------------------------+", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| attr_name | json_column | Utf8(\"a\") = attr_names.attr_name | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name) | json_get(Utf8(\"{\"a\": 1}\"),attr_names.attr_name_dict) | json_get(Utf8(\"{\"a\": 1}\"),Utf8(\"a\")) | json_get(json_columns.json_column,attr_names.attr_name) | json_get(json_columns.json_column,attr_names.attr_name_dict) | json_get(json_columns.json_column,Utf8(\"a\")) |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", + "| a | {\"a\": 1} | true | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} | {int=1} |", + "| b | {\"a\": 1} | false | {null=} | {null=} | {int=1} | | | {int=1} |", + "| c | {\"a\": 1} | false | {null=} | {null=} | {int=1} | | | {int=1} |", + "| a | {\"b\": 2} | true | {int=1} | {int=1} | {int=1} | | | |", + "| b | {\"b\": 2} | false | {null=} | {null=} | {int=1} | {int=2} | {int=2} | |", + "| c | {\"b\": 2} | false | {null=} | {null=} | {int=1} | | | |", + "+-----------+-------------+----------------------------------+-------------------------------------------------+------------------------------------------------------+--------------------------------------+---------------------------------------------------------+--------------------------------------------------------------+----------------------------------------------+", ]; let batches = run_query(sql).await.unwrap();