From f0e96c670108ba0ffdebb9dd9e764bba4d2dca8c Mon Sep 17 00:00:00 2001 From: Adam Curtis Date: Tue, 7 May 2024 06:43:34 -0400 Subject: [PATCH] feat: run expression simplifier in a loop until a fixedpoint or 3 cycles (#10358) * feat: run expression simplifier in a loop * change max_simplifier_iterations to u32 * use simplify_inner to explicitly test iteration count * refactor simplify_inner loop * const evaluator should return transformed=false on literals * update tests * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs Co-authored-by: Andrew Lamb * run shorten_in_list_simplifier once at the end of the loop * move UDF test case to core integration tests * documentation and naming updates * documentation and naming updates * remove unused import and minor doc formatting change * Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs --------- Co-authored-by: Andrew Lamb --- datafusion/core/tests/simplification.rs | 31 ++++ .../simplify_expressions/expr_simplifier.rs | 175 +++++++++++++++--- 2 files changed, 182 insertions(+), 24 deletions(-) diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/simplification.rs index 880c294bb7aa..bb4192983426 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/simplification.rs @@ -508,6 +508,29 @@ fn test_simplify(input_expr: Expr, expected_expr: Expr) { "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" ); } +fn test_simplify_with_cycle_count( + input_expr: Expr, + expected_expr: Expr, + expected_count: u32, +) { + let info: MyInfo = MyInfo { + schema: expr_test_schema(), + execution_props: ExecutionProps::new(), + }; + let simplifier = ExprSimplifier::new(info); + let (simplified_expr, count) = simplifier + .simplify_with_cycle_count(input_expr.clone()) + .expect("successfully evaluated"); + + assert_eq!( + simplified_expr, expected_expr, + "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" + ); + assert_eq!( + count, expected_count, + "Mismatch simplifier cycle count\n Expected: {expected_count}\n Got:{count}" + ); +} #[test] fn test_simplify_log() { @@ -658,3 +681,11 @@ fn test_simplify_concat() { let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]); test_simplify(expr, expected) } +#[test] +fn test_simplify_cycles() { + // cast(now() as int64) < cast(to_timestamp(0) as int64) + i64::MAX + let expr = cast(now(), DataType::Int64) + .lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX)); + let expected = lit(true); + test_simplify_with_cycle_count(expr, expected, 3); +} diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 4d7a207afb1b..0f711d6a2c6d 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -92,9 +92,12 @@ pub struct ExprSimplifier { /// Should expressions be canonicalized before simplification? Defaults to /// true canonicalize: bool, + /// Maximum number of simplifier cycles + max_simplifier_cycles: u32, } pub const THRESHOLD_INLINE_INLIST: usize = 3; +pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3; impl ExprSimplifier { /// Create a new `ExprSimplifier` with the given `info` such as an @@ -107,10 +110,11 @@ impl ExprSimplifier { info, guarantees: vec![], canonicalize: true, + max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES, } } - /// Simplifies this [`Expr`]`s as much as possible, evaluating + /// Simplifies this [`Expr`] as much as possible, evaluating /// constants and applying algebraic simplifications. /// /// The types of the expression must match what operators expect, @@ -171,7 +175,18 @@ impl ExprSimplifier { /// let expr = simplifier.simplify(expr).unwrap(); /// assert_eq!(expr, b_lt_2); /// ``` - pub fn simplify(&self, mut expr: Expr) -> Result { + pub fn simplify(&self, expr: Expr) -> Result { + Ok(self.simplify_with_cycle_count(expr)?.0) + } + + /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating + /// constants and applying algebraic simplifications. Additionally returns a `u32` + /// representing the number of simplification cycles performed, which can be useful for testing + /// optimizations. + /// + /// See [Self::simplify] for details and usage examples. + /// + pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); @@ -181,24 +196,26 @@ impl ExprSimplifier { expr = expr.rewrite(&mut Canonicalizer::new()).data()? } - // TODO iterate until no changes are made during rewrite - // (evaluating constants can enable new simplifications and - // simplifications can enable new constant evaluation) - // https://github.com/apache/datafusion/issues/1160 - expr.rewrite(&mut const_evaluator) - .data()? - .rewrite(&mut simplifier) - .data()? - .rewrite(&mut guarantee_rewriter) - .data()? - // run both passes twice to try an minimize simplifications that we missed - .rewrite(&mut const_evaluator) - .data()? - .rewrite(&mut simplifier) - .data()? - // shorten inlist should be started after other inlist rules are applied - .rewrite(&mut shorten_in_list_simplifier) - .data() + // Evaluating constants can enable new simplifications and + // simplifications can enable new constant evaluation + // see `Self::with_max_cycles` + let mut num_cycles = 0; + loop { + let Transformed { + data, transformed, .. + } = expr + .rewrite(&mut const_evaluator)? + .transform_data(|expr| expr.rewrite(&mut simplifier))? + .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?; + expr = data; + num_cycles += 1; + if !transformed || num_cycles >= self.max_simplifier_cycles { + break; + } + } + // shorten inlist should be started after other inlist rules are applied + expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?; + Ok((expr, num_cycles)) } /// Apply type coercion to an [`Expr`] so that it can be @@ -323,6 +340,63 @@ impl ExprSimplifier { self.canonicalize = canonicalize; self } + + /// Specifies the maximum number of simplification cycles to run. + /// + /// The simplifier can perform multiple passes of simplification. This is + /// because the output of one simplification step can allow more optimizations + /// in another simplification step. For example, constant evaluation can allow more + /// expression simplifications, and expression simplifications can allow more constant + /// evaluations. + /// + /// This method specifies the maximum number of allowed iteration cycles before the simplifier + /// returns an [Expr] output. However, it does not always perform the maximum number of cycles. + /// The simplifier will attempt to detect when an [Expr] is unchanged by all the simplification + /// passes, and return early. This avoids wasting time on unnecessary [Expr] tree traversals. + /// + /// If no maximum is specified, the value of [DEFAULT_MAX_SIMPLIFIER_CYCLES] is used + /// instead. + /// + /// ```rust + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use datafusion_expr::{col, lit, Expr}; + /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; + /// use datafusion_expr::execution_props::ExecutionProps; + /// use datafusion_expr::simplify::SimplifyContext; + /// use datafusion_optimizer::simplify_expressions::ExprSimplifier; + /// + /// let schema = Schema::new(vec![ + /// Field::new("a", DataType::Int64, false), + /// ]) + /// .to_dfschema_ref().unwrap(); + /// + /// // Create the simplifier + /// let props = ExecutionProps::new(); + /// let context = SimplifyContext::new(&props) + /// .with_schema(schema); + /// let simplifier = ExprSimplifier::new(context); + /// + /// // Expression: a IS NOT NULL + /// let expr = col("a").is_not_null(); + /// + /// // When using default maximum cycles, 2 cycles will be performed. + /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap(); + /// assert_eq!(simplified_expr, lit(true)); + /// // 2 cycles were executed, but only 1 was needed + /// assert_eq!(count, 2); + /// + /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1. + /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap(); + /// // Expression has been rewritten to: (c = a AND b = 1) + /// assert_eq!(simplified_expr, lit(true)); + /// // Only 1 cycle was executed + /// assert_eq!(count, 1); + /// + /// ``` + pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self { + self.max_simplifier_cycles = max_simplifier_cycles; + self + } } /// Canonicalize any BinaryExprs that are not in canonical form @@ -404,6 +478,8 @@ struct ConstEvaluator<'a> { enum ConstSimplifyResult { // Expr was simplifed and contains the new expression Simplified(ScalarValue), + // Expr was not simplified and original value is returned + NotSimplified(ScalarValue), // Evaluation encountered an error, contains the original expression SimplifyRuntimeError(DataFusionError, Expr), } @@ -450,6 +526,9 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { ConstSimplifyResult::Simplified(s) => { Ok(Transformed::yes(Expr::Literal(s))) } + ConstSimplifyResult::NotSimplified(s) => { + Ok(Transformed::no(Expr::Literal(s))) + } ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { Ok(Transformed::yes(expr)) } @@ -548,7 +627,7 @@ impl<'a> ConstEvaluator<'a> { /// Internal helper to evaluates an Expr pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult { if let Expr::Literal(s) = expr { - return ConstSimplifyResult::Simplified(s); + return ConstSimplifyResult::NotSimplified(s); } let phys_expr = @@ -1672,15 +1751,14 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { #[cfg(test)] mod tests { + use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; + use datafusion_expr::{interval_arithmetic::Interval, *}; use std::{ collections::HashMap, ops::{BitAnd, BitOr, BitXor}, sync::Arc, }; - use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; - use datafusion_expr::{interval_arithmetic::Interval, *}; - use crate::simplify_expressions::SimplifyContext; use crate::test::test_table_scan_with_name; @@ -2868,6 +2946,19 @@ mod tests { try_simplify(expr).unwrap() } + fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> { + let schema = expr_test_schema(); + let execution_props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&execution_props).with_schema(schema), + ); + simplifier.simplify_with_cycle_count(expr) + } + + fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) { + try_simplify_with_cycle_count(expr).unwrap() + } + fn simplify_with_guarantee( expr: Expr, guarantees: Vec<(Expr, NullableInterval)>, @@ -3575,4 +3666,40 @@ mod tests { assert_eq!(simplify(expr), expected); } + + #[test] + fn test_simplify_cycles() { + // TRUE + let expr = lit(true); + let expected = lit(true); + let (expr, num_iter) = simplify_with_cycle_count(expr); + assert_eq!(expr, expected); + assert_eq!(num_iter, 1); + + // (true != NULL) OR (5 > 10) + let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10))); + let expected = lit_bool_null(); + let (expr, num_iter) = simplify_with_cycle_count(expr); + assert_eq!(expr, expected); + assert_eq!(num_iter, 2); + + // NOTE: this currently does not simplify + // (((c4 - 10) + 10) *100) / 100 + let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100); + let expected = expr.clone(); + let (expr, num_iter) = simplify_with_cycle_count(expr); + assert_eq!(expr, expected); + assert_eq!(num_iter, 1); + + // ((c4<1 or c3<2) and c3_non_null<3) and false + let expr = col("c4") + .lt(lit(1)) + .or(col("c3").lt(lit(2))) + .and(col("c3_non_null").lt(lit(3))) + .and(lit(false)); + let expected = lit(false); + let (expr, num_iter) = simplify_with_cycle_count(expr); + assert_eq!(expr, expected); + assert_eq!(num_iter, 2); + } }