Skip to content

Commit

Permalink
feat: Add basic partition pruning support
Browse files Browse the repository at this point in the history
  • Loading branch information
scovich committed Feb 21, 2025
1 parent baa3fc3 commit c069278
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 23 deletions.
1 change: 0 additions & 1 deletion kernel/src/predicates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,6 @@ impl ResolveColumnAsScalar for EmptyColumnResolver {
}

// In testing, it is convenient to just build a hashmap of scalar values.
#[cfg(test)]
impl ResolveColumnAsScalar for std::collections::HashMap<ColumnName, Scalar> {
fn resolve_column(&self, col: &ColumnName) -> Option<Scalar> {
self.get(col).cloned()
Expand Down
104 changes: 85 additions & 19 deletions kernel/src/scan/log_replay.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use super::{ScanData, Transform};
use crate::actions::get_log_add_schema;
use crate::engine_data::{GetData, RowVisitor, TypedGetData as _};
use crate::expressions::{column_expr, column_name, ColumnName, Expression, ExpressionRef};
use crate::scan::{DeletionVectorDescriptor, TransformExpr};
use crate::predicates::{DefaultPredicateEvaluator, PredicateEvaluator as _};
use crate::scan::{DeletionVectorDescriptor, Scalar, TransformExpr};
use crate::schema::{ColumnNamesAndTypes, DataType, MapType, SchemaRef, StructField, StructType};
use crate::utils::require;
use crate::{DeltaResult, Engine, EngineData, Error, ExpressionEvaluator};
Expand All @@ -30,7 +31,8 @@ impl FileActionKey {
}

struct LogReplayScanner {
filter: Option<DataSkippingFilter>,
partition_filter: Option<ExpressionRef>,
data_skipping_filter: Option<DataSkippingFilter>,

/// A set of (data file path, dv_unique_id) pairs that have been seen thus
/// far in the log. This is used to filter out files with Remove actions as
Expand All @@ -47,6 +49,7 @@ struct AddRemoveDedupVisitor<'seen> {
selection_vector: Vec<bool>,
logical_schema: SchemaRef,
transform: Option<Arc<Transform>>,
partition_filter: Option<ExpressionRef>,
row_transform_exprs: Vec<Option<ExpressionRef>>,
is_log_batch: bool,
}
Expand Down Expand Up @@ -82,29 +85,54 @@ impl AddRemoveDedupVisitor<'_> {
}
}

fn parse_partition_value(
&self,
field_idx: usize,
partition_values: &HashMap<String, String>,
) -> DeltaResult<(usize, (String, Scalar))> {
let field = self.logical_schema.fields.get_index(field_idx);
let Some((_, field)) = field else {
return Err(Error::InternalError(format!(
"out of bounds partition column field index {field_idx}"
)));
};
let name = field.physical_name();
let partition_value =
super::parse_partition_value(partition_values.get(name), field.data_type())?;
Ok((field_idx, (name.to_string(), partition_value)))
}

fn parse_partition_values(
&self,
transform: &Transform,
partition_values: &HashMap<String, String>,
) -> DeltaResult<HashMap<usize, (String, Scalar)>> {
transform
.iter()
.filter_map(|transform_expr| match transform_expr {
TransformExpr::Partition(field_idx) => {
Some(self.parse_partition_value(*field_idx, partition_values))
}
TransformExpr::Static(_) => None,
})
.try_collect()
}

