Skip to content

Commit

Permalink
Eliminate deprecated HirScalarExpr::visit_mut_pre
Browse files Browse the repository at this point in the history
  • Loading branch information
ggevay committed Jan 22, 2025
1 parent 45da845 commit 1c8b234
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 77 deletions.
10 changes: 0 additions & 10 deletions src/sql/src/plan/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3097,16 +3097,6 @@ impl HirScalarExpr {
mem::replace(self, HirScalarExpr::literal_null(ScalarType::String))
}

#[deprecated = "Use `Visit::visit_mut_pre` instead."]
pub fn visit_mut_pre<F>(&mut self, f: &mut F)
where
F: FnMut(&mut Self),
{
f(self);
#[allow(deprecated)]
self.visit1_mut(|e: &mut HirScalarExpr| e.visit_mut_pre(f));
}

#[deprecated = "Use `VisitChildren<HirScalarExpr>::visit_children` instead."]
pub fn visit1_mut<F>(&mut self, mut f: F)
where
Expand Down
2 changes: 1 addition & 1 deletion src/sql/src/plan/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ impl HirRelationExpr {
mut other => {
let mut id_gen = mz_ore::id_gen::IdGen::default();
transform_hir::split_subquery_predicates(&mut other)?;
transform_hir::try_simplify_quantified_comparisons(&mut other);
transform_hir::try_simplify_quantified_comparisons(&mut other)?;
transform_hir::fuse_window_functions(&mut other, &context)?;
MirRelationExpr::constant(vec![vec![]], RelationType::new(vec![])).let_in(
&mut id_gen,
Expand Down
143 changes: 77 additions & 66 deletions src/sql/src/plan/transform_hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,20 @@ pub fn split_subquery_predicates(expr: &mut HirRelationExpr) -> Result<(), Recur
///
/// See Section 3.5 of "Execution Strategies for SQL Subqueries" by
/// M. Elhemali, et al.
pub fn try_simplify_quantified_comparisons(expr: &mut HirRelationExpr) {
fn walk_relation(expr: &mut HirRelationExpr, outers: &[RelationType]) {
pub fn try_simplify_quantified_comparisons(
expr: &mut HirRelationExpr,
) -> Result<(), RecursionLimitError> {
fn walk_relation(
expr: &mut HirRelationExpr,
outers: &[RelationType],
) -> Result<(), RecursionLimitError> {
match expr {
HirRelationExpr::Map { scalars, input } => {
walk_relation(input, outers);
walk_relation(input, outers)?;
let mut outers = outers.to_vec();
outers.insert(0, input.typ(&outers, &NO_PARAMS));
for scalar in scalars {
walk_scalar(scalar, &outers, false);
walk_scalar(scalar, &outers, false)?;
let (inner, outers) = outers
.split_first_mut()
.expect("outers known to have at least one element");
Expand All @@ -190,95 +195,101 @@ pub fn try_simplify_quantified_comparisons(expr: &mut HirRelationExpr) {
}
}
HirRelationExpr::Filter { predicates, input } => {
walk_relation(input, outers);
walk_relation(input, outers)?;
let mut outers = outers.to_vec();
outers.insert(0, input.typ(&outers, &NO_PARAMS));
for pred in predicates {
walk_scalar(pred, &outers, true);
walk_scalar(pred, &outers, true)?;
}
}
HirRelationExpr::CallTable { exprs, .. } => {
let mut outers = outers.to_vec();
outers.insert(0, RelationType::empty());
for scalar in exprs {
walk_scalar(scalar, &outers, false);
walk_scalar(scalar, &outers, false)?;
}
}
HirRelationExpr::Join { left, right, .. } => {
walk_relation(left, outers);
walk_relation(left, outers)?;
let mut outers = outers.to_vec();
outers.insert(0, left.typ(&outers, &NO_PARAMS));
walk_relation(right, &outers);
walk_relation(right, &outers)?;
}
expr => {
#[allow(deprecated)]
let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), ()> {
walk_relation(expr, outers);
Ok(())
let _ = expr.visit1_mut(0, &mut |expr, _| -> Result<(), RecursionLimitError> {
walk_relation(expr, outers)
});
}
}
Ok(())
}

fn walk_scalar(expr: &mut HirScalarExpr, outers: &[RelationType], mut in_filter: bool) {
#[allow(deprecated)]
expr.visit_mut_pre(&mut |e| match e {
HirScalarExpr::Exists(input) => walk_relation(input, outers),
HirScalarExpr::Select(input) => {
walk_relation(input, outers);

// We're inside of a `(SELECT ...)` subquery. Now let's see if
// it has the form `(SELECT <any|all>(...) FROM <input>)`.
// Ideally we could do this with one pattern, but Rust's pattern
// matching engine is not powerful enough, so we have to do this
// in stages; the early returns avoid brutal nesting.

let (func, expr, input) = match &mut **input {
HirRelationExpr::Reduce {
group_key,
aggregates,
input,
expected_group_size: _,
} if group_key.is_empty() && aggregates.len() == 1 => {
let agg = &mut aggregates[0];
(&agg.func, &mut agg.expr, input)
}
_ => return,
};

if !in_filter && column_type(outers, input, expr).nullable {
// Unless we're directly inside of a WHERE, this
// transformation is only valid if the expression involved
// is non-nullable.
return;
}
fn walk_scalar(
expr: &mut HirScalarExpr,
outers: &[RelationType],
mut in_filter: bool,
) -> Result<(), RecursionLimitError> {
expr.try_visit_mut_pre(&mut |e| {
match e {
HirScalarExpr::Exists(input) => walk_relation(input, outers)?,
HirScalarExpr::Select(input) => {
walk_relation(input, outers)?;

// We're inside a `(SELECT ...)` subquery. Now let's see if
// it has the form `(SELECT <any|all>(...) FROM <input>)`.
// Ideally we could do this with one pattern, but Rust's pattern
// matching engine is not powerful enough, so we have to do this
// in stages; the early returns avoid brutal nesting.

let (func, expr, input) = match &mut **input {
HirRelationExpr::Reduce {
group_key,
aggregates,
input,
expected_group_size: _,
} if group_key.is_empty() && aggregates.len() == 1 => {
let agg = &mut aggregates[0];
(&agg.func, &mut agg.expr, input)
}
_ => return Ok(()),
};

match func {
AggregateFunc::Any => {
// Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
// `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
*e = input.take().filter(vec![expr.take()]).exists();
if !in_filter && column_type(outers, input, expr).nullable {
// Unless we're directly inside a WHERE, this
// transformation is only valid if the expression involved
// is non-nullable.
return Ok(());
}
AggregateFunc::All => {
// Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
// `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
//
// Note that negation of <expr> alone is insufficient.
// Consider that `WHERE <pred>` filters out rows if
// `<pred>` is false *or* null. To invert the test, we
// need `NOT <pred> OR <pred> IS NULL`.
let expr = expr.take();
let filter = expr.clone().not().or(expr.call_is_null());
*e = input.take().filter(vec![filter]).exists().not();

match func {
AggregateFunc::Any => {
// Found `(SELECT any(<expr>) FROM <input>)`. Rewrite to
// `EXISTS(SELECT 1 FROM <input> WHERE <expr>)`.
*e = input.take().filter(vec![expr.take()]).exists();
}
AggregateFunc::All => {
// Found `(SELECT all(<expr>) FROM <input>)`. Rewrite to
// `NOT EXISTS(SELECT 1 FROM <input> WHERE NOT <expr> OR <expr> IS NULL)`.
//
// Note that negation of <expr> alone is insufficient.
// Consider that `WHERE <pred>` filters out rows if
// `<pred>` is false *or* null. To invert the test, we
// need `NOT <pred> OR <pred> IS NULL`.
let expr = expr.take();
let filter = expr.clone().not().or(expr.call_is_null());
*e = input.take().filter(vec![filter]).exists().not();
}
_ => (),
}
_ => (),
}
_ => {
// As soon as we see *any* scalar expression, we are no longer
// directly inside a filter.
in_filter = false;
}
}
_ => {
// As soon as we see *any* scalar expression, we are no longer
// directly inside of a filter.
in_filter = false;
}
Ok(())
})
}

Expand Down

0 comments on commit 1c8b234

Please sign in to comment.