diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 74d2ce0b6be9..e186a329de67 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1388,26 +1388,23 @@ impl TreeNodeRewriter for Simplifier<'_, S> { && when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number && info.is_boolean_type(&when_then_expr[0].1)? => { - // The disjunction of all the when predicates encountered so far + // String disjunction of all the when predicates encountered so far. Not nullable. let mut filter_expr = lit(false); // The disjunction of all the cases let mut out_expr = lit(false); for (when, then) in when_then_expr { - let case_expr = when - .as_ref() - .clone() - .and(filter_expr.clone().not()) - .and(*then); + let when = is_exactly_true(*when, info)?; + let case_expr = + when.clone().and(filter_expr.clone().not()).and(*then); out_expr = out_expr.or(case_expr); - filter_expr = filter_expr.or(*when); + filter_expr = filter_expr.or(when); } - if let Some(else_expr) = else_expr { - let case_expr = filter_expr.not().and(*else_expr); - out_expr = out_expr.or(case_expr); - } + let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null); + let case_expr = filter_expr.not().and(else_expr); + out_expr = out_expr.or(case_expr); // Do a first pass at simplification out_expr.rewrite(self)? @@ -1881,6 +1878,19 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result { Ok(Expr::InList(l1)) } +/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL). +fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result { + if !info.nullable(&expr)? { + Ok(expr) + } else { + Ok(Expr::BinaryExpr(BinaryExpr { + left: Box::new(expr), + op: Operator::IsNotDistinctFrom, + right: Box::new(lit(true)), + })) + } +} + #[cfg(test)] mod tests { use crate::simplify_expressions::SimplifyContext; @@ -3272,12 +3282,12 @@ mod tests { simplify(Expr::Case(Case::new( None, vec![( - Box::new(col("c2").not_eq(lit(false))), + Box::new(col("c2_non_null").not_eq(lit(false))), Box::new(lit("ok").eq(lit("not_ok"))), )], - Some(Box::new(col("c2").eq(lit(true)))), + Some(Box::new(col("c2_non_null").eq(lit(true)))), ))), - col("c2").not().and(col("c2")) // #1716 + lit(false) // #1716 ); // CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2 @@ -3292,12 +3302,12 @@ mod tests { simplify(simplify(Expr::Case(Case::new( None, vec![( - Box::new(col("c2").not_eq(lit(false))), + Box::new(col("c2_non_null").not_eq(lit(false))), Box::new(lit("ok").eq(lit("ok"))), )], - Some(Box::new(col("c2").eq(lit(true)))), + Some(Box::new(col("c2_non_null").eq(lit(true)))), )))), - col("c2") + col("c2_non_null") ); // CASE WHEN ISNULL(c2) THEN true ELSE c2 @@ -3328,12 +3338,12 @@ mod tests { simplify(simplify(Expr::Case(Case::new( None, vec![ - (Box::new(col("c1")), Box::new(lit(true)),), - (Box::new(col("c2")), Box::new(lit(false)),), + (Box::new(col("c1_non_null")), Box::new(lit(true)),), + (Box::new(col("c2_non_null")), Box::new(lit(false)),), ], Some(Box::new(lit(true))), )))), - col("c1").or(col("c1").not().and(col("c2").not())) + col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not())) ); // CASE WHEN c1 then true WHEN c2 then true ELSE false @@ -3347,13 +3357,53 @@ mod tests { simplify(simplify(Expr::Case(Case::new( None, vec![ - (Box::new(col("c1")), Box::new(lit(true)),), - (Box::new(col("c2")), Box::new(lit(false)),), + (Box::new(col("c1_non_null")), Box::new(lit(true)),), + (Box::new(col("c2_non_null")), Box::new(lit(false)),), ], Some(Box::new(lit(true))), )))), - col("c1").or(col("c1").not().and(col("c2").not())) + col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not())) + ); + + // CASE WHEN c > 0 THEN true END AS c1 + assert_eq!( + simplify(simplify(Expr::Case(Case::new( + None, + vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))], + None, + )))), + not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from( + col("c3").gt(lit(0_i64)), + lit(true) + ) + .and(lit_bool_null())) ); + + // CASE WHEN c > 0 THEN true ELSE false END AS c1 + assert_eq!( + simplify(simplify(Expr::Case(Case::new( + None, + vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))], + Some(Box::new(lit(false))), + )))), + not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)) + ); + } + + fn distinct_from(left: impl Into, right: impl Into) -> Expr { + Expr::BinaryExpr(BinaryExpr { + left: Box::new(left.into()), + op: Operator::IsDistinctFrom, + right: Box::new(right.into()), + }) + } + + fn not_distinct_from(left: impl Into, right: impl Into) -> Expr { + Expr::BinaryExpr(BinaryExpr { + left: Box::new(left.into()), + op: Operator::IsNotDistinctFrom, + right: Box::new(right.into()), + }) } #[test] diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 157bfb8a02aa..a339c2aa037e 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -289,12 +289,22 @@ query B select case when a=1 then false end from foo; ---- false -false -false -false -false -false +NULL +NULL +NULL +NULL +NULL +query IBB +SELECT c, + CASE WHEN c > 0 THEN true END AS c1, + CASE WHEN c > 0 THEN true ELSE false END AS c2 +FROM (VALUES (1), (0), (-1), (NULL)) AS t(c) +---- +1 true true +0 NULL false +-1 NULL false +NULL NULL false statement ok drop table foo