diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index 0fffd84b7047..5e0fe2d7b85f 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -18,7 +18,7 @@ //! Signature module contains foundational types that are used to represent signatures, types, //! and return types of functions in DataFusion. -use crate::type_coercion::aggregates::{NUMERICS, STRINGS}; +use crate::type_coercion::aggregates::NUMERICS; use arrow::datatypes::DataType; use datafusion_common::types::{LogicalTypeRef, NativeType}; use itertools::Itertools; @@ -113,6 +113,15 @@ pub enum TypeSignature { /// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]` /// since i32 and f32 can be casted to f64 Coercible(Vec), + /// The arguments will be coerced to a single type based on the comparison rules. + /// For example, i32 and i64 has coerced type Int64. + /// + /// Note: + /// - If compares with numeric and string, numeric is preferred for numeric string cases. For example, nullif('2', 1) has coerced types Int64. + /// - If the result is Null, it will be coerced to String (Utf8View). + /// + /// See `comparison_coercion_numeric` for more details. + Comparable(usize), /// Fixed number of arguments of arbitrary types, number should be larger than 0 Any(usize), /// Matches exactly one of a list of [`TypeSignature`]s. Coercion is attempted to match @@ -138,6 +147,13 @@ pub enum TypeSignature { NullAry, } +impl TypeSignature { + #[inline] + pub fn is_one_of(&self) -> bool { + matches!(self, TypeSignature::OneOf(_)) + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ArrayFunctionSignature { /// Specialized Signature for ArrayAppend and similar functions @@ -210,6 +226,9 @@ impl TypeSignature { TypeSignature::Numeric(num) => { vec![format!("Numeric({num})")] } + TypeSignature::Comparable(num) => { + vec![format!("Comparable({num})")] + } TypeSignature::Coercible(types) => { vec![Self::join_types(types, ", ")] } @@ -284,13 +303,13 @@ impl TypeSignature { .cloned() .map(|numeric_type| vec![numeric_type; *arg_count]) .collect(), - TypeSignature::String(arg_count) => STRINGS - .iter() - .cloned() - .map(|string_type| vec![string_type; *arg_count]) - .collect(), + TypeSignature::String(arg_count) => get_data_types(&NativeType::String) + .into_iter() + .map(|dt| vec![dt; *arg_count]) + .collect::>(), // TODO: Implement for other types TypeSignature::Any(_) + | TypeSignature::Comparable(_) | TypeSignature::NullAry | TypeSignature::VariadicAny | TypeSignature::ArraySignature(_) @@ -412,6 +431,14 @@ impl Signature { } } + /// Used for function that expects comparable data types, it will try to coerced all the types into single final one. + pub fn comparable(arg_count: usize, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::Comparable(arg_count), + volatility, + } + } + pub fn nullary(volatility: Volatility) -> Self { Signature { type_signature: TypeSignature::NullAry, diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 31fe6a59baee..39ccf202574f 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -28,6 +28,7 @@ use arrow::datatypes::{ DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; +use datafusion_common::types::NativeType; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err, Result, }; @@ -643,6 +644,21 @@ pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { + if lhs_type == rhs_type { + // same type => equality is possible + return Some(lhs_type.clone()); + } + binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| null_coercion(lhs_type, rhs_type)) + .or_else(|| string_numeric_coercion_as_numeric(lhs_type, rhs_type)) +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation /// where one is numeric and one is `Utf8`/`LargeUtf8`. fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { @@ -656,6 +672,24 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { + let lhs_logical_type = NativeType::from(lhs_type); + let rhs_logical_type = NativeType::from(rhs_type); + if lhs_logical_type.is_numeric() && rhs_logical_type == NativeType::String { + return Some(lhs_type.to_owned()); + } + if rhs_logical_type.is_numeric() && lhs_logical_type == NativeType::String { + return Some(rhs_type.to_owned()); + } + + None +} + /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation /// where one is temporal and one is `Utf8View`/`Utf8`/`LargeUtf8`. /// diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 6836713d8016..5f52c7ccc20e 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -29,7 +29,7 @@ use datafusion_common::{ }; use datafusion_expr_common::{ signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, - type_coercion::binary::string_coercion, + type_coercion::binary::{comparison_coercion_numeric, string_coercion}, }; use std::sync::Arc; @@ -182,6 +182,7 @@ fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { | TypeSignature::Coercible(_) | TypeSignature::Any(_) | TypeSignature::NullAry + | TypeSignature::Comparable(_) ) } @@ -194,13 +195,18 @@ fn try_coerce_types( // Well-supported signature that returns exact valid types. if !valid_types.is_empty() && is_well_supported_signature(type_signature) { - // exact valid types - assert_eq!(valid_types.len(), 1); + // There may be many valid types if valid signature is OneOf + // Otherwise, there should be only one valid type + if !type_signature.is_one_of() { + assert_eq!(valid_types.len(), 1); + } + let valid_types = valid_types.swap_remove(0); if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) { return Ok(t); } } else { + // TODO: Deprecate this branch after all signatures are well-supported (aka coercion has happened already) // Try and coerce the argument types to match the signature, returning the // coerced types from the first matching signature. for valid_types in valid_types { @@ -515,6 +521,23 @@ fn get_valid_types( vec![vec![valid_type; *number]] } + TypeSignature::Comparable(num) => { + function_length_check(current_types.len(), *num)?; + let mut target_type = current_types[0].to_owned(); + for data_type in current_types.iter().skip(1) { + if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) { + target_type = dt; + } else { + return plan_err!("{target_type} and {data_type} is not comparable"); + } + } + // Convert null to String type. + if target_type.is_null() { + vec![vec![DataType::Utf8View; *num]] + } else { + vec![vec![target_type; *num]] + } + } TypeSignature::Coercible(target_types) => { function_length_check(current_types.len(), target_types.len())?; diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index 801a80201946..05af8d3f589e 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -32,26 +32,6 @@ pub struct NullIfFunc { signature: Signature, } -/// Currently supported types by the nullif function. -/// The order of these types correspond to the order on which coercion applies -/// This should thus be from least informative to most informative -static SUPPORTED_NULLIF_TYPES: &[DataType] = &[ - DataType::Boolean, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - DataType::Utf8View, - DataType::Utf8, - DataType::LargeUtf8, -]; - impl Default for NullIfFunc { fn default() -> Self { Self::new() @@ -61,11 +41,20 @@ impl Default for NullIfFunc { impl NullIfFunc { pub fn new() -> Self { Self { - signature: Signature::uniform( - 2, - SUPPORTED_NULLIF_TYPES.to_vec(), - Volatility::Immutable, - ), + // Documentation mentioned in Postgres, + // The result has the same type as the first argument — but there is a subtlety. + // What is actually returned is the first argument of the implied = operator, + // and in some cases that will have been promoted to match the second argument's type. + // For example, NULLIF(1, 2.2) yields numeric, because there is no integer = numeric operator, only numeric = numeric + // + // We don't strictly follow Postgres or DuckDB for **simplicity**. + // In this function, we will coerce arguments to the same data type for comparison need. Unlike DuckDB + // we don't return the **original** first argument type but return the final coerced type. + // + // In Postgres, nullif('2', 2) returns Null but nullif('2::varchar', 2) returns error. + // While in DuckDB both query returns Null. We follow DuckDB in this case since I think they are equivalent thing and should + // have the same result as well. + signature: Signature::comparable(2, Volatility::Immutable), } } } @@ -83,14 +72,7 @@ impl ScalarUDFImpl for NullIfFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - // NULLIF has two args and they might get coerced, get a preview of this - let coerced_types = datafusion_expr::type_coercion::functions::data_types( - arg_types, - &self.signature, - ); - coerced_types - .map(|typs| typs[0].clone()) - .map_err(|e| e.context("Failed to coerce arguments for NULLIF")) + Ok(arg_types[0].to_owned()) } fn invoke(&self, args: &[ColumnarValue]) -> Result { diff --git a/datafusion/sqllogictest/test_files/nullif.slt b/datafusion/sqllogictest/test_files/nullif.slt index a5060077fe77..18642f6971ca 100644 --- a/datafusion/sqllogictest/test_files/nullif.slt +++ b/datafusion/sqllogictest/test_files/nullif.slt @@ -97,11 +97,54 @@ SELECT NULLIF(1, 3); ---- 1 -query I +query T SELECT NULLIF(NULL, NULL); ---- NULL +query R +select nullif(1, 1.2); +---- +1 + +query R +select nullif(1.0, 2); +---- +1 + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +select nullif(2, 'a'); + +query T +select nullif('2', '3'); +---- +2 + +query I +select nullif(2, '1'); +---- +2 + +query I +select nullif('2', 2); +---- +NULL + +query I +select nullif('1', 2); +---- +1 + +statement ok +create table t(a varchar, b int) as values ('1', 2), ('2', 2), ('3', 2); + +query I +select nullif(a, b) from t; +---- +1 +NULL +3 + query T SELECT NULLIF(arrow_cast('a', 'Utf8View'), 'a'); ---- @@ -130,4 +173,4 @@ NULL query T SELECT NULLIF(arrow_cast('a', 'Utf8View'), null); ---- -a \ No newline at end of file +a