Skip to content

Commit

Permalink
use a single row_count column during predicate pruning instead of one…
Browse files Browse the repository at this point in the history
… per column (apache#14295)

* use a single row_count column during predicate pruning instead of one per column

* fix tests

* fix conflicts and test:

* lint

* fix assertions
  • Loading branch information
adriangb authored Feb 9, 2025
1 parent 0f9773a commit df0d966
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1687,7 +1687,7 @@ mod tests {

assert_contains!(
&display,
"pruning_predicate=c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != bar OR bar != c1_max@1)"
"pruning_predicate=c1_null_count@2 != row_count@3 AND (c1_min@0 != bar OR bar != c1_max@1)"
);

assert_contains!(&display, r#"predicate=c1@0 != bar"#);
Expand Down
125 changes: 82 additions & 43 deletions datafusion/physical-optimizer/src/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -807,11 +807,22 @@ impl RequiredColumns {
column: &phys_expr::Column,
statistics_type: StatisticsType,
) -> Option<usize> {
self.columns
.iter()
.enumerate()
.find(|(_i, (c, t, _f))| c == column && t == &statistics_type)
.map(|(i, (_c, _t, _f))| i)
match statistics_type {
StatisticsType::RowCount => {
// Use the first row count we find, if any
self.columns
.iter()
.enumerate()
.find(|(_i, (_c, t, _f))| t == &statistics_type)
.map(|(i, (_c, _t, _f))| i)
}
_ => self
.columns
.iter()
.enumerate()
.find(|(_i, (c, t, _f))| c == column && t == &statistics_type)
.map(|(i, (_c, _t, _f))| i),
}
}

/// Rewrites column_expr so that all appearances of column
Expand All @@ -834,15 +845,15 @@ impl RequiredColumns {
None => (self.columns.len(), true),
};

let suffix = match stat_type {
StatisticsType::Min => "min",
StatisticsType::Max => "max",
StatisticsType::NullCount => "null_count",
StatisticsType::RowCount => "row_count",
let column_name = column.name();
let stat_column_name = match stat_type {
StatisticsType::Min => format!("{column_name}_min"),
StatisticsType::Max => format!("{column_name}_max"),
StatisticsType::NullCount => format!("{column_name}_null_count"),
StatisticsType::RowCount => "row_count".to_string(),
};

let stat_column =
phys_expr::Column::new(&format!("{}_{}", column.name(), suffix), idx);
let stat_column = phys_expr::Column::new(&stat_column_name, idx);

// only add statistics column if not previously added
if need_to_insert {
Expand Down Expand Up @@ -2189,6 +2200,38 @@ mod tests {
}
}

/// Row count should only be referenced once in the pruning expression, even if we need the row count
/// for multiple columns.
#[test]
fn test_unique_row_count_field_and_column() {
// c1 = 100 AND c2 = 200
let schema: SchemaRef = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Int32, true),
]));
let expr = col("c1").eq(lit(100)).and(col("c2").eq(lit(200)));
let expr = logical2physical(&expr, &schema);
let p = PruningPredicate::try_new(expr, Arc::clone(&schema)).unwrap();
// note pruning expression refers to row_count twice
assert_eq!(
"c1_null_count@2 != row_count@3 AND c1_min@0 <= 100 AND 100 <= c1_max@1 AND c2_null_count@6 != row_count@3 AND c2_min@4 <= 200 AND 200 <= c2_max@5",
p.predicate_expr.to_string()
);

// Fields in required schema should be unique, otherwise when creating batches
// it will fail because of duplicate field names
let mut fields = HashSet::new();
for (_col, _ty, field) in p.required_columns().iter() {
let was_new = fields.insert(field);
if !was_new {
panic!(
"Duplicate field in required schema: {:?}. Previous fields:\n{:#?}",
field, fields
);
}
}
}