/// Compute an expression that will transform from physical to logical for a given Add file action
fn get_transform_expr<'a>(
fn get_transform_expr(
&self,
i: usize,
transform: &Transform,
getters: &[&'a dyn GetData<'a>],
mut partition_values: HashMap<usize, (String, Scalar)>,
) -> DeltaResult<ExpressionRef> {
let partition_values: HashMap<_, _> = getters[1].get(i, "add.partitionValues")?;
let transforms = transform
.iter()
.map(|transform_expr| match transform_expr {
TransformExpr::Partition(field_idx) => {
let field = self.logical_schema.fields.get_index(*field_idx);
let Some((_, field)) = field else {
return Err(Error::Generic(
format!("logical schema did not contain expected field at {field_idx}, can't transform data")
));
let Some((_, partition_value)) = partition_values.remove(field_idx) else {
return Err(Error::InternalError(format!(
"missing partition value for field index {field_idx}"
)));
};
let name = field.physical_name();
let partition_value = super::parse_partition_value(
partition_values.get(name),
field.data_type(),
)?;
Ok(partition_value.into())
}
TransformExpr::Static(field_expr) => Ok(field_expr.clone()),
Expand All @@ -113,6 +141,24 @@ impl AddRemoveDedupVisitor<'_> {
Ok(Arc::new(Expression::Struct(transforms)))
}

fn is_file_partition_pruned(
&self,
partition_values: &HashMap<usize, (String, Scalar)>,
) -> bool {
if partition_values.is_empty() {
return false;
}
let Some(partition_filter) = &self.partition_filter else {
return false;
};
let partition_values: HashMap<_, _> = partition_values
.values()
.map(|(k, v)| (ColumnName::new([k]), v.clone()))
.collect();
let evaluator = DefaultPredicateEvaluator::from(partition_values);
evaluator.eval_sql_where(partition_filter) == Some(false)
}

/// True if this row contains an Add action that should survive log replay. Skip it if the row
/// is not an Add action, or the file has already been seen previously.
fn is_valid_add<'a>(&mut self, i: usize, getters: &[&'a dyn GetData<'a>]) -> DeltaResult<bool> {
Expand All @@ -138,6 +184,24 @@ impl AddRemoveDedupVisitor<'_> {
None => None,
};

// Apply partition pruning (to adds only) before deduplication, so that we don't waste memory
// tracking pruned files. Removes don't get pruned and we'll still have to track them.
//
// WARNING: It's not safe to partition-prune removes (just like it's not safe to data skip
// removes), because they are needed to suppress earlier incompatible adds we might
// encounter if the table's schema was replaced after the most recent checkpoint.
let partition_values = match &self.transform {
Some(transform) if is_add => {
let partition_values = getters[1].get(i, "add.partitionValues")?;
let partition_values = self.parse_partition_values(transform, &partition_values)?;
if self.is_file_partition_pruned(&partition_values) {
return Ok(false);
}
partition_values
}
_ => Default::default(),
};

// Check both adds and removes (skipping already-seen), but only transform and return adds
let file_key = FileActionKey::new(path, dv_unique_id);
if self.check_and_record_seen(file_key) || !is_add {
Expand All @@ -146,7 +210,7 @@ impl AddRemoveDedupVisitor<'_> {
let transform = self
.transform
.as_ref()
.map(|transform| self.get_transform_expr(i, transform, getters))
.map(|transform| self.get_transform_expr(transform, partition_values))
.transpose()?;
if transform.is_some() {
// fill in any needed `None`s for previous rows
Expand Down Expand Up @@ -250,7 +314,8 @@ impl LogReplayScanner {
/// Create a new [`LogReplayScanner`] instance
fn new(engine: &dyn Engine, physical_predicate: Option<(ExpressionRef, SchemaRef)>) -> Self {
Self {
filter: DataSkippingFilter::new(engine, physical_predicate),
partition_filter: physical_predicate.as_ref().map(|(e, _)| e.clone()),
data_skipping_filter: DataSkippingFilter::new(engine, physical_predicate),
seen: Default::default(),
}
}
Expand All @@ -265,7 +330,7 @@ impl LogReplayScanner {
) -> DeltaResult<ScanData> {
// Apply data skipping to get back a selection vector for actions that passed skipping. We
// will update the vector below as log replay identifies duplicates that should be ignored.
let selection_vector = match &self.filter {
let selection_vector = match &self.data_skipping_filter {
Some(filter) => filter.apply(actions)?,
None => vec![true; actions.len()],
};
Expand All @@ -276,6 +341,7 @@ impl LogReplayScanner {
selection_vector,
logical_schema,
transform,
partition_filter: self.partition_filter.clone(),
row_transform_exprs: Vec::new(),
is_log_batch,
};
Expand Down
2 changes: 1 addition & 1 deletion kernel/src/scan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ impl Scan {
// for other transforms as we support them)
let static_transform = (self.have_partition_cols
|| self.snapshot.column_mapping_mode() != ColumnMappingMode::None)
.then_some(Arc::new(Scan::get_static_transform(&self.all_fields)));
.then(|| Arc::new(Scan::get_static_transform(&self.all_fields)));
let physical_predicate = match self.physical_predicate.clone() {
PhysicalPredicate::StaticSkipAll => return Ok(None.into_iter().flatten()),
PhysicalPredicate::Some(predicate, schema) => Some((predicate, schema)),
Expand Down
135 changes: 133 additions & 2 deletions kernel/tests/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,26 @@ fn table_for_numbers(nums: Vec<u32>) -> Vec<String> {
res
}

// get the basic_partitioned table for a set of expected letters
fn table_for_letters(letters: &[char]) -> Vec<String> {
let mut res: Vec<String> = vec![
"+--------+--------+",
"| letter | number |",
"+--------+--------+",
]
.into_iter()
.map(String::from)
.collect();
let rows = vec![(1, 'a'), (2, 'b'), (3, 'c'), (4, 'a'), (5, 'e')];
for (num, letter) in rows {
if letters.contains(&letter) {
res.push(format!("| {letter} | {num} |"));
}
}
res.push("+--------+--------+".to_string());
res
}

#[test]
fn predicate_on_number() -> Result<(), Box<dyn std::error::Error>> {
let cases = vec![
Expand Down Expand Up @@ -614,6 +634,118 @@ fn predicate_on_number() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}

#[test]
fn predicate_on_letter() -> Result<(), Box<dyn std::error::Error>> {
// Test basic column pruning. Note that the actual expression machinery is already well-tested,
// so we're just testing wiring here.
let null_row_table: Vec<String> = vec![
"+--------+--------+",
"| letter | number |",
"+--------+--------+",
"| | 6 |",
"+--------+--------+",
]
.into_iter()
.map(String::from)
.collect();

let cases = vec![
(column_expr!("letter").is_null(), null_row_table),
(
column_expr!("letter").is_not_null(),
table_for_letters(&['a', 'b', 'c', 'e']),
),
(
column_expr!("letter").lt("c"),
table_for_letters(&['a', 'b']),
),
(
column_expr!("letter").le("c"),
table_for_letters(&['a', 'b', 'c']),
),
(column_expr!("letter").gt("c"), table_for_letters(&['e'])),
(
column_expr!("letter").ge("c"),
table_for_letters(&['c', 'e']),
),
(column_expr!("letter").eq("c"), table_for_letters(&['c'])),
(
column_expr!("letter").ne("c"),
table_for_letters(&['a', 'b', 'e']),
),
];

for (expr, expected) in cases {
read_table_data(
"./tests/data/basic_partitioned",
Some(&["letter", "number"]),
Some(expr),
expected,
)?;
}
Ok(())
}

#[test]
fn predicate_on_letter_and_number() -> Result<(), Box<dyn std::error::Error>> {
// Partition skipping and file skipping are currently implemented separately. Mixing them in an
// AND clause will evaulate each separately, but mixing them in an OR clause disables both.
let full_table: Vec<String> = vec![
"+--------+--------+",
"| letter | number |",
"+--------+--------+",
"| | 6 |",
"| a | 1 |",
"| a | 4 |",
"| b | 2 |",
"| c | 3 |",
"| e | 5 |",
"+--------+--------+",
]
.into_iter()
.map(String::from)
.collect();

let cases = vec![
(
Expression::or(
// No pruning power
column_expr!("letter").gt("a"),
column_expr!("number").gt(3i64),
),
full_table,
),
(
Expression::and(
column_expr!("letter").gt("a"), // numbers 2, 3, 5
column_expr!("number").gt(3i64), // letters a, e
),
table_for_letters(&['e']),
),
(
Expression::and(
column_expr!("letter").gt("a"), // numbers 2, 3, 5
Expression::or(
// No pruning power
column_expr!("letter").eq("c"),
column_expr!("number").eq(3i64),
),
),
table_for_letters(&['b', 'c', 'e']),
),
];

for (expr, expected) in cases {
read_table_data(
"./tests/data/basic_partitioned",
Some(&["letter", "number"]),
Some(expr),
expected,
)?;
}
Ok(())
}

#[test]
fn predicate_on_number_not() -> Result<(), Box<dyn std::error::Error>> {
let cases = vec![
Expand Down Expand Up @@ -960,8 +1092,7 @@ async fn predicate_on_non_nullable_partition_column() -> Result<(), Box<dyn std:
assert_eq!(&batch, &result_batch);
files_scanned += 1;
}
// Partition pruning is not yet implemented, so we still read the data for both partitions
assert_eq!(2, files_scanned);
assert_eq!(1, files_scanned);
Ok(())
}

Expand Down

0 comments on commit c069278

Please sign in to comment.