From 1384d4592ba94ec873c5f513946eb7b7fbfa36e5 Mon Sep 17 00:00:00 2001 From: congyi wang <58715567+wcy-fdu@users.noreply.github.com> Date: Sun, 26 Jan 2025 18:40:16 +0800 Subject: [PATCH] fix(parquet): handle nested data types correctly (#20156) --- e2e_test/s3/file_sink.py | 36 ++++-- src/connector/src/parser/parquet_parser.rs | 5 + .../source/iceberg/parquet_file_handler.rs | 108 +++++++++--------- 3 files changed, 85 insertions(+), 64 deletions(-) diff --git a/e2e_test/s3/file_sink.py b/e2e_test/s3/file_sink.py index 6eca0e2b9194f..1979b7f6606ab 100644 --- a/e2e_test/s3/file_sink.py +++ b/e2e_test/s3/file_sink.py @@ -17,6 +17,12 @@ def gen_data(file_num, item_num_per_file): assert item_num_per_file % 2 == 0, \ f'item_num_per_file should be even to ensure sum(mark) == 0: {item_num_per_file}' + + struct_type = pa.struct([ + ('field1', pa.int32()), + ('field2', pa.string()) + ]) + return [ [{ 'id': file_id * item_num_per_file + item_id, @@ -44,6 +50,7 @@ def gen_data(file_num, item_num_per_file): 'test_timestamptz_ms': pa.scalar(datetime.now().timestamp() * 1000, type=pa.timestamp('ms', tz='+00:00')), 'test_timestamptz_us': pa.scalar(datetime.now().timestamp() * 1000000, type=pa.timestamp('us', tz='+00:00')), 'test_timestamptz_ns': pa.scalar(datetime.now().timestamp() * 1000000000, type=pa.timestamp('ns', tz='+00:00')), + 'nested_struct': pa.scalar((item_id, f'struct_value_{item_id}'), type=struct_type), } for item_id in range(item_num_per_file)] for file_id in range(file_num) ] @@ -65,7 +72,7 @@ def _table(): print("test table function file scan") cur.execute(f''' SELECT - id, + id, name, sex, mark, @@ -89,7 +96,8 @@ def _table(): test_timestamptz_s, test_timestamptz_ms, test_timestamptz_us, - test_timestamptz_ns + test_timestamptz_ns, + nested_struct FROM file_scan( 'parquet', 's3', @@ -104,7 +112,6 @@ def _table(): except ValueError as e: print(f"cur.fetchone() got ValueError: {e}") - print("file scan test pass") # Execute a SELECT statement cur.execute(f'''CREATE TABLE {_table()}( @@ -132,8 +139,8 @@ def _table(): test_timestamptz_s timestamptz, test_timestamptz_ms timestamptz, test_timestamptz_us timestamptz, - test_timestamptz_ns timestamptz - + test_timestamptz_ns timestamptz, + nested_struct STRUCT<"field1" int, "field2" varchar> ) WITH ( connector = 's3', match_pattern = '*.parquet', @@ -213,7 +220,8 @@ def _table(): test_timestamptz_s, test_timestamptz_ms, test_timestamptz_us, - test_timestamptz_ns + test_timestamptz_ns, + nested_struct from {_table()} WITH ( connector = 's3', match_pattern = '*.parquet', @@ -230,7 +238,7 @@ def _table(): print('Sink into s3 in parquet encode...') # Execute a SELECT statement cur.execute(f'''CREATE TABLE test_parquet_sink_table( - id bigint primary key,\ + id bigint primary key, name TEXT, sex bigint, mark bigint, @@ -254,7 +262,8 @@ def _table(): test_timestamptz_s timestamptz, test_timestamptz_ms timestamptz, test_timestamptz_us timestamptz, - test_timestamptz_ns timestamptz + test_timestamptz_ns timestamptz, + nested_struct STRUCT<"field1" int, "field2" varchar>, ) WITH ( connector = 's3', match_pattern = 'test_parquet_sink/*.parquet', @@ -263,8 +272,8 @@ def _table(): s3.credentials.access = 'hummockadmin', s3.credentials.secret = 'hummockadmin', s3.endpoint_url = 'http://hummock001.127.0.0.1:9301', + refresh.interval.sec = 1, ) FORMAT PLAIN ENCODE PARQUET;''') - total_rows = file_num * item_num_per_file MAX_RETRIES = 40 for retry_no in range(MAX_RETRIES): @@ -305,7 +314,8 @@ def _table(): test_timestamptz_s, test_timestamptz_ms, test_timestamptz_us, - test_timestamptz_ns + test_timestamptz_ns, + nested_struct from {_table()} WITH ( connector = 'snowflake', match_pattern = '*.parquet', @@ -316,7 +326,8 @@ def _table(): s3.endpoint_url = 'http://hummock001.127.0.0.1:9301', s3.path = 'test_json_sink/', type = 'append-only', - force_append_only='true' + force_append_only='true', + refresh.interval.sec = 1, ) FORMAT PLAIN ENCODE JSON(force_append_only='true');''') print('Sink into s3 in json encode...') @@ -346,7 +357,8 @@ def _table(): test_timestamptz_s timestamptz, test_timestamptz_ms timestamptz, test_timestamptz_us timestamptz, - test_timestamptz_ns timestamptz + test_timestamptz_ns timestamptz, + nested_struct STRUCT<"field1" int, "field2" varchar> ) WITH ( connector = 's3', match_pattern = 'test_json_sink/*.json', diff --git a/src/connector/src/parser/parquet_parser.rs b/src/connector/src/parser/parquet_parser.rs index 5bef3b6310981..29148c9305ddc 100644 --- a/src/connector/src/parser/parquet_parser.rs +++ b/src/connector/src/parser/parquet_parser.rs @@ -26,6 +26,7 @@ use risingwave_common::util::tokio_util::compat::FuturesAsyncReadCompatExt; use crate::parser::ConnectorResult; use crate::source::filesystem::opendal_source::opendal_enumerator::OpendalEnumerator; use crate::source::filesystem::opendal_source::{OpendalGcs, OpendalPosixFs, OpendalS3}; +use crate::source::iceberg::is_parquet_schema_match_source_schema; use crate::source::reader::desc::SourceDesc; use crate::source::{ConnectorProperties, SourceColumnDesc}; /// `ParquetParser` is responsible for converting the incoming `record_batch_stream` @@ -109,6 +110,10 @@ impl ParquetParser { if let Some(parquet_column) = record_batch.column_by_name(rw_column_name) + && is_parquet_schema_match_source_schema( + parquet_column.data_type(), + rw_data_type, + ) { let arrow_field = IcebergArrowConvert .to_arrow_field(rw_column_name, rw_data_type)?; diff --git a/src/connector/src/source/iceberg/parquet_file_handler.rs b/src/connector/src/source/iceberg/parquet_file_handler.rs index 187140855c29b..6baf30a5795e8 100644 --- a/src/connector/src/source/iceberg/parquet_file_handler.rs +++ b/src/connector/src/source/iceberg/parquet_file_handler.rs @@ -36,7 +36,7 @@ use parquet::file::metadata::{FileMetaData, ParquetMetaData, ParquetMetaDataRead use risingwave_common::array::arrow::arrow_schema_udf::{DataType as ArrowDateType, IntervalUnit}; use risingwave_common::array::arrow::IcebergArrowConvert; use risingwave_common::array::StreamChunk; -use risingwave_common::catalog::ColumnId; +use risingwave_common::catalog::{ColumnDesc, ColumnId}; use risingwave_common::types::DataType as RwDataType; use risingwave_common::util::tokio_util::compat::FuturesAsyncReadCompatExt; use url::Url; @@ -217,55 +217,64 @@ pub async fn list_data_directory( } } -/// Extracts valid column indices from a Parquet file schema based on the user's requested schema. +/// Extracts a suitable `ProjectionMask` from a Parquet file schema based on the user's requested schema. /// -/// This function is used for column pruning of Parquet files. It calculates the intersection -/// between the columns in the currently read Parquet file and the schema provided by the user. -/// This is useful for reading a `RecordBatch` with the appropriate `ProjectionMask`, ensuring that -/// only the necessary columns are read. +/// This function is utilized for column pruning of Parquet files. It checks the user's requested schema +/// against the schema of the currently read Parquet file. If the provided `columns` are `None` +/// or if the Parquet file contains nested data types, it returns `ProjectionMask::all()`. Otherwise, +/// it returns only the columns where both the data type and column name match the requested schema, +/// facilitating efficient reading of the `RecordBatch`. /// /// # Parameters -/// - `columns`: A vector of `Column` representing the user's requested schema. +/// - `columns`: An optional vector of `Column` representing the user's requested schema. /// - `metadata`: A reference to `FileMetaData` containing the schema and metadata of the Parquet file. /// /// # Returns -/// - A `ConnectorResult>`, which contains the indices of the valid columns in the -/// Parquet file schema that match the requested schema. If an error occurs during processing, -/// it returns an appropriate error. -pub fn extract_valid_column_indices( - rw_columns: Vec, +/// - A `ConnectorResult`, which represents the valid columns in the Parquet file schema +/// that correspond to the requested schema. If an error occurs during processing, it returns an +/// appropriate error. +pub fn get_project_mask( + columns: Option>, metadata: &FileMetaData, -) -> ConnectorResult> { - let parquet_column_names = metadata - .schema_descr() - .columns() - .iter() - .map(|c| c.name()) - .collect_vec(); +) -> ConnectorResult { + match columns { + Some(rw_columns) => { + let root_column_names = metadata + .schema_descr() + .root_schema() + .get_fields() + .iter() + .map(|field| field.name()) + .collect_vec(); - let converted_arrow_schema = - parquet_to_arrow_schema(metadata.schema_descr(), metadata.key_value_metadata()) - .map_err(anyhow::Error::from)?; + let converted_arrow_schema = + parquet_to_arrow_schema(metadata.schema_descr(), metadata.key_value_metadata()) + .map_err(anyhow::Error::from)?; + let valid_column_indices: Vec = rw_columns + .iter() + .filter_map(|column| { + root_column_names + .iter() + .position(|&name| name == column.name) + .and_then(|pos| { + let arrow_data_type: &risingwave_common::array::arrow::arrow_schema_udf::DataType = converted_arrow_schema.field_with_name(&column.name).ok()?.data_type(); + let rw_data_type: &risingwave_common::types::DataType = &column.data_type; + if is_parquet_schema_match_source_schema(arrow_data_type, rw_data_type) { + Some(pos) + } else { + None + } + }) + }) + .collect(); - let valid_column_indices: Vec = rw_columns - .iter() - .filter_map(|column| { - parquet_column_names - .iter() - .position(|&name| name == column.name) - .and_then(|pos| { - let arrow_data_type: &risingwave_common::array::arrow::arrow_schema_udf::DataType = converted_arrow_schema.field(pos).data_type(); - let rw_data_type: &risingwave_common::types::DataType = &column.data_type; - - if is_parquet_schema_match_source_schema(arrow_data_type, rw_data_type) { - Some(pos) - } else { - None - } - }) - }) - .collect(); - Ok(valid_column_indices) + Ok(ProjectionMask::roots( + metadata.schema_descr(), + valid_column_indices, + )) + } + None => Ok(ProjectionMask::all()), + } } /// Reads a specified Parquet file and converts its content into a stream of chunks. @@ -289,13 +298,7 @@ pub async fn read_parquet_file( let parquet_metadata = reader.get_metadata().await.map_err(anyhow::Error::from)?; let file_metadata = parquet_metadata.file_metadata(); - let projection_mask = match rw_columns { - Some(columns) => { - let column_indices = extract_valid_column_indices(columns, file_metadata)?; - ProjectionMask::leaves(file_metadata.schema_descr(), column_indices) - } - None => ProjectionMask::all(), - }; + let projection_mask = get_project_mask(rw_columns, file_metadata)?; // For the Parquet format, we directly convert from a record batch to a stream chunk. // Therefore, the offset of the Parquet file represents the current position in terms of the number of rows read from the file. @@ -318,11 +321,12 @@ pub async fn read_parquet_file( .enumerate() .map(|(index, field_ref)| { let data_type = IcebergArrowConvert.type_from_field(field_ref).unwrap(); - SourceColumnDesc::simple( + let column_desc = ColumnDesc::named( field_ref.name().clone(), - data_type, ColumnId::new(index as i32), - ) + data_type, + ); + SourceColumnDesc::from(&column_desc) }) .collect(), }; @@ -367,7 +371,7 @@ pub async fn get_parquet_fields( /// - Arrow's `UInt32` matches with RisingWave's `Int64`. /// - Arrow's `UInt64` matches with RisingWave's `Decimal`. /// - Arrow's `Float16` matches with RisingWave's `Float32`. -fn is_parquet_schema_match_source_schema( +pub fn is_parquet_schema_match_source_schema( arrow_data_type: &ArrowDateType, rw_data_type: &RwDataType, ) -> bool {