#[test]
fn prune_all_rows_null_counts() {
// if null_count = row_count then we should prune the container for i = 0
Expand Down Expand Up @@ -2475,7 +2518,7 @@ mod tests {
fn row_group_predicate_eq() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr =
"c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1";
"c1_null_count@2 != row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1";

// test column on the left
let expr = col("c1").eq(lit(1));
Expand All @@ -2496,7 +2539,7 @@ mod tests {
fn row_group_predicate_not_eq() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr =
"c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 1 OR 1 != c1_max@1)";
"c1_null_count@2 != row_count@3 AND (c1_min@0 != 1 OR 1 != c1_max@1)";

// test column on the left
let expr = col("c1").not_eq(lit(1));
Expand All @@ -2516,7 +2559,7 @@ mod tests {
#[test]
fn row_group_predicate_gt() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_max@0 > 1";
let expected_expr = "c1_null_count@1 != row_count@2 AND c1_max@0 > 1";

// test column on the left
let expr = col("c1").gt(lit(1));
Expand All @@ -2536,7 +2579,7 @@ mod tests {
#[test]
fn row_group_predicate_gt_eq() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_max@0 >= 1";
let expected_expr = "c1_null_count@1 != row_count@2 AND c1_max@0 >= 1";

// test column on the left
let expr = col("c1").gt_eq(lit(1));
Expand All @@ -2555,7 +2598,7 @@ mod tests {
#[test]
fn row_group_predicate_lt() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 < 1";
let expected_expr = "c1_null_count@1 != row_count@2 AND c1_min@0 < 1";

// test column on the left
let expr = col("c1").lt(lit(1));
Expand All @@ -2575,7 +2618,7 @@ mod tests {
#[test]
fn row_group_predicate_lt_eq() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 <= 1";
let expected_expr = "c1_null_count@1 != row_count@2 AND c1_min@0 <= 1";

// test column on the left
let expr = col("c1").lt_eq(lit(1));
Expand All @@ -2600,7 +2643,7 @@ mod tests {
]);
// test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression
let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3")));
let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 < 1";
let expected_expr = "c1_null_count@1 != row_count@2 AND c1_min@0 < 1";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand Down Expand Up @@ -2666,7 +2709,7 @@ mod tests {
#[test]
fn row_group_predicate_lt_bool() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]);
let expected_expr = "c1_null_count@1 != c1_row_count@2 AND c1_min@0 < true";
let expected_expr = "c1_null_count@1 != row_count@2 AND c1_min@0 < true";

// DF doesn't support arithmetic on boolean columns so
// this predicate will error when evaluated
Expand All @@ -2689,16 +2732,12 @@ mod tests {
let expr = col("c1")
.lt(lit(1))
.and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3))));
let expected_expr = "c1_null_count@1 != c1_row_count@2 \
AND c1_min@0 < 1 AND (\
c2_null_count@5 != c2_row_count@6 \
AND c2_min@3 <= 2 AND 2 <= c2_max@4 OR \
c2_null_count@5 != c2_row_count@6 AND c2_min@3 <= 3 AND 3 <= c2_max@4\
)";
let expected_expr = "c1_null_count@1 != row_count@2 AND c1_min@0 < 1 AND (c2_null_count@5 != row_count@2 AND c2_min@3 <= 2 AND 2 <= c2_max@4 OR c2_null_count@5 != row_count@2 AND c2_min@3 <= 3 AND 3 <= c2_max@4)";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut required_columns);
assert_eq!(predicate_expr.to_string(), expected_expr);
// c1 < 1 should add c1_min
println!("required_columns: {:#?}", required_columns); // for debugging assertions below
// c1 < 1 should add c1_min
let c1_min_field = Field::new("c1_min", DataType::Int32, false);
assert_eq!(
required_columns.columns[0],
Expand All @@ -2718,14 +2757,14 @@ mod tests {
c1_null_count_field.with_nullable(true) // could be nullable if stats are not present
)
);
// c1 < 1 should add c1_row_count
let c1_row_count_field = Field::new("c1_row_count", DataType::UInt64, false);
// c1 < 1 should add row_count
let row_count_field = Field::new("row_count", DataType::UInt64, false);
assert_eq!(
required_columns.columns[2],
(
phys_expr::Column::new("c1", 0),
StatisticsType::RowCount,
c1_row_count_field.with_nullable(true) // could be nullable if stats are not present
row_count_field.with_nullable(true) // could be nullable if stats are not present
)
);
// c2 = 2 should add c2_min and c2_max
Expand Down Expand Up @@ -2757,18 +2796,18 @@ mod tests {
c2_null_count_field.with_nullable(true) // could be nullable if stats are not present
)
);
// c2 = 2 should add c2_row_count
let c2_row_count_field = Field::new("c2_row_count", DataType::UInt64, false);
// c2 = 1 should add row_count
let row_count_field = Field::new("row_count", DataType::UInt64, false);
assert_eq!(
required_columns.columns[6],
required_columns.columns[2],
(
phys_expr::Column::new("c2", 1),
phys_expr::Column::new("c1", 0),
StatisticsType::RowCount,
c2_row_count_field.with_nullable(true) // could be nullable if stats are not present
row_count_field.with_nullable(true) // could be nullable if stats are not present
)
);
// c2 = 3 shouldn't add any new statistics fields
assert_eq!(required_columns.columns.len(), 7);
assert_eq!(required_columns.columns.len(), 6);

