diff --git a/datafusion-federation/src/sql/mod.rs b/datafusion-federation/src/sql/mod.rs index d63d52d..98d90ab 100644 --- a/datafusion-federation/src/sql/mod.rs +++ b/datafusion-federation/src/sql/mod.rs @@ -14,8 +14,8 @@ use datafusion::{ AggregateFunction, Alias, Exists, InList, InSubquery, PlannedReplaceSelectItem, ScalarFunction, Sort, Unnest, WildcardOptions, WindowFunction, }, - Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, LogicalPlan, Subquery, - TryCast, + Between, BinaryExpr, Case, Cast, Expr, Extension, GroupingSet, Like, Limit, LogicalPlan, + Subquery, TryCast, }, optimizer::{optimizer::Optimizer, OptimizerConfig, OptimizerRule}, physical_expr::EquivalenceProperties, @@ -165,6 +165,29 @@ fn rewrite_table_scans( .map(|i| rewrite_table_scans(i, known_rewrites)) .collect::>>()?; + if let LogicalPlan::Limit(limit) = plan { + let rewritten_skip = limit + .skip + .as_ref() + .map(|skip| rewrite_table_scans_in_expr(*skip.clone(), known_rewrites).map(Box::new)) + .transpose()?; + + let rewritten_fetch = limit + .fetch + .as_ref() + .map(|fetch| rewrite_table_scans_in_expr(*fetch.clone(), known_rewrites).map(Box::new)) + .transpose()?; + + // explicitly set fetch and skip + let new_plan = LogicalPlan::Limit(Limit { + skip: rewritten_skip, + fetch: rewritten_fetch, + input: Arc::new(rewritten_inputs[0].clone()), + }); + + return Ok(new_plan); + } + let mut new_expressions = vec![]; for expression in plan.expressions() { let new_expr = rewrite_table_scans_in_expr(expression.clone(), known_rewrites)?; @@ -994,4 +1017,54 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_rewrite_table_scans_limit_offset() -> Result<()> { + init_tracing(); + let ctx = get_test_df_context(); + + let tests = vec![ + // Basic LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 5", + r#"SELECT remote_table.a FROM remote_table LIMIT 5"#, + ), + // Basic OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 5", + r#"SELECT remote_table.a FROM remote_table OFFSET 5"#, + ), + // OFFSET after LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 10 OFFSET 5", + r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, + ), + // LIMIT after OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 5 LIMIT 10", + r#"SELECT remote_table.a FROM remote_table LIMIT 10 OFFSET 5"#, + ), + // Zero OFFSET + ( + "SELECT a FROM foo.df_table OFFSET 0", + r#"SELECT remote_table.a FROM remote_table OFFSET 0"#, + ), + // Zero LIMIT + ( + "SELECT a FROM foo.df_table LIMIT 0", + r#"SELECT remote_table.a FROM remote_table LIMIT 0"#, + ), + // Zero LIMIT and OFFSET + ( + "SELECT a FROM foo.df_table LIMIT 0 OFFSET 0", + r#"SELECT remote_table.a FROM remote_table LIMIT 0 OFFSET 0"#, + ), + ]; + + for test in tests { + test_sql(&ctx, test.0, test.1).await?; + } + + Ok(()) + } }