Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cargo fmt does not handle mods defined in macros #676

Merged
merged 3 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 49 additions & 34 deletions kernel/src/engine/arrow_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ use crate::{DeltaResult, Error};

use arrow_array::cast::AsArray;
use arrow_array::types::{Int32Type, Int64Type};
use arrow_array::{Array, ArrayRef, GenericListArray, MapArray, OffsetSizeTrait, RecordBatch, StructArray};
use arrow_schema::{FieldRef, DataType as ArrowDataType};
use tracing::{debug};
use arrow_array::{
Array, ArrayRef, GenericListArray, MapArray, OffsetSizeTrait, RecordBatch, StructArray,
};
use arrow_schema::{DataType as ArrowDataType, FieldRef};
use tracing::debug;

use std::collections::{HashMap, HashSet};

Expand Down Expand Up @@ -138,14 +140,20 @@ impl EngineData for ArrowEngineData {
self.data.num_rows()
}

fn visit_rows(&self, leaf_columns: &[ColumnName], visitor: &mut dyn RowVisitor) -> DeltaResult<()> {
fn visit_rows(
&self,
leaf_columns: &[ColumnName],
visitor: &mut dyn RowVisitor,
) -> DeltaResult<()> {
// Make sure the caller passed the correct number of column names
let leaf_types = visitor.selected_column_names_and_types().1;
if leaf_types.len() != leaf_columns.len() {
return Err(Error::MissingColumn(format!(
"Visitor expected {} column names, but caller passed {}",
leaf_types.len(), leaf_columns.len()
)).with_backtrace());
leaf_types.len(),
leaf_columns.len()
))
.with_backtrace());
}

