Skip to content

Commit

Permalink
Use specialized dictionary kernels (#1178) (#2808)
Browse files Browse the repository at this point in the history
* Use specialized dictionary kernels (#1178)

* Fix tests
  • Loading branch information
tustvold authored Jun 30, 2022
1 parent 7cc3ffd commit 6e0bb84
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 26 deletions.
56 changes: 35 additions & 21 deletions datafusion/expr/src/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,12 @@ pub fn comparison_eq_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
// can't compare dictionaries directly due to
// https://github.com/apache/arrow-rs/issues/1201
if lhs_type == rhs_type && !is_dictionary(lhs_type) {
if lhs_type == rhs_type {
// same type => equality is possible
return Some(lhs_type.clone());
}
comparison_binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
Expand All @@ -173,15 +171,13 @@ fn comparison_order_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
// can't compare dictionaries directly due to
// https://github.com/apache/arrow-rs/issues/1201
if lhs_type == rhs_type && !is_dictionary(lhs_type) {
if lhs_type == rhs_type {
// same type => all good
return Some(lhs_type.clone());
}
comparison_binary_numeric_coercion(lhs_type, rhs_type)
.or_else(|| string_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
}
Expand Down Expand Up @@ -448,17 +444,24 @@ fn dictionary_value_coercion(
/// Coercion rules for Dictionaries: the type that both lhs and rhs
/// can be casted to for the purpose of a computation.
///
/// It would likely be preferable to cast primitive values to
/// dictionaries, and thus avoid unpacking dictionary as well as doing
/// faster comparisons. However, the arrow compute kernels (e.g. eq)
/// don't have DictionaryArray support yet, so fall back to unpacking
/// the dictionaries
fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
/// Not all operators support dictionaries, if `preserve_dictionaries` is true
/// dictionaries will be preserved if possible
fn dictionary_coercion(
lhs_type: &DataType,
rhs_type: &DataType,
preserve_dictionaries: bool,
) -> Option<DataType> {
match (lhs_type, rhs_type) {
(
DataType::Dictionary(_lhs_index_type, lhs_value_type),
DataType::Dictionary(_rhs_index_type, rhs_value_type),
) => dictionary_value_coercion(lhs_value_type, rhs_value_type),
(d @ DataType::Dictionary(_, value_type), other_type)
| (other_type, d @ DataType::Dictionary(_, value_type))
if preserve_dictionaries && value_type.as_ref() == other_type =>
{
Some(d.clone())
}
(DataType::Dictionary(_index_type, value_type), _) => {
dictionary_value_coercion(value_type, rhs_type)
}
Expand Down Expand Up @@ -514,7 +517,7 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
/// This is a union of string coercion rules and dictionary coercion rules
fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
string_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type, false))
.or_else(|| null_coercion(lhs_type, rhs_type))
}

Expand Down Expand Up @@ -616,7 +619,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
return Some(lhs_type.clone());
}
numerical_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
.or_else(|| null_coercion(lhs_type, rhs_type))
}
Expand Down Expand Up @@ -779,21 +782,32 @@ mod tests {
fn test_dictionary_type_coersion() {
use DataType::*;

// TODO: In the future, this would ideally return Dictionary types and avoid unpacking
let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32));
let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Int32));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Int32));
assert_eq!(
dictionary_coercion(&lhs_type, &rhs_type, false),
Some(Int32)
);

let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None);
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), None);

let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
let rhs_type = Utf8;
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8));
assert_eq!(
dictionary_coercion(&lhs_type, &rhs_type, true),
Some(lhs_type.clone())
);

let lhs_type = Utf8;
let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8));
assert_eq!(
dictionary_coercion(&lhs_type, &rhs_type, true),
Some(rhs_type.clone())
);
}
}
27 changes: 22 additions & 5 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,11 +1009,28 @@ impl PhysicalExpr for BinaryExpr {
let left_data_type = left_value.data_type();
let right_data_type = right_value.data_type();

if left_data_type != right_data_type {
return Err(DataFusionError::Internal(format!(
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
self.op, left_data_type, right_data_type
)));
match (&left_value, &left_data_type, &right_value, &right_data_type) {
// Types are equal => valid
(_, l, _, r) if l == r => {}
// Allow comparing a dictionary value with its corresponding scalar value
(
ColumnarValue::Array(_),
DataType::Dictionary(_, dict_t),
ColumnarValue::Scalar(_),
scalar_t,
)
| (
ColumnarValue::Scalar(_),
scalar_t,
ColumnarValue::Array(_),
DataType::Dictionary(_, dict_t),
) if dict_t.as_ref() == scalar_t => {}
_ => {
return Err(DataFusionError::Internal(format!(
"Cannot evaluate binary expression {:?} with types {:?} and {:?}",
self.op, left_data_type, right_data_type
)));
}
}

// Attempt to use special kernels if one input is scalar and the other is an array
Expand Down

0 comments on commit 6e0bb84

Please sign in to comment.