Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: df patched upgrade to 2024-03-05, requiring new DF fixes #3

Closed
Closed
69 changes: 50 additions & 19 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,15 @@ use datafusion_expr::logical_plan::{
use datafusion_expr::{col, Expr, ExprSchemable};

/// A map from expression's identifier to tuple including
///
/// key == Identifier created with only the current node (& subtree)
///
/// values:
/// - the expression itself (cloned)
/// - counter
/// - DataType of this expression.
type ExprSet = HashMap<Identifier, (Expr, usize, DataType)>;
/// - symbol used as the identifier in the alias
type ExprSet = HashMap<Identifier, (Expr, usize, DataType, Identifier)>;

/// Identifier for each subexpression.
///
Expand Down Expand Up @@ -278,9 +283,9 @@ impl CommonSubexprEliminate {

for id in affected_id {
match expr_set.get(&id) {
Some((expr, _, _)) => {
Some((expr, _, _, symbol)) => {
// todo: check `nullable`
agg_exprs.push(expr.clone().alias(&id));
agg_exprs.push(expr.clone().alias(symbol.as_str()));
}
_ => {
return internal_err!("expr_set invalid state");
Expand Down Expand Up @@ -455,11 +460,11 @@ fn build_common_expr_project_plan(

for id in affected_id {
match expr_set.get(&id) {
Some((expr, _, data_type)) => {
Some((expr, _, data_type, symbol)) => {
// todo: check `nullable`
let field = DFField::new_unqualified(&id, data_type.clone(), true);
fields_set.insert(field.name().to_owned());
project_exprs.push(expr.clone().alias(&id));
project_exprs.push(expr.clone().alias(symbol.as_str()));
}
_ => {
return internal_err!("expr_set invalid state");
Expand Down Expand Up @@ -650,16 +655,16 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> {
.push(VisitRecord::ExprItem(curr_expr_identifier));
return Ok(TreeNodeRecursion::Continue);
}
let mut desc = Self::expr_identifier(expr);
desc.push_str(&sub_expr_identifier);
let curr_expr_identifier = Self::expr_identifier(expr);
let desc = format!("{curr_expr_identifier}{sub_expr_identifier}");

self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));

let data_type = expr.get_type(&self.input_schema)?;

self.expr_set
.entry(desc)
.or_insert_with(|| (expr.clone(), 0, data_type))
.entry(curr_expr_identifier)
.or_insert_with(|| (expr.clone(), 0, data_type, desc))
.1 += 1;
Ok(TreeNodeRecursion::Continue)
}
Expand Down Expand Up @@ -713,7 +718,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {

// lookup previously visited expression
match self.expr_set.get(curr_id) {
Some((_, counter, _)) => {
Some((_, counter, _, symbol)) => {
// if has a commonly used (a.k.a. 1+ use) expr
if *counter > 1 {
self.affected_id.insert(curr_id.clone());
Expand All @@ -723,7 +728,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
// `projection_push_down` optimizer use "expr name" to eliminate useless
// projections.
Ok(Transformed::new(
col(curr_id).alias(expr_name),
col(symbol).alias(expr_name),
true,
TreeNodeRecursion::Jump,
))
Expand Down Expand Up @@ -1026,18 +1031,24 @@ mod test {
let expr_set_1 = [
(
"c+a".to_string(),
(col("c") + col("a"), 1, DataType::UInt32),
(col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()),
),
(
"b+a".to_string(),
(col("b") + col("a"), 1, DataType::UInt32),
(col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()),
),
]
.into_iter()
.collect();
let expr_set_2 = [
("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)),
("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)),
(
"c+a".to_string(),
(col("c+a"), 1, DataType::UInt32, "c+a".to_string()),
),
(
"b+a".to_string(),
(col("b+a"), 1, DataType::UInt32, "b+a".to_string()),
),
]
.into_iter()
.collect();
Expand Down Expand Up @@ -1069,23 +1080,43 @@ mod test {
let expr_set_1 = [
(
"test1.c+test1.a".to_string(),
(col("test1.c") + col("test1.a"), 1, DataType::UInt32),
(
col("test1.c") + col("test1.a"),
1,
DataType::UInt32,
"test1.c+test1.a".to_string(),
),
),
(
"test1.b+test1.a".to_string(),
(col("test1.b") + col("test1.a"), 1, DataType::UInt32),
(
col("test1.b") + col("test1.a"),
1,
DataType::UInt32,
"test1.b+test1.a".to_string(),
),
),
]
.into_iter()
.collect();
let expr_set_2 = [
(
"test1.c+test1.a".to_string(),
(col("test1.c+test1.a"), 1, DataType::UInt32),
(
col("test1.c+test1.a"),
1,
DataType::UInt32,
"test1.c+test1.a".to_string(),
),
),
(
"test1.b+test1.a".to_string(),
(col("test1.b+test1.a"), 1, DataType::UInt32),
(
col("test1.b+test1.a"),
1,
DataType::UInt32,
"test1.b+test1.a".to_string(),
),
),
]
.into_iter()
Expand Down
Loading