Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use specialized dictionary kernels (#1178) #2808

Merged
merged 2 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions datafusion/expr/src/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,12 @@ pub fn comparison_eq_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
// can't compare dictionaries directly due to
// https://github.com/apache/arrow-rs/issues/1201
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

if lhs_type == rhs_type && !is_dictionary(lhs_type) {
if lhs_type == rhs_type {
// same type => equality is possible
return Some(lhs_type.clone());
}
comparison_binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
Expand All @@ -173,15 +171,13 @@ fn comparison_order_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
// can't compare dictionaries directly due to
// https://github.com/apache/arrow-rs/issues/1201
if lhs_type == rhs_type && !is_dictionary(lhs_type) {
if lhs_type == rhs_type {
// same type => all good
return Some(lhs_type.clone());
}
comparison_binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
}
Expand Down Expand Up @@ -448,17 +444,24 @@ fn dictionary_value_coercion(
/// Coercion rules for Dictionaries: the type that both lhs and rhs
/// can be casted to for the purpose of a computation.
///
/// It would likely be preferable to cast primitive values to
/// dictionaries, and thus avoid unpacking dictionary as well as doing
/// faster comparisons. However, the arrow compute kernels (e.g. eq)
/// don't have DictionaryArray support yet, so fall back to unpacking
/// the dictionaries
fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
/// Not all operators support dictionaries, if `preserve_dictionaries` is true
/// dictionaries will be preserved if possible
fn dictionary_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_dictionaries: bool,
) -> Option<DataType> {
match (lhs_type, rhs_type) {
(
DataType::Dictionary(_lhs_index_type, lhs_value_type),
DataType::Dictionary(_rhs_index_type, rhs_value_type),
) => dictionary_value_coercion(lhs_value_type, rhs_value_type),
(d @ DataType::Dictionary(_, value_type), other_type)
| (other_type, d @ DataType::Dictionary(_, value_type))
if preserve_dictionaries && value_type.as_ref() == other_type =>
{
Some(d.clone())
}
(DataType::Dictionary(_index_type, value_type), _) => {
dictionary_value_coercion(value_type, rhs_type)
}
Expand Down Expand Up @@ -514,7 +517,7 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
/// This is a union of string coercion rules and dictionary coercion rules
fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
string_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type, false))
.or_else(|| null_coercion(lhs_type, rhs_type))
}

Expand Down Expand Up @@ -616,7 +619,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
return Some(lhs_type.clone());
}
numerical_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
}
Expand Down Expand Up @@ -779,21 +782,32 @@ mod tests {
fn test_dictionary_type_coersion() {
use DataType::*;

// TODO: In the future, this would ideally return Dictionary types and avoid unpacking
let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32));
let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Int32));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Int32));
assert_eq!(
dictionary_coercion(&lhs_type, &rhs_type, false),
Some(Int32)
);

let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None);
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), None);

let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
let rhs_type = Utf8;
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8));
assert_eq!(
dictionary_coercion(&lhs_type, &rhs_type, true),
Some(lhs_type.clone())
);

let lhs_type = Utf8;
let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8));
assert_eq!(
dictionary_coercion(&lhs_type, &rhs_type, true),
Some(rhs_type.clone())
);
}
}
27 changes: 22 additions & 5 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,11 +1009,28 @@ impl PhysicalExpr for BinaryExpr {
let left_data_type = left_value.data_type();
let right_data_type = right_value.data_type();

if left_data_type != right_data_type {
return Err(DataFusionError::Internal(format!(
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
self.op, left_data_type, right_data_type
)));
match (&left_value, &left_data_type, &right_value, &right_data_type) {
// Types are equal => valid
(_, l, _, r) if l == r => {}
// Allow comparing a dictionary value with its corresponding scalar value
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually necessary for correctness in addition to being beneficial for performance, because ScalarValue does not have a way to encode a dictionary data type

(
ColumnarValue::Array(_),
DataType::Dictionary(_, dict_t),
ColumnarValue::Scalar(_),
scalar_t,
)
| (
ColumnarValue::Scalar(_),
scalar_t,
ColumnarValue::Array(_),
DataType::Dictionary(_, dict_t),
) if dict_t.as_ref() == scalar_t => {}
_ => {
return Err(DataFusionError::Internal(format!(
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
self.op, left_data_type, right_data_type
)));
}
}

// Attempt to use special kernels if one input is scalar and the other is an array
Expand Down