From 6e0bb8476d783c1caaf6bf011487c92ae9352f78 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Thu, 30 Jun 2022 02:08:49 +0100 Subject: [PATCH] Use specialized dictionary kernels (#1178) (#2808) * Use specialized dictionary kernels (#1178) * Fix tests --- datafusion/expr/src/binary_rule.rs | 56 ++++++++++++------- .../physical-expr/src/expressions/binary.rs | 27 +++++++-- 2 files changed, 57 insertions(+), 26 deletions(-) diff --git a/datafusion/expr/src/binary_rule.rs b/datafusion/expr/src/binary_rule.rs index b7b2c57e83e5..5b404d8a2843 100644 --- a/datafusion/expr/src/binary_rule.rs +++ b/datafusion/expr/src/binary_rule.rs @@ -155,14 +155,12 @@ pub fn comparison_eq_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { - // can't compare dictionaries directly due to - // https://github.com/apache/arrow-rs/issues/1201 - if lhs_type == rhs_type && !is_dictionary(lhs_type) { + if lhs_type == rhs_type { // same type => equality is possible return Some(lhs_type.clone()); } comparison_binary_numeric_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_coercion(lhs_type, rhs_type)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, true)) .or_else(|| temporal_coercion(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| null_coercion(lhs_type, rhs_type)) @@ -173,15 +171,13 @@ fn comparison_order_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { - // can't compare dictionaries directly due to - // https://github.com/apache/arrow-rs/issues/1201 - if lhs_type == rhs_type && !is_dictionary(lhs_type) { + if lhs_type == rhs_type { // same type => all good return Some(lhs_type.clone()); } comparison_binary_numeric_coercion(lhs_type, rhs_type) .or_else(|| string_coercion(lhs_type, rhs_type)) - .or_else(|| dictionary_coercion(lhs_type, rhs_type)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, true)) .or_else(|| temporal_coercion(lhs_type, rhs_type)) .or_else(|| null_coercion(lhs_type, rhs_type)) } @@ -448,17 +444,24 @@ fn dictionary_value_coercion( /// Coercion rules for Dictionaries: the type that both lhs and rhs /// can be casted to for the purpose of a computation. /// -/// It would likely be preferable to cast primitive values to -/// dictionaries, and thus avoid unpacking dictionary as well as doing -/// faster comparisons. However, the arrow compute kernels (e.g. eq) -/// don't have DictionaryArray support yet, so fall back to unpacking -/// the dictionaries -fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { +/// Not all operators support dictionaries, if `preserve_dictionaries` is true +/// dictionaries will be preserved if possible +fn dictionary_coercion( + lhs_type: &DataType, + rhs_type: &DataType, + preserve_dictionaries: bool, +) -> Option { match (lhs_type, rhs_type) { ( DataType::Dictionary(_lhs_index_type, lhs_value_type), DataType::Dictionary(_rhs_index_type, rhs_value_type), ) => dictionary_value_coercion(lhs_value_type, rhs_value_type), + (d @ DataType::Dictionary(_, value_type), other_type) + | (other_type, d @ DataType::Dictionary(_, value_type)) + if preserve_dictionaries && value_type.as_ref() == other_type => + { + Some(d.clone()) + } (DataType::Dictionary(_index_type, value_type), _) => { dictionary_value_coercion(value_type, rhs_type) } @@ -514,7 +517,7 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option /// This is a union of string coercion rules and dictionary coercion rules fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_coercion(lhs_type, rhs_type)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, false)) .or_else(|| null_coercion(lhs_type, rhs_type)) } @@ -616,7 +619,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { return Some(lhs_type.clone()); } numerical_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_coercion(lhs_type, rhs_type)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type, true)) .or_else(|| temporal_coercion(lhs_type, rhs_type)) .or_else(|| null_coercion(lhs_type, rhs_type)) } @@ -779,21 +782,32 @@ mod tests { fn test_dictionary_type_coersion() { use DataType::*; - // TODO: In the future, this would ideally return Dictionary types and avoid unpacking let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Int32)); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Int32)); + assert_eq!( + dictionary_coercion(&lhs_type, &rhs_type, false), + Some(Int32) + ); let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), None); let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Utf8; - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8)); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8)); + assert_eq!( + dictionary_coercion(&lhs_type, &rhs_type, true), + Some(lhs_type.clone()) + ); let lhs_type = Utf8; let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8)); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8)); + assert_eq!( + dictionary_coercion(&lhs_type, &rhs_type, true), + Some(rhs_type.clone()) + ); } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index c0876a722a0e..417306221b7b 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -1009,11 +1009,28 @@ impl PhysicalExpr for BinaryExpr { let left_data_type = left_value.data_type(); let right_data_type = right_value.data_type(); - if left_data_type != right_data_type { - return Err(DataFusionError::Internal(format!( - "Cannot evaluate binary expression {:?} with types {:?} and {:?}", - self.op, left_data_type, right_data_type - ))); + match (&left_value, &left_data_type, &right_value, &right_data_type) { + // Types are equal => valid + (_, l, _, r) if l == r => {} + // Allow comparing a dictionary value with its corresponding scalar value + ( + ColumnarValue::Array(_), + DataType::Dictionary(_, dict_t), + ColumnarValue::Scalar(_), + scalar_t, + ) + | ( + ColumnarValue::Scalar(_), + scalar_t, + ColumnarValue::Array(_), + DataType::Dictionary(_, dict_t), + ) if dict_t.as_ref() == scalar_t => {} + _ => { + return Err(DataFusionError::Internal(format!( + "Cannot evaluate binary expression {:?} with types {:?} and {:?}", + self.op, left_data_type, right_data_type + ))); + } } // Attempt to use special kernels if one input is scalar and the other is an array