diff --git a/kernel/src/engine/arrow_data.rs b/kernel/src/engine/arrow_data.rs index 7c2dd5f40..50a627e5c 100644 --- a/kernel/src/engine/arrow_data.rs +++ b/kernel/src/engine/arrow_data.rs @@ -4,9 +4,11 @@ use crate::{DeltaResult, Error}; use arrow_array::cast::AsArray; use arrow_array::types::{Int32Type, Int64Type}; -use arrow_array::{Array, ArrayRef, GenericListArray, MapArray, OffsetSizeTrait, RecordBatch, StructArray}; -use arrow_schema::{FieldRef, DataType as ArrowDataType}; -use tracing::{debug}; +use arrow_array::{ + Array, ArrayRef, GenericListArray, MapArray, OffsetSizeTrait, RecordBatch, StructArray, +}; +use arrow_schema::{DataType as ArrowDataType, FieldRef}; +use tracing::debug; use std::collections::{HashMap, HashSet}; @@ -138,14 +140,20 @@ impl EngineData for ArrowEngineData { self.data.num_rows() } - fn visit_rows(&self, leaf_columns: &[ColumnName], visitor: &mut dyn RowVisitor) -> DeltaResult<()> { + fn visit_rows( + &self, + leaf_columns: &[ColumnName], + visitor: &mut dyn RowVisitor, + ) -> DeltaResult<()> { // Make sure the caller passed the correct number of column names let leaf_types = visitor.selected_column_names_and_types().1; if leaf_types.len() != leaf_columns.len() { return Err(Error::MissingColumn(format!( "Visitor expected {} column names, but caller passed {}", - leaf_types.len(), leaf_columns.len() - )).with_backtrace()); + leaf_types.len(), + leaf_columns.len() + )) + .with_backtrace()); } // Collect the names of all leaf columns we want to extract, along with their parents, to @@ -154,7 +162,7 @@ impl EngineData for ArrowEngineData { let mut mask = HashSet::new(); for column in leaf_columns { for i in 0..column.len() { - mask.insert(&column[..i+1]); + mask.insert(&column[..i + 1]); } } debug!("Column mask for selected columns {leaf_columns:?} is {mask:#?}"); @@ -162,12 +170,11 @@ impl EngineData for ArrowEngineData { let mut getters = vec![]; Self::extract_columns(&mut vec![], &mut getters, leaf_types, &mask, &self.data)?; if getters.len() != leaf_columns.len() { - return Err(Error::MissingColumn( - format!( - "Visitor expected {} leaf columns, but only {} were found in the data", - leaf_columns.len(), getters.len() - ) - )); + return Err(Error::MissingColumn(format!( + "Visitor expected {} leaf columns, but only {} were found in the data", + leaf_columns.len(), + getters.len() + ))); } visitor.visit(self.len(), &getters) } @@ -185,14 +192,11 @@ impl ArrowEngineData { path.push(field.name().to_string()); if column_mask.contains(&path[..]) { if let Some(struct_array) = column.as_struct_opt() { - debug!("Recurse into a struct array for {}", ColumnName::new(path.iter())); - Self::extract_columns( - path, - getters, - leaf_types, - column_mask, - struct_array, - )?; + debug!( + "Recurse into a struct array for {}", + ColumnName::new(path.iter()) + ); + Self::extract_columns(path, getters, leaf_types, column_mask, struct_array)?; } else if column.data_type() == &ArrowDataType::Null { debug!("Pushing a null array for {}", ColumnName::new(path.iter())); getters.push(&()); @@ -215,16 +219,20 @@ impl ArrowEngineData { col: &'a dyn Array, ) -> DeltaResult<&'a dyn GetData<'a>> { use ArrowDataType::Utf8; - let col_as_list = || if let Some(array) = col.as_list_opt::() { - (array.value_type() == Utf8).then_some(array as _) - } else if let Some(array) = col.as_list_opt::() { - (array.value_type() == Utf8).then_some(array as _) - } else { - None + let col_as_list = || { + if let Some(array) = col.as_list_opt::() { + (array.value_type() == Utf8).then_some(array as _) + } else if let Some(array) = col.as_list_opt::() { + (array.value_type() == Utf8).then_some(array as _) + } else { + None + } + }; + let col_as_map = || { + col.as_map_opt().and_then(|array| { + (array.key_type() == &Utf8 && array.value_type() == &Utf8).then_some(array as _) + }) }; - let col_as_map = || col.as_map_opt().and_then(|array| { - (array.key_type() == &Utf8 && array.value_type() == &Utf8).then_some(array as _) - }); let result: Result<&'a dyn GetData<'a>, _> = match data_type { &DataType::BOOLEAN => { debug!("Pushing boolean array for {}", ColumnName::new(path)); @@ -236,11 +244,15 @@ impl ArrowEngineData { } &DataType::INTEGER => { debug!("Pushing int32 array for {}", ColumnName::new(path)); - col.as_primitive_opt::().map(|a| a as _).ok_or("int") + col.as_primitive_opt::() + .map(|a| a as _) + .ok_or("int") } &DataType::LONG => { debug!("Pushing int64 array for {}", ColumnName::new(path)); - col.as_primitive_opt::().map(|a| a as _).ok_or("long") + col.as_primitive_opt::() + .map(|a| a as _) + .ok_or("long") } DataType::Array(_) => { debug!("Pushing list for {}", ColumnName::new(path)); @@ -252,14 +264,17 @@ impl ArrowEngineData { } data_type => { return Err(Error::UnexpectedColumnType(format!( - "On {}: Unsupported type {data_type}", ColumnName::new(path) + "On {}: Unsupported type {data_type}", + ColumnName::new(path) ))); } }; result.map_err(|type_name| { Error::UnexpectedColumnType(format!( "Type mismatch on {}: expected {}, got {}", - ColumnName::new(path), type_name, col.data_type() + ColumnName::new(path), + type_name, + col.data_type() )) }) } diff --git a/kernel/src/engine/ensure_data_types.rs b/kernel/src/engine/ensure_data_types.rs index 88ff01626..b6f186671 100644 --- a/kernel/src/engine/ensure_data_types.rs +++ b/kernel/src/engine/ensure_data_types.rs @@ -1,6 +1,9 @@ //! Helpers to ensure that kernel data types match arrow data types -use std::{collections::{HashMap, HashSet}, ops::Deref}; +use std::{ + collections::{HashMap, HashSet}, + ops::Deref, +}; use arrow_schema::{DataType as ArrowDataType, Field as ArrowField}; use itertools::Itertools; @@ -29,7 +32,9 @@ pub(crate) fn ensure_data_types( arrow_type: &ArrowDataType, check_nullability_and_metadata: bool, ) -> DeltaResult { - let check = EnsureDataTypes { check_nullability_and_metadata }; + let check = EnsureDataTypes { + check_nullability_and_metadata, + }; check.ensure_data_types(kernel_type, arrow_type) } @@ -61,41 +66,32 @@ impl EnsureDataTypes { } // strings, bools, and binary aren't primitive in arrow (&DataType::BOOLEAN, ArrowDataType::Boolean) - | (&DataType::STRING, ArrowDataType::Utf8) - | (&DataType::BINARY, ArrowDataType::Binary) => { - Ok(DataTypeCompat::Identical) - } + | (&DataType::STRING, ArrowDataType::Utf8) + | (&DataType::BINARY, ArrowDataType::Binary) => Ok(DataTypeCompat::Identical), (DataType::Array(inner_type), ArrowDataType::List(arrow_list_field)) => { self.ensure_nullability( "List", inner_type.contains_null, arrow_list_field.is_nullable(), )?; - self.ensure_data_types( - &inner_type.element_type, - arrow_list_field.data_type(), - ) + self.ensure_data_types(&inner_type.element_type, arrow_list_field.data_type()) } (DataType::Map(kernel_map_type), ArrowDataType::Map(arrow_map_type, _)) => { let ArrowDataType::Struct(fields) = arrow_map_type.data_type() else { return Err(make_arrow_error("Arrow map type wasn't a struct.")); }; let [key_type, value_type] = fields.deref() else { - return Err(make_arrow_error("Arrow map type didn't have expected key/value fields")); + return Err(make_arrow_error( + "Arrow map type didn't have expected key/value fields", + )); }; - self.ensure_data_types( - &kernel_map_type.key_type, - key_type.data_type(), - )?; + self.ensure_data_types(&kernel_map_type.key_type, key_type.data_type())?; self.ensure_nullability( "Map", kernel_map_type.value_contains_null, value_type.is_nullable(), )?; - self.ensure_data_types( - &kernel_map_type.value_type, - value_type.data_type(), - )?; + self.ensure_data_types(&kernel_map_type.value_type, value_type.data_type())?; Ok(DataTypeCompat::Nested) } (DataType::Struct(kernel_fields), ArrowDataType::Struct(arrow_fields)) => { @@ -109,10 +105,7 @@ impl EnsureDataTypes { // ensure that for the fields that we found, the types match for (kernel_field, arrow_field) in mapped_fields.zip(arrow_fields) { self.ensure_nullability_and_metadata(kernel_field, arrow_field)?; - self.ensure_data_types( - &kernel_field.data_type, - arrow_field.data_type(), - )?; + self.ensure_data_types(&kernel_field.data_type, arrow_field.data_type())?; found_fields += 1; } @@ -146,11 +139,12 @@ impl EnsureDataTypes { kernel_field_is_nullable: bool, arrow_field_is_nullable: bool, ) -> DeltaResult<()> { - if self.check_nullability_and_metadata && kernel_field_is_nullable != arrow_field_is_nullable { + if self.check_nullability_and_metadata + && kernel_field_is_nullable != arrow_field_is_nullable + { Err(Error::Generic(format!( "{desc} has nullablily {} in kernel and {} in arrow", - kernel_field_is_nullable, - arrow_field_is_nullable, + kernel_field_is_nullable, arrow_field_is_nullable, ))) } else { Ok(()) @@ -160,10 +154,16 @@ impl EnsureDataTypes { fn ensure_nullability_and_metadata( &self, kernel_field: &StructField, - arrow_field: &ArrowField + arrow_field: &ArrowField, ) -> DeltaResult<()> { - self.ensure_nullability(&kernel_field.name, kernel_field.nullable, arrow_field.is_nullable())?; - if self.check_nullability_and_metadata && !metadata_eq(&kernel_field.metadata, arrow_field.metadata()) { + self.ensure_nullability( + &kernel_field.name, + kernel_field.nullable, + arrow_field.is_nullable(), + )?; + if self.check_nullability_and_metadata + && !metadata_eq(&kernel_field.metadata, arrow_field.metadata()) + { Err(Error::Generic(format!( "Field {} has metadata {:?} in kernel and {:?} in arrow", kernel_field.name, diff --git a/kernel/src/engine/mod.rs b/kernel/src/engine/mod.rs index 284844cd1..8ea07384a 100644 --- a/kernel/src/engine/mod.rs +++ b/kernel/src/engine/mod.rs @@ -10,6 +10,8 @@ pub(crate) mod arrow_conversion; any(feature = "default-engine-base", feature = "sync-engine") ))] pub mod arrow_expression; +#[cfg(feature = "arrow-expression")] +pub(crate) mod arrow_utils; #[cfg(feature = "default-engine-base")] pub mod default; @@ -17,19 +19,11 @@ pub mod default; #[cfg(feature = "sync-engine")] pub mod sync; -macro_rules! declare_modules { - ( $(($vis:vis, $module:ident)),*) => { - $( - $vis mod $module; - )* - }; -} - #[cfg(any(feature = "default-engine-base", feature = "sync-engine"))] -declare_modules!( - (pub, arrow_data), - (pub, parquet_row_group_skipping), - (pub(crate), arrow_get_data), - (pub(crate), arrow_utils), - (pub(crate), ensure_data_types) -); +pub mod arrow_data; +#[cfg(any(feature = "default-engine-base", feature = "sync-engine"))] +pub(crate) mod arrow_get_data; +#[cfg(any(feature = "default-engine-base", feature = "sync-engine"))] +pub(crate) mod ensure_data_types; +#[cfg(any(feature = "default-engine-base", feature = "sync-engine"))] +pub mod parquet_row_group_skipping; diff --git a/kernel/src/engine/parquet_row_group_skipping.rs b/kernel/src/engine/parquet_row_group_skipping.rs index 0adae6c4b..79c87d923 100644 --- a/kernel/src/engine/parquet_row_group_skipping.rs +++ b/kernel/src/engine/parquet_row_group_skipping.rs @@ -1,5 +1,7 @@ //! An implementation of parquet row group skipping using data skipping predicates over footer stats. -use crate::expressions::{ColumnName, Expression, Scalar, UnaryExpression, BinaryExpression, VariadicExpression}; +use crate::expressions::{ + BinaryExpression, ColumnName, Expression, Scalar, UnaryExpression, VariadicExpression, +}; use crate::predicates::parquet_stats_skipping::ParquetStatsProvider; use crate::schema::{DataType, PrimitiveType}; use chrono::{DateTime, Days}; @@ -231,7 +233,9 @@ pub(crate) fn compute_field_indices( Column(name) => cols.extend([name.clone()]), // returns `()`, unlike `insert` Struct(fields) => fields.iter().for_each(recurse), Unary(UnaryExpression { expr, .. }) => recurse(expr), - Binary(BinaryExpression { left, right, .. }) => [left, right].iter().for_each(|e| recurse(e)), + Binary(BinaryExpression { left, right, .. }) => { + [left, right].iter().for_each(|e| recurse(e)) + } Variadic(VariadicExpression { exprs, .. }) => exprs.iter().for_each(recurse), } } diff --git a/kernel/src/engine/parquet_row_group_skipping/tests.rs b/kernel/src/engine/parquet_row_group_skipping/tests.rs index 39a9c2ab5..37a3bb1b0 100644 --- a/kernel/src/engine/parquet_row_group_skipping/tests.rs +++ b/kernel/src/engine/parquet_row_group_skipping/tests.rs @@ -1,6 +1,6 @@ use super::*; +use crate::expressions::{column_expr, column_name}; use crate::predicates::DataSkippingPredicateEvaluator as _; -use crate::expressions::{column_name, column_expr}; use crate::Expression; use parquet::arrow::arrow_reader::ArrowReaderMetadata; use std::fs::File; @@ -63,7 +63,10 @@ fn test_get_stat_values() { // Only the BOOL column has any nulls assert_eq!(filter.get_nullcount_stat(&column_name!("bool")), Some(3)); - assert_eq!(filter.get_nullcount_stat(&column_name!("varlen.utf8")), Some(0)); + assert_eq!( + filter.get_nullcount_stat(&column_name!("varlen.utf8")), + Some(0) + ); assert_eq!( filter.get_min_stat(&column_name!("varlen.utf8"), &DataType::STRING), @@ -106,27 +109,18 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_min_stat( - &column_name!("numeric.floats.float64"), - &DataType::DOUBLE - ), + filter.get_min_stat(&column_name!("numeric.floats.float64"), &DataType::DOUBLE), Some(1147f64.into()) ); // type widening! assert_eq!( - filter.get_min_stat( - &column_name!("numeric.floats.float32"), - &DataType::DOUBLE - ), + filter.get_min_stat(&column_name!("numeric.floats.float32"), &DataType::DOUBLE), Some(139f64.into()) ); assert_eq!( - filter.get_min_stat( - &column_name!("numeric.floats.float32"), - &DataType::FLOAT - ), + filter.get_min_stat(&column_name!("numeric.floats.float32"), &DataType::FLOAT), Some(139f32.into()) ); @@ -216,10 +210,7 @@ fn test_get_stat_values() { // CHEAT: Interpret the timestamp_ntz column as a normal timestamp assert_eq!( - filter.get_min_stat( - &column_name!("chrono.timestamp_ntz"), - &DataType::TIMESTAMP - ), + filter.get_min_stat(&column_name!("chrono.timestamp_ntz"), &DataType::TIMESTAMP), Some( PrimitiveType::Timestamp .parse_scalar("1970-01-02 00:00:00.000000") @@ -241,10 +232,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_min_stat( - &column_name!("chrono.date32"), - &DataType::TIMESTAMP_NTZ - ), + filter.get_min_stat(&column_name!("chrono.date32"), &DataType::TIMESTAMP_NTZ), Some( PrimitiveType::TimestampNtz .parse_scalar("1971-01-01 00:00:00.000000") @@ -293,27 +281,18 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_max_stat( - &column_name!("numeric.floats.float64"), - &DataType::DOUBLE - ), + filter.get_max_stat(&column_name!("numeric.floats.float64"), &DataType::DOUBLE), Some(1125899906842747f64.into()) ); // type widening! assert_eq!( - filter.get_max_stat( - &column_name!("numeric.floats.float32"), - &DataType::DOUBLE - ), + filter.get_max_stat(&column_name!("numeric.floats.float32"), &DataType::DOUBLE), Some(1048699f64.into()) ); assert_eq!( - filter.get_max_stat( - &column_name!("numeric.floats.float32"), - &DataType::FLOAT - ), + filter.get_max_stat(&column_name!("numeric.floats.float32"), &DataType::FLOAT), Some(1048699f32.into()) ); @@ -403,10 +382,7 @@ fn test_get_stat_values() { // CHEAT: Interpret the timestamp_ntz column as a normal timestamp assert_eq!( - filter.get_max_stat( - &column_name!("chrono.timestamp_ntz"), - &DataType::TIMESTAMP - ), + filter.get_max_stat(&column_name!("chrono.timestamp_ntz"), &DataType::TIMESTAMP), Some( PrimitiveType::Timestamp .parse_scalar("1970-01-02 00:04:00.000000") @@ -428,10 +404,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_max_stat( - &column_name!("chrono.date32"), - &DataType::TIMESTAMP_NTZ - ), + filter.get_max_stat(&column_name!("chrono.date32"), &DataType::TIMESTAMP_NTZ), Some( PrimitiveType::TimestampNtz .parse_scalar("1971-01-05 00:00:00.000000") diff --git a/kernel/src/predicates/tests.rs b/kernel/src/predicates/tests.rs index 4185b60ec..fdeda8305 100644 --- a/kernel/src/predicates/tests.rs +++ b/kernel/src/predicates/tests.rs @@ -609,7 +609,7 @@ fn test_sql_where() { expect_eq!(null_filter.eval_sql_where(expr), Some(true), "{expr}"); expect_eq!(empty_filter.eval_sql_where(expr), Some(true), "{expr}"); - // Constrast normal vs SQL WHERE semantics - comparison + // Contrast normal vs SQL WHERE semantics - comparison let expr = &Expr::lt(col.clone(), VAL); expect_eq!(null_filter.eval(expr), None, "{expr}"); expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); @@ -631,7 +631,7 @@ fn test_sql_where() { expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); expect_eq!(empty_filter.eval_sql_where(expr), None, "{expr}"); - // Constrast normal vs SQL WHERE semantics - comparison inside AND + // Contrast normal vs SQL WHERE semantics - comparison inside AND let expr = &Expr::and(TRUE, Expr::lt(col.clone(), VAL)); expect_eq!(null_filter.eval(expr), None, "{expr}"); expect_eq!(null_filter.eval_sql_where(expr), Some(false), "{expr}"); diff --git a/kernel/src/schema/compare.rs b/kernel/src/schema/compare.rs index e465f1618..eb65540cf 100644 --- a/kernel/src/schema/compare.rs +++ b/kernel/src/schema/compare.rs @@ -56,7 +56,7 @@ pub(crate) type SchemaComparisonResult = Result<(), Error>; /// Represents a schema compatibility check for the type. If `self` can be read as `read_type`, /// this function returns `Ok(())`. Otherwise, this function returns `Err`. /// -/// TODO (Oussama): Remove the `allow(unsued)` once this is used in CDF. +/// TODO (Oussama): Remove the `allow(unused)` once this is used in CDF. #[allow(unused)] pub(crate) trait SchemaComparison { fn can_read_as(&self, read_type: &Self) -> SchemaComparisonResult;