// Collect the names of all leaf columns we want to extract, along with their parents, to
Expand All @@ -154,20 +162,19 @@ impl EngineData for ArrowEngineData {
let mut mask = HashSet::new();
for column in leaf_columns {
for i in 0..column.len() {
mask.insert(&column[..i+1]);
mask.insert(&column[..i + 1]);
}
}
debug!("Column mask for selected columns {leaf_columns:?} is {mask:#?}");

let mut getters = vec![];
Self::extract_columns(&mut vec![], &mut getters, leaf_types, &mask, &self.data)?;
if getters.len() != leaf_columns.len() {
return Err(Error::MissingColumn(
format!(
"Visitor expected {} leaf columns, but only {} were found in the data",
leaf_columns.len(), getters.len()
)
));
return Err(Error::MissingColumn(format!(
"Visitor expected {} leaf columns, but only {} were found in the data",
leaf_columns.len(),
getters.len()
)));
}
visitor.visit(self.len(), &getters)
}
Expand All @@ -185,14 +192,11 @@ impl ArrowEngineData {
path.push(field.name().to_string());
if column_mask.contains(&path[..]) {
if let Some(struct_array) = column.as_struct_opt() {
debug!("Recurse into a struct array for {}", ColumnName::new(path.iter()));
Self::extract_columns(
path,
getters,
leaf_types,
column_mask,
struct_array,
)?;
debug!(
"Recurse into a struct array for {}",
ColumnName::new(path.iter())
);
Self::extract_columns(path, getters, leaf_types, column_mask, struct_array)?;
} else if column.data_type() == &ArrowDataType::Null {
debug!("Pushing a null array for {}", ColumnName::new(path.iter()));
getters.push(&());
Expand All @@ -215,16 +219,20 @@ impl ArrowEngineData {
col: &'a dyn Array,
) -> DeltaResult<&'a dyn GetData<'a>> {
use ArrowDataType::Utf8;
let col_as_list = || if let Some(array) = col.as_list_opt::<i32>() {
(array.value_type() == Utf8).then_some(array as _)
} else if let Some(array) = col.as_list_opt::<i64>() {
(array.value_type() == Utf8).then_some(array as _)
} else {
None
let col_as_list = || {
if let Some(array) = col.as_list_opt::<i32>() {
(array.value_type() == Utf8).then_some(array as _)
} else if let Some(array) = col.as_list_opt::<i64>() {
(array.value_type() == Utf8).then_some(array as _)
} else {
None
}
};
let col_as_map = || {
col.as_map_opt().and_then(|array| {
(array.key_type() == &Utf8 && array.value_type() == &Utf8).then_some(array as _)
})
};
let col_as_map = || col.as_map_opt().and_then(|array| {
(array.key_type() == &Utf8 && array.value_type() == &Utf8).then_some(array as _)
});
let result: Result<&'a dyn GetData<'a>, _> = match data_type {
&DataType::BOOLEAN => {
debug!("Pushing boolean array for {}", ColumnName::new(path));
Expand All @@ -236,11 +244,15 @@ impl ArrowEngineData {
}
&DataType::INTEGER => {
debug!("Pushing int32 array for {}", ColumnName::new(path));
col.as_primitive_opt::<Int32Type>().map(|a| a as _).ok_or("int")
col.as_primitive_opt::<Int32Type>()
.map(|a| a as _)
.ok_or("int")
}
&DataType::LONG => {
debug!("Pushing int64 array for {}", ColumnName::new(path));
col.as_primitive_opt::<Int64Type>().map(|a| a as _).ok_or("long")
col.as_primitive_opt::<Int64Type>()
.map(|a| a as _)
.ok_or("long")
}
DataType::Array(_) => {
debug!("Pushing list for {}", ColumnName::new(path));
Expand All @@ -252,14 +264,17 @@ impl ArrowEngineData {
}
data_type => {
return Err(Error::UnexpectedColumnType(format!(
"On {}: Unsupported type {data_type}", ColumnName::new(path)
"On {}: Unsupported type {data_type}",
ColumnName::new(path)
)));
}
};
result.map_err(|type_name| {
Error::UnexpectedColumnType(format!(
"Type mismatch on {}: expected {}, got {}",
ColumnName::new(path), type_name, col.data_type()
ColumnName::new(path),
type_name,
col.data_type()
))
})
}
Expand Down
58 changes: 29 additions & 29 deletions kernel/src/engine/ensure_data_types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
//! Helpers to ensure that kernel data types match arrow data types

use std::{collections::{HashMap, HashSet}, ops::Deref};
use std::{
collections::{HashMap, HashSet},
ops::Deref,
};

use arrow_schema::{DataType as ArrowDataType, Field as ArrowField};
use itertools::Itertools;
Expand Down Expand Up @@ -29,7 +32,9 @@ pub(crate) fn ensure_data_types(
arrow_type: &ArrowDataType,
check_nullability_and_metadata: bool,
) -> DeltaResult<DataTypeCompat> {
let check = EnsureDataTypes { check_nullability_and_metadata };
let check = EnsureDataTypes {
check_nullability_and_metadata,
};
check.ensure_data_types(kernel_type, arrow_type)
}

Expand Down Expand Up @@ -61,41 +66,32 @@ impl EnsureDataTypes {
}
// strings, bools, and binary aren't primitive in arrow
(&DataType::BOOLEAN, ArrowDataType::Boolean)
| (&DataType::STRING, ArrowDataType::Utf8)
| (&DataType::BINARY, ArrowDataType::Binary) => {
Ok(DataTypeCompat::Identical)
}
| (&DataType::STRING, ArrowDataType::Utf8)
| (&DataType::BINARY, ArrowDataType::Binary) => Ok(DataTypeCompat::Identical),
(DataType::Array(inner_type), ArrowDataType::List(arrow_list_field)) => {
self.ensure_nullability(
"List",
inner_type.contains_null,
arrow_list_field.is_nullable(),
)?;
self.ensure_data_types(
&inner_type.element_type,
arrow_list_field.data_type(),
)
self.ensure_data_types(&inner_type.element_type, arrow_list_field.data_type())
}
(DataType::Map(kernel_map_type), ArrowDataType::Map(arrow_map_type, _)) => {
let ArrowDataType::Struct(fields) = arrow_map_type.data_type() else {
return Err(make_arrow_error("Arrow map type wasn't a struct."));
};
let [key_type, value_type] = fields.deref() else {
return Err(make_arrow_error("Arrow map type didn't have expected key/value fields"));
return Err(make_arrow_error(
"Arrow map type didn't have expected key/value fields",
));
};
self.ensure_data_types(
&kernel_map_type.key_type,
key_type.data_type(),
)?;
self.ensure_data_types(&kernel_map_type.key_type, key_type.data_type())?;
self.ensure_nullability(
"Map",
kernel_map_type.value_contains_null,
value_type.is_nullable(),
)?;
self.ensure_data_types(
&kernel_map_type.value_type,
value_type.data_type(),
)?;
self.ensure_data_types(&kernel_map_type.value_type, value_type.data_type())?;
Ok(DataTypeCompat::Nested)
}
(DataType::Struct(kernel_fields), ArrowDataType::Struct(arrow_fields)) => {
Expand All @@ -109,10 +105,7 @@ impl EnsureDataTypes {
// ensure that for the fields that we found, the types match
for (kernel_field, arrow_field) in mapped_fields.zip(arrow_fields) {
self.ensure_nullability_and_metadata(kernel_field, arrow_field)?;
self.ensure_data_types(
&kernel_field.data_type,
arrow_field.data_type(),
)?;
self.ensure_data_types(&kernel_field.data_type, arrow_field.data_type())?;
found_fields += 1;
}

Expand Down Expand Up @@ -146,11 +139,12 @@ impl EnsureDataTypes {
kernel_field_is_nullable: bool,
arrow_field_is_nullable: bool,
) -> DeltaResult<()> {
if self.check_nullability_and_metadata && kernel_field_is_nullable != arrow_field_is_nullable {
if self.check_nullability_and_metadata
&& kernel_field_is_nullable != arrow_field_is_nullable
{
Err(Error::Generic(format!(
"{desc} has nullablily {} in kernel and {} in arrow",
kernel_field_is_nullable,
arrow_field_is_nullable,
kernel_field_is_nullable, arrow_field_is_nullable,
)))
} else {
Ok(())
Expand All @@ -160,10 +154,16 @@ impl EnsureDataTypes {
fn ensure_nullability_and_metadata(
&self,
kernel_field: &StructField,
arrow_field: &ArrowField
arrow_field: &ArrowField,
) -> DeltaResult<()> {
self.ensure_nullability(&kernel_field.name, kernel_field.nullable, arrow_field.is_nullable())?;
if self.check_nullability_and_metadata && !metadata_eq(&kernel_field.metadata, arrow_field.metadata()) {
self.ensure_nullability(
&kernel_field.name,
kernel_field.nullable,
arrow_field.is_nullable(),
)?;
if self.check_nullability_and_metadata
&& !metadata_eq(&kernel_field.metadata, arrow_field.metadata())
{
Err(Error::Generic(format!(
"Field {} has metadata {:?} in kernel and {:?} in arrow",
kernel_field.name,
Expand Down
24 changes: 9 additions & 15 deletions kernel/src/engine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,20 @@ pub(crate) mod arrow_conversion;
any(feature = "default-engine-base", feature = "sync-engine")
))]
pub mod arrow_expression;
#[cfg(feature = "arrow-expression")]
pub(crate) mod arrow_utils;

#[cfg(feature = "default-engine-base")]
pub mod default;

#[cfg(feature = "sync-engine")]
pub mod sync;

macro_rules! declare_modules {
( $(($vis:vis, $module:ident)),*) => {
$(
$vis mod $module;
)*
};
}

#[cfg(any(feature = "default-engine-base", feature = "sync-engine"))]
declare_modules!(
(pub, arrow_data),
(pub, parquet_row_group_skipping),
(pub(crate), arrow_get_data),
(pub(crate), arrow_utils),
(pub(crate), ensure_data_types)
);
pub mod arrow_data;
#[cfg(any(feature = "default-engine-base", feature = "sync-engine"))]
pub(crate) mod arrow_get_data;
#[cfg(any(feature = "default-engine-base", feature = "sync-engine"))]
pub(crate) mod ensure_data_types;
#[cfg(any(feature = "default-engine-base", feature = "sync-engine"))]
pub mod parquet_row_group_skipping;
8 changes: 6 additions & 2 deletions kernel/src/engine/parquet_row_group_skipping.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! An implementation of parquet row group skipping using data skipping predicates over footer stats.
use crate::expressions::{ColumnName, Expression, Scalar, UnaryExpression, BinaryExpression, VariadicExpression};
use crate::expressions::{
BinaryExpression, ColumnName, Expression, Scalar, UnaryExpression, VariadicExpression,
};
use crate::predicates::parquet_stats_skipping::ParquetStatsProvider;
use crate::schema::{DataType, PrimitiveType};
use chrono::{DateTime, Days};
Expand Down Expand Up @@ -231,7 +233,9 @@ pub(crate) fn compute_field_indices(
Column(name) => cols.extend([name.clone()]), // returns `()`, unlike `insert`
Struct(fields) => fields.iter().for_each(recurse),
Unary(UnaryExpression { expr, .. }) => recurse(expr),
Binary(BinaryExpression { left, right, .. }) => [left, right].iter().for_each(|e| recurse(e)),
Binary(BinaryExpression { left, right, .. }) => {
[left, right].iter().for_each(|e| recurse(e))
}
Variadic(VariadicExpression { exprs, .. }) => exprs.iter().for_each(recurse),
}
}
Expand Down
Loading
Loading