Ok(())
}
Expand All @@ -2785,7 +2824,7 @@ mod tests {
vec![lit(1), lit(2), lit(3)],
false,
));
let expected_expr = "c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 3 AND 3 <= c1_max@1";
let expected_expr = "c1_null_count@2 != row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_null_count@2 != row_count@3 AND c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_null_count@2 != row_count@3 AND c1_min@0 <= 3 AND 3 <= c1_max@1";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand Down Expand Up @@ -2821,7 +2860,7 @@ mod tests {
vec![lit(1), lit(2), lit(3)],
true,
));
let expected_expr = "c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 1 OR 1 != c1_max@1) AND c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 2 OR 2 != c1_max@1) AND c1_null_count@2 != c1_row_count@3 AND (c1_min@0 != 3 OR 3 != c1_max@1)";
let expected_expr = "c1_null_count@2 != row_count@3 AND (c1_min@0 != 1 OR 1 != c1_max@1) AND c1_null_count@2 != row_count@3 AND (c1_min@0 != 2 OR 2 != c1_max@1) AND c1_null_count@2 != row_count@3 AND (c1_min@0 != 3 OR 3 != c1_max@1)";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand Down Expand Up @@ -2867,7 +2906,7 @@ mod tests {
// test c1 in(1, 2) and c2 BETWEEN 4 AND 5
let expr3 = expr1.and(expr2);

let expected_expr = "(c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_null_count@2 != c1_row_count@3 AND c1_min@0 <= 2 AND 2 <= c1_max@1) AND c2_null_count@5 != c2_row_count@6 AND c2_max@4 >= 4 AND c2_null_count@5 != c2_row_count@6 AND c2_min@7 <= 5";
let expected_expr = "(c1_null_count@2 != row_count@3 AND c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_null_count@2 != row_count@3 AND c1_min@0 <= 2 AND 2 <= c1_max@1) AND c2_null_count@5 != row_count@3 AND c2_max@4 >= 4 AND c2_null_count@5 != row_count@3 AND c2_min@6 <= 5";
let predicate_expr =
test_build_predicate_expression(&expr3, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand All @@ -2894,7 +2933,7 @@ mod tests {
#[test]
fn row_group_predicate_cast() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr = "c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)";
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)";

// test cast(c1 as int64) = 1
// test column on the left
Expand All @@ -2910,7 +2949,7 @@ mod tests {
assert_eq!(predicate_expr.to_string(), expected_expr);

let expected_expr =
"c1_null_count@1 != c1_row_count@2 AND TRY_CAST(c1_max@0 AS Int64) > 1";
"c1_null_count@1 != row_count@2 AND TRY_CAST(c1_max@0 AS Int64) > 1";

// test column on the left
let expr =
Expand Down Expand Up @@ -2942,7 +2981,7 @@ mod tests {
],
false,
));
let expected_expr = "c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR c1_null_count@2 != c1_row_count@3 AND CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)";
let expected_expr = "c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR c1_null_count@2 != row_count@3 AND CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand All @@ -2956,7 +2995,7 @@ mod tests {
],
true,
));
let expected_expr = "c1_null_count@2 != c1_row_count@3 AND (CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) AND c1_null_count@2 != c1_row_count@3 AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) AND c1_null_count@2 != c1_row_count@3 AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))";
let expected_expr = "c1_null_count@2 != row_count@3 AND (CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) AND c1_null_count@2 != row_count@3 AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) AND c1_null_count@2 != row_count@3 AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))";
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), expected_expr);
Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/parquet.slt
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ SELECT * FROM test_table ORDER BY int_col;
5 eee 500 1970-01-06
6 fff 600 1970-01-07

# Check output plan, expect no "output_ordering" clause in the physical_plan -> DataSourceExec:
# Check output plan, expect no "output_ordering" clause in the physical_plan -> ParquetExec:
query TT
EXPLAIN SELECT int_col, string_col
FROM test_table
Expand Down Expand Up @@ -109,7 +109,7 @@ STORED AS PARQUET
WITH ORDER (string_col ASC NULLS LAST, int_col ASC NULLS LAST)
LOCATION 'test_files/scratch/parquet/test_table';

# Check output plan, expect an "output_ordering" clause in the physical_plan -> DataSourceExec:
# Check output plan, expect an "output_ordering" clause in the physical_plan -> ParquetExec:
query TT
EXPLAIN SELECT int_col, string_col
FROM test_table
Expand All @@ -130,7 +130,7 @@ STORED AS PARQUET;
----
3

# Check output plan again, expect no "output_ordering" clause in the physical_plan -> DataSourceExec,
# Check output plan again, expect no "output_ordering" clause in the physical_plan -> ParquetExec,
# due to there being more files than partitions:
query TT
EXPLAIN SELECT int_col, string_col
Expand Down Expand Up @@ -625,7 +625,7 @@ physical_plan
01)CoalesceBatchesExec: target_batch_size=8192
02)--FilterExec: column1@0 LIKE f%
03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/foo.parquet]]}, projection=[column1], file_type=parquet, predicate=column1@0 LIKE f%, pruning_predicate=column1_null_count@2 != column1_row_count@3 AND column1_min@0 <= g AND f <= column1_max@1, required_guarantees=[]
04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/foo.parquet]]}, projection=[column1], file_type=parquet, predicate=column1@0 LIKE f%, pruning_predicate=column1_null_count@2 != row_count@3 AND column1_min@0 <= g AND f <= column1_max@1, required_guarantees=[]

statement ok
drop table foo
Loading

0 comments on commit df0d966

Please sign in to comment.