Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix case of f(dict_array, dict_array) invocation #64

Merged
merged 3 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@ use std::sync::Arc;

use datafusion::arrow::array::{
Array, ArrayAccessor, ArrayRef, AsArray, DictionaryArray, Int64Array, LargeStringArray, PrimitiveArray,
StringArray, StringViewArray, UInt64Array, UnionArray,
StringArray, StringViewArray, UInt64Array,
};
use datafusion::arrow::compute::take;
use datafusion::arrow::datatypes::{
ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType, Int64Type, UInt64Type,
ArrowDictionaryKeyType, ArrowNativeType, ArrowNativeTypeOp, DataType, Int64Type, UInt64Type,
};
use datafusion::arrow::downcast_dictionary_array;
use datafusion::common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue};
use datafusion::logical_expr::ColumnarValue;
use jiter::{Jiter, JiterError, Peek};

use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array, TYPE_ID_NULL};
use crate::common_union::{
is_json_union, json_from_union_scalar, nested_json_array, nested_json_array_ref, TYPE_ID_NULL,
};

/// General implementation of `ScalarUDFImpl::return_type`.
///
Expand Down Expand Up @@ -95,6 +97,7 @@ impl From<i64> for JsonPath<'_> {
}
}

#[derive(Debug)]
enum JsonPathArgs<'a> {
Array(&'a ArrayRef),
Scalars(Vec<JsonPath<'a>>),
Expand Down Expand Up @@ -175,9 +178,48 @@ fn invoke_array_array<C: FromIterator<Option<I>> + 'static, I>(
) -> DataFusionResult<ArrayRef> {
downcast_dictionary_array!(
json_array => {
let values = invoke_array_array(json_array.values(), path_array, to_array, jiter_find, return_dict)?;
post_process_dict(json_array, values, return_dict)
}
fn wrap_as_dictionary<K: ArrowDictionaryKeyType>(original: &DictionaryArray<K>, new_values: ArrayRef) -> DictionaryArray<K> {
assert_eq!(original.keys().len(), new_values.len());
let mut key = K::Native::ZERO;
let key_range = std::iter::from_fn(move || {
let next = key;
key = key.add_checked(K::Native::ONE).expect("keys exhausted");
Some(next)
}).take(new_values.len());
let mut keys = PrimitiveArray::<K>::from_iter_values(key_range);
if is_json_union(new_values.data_type()) {
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = new_values.as_union().type_ids();
keys = mask_dictionary_keys(&keys, type_ids);
}
DictionaryArray::<K>::new(keys, new_values)
}

// TODO: in theory if path_array is _also_ a dictionary we could work out the unique key
// combinations and do less work, but this can be left as a future optimization
let output = match json_array.values().data_type() {
DataType::Utf8 => zip_apply(json_array.downcast_dict::<StringArray>().unwrap(), path_array, to_array, jiter_find),
DataType::LargeUtf8 => zip_apply(json_array.downcast_dict::<LargeStringArray>().unwrap(), path_array, to_array, jiter_find),
DataType::Utf8View => zip_apply(json_array.downcast_dict::<StringViewArray>().unwrap(), path_array, to_array, jiter_find),
other => if let Some(child_array) = nested_json_array_ref(json_array.values(), is_object_lookup_array(path_array.data_type())) {
// Horrible case: dict containing union as input with array for paths, figure
// out from the path type which union members we should access, repack the
// dictionary and then recurse.
//
// Use direct return because if return_dict applies, the recursion will handle it.
return invoke_array_array(&(Arc::new(json_array.with_values(child_array.clone())) as _), path_array, to_array, jiter_find, return_dict)
} else {
exec_err!("unexpected json array type {:?}", other)
}
}?;

if return_dict {
// ensure return is a dictionary to satisfy the declaration above in return_type_check
Ok(Arc::new(wrap_as_dictionary(json_array, output)))
} else {
Ok(output)
}
},
DataType::Utf8 => zip_apply(json_array.as_string::<i32>().iter(), path_array, to_array, jiter_find),
DataType::LargeUtf8 => zip_apply(json_array.as_string::<i64>().iter(), path_array, to_array, jiter_find),
DataType::Utf8View => zip_apply(json_array.as_string_view().iter(), path_array, to_array, jiter_find),
Expand Down Expand Up @@ -239,6 +281,7 @@ fn invoke_scalar_array<C: FromIterator<Option<I>> + 'static, I>(
to_array,
jiter_find,
)
// FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
.map(ColumnarValue::Array)
}

Expand All @@ -250,6 +293,7 @@ fn invoke_scalar_scalars<I>(
) -> DataFusionResult<ColumnarValue> {
let s = extract_json_scalar(scalar)?;
let v = jiter_find(s, path).ok();
// FIXME edge cases where scalar is wrapped in a dictionary, should return a dictionary?
Ok(ColumnarValue::Scalar(to_scalar(v)))
}

Expand Down Expand Up @@ -321,7 +365,7 @@ fn post_process_dict<T: ArrowDictionaryKeyType>(
if return_dict {
if is_json_union(result_values.data_type()) {
// JSON union: post-process the array to set keys to null where the union member is null
let type_ids = result_values.as_any().downcast_ref::<UnionArray>().unwrap().type_ids();
let type_ids = result_values.as_union().type_ids();
Ok(Arc::new(DictionaryArray::new(
mask_dictionary_keys(dict_array.keys(), type_ids),
result_values,
Expand Down Expand Up @@ -413,7 +457,7 @@ impl From<Utf8Error> for GetError {
///
/// That said, doing this might also be an optimization for cases like null-checking without needing
/// to check the value union array.
fn mask_dictionary_keys<K: ArrowPrimitiveType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
fn mask_dictionary_keys<K: ArrowDictionaryKeyType>(keys: &PrimitiveArray<K>, type_ids: &[i8]) -> PrimitiveArray<K> {
let mut null_mask = vec![true; keys.len()];
for (i, k) in keys.iter().enumerate() {
match k {
Expand Down
6 changes: 5 additions & 1 deletion src/common_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ pub fn is_json_union(data_type: &DataType) -> bool {
/// * `object_lookup` - If `true`, extract from the "object" member of the union,
/// otherwise extract from the "array" member
pub(crate) fn nested_json_array(array: &ArrayRef, object_lookup: bool) -> Option<&StringArray> {
nested_json_array_ref(array, object_lookup).map(AsArray::as_string)
}

pub(crate) fn nested_json_array_ref(array: &ArrayRef, object_lookup: bool) -> Option<&ArrayRef> {
let union_array: &UnionArray = array.as_any().downcast_ref::<UnionArray>()?;
let type_id = if object_lookup { TYPE_ID_OBJECT } else { TYPE_ID_ARRAY };
union_array.child(type_id).as_any().downcast_ref()
Some(union_array.child(type_id))
}

/// Extract a JSON string from a `JsonUnion` scalar
Expand Down
Loading
Loading