diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index c22ee244fe286..74cf79f3088c1 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -29,7 +29,7 @@ use crate::{ }; use datafusion_expr_common::signature::{Signature, TypeSignature}; -use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; @@ -958,7 +958,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( /// Can this data type be used in hash join equal conditions?? /// Data types here come from function 'equal_rows', if more data types are supported -/// in equal_rows(hash join), add those data types here to generate join logical plan. +/// in create_hashes, add those data types here to generate join logical plan. pub fn can_hash(data_type: &DataType) -> bool { match data_type { DataType::Null => true, @@ -971,31 +971,38 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::UInt16 => true, DataType::UInt32 => true, DataType::UInt64 => true, + DataType::Float16 => true, DataType::Float32 => true, DataType::Float64 => true, - DataType::Timestamp(time_unit, _) => match time_unit { - TimeUnit::Second => true, - TimeUnit::Millisecond => true, - TimeUnit::Microsecond => true, - TimeUnit::Nanosecond => true, - }, + DataType::Decimal128(_, _) => true, + DataType::Decimal256(_, _) => true, + DataType::Timestamp(_, _) => true, DataType::Utf8 => true, DataType::LargeUtf8 => true, DataType::Utf8View => true, - DataType::Decimal128(_, _) => true, + DataType::Binary => true, + DataType::LargeBinary => true, + DataType::BinaryView => true, DataType::Date32 => true, DataType::Date64 => true, + DataType::Time32(_) => true, + DataType::Time64(_) => true, + DataType::Duration(_) => true, + DataType::Interval(_) => true, DataType::FixedSizeBinary(_) => true, - DataType::Dictionary(key_type, value_type) - if *value_type.as_ref() == DataType::Utf8 => - { - DataType::is_dictionary_key_type(key_type) + DataType::Dictionary(key_type, value_type) => { + DataType::is_dictionary_key_type(key_type) && can_hash(&value_type) } - DataType::List(_) => true, - DataType::LargeList(_) => true, - DataType::FixedSizeList(_, _) => true, + DataType::List(value_type) => can_hash(value_type.data_type()), + DataType::LargeList(value_type) => can_hash(value_type.data_type()), + DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()), + DataType::Map(map_struct, true | false) => can_hash(&map_struct.data_type()), DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())), - _ => false, + + DataType::ListView(_) + | DataType::LargeListView(_) + | DataType::Union(_, _) + | DataType::RunEndEncoded(_, _) => false, } } @@ -1403,6 +1410,7 @@ mod tests { test::function_stub::max_udaf, test::function_stub::min_udaf, test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, }; + use arrow::datatypes::{UnionFields, UnionMode}; #[test] fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { @@ -1805,4 +1813,21 @@ mod tests { assert!(accum.contains(&Column::from_name("a"))); Ok(()) } + + #[test] + fn test_can_hash() { + let union_fields: UnionFields = [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + ] + .into_iter() + .collect(); + + let union_type = DataType::Union(union_fields, UnionMode::Sparse); + assert!(!can_hash(&union_type)); + + let list_union_type = + DataType::List(Arc::new(Field::new("my_union", union_type, true))); + assert!(!can_hash(&list_union_type)); + } }