Skip to content

Commit

Permalink
Fix incorrect searched CASE optimization
Browse files Browse the repository at this point in the history
There is an optimization for searched CASE where values are of boolean
type. It was converting the expression like

    CASE
        WHEN X THEN A
        WHEN Y THEN B
        ..
        [ ELSE D ]
    END

into

    (X AND A)
        OR (Y AND NOT X AND B)
        [ OR (NOT (X OR Y) AND D) ]

This had the following problems

- does not work for nullable conditions. If X is nullable, we cannot use
  NOT (X) to compliment it. We need to use `X IS DISTINCT FROM true`
- it does not work correctly when some conditions are nullable and other
  values are false. E.g. X=NULL, A=true, Y=NULL, B=true, D=false, the
  CASE should return false, but the boolean expression will simplify to
  `(NULL AND ..) OR (NULL AND ..) OR (false)` which is NULL, not false
  - thus we use `X` for truthness check of `X`, we need to test `X IS
    NOT DISTINCT FROM true`
- it did not work correctly when default D is missing, but conditions
  do not evaluate to NULL. CASE's result should be NULL but was false.

This commit fixes that optimization.
  • Loading branch information
findepi committed Jan 29, 2025
1 parent 66b4da2 commit 6fe2b99
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 28 deletions.
96 changes: 73 additions & 23 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1388,26 +1388,23 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
&& when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
&& info.is_boolean_type(&when_then_expr[0].1)? =>
{
// The disjunction of all the when predicates encountered so far
// String disjunction of all the when predicates encountered so far. Not nullable.
let mut filter_expr = lit(false);
// The disjunction of all the cases
let mut out_expr = lit(false);

for (when, then) in when_then_expr {
let case_expr = when
.as_ref()
.clone()
.and(filter_expr.clone().not())
.and(*then);
let when = is_exactly_true(*when, info)?;
let case_expr =
when.clone().and(filter_expr.clone().not()).and(*then);

out_expr = out_expr.or(case_expr);
filter_expr = filter_expr.or(*when);
filter_expr = filter_expr.or(when);
}

if let Some(else_expr) = else_expr {
let case_expr = filter_expr.not().and(*else_expr);
out_expr = out_expr.or(case_expr);
}
let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
let case_expr = filter_expr.not().and(else_expr);
out_expr = out_expr.or(case_expr);

// Do a first pass at simplification
out_expr.rewrite(self)?
Expand Down Expand Up @@ -1881,6 +1878,19 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
Ok(Expr::InList(l1))
}

/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
if !info.nullable(&expr)? {
Ok(expr)
} else {
Ok(Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr),
op: Operator::IsNotDistinctFrom,
right: Box::new(lit(true)),
}))
}
}

#[cfg(test)]
mod tests {
use crate::simplify_expressions::SimplifyContext;
Expand Down Expand Up @@ -3272,12 +3282,12 @@ mod tests {
simplify(Expr::Case(Case::new(
None,
vec![(
Box::new(col("c2").not_eq(lit(false))),
Box::new(col("c2_non_null").not_eq(lit(false))),
Box::new(lit("ok").eq(lit("not_ok"))),
)],
Some(Box::new(col("c2").eq(lit(true)))),
Some(Box::new(col("c2_non_null").eq(lit(true)))),
))),
col("c2").not().and(col("c2")) // #1716
lit(false) // #1716
);

// CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
Expand All @@ -3292,12 +3302,12 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![(
Box::new(col("c2").not_eq(lit(false))),
Box::new(col("c2_non_null").not_eq(lit(false))),
Box::new(lit("ok").eq(lit("ok"))),
)],
Some(Box::new(col("c2").eq(lit(true)))),
Some(Box::new(col("c2_non_null").eq(lit(true)))),
)))),
col("c2")
col("c2_non_null")
);

// CASE WHEN ISNULL(c2) THEN true ELSE c2
Expand Down Expand Up @@ -3328,12 +3338,12 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("c1")), Box::new(lit(true)),),
(Box::new(col("c2")), Box::new(lit(false)),),
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
],
Some(Box::new(lit(true))),
)))),
col("c1").or(col("c1").not().and(col("c2").not()))
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
);

// CASE WHEN c1 then true WHEN c2 then true ELSE false
Expand All @@ -3347,13 +3357,53 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("c1")), Box::new(lit(true)),),
(Box::new(col("c2")), Box::new(lit(false)),),
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
],
Some(Box::new(lit(true))),
)))),
col("c1").or(col("c1").not().and(col("c2").not()))
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
);

// CASE WHEN c > 0 THEN true END AS c1
assert_eq!(
simplify(simplify(Expr::Case(Case::new(
None,
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
None,
)))),
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
col("c3").gt(lit(0_i64)),
lit(true)
)
.and(lit_bool_null()))
);

// CASE WHEN c > 0 THEN true ELSE false END AS c1
assert_eq!(
simplify(simplify(Expr::Case(Case::new(
None,
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
Some(Box::new(lit(false))),
)))),
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
);
}

fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left.into()),
op: Operator::IsDistinctFrom,
right: Box::new(right.into()),
})
}

fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left.into()),
op: Operator::IsNotDistinctFrom,
right: Box::new(right.into()),
})
}

#[test]
Expand Down
20 changes: 15 additions & 5 deletions datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,22 @@ query B
select case when a=1 then false end from foo;
----
false
false
false
false
false
false
NULL
NULL
NULL
NULL
NULL

query IBB
SELECT c,
CASE WHEN c > 0 THEN true END AS c1,
CASE WHEN c > 0 THEN true ELSE false END AS c2
FROM (VALUES (1), (0), (-1), (NULL)) AS t(c)
----
1 true true
0 NULL false
-1 NULL false
NULL NULL false

statement ok
drop table foo

0 comments on commit 6fe2b99

Please sign in to comment.