Skip to content

Commit

Permalink
equivalence classes: fix projection
Browse files Browse the repository at this point in the history
This patch fixes the logic that projects equivalence classes:
when run over the projection mapping to find new equivalent expressions,
we need to normalize a source expression.
  • Loading branch information
askalt committed Jan 29, 2025
1 parent 09a0844 commit 13265a1
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 2 deletions.
64 changes: 62 additions & 2 deletions datafusion/physical-expr/src/equivalence/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,18 @@ impl EquivalenceGroup {
.collect::<Vec<_>>();
(new_class.len() > 1).then_some(EquivalenceClass::new(new_class))
});

// the key is the source expression and the value is the EquivalenceClass that contains the target expression of the source expression.
let mut new_classes: IndexMap<Arc<dyn PhysicalExpr>, EquivalenceClass> =
IndexMap::new();
mapping.iter().for_each(|(source, target)| {
// We need to find equivalent projected expressions.
// e.g. table with columns [a,b,c] and a == b, projection: [a+c, b+c].
// To conclude that a + c == b + c we firsty normalize all source expressions
// in the mapping, then merge all equivalent expressions into the classes.
let normalized_expr = self.normalize_expr(Arc::clone(source));
new_classes
.entry(Arc::clone(source))
.entry(normalized_expr)
.or_insert_with(EquivalenceClass::new_empty)
.push(Arc::clone(target));
});
Expand Down Expand Up @@ -752,8 +758,9 @@ mod tests {

use super::*;
use crate::equivalence::tests::create_test_params;
use crate::expressions::{lit, BinaryExpr, Literal};
use crate::expressions::{binary, col, lit, BinaryExpr, Literal};

use arrow_schema::{DataType, Field, Schema};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Operator;

Expand Down Expand Up @@ -1038,4 +1045,57 @@ mod tests {

Ok(())
}

#[test]
fn test_project_classes() -> Result<()> {
// - columns: [a, b, c].
// - "a" and "b" in the same equivalence class.
// - then after a+c, b+c projection col(0) and col(1) must be
// in the same class too.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));
let mut group = EquivalenceGroup::empty();
group.add_equal_conditions(&col("a", &schema)?, &col("b", &schema)?);

let projected_schema = Arc::new(Schema::new(vec![
Field::new("a+c", DataType::Int32, false),
Field::new("b+c", DataType::Int32, false),
]));

let mapping = ProjectionMapping {
map: vec![
(
binary(
col("a", &schema)?,
Operator::Plus,
col("c", &schema)?,
&schema,
)?,
col("a+c", &projected_schema)?,
),
(
binary(
col("b", &schema)?,
Operator::Plus,
col("c", &schema)?,
&schema,
)?,
col("b+c", &projected_schema)?,
),
],
};

let projected = group.project(&mapping);

assert!(!projected.is_empty());
let first_normalized = projected.normalize_expr(col("a+c", &projected_schema)?);
let second_normalized = projected.normalize_expr(col("b+c", &projected_schema)?);

assert!(first_normalized.eq(&second_normalized));

Ok(())
}
}
76 changes: 76 additions & 0 deletions datafusion/sqllogictest/test_files/join.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1312,3 +1312,79 @@ SELECT a+b*2,

statement ok
drop table t1;

# Test that equivalent classes are projected correctly.

statement ok
create table pairs(x int, y int) as values (1,1), (2,2), (3,3);

statement ok
create table f(a int) as values (1), (2), (3);

statement ok
create table s(b int) as values (1), (2), (3);

statement ok
set datafusion.optimizer.repartition_joins = true;

statement ok
set datafusion.execution.target_partitions = 16;

# After the filter applying (x = y) we can join by both x and y,
# partitioning only once.

query TT
explain
SELECT * FROM
(SELECT x+1 AS col0, y+1 AS col1 FROM PAIRS WHERE x == y)
JOIN f
ON col0 = f.a
JOIN s
ON col1 = s.b
----
logical_plan
01)Inner Join: col1 = CAST(s.b AS Int64)
02)--Inner Join: col0 = CAST(f.a AS Int64)
03)----Projection: CAST(pairs.x AS Int64) + Int64(1) AS col0, CAST(pairs.y AS Int64) + Int64(1) AS col1
04)------Filter: pairs.y = pairs.x
05)--------TableScan: pairs projection=[x, y]
06)----TableScan: f projection=[a]
07)--TableScan: s projection=[b]
physical_plan
01)CoalesceBatchesExec: target_batch_size=8192
02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col1@1, CAST(s.b AS Int64)@1)], projection=[col0@0, col1@1, a@2, b@3]
03)----ProjectionExec: expr=[col0@1 as col0, col1@2 as col1, a@0 as a]
04)------CoalesceBatchesExec: target_batch_size=8192
05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(CAST(f.a AS Int64)@1, col0@0)], projection=[a@0, col0@2, col1@3]
06)----------CoalesceBatchesExec: target_batch_size=8192
07)------------RepartitionExec: partitioning=Hash([CAST(f.a AS Int64)@1], 16), input_partitions=1
08)--------------ProjectionExec: expr=[a@0 as a, CAST(a@0 AS Int64) as CAST(f.a AS Int64)]
09)----------------MemoryExec: partitions=1, partition_sizes=[1]
10)----------CoalesceBatchesExec: target_batch_size=8192
11)------------RepartitionExec: partitioning=Hash([col0@0], 16), input_partitions=16
12)--------------ProjectionExec: expr=[CAST(x@0 AS Int64) + 1 as col0, CAST(y@1 AS Int64) + 1 as col1]
13)----------------RepartitionExec: partitioning=RoundRobinBatch(16), input_partitions=1
14)------------------CoalesceBatchesExec: target_batch_size=8192
15)--------------------FilterExec: y@1 = x@0
16)----------------------MemoryExec: partitions=1, partition_sizes=[1]
17)----CoalesceBatchesExec: target_batch_size=8192
18)------RepartitionExec: partitioning=Hash([CAST(s.b AS Int64)@1], 16), input_partitions=1
19)--------ProjectionExec: expr=[b@0 as b, CAST(b@0 AS Int64) as CAST(s.b AS Int64)]
20)----------MemoryExec: partitions=1, partition_sizes=[1]

statement ok
drop table pairs;

statement ok
drop table f;

statement ok
drop table s;

# Reset the configs to old values.
statement ok
set datafusion.execution.target_partitions = 4;

statement ok
set datafusion.optimizer.repartition_joins = false;

0 comments on commit 13265a1

Please sign in to comment.