-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Coalesce
casting logic to follows what Postgres and DuckDB do. Introduce signature that do non-comparison coercion
#10268
Changes from 13 commits
c79156f
a36e6b2
bf16c92
407e3c7
03b9162
4abf29d
4965e8d
81f0235
bae996c
ddf9b1c
c2799ea
2574896
6a17e57
4cba8c5
f1cfb8d
d2e83d3
3a88ad7
03880a3
481f548
dfc4176
46a9060
d656645
5683447
b949fae
5aaeb5b
a968c0e
cf679c5
15471ab
e5cc46b
a810e85
cb16cda
53bedda
a37da2d
70239e0
be116f8
8f4e991
6a8fe6f
20e618e
4153593
030a519
5b797d5
b954479
829b5a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -92,14 +92,22 @@ pub enum TypeSignature { | |||||||||||||||||
/// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` | ||||||||||||||||||
Variadic(Vec<DataType>), | ||||||||||||||||||
/// One or more arguments of an arbitrary but equal type. | ||||||||||||||||||
/// DataFusion attempts to coerce all argument types to match the first argument's type | ||||||||||||||||||
/// DataFusion attempts to coerce all argument types to match to the common type with comparision coercion. | ||||||||||||||||||
/// | ||||||||||||||||||
/// # Examples | ||||||||||||||||||
/// Given types in signature should be coercible to the same final type. | ||||||||||||||||||
/// A function such as `make_array` is `VariadicEqual`. | ||||||||||||||||||
/// | ||||||||||||||||||
/// `make_array(i32, i64) -> make_array(i64, i64)` | ||||||||||||||||||
VariadicEqual, | ||||||||||||||||||
This comment was marked as outdated.
Sorry, something went wrong. |
||||||||||||||||||
/// One or more arguments of an arbitrary but equal type or Null. | ||||||||||||||||||
/// Non-comparison coercion is attempted to match the signatures. | ||||||||||||||||||
/// | ||||||||||||||||||
/// Functions like `coalesce` is `VariadicEqual`. | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a little confused about what "Non-comparison coercion" means in this situation. Specifically how comparison coercion and non comparision coercion differ 🤔 Does non comparison coercion mean "type union resolution" (aka Also, I think
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. Actually, I think Ideally, we could shape At least I can tell the dict-coercion is different between these two now. |
||||||||||||||||||
// TODO: Temporary Signature, to differentiate existing VariadicEqual. | ||||||||||||||||||
// After we swtich `make_array` to VariadicEqualOrNull, | ||||||||||||||||||
// we can reuse VariadicEqual. | ||||||||||||||||||
VariadicEqualOrNull, | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need a similar signature but an exact args number for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could do something like this (basically to flavor the type signature) 🤔 pub enum TypeSignature {
...
/// Rather than the usual coercion rules, special type union rules are applied
Union(Box<TypeSignature>)
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice idea! |
||||||||||||||||||
/// One or more arguments with arbitrary types | ||||||||||||||||||
VariadicAny, | ||||||||||||||||||
/// Fixed number of arguments of an arbitrary but equal type out of a list of valid types. | ||||||||||||||||||
|
@@ -193,6 +201,9 @@ impl TypeSignature { | |||||||||||||||||
TypeSignature::VariadicEqual => { | ||||||||||||||||||
vec!["CoercibleT, .., CoercibleT".to_string()] | ||||||||||||||||||
} | ||||||||||||||||||
TypeSignature::VariadicEqualOrNull => { | ||||||||||||||||||
vec!["CoercibleT or NULL, .., CoercibleT or NULL".to_string()] | ||||||||||||||||||
} | ||||||||||||||||||
TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], | ||||||||||||||||||
TypeSignature::OneOf(sigs) => { | ||||||||||||||||||
sigs.iter().flat_map(|s| s.to_string_repr()).collect() | ||||||||||||||||||
|
@@ -255,13 +266,20 @@ impl Signature { | |||||||||||||||||
volatility, | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
/// An arbitrary number of arguments of the same type. | ||||||||||||||||||
/// One or more number of arguments to the same type. | ||||||||||||||||||
pub fn variadic_equal(volatility: Volatility) -> Self { | ||||||||||||||||||
Self { | ||||||||||||||||||
type_signature: TypeSignature::VariadicEqual, | ||||||||||||||||||
volatility, | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
/// One or more number of arguments of the same type. | ||||||||||||||||||
pub fn variadic_equal_or_null(volatility: Volatility) -> Self { | ||||||||||||||||||
Self { | ||||||||||||||||||
type_signature: TypeSignature::VariadicEqualOrNull, | ||||||||||||||||||
volatility, | ||||||||||||||||||
} | ||||||||||||||||||
} | ||||||||||||||||||
/// An arbitrary number of arguments of any type. | ||||||||||||||||||
pub fn variadic_any(volatility: Volatility) -> Self { | ||||||||||||||||||
Self { | ||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,8 @@ use arrow::datatypes::{ | |
|
||
use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result}; | ||
|
||
use super::functions::coerced_from; | ||
|
||
/// The type signature of an instantiation of binary operator expression such as | ||
/// `lhs + rhs` | ||
/// | ||
|
@@ -289,7 +291,118 @@ fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option<DataT | |
} | ||
} | ||
|
||
#[derive(Debug, PartialEq, Eq)] | ||
enum TypeCategory { | ||
Array, | ||
Boolean, | ||
Numeric, | ||
String, | ||
DateTime, | ||
Composite, | ||
Unknown, | ||
} | ||
|
||
fn data_type_category(data_type: &DataType) -> TypeCategory { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could also be something like impl From<&DataType> for TypeCategory {
... And then you create a let category = TypeCategory::from(&type) |
||
if data_type.is_numeric() { | ||
return TypeCategory::Numeric; | ||
} | ||
|
||
if matches!(data_type, DataType::Boolean) { | ||
return TypeCategory::Boolean; | ||
} | ||
|
||
if matches!( | ||
data_type, | ||
DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) | ||
) { | ||
return TypeCategory::Array; | ||
} | ||
|
||
if matches!(data_type, DataType::Utf8 | DataType::LargeUtf8) { | ||
return TypeCategory::String; | ||
} | ||
|
||
if matches!( | ||
data_type, | ||
DataType::Date32 | ||
| DataType::Date64 | ||
| DataType::Time32(_) | ||
| DataType::Time64(_) | ||
| DataType::Timestamp(_, _) | ||
| DataType::Interval(_) | ||
| DataType::Duration(_) | ||
) { | ||
return TypeCategory::DateTime; | ||
} | ||
|
||
if matches!( | ||
data_type, | ||
DataType::Dictionary(_, _) | DataType::Struct(_) | DataType::Union(_, _) | ||
) { | ||
return TypeCategory::Composite; | ||
} | ||
|
||
TypeCategory::Unknown | ||
} | ||
|
||
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of constructs including | ||
/// CASE, ARRAY, VALUES, and the GREATEST and LEAST functions. | ||
/// See <https://www.postgresql.org/docs/current/typeconv-union-case.html> for more information. | ||
pub fn type_resolution(data_types: &[DataType]) -> Option<DataType> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the biggest difference between comparison coercion is that we categorize types. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The doc string is a bit confusing here because there are no lhs_type and rhs_type. I assume that case would be Also, maybe this function name could reflect that it's finding a union type to satisfy a set of input types, for example |
||
if data_types.is_empty() { | ||
return None; | ||
} | ||
|
||
// if all the data_types is the same return first one | ||
if data_types.iter().all(|t| t == &data_types[0]) { | ||
return Some(data_types[0].clone()); | ||
} | ||
|
||
// if all the data_types are null, return string | ||
if data_types.iter().all(|t| t == &DataType::Null) { | ||
return Some(DataType::Utf8); | ||
} | ||
|
||
// Ignore Nulls, if any data_type category is not the same, return None | ||
let data_types_category: Vec<TypeCategory> = data_types | ||
.iter() | ||
.filter(|&t| t != &DataType::Null) | ||
.map(data_type_category) | ||
.collect(); | ||
if data_types_category | ||
.iter() | ||
.any(|t| t != &data_types_category[0]) | ||
{ | ||
return None; | ||
} | ||
|
||
// Ignore Nulls | ||
let mut candidate_type: Option<DataType> = None; | ||
for data_type in data_types.iter() { | ||
if data_type == &DataType::Null { | ||
continue; | ||
} | ||
if let Some(ref candidate_t) = candidate_type { | ||
// `coerced_from` is designed uni-directional for `can_coerced_from` so we need to check both directions | ||
if let Some(t) = coerced_from(data_type, candidate_t) { | ||
candidate_type = Some(t); | ||
} else if let Some(t) = coerced_from(candidate_t, data_type) { | ||
candidate_type = Some(t); | ||
} else { | ||
// Not coercible, return None | ||
return None; | ||
} | ||
} else { | ||
candidate_type = Some(data_type.clone()); | ||
} | ||
} | ||
|
||
candidate_type | ||
} | ||
|
||
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation | ||
/// Unlike [coerced_from], usually the coerced type is for comparison only. | ||
/// For example, compare with Dictionary and Dictionary, only value type is what we care about | ||
pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { | ||
if lhs_type == rhs_type { | ||
// same type => equality is possible | ||
|
@@ -375,20 +488,13 @@ pub(crate) fn comparison_binary_numeric_coercion( | |
return Some(lhs_type.clone()); | ||
} | ||
|
||
if let Some(t) = decimal_coercion(lhs_type, rhs_type) { | ||
return Some(t); | ||
} | ||
|
||
// these are ordered from most informative to least informative so | ||
// that the coercion does not lose information via truncation | ||
match (lhs_type, rhs_type) { | ||
// Prefer decimal data type over floating point for comparison operation | ||
(Decimal128(_, _), Decimal128(_, _)) => { | ||
get_wider_decimal_type(lhs_type, rhs_type) | ||
} | ||
(Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), | ||
(_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), | ||
(Decimal256(_, _), Decimal256(_, _)) => { | ||
get_wider_decimal_type(lhs_type, rhs_type) | ||
} | ||
(Decimal256(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), | ||
(_, Decimal256(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), | ||
(Float64, _) | (_, Float64) => Some(Float64), | ||
(_, Float32) | (Float32, _) => Some(Float32), | ||
// The following match arms encode the following logic: Given the two | ||
|
@@ -426,9 +532,31 @@ pub(crate) fn comparison_binary_numeric_coercion( | |
} | ||
} | ||
|
||
/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of | ||
/// a comparison operation where one is a decimal | ||
fn get_comparison_common_decimal_type( | ||
/// Decimal coercion rules. | ||
pub(crate) fn decimal_coercion( | ||
lhs_type: &DataType, | ||
rhs_type: &DataType, | ||
) -> Option<DataType> { | ||
use arrow::datatypes::DataType::*; | ||
|
||
match (lhs_type, rhs_type) { | ||
// Prefer decimal data type over floating point for comparison operation | ||
(Decimal128(_, _), Decimal128(_, _)) => { | ||
get_wider_decimal_type(lhs_type, rhs_type) | ||
} | ||
(Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), | ||
(_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type), | ||
(Decimal256(_, _), Decimal256(_, _)) => { | ||
get_wider_decimal_type(lhs_type, rhs_type) | ||
} | ||
(Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), | ||
(_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type), | ||
(_, _) => None, | ||
} | ||
} | ||
|
||
/// Coerce `lhs_type` and `rhs_type` to a common type. | ||
fn get_common_decimal_type( | ||
decimal_type: &DataType, | ||
other_type: &DataType, | ||
) -> Option<DataType> { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ use std::sync::Arc; | |
use crate::signature::{ | ||
ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, | ||
}; | ||
use crate::type_coercion::binary::{decimal_coercion, type_resolution}; | ||
use crate::{Signature, TypeSignature}; | ||
use arrow::{ | ||
compute::can_cast_types, | ||
|
@@ -28,7 +29,7 @@ use arrow::{ | |
use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims}; | ||
use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; | ||
|
||
use super::binary::{comparison_binary_numeric_coercion, comparison_coercion}; | ||
use super::binary::comparison_coercion; | ||
|
||
/// Performs type coercion for function arguments. | ||
/// | ||
|
@@ -54,7 +55,6 @@ pub fn data_types( | |
} | ||
|
||
let valid_types = get_valid_types(&signature.type_signature, current_types)?; | ||
|
||
if valid_types | ||
.iter() | ||
.any(|data_type| data_type == current_types) | ||
|
@@ -184,6 +184,13 @@ fn get_valid_types( | |
.iter() | ||
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) | ||
.collect(), | ||
TypeSignature::VariadicEqualOrNull => { | ||
if let Some(common_type) = type_resolution(current_types) { | ||
vec![vec![common_type; current_types.len()]] | ||
} else { | ||
vec![] | ||
} | ||
} | ||
TypeSignature::VariadicEqual => { | ||
let new_type = current_types.iter().skip(1).try_fold( | ||
current_types.first().unwrap().clone(), | ||
|
@@ -307,11 +314,14 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { | |
false | ||
} | ||
|
||
fn coerced_from<'a>( | ||
/// Coerced_from implicitly casts between types. | ||
/// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion. | ||
pub(crate) fn coerced_from<'a>( | ||
type_into: &'a DataType, | ||
type_from: &'a DataType, | ||
) -> Option<DataType> { | ||
use self::DataType::*; | ||
|
||
// match Dictionary first | ||
match (type_into, type_from) { | ||
// coerced dictionary first | ||
|
@@ -325,6 +335,14 @@ fn coerced_from<'a>( | |
{ | ||
Some(type_into.clone()) | ||
} | ||
(Decimal128(_, _), Decimal128(_, _)) | (Decimal256(_, _), Decimal256(_, _)) => { | ||
decimal_coercion(type_into, type_from) | ||
} | ||
(Decimal128(_, _) | Decimal256(_, _), _) | ||
if matches!(type_from, Int8 | Int16 | Int32 | Int64) => | ||
jayzhan211 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
decimal_coercion(type_into, type_from) | ||
} | ||
// coerced into type_into | ||
(Int8, _) if matches!(type_from, Null | Int8) => Some(type_into.clone()), | ||
(Int16, _) if matches!(type_from, Null | Int8 | Int16 | UInt8) => { | ||
|
@@ -429,7 +447,6 @@ fn coerced_from<'a>( | |
} | ||
_ => None, | ||
}, | ||
|
||
(Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => { | ||
match type_from { | ||
Timestamp(_, Some(from_tz)) => { | ||
|
@@ -450,19 +467,7 @@ fn coerced_from<'a>( | |
{ | ||
Some(type_into.clone()) | ||
} | ||
// More coerce rules. | ||
// Note that not all rules in `comparison_coercion` can be reused here. | ||
// For example, all numeric types can be coerced into Utf8 for comparison, | ||
// but not for function arguments. | ||
_ => comparison_binary_numeric_coercion(type_into, type_from).and_then( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic is introduced in #9459, so I think it is safe to remove together with this PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @viirya |
||
|coerced_type| { | ||
if *type_into == coerced_type { | ||
Some(coerced_type) | ||
} else { | ||
None | ||
} | ||
}, | ||
), | ||
_ => None, | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,8 +22,8 @@ use arrow::compute::kernels::zip::zip; | |
use arrow::compute::{and, is_not_null, is_null}; | ||
use arrow::datatypes::DataType; | ||
|
||
use datafusion_common::{exec_err, Result}; | ||
use datafusion_expr::type_coercion::functions::data_types; | ||
use datafusion_common::{exec_err, internal_err, Result}; | ||
use datafusion_expr::type_coercion::binary::type_resolution; | ||
use datafusion_expr::ColumnarValue; | ||
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; | ||
|
||
|
@@ -41,7 +41,7 @@ impl Default for CoalesceFunc { | |
impl CoalesceFunc { | ||
pub fn new() -> Self { | ||
Self { | ||
signature: Signature::variadic_equal(Volatility::Immutable), | ||
signature: Signature::variadic_equal_or_null(Volatility::Immutable), | ||
} | ||
} | ||
} | ||
|
@@ -60,9 +60,11 @@ impl ScalarUDFImpl for CoalesceFunc { | |
} | ||
|
||
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
// COALESCE has multiple args and they might get coerced, get a preview of this | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 ❤️ |
||
let coerced_types = data_types(arg_types, self.signature()); | ||
coerced_types.map(|types| types[0].clone()) | ||
if let Some(common_type) = type_resolution(arg_types) { | ||
Ok(common_type) | ||
} else { | ||
internal_err!("Error should be thrown via signature validation") | ||
} | ||
} | ||
|
||
/// coalesce evaluates to the first value which is not NULL | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please link to https://docs.rs/datafusion/latest/datafusion/logical_expr/type_coercion/binary/fn.comparison_coercion.html that explains (however limited) what comparison coercion is?