Skip to content

Commit

Permalink
Define an ExpressionTransform trait (#530)
Browse files Browse the repository at this point in the history
## What changes are proposed in this pull request?

Similar to the existing `SchemaTransform` trait, an
`ExpressionTransform` trait can make it a lot easier to recursively
manipulate expressions. Define the trait and introduce an expression
depth checker to test it.

### This PR affects the following public APIs

Because enum tuple variants cannot be passed as function arguments, we
factor out `UnaryExpression`, `BinaryExpression` and
`VariadicExpression` structs which the corresponding expression variants
can then use. This requires updating match arms throughout the code
base.

## How was this change tested?

New unit test.
  • Loading branch information
scovich authored Nov 27, 2024
1 parent ac8dcdc commit 953ceed
Show file tree
Hide file tree
Showing 6 changed files with 434 additions and 59 deletions.
9 changes: 5 additions & 4 deletions ffi/src/expressions/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use std::ffi::c_void;

use crate::{handle::Handle, kernel_string_slice, KernelStringSlice};
use delta_kernel::expressions::{
ArrayData, BinaryOperator, Expression, Scalar, StructData, UnaryOperator, VariadicOperator,
ArrayData, BinaryExpression, BinaryOperator, Expression, Scalar, StructData, UnaryExpression,
UnaryOperator, VariadicExpression, VariadicOperator,
};

/// Free the memory the passed SharedExpression
Expand Down Expand Up @@ -330,7 +331,7 @@ pub unsafe extern "C" fn visit_expression(
Expression::Struct(exprs) => {
visit_expression_struct_expr(visitor, exprs, sibling_list_id)
}
Expression::BinaryOperation { op, left, right } => {
Expression::Binary(BinaryExpression { op, left, right }) => {
let child_list_id = call!(visitor, make_field_list, 2);
visit_expression_impl(visitor, left, child_list_id);
visit_expression_impl(visitor, right, child_list_id);
Expand All @@ -351,7 +352,7 @@ pub unsafe extern "C" fn visit_expression(
};
op(visitor.data, sibling_list_id, child_list_id);
}
Expression::UnaryOperation { op, expr } => {
Expression::Unary(UnaryExpression { op, expr }) => {
let child_id_list = call!(visitor, make_field_list, 1);
visit_expression_impl(visitor, expr, child_id_list);
let op = match op {
Expand All @@ -360,7 +361,7 @@ pub unsafe extern "C" fn visit_expression(
};
op(visitor.data, sibling_list_id, child_id_list);
}
Expression::VariadicOperation { op, exprs } => {
Expression::Variadic(VariadicExpression { op, exprs }) => {
visit_expression_variadic(visitor, op, exprs, sibling_list_id)
}
}
Expand Down
21 changes: 12 additions & 9 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use crate::engine::arrow_data::ArrowEngineData;
use crate::engine::arrow_utils::prim_array_cmp;
use crate::engine::ensure_data_types::ensure_data_types;
use crate::error::{DeltaResult, Error};
use crate::expressions::{BinaryOperator, Expression, Scalar, UnaryOperator, VariadicOperator};
use crate::expressions::{
BinaryExpression, BinaryOperator, Expression, Scalar, UnaryExpression, UnaryOperator,
VariadicExpression, VariadicOperator,
};
use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, Schema, SchemaRef, StructField};
use crate::{EngineData, ExpressionEvaluator, ExpressionHandler};

Expand Down Expand Up @@ -218,19 +221,19 @@ fn evaluate_expression(
(Struct(_), _) => Err(Error::generic(
"Data type is required to evaluate struct expressions",
)),
(UnaryOperation { op, expr }, _) => {
(Unary(UnaryExpression { op, expr }), _) => {
let arr = evaluate_expression(expr.as_ref(), batch, None)?;
Ok(match op {
UnaryOperator::Not => Arc::new(not(downcast_to_bool(&arr)?)?),
UnaryOperator::IsNull => Arc::new(is_null(&arr)?),
})
}
(
BinaryOperation {
Binary(BinaryExpression {
op: In,
left,
right,
},
}),
_,
) => match (left.as_ref(), right.as_ref()) {
(Literal(_), Column(_)) => {
Expand Down Expand Up @@ -287,11 +290,11 @@ fn evaluate_expression(
))),
},
(
BinaryOperation {
Binary(BinaryExpression {
op: NotIn,
left,
right,
},
}),
_,
) => {
let reverse_op = Expression::binary(In, *left.clone(), *right.clone());
Expand All @@ -300,7 +303,7 @@ fn evaluate_expression(
.map(wrap_comparison_result)
.map_err(Error::generic_err)
}
(BinaryOperation { op, left, right }, _) => {
(Binary(BinaryExpression { op, left, right }), _) => {
let left_arr = evaluate_expression(left.as_ref(), batch, None)?;
let right_arr = evaluate_expression(right.as_ref(), batch, None)?;

Expand All @@ -323,7 +326,7 @@ fn evaluate_expression(

eval(&left_arr, &right_arr).map_err(Error::generic_err)
}
(VariadicOperation { op, exprs }, None | Some(&DataType::BOOLEAN)) => {
(Variadic(VariadicExpression { op, exprs }), None | Some(&DataType::BOOLEAN)) => {
type Operation = fn(&BooleanArray, &BooleanArray) -> Result<BooleanArray, ArrowError>;
let (reducer, default): (Operation, _) = match op {
VariadicOperator::And => (and_kleene, true),
Expand All @@ -340,7 +343,7 @@ fn evaluate_expression(
evaluate_expression(&Expression::literal(default), batch, result_type)
})
}
(VariadicOperation { .. }, _) => {
(Variadic(_), _) => {
// NOTE: Update this error message if we add support for variadic operations on other types
Err(Error::Generic(format!(
"Variadic {expression:?} is expected to return boolean results, got {result_type:?}"
Expand Down
8 changes: 4 additions & 4 deletions kernel/src/engine/parquet_row_group_skipping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::engine::parquet_stats_skipping::{
ParquetStatsProvider, ParquetStatsSkippingFilter as _,
};
use crate::expressions::{ColumnName, Expression, Scalar};
use crate::expressions::{ColumnName, Expression, Scalar, UnaryExpression, BinaryExpression, VariadicExpression};
use crate::schema::{DataType, PrimitiveType};
use chrono::{DateTime, Days};
use parquet::arrow::arrow_reader::ArrowReaderBuilder;
Expand Down Expand Up @@ -231,9 +231,9 @@ pub(crate) fn compute_field_indices(
Literal(_) => {}
Column(name) => cols.extend([name.clone()]), // returns `()`, unlike `insert`
Struct(fields) => fields.iter().for_each(recurse),
UnaryOperation { expr, .. } => recurse(expr),
BinaryOperation { left, right, .. } => [left, right].iter().for_each(|e| recurse(e)),
VariadicOperation { exprs, .. } => exprs.iter().for_each(recurse),
Unary(UnaryExpression { expr, .. }) => recurse(expr),
Binary(BinaryExpression { left, right, .. }) => [left, right].iter().for_each(|e| recurse(e)),
Variadic(VariadicExpression { exprs, .. }) => exprs.iter().for_each(recurse),
}
}

Expand Down
8 changes: 4 additions & 4 deletions kernel/src/engine/parquet_stats_skipping.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! An implementation of data skipping that leverages parquet stats from the file footer.
use crate::expressions::{
BinaryOperator, ColumnName, Expression as Expr, Scalar, UnaryOperator, VariadicOperator,
BinaryOperator, ColumnName, Expression as Expr, Scalar, UnaryOperator, VariadicOperator, VariadicExpression, BinaryExpression,
};
use crate::predicates::{
DataSkippingPredicateEvaluator, PredicateEvaluator, PredicateEvaluatorDefaults,
Expand Down Expand Up @@ -155,9 +155,9 @@ pub(crate) trait ParquetStatsSkippingFilter {

impl<T: DataSkippingPredicateEvaluator<Output = bool>> ParquetStatsSkippingFilter for T {
fn eval_sql_where(&self, filter: &Expr) -> Option<bool> {
use Expr::{BinaryOperation, VariadicOperation};
use Expr::{Binary, Variadic};
match filter {
VariadicOperation { op: VariadicOperator::And, exprs } => {
Variadic(VariadicExpression { op: VariadicOperator::And, exprs }) => {
let exprs: Vec<_> = exprs
.iter()
.map(|expr| self.eval_sql_where(expr))
Expand All @@ -168,7 +168,7 @@ impl<T: DataSkippingPredicateEvaluator<Output = bool>> ParquetStatsSkippingFilte
.collect();
self.eval_variadic(VariadicOperator::And, &exprs, false)
}
BinaryOperation { op, left, right } => self.eval_binary_nullsafe(*op, left, right),
Binary(BinaryExpression { op, left, right }) => self.eval_binary_nullsafe(*op, left, right),
_ => self.eval_expr(filter, false),
}
}
Expand Down
Loading

0 comments on commit 953ceed

Please sign in to comment.