Skip to content

Commit

Permalink
fix(parquet): handle nested data types correctly (#20156)
Browse files Browse the repository at this point in the history
  • Loading branch information
wcy-fdu authored Jan 26, 2025
1 parent fa4c463 commit 1384d45
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 64 deletions.
36 changes: 24 additions & 12 deletions e2e_test/s3/file_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
def gen_data(file_num, item_num_per_file):
assert item_num_per_file % 2 == 0, \
f'item_num_per_file should be even to ensure sum(mark) == 0: {item_num_per_file}'

struct_type = pa.struct([
('field1', pa.int32()),
('field2', pa.string())
])

return [
[{
'id': file_id * item_num_per_file + item_id,
Expand Down Expand Up @@ -44,6 +50,7 @@ def gen_data(file_num, item_num_per_file):
'test_timestamptz_ms': pa.scalar(datetime.now().timestamp() * 1000, type=pa.timestamp('ms', tz='+00:00')),
'test_timestamptz_us': pa.scalar(datetime.now().timestamp() * 1000000, type=pa.timestamp('us', tz='+00:00')),
'test_timestamptz_ns': pa.scalar(datetime.now().timestamp() * 1000000000, type=pa.timestamp('ns', tz='+00:00')),
'nested_struct': pa.scalar((item_id, f'struct_value_{item_id}'), type=struct_type),
} for item_id in range(item_num_per_file)]
for file_id in range(file_num)
]
Expand All @@ -65,7 +72,7 @@ def _table():
print("test table function file scan")
cur.execute(f'''
SELECT
id,
id,
name,
sex,
mark,
Expand All @@ -89,7 +96,8 @@ def _table():
test_timestamptz_s,
test_timestamptz_ms,
test_timestamptz_us,
test_timestamptz_ns
test_timestamptz_ns,
nested_struct
FROM file_scan(
'parquet',
's3',
Expand All @@ -104,7 +112,6 @@ def _table():
except ValueError as e:
print(f"cur.fetchone() got ValueError: {e}")


print("file scan test pass")
# Execute a SELECT statement
cur.execute(f'''CREATE TABLE {_table()}(
Expand Down Expand Up @@ -132,8 +139,8 @@ def _table():
test_timestamptz_s timestamptz,
test_timestamptz_ms timestamptz,
test_timestamptz_us timestamptz,
test_timestamptz_ns timestamptz
test_timestamptz_ns timestamptz,
nested_struct STRUCT<"field1" int, "field2" varchar>
) WITH (
connector = 's3',
match_pattern = '*.parquet',
Expand Down Expand Up @@ -213,7 +220,8 @@ def _table():
test_timestamptz_s,
test_timestamptz_ms,
test_timestamptz_us,
test_timestamptz_ns
test_timestamptz_ns,
nested_struct
from {_table()} WITH (
connector = 's3',
match_pattern = '*.parquet',
Expand All @@ -230,7 +238,7 @@ def _table():
print('Sink into s3 in parquet encode...')
# Execute a SELECT statement
cur.execute(f'''CREATE TABLE test_parquet_sink_table(
id bigint primary key,\
id bigint primary key,
name TEXT,
sex bigint,
mark bigint,
Expand All @@ -254,7 +262,8 @@ def _table():
test_timestamptz_s timestamptz,
test_timestamptz_ms timestamptz,
test_timestamptz_us timestamptz,
test_timestamptz_ns timestamptz
test_timestamptz_ns timestamptz,
nested_struct STRUCT<"field1" int, "field2" varchar>,
) WITH (
connector = 's3',
match_pattern = 'test_parquet_sink/*.parquet',
Expand All @@ -263,8 +272,8 @@ def _table():
s3.credentials.access = 'hummockadmin',
s3.credentials.secret = 'hummockadmin',
s3.endpoint_url = 'http://hummock001.127.0.0.1:9301',
refresh.interval.sec = 1,
) FORMAT PLAIN ENCODE PARQUET;''')

total_rows = file_num * item_num_per_file
MAX_RETRIES = 40
for retry_no in range(MAX_RETRIES):
Expand Down Expand Up @@ -305,7 +314,8 @@ def _table():
test_timestamptz_s,
test_timestamptz_ms,
test_timestamptz_us,
test_timestamptz_ns
test_timestamptz_ns,
nested_struct
from {_table()} WITH (
connector = 'snowflake',
match_pattern = '*.parquet',
Expand All @@ -316,7 +326,8 @@ def _table():
s3.endpoint_url = 'http://hummock001.127.0.0.1:9301',
s3.path = 'test_json_sink/',
type = 'append-only',
force_append_only='true'
force_append_only='true',
refresh.interval.sec = 1,
) FORMAT PLAIN ENCODE JSON(force_append_only='true');''')

print('Sink into s3 in json encode...')
Expand Down Expand Up @@ -346,7 +357,8 @@ def _table():
test_timestamptz_s timestamptz,
test_timestamptz_ms timestamptz,
test_timestamptz_us timestamptz,
test_timestamptz_ns timestamptz
test_timestamptz_ns timestamptz,
nested_struct STRUCT<"field1" int, "field2" varchar>
) WITH (
connector = 's3',
match_pattern = 'test_json_sink/*.json',
Expand Down
5 changes: 5 additions & 0 deletions src/connector/src/parser/parquet_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use risingwave_common::util::tokio_util::compat::FuturesAsyncReadCompatExt;
use crate::parser::ConnectorResult;
use crate::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator;
use crate::source::filesystem::opendal_source::{OpendalGcs, OpendalPosixFs, OpendalS3};
use crate::source::iceberg::is_parquet_schema_match_source_schema;
use crate::source::reader::desc::SourceDesc;
use crate::source::{ConnectorProperties, SourceColumnDesc};
/// `ParquetParser` is responsible for converting the incoming `record_batch_stream`
Expand Down Expand Up @@ -109,6 +110,10 @@ impl ParquetParser {

if let Some(parquet_column) =
record_batch.column_by_name(rw_column_name)
&& is_parquet_schema_match_source_schema(
parquet_column.data_type(),
rw_data_type,
)
{
let arrow_field = IcebergArrowConvert
.to_arrow_field(rw_column_name, rw_data_type)?;
Expand Down
108 changes: 56 additions & 52 deletions src/connector/src/source/iceberg/parquet_file_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use parquet::file::metadata::{FileMetaData, ParquetMetaData, ParquetMetaDataRead
use risingwave_common::array::arrow::arrow_schema_udf::{DataType as ArrowDateType, IntervalUnit};
use risingwave_common::array::arrow::IcebergArrowConvert;
use risingwave_common::array::StreamChunk;
use risingwave_common::catalog::ColumnId;
use risingwave_common::catalog::{ColumnDesc, ColumnId};
use risingwave_common::types::DataType as RwDataType;
use risingwave_common::util::tokio_util::compat::FuturesAsyncReadCompatExt;
use url::Url;
Expand Down Expand Up @@ -217,55 +217,64 @@ pub async fn list_data_directory(
}
}

/// Extracts valid column indices from a Parquet file schema based on the user's requested schema.
/// Extracts a suitable `ProjectionMask` from a Parquet file schema based on the user's requested schema.
///
/// This function is used for column pruning of Parquet files. It calculates the intersection
/// between the columns in the currently read Parquet file and the schema provided by the user.
/// This is useful for reading a `RecordBatch` with the appropriate `ProjectionMask`, ensuring that
/// only the necessary columns are read.
/// This function is utilized for column pruning of Parquet files. It checks the user's requested schema
/// against the schema of the currently read Parquet file. If the provided `columns` are `None`
/// or if the Parquet file contains nested data types, it returns `ProjectionMask::all()`. Otherwise,
/// it returns only the columns where both the data type and column name match the requested schema,
/// facilitating efficient reading of the `RecordBatch`.
///
/// # Parameters
/// - `columns`: A vector of `Column` representing the user's requested schema.
/// - `columns`: An optional vector of `Column` representing the user's requested schema.
/// - `metadata`: A reference to `FileMetaData` containing the schema and metadata of the Parquet file.
///
/// # Returns
/// - A `ConnectorResult<Vec<usize>>`, which contains the indices of the valid columns in the
/// Parquet file schema that match the requested schema. If an error occurs during processing,
/// it returns an appropriate error.
pub fn extract_valid_column_indices(
rw_columns: Vec<Column>,
/// - A `ConnectorResult<ProjectionMask>`, which represents the valid columns in the Parquet file schema
/// that correspond to the requested schema. If an error occurs during processing, it returns an
/// appropriate error.
pub fn get_project_mask(
columns: Option<Vec<Column>>,
metadata: &FileMetaData,
) -> ConnectorResult<Vec<usize>> {
let parquet_column_names = metadata
.schema_descr()
.columns()
.iter()
.map(|c| c.name())
.collect_vec();
) -> ConnectorResult<ProjectionMask> {
match columns {
Some(rw_columns) => {
let root_column_names = metadata
.schema_descr()
.root_schema()
.get_fields()
.iter()
.map(|field| field.name())
.collect_vec();

let converted_arrow_schema =
parquet_to_arrow_schema(metadata.schema_descr(), metadata.key_value_metadata())
.map_err(anyhow::Error::from)?;
let converted_arrow_schema =
parquet_to_arrow_schema(metadata.schema_descr(), metadata.key_value_metadata())
.map_err(anyhow::Error::from)?;
let valid_column_indices: Vec<usize> = rw_columns
.iter()
.filter_map(|column| {
root_column_names
.iter()
.position(|&name| name == column.name)
.and_then(|pos| {
let arrow_data_type: &risingwave_common::array::arrow::arrow_schema_udf::DataType = converted_arrow_schema.field_with_name(&column.name).ok()?.data_type();
let rw_data_type: &risingwave_common::types::DataType = &column.data_type;
if is_parquet_schema_match_source_schema(arrow_data_type, rw_data_type) {
Some(pos)
} else {
None
}
})
})
.collect();

let valid_column_indices: Vec<usize> = rw_columns
.iter()
.filter_map(|column| {
parquet_column_names
.iter()
.position(|&name| name == column.name)
.and_then(|pos| {
let arrow_data_type: &risingwave_common::array::arrow::arrow_schema_udf::DataType = converted_arrow_schema.field(pos).data_type();
let rw_data_type: &risingwave_common::types::DataType = &column.data_type;

if is_parquet_schema_match_source_schema(arrow_data_type, rw_data_type) {
Some(pos)
} else {
None
}
})
})
.collect();
Ok(valid_column_indices)
Ok(ProjectionMask::roots(
metadata.schema_descr(),
valid_column_indices,
))
}
None => Ok(ProjectionMask::all()),
}
}

/// Reads a specified Parquet file and converts its content into a stream of chunks.
Expand All @@ -289,13 +298,7 @@ pub async fn read_parquet_file(
let parquet_metadata = reader.get_metadata().await.map_err(anyhow::Error::from)?;

let file_metadata = parquet_metadata.file_metadata();
let projection_mask = match rw_columns {
Some(columns) => {
let column_indices = extract_valid_column_indices(columns, file_metadata)?;
ProjectionMask::leaves(file_metadata.schema_descr(), column_indices)
}
None => ProjectionMask::all(),
};
let projection_mask = get_project_mask(rw_columns, file_metadata)?;

// For the Parquet format, we directly convert from a record batch to a stream chunk.
// Therefore, the offset of the Parquet file represents the current position in terms of the number of rows read from the file.
Expand All @@ -318,11 +321,12 @@ pub async fn read_parquet_file(
.enumerate()
.map(|(index, field_ref)| {
let data_type = IcebergArrowConvert.type_from_field(field_ref).unwrap();
SourceColumnDesc::simple(
let column_desc = ColumnDesc::named(
field_ref.name().clone(),
data_type,
ColumnId::new(index as i32),
)
data_type,
);
SourceColumnDesc::from(&column_desc)
})
.collect(),
};
Expand Down Expand Up @@ -367,7 +371,7 @@ pub async fn get_parquet_fields(
/// - Arrow's `UInt32` matches with RisingWave's `Int64`.
/// - Arrow's `UInt64` matches with RisingWave's `Decimal`.
/// - Arrow's `Float16` matches with RisingWave's `Float32`.
fn is_parquet_schema_match_source_schema(
pub fn is_parquet_schema_match_source_schema(
arrow_data_type: &ArrowDateType,
rw_data_type: &RwDataType,
) -> bool {
Expand Down

0 comments on commit 1384d45

Please sign in to comment.