Skip to content

Commit

Permalink
fix case of f(dict_array, dict_array) invocation
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jan 14, 2025
1 parent 38caf97 commit 433445d
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 47 deletions.
46 changes: 37 additions & 9 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use std::sync::Arc;

use datafusion::arrow::array::{
Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray,
StringArray, StringViewArray, UInt64Array, UnionArray,
StringArray, StringViewArray, UInt64Array,
};
use datafusion::arrow::compute::take;
use datafusion::arrow::datatypes::{
ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Int64Type, UInt64Type,
ArrowDictionaryKeyType, ArrowNativeType, ArrowNativeTypeOp, DataType, Int64Type, UInt64Type,
};
use datafusion::arrow::downcast_dictionary_array;
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
Expand Down Expand Up @@ -95,6 +95,7 @@ impl From<i64> for JsonPath<'_> {
}
}

#[derive(Debug)]
enum JsonPathArgs<'a> {
Array(&'a ArrayRef),
Scalars(Vec<JsonPath<'a>>),
Expand Down Expand Up @@ -152,7 +153,7 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
let path = JsonPathArgs::extract_path(path_args)?;
match (json_arg, path) {
(ColumnarValue::Array(json_array), JsonPathArgs::Array(path_array)) => {
invoke_array_array(json_array, path_array, to_array, jiter_find, return_dict).map(ColumnarValue::Array)
invoke_array_array(json_array, path_array, to_array, jiter_find).map(ColumnarValue::Array)
}
(ColumnarValue::Array(json_array), JsonPathArgs::Scalars(path)) => {
invoke_array_scalars(json_array, &path, to_array, jiter_find, return_dict).map(ColumnarValue::Array)
Expand All @@ -171,13 +172,38 @@ fn invoke_array_array<C: FromIterator<Option<I>> + 'static, I>(
path_array: &ArrayRef,
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
return_dict: bool,
) -> DataFusionResult<ArrayRef> {
downcast_dictionary_array!(
json_array => {
let values = invoke_array_array(json_array.values(), path_array, to_array, jiter_find, return_dict)?;
post_process_dict(json_array, values, return_dict)
}
fn wrap_as_dictionary<K: ArrowDictionaryKeyType>(original: &DictionaryArray<K>, new_values: ArrayRef) -> DictionaryArray<K> {
assert_eq!(original.keys().len(), new_values.len());
let mut key = K::Native::ZERO;
let key_range = std::iter::from_fn(move || {
let next = key;
key = key.add_checked(K::Native::ONE).expect("keys exhausted");
Some(next)
}).take(new_values.len());
let mut keys = PrimitiveArray::<K>::from_iter_values(key_range);
if is_json_union(new_values.data_type()) {
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = new_values.as_union().type_ids();
keys = mask_dictionary_keys(&keys, type_ids);
}
DictionaryArray::<K>::new(keys, new_values)
}

// TODO: in theory if path_array is _also_ a dictionary we could work out the unique key
// combinations and do less work, but this can be left as a future optimization
let output = match json_array.values().data_type() {
DataType::Utf8 => zip_apply(json_array.downcast_dict::<StringArray>().unwrap(), path_array, to_array, jiter_find),
DataType::LargeUtf8 => zip_apply(json_array.downcast_dict::<LargeStringArray>().unwrap(), path_array, to_array, jiter_find),
DataType::Utf8View => zip_apply(json_array.downcast_dict::<StringViewArray>().unwrap(), path_array, to_array, jiter_find),
other => exec_err!("unexpected json array type {:?}", other),
}?;

// ensure return is a dictionary to satisfy the declaration above in return_type_check
Ok(Arc::new(wrap_as_dictionary(json_array, output)))
},
DataType::Utf8 => zip_apply(json_array.as_string::<i32>().iter(), path_array, to_array, jiter_find),
DataType::LargeUtf8 => zip_apply(json_array.as_string::<i64>().iter(), path_array, to_array, jiter_find),
DataType::Utf8View => zip_apply(json_array.as_string_view().iter(), path_array, to_array, jiter_find),
Expand Down Expand Up @@ -239,6 +265,7 @@ fn invoke_scalar_array<C: FromIterator<Option<I>> + 'static, I>(
to_array,
jiter_find,
)
// FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
.map(ColumnarValue::Array)
}

Expand All @@ -250,6 +277,7 @@ fn invoke_scalar_scalars<I>(
) -> DataFusionResult<ColumnarValue> {
let s = extract_json_scalar(scalar)?;
let v = jiter_find(s, path).ok();
// FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
Ok(ColumnarValue::Scalar(to_scalar(v)))
}

Expand Down Expand Up @@ -321,7 +349,7 @@ fn post_process_dict<T: ArrowDictionaryKeyType>(
if return_dict {
if is_json_union(result_values.data_type()) {
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = result_values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
let type_ids = result_values.as_union().type_ids();
Ok(Arc::new(DictionaryArray::new(
mask_dictionary_keys(dict_array.keys(), type_ids),
result_values,
Expand Down Expand Up @@ -413,7 +441,7 @@ impl From<Utf8Error> for GetError {
///
/// That said, doing this might also be an optimization for cases like null-checking without needing
/// to check the value union array.
fn mask_dictionary_keys<K: ArrowPrimitiveType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
fn mask_dictionary_keys<K: ArrowDictionaryKeyType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
let mut null_mask = vec![true; keys.len()];
for (i, k) in keys.iter().enumerate() {
match k {
Expand Down
Loading

0 comments on commit 433445d

Please sign in to comment.