Skip to content

Commit

Permalink
Support aliases in ConstEvaluator (#14734)
Browse files Browse the repository at this point in the history
Not sure why they are not supported. It seems that if we're not careful,
some transformations can introduce aliases nested inside other expressions.
  • Loading branch information
joroKr21 authored Feb 19, 2025
1 parent 6a036ae commit c176533
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 26 deletions.
27 changes: 27 additions & 0 deletions datafusion/core/tests/expr_api/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,33 @@ fn test_const_evaluator() {
);
}

#[test]
fn test_const_evaluator_alias() {
// true --> true
test_evaluate(lit(true).alias("a"), lit(true));
// true or true --> true
test_evaluate(lit(true).alias("a").or(lit(true).alias("b")), lit(true));
// "foo" == "foo" --> true
test_evaluate(lit("foo").alias("a").eq(lit("foo").alias("b")), lit(true));
// c = 1 + 2 --> c + 3
test_evaluate(
col("c")
.alias("a")
.eq(lit(1).alias("b") + lit(2).alias("c")),
col("c").alias("a").eq(lit(3)),
);
// (foo != foo) OR (c = 1) --> false OR (c = 1)
test_evaluate(
lit("foo")
.alias("a")
.not_eq(lit("foo").alias("b"))
.alias("c")
.or(col("c").alias("d").eq(lit(1).alias("e")))
.alias("f"),
col("c").alias("d").eq(lit(1)).alias("f"),
);
}

#[test]
fn test_const_evaluator_scalar_functions() {
// concat("foo", "bar") --> "foobar"
Expand Down
41 changes: 18 additions & 23 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,15 @@ use datafusion_expr::{
};
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};

use super::inlist_simplifier::ShortenInListSimplifier;
use super::utils::*;
use crate::analyzer::type_coercion::TypeCoercionRewriter;
use crate::simplify_expressions::guarantees::GuaranteeRewriter;
use crate::simplify_expressions::regex::simplify_regex_expr;
use crate::simplify_expressions::SimplifyInfo;
use indexmap::IndexSet;
use regex::Regex;

use super::inlist_simplifier::ShortenInListSimplifier;
use super::utils::*;

/// This structure handles API for expression simplification
///
/// Provides simplification information based on DFSchema and
Expand Down Expand Up @@ -515,30 +514,27 @@ impl TreeNodeRewriter for ConstEvaluator<'_> {

// NB: do not short circuit recursion even if we find a non
// evaluatable node (so we can fold other children, args to
// functions, etc)
// functions, etc.)
Ok(Transformed::no(expr))
}

fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
match self.can_evaluate.pop() {
// Certain expressions such as `CASE` and `COALESCE` are short circuiting
// and may not evaluate all their sub expressions. Thus if
// if any error is countered during simplification, return the original
// Certain expressions such as `CASE` and `COALESCE` are short-circuiting
// and may not evaluate all their sub expressions. Thus, if
// any error is countered during simplification, return the original
// so that normal evaluation can occur
Some(true) => {
let result = self.evaluate_to_scalar(expr);
match result {
ConstSimplifyResult::Simplified(s) => {
Ok(Transformed::yes(Expr::Literal(s)))
}
ConstSimplifyResult::NotSimplified(s) => {
Ok(Transformed::no(Expr::Literal(s)))
}
ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
Ok(Transformed::yes(expr))
}
Some(true) => match self.evaluate_to_scalar(expr) {
ConstSimplifyResult::Simplified(s) => {
Ok(Transformed::yes(Expr::Literal(s)))
}
}
ConstSimplifyResult::NotSimplified(s) => {
Ok(Transformed::no(Expr::Literal(s)))
}
ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
Ok(Transformed::yes(expr))
}
},
Some(false) => Ok(Transformed::no(expr)),
_ => internal_err!("Failed to pop can_evaluate"),
}
Expand Down Expand Up @@ -586,9 +582,7 @@ impl<'a> ConstEvaluator<'a> {
// added they can be checked for their ability to be evaluated
// at plan time
match expr {
// Has no runtime cost, but needed during planning
Expr::Alias(..)
| Expr::AggregateFunction { .. }
Expr::AggregateFunction { .. }
| Expr::ScalarVariable(_, _)
| Expr::Column(_)
| Expr::OuterReferenceColumn(_, _)
Expand All @@ -603,6 +597,7 @@ impl<'a> ConstEvaluator<'a> {
Self::volatility_ok(func.signature().volatility)
}
Expr::Literal(_)
| Expr::Alias(..)
| Expr::Unnest(_)
| Expr::BinaryExpr { .. }
| Expr::Not(_)
Expand Down
6 changes: 3 additions & 3 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ query TT
explain SELECT t1_id, (SELECT count(*) as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1
----
logical_plan
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) AS _cnt ELSE __scalar_sq_1._cnt END AS cnt
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1._cnt END AS cnt
02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
03)----TableScan: t1 projection=[t1_id, t1_int]
04)----SubqueryAlias: __scalar_sq_1
Expand All @@ -855,7 +855,7 @@ query TT
explain SELECT t1_id, (SELECT count(*) + 2 as _cnt FROM t2 WHERE t2.t2_int = t1.t1_int) from t1
----
logical_plan
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS _cnt ELSE __scalar_sq_1._cnt END AS _cnt
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) ELSE __scalar_sq_1._cnt END AS _cnt
02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
03)----TableScan: t1 projection=[t1_id, t1_int]
04)----SubqueryAlias: __scalar_sq_1
Expand Down Expand Up @@ -922,7 +922,7 @@ query TT
explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1
----
logical_plan
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS cnt_plus_2 WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2
01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2
02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int
03)----TableScan: t1 projection=[t1_id, t1_int]
04)----SubqueryAlias: __scalar_sq_1
Expand Down

0 comments on commit c176533

Please sign in to comment.