diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index 3eac62a4df08..b602a9cba4f4 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -246,3 +246,119 @@ fn calculate_selectivity( acc * cardinality_ratio(&initial.interval, &target.interval) }) } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::{assert_contains, DFSchema}; + use datafusion_expr::{ + col, execution_props::ExecutionProps, interval_arithmetic::Interval, lit, Expr, + }; + + use crate::{create_physical_expr, AnalysisContext}; + + use super::{analyze, ExprBoundaries}; + + fn make_field(name: &str, data_type: DataType) -> Field { + let nullable = false; + Field::new(name, data_type, nullable) + } + + #[test] + fn test_analyze_boundary_exprs() { + let schema = Arc::new(Schema::new(vec![make_field("a", DataType::Int32)])); + + /// Test case containing (expression tree, lower bound, upper bound) + type TestCase = (Expr, Option, Option); + + let test_cases: Vec = vec![ + // a > 10 + (col("a").gt(lit(10)), Some(11), None), + // a < 20 + (col("a").lt(lit(20)), None, Some(19)), + // a > 10 AND a < 20 + ( + col("a").gt(lit(10)).and(col("a").lt(lit(20))), + Some(11), + Some(19), + ), + // a >= 10 + (col("a").gt_eq(lit(10)), Some(10), None), + // a <= 20 + (col("a").lt_eq(lit(20)), None, Some(20)), + // a >= 10 AND a <= 20 + ( + col("a").gt_eq(lit(10)).and(col("a").lt_eq(lit(20))), + Some(10), + Some(20), + ), + // a > 10 AND a < 20 AND a < 15 + ( + col("a") + .gt(lit(10)) + .and(col("a").lt(lit(20))) + .and(col("a").lt(lit(15))), + Some(11), + Some(14), + ), + // (a > 10 AND a < 20) AND (a > 15 AND a < 25) + ( + col("a") + .gt(lit(10)) + .and(col("a").lt(lit(20))) + .and(col("a").gt(lit(15))) + .and(col("a").lt(lit(25))), + Some(16), + Some(19), + ), + // (a > 10 AND a < 20) AND (a > 20 AND a < 30) + ( + col("a") + .gt(lit(10)) + .and(col("a").lt(lit(20))) + .and(col("a").gt(lit(20))) + .and(col("a").lt(lit(30))), + None, + None, + ), + ]; + for (expr, lower, upper) in test_cases { + let boundaries = ExprBoundaries::try_new_unbounded(&schema).unwrap(); + let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); + let physical_expr = + create_physical_expr(&expr, &df_schema, &ExecutionProps::new()).unwrap(); + let analysis_result = analyze( + &physical_expr, + AnalysisContext::new(boundaries), + df_schema.as_ref(), + ) + .unwrap(); + let actual = &analysis_result.boundaries[0].interval; + let expected = Interval::make(lower, upper).unwrap(); + assert_eq!( + &expected, actual, + "did not get correct interval for SQL expression: {expr:?}" + ); + } + } + + #[test] + fn test_analyze_invalid_boundary_exprs() { + let schema = Arc::new(Schema::new(vec![make_field("a", DataType::Int32)])); + let expr = col("a").lt(lit(10)).or(col("a").gt(lit(20))); + let expected_error = "Interval arithmetic does not support the operator OR"; + let boundaries = ExprBoundaries::try_new_unbounded(&schema).unwrap(); + let df_schema = DFSchema::try_from(Arc::clone(&schema)).unwrap(); + let physical_expr = + create_physical_expr(&expr, &df_schema, &ExecutionProps::new()).unwrap(); + let analysis_error = analyze( + &physical_expr, + AnalysisContext::new(boundaries), + df_schema.as_ref(), + ) + .unwrap_err(); + assert_contains!(analysis_error.to_string(), expected_error); + } +}