From 9b3b80510e7fa56933a4f8a7a82c26f136a9e182 Mon Sep 17 00:00:00 2001 From: RT_Enzyme <58059931+RTEnzyme@users.noreply.github.com> Date: Wed, 12 Jun 2024 16:34:52 +0800 Subject: [PATCH 01/34] replace and(.., not(...)) with and_not(..) (#10885) Co-authored-by: velosearch --- datafusion/physical-expr/src/expressions/case.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index c56229e07a63..08d8cd441334 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -26,7 +26,7 @@ use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::zip::zip; -use arrow::compute::{and, is_null, not, nullif, or, prep_null_mask_filter}; +use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::as_boolean_array; use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; @@ -168,7 +168,7 @@ impl CaseExpr { } }; - remainder = and(&remainder, ¬(&when_match)?)?; + remainder = and_not(&remainder, &when_match)?; } if let Some(e) = &self.else_expr { @@ -241,7 +241,7 @@ impl CaseExpr { // Succeed tuples should be filtered out for short-circuit evaluation, // null values for the current when expr should be kept - remainder = and(&remainder, ¬(&when_value)?)?; + remainder = and_not(&remainder, &when_value)?; } if let Some(e) = &self.else_expr { From 7f6fc07577f882d39db72e44ebabe0442a7bf016 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Wed, 12 Jun 2024 09:53:51 -0400 Subject: [PATCH 02/34] Disabling test for semi join with filters (#10887) --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 8c2e24de56b9..7dbbfb25bf78 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -179,7 +179,8 @@ async fn test_semi_join_1k() { .run_test() .await } - +// See https://github.com/apache/datafusion/issues/10886 +#[ignore] #[tokio::test] async fn test_semi_join_1k_filtered() { JoinFuzzTestCase::new( From 73381fe35738ef2f5a06e9f55626f08855e8a852 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 12 Jun 2024 11:17:14 -0400 Subject: [PATCH 03/34] Minor: Update `min_statistics` and `max_statistics` to be helpers, update docs (#10866) --- .../physical_plan/parquet/statistics.rs | 50 +++++++++++-------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index a4a919f20d0f..c0d36f1fc4d7 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`min_statistics`] and [`max_statistics`] convert statistics in parquet format to arrow [`ArrayRef`]. +//! [`StatisticsConverter`] to convert statistics in parquet format to arrow [`ArrayRef`]. // TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 @@ -542,8 +542,11 @@ pub(crate) fn parquet_column<'a>( Some((parquet_idx, field)) } -/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] -pub(crate) fn min_statistics<'a, I: Iterator>>( +/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an +/// [`ArrayRef`] +/// +/// This is an internal helper -- see [`StatisticsConverter`] for public API +fn min_statistics<'a, I: Iterator>>( data_type: &DataType, iterator: I, ) -> Result { @@ -551,7 +554,9 @@ pub(crate) fn min_statistics<'a, I: Iterator>>( +/// +/// This is an internal helper -- see [`StatisticsConverter`] for public API +fn max_statistics<'a, I: Iterator>>( data_type: &DataType, iterator: I, ) -> Result { @@ -1425,9 +1430,10 @@ mod test { assert_eq!(idx, 2); let row_groups = metadata.row_groups(); - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); + let converter = + StatisticsConverter::try_new("int_col", &schema, parquet_schema).unwrap(); - let min = min_statistics(&DataType::Int32, iter.clone()).unwrap(); + let min = converter.row_group_mins(row_groups.iter()).unwrap(); assert_eq!( &min, &expected_min, @@ -1435,7 +1441,7 @@ mod test { DisplayStats(row_groups) ); - let max = max_statistics(&DataType::Int32, iter).unwrap(); + let max = converter.row_group_maxes(row_groups.iter()).unwrap(); assert_eq!( &max, &expected_max, @@ -1623,22 +1629,23 @@ mod test { continue; } - let (idx, f) = - parquet_column(parquet_schema, &schema, field.name()).unwrap(); - assert_eq!(f, field); + let converter = + StatisticsConverter::try_new(field.name(), &schema, parquet_schema) + .unwrap(); - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); - let min = min_statistics(f.data_type(), iter.clone()).unwrap(); + assert_eq!(converter.arrow_field, field.as_ref()); + + let mins = converter.row_group_mins(row_groups.iter()).unwrap(); assert_eq!( - &min, + &mins, &expected_min, "Min. Statistics\n\n{}\n\n", DisplayStats(row_groups) ); - let max = max_statistics(f.data_type(), iter).unwrap(); + let maxes = converter.row_group_maxes(row_groups.iter()).unwrap(); assert_eq!( - &max, + &maxes, &expected_max, "Max. Statistics\n\n{}\n\n", DisplayStats(row_groups) @@ -1705,7 +1712,7 @@ mod test { self } - /// Reads the specified parquet file and validates that the exepcted min/max + /// Reads the specified parquet file and validates that the expected min/max /// values for the specified columns are as expected. fn run(self) { let path = PathBuf::from(parquet_test_data()).join(self.file_name); @@ -1723,14 +1730,13 @@ mod test { expected_max, } = expected_column; - let (idx, field) = - parquet_column(parquet_schema, arrow_schema, name).unwrap(); - - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); - let actual_min = min_statistics(field.data_type(), iter.clone()).unwrap(); + let converter = + StatisticsConverter::try_new(name, arrow_schema, parquet_schema) + .unwrap(); + let actual_min = converter.row_group_mins(row_groups.iter()).unwrap(); assert_eq!(&expected_min, &actual_min, "column {name}"); - let actual_max = max_statistics(field.data_type(), iter).unwrap(); + let actual_max = converter.row_group_maxes(row_groups.iter()).unwrap(); assert_eq!(&expected_max, &actual_max, "column {name}"); } } From 87d826703bfe05df292649adf6c30b2528c83ab2 Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Wed, 12 Jun 2024 18:34:22 +0200 Subject: [PATCH 04/34] chore: remove interval test (#10888) --- .../core/tests/parquet/arrow_statistics.rs | 81 +------------------ datafusion/core/tests/parquet/mod.rs | 80 +----------------- 2 files changed, 2 insertions(+), 159 deletions(-) diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 0e23e6824027..2ea18d7cf823 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -30,8 +30,7 @@ use arrow::datatypes::{ use arrow_array::{ make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, Float16Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, - IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, + Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, RecordBatch, StringArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, @@ -1061,84 +1060,6 @@ async fn test_dates_64_diff_rg_sizes() { .run(); } -#[tokio::test] -#[should_panic] -// Currently this test `should_panic` since statistics for `Intervals` -// are not supported and `IntervalMonthDayNano` cannot be written -// to parquet yet. -// Refer to issue: https://github.com/apache/arrow-rs/issues/5847 -// and https://github.com/apache/arrow-rs/blob/master/parquet/src/arrow/arrow_writer/mod.rs#L747 -async fn test_interval_diff_rg_sizes() { - // This creates a parquet files of 3 columns: - // "year_month" --> IntervalYearMonthArray - // "day_time" --> IntervalDayTimeArray - // "month_day_nano" --> IntervalMonthDayNanoArray - // - // The file is created by 4 record batches (each has a null row) - // each has 5 rows but then will be split into 2 row groups with size 13, 7 - let reader = TestReader { - scenario: Scenario::Interval, - row_per_group: 13, - } - .build() - .await; - - // TODO: expected values need to be changed once issue is resolved - // expected_min: Arc::new(IntervalYearMonthArray::from(vec![ - // IntervalYearMonthType::make_value(1, 10), - // IntervalYearMonthType::make_value(4, 13), - // ])), - // expected_max: Arc::new(IntervalYearMonthArray::from(vec![ - // IntervalYearMonthType::make_value(6, 51), - // IntervalYearMonthType::make_value(8, 53), - // ])), - Test { - reader: &reader, - expected_min: Arc::new(IntervalYearMonthArray::from(vec![None, None])), - expected_max: Arc::new(IntervalYearMonthArray::from(vec![None, None])), - expected_null_counts: UInt64Array::from(vec![2, 2]), - expected_row_counts: UInt64Array::from(vec![13, 7]), - column_name: "year_month", - } - .run(); - - // expected_min: Arc::new(IntervalDayTimeArray::from(vec![ - // IntervalDayTimeType::make_value(1, 10), - // IntervalDayTimeType::make_value(4, 13), - // ])), - // expected_max: Arc::new(IntervalDayTimeArray::from(vec![ - // IntervalDayTimeType::make_value(6, 51), - // IntervalDayTimeType::make_value(8, 53), - // ])), - Test { - reader: &reader, - expected_min: Arc::new(IntervalDayTimeArray::from(vec![None, None])), - expected_max: Arc::new(IntervalDayTimeArray::from(vec![None, None])), - expected_null_counts: UInt64Array::from(vec![2, 2]), - expected_row_counts: UInt64Array::from(vec![13, 7]), - column_name: "day_time", - } - .run(); - - // expected_min: Arc::new(IntervalMonthDayNanoArray::from(vec![ - // IntervalMonthDayNanoType::make_value(1, 10, 100), - // IntervalMonthDayNanoType::make_value(4, 13, 103), - // ])), - // expected_max: Arc::new(IntervalMonthDayNanoArray::from(vec![ - // IntervalMonthDayNanoType::make_value(6, 51, 501), - // IntervalMonthDayNanoType::make_value(8, 53, 503), - // ])), - Test { - reader: &reader, - expected_min: Arc::new(IntervalMonthDayNanoArray::from(vec![None, None])), - expected_max: Arc::new(IntervalMonthDayNanoArray::from(vec![None, None])), - expected_null_counts: UInt64Array::from(vec![2, 2]), - expected_row_counts: UInt64Array::from(vec![13, 7]), - column_name: "month_day_nano", - } - .run(); -} - #[tokio::test] async fn test_uint() { // This creates a parquet files of 4 columns named "u8", "u16", "u32", "u64" diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 5ab268beb92f..9546ab30c9e0 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -18,9 +18,7 @@ //! Parquet integration tests use crate::parquet::utils::MetricsFinder; use arrow::array::Decimal128Array; -use arrow::datatypes::{ - i256, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, -}; +use arrow::datatypes::i256; use arrow::{ array::{ make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, @@ -36,10 +34,6 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; -use arrow_array::{ - IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, -}; -use arrow_schema::IntervalUnit; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ datasource::{provider_as_source, TableProvider}, @@ -92,7 +86,6 @@ enum Scenario { Time32Millisecond, Time64Nanosecond, Time64Microsecond, - Interval, /// 7 Rows, for each i8, i16, i32, i64, u8, u16, u32, u64, f32, f64 /// -MIN, -100, -1, 0, 1, 100, MAX NumericLimits, @@ -921,71 +914,6 @@ fn make_dict_batch() -> RecordBatch { .unwrap() } -fn make_interval_batch(offset: i32) -> RecordBatch { - let schema = Schema::new(vec![ - Field::new( - "year_month", - DataType::Interval(IntervalUnit::YearMonth), - true, - ), - Field::new("day_time", DataType::Interval(IntervalUnit::DayTime), true), - Field::new( - "month_day_nano", - DataType::Interval(IntervalUnit::MonthDayNano), - true, - ), - ]); - let schema = Arc::new(schema); - - let ym_arr = IntervalYearMonthArray::from(vec![ - Some(IntervalYearMonthType::make_value(1 + offset, 10 + offset)), - Some(IntervalYearMonthType::make_value(2 + offset, 20 + offset)), - Some(IntervalYearMonthType::make_value(3 + offset, 30 + offset)), - None, - Some(IntervalYearMonthType::make_value(5 + offset, 50 + offset)), - ]); - - let dt_arr = IntervalDayTimeArray::from(vec![ - Some(IntervalDayTimeType::make_value(1 + offset, 10 + offset)), - Some(IntervalDayTimeType::make_value(2 + offset, 20 + offset)), - Some(IntervalDayTimeType::make_value(3 + offset, 30 + offset)), - None, - Some(IntervalDayTimeType::make_value(5 + offset, 50 + offset)), - ]); - - // Not yet implemented, refer to: - // https://github.com/apache/arrow-rs/blob/master/parquet/src/arrow/arrow_writer/mod.rs#L747 - let mdn_arr = IntervalMonthDayNanoArray::from(vec![ - Some(IntervalMonthDayNanoType::make_value( - 1 + offset, - 10 + offset, - 100 + (offset as i64), - )), - Some(IntervalMonthDayNanoType::make_value( - 2 + offset, - 20 + offset, - 200 + (offset as i64), - )), - Some(IntervalMonthDayNanoType::make_value( - 3 + offset, - 30 + offset, - 300 + (offset as i64), - )), - None, - Some(IntervalMonthDayNanoType::make_value( - 5 + offset, - 50 + offset, - 500 + (offset as i64), - )), - ]); - - RecordBatch::try_new( - schema, - vec![Arc::new(ym_arr), Arc::new(dt_arr), Arc::new(mdn_arr)], - ) - .unwrap() -} - fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Boolean => { @@ -1407,12 +1335,6 @@ fn create_data_batch(scenario: Scenario) -> Vec { ]), ] } - Scenario::Interval => vec![ - make_interval_batch(0), - make_interval_batch(1), - make_interval_batch(2), - make_interval_batch(3), - ], } } From dfdda7cb04f7f9b640da4f297ce1a16b08f3bf7b Mon Sep 17 00:00:00 2001 From: Arttu Date: Wed, 12 Jun 2024 18:40:40 +0200 Subject: [PATCH 05/34] fix: Ignore nullability of list elements when consuming Substrait (#10874) * Ignore nullability of list elements when consuming Substrait DataFusion (= Arrow) is quite strict about nullability, specifically, when using e.g. LogicalPlan::Values, the given schema must match the given literals exactly - including nullability. This is non-trivial to do when converting schema and literals separately. The existing implementation for from_substrait_literal already creates lists that are always nullable (see ScalarValue::new_list => array_into_list_array). This reverts part of https://github.com/apache/datafusion/pull/10640 to align from_substrait_type with that behavior. This is the error I was hitting: ``` ArrowError(InvalidArgumentError("column types must match schema types, expected List(Field { name: \"item\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }) but found List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) at column index 0"), None) ``` * use `Field::new_list_field` in `array_into_(large_)list_array` just for consistency, to reduce the places where "item" is written out * add a test for non-nullable lists --- datafusion/common/src/utils/mod.rs | 14 ++-- .../substrait/src/logical_plan/consumer.rs | 4 +- .../substrait/src/logical_plan/producer.rs | 14 ++-- .../substrait/tests/cases/logical_plans.rs | 32 +++++++-- .../non_nullable_lists.substrait.json | 71 +++++++++++++++++++ 5 files changed, 114 insertions(+), 21 deletions(-) create mode 100644 datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index ae444c2cb285..a0e4d1a76c03 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -354,7 +354,7 @@ pub fn longest_consecutive_prefix>( pub fn array_into_list_array(arr: ArrayRef) -> ListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); ListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), offsets, arr, None, @@ -366,7 +366,7 @@ pub fn array_into_list_array(arr: ArrayRef) -> ListArray { pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); LargeListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), offsets, arr, None, @@ -379,7 +379,7 @@ pub fn array_into_fixed_size_list_array( ) -> FixedSizeListArray { let list_size = list_size as i32; FixedSizeListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), list_size, arr, None, @@ -420,7 +420,7 @@ pub fn arrays_into_list_array( let data_type = arr[0].data_type().to_owned(); let values = arr.iter().map(|x| x.as_ref()).collect::>(); Ok(ListArray::new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new_list_field(data_type, true)), OffsetBuffer::from_lengths(lens), arrow::compute::concat(values.as_slice())?, None, @@ -435,7 +435,7 @@ pub fn arrays_into_list_array( /// use datafusion_common::utils::base_type; /// use std::sync::Arc; /// -/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// assert_eq!(base_type(&data_type), DataType::Int32); /// /// let data_type = DataType::Int32; @@ -458,10 +458,10 @@ pub fn base_type(data_type: &DataType) -> DataType { /// use datafusion_common::utils::coerced_type_with_base_type_only; /// use std::sync::Arc; /// -/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// let base_type = DataType::Float64; /// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); -/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true)))); +/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 648a281832e1..3f9a895d951c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1395,7 +1395,9 @@ fn from_substrait_type( })?; let field = Arc::new(Field::new_list_field( from_substrait_type(inner_type, dfs_names, name_idx)?, - is_substrait_type_nullable(inner_type)?, + // We ignore Substrait's nullability here to match to_substrait_literal + // which always creates nullable lists + true, )); match list.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 88dc894eccd2..c0469d333164 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2309,14 +2309,12 @@ mod test { round_trip_type(DataType::Decimal128(10, 2))?; round_trip_type(DataType::Decimal256(30, 2))?; - for nullable in [true, false] { - round_trip_type(DataType::List( - Field::new_list_field(DataType::Int32, nullable).into(), - ))?; - round_trip_type(DataType::LargeList( - Field::new_list_field(DataType::Int32, nullable).into(), - ))?; - } + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, true).into(), + ))?; round_trip_type(DataType::Struct( vec![ diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs index 994a932c30e0..94572e098b2c 100644 --- a/datafusion/substrait/tests/cases/logical_plans.rs +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -20,6 +20,7 @@ #[cfg(test)] mod tests { use datafusion::common::Result; + use datafusion::dataframe::DataFrame; use datafusion::prelude::{CsvReadOptions, SessionContext}; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; use std::fs::File; @@ -38,11 +39,7 @@ mod tests { // File generated with substrait-java's Isthmus: // ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)" - let path = "tests/testdata/select_not_bool.substrait.json"; - let proto = serde_json::from_reader::<_, Plan>(BufReader::new( - File::open(path).expect("file not found"), - )) - .expect("failed to parse json"); + let proto = read_json("tests/testdata/select_not_bool.substrait.json"); let plan = from_substrait_plan(&ctx, &proto).await?; @@ -54,6 +51,31 @@ mod tests { Ok(()) } + #[tokio::test] + async fn non_nullable_lists() -> Result<()> { + // DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable. + // That's because implementing the non-nullability consistently is non-trivial. + // This test confirms that reading a plan with non-nullable lists works as expected. + let ctx = create_context().await?; + let proto = read_json("tests/testdata/non_nullable_lists.substrait.json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + + assert_eq!(format!("{:?}", &plan), "Values: (List([1, 2]))"); + + // Need to trigger execution to ensure that Arrow has validated the plan + DataFrame::new(ctx.state(), plan).show().await?; + + Ok(()) + } + + fn read_json(path: &str) -> Plan { + serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json") + } + async fn create_context() -> datafusion::common::Result { let ctx = SessionContext::new(); ctx.register_csv("DATA", "tests/testdata/data.csv", CsvReadOptions::new()) diff --git a/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json b/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json new file mode 100644 index 000000000000..e1c5574f8bec --- /dev/null +++ b/datafusion/substrait/tests/testdata/non_nullable_lists.substrait.json @@ -0,0 +1,71 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [ + { + "root": { + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "col" + ], + "struct": { + "types": [ + { + "list": { + "type": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [ + { + "fields": [ + { + "list": { + "values": [ + { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + }, + { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + ] + }, + "nullable": false, + "typeVariationReference": 0 + } + ] + } + ] + } + } + }, + "names": [ + "col" + ] + } + } + ], + "expectedTypeUrls": [] +} From 908a3a1d2feea1b1ae8c6220dcdb9e8264dd27ad Mon Sep 17 00:00:00 2001 From: Oleks V Date: Wed, 12 Jun 2024 14:46:50 -0700 Subject: [PATCH 06/34] Minor: SMJ fuzz tests fix for rowcounts (#10891) * Fix: Sort Merge Join crashes on TPCH Q21 * Fix LeftAnti SMJ join when the join filter is set * rm dbg * Minor: Fix fuzz testing row counts --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 7dbbfb25bf78..a893e780581f 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -55,7 +55,7 @@ async fn test_inner_join_1k() { .await } -fn less_than_10_join_filter(schema1: Arc, _schema2: Arc) -> JoinFilter { +fn less_than_100_join_filter(schema1: Arc, _schema2: Arc) -> JoinFilter { let less_than_100 = Arc::new(BinaryExpr::new( Arc::new(Column::new("a", 0)), Operator::Lt, @@ -77,7 +77,7 @@ async fn test_inner_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Inner, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -113,7 +113,7 @@ async fn test_left_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Left, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -138,7 +138,7 @@ async fn test_right_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Right, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -162,7 +162,7 @@ async fn test_full_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Full, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -179,15 +179,14 @@ async fn test_semi_join_1k() { .run_test() .await } -// See https://github.com/apache/datafusion/issues/10886 -#[ignore] + #[tokio::test] async fn test_semi_join_1k_filtered() { JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftSemi, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -213,7 +212,7 @@ async fn test_anti_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftAnti, - Some(Box::new(less_than_10_join_filter)), + Some(Box::new(less_than_100_join_filter)), ) .run_test() .await @@ -392,6 +391,15 @@ impl JoinFuzzTestCase { let hj = self.hash_join(); let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); + // Get actual row counts(without formatting overhead) for HJ and SMJ + let hj_rows = hj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); + let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); + + assert_eq!( + hj_rows, smj_rows, + "SortMergeJoinExec and HashJoinExec produced different row counts" + ); + let nlj = self.nested_loop_join(); let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); @@ -414,21 +422,20 @@ impl JoinFuzzTestCase { nlj_formatted.trim().lines().collect(); nlj_formatted_sorted.sort_unstable(); - assert_eq!( - smj_formatted_sorted.len(), - hj_formatted_sorted.len(), - "SortMergeJoinExec and HashJoinExec produced different row counts" - ); - for (i, (smj_line, hj_line)) in smj_formatted_sorted - .iter() - .zip(&hj_formatted_sorted) - .enumerate() - { - assert_eq!( - (i, smj_line), - (i, hj_line), - "SortMergeJoinExec and HashJoinExec produced different results" - ); + // row level compare if any of joins returns the result + // the reason is different formatting when there is no rows + if smj_rows > 0 || hj_rows > 0 { + for (i, (smj_line, hj_line)) in smj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, smj_line), + (i, hj_line), + "SortMergeJoinExec and HashJoinExec produced different results" + ); + } } for (i, (nlj_line, hj_line)) in nlj_formatted_sorted From 8f718dd3ce291c9f5688144ca6c9d7d854dc4b0b Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Thu, 13 Jun 2024 07:54:39 +0800 Subject: [PATCH 07/34] Move `Count` to `functions-aggregate`, update MSRV to rust 1.75 (#10484) * mv accumulate indices Signed-off-by: jayzhan211 * complete udaf Signed-off-by: jayzhan211 * register Signed-off-by: jayzhan211 * fix expr Signed-off-by: jayzhan211 * filter distinct count Signed-off-by: jayzhan211 * todo: need to move count distinct too Signed-off-by: jayzhan211 * move code around Signed-off-by: jayzhan211 * move distinct to aggr-crate Signed-off-by: jayzhan211 * replace Signed-off-by: jayzhan211 * backup Signed-off-by: jayzhan211 * fix function name and physical expr Signed-off-by: jayzhan211 * fix physical optimizer Signed-off-by: jayzhan211 * fix all slt Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix with args Signed-off-by: jayzhan211 * add label Signed-off-by: jayzhan211 * revert builtin related code back Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * fix substrait Signed-off-by: jayzhan211 * fix doc Signed-off-by: jayzhan211 * fmy Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 * fix udaf macro for distinct but not apply Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix count distinct and use workspace Signed-off-by: jayzhan211 * add reverse Signed-off-by: jayzhan211 * remove old code Signed-off-by: jayzhan211 * backup Signed-off-by: jayzhan211 * use macro Signed-off-by: jayzhan211 * expr builder Signed-off-by: jayzhan211 * introduce expr builder Signed-off-by: jayzhan211 * add example Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * clean agg sta Signed-off-by: jayzhan211 * combine agg Signed-off-by: jayzhan211 * limit distinct and fmt Signed-off-by: jayzhan211 * cleanup name Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * fix window Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix merged Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 * fix rebase Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * use std Signed-off-by: jayzhan211 * update mrsv Signed-off-by: jayzhan211 * upd msrv Signed-off-by: jayzhan211 * revert test Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * downgrade to 1.75 Signed-off-by: jayzhan211 * 1.76 Signed-off-by: jayzhan211 * ahas Signed-off-by: jayzhan211 * revert to 1.75 Signed-off-by: jayzhan211 * rm count Signed-off-by: jayzhan211 * fix merge Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 * rm sum in test_no_duplicate_name Signed-off-by: jayzhan211 * fix Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- Cargo.toml | 4 +- datafusion-cli/Cargo.lock | 2 + datafusion-cli/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- datafusion/core/src/dataframe/mod.rs | 13 +- .../aggregate_statistics.rs | 79 +- .../combine_partial_final_agg.rs | 47 +- .../limited_distinct_aggregation.rs | 16 +- .../core/src/physical_optimizer/test_utils.rs | 5 +- datafusion/core/src/physical_planner.rs | 1 - .../provider_filter_pushdown.rs | 1 + datafusion/core/tests/dataframe/mod.rs | 11 +- .../core/tests/fuzz_cases/window_fuzz.rs | 5 +- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/expr_fn.rs | 2 + datafusion/functions-aggregate/src/count.rs | 562 ++++++++++++++ datafusion/functions-aggregate/src/lib.rs | 8 +- datafusion/optimizer/src/decorrelate.rs | 10 +- .../src/single_distinct_to_groupby.rs | 3 +- datafusion/physical-expr-common/Cargo.toml | 2 + .../src/aggregate/count_distinct/bytes.rs | 6 +- .../src/aggregate/count_distinct/mod.rs | 23 + .../src/aggregate/count_distinct/native.rs | 23 +- .../physical-expr-common/src/aggregate/mod.rs | 1 + .../src/binary_map.rs | 21 +- datafusion/physical-expr-common/src/lib.rs | 1 + .../physical-expr/src/aggregate/build_in.rs | 92 +-- .../physical-expr/src/aggregate/count.rs | 348 --------- .../src/aggregate/count_distinct/mod.rs | 718 ------------------ .../src/aggregate/groups_accumulator/mod.rs | 2 +- datafusion/physical-expr/src/aggregate/mod.rs | 2 - .../physical-expr/src/expressions/mod.rs | 2 - datafusion/physical-expr/src/lib.rs | 4 +- .../src/aggregates/group_values/bytes.rs | 2 +- .../physical-plan/src/aggregates/mod.rs | 19 +- .../src/windows/bounded_window_agg_exec.rs | 7 +- datafusion/physical-plan/src/windows/mod.rs | 4 +- datafusion/proto-common/Cargo.toml | 2 +- datafusion/proto-common/gen/Cargo.toml | 2 +- datafusion/proto/Cargo.toml | 2 +- datafusion/proto/gen/Cargo.toml | 2 +- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 17 + datafusion/proto/src/generated/prost.rs | 2 + .../proto/src/logical_plan/from_proto.rs | 2 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 + .../proto/src/physical_plan/to_proto.rs | 15 +- .../tests/cases/roundtrip_logical_plan.rs | 2 + .../tests/cases/roundtrip_physical_plan.rs | 33 +- datafusion/sqllogictest/test_files/errors.slt | 4 +- datafusion/substrait/Cargo.toml | 2 +- .../substrait/src/logical_plan/consumer.rs | 12 +- 52 files changed, 822 insertions(+), 1329 deletions(-) create mode 100644 datafusion/functions-aggregate/src/count.rs rename datafusion/{physical-expr => physical-expr-common}/src/aggregate/count_distinct/bytes.rs (93%) create mode 100644 datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs rename datafusion/{physical-expr => physical-expr-common}/src/aggregate/count_distinct/native.rs (93%) rename datafusion/{physical-expr => physical-expr-common}/src/binary_map.rs (98%) delete mode 100644 datafusion/physical-expr/src/aggregate/count.rs delete mode 100644 datafusion/physical-expr/src/aggregate/count_distinct/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 65ef191d7421..aa1ba1f214d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,7 +52,7 @@ homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" -rust-version = "1.73" +rust-version = "1.75" version = "39.0.0" [workspace.dependencies] @@ -107,7 +107,7 @@ doc-comment = "0.3" env_logger = "0.11" futures = "0.3" half = { version = "2.2.1", default-features = false } -hashbrown = { version = "0.14", features = ["raw"] } +hashbrown = { version = "0.14.5", features = ["raw"] } indexmap = "2.0.0" itertools = "0.12" log = "^0.4" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 932f44d98486..c5b34df4f1cf 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1376,9 +1376,11 @@ dependencies = [ name = "datafusion-physical-expr-common" version = "39.0.0" dependencies = [ + "ahash", "arrow", "datafusion-common", "datafusion-expr", + "hashbrown 0.14.5", "rand", ] diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 5e393246b958..8f4b3cd81f36 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -26,7 +26,7 @@ license = "Apache-2.0" homepage = "https://datafusion.apache.org" repository = "https://github.com/apache/datafusion" # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" readme = "README.md" [dependencies] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 7533e2cff198..45617d88dc0c 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -30,7 +30,7 @@ authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with # "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" # https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.73" +rust-version = "1.75" [lints] workspace = true diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 06a85d303687..950cb7ddb2d3 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -50,12 +50,11 @@ use datafusion_common::{ }; use datafusion_expr::lit; use datafusion_expr::{ - avg, count, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, + avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, }; use datafusion_expr::{case, is_null}; -use datafusion_functions_aggregate::expr_fn::sum; -use datafusion_functions_aggregate::expr_fn::{median, stddev}; +use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum}; use async_trait::async_trait; @@ -854,10 +853,7 @@ impl DataFrame { /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate( - vec![], - vec![datafusion_expr::count(Expr::Literal(COUNT_STAR_EXPANSION))], - )? + .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? .collect() .await?; let len = *rows @@ -1594,9 +1590,10 @@ mod tests { use datafusion_common::{Constraint, Constraints}; use datafusion_common_runtime::SpawnedTask; use datafusion_expr::{ - array_agg, cast, count_distinct, create_udf, expr, lit, BuiltInWindowFunction, + array_agg, cast, create_udf, expr, lit, BuiltInWindowFunction, ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, }; + use datafusion_functions_aggregate::expr_fn::count_distinct; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 05f05d95b8db..eeacc48b85db 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -170,38 +170,6 @@ fn take_optimizable_column_and_table_count( } } } - // TODO: Remove this after revmoing Builtin Count - else if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( - &stats.num_rows, - agg_expr.as_any().downcast_ref::(), - ) { - // TODO implementing Eq on PhysicalExpr would help a lot here - if casted_expr.expressions().len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - let current_val = &col_stats[col_expr.index()].null_count; - if let &Precision::Exact(val) = current_val { - return Some(( - ScalarValue::Int64(Some((num_rows - val) as i64)), - casted_expr.name().to_string(), - )); - } - } else if let Some(lit_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - casted_expr.name().to_owned(), - )); - } - } - } - } None } @@ -307,13 +275,12 @@ fn take_optimizable_max( #[cfg(test)] pub(crate) mod tests { - use super::*; + use crate::logical_expr::Operator; use crate::physical_plan::aggregates::PhysicalGroupBy; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; - use crate::physical_plan::expressions::Count; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::memory::MemoryExec; use crate::prelude::SessionContext; @@ -322,8 +289,10 @@ pub(crate) mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_int64_array; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::cast; use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use datafusion_physical_plan::aggregates::AggregateMode; /// Mock data using a MemoryExec which has an exact count statistic @@ -414,13 +383,19 @@ pub(crate) mod tests { Self::ColumnA(schema.clone()) } - /// Return appropriate expr depending if COUNT is for col or table (*) - pub(crate) fn count_expr(&self) -> Arc { - Arc::new(Count::new( - self.column(), + // Return appropriate expr depending if COUNT is for col or table (*) + pub(crate) fn count_expr(&self, schema: &Schema) -> Arc { + create_aggregate_expr( + &count_udaf(), + &[self.column()], + &[], + &[], + schema, self.column_name(), - DataType::Int64, - )) + false, + false, + ) + .unwrap() } /// what argument would this aggregate need in the plan? @@ -458,7 +433,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -467,7 +442,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -488,7 +463,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -497,7 +472,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -517,7 +492,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -529,7 +504,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -549,7 +524,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], source, Arc::clone(&schema), @@ -561,7 +536,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -592,7 +567,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], filter, Arc::clone(&schema), @@ -601,7 +576,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -637,7 +612,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], filter, Arc::clone(&schema), @@ -646,7 +621,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![agg.count_expr(&schema)], vec![None], Arc::new(partial_agg), Arc::clone(&schema), diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 3ad61e52c82e..38b92959e841 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -206,8 +206,9 @@ mod tests { use crate::physical_plan::{displayable, Partitioning}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; - use datafusion_physical_expr::expressions::{col, Count}; + use datafusion_physical_expr::expressions::col; use datafusion_physical_plan::udaf::create_aggregate_expr; /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected @@ -303,15 +304,31 @@ mod tests { ) } + // Return appropriate expr depending if COUNT is for col or table (*) + fn count_expr( + expr: Arc, + name: &str, + schema: &Schema, + ) -> Arc { + create_aggregate_expr( + &count_udaf(), + &[expr], + &[], + &[], + schema, + name, + false, + false, + ) + .unwrap() + } + #[test] fn aggregations_not_combined() -> Result<()> { let schema = schema(); - let aggr_expr = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; + let plan = final_aggregate_exec( repartition_exec(partial_aggregate_exec( parquet_exec(&schema), @@ -330,16 +347,8 @@ mod tests { ]; assert_optimized!(expected, plan); - let aggr_expr1 = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; - let aggr_expr2 = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(2)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr1 = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; + let aggr_expr2 = vec![count_expr(lit(1i8), "COUNT(2)", &schema)]; let plan = final_aggregate_exec( partial_aggregate_exec( @@ -365,11 +374,7 @@ mod tests { #[test] fn aggregations_combined() -> Result<()> { let schema = schema(); - let aggr_expr = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; + let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; let plan = final_aggregate_exec( partial_aggregate_exec( diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 1274fbe50a5f..f9d5a4c186ee 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -517,10 +517,10 @@ mod tests { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr()], /* aggr_expr */ - vec![None], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![agg.count_expr(&schema)], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), @@ -554,10 +554,10 @@ mod tests { let single_agg = AggregateExec::try_new( AggregateMode::Single, build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr()], /* aggr_expr */ - vec![filter_expr], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ + vec![agg.count_expr(&schema)], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ )?; let limit_exec = LocalLimitExec::new( Arc::new(single_agg), diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 5895c39a5f87..154e77cd23ae 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -43,7 +43,8 @@ use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; +use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use datafusion_physical_plan::displayable; @@ -240,7 +241,7 @@ pub fn bounded_window_exec( Arc::new( crate::physical_plan::windows::BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 79033643cf37..4f9187595018 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2181,7 +2181,6 @@ impl DefaultPhysicalPlanner { expr: &[Expr], ) -> Result> { let input_schema = input.as_ref().schema(); - let physical_exprs = expr .iter() .map(|e| { diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index 8c9cffcf08d1..068383b20031 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -35,6 +35,7 @@ use datafusion::scalar::ScalarValue; use datafusion_common::cast::as_primitive_array; use datafusion_common::{internal_err, not_impl_err}; use datafusion_expr::expr::{BinaryExpr, Cast}; +use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; use async_trait::async_trait; diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index befd98d04302..fa364c5f2a65 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -31,6 +31,7 @@ use arrow::{ }; use arrow_array::Float32Array; use arrow_schema::ArrowError; +use datafusion_functions_aggregate::count::count_udaf; use object_store::local::LocalFileSystem; use std::fs; use std::sync::Arc; @@ -51,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col, - placeholder, scalar_subquery, when, wildcard, AggregateFunction, Expr, ExprSchemable, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, + placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion_functions_aggregate::expr_fn::sum; +use datafusion_functions_aggregate::expr_fn::{count, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { @@ -178,7 +179,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index b85f6376c3f2..4358691ee5a5 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -38,6 +38,7 @@ use datafusion_expr::{ AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; @@ -165,7 +166,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), // its name "COUNT", // window function argument @@ -350,7 +351,7 @@ fn get_random_function( window_fn_map.insert( "count", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![arg.clone()], ), ); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 98ab8ec251f4..57f5414c13bd 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1861,6 +1861,7 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { null_treatment, }) => { write_function_name(w, &fun.to_string(), false, args)?; + if let Some(nt) = null_treatment { w.write_str(" ")?; write!(w, "{}", nt)?; @@ -1885,7 +1886,6 @@ fn write_name(w: &mut W, e: &Expr) -> Result<()> { null_treatment, }) => { write_function_name(w, func_def.name(), *distinct, args)?; - if let Some(fe) = filter { write!(w, " FILTER (WHERE {fe})")?; }; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 420312050870..1fafc63e9665 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -193,6 +193,7 @@ pub fn avg(expr: Expr) -> Expr { } /// Create an expression to represent the count() aggregate function +// TODO: Remove this and use `expr_fn::count` instead pub fn count(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( aggregate_function::AggregateFunction::Count, @@ -250,6 +251,7 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { } /// Create an expression to represent the count(distinct) aggregate function +// TODO: Remove this and use `expr_fn::count_distinct` instead pub fn count_distinct(expr: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( aggregate_function::AggregateFunction::Count, diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs new file mode 100644 index 000000000000..cfd56619537b --- /dev/null +++ b/datafusion/functions-aggregate/src/count.rs @@ -0,0 +1,562 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ahash::RandomState; +use std::collections::HashSet; +use std::ops::BitAnd; +use std::{fmt::Debug, sync::Arc}; + +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::{ + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, +}; + +use arrow::{ + array::{Array, BooleanArray, Int64Array, PrimitiveArray}, + buffer::BooleanBuffer, +}; +use datafusion_common::{ + downcast_value, internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::{ + function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, + EmitTo, GroupsAccumulator, Signature, Volatility, +}; +use datafusion_expr::{Expr, ReversedUDAF}; +use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::accumulate_indices; +use datafusion_physical_expr_common::{ + aggregate::count_distinct::{ + BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, + PrimitiveDistinctCountAccumulator, + }, + binary_map::OutputType, +}; + +make_udaf_expr_and_func!( + Count, + count, + expr, + "Count the number of non-null values in the column", + count_udaf +); + +pub fn count_distinct(expr: Expr) -> datafusion_expr::Expr { + datafusion_expr::Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction::new_udf( + count_udaf(), + vec![expr], + true, + None, + None, + None, + ), + ) +} + +pub struct Count { + signature: Signature, + aliases: Vec, +} + +impl Debug for Count { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Count") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Count { + fn default() -> Self { + Self::new() + } +} + +impl Count { + pub fn new() -> Self { + Self { + aliases: vec!["count".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Count { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "COUNT" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + Ok(vec![Field::new_list( + format_state_name(args.name, "count distinct"), + Field::new("item", args.input_type.clone(), true), + false, + )]) + } else { + Ok(vec![Field::new( + format_state_name(args.name, "count"), + DataType::Int64, + true, + )]) + } + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if !acc_args.is_distinct { + return Ok(Box::new(CountAccumulator::new())); + } + + let data_type = acc_args.input_type; + Ok(match data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator + DataType::Int8 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt8 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal128Type, + >::new(data_type)), + DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal256Type, + >::new(data_type)), + + DataType::Date32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Date64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Millisecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Time32(TimeUnit::Second) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Microsecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Time64(TimeUnit::Nanosecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Second, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + + DataType::Float16 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float32 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float64 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + + DataType::Utf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::LargeUtf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + + // Use the generic accumulator based on `ScalarValue` for all other types + _ => Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: data_type.clone(), + }), + }) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + // groups accumulator only supports `COUNT(c1)`, not + // `COUNT(c1, c2)`, etc + if args.is_distinct { + return false; + } + args.args_num == 1 + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + // instantiate specialized accumulator + Ok(Box::new(CountGroupsAccumulator::new())) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +#[derive(Debug)] +struct CountAccumulator { + count: i64, +} + +impl CountAccumulator { + /// new count accumulator + pub fn new() -> Self { + Self { count: 0 } + } +} + +impl Accumulator for CountAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Int64(Some(self.count))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count += (array.len() - null_count_for_multiple_cols(values)) as i64; + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], Int64Array); + let delta = &arrow::compute::sum(counts); + if let Some(d) = delta { + self.count += *d; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.count))) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +/// An accumulator to compute the counts of [`PrimitiveArray`]. +/// Stores values as native types, and does overflow checking +/// +/// Unlike most other accumulators, COUNT never produces NULLs. If no +/// non-null values are seen in any group the output is 0. Thus, this +/// accumulator has no additional null or seen filter tracking. +#[derive(Debug)] +struct CountGroupsAccumulator { + /// Count per group. + /// + /// Note this is an i64 and not a u64 (or usize) because the + /// output type of count is `DataType::Int64`. Thus by using `i64` + /// for the counts, the output [`Int64Array`] can be created + /// without copy. + counts: Vec, +} + +impl CountGroupsAccumulator { + pub fn new() -> Self { + Self { counts: vec![] } + } +} + +impl GroupsAccumulator for CountGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = &values[0]; + + // Add one to each group's counter for each non null, non + // filtered value + self.counts.resize(total_num_groups, 0); + accumulate_indices( + group_indices, + values.logical_nulls().as_ref(), + opt_filter, + |group_index| { + self.counts[group_index] += 1; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values[0].as_primitive::(); + + // intermediate counts are always created as non null + assert_eq!(partial_counts.null_count(), 0); + let partial_counts = partial_counts.values(); + + // Adds the counts with the partial counts + self.counts.resize(total_num_groups, 0); + match opt_filter { + Some(filter) => filter + .iter() + .zip(group_indices.iter()) + .zip(partial_counts.iter()) + .for_each(|((filter_value, &group_index), partial_count)| { + if let Some(true) = filter_value { + self.counts[group_index] += partial_count; + } + }), + None => group_indices.iter().zip(partial_counts.iter()).for_each( + |(&group_index, partial_count)| { + self.counts[group_index] += partial_count; + }, + ), + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.counts); + + // Count is always non null (null inputs just don't contribute to the overall values) + let nulls = None; + let array = PrimitiveArray::::new(counts.into(), nulls); + + Ok(Arc::new(array)) + } + + // return arrays for counts + fn state(&mut self, emit_to: EmitTo) -> Result> { + let counts = emit_to.take_needed(&mut self.counts); + let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls + Ok(vec![Arc::new(counts) as ArrayRef]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + } +} + +/// count null values for multiple columns +/// for each row if one column value is null, then null_count + 1 +fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { + if values.len() > 1 { + let result_bool_buf: Option = values + .iter() + .map(|a| a.logical_nulls()) + .fold(None, |acc, b| match (acc, b) { + (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), + (Some(acc), None) => Some(acc), + (None, Some(b)) => Some(b.into_inner()), + _ => None, + }); + result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) + } else { + values[0] + .logical_nulls() + .map_or(0, |nulls| nulls.null_count()) + } +} + +/// General purpose distinct accumulator that works for any DataType by using +/// [`ScalarValue`]. +/// +/// It stores intermediate results as a `ListArray` +/// +/// Note that many types have specialized accumulators that are (much) +/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and +/// [`BytesDistinctCountAccumulator`] +#[derive(Debug)] +struct DistinctCountAccumulator { + values: HashSet, + state_data_type: DataType, +} + +impl DistinctCountAccumulator { + // calculating the size for fixed length values, taking first batch size * + // number of batches This method is faster than .full_size(), however it is + // not suitable for variable length values like strings or complex types + fn fixed_size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .next() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .unwrap_or(0) + + std::mem::size_of::() + } + + // calculates the size as accurately as possible. Note that calling this + // method is expensive + fn full_size(&self) -> usize { + std::mem::size_of_val(self) + + (std::mem::size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) + .sum::() + + std::mem::size_of::() + } +} + +impl Accumulator for DistinctCountAccumulator { + /// Returns the distinct values seen so far as (one element) ListArray. + fn state(&mut self) -> Result> { + let scalars = self.values.iter().cloned().collect::>(); + let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); + Ok(vec![ScalarValue::List(arr)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = &values[0]; + if arr.data_type() == &DataType::Null { + return Ok(()); + } + + (0..arr.len()).try_for_each(|index| { + if !arr.is_null(index) { + let scalar = ScalarValue::try_from_array(arr, index)?; + self.values.insert(scalar); + } + Ok(()) + }) + } + + /// Merges multiple sets of distinct values into the current set. + /// + /// The input to this function is a `ListArray` with **multiple** rows, + /// where each row contains the values from a partial aggregate's phase (e.g. + /// the result of calling `Self::state` on multiple accumulators). + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!(states.len(), 1, "array_agg states must be singleton!"); + let array = &states[0]; + let list_array = array.as_list::(); + for inner_array in list_array.iter() { + let Some(inner_array) = inner_array else { + return internal_err!( + "Intermediate results of COUNT DISTINCT should always be non null" + ); + }; + self.update_batch(&[inner_array])?; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + match &self.state_data_type { + DataType::Boolean | DataType::Null => self.fixed_size(), + d if d.is_primitive() => self.fixed_size(), + _ => self.full_size(), + } + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 2d062cf2cb9b..56fc1305bb59 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -56,6 +56,7 @@ pub mod macros; pub mod approx_distinct; +pub mod count; pub mod covariance; pub mod first_last; pub mod hyperloglog; @@ -77,6 +78,8 @@ use std::sync::Arc; pub mod expr_fn { pub use super::approx_distinct; pub use super::approx_median::approx_median; + pub use super::count::count; + pub use super::count::count_distinct; pub use super::covariance::covar_pop; pub use super::covariance::covar_samp; pub use super::first_last::first_value; @@ -98,6 +101,7 @@ pub fn all_default_aggregate_functions() -> Vec> { sum::sum_udaf(), covariance::covar_pop_udaf(), median::median_udaf(), + count::count_udaf(), variance::var_samp_udaf(), variance::var_pop_udaf(), stddev::stddev_udaf(), @@ -133,8 +137,8 @@ mod tests { let mut names = HashSet::new(); for func in all_default_aggregate_functions() { // TODO: remove this - // sum is in intermidiate migration state, skip this - if func.name().to_lowercase() == "sum" { + // These functions are in intermidiate migration state, skip them + if func.name().to_lowercase() == "count" { continue; } assert!( diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b55b1a7f8f2d..e14ee763a3c0 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -441,8 +441,14 @@ fn agg_exprs_evaluation_result_on_empty_batch( Transformed::yes(Expr::Literal(ScalarValue::Null)) } } - AggregateFunctionDefinition::UDF { .. } => { - Transformed::yes(Expr::Literal(ScalarValue::Null)) + AggregateFunctionDefinition::UDF(fun) => { + if fun.name() == "COUNT" { + Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( + 0, + )))) + } else { + Transformed::yes(Expr::Literal(ScalarValue::Null)) + } } }, _ => Transformed::no(expr), diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 32b6703bcae5..e738209eb4fd 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -361,8 +361,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { mod tests { use super::*; use crate::test::*; - use datafusion_expr::expr; - use datafusion_expr::expr::GroupingSet; + use datafusion_expr::expr::{self, GroupingSet}; use datafusion_expr::test::function_stub::{sum, sum_udaf}; use datafusion_expr::{ count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, min, diff --git a/datafusion/physical-expr-common/Cargo.toml b/datafusion/physical-expr-common/Cargo.toml index 637b8775112e..3ef2d5345533 100644 --- a/datafusion/physical-expr-common/Cargo.toml +++ b/datafusion/physical-expr-common/Cargo.toml @@ -36,7 +36,9 @@ name = "datafusion_physical_expr_common" path = "src/lib.rs" [dependencies] +ahash = { workspace = true } arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +hashbrown = { workspace = true } rand = { workspace = true } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs similarity index 93% rename from datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs rename to datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs index 2ed9b002c841..5c888ca66caa 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs @@ -18,7 +18,7 @@ //! [`BytesDistinctCountAccumulator`] for Utf8/LargeUtf8/Binary/LargeBinary values use crate::binary_map::{ArrowBytesSet, OutputType}; -use arrow_array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::{ArrayRef, OffsetSizeTrait}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array; use datafusion_common::ScalarValue; @@ -35,10 +35,10 @@ use std::sync::Arc; /// [`BinaryArray`]: arrow::array::BinaryArray /// [`LargeBinaryArray`]: arrow::array::LargeBinaryArray #[derive(Debug)] -pub(super) struct BytesDistinctCountAccumulator(ArrowBytesSet); +pub struct BytesDistinctCountAccumulator(ArrowBytesSet); impl BytesDistinctCountAccumulator { - pub(super) fn new(output_type: OutputType) -> Self { + pub fn new(output_type: OutputType) -> Self { Self(ArrowBytesSet::new(output_type)) } } diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs new file mode 100644 index 000000000000..f216406d0dd7 --- /dev/null +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod bytes; +mod native; + +pub use bytes::BytesDistinctCountAccumulator; +pub use native::FloatDistinctCountAccumulator; +pub use native::PrimitiveDistinctCountAccumulator; diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs similarity index 93% rename from datafusion/physical-expr/src/aggregate/count_distinct/native.rs rename to datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs index 0e7483d4a1cd..72b83676e81d 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/native.rs @@ -26,10 +26,10 @@ use std::hash::Hash; use std::sync::Arc; use ahash::RandomState; +use arrow::array::types::ArrowPrimitiveType; use arrow::array::ArrayRef; -use arrow_array::types::ArrowPrimitiveType; -use arrow_array::PrimitiveArray; -use arrow_schema::DataType; +use arrow::array::PrimitiveArray; +use arrow::datatypes::DataType; use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::utils::array_into_list_array; @@ -40,7 +40,7 @@ use datafusion_expr::Accumulator; use crate::aggregate::utils::Hashable; #[derive(Debug)] -pub(super) struct PrimitiveDistinctCountAccumulator +pub struct PrimitiveDistinctCountAccumulator where T: ArrowPrimitiveType + Send, T::Native: Eq + Hash, @@ -54,7 +54,7 @@ where T: ArrowPrimitiveType + Send, T::Native: Eq + Hash, { - pub(super) fn new(data_type: &DataType) -> Self { + pub fn new(data_type: &DataType) -> Self { Self { values: HashSet::default(), data_type: data_type.clone(), @@ -125,7 +125,7 @@ where } #[derive(Debug)] -pub(super) struct FloatDistinctCountAccumulator +pub struct FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send, { @@ -136,13 +136,22 @@ impl FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send, { - pub(super) fn new() -> Self { + pub fn new() -> Self { Self { values: HashSet::default(), } } } +impl Default for FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn default() -> Self { + Self::new() + } +} + impl Accumulator for FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send + Debug, diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index ec02df57b82d..21884f840dbd 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod count_distinct; pub mod groups_accumulator; pub mod stats; pub mod tdigest; diff --git a/datafusion/physical-expr/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs similarity index 98% rename from datafusion/physical-expr/src/binary_map.rs rename to datafusion/physical-expr-common/src/binary_map.rs index 0923fcdaeb91..6d5ba737a1df 100644 --- a/datafusion/physical-expr/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -19,17 +19,16 @@ //! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray. use ahash::RandomState; -use arrow_array::cast::AsArray; -use arrow_array::types::{ByteArrayType, GenericBinaryType, GenericStringType}; -use arrow_array::{ - Array, ArrayRef, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, +use arrow::array::cast::AsArray; +use arrow::array::types::{ByteArrayType, GenericBinaryType, GenericStringType}; +use arrow::array::{ + Array, ArrayRef, BooleanBufferBuilder, BufferBuilder, GenericBinaryArray, + GenericStringArray, OffsetSizeTrait, }; -use arrow_buffer::{ - BooleanBufferBuilder, BufferBuilder, NullBuffer, OffsetBuffer, ScalarBuffer, -}; -use arrow_schema::DataType; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::DataType; use datafusion_common::hash_utils::create_hashes; -use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; use std::any::type_name; use std::fmt::Debug; use std::mem; @@ -605,8 +604,8 @@ where #[cfg(test)] mod tests { use super::*; - use arrow_array::{BinaryArray, LargeBinaryArray, StringArray}; - use hashbrown::HashMap; + use arrow::array::{BinaryArray, LargeBinaryArray, StringArray}; + use std::collections::HashMap; #[test] fn string_set_empty() { diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index f335958698ab..0ddb84141a07 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -16,6 +16,7 @@ // under the License. pub mod aggregate; +pub mod binary_map; pub mod expressions; pub mod physical_expr; pub mod sort_expr; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index ac24dd2e7603..aee7bca3b88f 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,12 +30,13 @@ use std::sync::Arc; use arrow::datatypes::Schema; +use datafusion_common::{exec_err, internal_err, not_impl_err, Result}; +use datafusion_expr::AggregateFunction; + use crate::aggregate::average::Avg; use crate::aggregate::regr::RegrType; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; -use datafusion_common::{exec_err, not_impl_err, Result}; -use datafusion_expr::AggregateFunction; /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. pub fn create_aggregate_expr( @@ -60,14 +61,9 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::Count, false) => Arc::new( - expressions::Count::new_with_multiple_exprs(input_phy_exprs, name, data_type), - ), - (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( - data_type, - input_phy_exprs[0].clone(), - name, - )), + (AggregateFunction::Count, _) => { + return internal_err!("Builtin Count will be removed"); + } (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, @@ -320,7 +316,7 @@ mod tests { use super::*; use crate::expressions::{ try_cast, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, - BoolOr, Count, DistinctArrayAgg, DistinctCount, Max, Min, + BoolOr, DistinctArrayAgg, Max, Min, }; use datafusion_common::{plan_err, DataFusionError, ScalarValue}; @@ -328,8 +324,8 @@ mod tests { use datafusion_expr::{type_coercion, Signature}; #[test] - fn test_count_arragg_approx_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Count, AggregateFunction::ArrayAgg]; + fn test_approx_expr() -> Result<()> { + let funcs = vec![AggregateFunction::ArrayAgg]; let data_types = vec![ DataType::UInt32, DataType::Int32, @@ -352,29 +348,18 @@ mod tests { &input_schema, "c1", )?; - match fun { - AggregateFunction::Count => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Int64, true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::ArrayAgg => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; + if fun == AggregateFunction::ArrayAgg { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new_list( + "c1", + Field::new("item", data_type.clone(), true), + true, + ), + result_agg_phy_exprs.field().unwrap() + ); + } let result_distinct = create_physical_agg_expr_for_test( &fun, @@ -383,29 +368,18 @@ mod tests { &input_schema, "c1", )?; - match fun { - AggregateFunction::Count => { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new("c1", DataType::Int64, true), - result_distinct.field().unwrap() - ); - } - AggregateFunction::ArrayAgg => { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; + if fun == AggregateFunction::ArrayAgg { + assert!(result_distinct.as_any().is::()); + assert_eq!("c1", result_distinct.name()); + assert_eq!( + Field::new_list( + "c1", + Field::new("item", data_type.clone(), true), + true, + ), + result_agg_phy_exprs.field().unwrap() + ); + } } } Ok(()) diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs deleted file mode 100644 index aad18a82ab87..000000000000 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ /dev/null @@ -1,348 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::fmt::Debug; -use std::ops::BitAnd; -use std::sync::Arc; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, Int64Array}; -use arrow::compute; -use arrow::datatypes::DataType; -use arrow::{array::ArrayRef, datatypes::Field}; -use arrow_array::cast::AsArray; -use arrow_array::types::Int64Type; -use arrow_array::PrimitiveArray; -use arrow_buffer::BooleanBuffer; -use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; - -use crate::expressions::format_state_name; - -use super::groups_accumulator::accumulate::accumulate_indices; - -/// COUNT aggregate expression -/// Returns the amount of non-null values of the given expression. -#[derive(Debug, Clone)] -pub struct Count { - name: String, - data_type: DataType, - nullable: bool, - /// Input exprs - /// - /// For `COUNT(c1)` this is `[c1]` - /// For `COUNT(c1, c2)` this is `[c1, c2]` - exprs: Vec>, -} - -impl Count { - /// Create a new COUNT aggregate function. - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - exprs: vec![expr], - data_type, - nullable: true, - } - } - - pub fn new_with_multiple_exprs( - exprs: Vec>, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - exprs, - data_type, - nullable: true, - } - } -} - -/// An accumulator to compute the counts of [`PrimitiveArray`]. -/// Stores values as native types, and does overflow checking -/// -/// Unlike most other accumulators, COUNT never produces NULLs. If no -/// non-null values are seen in any group the output is 0. Thus, this -/// accumulator has no additional null or seen filter tracking. -#[derive(Debug)] -struct CountGroupsAccumulator { - /// Count per group. - /// - /// Note this is an i64 and not a u64 (or usize) because the - /// output type of count is `DataType::Int64`. Thus by using `i64` - /// for the counts, the output [`Int64Array`] can be created - /// without copy. - counts: Vec, -} - -impl CountGroupsAccumulator { - pub fn new() -> Self { - Self { counts: vec![] } - } -} - -impl GroupsAccumulator for CountGroupsAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = &values[0]; - - // Add one to each group's counter for each non null, non - // filtered value - self.counts.resize(total_num_groups, 0); - accumulate_indices( - group_indices, - values.logical_nulls().as_ref(), - opt_filter, - |group_index| { - self.counts[group_index] += 1; - }, - ); - - Ok(()) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "one argument to merge_batch"); - // first batch is counts, second is partial sums - let partial_counts = values[0].as_primitive::(); - - // intermediate counts are always created as non null - assert_eq!(partial_counts.null_count(), 0); - let partial_counts = partial_counts.values(); - - // Adds the counts with the partial counts - self.counts.resize(total_num_groups, 0); - match opt_filter { - Some(filter) => filter - .iter() - .zip(group_indices.iter()) - .zip(partial_counts.iter()) - .for_each(|((filter_value, &group_index), partial_count)| { - if let Some(true) = filter_value { - self.counts[group_index] += partial_count; - } - }), - None => group_indices.iter().zip(partial_counts.iter()).for_each( - |(&group_index, partial_count)| { - self.counts[group_index] += partial_count; - }, - ), - } - - Ok(()) - } - - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let counts = emit_to.take_needed(&mut self.counts); - - // Count is always non null (null inputs just don't contribute to the overall values) - let nulls = None; - let array = PrimitiveArray::::new(counts.into(), nulls); - - Ok(Arc::new(array)) - } - - // return arrays for counts - fn state(&mut self, emit_to: EmitTo) -> Result> { - let counts = emit_to.take_needed(&mut self.counts); - let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls - Ok(vec![Arc::new(counts) as ArrayRef]) - } - - fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - } -} - -/// count null values for multiple columns -/// for each row if one column value is null, then null_count + 1 -fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { - if values.len() > 1 { - let result_bool_buf: Option = values - .iter() - .map(|a| a.logical_nulls()) - .fold(None, |acc, b| match (acc, b) { - (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), - (Some(acc), None) => Some(acc), - (None, Some(b)) => Some(b.into_inner()), - _ => None, - }); - result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) - } else { - values[0] - .logical_nulls() - .map_or(0, |nulls| nulls.null_count()) - } -} - -impl AggregateExpr for Count { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int64, self.nullable)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "count"), - DataType::Int64, - true, - )]) - } - - fn expressions(&self) -> Vec> { - self.exprs.clone() - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(CountAccumulator::new())) - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - // groups accumulator only supports `COUNT(c1)`, not - // `COUNT(c1, c2)`, etc - self.exprs.len() == 1 - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(CountAccumulator::new())) - } - - fn create_groups_accumulator(&self) -> Result> { - // instantiate specialized accumulator - Ok(Box::new(CountGroupsAccumulator::new())) - } - - fn with_new_expressions( - &self, - args: Vec>, - order_by_exprs: Vec>, - ) -> Option> { - debug_assert_eq!(self.exprs.len(), args.len()); - debug_assert!(order_by_exprs.is_empty()); - Some(Arc::new(Count { - name: self.name.clone(), - data_type: self.data_type.clone(), - nullable: self.nullable, - exprs: args, - })) - } -} - -impl PartialEq for Count { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.exprs.len() == x.exprs.len() - && self - .exprs - .iter() - .zip(x.exprs.iter()) - .all(|(expr1, expr2)| expr1.eq(expr2)) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -struct CountAccumulator { - count: i64, -} - -impl CountAccumulator { - /// new count accumulator - pub fn new() -> Self { - Self { count: 0 } - } -} - -impl Accumulator for CountAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ScalarValue::Int64(Some(self.count))]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = &values[0]; - self.count += (array.len() - null_count_for_multiple_cols(values)) as i64; - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = &values[0]; - self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], Int64Array); - let delta = &compute::sum(counts); - if let Some(d) = delta { - self.count += *d; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.count))) - } - - fn supports_retract_batch(&self) -> bool { - true - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs deleted file mode 100644 index 52f1c5c0f9a0..000000000000 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ /dev/null @@ -1,718 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod bytes; -mod native; - -use std::any::Any; -use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; - -use ahash::RandomState; -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field, TimeUnit}; -use arrow_array::cast::AsArray; -use arrow_array::types::{ - Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, - Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; - -use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::Accumulator; - -use crate::aggregate::count_distinct::bytes::BytesDistinctCountAccumulator; -use crate::aggregate::count_distinct::native::{ - FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, -}; -use crate::aggregate::utils::down_cast_any_ref; -use crate::binary_map::OutputType; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -/// Expression for a `COUNT(DISTINCT)` aggregation. -#[derive(Debug)] -pub struct DistinctCount { - /// Column name - name: String, - /// The DataType used to hold the state for each input - state_data_type: DataType, - /// The input arguments - expr: Arc, -} - -impl DistinctCount { - /// Create a new COUNT(DISTINCT) aggregate function. - pub fn new( - input_data_type: DataType, - expr: Arc, - name: impl Into, - ) -> Self { - Self { - name: name.into(), - state_data_type: input_data_type, - expr, - } - } -} - -impl AggregateExpr for DistinctCount { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int64, true)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "count distinct"), - Field::new("item", self.state_data_type.clone(), true), - false, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - use DataType::*; - use TimeUnit::*; - - let data_type = &self.state_data_type; - Ok(match data_type { - // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator - Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt16 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal128Type, - >::new(data_type)), - Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal256Type, - >::new(data_type)), - - Date32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Date64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Time32(Millisecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time32MillisecondType, - >::new(data_type)), - Time32(Second) => Box::new(PrimitiveDistinctCountAccumulator::< - Time32SecondType, - >::new(data_type)), - Time64(Microsecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time64MicrosecondType, - >::new(data_type)), - Time64(Nanosecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time64NanosecondType, - >::new(data_type)), - Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMicrosecondType, - >::new(data_type)), - Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMillisecondType, - >::new(data_type)), - Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampNanosecondType, - >::new(data_type)), - Timestamp(Second, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampSecondType, - >::new(data_type)), - - Float16 => Box::new(FloatDistinctCountAccumulator::::new()), - Float32 => Box::new(FloatDistinctCountAccumulator::::new()), - Float64 => Box::new(FloatDistinctCountAccumulator::::new()), - - Utf8 => Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)), - LargeUtf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) - } - Binary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - - // Use the generic accumulator based on `ScalarValue` for all other types - _ => Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: self.state_data_type.clone(), - }), - }) - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctCount { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.state_data_type == x.state_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -/// General purpose distinct accumulator that works for any DataType by using -/// [`ScalarValue`]. -/// -/// It stores intermediate results as a `ListArray` -/// -/// Note that many types have specialized accumulators that are (much) -/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and -/// [`BytesDistinctCountAccumulator`] -#[derive(Debug)] -struct DistinctCountAccumulator { - values: HashSet, - state_data_type: DataType, -} - -impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * - // number of batches This method is faster than .full_size(), however it is - // not suitable for variable length values like strings or complex types - fn fixed_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .next() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .unwrap_or(0) - + std::mem::size_of::() - } - - // calculates the size as accurately as possible. Note that calling this - // method is expensive - fn full_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .sum::() - + std::mem::size_of::() - } -} - -impl Accumulator for DistinctCountAccumulator { - /// Returns the distinct values seen so far as (one element) ListArray. - fn state(&mut self) -> Result> { - let scalars = self.values.iter().cloned().collect::>(); - let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); - Ok(vec![ScalarValue::List(arr)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let arr = &values[0]; - if arr.data_type() == &DataType::Null { - return Ok(()); - } - - (0..arr.len()).try_for_each(|index| { - if !arr.is_null(index) { - let scalar = ScalarValue::try_from_array(arr, index)?; - self.values.insert(scalar); - } - Ok(()) - }) - } - - /// Merges multiple sets of distinct values into the current set. - /// - /// The input to this function is a `ListArray` with **multiple** rows, - /// where each row contains the values from a partial aggregate's phase (e.g. - /// the result of calling `Self::state` on multiple accumulators). - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert_eq!(states.len(), 1, "array_agg states must be singleton!"); - let array = &states[0]; - let list_array = array.as_list::(); - for inner_array in list_array.iter() { - let Some(inner_array) = inner_array else { - return internal_err!( - "Intermediate results of COUNT DISTINCT should always be non null" - ); - }; - self.update_batch(&[inner_array])?; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.values.len() as i64))) - } - - fn size(&self) -> usize { - match &self.state_data_type { - DataType::Boolean | DataType::Null => self.fixed_size(), - d if d.is_primitive() => self.fixed_size(), - _ => self.full_size(), - } - } -} - -#[cfg(test)] -mod tests { - use arrow::array::{ - BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }; - use arrow_array::Decimal256Array; - use arrow_buffer::i256; - - use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; - use datafusion_common::internal_err; - use datafusion_common::DataFusionError; - - use crate::expressions::NoOp; - - use super::*; - - macro_rules! state_to_vec_primitive { - ($LIST:expr, $DATA_TYPE:ident) => {{ - let arr = ScalarValue::raw_data($LIST).unwrap(); - let list_arr = as_list_array(&arr).unwrap(); - let arr = list_arr.values(); - let arr = as_primitive_array::<$DATA_TYPE>(arr)?; - arr.values().iter().cloned().collect::>() - }}; - } - - macro_rules! test_count_distinct_update_batch_numeric { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(1), - Some(1), - None, - Some(3), - Some(2), - None, - Some(2), - Some(3), - Some(1), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - state_vec.sort(); - - assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![1, 2, 3]); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - Ok(()) - }}; - } - - fn state_to_vec_bool(sv: &ScalarValue) -> Result> { - let arr = ScalarValue::raw_data(sv)?; - let list_arr = as_list_array(&arr)?; - let arr = list_arr.values(); - let bool_arr = as_boolean_array(arr)?; - Ok(bool_arr.iter().flatten().collect()) - } - - fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec, ScalarValue)> { - let agg = DistinctCount::new( - arrays[0].data_type().clone(), - Arc::new(NoOp::new()), - String::from("__col_name__"), - ); - - let mut accum = agg.create_accumulator()?; - accum.update_batch(arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - - fn run_update( - data_types: &[DataType], - rows: &[Vec], - ) -> Result<(Vec, ScalarValue)> { - let agg = DistinctCount::new( - data_types[0].clone(), - Arc::new(NoOp::new()), - String::from("__col_name__"), - ); - - let mut accum = agg.create_accumulator()?; - - let cols = (0..rows[0].len()) - .map(|i| { - rows.iter() - .map(|inner| inner[i].clone()) - .collect::>() - }) - .collect::>(); - - let arrays: Vec = cols - .iter() - .map(|c| ScalarValue::iter_to_array(c.clone())) - .collect::>>()?; - - accum.update_batch(&arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - - // Used trait to create associated constant for f32 and f64 - trait SubNormal: 'static { - const SUBNORMAL: Self; - } - - impl SubNormal for f64 { - const SUBNORMAL: Self = 1.0e-308_f64; - } - - impl SubNormal for f32 { - const SUBNORMAL: Self = 1.0e-38_f32; - } - - macro_rules! test_count_distinct_update_batch_floating_point { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(<$PRIM_TYPE>::INFINITY), - Some(<$PRIM_TYPE>::NAN), - Some(1.0), - Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL), - Some(1.0), - Some(<$PRIM_TYPE>::INFINITY), - None, - Some(3.0), - Some(-4.5), - Some(2.0), - None, - Some(2.0), - Some(3.0), - Some(<$PRIM_TYPE>::NEG_INFINITY), - Some(1.0), - Some(<$PRIM_TYPE>::NAN), - Some(<$PRIM_TYPE>::NEG_INFINITY), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - - dbg!(&state_vec); - state_vec.sort_by(|a, b| match (a, b) { - (lhs, rhs) => lhs.total_cmp(rhs), - }); - - let nan_idx = state_vec.len() - 1; - assert_eq!(states.len(), 1); - assert_eq!( - &state_vec[..nan_idx], - vec![ - <$PRIM_TYPE>::NEG_INFINITY, - -4.5, - <$PRIM_TYPE as SubNormal>::SUBNORMAL, - 1.0, - 2.0, - 3.0, - <$PRIM_TYPE>::INFINITY - ] - ); - assert!(state_vec[nan_idx].is_nan()); - assert_eq!(result, ScalarValue::Int64(Some(8))); - - Ok(()) - }}; - } - - macro_rules! test_count_distinct_update_batch_bigint { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(i256::from(1)), - Some(i256::from(1)), - None, - Some(i256::from(3)), - Some(i256::from(2)), - None, - Some(i256::from(2)), - Some(i256::from(3)), - Some(i256::from(1)), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - state_vec.sort(); - - assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - Ok(()) - }}; - } - - #[test] - fn count_distinct_update_batch_i8() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) - } - - #[test] - fn count_distinct_update_batch_i16() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int16Array, Int16Type, i16) - } - - #[test] - fn count_distinct_update_batch_i32() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int32Array, Int32Type, i32) - } - - #[test] - fn count_distinct_update_batch_i64() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int64Array, Int64Type, i64) - } - - #[test] - fn count_distinct_update_batch_u8() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt8Array, UInt8Type, u8) - } - - #[test] - fn count_distinct_update_batch_u16() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt16Array, UInt16Type, u16) - } - - #[test] - fn count_distinct_update_batch_u32() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt32Array, UInt32Type, u32) - } - - #[test] - fn count_distinct_update_batch_u64() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt64Array, UInt64Type, u64) - } - - #[test] - fn count_distinct_update_batch_f32() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float32Array, Float32Type, f32) - } - - #[test] - fn count_distinct_update_batch_f64() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) - } - - #[test] - fn count_distinct_update_batch_i256() -> Result<()> { - test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256) - } - - #[test] - fn count_distinct_update_batch_boolean() -> Result<()> { - let get_count = |data: BooleanArray| -> Result<(Vec, i64)> { - let arrays = vec![Arc::new(data) as ArrayRef]; - let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = state_to_vec_bool(&states[0])?; - state_vec.sort(); - - let count = match result { - ScalarValue::Int64(c) => c.ok_or_else(|| { - DataFusionError::Internal("Found None count".to_string()) - }), - scalar => { - internal_err!("Found non int64 scalar value from count: {scalar}") - } - }?; - Ok((state_vec, count)) - }; - - let zero_count_values = BooleanArray::from(Vec::::new()); - - let one_count_values = BooleanArray::from(vec![false, false]); - let one_count_values_with_null = - BooleanArray::from(vec![Some(true), Some(true), None, None]); - - let two_count_values = BooleanArray::from(vec![true, false, true, false, true]); - let two_count_values_with_null = BooleanArray::from(vec![ - Some(true), - Some(false), - None, - None, - Some(true), - Some(false), - ]); - - assert_eq!(get_count(zero_count_values)?, (Vec::::new(), 0)); - assert_eq!(get_count(one_count_values)?, (vec![false], 1)); - assert_eq!(get_count(one_count_values_with_null)?, (vec![true], 1)); - assert_eq!(get_count(two_count_values)?, (vec![false, true], 2)); - assert_eq!( - get_count(two_count_values_with_null)?, - (vec![false, true], 2) - ); - Ok(()) - } - - #[test] - fn count_distinct_update_batch_all_nulls() -> Result<()> { - let arrays = vec![Arc::new(Int32Array::from( - vec![None, None, None, None] as Vec> - )) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - let state_vec = state_to_vec_primitive!(&states[0], Int32Type); - assert_eq!(states.len(), 1); - assert!(state_vec.is_empty()); - assert_eq!(result, ScalarValue::Int64(Some(0))); - - Ok(()) - } - - #[test] - fn count_distinct_update_batch_empty() -> Result<()> { - let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - let state_vec = state_to_vec_primitive!(&states[0], Int32Type); - assert_eq!(states.len(), 1); - assert!(state_vec.is_empty()); - assert_eq!(result, ScalarValue::Int64(Some(0))); - - Ok(()) - } - - #[test] - fn count_distinct_update() -> Result<()> { - let (states, result) = run_update( - &[DataType::Int32], - &[ - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(5))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(5))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(2))], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - let (states, result) = run_update( - &[DataType::UInt64], - &[ - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(5))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(5))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(2))], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(3))); - Ok(()) - } - - #[test] - fn count_distinct_update_with_nulls() -> Result<()> { - let (states, result) = run_update( - &[DataType::Int32], - &[ - // None of these updates contains a None, so these are accumulated. - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-2))], - // Each of these updates contains at least one None, so these - // won't be accumulated. - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(None)], - vec![ScalarValue::Int32(None)], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(2))); - - let (states, result) = run_update( - &[DataType::UInt64], - &[ - // None of these updates contains a None, so these are accumulated. - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(2))], - // Each of these updates contains at least one None, so these - // won't be accumulated. - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(None)], - vec![ScalarValue::UInt64(None)], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(2))); - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs index 65227b727be7..a6946e739c97 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs @@ -20,7 +20,7 @@ pub use adapter::GroupsAccumulatorAdapter; // Backward compatibility pub(crate) mod accumulate { - pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::{accumulate_indices, NullState}; + pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; } pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 7a6c5f9d0e24..01105c8559c9 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -26,8 +26,6 @@ pub(crate) mod average; pub(crate) mod bit_and_or_xor; pub(crate) mod bool_and_or; pub(crate) mod correlation; -pub(crate) mod count; -pub(crate) mod count_distinct; pub(crate) mod covariance; pub(crate) mod grouping; pub(crate) mod nth_value; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index a96d02173018..123ada6d7c86 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -47,8 +47,6 @@ pub use crate::aggregate::bit_and_or_xor::{BitAnd, BitOr, BitXor, DistinctBitXor pub use crate::aggregate::bool_and_or::{BoolAnd, BoolOr}; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::correlation::Correlation; -pub use crate::aggregate::count::Count; -pub use crate::aggregate::count_distinct::DistinctCount; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 72f5f2d50cb8..b764e81a95d1 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -17,7 +17,9 @@ pub mod aggregate; pub mod analysis; -pub mod binary_map; +pub mod binary_map { + pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; +} pub mod equivalence; pub mod expressions; pub mod functions; diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs index d073c8995a9b..f789af8b8a02 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs @@ -18,7 +18,7 @@ use crate::aggregates::group_values::GroupValues; use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; use datafusion_expr::EmitTo; -use datafusion_physical_expr::binary_map::{ArrowBytesMap, OutputType}; +use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; /// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values /// diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 79abbdb52ca2..b6fc70be7cbc 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1194,12 +1194,14 @@ mod tests { use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::expr::Sort; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::median::median_udaf; use datafusion_physical_expr::expressions::{ - lit, Count, FirstValue, LastValue, OrderSensitiveArrayAgg, + lit, FirstValue, LastValue, OrderSensitiveArrayAgg, }; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr_common::aggregate::create_aggregate_expr; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -1334,11 +1336,16 @@ mod tests { ], }; - let aggregates: Vec> = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - ))]; + let aggregates = vec![create_aggregate_expr( + &count_udaf(), + &[lit(1i8)], + &[], + &[], + &input_schema, + "COUNT(1)", + false, + false, + )?]; let task_ctx = if spill { new_spill_ctx(4, 1000) diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 48f1bee59bbf..56d780e51394 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1194,9 +1194,9 @@ mod tests { RecordBatchStream, SendableRecordBatchStream, TaskContext, }; use datafusion_expr::{ - AggregateFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::{col, Column, NthValue}; use datafusion_physical_expr::window::{ BuiltInWindowExpr, BuiltInWindowFunctionExpr, @@ -1298,8 +1298,7 @@ mod tests { order_by: &str, ) -> Result> { let schema = input.schema(); - let window_fn = - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count); + let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf()); let col_expr = Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; let args = vec![col_expr]; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 9b392d941ef4..63ce473fc57e 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -597,7 +597,6 @@ pub fn get_window_mode( #[cfg(test)] mod tests { use super::*; - use crate::aggregates::AggregateFunction; use crate::collect; use crate::expressions::col; use crate::streaming::StreamingTableExec; @@ -607,6 +606,7 @@ mod tests { use arrow::compute::SortOptions; use datafusion_execution::TaskContext; + use datafusion_functions_aggregate::count::count_udaf; use futures::FutureExt; use InputOrderMode::{Linear, PartiallySorted, Sorted}; @@ -749,7 +749,7 @@ mod tests { let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( - &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col("a", &schema)?], &[], diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml index 97568fb5f678..66ce7cbd838f 100644 --- a/datafusion/proto-common/Cargo.toml +++ b/datafusion/proto-common/Cargo.toml @@ -26,7 +26,7 @@ homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } authors = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto-common/gen/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml index 49884c48b3cc..9f8f03de6dc9 100644 --- a/datafusion/proto-common/gen/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen-common" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index 358ba7e3eb94..b1897aa58e7d 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -27,7 +27,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index b6993f6c040b..eabaf7ba8e14 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.73" +rust-version = "1.75" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index b401ff8810db..2bb3ec793d7f 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -520,6 +520,7 @@ message AggregateExprNode { message AggregateUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; + bool distinct = 5; LogicalExprNode filter = 3; repeated LogicalExprNode order_by = 4; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d6632c77d8da..59b7861a6ef1 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -886,6 +886,9 @@ impl serde::Serialize for AggregateUdfExprNode { if !self.args.is_empty() { len += 1; } + if self.distinct { + len += 1; + } if self.filter.is_some() { len += 1; } @@ -899,6 +902,9 @@ impl serde::Serialize for AggregateUdfExprNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if self.distinct { + struct_ser.serialize_field("distinct", &self.distinct)?; + } if let Some(v) = self.filter.as_ref() { struct_ser.serialize_field("filter", v)?; } @@ -918,6 +924,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { "fun_name", "funName", "args", + "distinct", "filter", "order_by", "orderBy", @@ -927,6 +934,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { enum GeneratedField { FunName, Args, + Distinct, Filter, OrderBy, } @@ -952,6 +960,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { match value { "funName" | "fun_name" => Ok(GeneratedField::FunName), "args" => Ok(GeneratedField::Args), + "distinct" => Ok(GeneratedField::Distinct), "filter" => Ok(GeneratedField::Filter), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), @@ -975,6 +984,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { { let mut fun_name__ = None; let mut args__ = None; + let mut distinct__ = None; let mut filter__ = None; let mut order_by__ = None; while let Some(k) = map_.next_key()? { @@ -991,6 +1001,12 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { } args__ = Some(map_.next_value()?); } + GeneratedField::Distinct => { + if distinct__.is_some() { + return Err(serde::de::Error::duplicate_field("distinct")); + } + distinct__ = Some(map_.next_value()?); + } GeneratedField::Filter => { if filter__.is_some() { return Err(serde::de::Error::duplicate_field("filter")); @@ -1008,6 +1024,7 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { Ok(AggregateUdfExprNode { fun_name: fun_name__.unwrap_or_default(), args: args__.unwrap_or_default(), + distinct: distinct__.unwrap_or_default(), filter: filter__, order_by: order_by__.unwrap_or_default(), }) diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 0aca5ef1ffb8..0861c287fcfa 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -767,6 +767,8 @@ pub struct AggregateUdfExprNode { pub fun_name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "5")] + pub distinct: bool, #[prost(message, optional, boxed, tag = "3")] pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "4")] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3ad5973380ed..2ad40d883fe6 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -642,7 +642,7 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, parse_exprs(&pb.args, registry, codec)?, - false, + pb.distinct, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), parse_vec_expr(&pb.order_by, registry, codec)?, None, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d42470f198e3..6a275ed7a1b8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -456,6 +456,7 @@ pub fn serialize_expr( protobuf::AggregateUdfExprNode { fun_name: fun.name().to_string(), args: serialize_exprs(args, codec)?, + distinct: *distinct, filter: match filter { Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), None => None, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 5258bdd11d86..e25447b023d8 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -25,10 +25,10 @@ use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, - Count, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, - NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, - RowNumber, StringAgg, TryCastExpr, WindowShift, + CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr, IsNotNullExpr, + IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, + OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, StringAgg, + TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -240,12 +240,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let aggr_expr = expr.as_any(); let mut distinct = false; - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Count - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::Count - } else if aggr_expr.downcast_ref::().is_some() { + let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Grouping } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::BitAnd diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 699697dd2f2c..d9736da69d42 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -649,6 +649,8 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), + count(lit(1)), + count_distinct(lit(1)), first_value(lit(1), None), first_value(lit(1), Some(vec![lit(2).sort(true, true)])), covar_samp(lit(1.5), lit(2.2)), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 9cf686dbd3d6..e517482f1db0 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -38,7 +38,7 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; -use datafusion::physical_expr::expressions::{Count, Max, NthValueAgg}; +use datafusion::physical_expr::expressions::{Max, NthValueAgg}; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -47,8 +47,8 @@ use datafusion::physical_plan::aggregates::{ use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ - binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, DistinctCount, - NotExpr, NthValue, PhysicalSortExpr, StringAgg, + binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, NotExpr, NthValue, + PhysicalSortExpr, StringAgg, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::insert::DataSinkExec; @@ -806,7 +806,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::new(vec![], vec![], vec![]), - vec![Arc::new(Count::new(udf_expr, "count", DataType::Int64))], + vec![Arc::new(Max::new(udf_expr, "max", DataType::Int64))], vec![None], window, schema.clone(), @@ -818,31 +818,6 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { Ok(()) } -#[test] -fn roundtrip_distinct_count() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - - let aggregates: Vec> = vec![Arc::new(DistinctCount::new( - DataType::Int64, - col("b", &schema)?, - "COUNT(DISTINCT b)".to_string(), - ))]; - - let groups: Vec<(Arc, String)> = - vec![(col("a", &schema)?, "unused".to_string())]; - - roundtrip_test(Arc::new(AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::new_single(groups), - aggregates.clone(), - vec![None], - Arc::new(EmptyExec::new(schema.clone())), - schema, - )?)) -} - #[test] fn roundtrip_like() -> Result<()> { let schema = Schema::new(vec![ diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index e930af107f77..c7b9808c249d 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -46,7 +46,7 @@ statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'c SELECT CAST(c1 AS INT) FROM aggregate_test_100 # aggregation_with_bad_arguments -statement error DataFusion error: SQL error: ParserError\("Expected an expression:, found: \)"\) +query error SELECT COUNT(DISTINCT) FROM aggregate_test_100 # query_cte_incorrect @@ -104,7 +104,7 @@ SELECT power(1, 2, 3); # # AggregateFunction with wrong number of arguments -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'COUNT\(\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tCOUNT\(Any, \.\., Any\) +query error select count(); # AggregateFunction with wrong number of arguments diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index ee96ffa67044..d934dba4cfea 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -26,7 +26,7 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.75" [lints] workspace = true diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 3f9a895d951c..93f197885c0a 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -982,18 +982,16 @@ pub async fn from_substrait_agg_func( let function_name = substrait_fun_name((**function_name).as_str()); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { + // deal with situation that count(*) got no arguments + if fun.name() == "COUNT" && args.is_empty() { + args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); + } + Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), ))) } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) { - match &fun { - // deal with situation that count(*) got no arguments - aggregate_function::AggregateFunction::Count if args.is_empty() => { - args.push(Expr::Literal(ScalarValue::Int64(Some(1)))); - } - _ => {} - } Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None), ))) From ea21b08e477cc48f458917c132f79c4980c957c1 Mon Sep 17 00:00:00 2001 From: Nga Tran Date: Thu, 13 Jun 2024 07:38:37 -0400 Subject: [PATCH 08/34] refactor: fetch statistics for a given ParquetMetaData (#10880) * refactor: fetch statistics for a given ParquetMetaData * test: add tests for fetch_statistics_from_parquet_meta * Rename function and improve docs * Simplify the test --------- Co-authored-by: Andrew Lamb --- .../src/datasource/file_format/parquet.rs | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 99c38d3f0980..572904254fd7 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -455,6 +455,8 @@ async fn fetch_schema( } /// Read and parse the statistics of the Parquet file at location `path` +/// +/// See [`statistics_from_parquet_meta`] for more details async fn fetch_statistics( store: &dyn ObjectStore, table_schema: SchemaRef, @@ -462,6 +464,17 @@ async fn fetch_statistics( metadata_size_hint: Option, ) -> Result { let metadata = fetch_parquet_metadata(store, file, metadata_size_hint).await?; + statistics_from_parquet_meta(&metadata, table_schema).await +} + +/// Convert statistics in [`ParquetMetaData`] into [`Statistics`] +/// +/// The statistics are calculated for each column in the table schema +/// using the row group statistics in the parquet metadata. +pub async fn statistics_from_parquet_meta( + metadata: &ParquetMetaData, + table_schema: SchemaRef, +) -> Result { let file_metadata = metadata.file_metadata(); let file_schema = parquet_to_arrow_schema( @@ -1402,6 +1415,66 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_statistics_from_parquet_metadata() -> Result<()> { + // Data for column c1: ["Foo", null, "bar"] + let c1: ArrayRef = + Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); + let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())]).unwrap(); + + // Data for column c2: [1, 2, null] + let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)]).unwrap(); + + // Use store_parquet to write each batch to its own file + // . batch1 written into first file and includes: + // - column c1 that has 3 rows with one null. Stats min and max of string column is missing for this test even the column has values + // . batch2 written into second file and includes: + // - column c2 that has 3 rows with one null. Stats min and max of int are avaialble and 1 and 2 respectively + let store = Arc::new(LocalFileSystem::new()) as _; + let (files, _file_names) = store_parquet(vec![batch1, batch2], false).await?; + + let state = SessionContext::new().state(); + let format = ParquetFormat::default(); + let schema = format.infer_schema(&state, &store, &files).await.unwrap(); + + let null_i64 = ScalarValue::Int64(None); + let null_utf8 = ScalarValue::Utf8(None); + + // Fetch statistics for first file + let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; + let stats = statistics_from_parquet_meta(&pq_meta, schema.clone()).await?; + // + assert_eq!(stats.num_rows, Precision::Exact(3)); + // column c1 + let c1_stats = &stats.column_statistics[0]; + assert_eq!(c1_stats.null_count, Precision::Exact(1)); + assert_eq!(c1_stats.max_value, Precision::Absent); + assert_eq!(c1_stats.min_value, Precision::Absent); + // column c2: missing from the file so the table treats all 3 rows as null + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c2_stats.null_count, Precision::Exact(3)); + assert_eq!(c2_stats.max_value, Precision::Exact(null_i64.clone())); + assert_eq!(c2_stats.min_value, Precision::Exact(null_i64.clone())); + + // Fetch statistics for second file + let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[1], None).await?; + let stats = statistics_from_parquet_meta(&pq_meta, schema.clone()).await?; + assert_eq!(stats.num_rows, Precision::Exact(3)); + // column c1: missing from the file so the table treats all 3 rows as null + let c1_stats = &stats.column_statistics[0]; + assert_eq!(c1_stats.null_count, Precision::Exact(3)); + assert_eq!(c1_stats.max_value, Precision::Exact(null_utf8.clone())); + assert_eq!(c1_stats.min_value, Precision::Exact(null_utf8.clone())); + // column c2 + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c2_stats.null_count, Precision::Exact(1)); + assert_eq!(c2_stats.max_value, Precision::Exact(2i64.into())); + assert_eq!(c2_stats.min_value, Precision::Exact(1i64.into())); + + Ok(()) + } + #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); From 1aa205d06bb73937571b742178c74de29f4f7eba Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Thu, 13 Jun 2024 16:59:18 +0300 Subject: [PATCH 09/34] Move FileSinkExec::metrics to the correct place (#239) (#10901) --- datafusion/physical-plan/src/insert.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index fa30141a1934..30c3353d4b71 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -175,11 +175,6 @@ impl DataSinkExec { &self.sort_order } - /// Returns the metrics of the underlying [DataSink] - pub fn metrics(&self) -> Option { - self.sink.metrics() - } - fn create_schema( input: &Arc, schema: SchemaRef, @@ -289,6 +284,11 @@ impl ExecutionPlan for DataSinkExec { stream, ))) } + + /// Returns the metrics of the underlying [DataSink] + fn metrics(&self) -> Option { + self.sink.metrics() + } } /// Create a output record batch with a count From 1fc5f915b9d84b3a77047f5655f0ef3c5b188b1a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 13 Jun 2024 10:37:56 -0400 Subject: [PATCH 10/34] Refine ParquetAccessPlan comments and tests (#10896) --- .../src/datasource/physical_plan/parquet/access_plan.rs | 6 +++--- datafusion/core/src/datasource/physical_plan/parquet/mod.rs | 5 ++--- .../core/src/datasource/physical_plan/parquet/opener.rs | 2 ++ 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs index f51f2c49e896..e15e907cd9b8 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs @@ -384,7 +384,7 @@ mod test { let access_plan = ParquetAccessPlan::new(vec![ RowGroupAccess::Scan, RowGroupAccess::Selection( - // select / skip all 20 rows in row group 1 + // specifies all 20 rows in row group 1 vec![ RowSelector::select(5), RowSelector::skip(7), @@ -463,7 +463,7 @@ mod test { fn test_invalid_too_few() { let access_plan = ParquetAccessPlan::new(vec![ RowGroupAccess::Scan, - // select 12 rows, but row group 1 has 20 + // specify only 12 rows in selection, but row group 1 has 20 RowGroupAccess::Selection( vec![RowSelector::select(5), RowSelector::skip(7)].into(), ), @@ -484,7 +484,7 @@ mod test { fn test_invalid_too_many() { let access_plan = ParquetAccessPlan::new(vec![ RowGroupAccess::Scan, - // select 22 rows, but row group 1 has only 20 + // specify 22 rows in selection, but row group 1 has only 20 RowGroupAccess::Selection( vec![ RowSelector::select(10), diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 5e5cc93bc54f..ec21c5504c69 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -156,9 +156,8 @@ pub use writer::plan_to_parquet; /// used to implement external indexes on top of parquet files and select only /// portions of the files. /// -/// The `ParquetExec` will try and further reduce any provided -/// `ParquetAccessPlan` further based on the contents of `ParquetMetadata` and -/// other settings. +/// The `ParquetExec` will try and reduce any provided `ParquetAccessPlan` +/// further based on the contents of `ParquetMetadata` and other settings. /// /// ## Example of providing a ParquetAccessPlan /// diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs index 8557c6d5f950..36335863032c 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -238,6 +238,8 @@ fn create_initial_plan( // check row group count matches the plan return Ok(access_plan.clone()); + } else { + debug!("ParquetExec Ignoring unknown extension specified for {file_name}"); } } From 2d2685914d70e3be542d957c927a65459d901906 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Thu, 13 Jun 2024 23:55:51 +0800 Subject: [PATCH 11/34] ci: fix clippy failures on main (#10903) * ci: fix clippy * retry ci --- datafusion/expr/src/utils.rs | 2 +- datafusion/physical-expr/src/window/nth_value.rs | 1 - datafusion/physical-expr/src/window/window_expr.rs | 1 - datafusion/physical-plan/src/joins/hash_join.rs | 2 +- datafusion/physical-plan/src/joins/stream_join_utils.rs | 1 - datafusion/physical-plan/src/joins/symmetric_hash_join.rs | 2 +- datafusion/physical-plan/src/joins/test_utils.rs | 1 - datafusion/physical-plan/src/joins/utils.rs | 1 - 8 files changed, 3 insertions(+), 8 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 71a3a5fe7309..3ab0c180dcba 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -818,7 +818,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( } } Expr::Literal(_) => { - indexes.push(std::usize::MAX); + indexes.push(usize::MAX); } _ => {} } diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 55d112e1f6e0..4bd40066ff34 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -125,7 +125,6 @@ impl BuiltInWindowFunctionExpr for NthValue { fn create_evaluator(&self) -> Result> { let state = NthValueState { - range: Default::default(), finalized_result: None, kind: self.kind, }; diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 065371d9e43e..3cf68379d72b 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -559,7 +559,6 @@ pub enum NthValueKind { #[derive(Debug, Clone)] pub struct NthValueState { - pub range: Range, // In certain cases, we can finalize the result early. Consider this usage: // ``` // FIRST_VALUE(increasing_col) OVER window AS my_first_value diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 784584f03f0f..cd66ab093f88 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -21,7 +21,7 @@ use std::fmt; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; -use std::{any::Any, usize, vec}; +use std::{any::Any, vec}; use super::{ utils::{OnceAsync, OnceFut}, diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 0a01d84141e7..46d3ac5acf1e 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -20,7 +20,6 @@ use std::collections::{HashMap, VecDeque}; use std::sync::Arc; -use std::usize; use crate::joins::utils::{JoinFilter, JoinHashMapType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 7b4d790479b1..e11e6dd2f627 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -29,7 +29,7 @@ use std::any::Any; use std::fmt::{self, Debug}; use std::sync::Arc; use std::task::{Context, Poll}; -use std::{usize, vec}; +use std::vec; use crate::common::SharedMemoryReservation; use crate::handle_state; diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 9598ed83aa58..7e05ded6f69d 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -18,7 +18,6 @@ //! This file has test utils for hash joins use std::sync::Arc; -use std::usize; use crate::joins::utils::{JoinFilter, JoinOn}; use crate::joins::{ diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 0d99d7a16356..c08b0e3d091c 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -23,7 +23,6 @@ use std::future::Future; use std::ops::{IndexMut, Range}; use std::sync::Arc; use std::task::{Context, Poll}; -use std::usize; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; From b7d2aea1dd4bb4a3abe3163dae936d7bfa5b32c9 Mon Sep 17 00:00:00 2001 From: Oleks V Date: Thu, 13 Jun 2024 09:49:11 -0700 Subject: [PATCH 12/34] Minor: disable flaky fuzz test (#10904) * Minor: disable fuzz test to avoid CI spontaneous failures --- datafusion/core/tests/fuzz_cases/join_fuzz.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index a893e780581f..516749e82a53 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -180,6 +180,9 @@ async fn test_semi_join_1k() { .await } +// The test is flaky +// https://github.com/apache/datafusion/issues/10886 +#[ignore] #[tokio::test] async fn test_semi_join_1k_filtered() { JoinFuzzTestCase::new( From b627ca3e78d35cd12a850a7ef181fd8862dbf50f Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 14 Jun 2024 07:23:23 +0800 Subject: [PATCH 13/34] Remove builtin count (#10893) * rm expr fn Signed-off-by: jayzhan211 * rm function Signed-off-by: jayzhan211 * fix query and fmt Signed-off-by: jayzhan211 * fix example Signed-off-by: jayzhan211 * Update datafusion/expr/src/test/function_stub.rs Co-authored-by: Andrew Lamb --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- datafusion/expr/src/aggregate_function.rs | 6 -- datafusion/expr/src/expr.rs | 13 --- datafusion/expr/src/expr_fn.rs | 26 ------ datafusion/expr/src/logical_plan/plan.rs | 4 +- datafusion/expr/src/test/function_stub.rs | 86 ++++++++++++++++++- .../expr/src/type_coercion/aggregates.rs | 2 - datafusion/optimizer/Cargo.toml | 1 + .../src/analyzer/count_wildcard_rule.rs | 42 +++------ datafusion/optimizer/src/decorrelate.rs | 10 +-- .../src/eliminate_group_by_constant.rs | 6 +- .../optimizer/src/optimize_projections/mod.rs | 20 ++--- .../src/single_distinct_to_groupby.rs | 52 +++++------ .../optimizer/tests/optimizer_integration.rs | 8 +- .../physical-expr/src/aggregate/build_in.rs | 19 +--- datafusion/proto/Cargo.toml | 1 + datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 2 - .../tests/cases/roundtrip_logical_plan.rs | 35 +++----- datafusion/sql/examples/sql.rs | 10 ++- datafusion/sql/src/unparser/expr.rs | 45 ++++------ datafusion/sql/src/utils.rs | 3 +- datafusion/sql/tests/cases/plan_to_sql.rs | 6 +- datafusion/sql/tests/common/mod.rs | 3 +- datafusion/sql/tests/sql_integration.rs | 7 +- .../sqllogictest/test_files/functions.slt | 2 +- 28 files changed, 200 insertions(+), 219 deletions(-) diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index e3d2e6555d5c..5899cc927703 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -33,8 +33,6 @@ use strum_macros::EnumIter; // https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum AggregateFunction { - /// Count - Count, /// Minimum Min, /// Maximum @@ -89,7 +87,6 @@ impl AggregateFunction { pub fn name(&self) -> &str { use AggregateFunction::*; match self { - Count => "COUNT", Min => "MIN", Max => "MAX", Avg => "AVG", @@ -135,7 +132,6 @@ impl FromStr for AggregateFunction { "bit_xor" => AggregateFunction::BitXor, "bool_and" => AggregateFunction::BoolAnd, "bool_or" => AggregateFunction::BoolOr, - "count" => AggregateFunction::Count, "max" => AggregateFunction::Max, "mean" => AggregateFunction::Avg, "min" => AggregateFunction::Min, @@ -190,7 +186,6 @@ impl AggregateFunction { })?; match self { - AggregateFunction::Count => Ok(DataType::Int64), AggregateFunction::Max | AggregateFunction::Min => { // For min and max agg function, the returned type is same as input type. // The coerced_data_types is same with input_types. @@ -249,7 +244,6 @@ impl AggregateFunction { pub fn signature(&self) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. match self { - AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable), AggregateFunction::Grouping | AggregateFunction::ArrayAgg => { Signature::any(1, Volatility::Immutable) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 57f5414c13bd..9ba866a4c919 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2135,18 +2135,6 @@ mod test { use super::*; - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - #[test] fn test_first_value_return_type() -> Result<()> { let fun = find_df_window_func("first_value").unwrap(); @@ -2250,7 +2238,6 @@ mod test { "nth_value", "min", "max", - "count", "avg", ]; for name in names { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1fafc63e9665..fb5b3991ecd8 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -192,19 +192,6 @@ pub fn avg(expr: Expr) -> Expr { )) } -/// Create an expression to represent the count() aggregate function -// TODO: Remove this and use `expr_fn::count` instead -pub fn count(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( @@ -250,19 +237,6 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { )) } -/// Create an expression to represent the count(distinct) aggregate function -// TODO: Remove this and use `expr_fn::count_distinct` instead -pub fn count_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![expr], - true, - None, - None, - None, - )) -} - /// Create an in_list expression pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { Expr::InList(InList::new(Box::new(expr), list, negated)) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9ea2abe64ede..02378ab3fc1b 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2965,11 +2965,13 @@ mod tests { use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; - use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; + use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet}; use datafusion_common::tree_node::TreeNodeVisitor; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; + use crate::test::function_stub::count; + fn employee_schema() -> Schema { Schema::new(vec![ Field::new("id", DataType::Int32, false), diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs index b9aa1e636d94..ac98ee9747cc 100644 --- a/datafusion/expr/src/test/function_stub.rs +++ b/datafusion/expr/src/test/function_stub.rs @@ -31,7 +31,7 @@ use crate::{ use arrow::datatypes::{ DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { @@ -69,6 +69,19 @@ pub fn sum(expr: Expr) -> Expr { )) } +create_func!(Count, count_udaf); + +pub fn count(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + count_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + /// Stub `sum` used for optimizer testing #[derive(Debug)] pub struct Sum { @@ -189,3 +202,74 @@ impl AggregateUDFImpl for Sum { AggregateOrderSensitivity::Insensitive } } + +/// Testing stub implementation of COUNT aggregate +pub struct Count { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Count { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Count") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Count { + fn default() -> Self { + Self::new() + } +} + +impl Count { + pub fn new() -> Self { + Self { + aliases: vec!["count".to_string()], + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Count { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "COUNT" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index ab7deaff9885..2c76407cdfe2 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -96,7 +96,6 @@ pub fn coerce_types( check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; match agg_fun { - AggregateFunction::Count => Ok(input_types.to_vec()), AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), AggregateFunction::Min | AggregateFunction::Max => { // min and max support the dictionary data type @@ -525,7 +524,6 @@ mod tests { // test count, array_agg, approx_distinct, min, max. // the coerced types is same with input types let funs = vec![ - AggregateFunction::Count, AggregateFunction::ArrayAgg, AggregateFunction::Min, AggregateFunction::Max, diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index cb14f6bdd4a3..1a9e9630c076 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -56,5 +56,6 @@ regex-syntax = "0.8.0" [dev-dependencies] arrow-buffer = { workspace = true } ctor = { workspace = true } +datafusion-functions-aggregate = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index af1c99c52390..de2af520053a 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -25,9 +25,7 @@ use datafusion_expr::expr::{ AggregateFunction, AggregateFunctionDefinition, WindowFunction, }; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{ - aggregate_function, lit, Expr, LogicalPlan, WindowFunctionDefinition, -}; +use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// @@ -56,37 +54,19 @@ fn is_wildcard(expr: &Expr) -> bool { } fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { - match aggregate_function { + matches!(aggregate_function, AggregateFunction { func_def: AggregateFunctionDefinition::UDF(udf), args, .. - } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => true, - AggregateFunction { - func_def: - AggregateFunctionDefinition::BuiltIn( - datafusion_expr::aggregate_function::AggregateFunction::Count, - ), - args, - .. - } if args.len() == 1 && is_wildcard(&args[0]) => true, - _ => false, - } + } if udf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { let args = &window_function.args; - match window_function.fun { - WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Count, - ) if args.len() == 1 && is_wildcard(&args[0]) => true, + matches!(window_function.fun, WindowFunctionDefinition::AggregateUDF(ref udaf) - if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0]) => - { - true - } - _ => false, - } + if udaf.name() == "COUNT" && args.len() == 1 && is_wildcard(&args[0])) } fn analyze_internal(plan: LogicalPlan) -> Result> { @@ -121,14 +101,16 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; - use datafusion_expr::test::function_stub::sum; use datafusion_expr::{ - col, count, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, wildcard, AggregateFunction, WindowFrame, - WindowFrameBound, WindowFrameUnits, + col, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, + out_ref_col, scalar_subquery, wildcard, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; + use datafusion_functions_aggregate::count::count_udaf; use std::sync::Arc; + use datafusion_functions_aggregate::expr_fn::{count, sum}; + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_analyzed_plan_eq_display_indent( Arc::new(CountWildcardRule::new()), @@ -239,7 +221,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], vec![], vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index e14ee763a3c0..e949e1921b97 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -432,14 +432,8 @@ fn agg_exprs_evaluation_result_on_empty_batch( Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( - 0, - )))) - } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } + AggregateFunctionDefinition::BuiltIn(_fun) => { + Transformed::yes(Expr::Literal(ScalarValue::Null)) } AggregateFunctionDefinition::UDF(fun) => { if fun.name() == "COUNT" { diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index cef226d67b6c..7a8dd7aac249 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -129,10 +129,12 @@ mod tests { use datafusion_common::Result; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ - col, count, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, - Signature, TypeSignature, + col, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, + TypeSignature, }; + use datafusion_functions_aggregate::expr_fn::count; + use std::sync::Arc; #[derive(Debug)] diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index af51814c9686..11540d3e162e 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -818,10 +818,11 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; + use datafusion_expr::AggregateExt; use datafusion_expr::{ binary_expr, build_join_schema, builder::table_scan_with_filters, - col, count, + col, expr::{self, Cast}, lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, @@ -830,6 +831,9 @@ mod tests { WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::count; + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } @@ -1886,16 +1890,10 @@ mod tests { #[test] fn aggregate_filter_pushdown() -> Result<()> { let table_scan = test_table_scan()?; - - let aggr_with_filter = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("b")], - false, - Some(Box::new(col("c").gt(lit(42)))), - None, - None, - )); - + let aggr_with_filter = count_udaf() + .call(vec![col("b")]) + .filter(col("c").gt(lit(42))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index e738209eb4fd..d3d22eb53f39 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -362,11 +362,13 @@ mod tests { use super::*; use crate::test::*; use datafusion_expr::expr::{self, GroupingSet}; - use datafusion_expr::test::function_stub::{sum, sum_udaf}; + use datafusion_expr::AggregateExt; use datafusion_expr::{ - count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, min, - AggregateFunction, + lit, logical_plan::builder::LogicalPlanBuilder, max, min, AggregateFunction, }; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct, sum}; + use datafusion_functions_aggregate::sum::sum_udaf; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -679,14 +681,11 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - None, - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; @@ -725,19 +724,16 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a ORDER BY a) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - None, - Some(vec![col("a")]), - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -748,19 +744,17 @@ mod tests { let table_scan = test_table_scan()?; // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - Some(vec![col("a")]), - None, - )); + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index b3501cca9efa..f60bf6609005 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -25,6 +25,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; @@ -323,7 +324,9 @@ fn test_sql(sql: &str) -> Result { let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); let statement = &ast[0]; - let context_provider = MyContextProvider::default().with_udaf(sum_udaf()); + let context_provider = MyContextProvider::default() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -345,7 +348,8 @@ struct MyContextProvider { impl MyContextProvider { fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index aee7bca3b88f..75f2e12320bf 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use datafusion_common::{exec_err, internal_err, not_impl_err, Result}; +use datafusion_common::{exec_err, not_impl_err, Result}; use datafusion_expr::AggregateFunction; use crate::aggregate::average::Avg; @@ -61,9 +61,6 @@ pub fn create_aggregate_expr( .collect::>>()?; let input_phy_exprs = input_phy_exprs.to_vec(); Ok(match (fun, distinct) { - (AggregateFunction::Count, _) => { - return internal_err!("Builtin Count will be removed"); - } (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, @@ -642,20 +639,6 @@ mod tests { Ok(()) } - #[test] - fn test_count_return_type() -> Result<()> { - let observed = AggregateFunction::Count.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = AggregateFunction::Count.return_type(&[DataType::Int8])?; - assert_eq!(DataType::Int64, observed); - - let observed = - AggregateFunction::Count.return_type(&[DataType::Decimal128(28, 13)])?; - assert_eq!(DataType::Int64, observed); - Ok(()) - } - #[test] fn test_avg_return_type() -> Result<()> { let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?; diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index b1897aa58e7d..aa8d0e55b68f 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -59,6 +59,7 @@ serde_json = { workspace = true, optional = true } [dev-dependencies] datafusion-functions = { workspace = true, default-features = true } +datafusion-functions-aggregate = { workspace = true } doc-comment = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 2bb3ec793d7f..31cb0d1da9d5 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -476,7 +476,7 @@ enum AggregateFunction { MAX = 1; // SUM = 2; AVG = 3; - COUNT = 4; + // COUNT = 4; // APPROX_DISTINCT = 5; ARRAY_AGG = 6; // VARIANCE = 7; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 59b7861a6ef1..503f83af65f2 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -535,7 +535,6 @@ impl serde::Serialize for AggregateFunction { Self::Min => "MIN", Self::Max => "MAX", Self::Avg => "AVG", - Self::Count => "COUNT", Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", @@ -571,7 +570,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN", "MAX", "AVG", - "COUNT", "ARRAY_AGG", "CORRELATION", "APPROX_PERCENTILE_CONT", @@ -636,7 +634,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "MIN" => Ok(AggregateFunction::Min), "MAX" => Ok(AggregateFunction::Max), "AVG" => Ok(AggregateFunction::Avg), - "COUNT" => Ok(AggregateFunction::Count), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), "APPROX_PERCENTILE_CONT" => Ok(AggregateFunction::ApproxPercentileCont), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 0861c287fcfa..2c0ea62466b4 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1930,7 +1930,7 @@ pub enum AggregateFunction { Max = 1, /// SUM = 2; Avg = 3, - Count = 4, + /// COUNT = 4; /// APPROX_DISTINCT = 5; ArrayAgg = 6, /// VARIANCE = 7; @@ -1972,7 +1972,6 @@ impl AggregateFunction { AggregateFunction::Min => "MIN", AggregateFunction::Max => "MAX", AggregateFunction::Avg => "AVG", - AggregateFunction::Count => "COUNT", AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", AggregateFunction::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", @@ -2004,7 +2003,6 @@ impl AggregateFunction { "MIN" => Some(Self::Min), "MAX" => Some(Self::Max), "AVG" => Some(Self::Avg), - "COUNT" => Some(Self::Count), "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2ad40d883fe6..54a59485c836 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -145,7 +145,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::BitXor => Self::BitXor, protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, protobuf::AggregateFunction::BoolOr => Self::BoolOr, - protobuf::AggregateFunction::Count => Self::Count, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Correlation => Self::Correlation, protobuf::AggregateFunction::RegrSlope => Self::RegrSlope, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6a275ed7a1b8..80ce05d151ee 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -116,7 +116,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::BitXor => Self::BitXor, AggregateFunction::BoolAnd => Self::BoolAnd, AggregateFunction::BoolOr => Self::BoolOr, - AggregateFunction::Count => Self::Count, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Correlation => Self::Correlation, AggregateFunction::RegrSlope => Self::RegrSlope, @@ -406,7 +405,6 @@ pub fn serialize_expr( AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d9736da69d42..d0f1c4aade5e 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -26,6 +26,7 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; +use datafusion_functions_aggregate::count::count_udaf; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -35,8 +36,8 @@ use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::approx_median::approx_median; use datafusion::functions_aggregate::expr_fn::{ - covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, var_pop, - var_sample, + count, count_distinct, covar_pop, covar_samp, first_value, median, stddev, + stddev_pop, sum, var_pop, var_sample, }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -53,10 +54,10 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateFunction, ColumnarValue, ExprSchemable, LogicalPlan, Operator, - PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, - WindowUDFImpl, + Accumulator, AggregateExt, AggregateFunction, ColumnarValue, ExprSchemable, + LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, + TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -1782,28 +1783,18 @@ fn roundtrip_similar_to() { #[test] fn roundtrip_count() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - false, - None, - None, - None, - )); + let test_expr = count(col("bananas")); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } #[test] fn roundtrip_count_distinct() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - true, - None, - None, - None, - )); + let test_expr = count_udaf() + .call(vec![col("bananas")]) + .distinct() + .build() + .unwrap(); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 893db018c8af..aee4cf5a38ed 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -18,11 +18,12 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; -use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::WindowUDF; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_sql::{ planner::{ContextProvider, SqlToRel}, sqlparser::{dialect::GenericDialect, parser::Parser}, @@ -50,7 +51,9 @@ fn main() { let statement = &ast[0]; // create a logical query plan - let context_provider = MyContextProvider::new().with_udaf(sum_udaf()); + let context_provider = MyContextProvider::new() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); @@ -66,7 +69,8 @@ struct MyContextProvider { impl MyContextProvider { fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index dc25a6c33ece..12c48054f1a7 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -960,13 +960,14 @@ mod tests { use arrow_schema::DataType::Int8; use datafusion_common::TableReference; + use datafusion_expr::AggregateExt; use datafusion_expr::{ - case, col, cube, exists, - expr::{AggregateFunction, AggregateFunctionDefinition}, - grouping_set, lit, not, not_exists, out_ref_col, placeholder, rollup, table_scan, - try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, - Volatility, WindowFrame, WindowFunctionDefinition, + case, col, cube, exists, grouping_set, lit, not, not_exists, out_ref_col, + placeholder, rollup, table_scan, try_cast, when, wildcard, ColumnarValue, + ScalarUDF, ScalarUDFImpl, Signature, Volatility, WindowFrame, + WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; use crate::unparser::dialect::CustomDialect; @@ -1127,29 +1128,19 @@ mod tests { ), (sum(col("a")), r#"sum(a)"#), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Count, - ), - args: vec![Expr::Wildcard { qualifier: None }], - distinct: true, - filter: None, - order_by: None, - null_treatment: None, - }), + count_udaf() + .call(vec![Expr::Wildcard { qualifier: None }]) + .distinct() + .build() + .unwrap(), "COUNT(DISTINCT *)", ), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Count, - ), - args: vec![Expr::Wildcard { qualifier: None }], - distinct: false, - filter: Some(Box::new(lit(true))), - order_by: None, - null_treatment: None, - }), + count_udaf() + .call(vec![Expr::Wildcard { qualifier: None }]) + .filter(lit(true)) + .build() + .unwrap(), "COUNT(*) FILTER (WHERE true)", ), ( @@ -1167,9 +1158,7 @@ mod tests { ), ( Expr::WindowFunction(WindowFunction { - fun: WindowFunctionDefinition::AggregateFunction( - datafusion_expr::AggregateFunction::Count, - ), + fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), args: vec![wildcard()], partition_by: vec![], order_by: vec![Expr::Sort(Sort::new( diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 51bacb5f702b..bc27d25cf216 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -350,7 +350,8 @@ mod tests { use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; use arrow_schema::Fields; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{col, count, lit, unnest, EmptyRelation, LogicalPlan}; + use datafusion_expr::{col, lit, unnest, EmptyRelation, LogicalPlan}; + use datafusion_functions_aggregate::expr_fn::count; use crate::utils::{recursive_transform_unnest, resolve_positions_to_exprs}; diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 72018371a5f1..33e28e7056b9 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -19,7 +19,7 @@ use std::vec; use arrow_schema::*; use datafusion_common::{DFSchema, Result, TableReference}; -use datafusion_expr::test::function_stub::sum_udaf; +use datafusion_expr::test::function_stub::{count_udaf, sum_udaf}; use datafusion_expr::{col, table_scan}; use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ @@ -153,7 +153,9 @@ fn roundtrip_statement() -> Result<()> { .try_with_sql(query)? .parse_statement()?; - let context = MockContextProvider::default().with_udaf(sum_udaf()); + let context = MockContextProvider::default() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()); let sql_to_rel = SqlToRel::new(&context); let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs index d91c09ae1287..893678d6b374 100644 --- a/datafusion/sql/tests/common/mod.rs +++ b/datafusion/sql/tests/common/mod.rs @@ -46,7 +46,8 @@ impl MockContextProvider { } pub(crate) fn with_udaf(mut self, udaf: Arc) -> Self { - self.udafs.insert(udaf.name().to_string(), udaf); + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); self } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 7b9d39a2b51e..8eb2a2b609e7 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -37,7 +37,9 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; -use datafusion_functions_aggregate::approx_median::approx_median_udaf; +use datafusion_functions_aggregate::{ + approx_median::approx_median_udaf, count::count_udaf, +}; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -2702,7 +2704,8 @@ fn logical_plan_with_dialect_and_options( )) .with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64)) .with_udaf(sum_udaf()) - .with_udaf(approx_median_udaf()); + .with_udaf(approx_median_udaf()) + .with_udaf(count_udaf()); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index f04d76822124..df6295d63b81 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -487,7 +487,7 @@ statement error Did you mean 'to_timestamp_seconds'? SELECT to_TIMESTAMPS_second(v2) from test; # Aggregate function -statement error Did you mean 'COUNT'? +query error DataFusion error: Error during planning: Invalid function 'counter' SELECT counter(*) from test; # Aggregate function From cc60278f50eac33f9c8ea3509c171b4f2b008b27 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Fri, 14 Jun 2024 01:32:55 +0200 Subject: [PATCH 14/34] Move Regr_* functions to use UDAF (#10898) * Move Regr_* functions to use UDAF Closes #10883 and is part of #8708 * Format and regen * tweak error check --------- Co-authored-by: Andrew Lamb --- datafusion/expr/src/aggregate_function.rs | 56 +------- .../expr/src/type_coercion/aggregates.rs | 21 --- datafusion/functions-aggregate/src/lib.rs | 19 +++ datafusion/functions-aggregate/src/macros.rs | 14 +- .../src}/regr.rs | 127 +++++++++++------- .../physical-expr/src/aggregate/build_in.rs | 78 ----------- datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/expressions/mod.rs | 1 - datafusion/proto/proto/datafusion.proto | 18 +-- datafusion/proto/src/generated/pbjson.rs | 27 ---- datafusion/proto/src/generated/prost.rs | 36 ++--- .../proto/src/logical_plan/from_proto.rs | 9 -- datafusion/proto/src/logical_plan/to_proto.rs | 24 ---- .../proto/src/physical_plan/to_proto.rs | 16 +-- datafusion/sqllogictest/test_files/errors.slt | 4 +- 15 files changed, 135 insertions(+), 316 deletions(-) rename datafusion/{physical-expr/src/aggregate => functions-aggregate/src}/regr.rs (84%) diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 5899cc927703..81562bf12476 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -45,24 +45,6 @@ pub enum AggregateFunction { NthValue, /// Correlation Correlation, - /// Slope from linear regression - RegrSlope, - /// Intercept from linear regression - RegrIntercept, - /// Number of input rows in which both expressions are not null - RegrCount, - /// R-squared value from linear regression - RegrR2, - /// Average of the independent variable - RegrAvgx, - /// Average of the dependent variable - RegrAvgy, - /// Sum of squares of the independent variable - RegrSXX, - /// Sum of squares of the dependent variable - RegrSYY, - /// Sum of products of pairs of numbers - RegrSXY, /// Approximate continuous percentile function ApproxPercentileCont, /// Approximate continuous percentile function with weight @@ -93,15 +75,6 @@ impl AggregateFunction { ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", Correlation => "CORR", - RegrSlope => "REGR_SLOPE", - RegrIntercept => "REGR_INTERCEPT", - RegrCount => "REGR_COUNT", - RegrR2 => "REGR_R2", - RegrAvgx => "REGR_AVGX", - RegrAvgy => "REGR_AVGY", - RegrSXX => "REGR_SXX", - RegrSYY => "REGR_SYY", - RegrSXY => "REGR_SXY", ApproxPercentileCont => "APPROX_PERCENTILE_CONT", ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Grouping => "GROUPING", @@ -140,15 +113,6 @@ impl FromStr for AggregateFunction { "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, - "regr_slope" => AggregateFunction::RegrSlope, - "regr_intercept" => AggregateFunction::RegrIntercept, - "regr_count" => AggregateFunction::RegrCount, - "regr_r2" => AggregateFunction::RegrR2, - "regr_avgx" => AggregateFunction::RegrAvgx, - "regr_avgy" => AggregateFunction::RegrAvgy, - "regr_sxx" => AggregateFunction::RegrSXX, - "regr_syy" => AggregateFunction::RegrSYY, - "regr_sxy" => AggregateFunction::RegrSXY, // approximate "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, "approx_percentile_cont_with_weight" => { @@ -200,15 +164,6 @@ impl AggregateFunction { AggregateFunction::Correlation => { correlation_return_type(&coerced_data_types[0]) } - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => Ok(DataType::Float64), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( "item", @@ -272,16 +227,7 @@ impl AggregateFunction { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), - AggregateFunction::Correlation - | AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => { + AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::ApproxPercentileCont => { diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 2c76407cdfe2..6c9a71bab46a 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -158,27 +158,6 @@ pub fn coerce_types( } Ok(vec![Float64, Float64]) } - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => { - let valid_types = [NUMERICS.to_vec(), vec![Null]].concat(); - let input_types_valid = // number of input already checked before - valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]); - if !input_types_valid { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64, Float64]) - } AggregateFunction::ApproxPercentileCont => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { return plan_err!( diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 56fc1305bb59..fabe15e416f4 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -61,6 +61,7 @@ pub mod covariance; pub mod first_last; pub mod hyperloglog; pub mod median; +pub mod regr; pub mod stddev; pub mod sum; pub mod variance; @@ -85,6 +86,15 @@ pub mod expr_fn { pub use super::first_last::first_value; pub use super::first_last::last_value; pub use super::median::median; + pub use super::regr::regr_avgx; + pub use super::regr::regr_avgy; + pub use super::regr::regr_count; + pub use super::regr::regr_intercept; + pub use super::regr::regr_r2; + pub use super::regr::regr_slope; + pub use super::regr::regr_sxx; + pub use super::regr::regr_sxy; + pub use super::regr::regr_syy; pub use super::stddev::stddev; pub use super::stddev::stddev_pop; pub use super::sum::sum; @@ -102,6 +112,15 @@ pub fn all_default_aggregate_functions() -> Vec> { covariance::covar_pop_udaf(), median::median_udaf(), count::count_udaf(), + regr::regr_slope_udaf(), + regr::regr_intercept_udaf(), + regr::regr_count_udaf(), + regr::regr_r2_udaf(), + regr::regr_avgx_udaf(), + regr::regr_avgy_udaf(), + regr::regr_sxx_udaf(), + regr::regr_syy_udaf(), + regr::regr_sxy_udaf(), variance::var_samp_udaf(), variance::var_pop_udaf(), stddev::stddev_udaf(), diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 75bb9dc54719..cae72cf35223 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -32,8 +32,8 @@ // specific language governing permissions and limitations // under the License. -macro_rules! make_udaf_expr_and_func { - ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { +macro_rules! make_udaf_expr { + ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN( @@ -48,7 +48,12 @@ macro_rules! make_udaf_expr_and_func { None, )) } + }; +} +macro_rules! make_udaf_expr_and_func { + ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN); create_func!($UDAF, $AGGREGATE_UDF_FN); }; ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { @@ -73,6 +78,9 @@ macro_rules! make_udaf_expr_and_func { macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { + create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default()); + }; + ($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => { paste::paste! { /// Singleton instance of [$UDAF], ensures the UDAF is only created once /// named STATIC_$(UDAF). For example `STATIC_FirstValue` @@ -86,7 +94,7 @@ macro_rules! create_func { pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { [< STATIC_ $UDAF >] .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::AggregateUDF::from(<$UDAF>::default())) + std::sync::Arc::new(datafusion_expr::AggregateUDF::from($CREATE)) }) .clone() } diff --git a/datafusion/physical-expr/src/aggregate/regr.rs b/datafusion/functions-aggregate/src/regr.rs similarity index 84% rename from datafusion/physical-expr/src/aggregate/regr.rs rename to datafusion/functions-aggregate/src/regr.rs index 36e7b7c9b3e4..8d04ae87157d 100644 --- a/datafusion/physical-expr/src/aggregate/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -18,9 +18,8 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::sync::Arc; +use std::fmt::Debug; -use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::Float64Array; use arrow::{ array::{ArrayRef, UInt64Array}, @@ -28,13 +27,56 @@ use arrow::{ datatypes::DataType, datatypes::Field, }; -use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue}; +use datafusion_common::{downcast_value, plan_err, unwrap_or_internal_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; + +macro_rules! make_regr_udaf_expr_and_func { + ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { + make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN); + create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN))); + } +} + +make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope); +make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept); +make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count); +make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2); +make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX); +make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY); +make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); +make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); +make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); + +pub struct Regr { + signature: Signature, + regr_type: RegrType, + func_name: &'static str, +} -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; +impl Debug for Regr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("regr") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} +impl Regr { + pub fn new(regr_type: RegrType, func_name: &'static str) -> Self { + Self { + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + regr_type, + func_name, + } + } +} + +/* #[derive(Debug)] pub struct Regr { name: String, @@ -48,6 +90,7 @@ impl Regr { self.regr_type.clone() } } +*/ #[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] @@ -92,86 +135,75 @@ pub enum RegrType { SXY, } -impl Regr { - pub fn new( - expr_y: Arc, - expr_x: Arc, - name: impl Into, - regr_type: RegrType, - return_type: DataType, - ) -> Self { - // the result of regr_slope only support FLOAT64 data type. - assert!(matches!(return_type, DataType::Float64)); - Self { - name: name.into(), - regr_type, - expr_y, - expr_x, - } - } -} - -impl AggregateExpr for Regr { +impl AggregateUDFImpl for Regr { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + self.func_name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Covariance requires numeric input types"); + } + + Ok(DataType::Float64) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn create_sliding_accumulator(&self) -> Result> { + fn create_sliding_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( - format_state_name(&self.name, "count"), + format_state_name(args.name, "count"), DataType::UInt64, true, ), Field::new( - format_state_name(&self.name, "mean_x"), + format_state_name(args.name, "mean_x"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "mean_y"), + format_state_name(args.name, "mean_y"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "m2_x"), + format_state_name(args.name, "m2_x"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "m2_y"), + format_state_name(args.name, "m2_y"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "algo_const"), + format_state_name(args.name, "algo_const"), DataType::Float64, true, ), ]) } - - fn expressions(&self) -> Vec> { - vec![self.expr_y.clone(), self.expr_x.clone()] - } - - fn name(&self) -> &str { - &self.name - } } +/* impl PartialEq for Regr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) @@ -184,6 +216,7 @@ impl PartialEq for Regr { .unwrap_or(false) } } +*/ /// `RegrAccumulator` is used to compute linear regression aggregate functions /// by maintaining statistics needed to compute them in an online fashion. @@ -305,6 +338,10 @@ impl Accumulator for RegrAccumulator { Ok(()) } + fn supports_retract_batch(&self) -> bool { + true + } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values_y = &cast(&values[0], &DataType::Float64)?; let values_x = &cast(&values[1], &DataType::Float64)?; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 75f2e12320bf..df87a2e261a1 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -34,7 +34,6 @@ use datafusion_common::{exec_err, not_impl_err, Result}; use datafusion_expr::AggregateFunction; use crate::aggregate::average::Avg; -use crate::aggregate::regr::RegrType; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; /// Create a physical aggregation expression. @@ -155,83 +154,6 @@ pub fn create_aggregate_expr( (AggregateFunction::Correlation, true) => { return not_impl_err!("CORR(DISTINCT) aggregations are not available"); } - (AggregateFunction::RegrSlope, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Slope, - data_type, - )), - (AggregateFunction::RegrIntercept, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Intercept, - data_type, - )), - (AggregateFunction::RegrCount, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Count, - data_type, - )), - (AggregateFunction::RegrR2, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::R2, - data_type, - )), - (AggregateFunction::RegrAvgx, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::AvgX, - data_type, - )), - (AggregateFunction::RegrAvgy, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::AvgY, - data_type, - )), - (AggregateFunction::RegrSXX, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SXX, - data_type, - )), - (AggregateFunction::RegrSYY, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SYY, - data_type, - )), - (AggregateFunction::RegrSXY, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SXY, - data_type, - )), - ( - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY, - true, - ) => { - return not_impl_err!("{}(DISTINCT) aggregations are not available", fun); - } (AggregateFunction::ApproxPercentileCont, false) => { if input_phy_exprs.len() == 2 { Arc::new(expressions::ApproxPercentileCont::new( diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 01105c8559c9..9079a81e6241 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -33,7 +33,6 @@ pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; pub(crate) mod groups_accumulator; -pub(crate) mod regr; pub(crate) mod stats; pub(crate) mod stddev; pub(crate) mod variance; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 123ada6d7c86..beba25740501 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -50,7 +50,6 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; -pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::string_agg::StringAgg; pub use crate::window::cume_dist::{cume_dist, CumeDist}; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 31cb0d1da9d5..83223a04d023 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -496,15 +496,15 @@ enum AggregateFunction { BIT_XOR = 21; BOOL_AND = 22; BOOL_OR = 23; - REGR_SLOPE = 26; - REGR_INTERCEPT = 27; - REGR_COUNT = 28; - REGR_R2 = 29; - REGR_AVGX = 30; - REGR_AVGY = 31; - REGR_SXX = 32; - REGR_SYY = 33; - REGR_SXY = 34; + // REGR_SLOPE = 26; + // REGR_INTERCEPT = 27; + // REGR_COUNT = 28; + // REGR_R2 = 29; + // REGR_AVGX = 30; + // REGR_AVGY = 31; + // REGR_SXX = 32; + // REGR_SYY = 33; + // REGR_SXY = 34; STRING_AGG = 35; NTH_VALUE_AGG = 36; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 503f83af65f2..f298dd241abf 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -545,15 +545,6 @@ impl serde::Serialize for AggregateFunction { Self::BitXor => "BIT_XOR", Self::BoolAnd => "BOOL_AND", Self::BoolOr => "BOOL_OR", - Self::RegrSlope => "REGR_SLOPE", - Self::RegrIntercept => "REGR_INTERCEPT", - Self::RegrCount => "REGR_COUNT", - Self::RegrR2 => "REGR_R2", - Self::RegrAvgx => "REGR_AVGX", - Self::RegrAvgy => "REGR_AVGY", - Self::RegrSxx => "REGR_SXX", - Self::RegrSyy => "REGR_SYY", - Self::RegrSxy => "REGR_SXY", Self::StringAgg => "STRING_AGG", Self::NthValueAgg => "NTH_VALUE_AGG", }; @@ -580,15 +571,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR", "BOOL_AND", "BOOL_OR", - "REGR_SLOPE", - "REGR_INTERCEPT", - "REGR_COUNT", - "REGR_R2", - "REGR_AVGX", - "REGR_AVGY", - "REGR_SXX", - "REGR_SYY", - "REGR_SXY", "STRING_AGG", "NTH_VALUE_AGG", ]; @@ -644,15 +626,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR" => Ok(AggregateFunction::BitXor), "BOOL_AND" => Ok(AggregateFunction::BoolAnd), "BOOL_OR" => Ok(AggregateFunction::BoolOr), - "REGR_SLOPE" => Ok(AggregateFunction::RegrSlope), - "REGR_INTERCEPT" => Ok(AggregateFunction::RegrIntercept), - "REGR_COUNT" => Ok(AggregateFunction::RegrCount), - "REGR_R2" => Ok(AggregateFunction::RegrR2), - "REGR_AVGX" => Ok(AggregateFunction::RegrAvgx), - "REGR_AVGY" => Ok(AggregateFunction::RegrAvgy), - "REGR_SXX" => Ok(AggregateFunction::RegrSxx), - "REGR_SYY" => Ok(AggregateFunction::RegrSyy), - "REGR_SXY" => Ok(AggregateFunction::RegrSxy), "STRING_AGG" => Ok(AggregateFunction::StringAgg), "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2c0ea62466b4..fa0217e9ef4f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1950,15 +1950,15 @@ pub enum AggregateFunction { BitXor = 21, BoolAnd = 22, BoolOr = 23, - RegrSlope = 26, - RegrIntercept = 27, - RegrCount = 28, - RegrR2 = 29, - RegrAvgx = 30, - RegrAvgy = 31, - RegrSxx = 32, - RegrSyy = 33, - RegrSxy = 34, + /// REGR_SLOPE = 26; + /// REGR_INTERCEPT = 27; + /// REGR_COUNT = 28; + /// REGR_R2 = 29; + /// REGR_AVGX = 30; + /// REGR_AVGY = 31; + /// REGR_SXX = 32; + /// REGR_SYY = 33; + /// REGR_SXY = 34; StringAgg = 35, NthValueAgg = 36, } @@ -1984,15 +1984,6 @@ impl AggregateFunction { AggregateFunction::BitXor => "BIT_XOR", AggregateFunction::BoolAnd => "BOOL_AND", AggregateFunction::BoolOr => "BOOL_OR", - AggregateFunction::RegrSlope => "REGR_SLOPE", - AggregateFunction::RegrIntercept => "REGR_INTERCEPT", - AggregateFunction::RegrCount => "REGR_COUNT", - AggregateFunction::RegrR2 => "REGR_R2", - AggregateFunction::RegrAvgx => "REGR_AVGX", - AggregateFunction::RegrAvgy => "REGR_AVGY", - AggregateFunction::RegrSxx => "REGR_SXX", - AggregateFunction::RegrSyy => "REGR_SYY", - AggregateFunction::RegrSxy => "REGR_SXY", AggregateFunction::StringAgg => "STRING_AGG", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } @@ -2015,15 +2006,6 @@ impl AggregateFunction { "BIT_XOR" => Some(Self::BitXor), "BOOL_AND" => Some(Self::BoolAnd), "BOOL_OR" => Some(Self::BoolOr), - "REGR_SLOPE" => Some(Self::RegrSlope), - "REGR_INTERCEPT" => Some(Self::RegrIntercept), - "REGR_COUNT" => Some(Self::RegrCount), - "REGR_R2" => Some(Self::RegrR2), - "REGR_AVGX" => Some(Self::RegrAvgx), - "REGR_AVGY" => Some(Self::RegrAvgy), - "REGR_SXX" => Some(Self::RegrSxx), - "REGR_SYY" => Some(Self::RegrSyy), - "REGR_SXY" => Some(Self::RegrSxy), "STRING_AGG" => Some(Self::StringAgg), "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 54a59485c836..ed7b0129cc48 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -147,15 +147,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::BoolOr => Self::BoolOr, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Correlation => Self::Correlation, - protobuf::AggregateFunction::RegrSlope => Self::RegrSlope, - protobuf::AggregateFunction::RegrIntercept => Self::RegrIntercept, - protobuf::AggregateFunction::RegrCount => Self::RegrCount, - protobuf::AggregateFunction::RegrR2 => Self::RegrR2, - protobuf::AggregateFunction::RegrAvgx => Self::RegrAvgx, - protobuf::AggregateFunction::RegrAvgy => Self::RegrAvgy, - protobuf::AggregateFunction::RegrSxx => Self::RegrSXX, - protobuf::AggregateFunction::RegrSyy => Self::RegrSYY, - protobuf::AggregateFunction::RegrSxy => Self::RegrSXY, protobuf::AggregateFunction::ApproxPercentileCont => { Self::ApproxPercentileCont } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 80ce05d151ee..04f7b596fea8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -118,15 +118,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::BoolOr => Self::BoolOr, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Correlation => Self::Correlation, - AggregateFunction::RegrSlope => Self::RegrSlope, - AggregateFunction::RegrIntercept => Self::RegrIntercept, - AggregateFunction::RegrCount => Self::RegrCount, - AggregateFunction::RegrR2 => Self::RegrR2, - AggregateFunction::RegrAvgx => Self::RegrAvgx, - AggregateFunction::RegrAvgy => Self::RegrAvgy, - AggregateFunction::RegrSXX => Self::RegrSxx, - AggregateFunction::RegrSYY => Self::RegrSyy, - AggregateFunction::RegrSXY => Self::RegrSxy, AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight @@ -408,21 +399,6 @@ pub fn serialize_expr( AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::NthValue => { protobuf::AggregateFunction::NthValueAgg diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e25447b023d8..ef462ac94b9a 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -27,8 +27,8 @@ use datafusion::physical_plan::expressions::{ BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, - OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, StringAgg, - TryCastExpr, WindowShift, + OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, TryCastExpr, + WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -270,18 +270,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Avg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Correlation - } else if let Some(regr_expr) = aggr_expr.downcast_ref::() { - match regr_expr.get_regr_type() { - RegrType::Slope => protobuf::AggregateFunction::RegrSlope, - RegrType::Intercept => protobuf::AggregateFunction::RegrIntercept, - RegrType::Count => protobuf::AggregateFunction::RegrCount, - RegrType::R2 => protobuf::AggregateFunction::RegrR2, - RegrType::AvgX => protobuf::AggregateFunction::RegrAvgx, - RegrType::AvgY => protobuf::AggregateFunction::RegrAvgy, - RegrType::SXX => protobuf::AggregateFunction::RegrSxx, - RegrType::SYY => protobuf::AggregateFunction::RegrSyy, - RegrType::SXY => protobuf::AggregateFunction::RegrSxy, - } } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ApproxPercentileCont } else if aggr_expr diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index c7b9808c249d..d51c69496d46 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -112,11 +112,11 @@ statement error DataFusion error: Error during planning: No function matches the select avg(c1, c12) from aggregate_test_100; # AggregateFunction with wrong argument type -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Int64, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error Coercion select regr_slope(1, '2'); # WindowFunction using AggregateFunction wrong signature -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Float32, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error Coercion select c9, regr_slope(c11, '2') over () as min1 From 8f76ac553a4a05705cb92eb4fe46158b4bfa354b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 14 Jun 2024 11:54:31 -0400 Subject: [PATCH 15/34] Docs: clarify when the reader will read from object store when using cached metadata (#10909) --- .../src/datasource/physical_plan/parquet/reader.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/reader.rs b/datafusion/core/src/datasource/physical_plan/parquet/reader.rs index 265fb9d570cc..8a4ba136fc96 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/reader.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/reader.rs @@ -16,7 +16,7 @@ // under the License. //! [`ParquetFileReaderFactory`] and [`DefaultParquetFileReaderFactory`] for -//! creating parquet file readers +//! low level control of parquet file readers use crate::datasource::physical_plan::{FileMeta, ParquetFileMetrics}; use bytes::Bytes; @@ -33,12 +33,19 @@ use std::sync::Arc; /// /// The combined implementations of [`ParquetFileReaderFactory`] and /// [`AsyncFileReader`] can be used to provide custom data access operations -/// such as pre-cached data, I/O coalescing, etc. +/// such as pre-cached metadata, I/O coalescing, etc. /// /// See [`DefaultParquetFileReaderFactory`] for a simple implementation. pub trait ParquetFileReaderFactory: Debug + Send + Sync + 'static { /// Provides an `AsyncFileReader` for reading data from a parquet file specified /// + /// # Notes + /// + /// If the resulting [`AsyncFileReader`] returns `ParquetMetaData` without + /// page index information, the reader will load it on demand. Thus it is important + /// to ensure that the returned `ParquetMetaData` has the necessary information + /// if you wish to avoid a subsequent I/O + /// /// # Arguments /// * partition_index - Index of the partition (for reporting metrics) /// * file_meta - The file to be read From 4dd41219011476ff45b9a73d7befbc7a1cbcccc6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 14 Jun 2024 11:55:38 -0400 Subject: [PATCH 16/34] Minor: Fix `bench.sh tpch data` (#10905) --- benchmarks/bench.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 77779a12c450..62c0e925db74 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -302,7 +302,7 @@ data_tpch() { else echo " creating parquet files using benchmark binary ..." pushd "${SCRIPT_DIR}" > /dev/null - $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --prefer_hash_join ${PREFER_HASH_JOIN} --output "${TPCH_DIR}" --format parquet + $CARGO_COMMAND --bin tpch -- convert --input "${TPCH_DIR}" --output "${TPCH_DIR}" --format parquet popd > /dev/null fi } From 38bd8932fdebfaf203ea4249f4ee84f859c953f2 Mon Sep 17 00:00:00 2001 From: tmi Date: Fri, 14 Jun 2024 17:57:15 +0200 Subject: [PATCH 17/34] Minor: use venv in benchmark compare (#10894) Co-authored-by: Andrew Lamb --- benchmarks/.gitignore | 3 ++- benchmarks/bench.sh | 16 ++++++++++++++-- benchmarks/compare.py | 2 +- benchmarks/requirements.txt | 18 ++++++++++++++++++ 4 files changed, 35 insertions(+), 4 deletions(-) create mode 100644 benchmarks/requirements.txt diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore index 2c574ff30d12..c35b1a7c1944 100644 --- a/benchmarks/.gitignore +++ b/benchmarks/.gitignore @@ -1,2 +1,3 @@ data -results \ No newline at end of file +results +venv diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 62c0e925db74..903fcb940b3e 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -37,6 +37,7 @@ DATA_DIR=${DATA_DIR:-$SCRIPT_DIR/data} #CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --release"} CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --profile release-nonlto"} # for faster iterations PREFER_HASH_JOIN=${PREFER_HASH_JOIN:-true} +VIRTUAL_ENV=${VIRTUAL_ENV:-$SCRIPT_DIR/venv} usage() { echo " @@ -46,6 +47,7 @@ Usage: $0 data [benchmark] $0 run [benchmark] $0 compare +$0 venv ********** Examples: @@ -62,6 +64,7 @@ DATAFUSION_DIR=/source/datafusion ./bench.sh run tpch data: Generates or downloads data needed for benchmarking run: Runs the named benchmark compare: Compares results from benchmark runs +venv: Creates new venv (unless already exists) and installs compare's requirements into it ********** * Benchmarks @@ -84,7 +87,8 @@ DATA_DIR directory to store datasets CARGO_COMMAND command that runs the benchmark binary DATAFUSION_DIR directory to use (default $DATAFUSION_DIR) RESULTS_NAME folder where the benchmark files are stored -PREFER_HASH_JOIN Prefer hash join algorithm(default true) +PREFER_HASH_JOIN Prefer hash join algorithm (default true) +VENV_PATH Python venv to use for compare and venv commands (default ./venv, override by /bin/activate) " exit 1 } @@ -243,6 +247,9 @@ main() { compare) compare_benchmarks "$ARG2" "$ARG3" ;; + venv) + setup_venv + ;; "") usage ;; @@ -448,7 +455,7 @@ compare_benchmarks() { echo "--------------------" echo "Benchmark ${bench}" echo "--------------------" - python3 "${SCRIPT_DIR}"/compare.py "${RESULTS_FILE1}" "${RESULTS_FILE2}" + PATH=$VIRTUAL_ENV/bin:$PATH python3 "${SCRIPT_DIR}"/compare.py "${RESULTS_FILE1}" "${RESULTS_FILE2}" else echo "Note: Skipping ${RESULTS_FILE1} as ${RESULTS_FILE2} does not exist" fi @@ -456,5 +463,10 @@ compare_benchmarks() { } +setup_venv() { + python3 -m venv $VIRTUAL_ENV + PATH=$VIRTUAL_ENV/bin:$PATH python3 -m pip install -r requirements.txt +} + # And start the process up main diff --git a/benchmarks/compare.py b/benchmarks/compare.py index ec2b28fa0556..2574c0735ca8 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -29,7 +29,7 @@ from rich.console import Console from rich.table import Table except ImportError: - print("Try `pip install rich` for using this script.") + print("Couldn't import modules -- run `./bench.sh venv` first") raise diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt new file mode 100644 index 000000000000..20a5a2bddbf2 --- /dev/null +++ b/benchmarks/requirements.txt @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +rich From 0203a1a21d2fe5c53a039ff59461e66e4d5fc6dd Mon Sep 17 00:00:00 2001 From: Duong Cong Toai <35887761+duongcongtoai@users.noreply.github.com> Date: Fri, 14 Jun 2024 22:30:44 +0200 Subject: [PATCH 18/34] Support explicit type and name during table creation (#10273) * temp cargo commit * chore: update test * fmt * update cli cargo lock --- datafusion/sql/src/planner.rs | 20 +++++++++++++- datafusion/sqllogictest/test_files/struct.slt | 27 +++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 0f04281aa23b..a92e64597e82 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -439,6 +439,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLDataType::Bytea => Ok(DataType::Binary), SQLDataType::Interval => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + SQLDataType::Struct(fields) => { + let fields = fields + .iter() + .enumerate() + .map(|(idx, field)| { + let data_type = self.convert_data_type(&field.field_type)?; + let field_name = match &field.field_name{ + Some(ident) => ident.clone(), + None => Ident::new(format!("c{idx}")) + }; + Ok(Arc::new(Field::new( + self.normalizer.normalize(field_name), + data_type, + true, + ))) + }) + .collect::>>()?; + Ok(DataType::Struct(Fields::from(fields))) + } // Explicitly list all other types so that if sqlparser // adds/changes the `SQLDataType` the compiler will tell us on upgrade // and avoid bugs like https://github.com/apache/datafusion/issues/3059 @@ -472,7 +491,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::Bytes(_) | SQLDataType::Int64 | SQLDataType::Float64 - | SQLDataType::Struct(_) | SQLDataType::JSONB | SQLDataType::Unspecified => not_impl_err!( diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index 46a08709c3a3..749daa7e20e7 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -31,6 +31,33 @@ CREATE TABLE values( (3, 3.3, 'c', NULL) ; + +# named and named less struct fields +statement ok +CREATE TABLE struct_values ( + s1 struct, + s2 struct +) AS VALUES + (struct(1), struct(1, 'string1')), + (struct(2), struct(2, 'string2')), + (struct(3), struct(3, 'string3')) +; + +query ?? +select * from struct_values; +---- +{c0: 1} {a: 1, b: string1} +{c0: 2} {a: 2, b: string2} +{c0: 3} {a: 3, b: string3} + +query TT +select arrow_typeof(s1), arrow_typeof(s2) from struct_values; +---- +Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + + # struct[i] query IRT select struct(1, 3.14, 'h')['c0'], struct(3, 2.55, 'b')['c1'], struct(2, 6.43, 'a')['c2']; From 9ab597b2511b23dc951c9f35d00d38c2071e9fea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Fri, 14 Jun 2024 23:33:19 +0300 Subject: [PATCH 19/34] Simplify Join Partition Rules (#10911) * clean-up * add asym case * fix import errors * Update utils.rs * Update hash_join.rs --- datafusion/core/src/dataframe/mod.rs | 11 +-- .../physical-plan/src/joins/hash_join.rs | 38 +++-------- .../src/joins/nested_loop_join.rs | 32 +++------ .../src/joins/sort_merge_join.rs | 37 +++++----- .../src/joins/symmetric_hash_join.rs | 12 +--- datafusion/physical-plan/src/joins/utils.rs | 67 +++++++++++++------ 6 files changed, 90 insertions(+), 107 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 950cb7ddb2d3..b5c58eff577c 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -3100,10 +3100,7 @@ mod tests { let join_schema = physical_plan.schema(); match join_type { - JoinType::Inner - | JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti => { + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { let left_exprs: Vec> = vec![ Arc::new(Column::new_with_schema("c1", &join_schema)?), Arc::new(Column::new_with_schema("c2", &join_schema)?), @@ -3113,7 +3110,10 @@ mod tests { &Partitioning::Hash(left_exprs, default_partition_count) ); } - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + JoinType::Inner + | JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti => { let right_exprs: Vec> = vec![ Arc::new(Column::new_with_schema("c2_c1", &join_schema)?), Arc::new(Column::new_with_schema("c2_c2", &join_schema)?), @@ -3133,6 +3133,7 @@ mod tests { Ok(()) } + #[tokio::test] async fn nested_explain_should_fail() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index cd66ab093f88..5353092d5c45 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use std::task::Poll; use std::{any::Any, vec}; +use super::utils::asymmetric_join_output_partitioning; use super::{ utils::{OnceAsync, OnceFut}, PartitionMode, @@ -34,10 +35,10 @@ use crate::{ execution_mode_from_children, handle_state, hash_utils::create_hashes, joins::utils::{ - adjust_indices_by_join_type, adjust_right_output_partitioning, - apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, estimate_join_statistics, get_final_indices_from_bit_map, - need_produce_result_in_final, partitioned_join_output_partitioning, + adjust_indices_by_join_type, apply_join_filter_to_indices, + build_batch_from_indices, build_join_schema, check_join_is_valid, + estimate_join_statistics, get_final_indices_from_bit_map, + need_produce_result_in_final, symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMap, JoinHashMapOffset, JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, }, @@ -490,33 +491,16 @@ impl HashJoinExec { on, ); - // Get output partitioning: - let left_columns_len = left.schema().fields.len(); let mut output_partitioning = match mode { - PartitionMode::CollectLeft => match join_type { - JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( - right.output_partitioning(), - left_columns_len, - ), - JoinType::RightSemi | JoinType::RightAnti => { - right.output_partitioning().clone() - } - JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::Full => Partitioning::UnknownPartitioning( - right.output_partitioning().partition_count(), - ), - }, - PartitionMode::Partitioned => partitioned_join_output_partitioning( - join_type, - left.output_partitioning(), - right.output_partitioning(), - left_columns_len, - ), + PartitionMode::CollectLeft => { + asymmetric_join_output_partitioning(left, right, &join_type) + } PartitionMode::Auto => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), + PartitionMode::Partitioned => { + symmetric_join_output_partitioning(left, right, &join_type) + } }; // Determine execution mode by checking whether this join is pipeline diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 18518600ef2f..6be124cce06f 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -25,18 +25,19 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; +use super::utils::{asymmetric_join_output_partitioning, need_produce_result_in_final}; use crate::coalesce_batches::concat_batches; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::utils::{ - adjust_indices_by_join_type, adjust_right_output_partitioning, - apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, estimate_join_statistics, get_final_indices_from_bit_map, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceAsync, OnceFut, + adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, + build_join_schema, check_join_is_valid, estimate_join_statistics, + get_final_indices_from_bit_map, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + OnceAsync, OnceFut, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution, - ExecutionMode, ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, + ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; @@ -55,8 +56,6 @@ use datafusion_physical_expr::equivalence::join_equivalence_properties; use futures::{ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; -use super::utils::need_produce_result_in_final; - /// Shared bitmap for visited left-side indices type SharedBitmapBuilder = Mutex; /// Left (build-side) data @@ -228,21 +227,8 @@ impl NestedLoopJoinExec { &[], ); - // Get output partitioning, - let output_partitioning = match join_type { - JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( - right.output_partitioning(), - left.schema().fields().len(), - ), - JoinType::RightSemi | JoinType::RightAnti => { - right.output_partitioning().clone() - } - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Full => { - Partitioning::UnknownPartitioning( - right.output_partitioning().partition_count(), - ) - } - }; + let output_partitioning = + asymmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: let mut mode = execution_mode_from_children([left, right]); @@ -673,7 +659,7 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; - use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr::{Partitioning, PhysicalExpr}; fn build_table( a: (&str, &Vec), diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 8da345cdfca6..01abb30181d0 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -30,12 +30,22 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use crate::expressions::PhysicalSortExpr; +use crate::joins::utils::{ + build_join_schema, check_join_is_valid, estimate_join_statistics, + symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, +}; +use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::{ + execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, + ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, Statistics, +}; + use arrow::array::*; use arrow::compute::{self, concat_batches, take, SortOptions}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; -use futures::{Stream, StreamExt}; -use hashbrown::HashSet; use datafusion_common::{ internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, @@ -45,17 +55,8 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; -use crate::expressions::PhysicalSortExpr; -use crate::joins::utils::{ - build_join_schema, check_join_is_valid, estimate_join_statistics, - partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, -}; -use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use crate::{ - execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, - ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, -}; +use futures::{Stream, StreamExt}; +use hashbrown::HashSet; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. @@ -220,14 +221,8 @@ impl SortMergeJoinExec { join_on, ); - // Get output partitioning: - let left_columns_len = left.schema().fields.len(); - let output_partitioning = partitioned_join_output_partitioning( - join_type, - left.output_partitioning(), - right.output_partitioning(), - left_columns_len, - ); + let output_partitioning = + symmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: let mode = execution_mode_from_children([left, right]); diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index e11e6dd2f627..813f670147bc 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -42,7 +42,7 @@ use crate::joins::stream_join_utils::{ }; use crate::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, partitioned_join_output_partitioning, ColumnIndex, JoinFilter, + check_join_is_valid, symmetric_join_output_partitioning, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, }; use crate::{ @@ -271,14 +271,8 @@ impl SymmetricHashJoinExec { join_on, ); - // Get output partitioning: - let left_columns_len = left.schema().fields.len(); - let output_partitioning = partitioned_join_output_partitioning( - join_type, - left.output_partitioning(), - right.output_partitioning(), - left_columns_len, - ); + let output_partitioning = + symmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: let mode = execution_mode_from_children([left, right]); diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index c08b0e3d091c..dfa1fd4763f4 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -25,7 +25,9 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; -use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; +use crate::{ + ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics, +}; use arrow::array::{ downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array, @@ -428,27 +430,6 @@ fn check_join_set_is_valid( Ok(()) } -/// Calculate the OutputPartitioning for Partitioned Join -pub fn partitioned_join_output_partitioning( - join_type: JoinType, - left_partitioning: &Partitioning, - right_partitioning: &Partitioning, - left_columns_len: usize, -) -> Partitioning { - match join_type { - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { - left_partitioning.clone() - } - JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), - JoinType::Right => { - adjust_right_output_partitioning(right_partitioning, left_columns_len) - } - JoinType::Full => { - Partitioning::UnknownPartitioning(right_partitioning.partition_count()) - } - } -} - /// Adjust the right out partitioning to new Column Index pub fn adjust_right_output_partitioning( right_partitioning: &Partitioning, @@ -1539,6 +1520,48 @@ pub enum StatefulStreamResult { Continue, } +pub(crate) fn symmetric_join_output_partitioning( + left: &Arc, + right: &Arc, + join_type: &JoinType, +) -> Partitioning { + let left_columns_len = left.schema().fields.len(); + let left_partitioning = left.output_partitioning(); + let right_partitioning = right.output_partitioning(); + match join_type { + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { + left_partitioning.clone() + } + JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), + JoinType::Inner | JoinType::Right => { + adjust_right_output_partitioning(right_partitioning, left_columns_len) + } + JoinType::Full => { + // We could also use left partition count as they are necessarily equal. + Partitioning::UnknownPartitioning(right_partitioning.partition_count()) + } + } +} + +pub(crate) fn asymmetric_join_output_partitioning( + left: &Arc, + right: &Arc, + join_type: &JoinType, +) -> Partitioning { + match join_type { + JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( + right.output_partitioning(), + left.schema().fields().len(), + ), + JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Full => { + Partitioning::UnknownPartitioning( + right.output_partitioning().partition_count(), + ) + } + } +} + #[cfg(test)] mod tests { use std::pin::Pin; From e711775f08cccaed797dbda77f9d7448d1568b06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Sat, 15 Jun 2024 08:01:48 +0800 Subject: [PATCH 20/34] Move Literal to physical-expr-common (#10910) --- .../src/expressions/literal.rs | 3 +-- datafusion/physical-expr-common/src/expressions/mod.rs | 2 ++ datafusion/physical-expr/src/expressions/mod.rs | 3 +-- 3 files changed, 4 insertions(+), 4 deletions(-) rename datafusion/{physical-expr => physical-expr-common}/src/expressions/literal.rs (98%) diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr-common/src/expressions/literal.rs similarity index 98% rename from datafusion/physical-expr/src/expressions/literal.rs rename to datafusion/physical-expr-common/src/expressions/literal.rs index fcaf229af0a8..b3cff1ef69ba 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr-common/src/expressions/literal.rs @@ -21,8 +21,7 @@ use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::down_cast_any_ref; -use crate::PhysicalExpr; +use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; use arrow::{ datatypes::{DataType, Schema}, diff --git a/datafusion/physical-expr-common/src/expressions/mod.rs b/datafusion/physical-expr-common/src/expressions/mod.rs index 4b5965e164b5..ea21c8e9a92b 100644 --- a/datafusion/physical-expr-common/src/expressions/mod.rs +++ b/datafusion/physical-expr-common/src/expressions/mod.rs @@ -17,5 +17,7 @@ mod cast; pub mod column; +mod literal; pub use cast::{cast, cast_with_options, CastExpr}; +pub use literal::{lit, Literal}; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index beba25740501..592393f800d0 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -26,7 +26,6 @@ mod in_list; mod is_not_null; mod is_null; mod like; -mod literal; mod negative; mod no_op; mod not; @@ -67,11 +66,11 @@ pub use datafusion_expr::utils::format_state_name; pub use datafusion_functions_aggregate::first_last::{FirstValue, LastValue}; pub use datafusion_physical_expr_common::expressions::column::{col, Column}; pub use datafusion_physical_expr_common::expressions::{cast, CastExpr}; +pub use datafusion_physical_expr_common::expressions::{lit, Literal}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; pub use like::{like, LikeExpr}; -pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; From ebca68109dc05f163ca1faf111e90fb9eebd5083 Mon Sep 17 00:00:00 2001 From: Jeffrey Smith II Date: Fri, 14 Jun 2024 20:46:16 -0400 Subject: [PATCH 21/34] chore: update some error messages for clarity (#10916) * chore: update some error messages for clarity * chore: update error messages in tests --- datafusion/expr/src/expr_schema.rs | 6 +++--- datafusion/expr/src/type_coercion/functions.rs | 14 +++----------- datafusion/sqllogictest/test_files/array.slt | 4 ++-- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 7ea0313bf776..986f85adebaa 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -141,7 +141,7 @@ impl ExprSchemable for Expr { // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| { plan_datafusion_err!( - "{} and {}", + "{} {}", err, utils::generate_signature_error_msg( func.name(), @@ -164,7 +164,7 @@ impl ExprSchemable for Expr { WindowFunctionDefinition::AggregateUDF(udf) => { let new_types = data_types_with_aggregate_udf(&data_types, udf).map_err(|err| { plan_datafusion_err!( - "{} and {}", + "{} {}", err, utils::generate_signature_error_msg( fun.name(), @@ -192,7 +192,7 @@ impl ExprSchemable for Expr { AggregateFunctionDefinition::UDF(fun) => { let new_types = data_types_with_aggregate_udf(&data_types, fun).map_err(|err| { plan_datafusion_err!( - "{} and {}", + "{} {}", err, utils::generate_signature_error_msg( fun.name(), diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 4dd8d6371934..5f060a4a4f16 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -49,10 +49,7 @@ pub fn data_types_with_scalar_udf( if signature.type_signature.supports_zero_argument() { return Ok(vec![]); } else { - return plan_err!( - "[data_types_with_scalar_udf] signature {:?} does not support zero arguments.", - &signature.type_signature - ); + return plan_err!("{} does not support zero arguments.", func.name()); } } @@ -79,11 +76,7 @@ pub fn data_types_with_aggregate_udf( if signature.type_signature.supports_zero_argument() { return Ok(vec![]); } else { - return plan_err!( - "[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.", - current_types, - &signature.type_signature - ); + return plan_err!("{} does not support zero arguments.", func.name()); } } @@ -118,8 +111,7 @@ pub fn data_types( return Ok(vec![]); } else { return plan_err!( - "[data_types] Coercion from {:?} to the signature {:?} failed.", - current_types, + "signature {:?} does not support zero arguments.", &signature.type_signature ); } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index c62c1ce29c06..55a430767c76 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1137,7 +1137,7 @@ from arrays_values_without_nulls; ## array_element (aliases: array_extract, list_extract, list_element) # Testing with empty arguments should result in an error -query error DataFusion error: Error during planning: Error during planning: \[data_types_with_scalar_udf\] signature ArraySignature\(ArrayAndIndex\) does not support zero arguments. +query error DataFusion error: Error during planning: Error during planning: array_element does not support zero arguments. select array_element(); # array_element error @@ -1979,7 +1979,7 @@ select array_slice(a, -1, 2, 1), array_slice(a, -1, 2), [6.0] [6.0] [] [] # Testing with empty arguments should result in an error -query error DataFusion error: Error during planning: Error during planning: \[data_types_with_scalar_udf\] signature VariadicAny does not support zero arguments. +query error DataFusion error: Error during planning: Error during planning: array_slice does not support zero arguments. select array_slice(); From 2f4347647172f6997448b2e24d322b50c856f3a0 Mon Sep 17 00:00:00 2001 From: Marvin Lanhenke <62298609+marvinlanhenke@users.noreply.github.com> Date: Sat, 15 Jun 2024 17:58:44 +0200 Subject: [PATCH 22/34] Initial Extract parquet data page statistics API (#10852) * feat: enable page statistics * feat: prototype int64 data_page_min * feat: prototype MinInt64DataPageStatsIterator * feat: add make_data_page_stats_iterator macro * feat: add get_data_page_statistics macro * feat: add MaxInt64DataPageStatsIterator * feat: add test_data_page_stats param * chore: add testcase int64_with_nulls * feat: add data page null_counts * fix: clippy * chore: rename column_page_index * feat: add data page row counts * feat: add num_data_pages to iterator * chore: update docs * fix: use colum_offset len in data_page_null_counts * fix: docs * tweak comments * update test helper * Add explicit multi-data page tests to statistics test * Add explicit data page test * remove duplicate test * update coverage --------- Co-authored-by: Andrew Lamb --- .../physical_plan/parquet/statistics.rs | 315 +++++++++++- .../core/tests/parquet/arrow_statistics.rs | 479 +++++++++++++----- datafusion/core/tests/parquet/mod.rs | 3 +- 3 files changed, 657 insertions(+), 140 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index c0d36f1fc4d7..a2e0d8fa66be 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -33,7 +33,8 @@ use arrow_array::{ use arrow_schema::{Field, FieldRef, Schema, TimeUnit}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; use half::f16; -use parquet::file::metadata::RowGroupMetaData; +use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex, RowGroupMetaData}; +use parquet::file::page_index::index::Index; use parquet::file::statistics::Statistics as ParquetStatistics; use parquet::schema::types::SchemaDescriptor; use paste::paste; @@ -517,6 +518,74 @@ macro_rules! get_statistics { }}} } +macro_rules! make_data_page_stats_iterator { + ($iterator_type: ident, $func: ident, $index_type: path, $stat_value_type: ty) => { + struct $iterator_type<'a, I> + where + I: Iterator, + { + iter: I, + } + + impl<'a, I> $iterator_type<'a, I> + where + I: Iterator, + { + fn new(iter: I) -> Self { + Self { iter } + } + } + + impl<'a, I> Iterator for $iterator_type<'a, I> + where + I: Iterator, + { + type Item = Vec>; + + fn next(&mut self) -> Option { + let next = self.iter.next(); + match next { + Some((len, index)) => match index { + $index_type(native_index) => Some( + native_index + .indexes + .iter() + .map(|x| x.$func) + .collect::>(), + ), + // No matching `Index` found; + // thus no statistics that can be extracted. + // We return vec![None; len] to effectively + // create an arrow null-array with the length + // corresponding to the number of entries in + // `ParquetOffsetIndex` per row group per column. + _ => Some(vec![None; len]), + }, + _ => None, + } + } + + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } + } + }; +} + +make_data_page_stats_iterator!(MinInt64DataPageStatsIterator, min, Index::INT64, i64); +make_data_page_stats_iterator!(MaxInt64DataPageStatsIterator, max, Index::INT64, i64); + +macro_rules! get_data_page_statistics { + ($stat_type_prefix: ident, $data_type: ident, $iterator: ident) => { + paste! { + match $data_type { + Some(DataType::Int64) => Ok(Arc::new(Int64Array::from_iter([<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten()))), + _ => unimplemented!() + } + } + } +} + /// Lookups up the parquet column by name /// /// Returns the parquet column index and the corresponding arrow field @@ -563,6 +632,51 @@ fn max_statistics<'a, I: Iterator>>( get_statistics!(Max, data_type, iterator) } +/// Extracts the min statistics from an iterator +/// of parquet page [`Index`]'es to an [`ArrayRef`] +pub(crate) fn min_page_statistics<'a, I>( + data_type: Option<&DataType>, + iterator: I, +) -> Result +where + I: Iterator, +{ + get_data_page_statistics!(Min, data_type, iterator) +} + +/// Extracts the max statistics from an iterator +/// of parquet page [`Index`]'es to an [`ArrayRef`] +pub(crate) fn max_page_statistics<'a, I>( + data_type: Option<&DataType>, + iterator: I, +) -> Result +where + I: Iterator, +{ + get_data_page_statistics!(Max, data_type, iterator) +} + +/// Extracts the null count statistics from an iterator +/// of parquet page [`Index`]'es to an [`ArrayRef`] +/// +/// The returned Array is an [`UInt64Array`] +pub(crate) fn null_counts_page_statistics<'a, I>(iterator: I) -> Result +where + I: Iterator, +{ + let iter = iterator.flat_map(|(len, index)| match index { + Index::NONE => vec![None; len], + Index::INT64(native_index) => native_index + .indexes + .iter() + .map(|x| x.null_count.map(|x| x as u64)) + .collect::>(), + _ => unimplemented!(), + }); + + Ok(Arc::new(UInt64Array::from_iter(iter))) +} + /// Extracts Parquet statistics as Arrow arrays /// /// This is used to convert Parquet statistics to Arrow arrays, with proper type @@ -771,10 +885,205 @@ impl<'a> StatisticsConverter<'a> { Ok(Arc::new(UInt64Array::from_iter(null_counts))) } + /// Extract the minimum values from Data Page statistics. + /// + /// In Parquet files, in addition to the Column Chunk level statistics + /// (stored for each column for each row group) there are also + /// optional statistics stored for each data page, as part of + /// the [`ParquetColumnIndex`]. + /// + /// Since a single Column Chunk is stored as one or more pages, + /// page level statistics can prune at a finer granularity. + /// + /// However since they are stored in a separate metadata + /// structure ([`Index`]) there is different code to extract them as + /// compared to arrow statistics. + /// + /// # Parameters: + /// + /// * `column_page_index`: The parquet column page indices, read from + /// `ParquetMetaData` column_index + /// + /// * `column_offset_index`: The parquet column offset indices, read from + /// `ParquetMetaData` offset_index + /// + /// * `row_group_indices`: The indices of the row groups, that are used to + /// extract the column page index and offset index on a per row group + /// per column basis. + /// + /// # Return Value + /// + /// The returned array contains 1 value for each `NativeIndex` + /// in the underlying `Index`es, in the same order as they appear + /// in `metadatas`. + /// + /// For example, if there are two `Index`es in `metadatas`: + /// 1. the first having `3` `PageIndex` entries + /// 2. the second having `2` `PageIndex` entries + /// + /// The returned array would have 5 rows. + /// + /// Each value is either: + /// * the minimum value for the page + /// * a null value, if the statistics can not be extracted + /// + /// Note that a null value does NOT mean the min value was actually + /// `null` it means it the requested statistic is unknown + /// + /// # Errors + /// + /// Reasons for not being able to extract the statistics include: + /// * the column is not present in the parquet file + /// * statistics for the pages are not present in the row group + /// * the stored statistic value can not be converted to the requested type + pub fn data_page_mins( + &self, + column_page_index: &ParquetColumnIndex, + column_offset_index: &ParquetOffsetIndex, + row_group_indices: I, + ) -> Result + where + I: IntoIterator, + { + let data_type = self.arrow_field.data_type(); + + let Some(parquet_index) = self.parquet_index else { + return Ok(self.make_null_array(data_type, row_group_indices)); + }; + + let iter = row_group_indices.into_iter().map(|rg_index| { + let column_page_index_per_row_group_per_column = + &column_page_index[*rg_index][parquet_index]; + let num_data_pages = &column_offset_index[*rg_index][parquet_index].len(); + + (*num_data_pages, column_page_index_per_row_group_per_column) + }); + + min_page_statistics(Some(data_type), iter) + } + + /// Extract the maximum values from Data Page statistics. + /// + /// See docs on [`Self::data_page_mins`] for details. + pub fn data_page_maxes( + &self, + column_page_index: &ParquetColumnIndex, + column_offset_index: &ParquetOffsetIndex, + row_group_indices: I, + ) -> Result + where + I: IntoIterator, + { + let data_type = self.arrow_field.data_type(); + + let Some(parquet_index) = self.parquet_index else { + return Ok(self.make_null_array(data_type, row_group_indices)); + }; + + let iter = row_group_indices.into_iter().map(|rg_index| { + let column_page_index_per_row_group_per_column = + &column_page_index[*rg_index][parquet_index]; + let num_data_pages = &column_offset_index[*rg_index][parquet_index].len(); + + (*num_data_pages, column_page_index_per_row_group_per_column) + }); + + max_page_statistics(Some(data_type), iter) + } + + /// Extract the null counts from Data Page statistics. + /// + /// The returned Array is an [`UInt64Array`] + /// + /// See docs on [`Self::data_page_mins`] for details. + pub fn data_page_null_counts( + &self, + column_page_index: &ParquetColumnIndex, + column_offset_index: &ParquetOffsetIndex, + row_group_indices: I, + ) -> Result + where + I: IntoIterator, + { + let data_type = self.arrow_field.data_type(); + + let Some(parquet_index) = self.parquet_index else { + return Ok(self.make_null_array(data_type, row_group_indices)); + }; + + let iter = row_group_indices.into_iter().map(|rg_index| { + let column_page_index_per_row_group_per_column = + &column_page_index[*rg_index][parquet_index]; + let num_data_pages = &column_offset_index[*rg_index][parquet_index].len(); + + (*num_data_pages, column_page_index_per_row_group_per_column) + }); + null_counts_page_statistics(iter) + } + + /// Returns an [`ArrayRef`] with row counts for each row group. + /// + /// This function iterates over the given row group indexes and computes + /// the row count for each page in the specified column. + /// + /// # Parameters: + /// + /// * `column_offset_index`: The parquet column offset indices, read from + /// `ParquetMetaData` offset_index + /// + /// * `row_group_metadatas`: The metadata slice of the row groups, read + /// from `ParquetMetaData` row_groups + /// + /// * `row_group_indices`: The indices of the row groups, that are used to + /// extract the column offset index on a per row group per column basis. + /// + /// See docs on [`Self::data_page_mins`] for details. + pub fn data_page_row_counts( + &self, + column_offset_index: &ParquetOffsetIndex, + row_group_metadatas: &[RowGroupMetaData], + row_group_indices: I, + ) -> Result + where + I: IntoIterator, + { + let data_type = self.arrow_field.data_type(); + + let Some(parquet_index) = self.parquet_index else { + return Ok(self.make_null_array(data_type, row_group_indices)); + }; + + // `offset_index[row_group_number][column_number][page_number]` holds + // the [`PageLocation`] corresponding to page `page_number` of column + // `column_number`of row group `row_group_number`. + let mut row_count_total = Vec::new(); + for rg_idx in row_group_indices { + let page_locations = &column_offset_index[*rg_idx][parquet_index]; + + let row_count_per_page = page_locations.windows(2).map(|loc| { + Some(loc[1].first_row_index as u64 - loc[0].first_row_index as u64) + }); + + let num_rows_in_row_group = &row_group_metadatas[*rg_idx].num_rows(); + + // append the last page row count + let row_count_per_page = row_count_per_page + .chain(std::iter::once(Some( + *num_rows_in_row_group as u64 + - page_locations.last().unwrap().first_row_index as u64, + ))) + .collect::>(); + + row_count_total.extend(row_count_per_page); + } + + Ok(Arc::new(UInt64Array::from_iter(row_count_total))) + } + /// Returns a null array of data_type with one element per row group - fn make_null_array(&self, data_type: &DataType, metadatas: I) -> ArrayRef + fn make_null_array(&self, data_type: &DataType, metadatas: I) -> ArrayRef where - I: IntoIterator, + I: IntoIterator, { // column was in the arrow schema but not in the parquet schema, so return a null array let num_row_groups = metadatas.into_iter().count(); diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 2ea18d7cf823..3c812800e2b7 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -18,6 +18,7 @@ //! This file contains an end to end test of extracting statitics from parquet files. //! It writes data into a parquet file, reads statistics and verifies they are correct +use std::default::Default; use std::fs::File; use std::sync::Arc; @@ -39,102 +40,102 @@ use arrow_array::{ use arrow_schema::{DataType, Field, Schema}; use datafusion::datasource::physical_plan::parquet::StatisticsConverter; use half::f16; -use parquet::arrow::arrow_reader::{ArrowReaderBuilder, ParquetRecordBatchReaderBuilder}; +use parquet::arrow::arrow_reader::{ + ArrowReaderBuilder, ArrowReaderOptions, ParquetRecordBatchReaderBuilder, +}; use parquet::arrow::ArrowWriter; use parquet::file::properties::{EnabledStatistics, WriterProperties}; use super::make_test_file_rg; -// TEST HELPERS - -/// Return a record batch with i64 with Null values -fn make_int64_batches_with_null( +#[derive(Debug, Default, Clone)] +struct Int64Case { + /// Number of nulls in the column null_values: usize, + /// Non null values in the range `[no_null_values_start, + /// no_null_values_end]`, one value for each row no_null_values_start: i64, no_null_values_end: i64, -) -> RecordBatch { - let schema = Arc::new(Schema::new(vec![Field::new("i64", DataType::Int64, true)])); - - let v64: Vec = (no_null_values_start as _..no_null_values_end as _).collect(); - - RecordBatch::try_new( - schema, - vec![make_array( - Int64Array::from_iter( - v64.into_iter() - .map(Some) - .chain(std::iter::repeat(None).take(null_values)), - ) - .to_data(), - )], - ) - .unwrap() -} - -// Create a parquet file with one column for data type i64 -// Data of the file include -// . Number of null rows is the given num_null -// . There are non-null values in the range [no_null_values_start, no_null_values_end], one value each row -// . The file is divided into row groups of size row_per_group -pub fn parquet_file_one_column( - num_null: usize, - no_null_values_start: i64, - no_null_values_end: i64, + /// Number of rows per row group row_per_group: usize, -) -> ParquetRecordBatchReaderBuilder { - parquet_file_one_column_stats( - num_null, - no_null_values_start, - no_null_values_end, - row_per_group, - EnabledStatistics::Chunk, - ) + /// if specified, overrides default statistics settings + enable_stats: Option, + /// If specified, the number of values in each page + data_page_row_count_limit: Option, } -// Create a parquet file with one column for data type i64 -// Data of the file include -// . Number of null rows is the given num_null -// . There are non-null values in the range [no_null_values_start, no_null_values_end], one value each row -// . The file is divided into row groups of size row_per_group -// . Statistics are enabled/disabled based on the given enable_stats -pub fn parquet_file_one_column_stats( - num_null: usize, - no_null_values_start: i64, - no_null_values_end: i64, - row_per_group: usize, - enable_stats: EnabledStatistics, -) -> ParquetRecordBatchReaderBuilder { - let mut output_file = tempfile::Builder::new() - .prefix("parquert_statistics_test") - .suffix(".parquet") - .tempfile() - .expect("tempfile creation"); - - let props = WriterProperties::builder() - .set_max_row_group_size(row_per_group) - .set_statistics_enabled(enable_stats) - .build(); - - let batches = vec![make_int64_batches_with_null( - num_null, - no_null_values_start, - no_null_values_end, - )]; - - let schema = batches[0].schema(); - - let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); +impl Int64Case { + /// Return a record batch with i64 with Null values + /// The first no_null_values_end - no_null_values_start values + /// are non-null with the specified range, the rest are null + fn make_int64_batches_with_null(&self) -> RecordBatch { + let schema = + Arc::new(Schema::new(vec![Field::new("i64", DataType::Int64, true)])); + + let v64: Vec = + (self.no_null_values_start as _..self.no_null_values_end as _).collect(); + + RecordBatch::try_new( + schema, + vec![make_array( + Int64Array::from_iter( + v64.into_iter() + .map(Some) + .chain(std::iter::repeat(None).take(self.null_values)), + ) + .to_data(), + )], + ) + .unwrap() + } + + // Create a parquet file with the specified settings + pub fn build(&self) -> ParquetRecordBatchReaderBuilder { + let mut output_file = tempfile::Builder::new() + .prefix("parquert_statistics_test") + .suffix(".parquet") + .tempfile() + .expect("tempfile creation"); + + let mut builder = + WriterProperties::builder().set_max_row_group_size(self.row_per_group); + if let Some(enable_stats) = self.enable_stats { + builder = builder.set_statistics_enabled(enable_stats); + } + if let Some(data_page_row_count_limit) = self.data_page_row_count_limit { + builder = builder.set_data_page_row_count_limit(data_page_row_count_limit); + } + let props = builder.build(); + + let batches = vec![self.make_int64_batches_with_null()]; + + let schema = batches[0].schema(); + + let mut writer = + ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); + + // if we have a datapage limit send the batches in one at a time to give + // the writer a chance to be split into multiple pages + if self.data_page_row_count_limit.is_some() { + for batch in batches { + for i in 0..batch.num_rows() { + writer.write(&batch.slice(i, 1)).expect("writing batch"); + } + } + } else { + for batch in batches { + writer.write(&batch).expect("writing batch"); + } + } + + // close file + let _file_meta = writer.close().unwrap(); - for batch in batches { - writer.write(&batch).expect("writing batch"); + // open the file & get the reader + let file = output_file.reopen().unwrap(); + let options = ArrowReaderOptions::new().with_page_index(true); + ArrowReaderBuilder::try_new_with_options(file, options).unwrap() } - - // close file - let _file_meta = writer.close().unwrap(); - - // open the file & get the reader - let file = output_file.reopen().unwrap(); - ArrowReaderBuilder::try_new(file).unwrap() } /// Defines what data to create in a parquet file @@ -158,7 +159,8 @@ impl TestReader { // open the file & get the reader let file = file.reopen().unwrap(); - ArrowReaderBuilder::try_new(file).unwrap() + let options = ArrowReaderOptions::new().with_page_index(true); + ArrowReaderBuilder::try_new_with_options(file, options).unwrap() } } @@ -172,6 +174,9 @@ struct Test<'a> { expected_row_counts: UInt64Array, /// Which column to extract statistics from column_name: &'static str, + /// If true, extracts and compares data page statistics rather than row + /// group statistics + test_data_page_statistics: bool, } impl<'a> Test<'a> { @@ -183,6 +188,7 @@ impl<'a> Test<'a> { expected_null_counts, expected_row_counts, column_name, + test_data_page_statistics, } = self; let converter = StatisticsConverter::try_new( @@ -193,36 +199,103 @@ impl<'a> Test<'a> { .unwrap(); let row_groups = reader.metadata().row_groups(); - let min = converter.row_group_mins(row_groups).unwrap(); - - assert_eq!( - &min, &expected_min, - "{column_name}: Mismatch with expected minimums" - ); - let max = converter.row_group_maxes(row_groups).unwrap(); - assert_eq!( - &max, &expected_max, - "{column_name}: Mismatch with expected maximum" - ); - - let null_counts = converter.row_group_null_counts(row_groups).unwrap(); - let expected_null_counts = Arc::new(expected_null_counts) as ArrayRef; - assert_eq!( - &null_counts, &expected_null_counts, - "{column_name}: Mismatch with expected null counts. \ - Actual: {null_counts:?}. Expected: {expected_null_counts:?}" - ); - - let row_counts = StatisticsConverter::row_group_row_counts( - reader.metadata().row_groups().iter(), - ) - .unwrap(); - assert_eq!( - row_counts, expected_row_counts, - "{column_name}: Mismatch with expected row counts. \ - Actual: {row_counts:?}. Expected: {expected_row_counts:?}" - ); + if test_data_page_statistics { + let column_page_index = reader + .metadata() + .column_index() + .expect("File should have column page indices"); + + let column_offset_index = reader + .metadata() + .offset_index() + .expect("File should have column offset indices"); + + let row_group_indices = row_groups + .iter() + .enumerate() + .map(|(i, _)| i) + .collect::>(); + + let min = converter + .data_page_mins( + column_page_index, + column_offset_index, + &row_group_indices, + ) + .unwrap(); + assert_eq!( + &min, &expected_min, + "{column_name}: Mismatch with expected data page minimums" + ); + + let max = converter + .data_page_maxes( + column_page_index, + column_offset_index, + &row_group_indices, + ) + .unwrap(); + assert_eq!( + &max, &expected_max, + "{column_name}: Mismatch with expected data page maximum" + ); + + let null_counts = converter + .data_page_null_counts( + column_page_index, + column_offset_index, + &row_group_indices, + ) + .unwrap(); + + let expected_null_counts = Arc::new(expected_null_counts) as ArrayRef; + assert_eq!( + &null_counts, &expected_null_counts, + "{column_name}: Mismatch with expected data page null counts. \ + Actual: {null_counts:?}. Expected: {expected_null_counts:?}" + ); + + let row_counts = converter + .data_page_row_counts(column_offset_index, row_groups, &row_group_indices) + .unwrap(); + let expected_row_counts = Arc::new(expected_row_counts) as ArrayRef; + assert_eq!( + &row_counts, &expected_row_counts, + "{column_name}: Mismatch with expected row counts. \ + Actual: {row_counts:?}. Expected: {expected_row_counts:?}" + ); + } else { + let min = converter.row_group_mins(row_groups).unwrap(); + assert_eq!( + &min, &expected_min, + "{column_name}: Mismatch with expected minimums" + ); + + let max = converter.row_group_maxes(row_groups).unwrap(); + assert_eq!( + &max, &expected_max, + "{column_name}: Mismatch with expected maximum" + ); + + let null_counts = converter.row_group_null_counts(row_groups).unwrap(); + let expected_null_counts = Arc::new(expected_null_counts) as ArrayRef; + assert_eq!( + &null_counts, &expected_null_counts, + "{column_name}: Mismatch with expected null counts. \ + Actual: {null_counts:?}. Expected: {expected_null_counts:?}" + ); + + let row_counts = StatisticsConverter::row_group_row_counts( + reader.metadata().row_groups().iter(), + ) + .unwrap(); + assert_eq!( + row_counts, expected_row_counts, + "{column_name}: Mismatch with expected row counts. \ + Actual: {row_counts:?}. Expected: {expected_row_counts:?}" + ); + } } /// Run the test and expect a column not found error @@ -234,6 +307,7 @@ impl<'a> Test<'a> { expected_null_counts: _, expected_row_counts: _, column_name, + .. } = self; let converter = StatisticsConverter::try_new( @@ -254,8 +328,15 @@ impl<'a> Test<'a> { #[tokio::test] async fn test_one_row_group_without_null() { - let row_per_group = 20; - let reader = parquet_file_one_column(0, 4, 7, row_per_group); + let reader = Int64Case { + null_values: 0, + no_null_values_start: 4, + no_null_values_end: 7, + row_per_group: 20, + ..Default::default() + } + .build(); + Test { reader: &reader, // min is 4 @@ -267,14 +348,21 @@ async fn test_one_row_group_without_null() { // 3 rows expected_row_counts: UInt64Array::from(vec![3]), column_name: "i64", + test_data_page_statistics: false, } .run() } #[tokio::test] async fn test_one_row_group_with_null_and_negative() { - let row_per_group = 20; - let reader = parquet_file_one_column(2, -1, 5, row_per_group); + let reader = Int64Case { + null_values: 2, + no_null_values_start: -1, + no_null_values_end: 5, + row_per_group: 20, + ..Default::default() + } + .build(); Test { reader: &reader, @@ -287,14 +375,21 @@ async fn test_one_row_group_with_null_and_negative() { // 8 rows expected_row_counts: UInt64Array::from(vec![8]), column_name: "i64", + test_data_page_statistics: false, } .run() } #[tokio::test] async fn test_two_row_group_with_null() { - let row_per_group = 10; - let reader = parquet_file_one_column(2, 4, 17, row_per_group); + let reader = Int64Case { + null_values: 2, + no_null_values_start: 4, + no_null_values_end: 17, + row_per_group: 10, + ..Default::default() + } + .build(); Test { reader: &reader, @@ -307,14 +402,21 @@ async fn test_two_row_group_with_null() { // row counts are [10, 5] expected_row_counts: UInt64Array::from(vec![10, 5]), column_name: "i64", + test_data_page_statistics: false, } .run() } #[tokio::test] async fn test_two_row_groups_with_all_nulls_in_one() { - let row_per_group = 5; - let reader = parquet_file_one_column(4, -2, 2, row_per_group); + let reader = Int64Case { + null_values: 4, + no_null_values_start: -2, + no_null_values_end: 2, + row_per_group: 5, + ..Default::default() + } + .build(); Test { reader: &reader, @@ -327,6 +429,38 @@ async fn test_two_row_groups_with_all_nulls_in_one() { // row counts are [5, 3] expected_row_counts: UInt64Array::from(vec![5, 3]), column_name: "i64", + test_data_page_statistics: false, + } + .run() +} + +#[tokio::test] +async fn test_multiple_data_pages_nulls_and_negatives() { + let reader = Int64Case { + null_values: 3, + no_null_values_start: -1, + no_null_values_end: 10, + row_per_group: 20, + // limit page row count to 4 + data_page_row_count_limit: Some(4), + enable_stats: Some(EnabledStatistics::Page), + } + .build(); + + // Data layout looks like this: + // + // page 0: [-1, 0, 1, 2] + // page 1: [3, 4, 5, 6] + // page 2: [7, 8, 9, null] + // page 3: [null, null] + Test { + reader: &reader, + expected_min: Arc::new(Int64Array::from(vec![Some(-1), Some(3), Some(7), None])), + expected_max: Arc::new(Int64Array::from(vec![Some(2), Some(6), Some(9), None])), + expected_null_counts: UInt64Array::from(vec![0, 0, 1, 2]), + expected_row_counts: UInt64Array::from(vec![4, 4, 4, 2]), + column_name: "i64", + test_data_page_statistics: true, } .run() } @@ -347,19 +481,23 @@ async fn test_int_64() { .build() .await; - Test { - reader: &reader, - // mins are [-5, -4, 0, 5] - expected_min: Arc::new(Int64Array::from(vec![-5, -4, 0, 5])), - // maxes are [-1, 0, 4, 9] - expected_max: Arc::new(Int64Array::from(vec![-1, 0, 4, 9])), - // nulls are [0, 0, 0, 0] - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), - // row counts are [5, 5, 5, 5] - expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), - column_name: "i64", + // since each row has only one data page, the statistics are the same + for test_data_page_statistics in [true, false] { + Test { + reader: &reader, + // mins are [-5, -4, 0, 5] + expected_min: Arc::new(Int64Array::from(vec![-5, -4, 0, 5])), + // maxes are [-1, 0, 4, 9] + expected_max: Arc::new(Int64Array::from(vec![-1, 0, 4, 9])), + // nulls are [0, 0, 0, 0] + expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), + // row counts are [5, 5, 5, 5] + expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), + column_name: "i64", + test_data_page_statistics, + } + .run(); } - .run(); } #[tokio::test] @@ -383,6 +521,7 @@ async fn test_int_32() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i32", + test_data_page_statistics: false, } .run(); } @@ -423,6 +562,7 @@ async fn test_int_16() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i16", + test_data_page_statistics: false, } .run(); } @@ -451,6 +591,7 @@ async fn test_int_8() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i8", + test_data_page_statistics: false, } .run(); } @@ -500,6 +641,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "nanos", + test_data_page_statistics: false, } .run(); @@ -528,6 +670,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "nanos_timezoned", + test_data_page_statistics: false, } .run(); @@ -549,6 +692,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "micros", + test_data_page_statistics: false, } .run(); @@ -577,6 +721,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "micros_timezoned", + test_data_page_statistics: false, } .run(); @@ -598,6 +743,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "millis", + test_data_page_statistics: false, } .run(); @@ -626,6 +772,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "millis_timezoned", + test_data_page_statistics: false, } .run(); @@ -647,6 +794,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "seconds", + test_data_page_statistics: false, } .run(); @@ -675,6 +823,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "seconds_timezoned", + test_data_page_statistics: false, } .run(); } @@ -720,6 +869,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "nanos", + test_data_page_statistics: false, } .run(); @@ -746,6 +896,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "nanos_timezoned", + test_data_page_statistics: false, } .run(); @@ -765,6 +916,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "micros", + test_data_page_statistics: false, } .run(); @@ -791,6 +943,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "micros_timezoned", + test_data_page_statistics: false, } .run(); @@ -810,6 +963,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "millis", + test_data_page_statistics: false, } .run(); @@ -836,6 +990,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "millis_timezoned", + test_data_page_statistics: false, } .run(); @@ -855,6 +1010,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "seconds", + test_data_page_statistics: false, } .run(); @@ -881,6 +1037,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "seconds_timezoned", + test_data_page_statistics: false, } .run(); } @@ -918,6 +1075,7 @@ async fn test_dates_32_diff_rg_sizes() { // row counts are [13, 7] expected_row_counts: UInt64Array::from(vec![13, 7]), column_name: "date32", + test_data_page_statistics: false, } .run(); } @@ -940,6 +1098,7 @@ async fn test_time32_second_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "second", + test_data_page_statistics: false, } .run(); } @@ -966,6 +1125,7 @@ async fn test_time32_millisecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "millisecond", + test_data_page_statistics: false, } .run(); } @@ -998,6 +1158,7 @@ async fn test_time64_microsecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "microsecond", + test_data_page_statistics: false, } .run(); } @@ -1030,6 +1191,7 @@ async fn test_time64_nanosecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "nanosecond", + test_data_page_statistics: false, } .run(); } @@ -1056,6 +1218,7 @@ async fn test_dates_64_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![2, 2]), expected_row_counts: UInt64Array::from(vec![13, 7]), column_name: "date64", + test_data_page_statistics: false, } .run(); } @@ -1083,6 +1246,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u8", + test_data_page_statistics: false, } .run(); @@ -1093,6 +1257,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u16", + test_data_page_statistics: false, } .run(); @@ -1103,6 +1268,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u32", + test_data_page_statistics: false, } .run(); @@ -1113,6 +1279,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u64", + test_data_page_statistics: false, } .run(); } @@ -1135,6 +1302,7 @@ async fn test_int32_range() { expected_null_counts: UInt64Array::from(vec![0]), expected_row_counts: UInt64Array::from(vec![4]), column_name: "i", + test_data_page_statistics: false, } .run(); } @@ -1157,6 +1325,7 @@ async fn test_uint32_range() { expected_null_counts: UInt64Array::from(vec![0]), expected_row_counts: UInt64Array::from(vec![4]), column_name: "u", + test_data_page_statistics: false, } .run(); } @@ -1178,6 +1347,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u8", + test_data_page_statistics: false, } .run(); @@ -1188,6 +1358,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u16", + test_data_page_statistics: false, } .run(); @@ -1198,6 +1369,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u32", + test_data_page_statistics: false, } .run(); @@ -1208,6 +1380,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u64", + test_data_page_statistics: false, } .run(); } @@ -1229,6 +1402,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i8", + test_data_page_statistics: false, } .run(); @@ -1239,6 +1413,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i16", + test_data_page_statistics: false, } .run(); @@ -1249,6 +1424,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i32", + test_data_page_statistics: false, } .run(); @@ -1259,6 +1435,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i64", + test_data_page_statistics: false, } .run(); } @@ -1280,6 +1457,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f32", + test_data_page_statistics: false, } .run(); @@ -1290,6 +1468,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f64", + test_data_page_statistics: false, } .run(); @@ -1300,6 +1479,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f32_nan", + test_data_page_statistics: false, } .run(); @@ -1310,6 +1490,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f64_nan", + test_data_page_statistics: false, } .run(); } @@ -1332,6 +1513,7 @@ async fn test_float64() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "f", + test_data_page_statistics: false, } .run(); } @@ -1364,6 +1546,7 @@ async fn test_float16() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "f", + test_data_page_statistics: false, } .run(); } @@ -1394,6 +1577,7 @@ async fn test_decimal() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "decimal_col", + test_data_page_statistics: false, } .run(); } @@ -1431,6 +1615,7 @@ async fn test_decimal_256() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "decimal256_col", + test_data_page_statistics: false, } .run(); } @@ -1450,6 +1635,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "string_dict_i8", + test_data_page_statistics: false, } .run(); @@ -1460,6 +1646,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "string_dict_i32", + test_data_page_statistics: false, } .run(); @@ -1470,6 +1657,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "int_dict_i8", + test_data_page_statistics: false, } .run(); } @@ -1507,6 +1695,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "name", + test_data_page_statistics: false, } .run(); @@ -1526,6 +1715,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_string", + test_data_page_statistics: false, } .run(); @@ -1544,6 +1734,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_binary", + test_data_page_statistics: false, } .run(); @@ -1564,6 +1755,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_fixedsize", + test_data_page_statistics: false, } .run(); @@ -1584,6 +1776,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_large_binary", + test_data_page_statistics: false, } .run(); } @@ -1616,6 +1809,7 @@ async fn test_period_in_column_names() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "name", + test_data_page_statistics: false, } .run(); @@ -1629,6 +1823,7 @@ async fn test_period_in_column_names() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service.name", + test_data_page_statistics: false, } .run(); } @@ -1652,6 +1847,7 @@ async fn test_boolean() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 5]), column_name: "bool", + test_data_page_statistics: false, } .run(); } @@ -1678,6 +1874,7 @@ async fn test_struct() { expected_null_counts: UInt64Array::from(vec![0]), expected_row_counts: UInt64Array::from(vec![3]), column_name: "struct", + test_data_page_statistics: false, } .run(); } @@ -1700,6 +1897,7 @@ async fn test_utf8() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 5]), column_name: "utf8", + test_data_page_statistics: false, } .run(); @@ -1711,6 +1909,7 @@ async fn test_utf8() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 5]), column_name: "large_utf8", + test_data_page_statistics: false, } .run(); } @@ -1719,9 +1918,15 @@ async fn test_utf8() { #[tokio::test] async fn test_missing_statistics() { - let row_per_group = 5; - let reader = - parquet_file_one_column_stats(0, 4, 7, row_per_group, EnabledStatistics::None); + let reader = Int64Case { + null_values: 0, + no_null_values_start: 4, + no_null_values_end: 7, + row_per_group: 5, + enable_stats: Some(EnabledStatistics::None), + ..Default::default() + } + .build(); Test { reader: &reader, @@ -1730,6 +1935,7 @@ async fn test_missing_statistics() { expected_null_counts: UInt64Array::from(vec![None]), expected_row_counts: UInt64Array::from(vec![3]), // stil has row count statistics column_name: "i64", + test_data_page_statistics: false, } .run(); } @@ -1751,6 +1957,7 @@ async fn test_column_not_found() { expected_null_counts: UInt64Array::from(vec![2, 2]), expected_row_counts: UInt64Array::from(vec![13, 7]), column_name: "not_a_column", + test_data_page_statistics: false, } .run_col_not_found(); } diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 9546ab30c9e0..0434a271c32e 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -43,7 +43,7 @@ use datafusion::{ use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use half::f16; use parquet::arrow::ArrowWriter; -use parquet::file::properties::WriterProperties; +use parquet::file::properties::{EnabledStatistics, WriterProperties}; use std::sync::Arc; use tempfile::NamedTempFile; @@ -1349,6 +1349,7 @@ async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTem let props = WriterProperties::builder() .set_max_row_group_size(row_per_group) .set_bloom_filter_enabled(true) + .set_statistics_enabled(EnabledStatistics::Page) .build(); let batches = create_data_batch(scenario); From 87aea143099b7220037bd54629c38d450b894de1 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Sat, 15 Jun 2024 09:45:57 -0700 Subject: [PATCH 23/34] Add contains function, and support in datafusion substrait consumer (#10879) * adding new function contains * adding substrait test * adding doc * adding doc * Update docs/source/user-guide/sql/scalar_functions.md Co-authored-by: Alex Huang * adding entry --------- Co-authored-by: Alex Huang --- datafusion/functions/src/string/contains.rs | 143 ++++++++++++++++++ datafusion/functions/src/string/mod.rs | 8 +- .../sqllogictest/test_files/functions.slt | 18 +++ .../substrait/tests/cases/function_test.rs | 58 +++++++ datafusion/substrait/tests/cases/mod.rs | 1 + .../testdata/contains_plan.substrait.json | 133 ++++++++++++++++ .../source/user-guide/sql/scalar_functions.md | 14 ++ 7 files changed, 373 insertions(+), 2 deletions(-) create mode 100644 datafusion/functions/src/string/contains.rs create mode 100644 datafusion/substrait/tests/cases/function_test.rs create mode 100644 datafusion/substrait/tests/testdata/contains_plan.substrait.json diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs new file mode 100644 index 000000000000..faf979f80614 --- /dev/null +++ b/datafusion/functions/src/string/contains.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Boolean; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_common::{arrow_datafusion_err, exec_err}; +use datafusion_expr::ScalarUDFImpl; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; +#[derive(Debug)] +pub struct ContainsFunc { + signature: Signature, +} + +impl Default for ContainsFunc { + fn default() -> Self { + ContainsFunc::new() + } +} + +impl ContainsFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for ContainsFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(contains::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(contains::, vec![])(args), + other => { + exec_err!("unsupported data type {other:?} for function contains") + } + } + } +} + +/// use regexp_is_match_utf8_scalar to do the calculation for contains +pub fn contains( + args: &[ArrayRef], +) -> Result { + let mod_str = as_generic_string_array::(&args[0])?; + let match_str = as_generic_string_array::(&args[1])?; + let res = arrow::compute::kernels::comparison::regexp_is_match_utf8( + mod_str, match_str, None, + ) + .map_err(|e| arrow_datafusion_err!(e))?; + + Ok(Arc::new(res) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use crate::string::contains::ContainsFunc; + use crate::utils::test::test_function; + use arrow::array::Array; + use arrow::{array::BooleanArray, datatypes::DataType::Boolean}; + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use datafusion_expr::ColumnarValue; + use datafusion_expr::ScalarUDFImpl; + #[test] + fn test_functions() -> Result<()> { + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from("alph")), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from("dddddd")), + ], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + test_function!( + ContainsFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from("pha")), + ], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + Ok(()) + } +} diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 219ef8b5a50f..5bf372c29f2d 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -28,6 +28,7 @@ pub mod chr; pub mod common; pub mod concat; pub mod concat_ws; +pub mod contains; pub mod ends_with; pub mod initcap; pub mod levenshtein; @@ -43,7 +44,6 @@ pub mod starts_with; pub mod to_hex; pub mod upper; pub mod uuid; - // create UDFs make_udf_function!(ascii::AsciiFunc, ASCII, ascii); make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length); @@ -66,7 +66,7 @@ make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part); make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); make_udf_function!(upper::UpperFunc, UPPER, upper); make_udf_function!(uuid::UuidFunc, UUID, uuid); - +make_udf_function!(contains::ContainsFunc, CONTAINS, contains); pub mod expr_fn { use datafusion_expr::Expr; @@ -149,6 +149,9 @@ pub mod expr_fn { ),( uuid, "returns uuid v4 as a string value", + ), ( + contains, + "Return true if search_string is found within string. treated it like a reglike", )); #[doc = "Removes all characters, spaces by default, from both sides of a string"] @@ -188,5 +191,6 @@ pub fn functions() -> Vec> { to_hex(), upper(), uuid(), + contains(), ] } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index df6295d63b81..c3dd791f6ca8 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -1158,3 +1158,21 @@ drop table uuid_table statement ok drop table t + + +# test for contains + +query B +select contains('alphabet', 'pha'); +---- +true + +query B +select contains('alphabet', 'dddd'); +---- +false + +query B +select contains('', ''); +---- +true diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs new file mode 100644 index 000000000000..b4c5659a3a49 --- /dev/null +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for Function Compatibility + +#[cfg(test)] +mod tests { + use datafusion::common::Result; + use datafusion::prelude::{CsvReadOptions, SessionContext}; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::fs::File; + use std::io::BufReader; + use substrait::proto::Plan; + + #[tokio::test] + async fn contains_function_test() -> Result<()> { + let ctx = create_context().await?; + + let path = "tests/testdata/contains_plan.substrait.json"; + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let plan = from_substrait_plan(&ctx, &proto).await?; + + let plan_str = format!("{:?}", plan); + + assert_eq!( + plan_str, + "Projection: nation.b AS n_name\ + \n Filter: contains(nation.b, Utf8(\"IA\"))\ + \n TableScan: nation projection=[a, b, c, d, e, f]" + ); + Ok(()) + } + + async fn create_context() -> datafusion::common::Result { + let ctx = SessionContext::new(); + ctx.register_csv("nation", "tests/testdata/data.csv", CsvReadOptions::new()) + .await?; + Ok(ctx) + } +} diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index a31f93087d83..d3ea7695e4b9 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -16,6 +16,7 @@ // under the License. mod consumer_integration; +mod function_test; mod logical_plans; mod roundtrip_logical_plan; mod roundtrip_physical_plan; diff --git a/datafusion/substrait/tests/testdata/contains_plan.substrait.json b/datafusion/substrait/tests/testdata/contains_plan.substrait.json new file mode 100644 index 000000000000..76edde34e3b0 --- /dev/null +++ b/datafusion/substrait/tests/testdata/contains_plan.substrait.json @@ -0,0 +1,133 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_string.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "contains:str_str" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 4 + ] + } + }, + "input": { + "filter": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "n_nationkey", + "n_name", + "n_regionkey", + "n_comment" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "nation" + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "literal": { + "string": "IA" + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + }, + "names": [ + "n_name" + ] + } + } + ], + "version": { + "minorNumber": 38, + "producer": "ibis-substrait" + } +} \ No newline at end of file diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 10c52bc5de9e..ec34dbf9ba6c 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -681,6 +681,7 @@ _Alias of [nvl](#nvl)._ - [substr_index](#substr_index) - [find_in_set](#find_in_set) - [position](#position) +- [contains](#contains) ### `ascii` @@ -1443,6 +1444,19 @@ position(substr in origstr) - **substr**: The pattern string. - **origstr**: The model string. +### `contains` + +Return true if search_string is found within string. + +``` +contains(string, search_string) +``` + +#### Arguments + +- **string**: The pattern string. +- **search_string**: The model string. + ## Time and Date Functions - [now](#now) From 648c20c388427fa82d219dc9829db6c1c9b119a0 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 16 Jun 2024 07:37:40 -0400 Subject: [PATCH 24/34] Minor: Improve arrow_statistics tests (#10927) --- .../core/tests/parquet/arrow_statistics.rs | 213 ++++++++++-------- 1 file changed, 121 insertions(+), 92 deletions(-) diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 3c812800e2b7..6b8705441d12 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -164,6 +164,36 @@ impl TestReader { } } +/// Which statistics should we check? +#[derive(Clone, Debug, Copy)] +enum Check { + /// Extract and check row group statistics + RowGroup, + /// Extract and check data page statistics + DataPage, + /// Extract and check both row group and data page statistics. + /// + /// Note if a row group contains a single data page, + /// the statistics for row groups and data pages are the same. + Both, +} + +impl Check { + fn row_group(&self) -> bool { + match self { + Self::RowGroup | Self::Both => true, + Self::DataPage => false, + } + } + + fn data_page(&self) -> bool { + match self { + Self::DataPage | Self::Both => true, + Self::RowGroup => false, + } + } +} + /// Defines a test case for statistics extraction struct Test<'a> { /// The parquet file reader @@ -174,9 +204,8 @@ struct Test<'a> { expected_row_counts: UInt64Array, /// Which column to extract statistics from column_name: &'static str, - /// If true, extracts and compares data page statistics rather than row - /// group statistics - test_data_page_statistics: bool, + /// What statistics should be checked? + check: Check, } impl<'a> Test<'a> { @@ -188,7 +217,7 @@ impl<'a> Test<'a> { expected_null_counts, expected_row_counts, column_name, - test_data_page_statistics, + check, } = self; let converter = StatisticsConverter::try_new( @@ -199,8 +228,9 @@ impl<'a> Test<'a> { .unwrap(); let row_groups = reader.metadata().row_groups(); + let expected_null_counts = Arc::new(expected_null_counts) as ArrayRef; - if test_data_page_statistics { + if check.data_page() { let column_page_index = reader .metadata() .column_index() @@ -249,7 +279,6 @@ impl<'a> Test<'a> { ) .unwrap(); - let expected_null_counts = Arc::new(expected_null_counts) as ArrayRef; assert_eq!( &null_counts, &expected_null_counts, "{column_name}: Mismatch with expected data page null counts. \ @@ -259,13 +288,16 @@ impl<'a> Test<'a> { let row_counts = converter .data_page_row_counts(column_offset_index, row_groups, &row_group_indices) .unwrap(); - let expected_row_counts = Arc::new(expected_row_counts) as ArrayRef; + // https://github.com/apache/datafusion/issues/10926 + let expected_row_counts: ArrayRef = Arc::new(expected_row_counts.clone()); assert_eq!( &row_counts, &expected_row_counts, "{column_name}: Mismatch with expected row counts. \ Actual: {row_counts:?}. Expected: {expected_row_counts:?}" ); - } else { + } + + if check.row_group() { let min = converter.row_group_mins(row_groups).unwrap(); assert_eq!( &min, &expected_min, @@ -279,7 +311,6 @@ impl<'a> Test<'a> { ); let null_counts = converter.row_group_null_counts(row_groups).unwrap(); - let expected_null_counts = Arc::new(expected_null_counts) as ArrayRef; assert_eq!( &null_counts, &expected_null_counts, "{column_name}: Mismatch with expected null counts. \ @@ -348,7 +379,7 @@ async fn test_one_row_group_without_null() { // 3 rows expected_row_counts: UInt64Array::from(vec![3]), column_name: "i64", - test_data_page_statistics: false, + check: Check::RowGroup, } .run() } @@ -375,7 +406,7 @@ async fn test_one_row_group_with_null_and_negative() { // 8 rows expected_row_counts: UInt64Array::from(vec![8]), column_name: "i64", - test_data_page_statistics: false, + check: Check::RowGroup, } .run() } @@ -402,7 +433,7 @@ async fn test_two_row_group_with_null() { // row counts are [10, 5] expected_row_counts: UInt64Array::from(vec![10, 5]), column_name: "i64", - test_data_page_statistics: false, + check: Check::RowGroup, } .run() } @@ -429,7 +460,7 @@ async fn test_two_row_groups_with_all_nulls_in_one() { // row counts are [5, 3] expected_row_counts: UInt64Array::from(vec![5, 3]), column_name: "i64", - test_data_page_statistics: false, + check: Check::RowGroup, } .run() } @@ -460,7 +491,7 @@ async fn test_multiple_data_pages_nulls_and_negatives() { expected_null_counts: UInt64Array::from(vec![0, 0, 1, 2]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 2]), column_name: "i64", - test_data_page_statistics: true, + check: Check::DataPage, } .run() } @@ -482,22 +513,20 @@ async fn test_int_64() { .await; // since each row has only one data page, the statistics are the same - for test_data_page_statistics in [true, false] { - Test { - reader: &reader, - // mins are [-5, -4, 0, 5] - expected_min: Arc::new(Int64Array::from(vec![-5, -4, 0, 5])), - // maxes are [-1, 0, 4, 9] - expected_max: Arc::new(Int64Array::from(vec![-1, 0, 4, 9])), - // nulls are [0, 0, 0, 0] - expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), - // row counts are [5, 5, 5, 5] - expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), - column_name: "i64", - test_data_page_statistics, - } - .run(); + Test { + reader: &reader, + // mins are [-5, -4, 0, 5] + expected_min: Arc::new(Int64Array::from(vec![-5, -4, 0, 5])), + // maxes are [-1, 0, 4, 9] + expected_max: Arc::new(Int64Array::from(vec![-1, 0, 4, 9])), + // nulls are [0, 0, 0, 0] + expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), + // row counts are [5, 5, 5, 5] + expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), + column_name: "i64", + check: Check::Both, } + .run(); } #[tokio::test] @@ -521,7 +550,7 @@ async fn test_int_32() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i32", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -562,7 +591,7 @@ async fn test_int_16() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i16", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -591,7 +620,7 @@ async fn test_int_8() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i8", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -641,7 +670,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "nanos", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -670,7 +699,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "nanos_timezoned", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -692,7 +721,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "micros", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -721,7 +750,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "micros_timezoned", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -743,7 +772,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "millis", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -772,7 +801,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "millis_timezoned", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -794,7 +823,7 @@ async fn test_timestamp() { expected_null_counts: UInt64Array::from(vec![1, 1, 1, 1]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "seconds", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -823,7 +852,7 @@ async fn test_timestamp() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "seconds_timezoned", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -869,7 +898,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "nanos", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -896,7 +925,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "nanos_timezoned", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -916,7 +945,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "micros", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -943,7 +972,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "micros_timezoned", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -963,7 +992,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "millis", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -990,7 +1019,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "millis_timezoned", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1010,7 +1039,7 @@ async fn test_timestamp_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![1, 2, 1]), expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "seconds", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1037,7 +1066,7 @@ async fn test_timestamp_diff_rg_sizes() { // row counts are [8, 8, 4] expected_row_counts: UInt64Array::from(vec![8, 8, 4]), column_name: "seconds_timezoned", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1075,7 +1104,7 @@ async fn test_dates_32_diff_rg_sizes() { // row counts are [13, 7] expected_row_counts: UInt64Array::from(vec![13, 7]), column_name: "date32", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1098,7 +1127,7 @@ async fn test_time32_second_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "second", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1125,7 +1154,7 @@ async fn test_time32_millisecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "millisecond", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1158,7 +1187,7 @@ async fn test_time64_microsecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "microsecond", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1191,7 +1220,7 @@ async fn test_time64_nanosecond_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), // Assuming 1 null per row group for simplicity expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4]), column_name: "nanosecond", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1218,7 +1247,7 @@ async fn test_dates_64_diff_rg_sizes() { expected_null_counts: UInt64Array::from(vec![2, 2]), expected_row_counts: UInt64Array::from(vec![13, 7]), column_name: "date64", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1246,7 +1275,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u8", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1257,7 +1286,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u16", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1268,7 +1297,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u32", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1279,7 +1308,7 @@ async fn test_uint() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![4, 4, 4, 4, 4]), column_name: "u64", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1302,7 +1331,7 @@ async fn test_int32_range() { expected_null_counts: UInt64Array::from(vec![0]), expected_row_counts: UInt64Array::from(vec![4]), column_name: "i", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1325,7 +1354,7 @@ async fn test_uint32_range() { expected_null_counts: UInt64Array::from(vec![0]), expected_row_counts: UInt64Array::from(vec![4]), column_name: "u", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1347,7 +1376,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u8", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1358,7 +1387,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u16", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1369,7 +1398,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u32", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1380,7 +1409,7 @@ async fn test_numeric_limits_unsigned() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "u64", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1402,7 +1431,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i8", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1413,7 +1442,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i16", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1424,7 +1453,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i32", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1435,7 +1464,7 @@ async fn test_numeric_limits_signed() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "i64", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1457,7 +1486,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f32", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1468,7 +1497,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f64", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1479,7 +1508,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f32_nan", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1490,7 +1519,7 @@ async fn test_numeric_limits_float() { expected_null_counts: UInt64Array::from(vec![0, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "f64_nan", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1513,7 +1542,7 @@ async fn test_float64() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "f", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1546,7 +1575,7 @@ async fn test_float16() { expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "f", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1577,7 +1606,7 @@ async fn test_decimal() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "decimal_col", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1615,7 +1644,7 @@ async fn test_decimal_256() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "decimal256_col", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1635,7 +1664,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "string_dict_i8", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1646,7 +1675,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "string_dict_i32", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1657,7 +1686,7 @@ async fn test_dictionary() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 2]), column_name: "int_dict_i8", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1695,7 +1724,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "name", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1715,7 +1744,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_string", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1734,7 +1763,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_binary", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1755,7 +1784,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_fixedsize", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1776,7 +1805,7 @@ async fn test_byte() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service_large_binary", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1809,7 +1838,7 @@ async fn test_period_in_column_names() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "name", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1823,7 +1852,7 @@ async fn test_period_in_column_names() { expected_null_counts: UInt64Array::from(vec![0, 0, 0]), expected_row_counts: UInt64Array::from(vec![5, 5, 5]), column_name: "service.name", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1847,7 +1876,7 @@ async fn test_boolean() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 5]), column_name: "bool", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1874,7 +1903,7 @@ async fn test_struct() { expected_null_counts: UInt64Array::from(vec![0]), expected_row_counts: UInt64Array::from(vec![3]), column_name: "struct", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1897,7 +1926,7 @@ async fn test_utf8() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 5]), column_name: "utf8", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); @@ -1909,7 +1938,7 @@ async fn test_utf8() { expected_null_counts: UInt64Array::from(vec![1, 0]), expected_row_counts: UInt64Array::from(vec![5, 5]), column_name: "large_utf8", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1935,7 +1964,7 @@ async fn test_missing_statistics() { expected_null_counts: UInt64Array::from(vec![None]), expected_row_counts: UInt64Array::from(vec![3]), // stil has row count statistics column_name: "i64", - test_data_page_statistics: false, + check: Check::RowGroup, } .run(); } @@ -1957,7 +1986,7 @@ async fn test_column_not_found() { expected_null_counts: UInt64Array::from(vec![2, 2]), expected_row_counts: UInt64Array::from(vec![13, 7]), column_name: "not_a_column", - test_data_page_statistics: false, + check: Check::RowGroup, } .run_col_not_found(); } From d175163ef6442056d8210de9b0e28e264c39ca2c Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 16 Jun 2024 19:39:13 +0800 Subject: [PATCH 25/34] rm env (#10933) Signed-off-by: jayzhan211 --- benchmarks/bench.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 903fcb940b3e..efd56b17c7cb 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -412,7 +412,7 @@ run_clickbench_1() { RESULTS_FILE="${RESULTS_DIR}/clickbench_1.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --prefer_hash_join ${PREFER_HASH_JOIN} --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} } # Runs the clickbench benchmark with the partitioned parquet files @@ -420,7 +420,7 @@ run_clickbench_partitioned() { RESULTS_FILE="${RESULTS_DIR}/clickbench_partitioned.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (partitioned, 100 files) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --prefer_hash_join ${PREFER_HASH_JOIN} --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} } # Runs the clickbench "extended" benchmark with a single large parquet file @@ -428,7 +428,7 @@ run_clickbench_extended() { RESULTS_FILE="${RESULTS_DIR}/clickbench_extended.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) extended benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --prefer_hash_join ${PREFER_HASH_JOIN} --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o ${RESULTS_FILE} + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o ${RESULTS_FILE} } compare_benchmarks() { From c884bdb692020d8feb9599c9e455a406b98a6f46 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Sun, 16 Jun 2024 22:42:44 +0800 Subject: [PATCH 26/34] Convert ApproxPercentileCont and ApproxPercentileContWithWeight to UDAF (#10917) * pass logical expr of arguments for udaf * implement approx_percentile_cont udaf * register udaf * remove ApproxPercentileCont * convert with_wegiht to udaf and remove original * fix conflict * fix compile check * fix doc and testing * evaluate args through physical plan * public use Literal * fix tests * rollback the experimental tests * remove unused import * rename args and inline code * remove unnecessary partial eq trait * fix error message --- .../aggregate_statistics.rs | 1 + .../combine_partial_final_agg.rs | 2 + .../core/src/physical_optimizer/test_utils.rs | 1 + datafusion/core/src/physical_planner.rs | 16 +- .../tests/dataframe/dataframe_functions.rs | 6 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 1 + .../core/tests/fuzz_cases/window_fuzz.rs | 4 + datafusion/expr/src/aggregate_function.rs | 50 +--- datafusion/expr/src/expr_fn.rs | 28 -- datafusion/expr/src/function.rs | 4 +- .../expr/src/type_coercion/aggregates.rs | 82 ------ .../functions-aggregate/src/approx_median.rs | 10 - .../src/approx_percentile_cont.rs | 235 ++++++++++++++++- .../approx_percentile_cont_with_weight.rs | 159 ++++++----- datafusion/functions-aggregate/src/count.rs | 2 +- datafusion/functions-aggregate/src/lib.rs | 7 + datafusion/functions-aggregate/src/stddev.rs | 4 +- .../optimizer/src/analyzer/type_coercion.rs | 25 -- .../physical-expr-common/src/aggregate/mod.rs | 13 +- .../src/expressions/mod.rs | 2 +- datafusion/physical-expr-common/src/utils.rs | 26 +- .../src/aggregate/approx_percentile_cont.rs | 249 ------------------ .../physical-expr/src/aggregate/build_in.rs | 101 +------ datafusion/physical-expr/src/aggregate/mod.rs | 2 - .../physical-expr/src/expressions/mod.rs | 4 +- .../physical-plan/src/aggregates/mod.rs | 8 +- .../src/windows/bounded_window_agg_exec.rs | 6 +- datafusion/physical-plan/src/windows/mod.rs | 3 + datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/pbjson.rs | 6 - datafusion/proto/src/generated/prost.rs | 12 +- .../proto/src/logical_plan/from_proto.rs | 6 - datafusion/proto/src/logical_plan/to_proto.rs | 10 - .../proto/src/physical_plan/from_proto.rs | 4 +- datafusion/proto/src/physical_plan/mod.rs | 5 +- .../proto/src/physical_plan/to_proto.rs | 18 +- .../tests/cases/roundtrip_logical_plan.rs | 25 +- .../tests/cases/roundtrip_physical_plan.rs | 2 + .../sqllogictest/test_files/aggregate.slt | 14 +- 39 files changed, 443 insertions(+), 714 deletions(-) rename datafusion/{physical-expr/src/aggregate => functions-aggregate/src}/approx_percentile_cont_with_weight.rs (51%) delete mode 100644 datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index eeacc48b85db..ca1582bcb34f 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -390,6 +390,7 @@ pub(crate) mod tests { &[self.column()], &[], &[], + &[], schema, self.column_name(), false, diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 38b92959e841..b57f36f728d7 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -315,6 +315,7 @@ mod tests { &[expr], &[], &[], + &[], schema, name, false, @@ -404,6 +405,7 @@ mod tests { &[col("b", &schema)?], &[], &[], + &[], &schema, "Sum(b)", false, diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 154e77cd23ae..5320938d2eb8 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -245,6 +245,7 @@ pub fn bounded_window_exec( "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], + &[], &sort_exprs, Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 4f9187595018..404bcbb2e7d4 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1766,7 +1766,8 @@ pub fn create_window_expr_with_name( window_frame, null_treatment, }) => { - let args = create_physical_exprs(args, logical_schema, execution_props)?; + let physical_args = + create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = create_physical_exprs(partition_by, logical_schema, execution_props)?; let order_by = @@ -1780,13 +1781,13 @@ pub fn create_window_expr_with_name( } let window_frame = Arc::new(window_frame.clone()); - let ignore_nulls = null_treatment - .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) + let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; windows::create_window_expr( fun, name, - &args, + &physical_args, + args, &partition_by, &order_by, window_frame, @@ -1837,7 +1838,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( order_by, null_treatment, }) => { - let args = + let physical_args = create_physical_exprs(args, logical_input_schema, execution_props)?; let filter = match filter { Some(e) => Some(create_physical_expr( @@ -1867,7 +1868,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let agg_expr = aggregates::create_aggregate_expr( fun, *distinct, - &args, + &physical_args, &ordering_reqs, physical_input_schema, name, @@ -1889,7 +1890,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( physical_sort_exprs.clone().unwrap_or(vec![]); let agg_expr = udaf::create_aggregate_expr( fun, - &args, + &physical_args, + args, &sort_exprs, &ordering_reqs, physical_input_schema, diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index b05769a6ce9d..1c55c48fea40 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -33,7 +33,7 @@ use datafusion::assert_batches_eq; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::ExprSchemable; -use datafusion_functions_aggregate::expr_fn::approx_median; +use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -363,7 +363,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { let expected = [ "+---------------------------------------------+", - "| APPROX_PERCENTILE_CONT(test.b,Float64(0.5)) |", + "| approx_percentile_cont(test.b,Float64(0.5)) |", "+---------------------------------------------+", "| 10 |", "+---------------------------------------------+", @@ -384,7 +384,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { let df = create_test_table().await?; let expected = [ "+--------------------------------------+", - "| APPROX_PERCENTILE_CONT(test.b,arg_2) |", + "| approx_percentile_cont(test.b,arg_2) |", "+--------------------------------------+", "| 10 |", "+--------------------------------------+", diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index c76c1fc2c736..a04f4f349122 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -108,6 +108,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str &[col("d", &schema).unwrap()], &[], &[], + &[], &schema, "sum1", false, diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 4358691ee5a5..5bd19850cacc 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -252,6 +252,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { let partitionby_exprs = vec![]; let orderby_exprs = vec![]; + let logical_exprs = vec![]; // Window frame starts with "UNBOUNDED PRECEDING": let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); @@ -283,6 +284,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { &window_fn, fn_name.to_string(), &args, + &logical_exprs, &partitionby_exprs, &orderby_exprs, Arc::new(window_frame), @@ -699,6 +701,7 @@ async fn run_window_test( &window_fn, fn_name.clone(), &args, + &[], &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), @@ -717,6 +720,7 @@ async fn run_window_test( &window_fn, fn_name, &args, + &[], &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 81562bf12476..441e8953dffc 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use std::{fmt, str::FromStr}; use crate::utils; -use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility}; +use crate::{type_coercion::aggregates::*, Signature, Volatility}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; @@ -45,10 +45,6 @@ pub enum AggregateFunction { NthValue, /// Correlation Correlation, - /// Approximate continuous percentile function - ApproxPercentileCont, - /// Approximate continuous percentile function with weight - ApproxPercentileContWithWeight, /// Grouping Grouping, /// Bit And @@ -75,8 +71,6 @@ impl AggregateFunction { ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", Correlation => "CORR", - ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Grouping => "GROUPING", BitAnd => "BIT_AND", BitOr => "BIT_OR", @@ -113,11 +107,6 @@ impl FromStr for AggregateFunction { "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, - // approximate - "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, - "approx_percentile_cont_with_weight" => { - AggregateFunction::ApproxPercentileContWithWeight - } // other "grouping" => AggregateFunction::Grouping, _ => { @@ -170,10 +159,6 @@ impl AggregateFunction { coerced_data_types[0].clone(), true, )))), - AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), - AggregateFunction::ApproxPercentileContWithWeight => { - Ok(coerced_data_types[0].clone()) - } AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), @@ -230,39 +215,6 @@ impl AggregateFunction { AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } - AggregateFunction::ApproxPercentileCont => { - let mut variants = - Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); - // Accept any numeric value paired with a float64 percentile - for num in NUMERICS { - variants - .push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); - // Additionally accept an integer number of centroids for T-Digest - for int in INTEGERS { - variants.push(TypeSignature::Exact(vec![ - num.clone(), - DataType::Float64, - int.clone(), - ])) - } - } - - Signature::one_of(variants, Volatility::Immutable) - } - AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of( - // Accept any numeric value paired with a float64 percentile - NUMERICS - .iter() - .map(|t| { - TypeSignature::Exact(vec![ - t.clone(), - t.clone(), - DataType::Float64, - ]) - }) - .collect(), - Volatility::Immutable, - ), AggregateFunction::StringAgg => { Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index fb5b3991ecd8..099851aece46 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -242,34 +242,6 @@ pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { Expr::InList(InList::new(Box::new(expr), list, negated)) } -/// Calculate an approximation of the specified `percentile` for `expr`. -pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ApproxPercentileCont, - vec![expr, percentile], - false, - None, - None, - None, - )) -} - -/// Calculate an approximation of the specified `percentile` for `expr` and `weight_expr`. -pub fn approx_percentile_cont_with_weight( - expr: Expr, - weight_expr: Expr, - percentile: Expr, -) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ApproxPercentileContWithWeight, - vec![expr, weight_expr, percentile], - false, - None, - None, - None, - )) -} - /// Create an EXISTS subquery expression pub fn exists(subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index c06f177510e7..169436145aae 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -83,8 +83,8 @@ pub struct AccumulatorArgs<'a> { /// The input type of the aggregate function. pub input_type: &'a DataType, - /// The number of arguments the aggregate function takes. - pub args_num: usize, + /// The logical expression of arguments the aggregate function takes. + pub input_exprs: &'a [Expr], } /// [`StateFieldsArgs`] contains information about the fields that an diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 6c9a71bab46a..98324ed6120b 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -17,7 +17,6 @@ use std::ops::Deref; -use super::functions::can_coerce_from; use crate::{AggregateFunction, Signature, TypeSignature}; use arrow::datatypes::{ @@ -158,55 +157,6 @@ pub fn coerce_types( } Ok(vec![Float64, Float64]) } - AggregateFunction::ApproxPercentileCont => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - if input_types.len() == 3 && !input_types[2].is_integer() { - return plan_err!( - "The percentile sample points count for {:?} must be integer, not {:?}.", - agg_fun, input_types[2] - ); - } - let mut result = input_types.to_vec(); - if can_coerce_from(&Float64, &input_types[1]) { - result[1] = Float64; - } else { - return plan_err!( - "Could not coerce the percent argument for {:?} to Float64. Was {:?}.", - agg_fun, input_types[1] - ); - } - Ok(result) - } - AggregateFunction::ApproxPercentileContWithWeight => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) { - return plan_err!( - "The weight argument for {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[1] - ); - } - if !matches!(input_types[2], Float64) { - return plan_err!( - "The percentile argument for {:?} must be Float64, not {:?}.", - agg_fun, - input_types[2] - ); - } - Ok(input_types.to_vec()) - } AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), AggregateFunction::StringAgg => { @@ -459,15 +409,6 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool { arg_type.is_integer() } -/// Return `true` if `arg_type` is of a [`DataType`] that the -/// [`AggregateFunction::ApproxPercentileCont`] aggregation can operate on. -pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - /// Return `true` if `arg_type` is of a [`DataType`] that the /// [`AggregateFunction::StringAgg`] aggregation can operate on. pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { @@ -532,29 +473,6 @@ mod tests { assert_eq!(r[0], DataType::Decimal128(20, 3)); let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap(); assert_eq!(r[0], DataType::Decimal256(20, 3)); - - // ApproxPercentileCont input types - let input_types = vec![ - vec![DataType::Int8, DataType::Float64], - vec![DataType::Int16, DataType::Float64], - vec![DataType::Int32, DataType::Float64], - vec![DataType::Int64, DataType::Float64], - vec![DataType::UInt8, DataType::Float64], - vec![DataType::UInt16, DataType::Float64], - vec![DataType::UInt32, DataType::Float64], - vec![DataType::UInt64, DataType::Float64], - vec![DataType::Float32, DataType::Float64], - vec![DataType::Float64, DataType::Float64], - ]; - for input_type in &input_types { - let signature = AggregateFunction::ApproxPercentileCont.signature(); - let result = coerce_types( - &AggregateFunction::ApproxPercentileCont, - input_type, - &signature, - ); - assert_eq!(*input_type, result.unwrap()); - } } #[test] diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs index b8b86d30557a..bc723c862953 100644 --- a/datafusion/functions-aggregate/src/approx_median.rs +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -28,7 +28,6 @@ use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion_expr::type_coercion::aggregates::NUMERICS; use datafusion_expr::utils::format_state_name; use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; -use datafusion_physical_expr_common::aggregate::utils::down_cast_any_ref; use crate::approx_percentile_cont::ApproxPercentileAccumulator; @@ -118,12 +117,3 @@ impl AggregateUDFImpl for ApproxMedian { ))) } } - -impl PartialEq for ApproxMedian { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.signature == x.signature) - .unwrap_or(false) - } -} diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index e75417efc684..5ae5684d9cab 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::Arc; + +use arrow::array::RecordBatch; use arrow::{ array::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, @@ -22,12 +27,238 @@ use arrow::{ }, datatypes::DataType, }; +use arrow_schema::{Field, Schema}; -use datafusion_common::{downcast_value, internal_err, DataFusionError, ScalarValue}; -use datafusion_expr::Accumulator; +use datafusion_common::{ + downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, ScalarValue, +}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, ColumnarValue, Expr, Signature, TypeSignature, + Volatility, +}; use datafusion_physical_expr_common::aggregate::tdigest::{ TDigest, TryIntoF64, DEFAULT_MAX_SIZE, }; +use datafusion_physical_expr_common::utils::limited_convert_logical_expr_to_physical_expr; + +make_udaf_expr_and_func!( + ApproxPercentileCont, + approx_percentile_cont, + expression percentile, + "Computes the approximate percentile continuous of a set of numbers", + approx_percentile_cont_udaf +); + +pub struct ApproxPercentileCont { + signature: Signature, +} + +impl Debug for ApproxPercentileCont { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("ApproxPercentileCont") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ApproxPercentileCont { + fn default() -> Self { + Self::new() + } +} + +impl ApproxPercentileCont { + /// Create a new [`ApproxPercentileCont`] aggregate function. + pub fn new() -> Self { + let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); + // Accept any numeric value paired with a float64 percentile + for num in NUMERICS { + variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); + // Additionally accept an integer number of centroids for T-Digest + for int in INTEGERS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + DataType::Float64, + int.clone(), + ])) + } + } + Self { + signature: Signature::one_of(variants, Volatility::Immutable), + } + } + + pub(crate) fn create_accumulator( + &self, + args: AccumulatorArgs, + ) -> datafusion_common::Result { + let percentile = validate_input_percentile_expr(&args.input_exprs[1])?; + let tdigest_max_size = if args.input_exprs.len() == 3 { + Some(validate_input_max_size_expr(&args.input_exprs[2])?) + } else { + None + }; + + let accumulator: ApproxPercentileAccumulator = match args.input_type { + t @ (DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64) => { + if let Some(max_size) = tdigest_max_size { + ApproxPercentileAccumulator::new_with_max_size(percentile, t.clone(), max_size) + }else{ + ApproxPercentileAccumulator::new(percentile, t.clone()) + + } + } + other => { + return not_impl_err!( + "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented" + ) + } + }; + + Ok(accumulator) + } +} + +fn get_lit_value(expr: &Expr) -> datafusion_common::Result { + let empty_schema = Arc::new(Schema::empty()); + let empty_batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); + let expr = limited_convert_logical_expr_to_physical_expr(expr, &empty_schema)?; + let result = expr.evaluate(&empty_batch)?; + match result { + ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( + "The expr {:?} can't be evaluated to scalar value", + expr + ))), + ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), + } +} + +fn validate_input_percentile_expr(expr: &Expr) -> datafusion_common::Result { + let lit = get_lit_value(expr)?; + let percentile = match &lit { + ScalarValue::Float32(Some(q)) => *q as f64, + ScalarValue::Float64(Some(q)) => *q, + got => return not_impl_err!( + "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", + got.data_type() + ) + }; + + // Ensure the percentile is between 0 and 1. + if !(0.0..=1.0).contains(&percentile) { + return plan_err!( + "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" + ); + } + Ok(percentile) +} + +fn validate_input_max_size_expr(expr: &Expr) -> datafusion_common::Result { + let lit = get_lit_value(expr)?; + let max_size = match &lit { + ScalarValue::UInt8(Some(q)) => *q as usize, + ScalarValue::UInt16(Some(q)) => *q as usize, + ScalarValue::UInt32(Some(q)) => *q as usize, + ScalarValue::UInt64(Some(q)) => *q as usize, + ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize, + ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize, + ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize, + ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize, + got => return not_impl_err!( + "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", + got.data_type() + ) + }; + Ok(max_size) +} + +impl AggregateUDFImpl for ApproxPercentileCont { + fn as_any(&self) -> &dyn Any { + self + } + + #[allow(rustdoc::private_intra_doc_links)] + /// See [`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields( + &self, + args: StateFieldsArgs, + ) -> datafusion_common::Result> { + Ok(vec![ + Field::new( + format_state_name(args.name, "max_size"), + DataType::UInt64, + false, + ), + Field::new( + format_state_name(args.name, "sum"), + DataType::Float64, + false, + ), + Field::new( + format_state_name(args.name, "count"), + DataType::Float64, + false, + ), + Field::new( + format_state_name(args.name, "max"), + DataType::Float64, + false, + ), + Field::new( + format_state_name(args.name, "min"), + DataType::Float64, + false, + ), + Field::new_list( + format_state_name(args.name, "centroids"), + Field::new("item", DataType::Float64, true), + false, + ), + ]) + } + + fn name(&self) -> &str { + "approx_percentile_cont" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + #[inline] + fn accumulator( + &self, + acc_args: AccumulatorArgs, + ) -> datafusion_common::Result> { + Ok(Box::new(self.create_accumulator(acc_args)?)) + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + if !arg_types[0].is_numeric() { + return plan_err!("approx_percentile_cont requires numeric input types"); + } + if arg_types.len() == 3 && !arg_types[2].is_integer() { + return plan_err!( + "approx_percentile_cont requires integer max_size input types" + ); + } + Ok(arg_types[0].clone()) + } +} #[derive(Debug)] pub struct ApproxPercentileAccumulator { diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs similarity index 51% rename from datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs rename to datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs index 07c2aff3437f..a64218c606c4 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -15,105 +15,140 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::ApproxPercentileCont; -use crate::{AggregateExpr, PhysicalExpr}; +use std::any::Any; +use std::fmt::{Debug, Formatter}; + use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; -use datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator; + +use datafusion_common::ScalarValue; +use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, TypeSignature}; use datafusion_physical_expr_common::aggregate::tdigest::{ Centroid, TDigest, DEFAULT_MAX_SIZE, }; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont}; -use crate::aggregate::utils::down_cast_any_ref; -use std::{any::Any, sync::Arc}; +make_udaf_expr_and_func!( + ApproxPercentileContWithWeight, + approx_percentile_cont_with_weight, + expression weight percentile, + "Computes the approximate percentile continuous with weight of a set of numbers", + approx_percentile_cont_with_weight_udaf +); /// APPROX_PERCENTILE_CONT_WITH_WEIGTH aggregate expression -#[derive(Debug)] pub struct ApproxPercentileContWithWeight { + signature: Signature, approx_percentile_cont: ApproxPercentileCont, - column_expr: Arc, - weight_expr: Arc, - percentile_expr: Arc, +} + +impl Debug for ApproxPercentileContWithWeight { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ApproxPercentileContWithWeight") + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ApproxPercentileContWithWeight { + fn default() -> Self { + Self::new() + } } impl ApproxPercentileContWithWeight { /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. - pub fn new( - expr: Vec>, - name: impl Into, - return_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, WeightExpr, DesiredPercentileLiteral] - debug_assert_eq!(expr.len(), 3); - - let sub_expr = vec![expr[0].clone(), expr[2].clone()]; - let approx_percentile_cont = - ApproxPercentileCont::new(sub_expr, name, return_type)?; - - Ok(Self { - approx_percentile_cont, - column_expr: expr[0].clone(), - weight_expr: expr[1].clone(), - percentile_expr: expr[2].clone(), - }) + pub fn new() -> Self { + Self { + signature: Signature::one_of( + // Accept any numeric value paired with a float64 percentile + NUMERICS + .iter() + .map(|t| { + TypeSignature::Exact(vec![ + t.clone(), + t.clone(), + DataType::Float64, + ]) + }) + .collect(), + Immutable, + ), + approx_percentile_cont: ApproxPercentileCont::new(), + } } } -impl AggregateExpr for ApproxPercentileContWithWeight { +impl AggregateUDFImpl for ApproxPercentileContWithWeight { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - self.approx_percentile_cont.field() + fn name(&self) -> &str { + "approx_percentile_cont_with_weight" } - #[allow(rustdoc::private_intra_doc_links)] - /// See [`TDigest::to_scalar_state()`] for a description of the serialised - /// state. - fn state_fields(&self) -> Result> { - self.approx_percentile_cont.state_fields() + fn signature(&self) -> &Signature { + &self.signature } - fn expressions(&self) -> Vec> { - vec![ - self.column_expr.clone(), - self.weight_expr.clone(), - self.percentile_expr.clone(), - ] + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!( + "approx_percentile_cont_with_weight requires numeric input types" + ); + } + if !arg_types[1].is_numeric() { + return plan_err!( + "approx_percentile_cont_with_weight requires numeric weight input types" + ); + } + if arg_types[2] != DataType::Float64 { + return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types"); + } + Ok(arg_types[0].clone()) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!( + "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available" + ); + } + + if acc_args.input_exprs.len() != 3 { + return plan_err!( + "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile" + ); + } + + let sub_args = AccumulatorArgs { + input_exprs: &[ + acc_args.input_exprs[0].clone(), + acc_args.input_exprs[2].clone(), + ], + ..acc_args + }; let approx_percentile_cont_accumulator = - self.approx_percentile_cont.create_plain_accumulator()?; + self.approx_percentile_cont.create_accumulator(sub_args)?; let accumulator = ApproxPercentileWithWeightAccumulator::new( approx_percentile_cont_accumulator, ); Ok(Box::new(accumulator)) } - fn name(&self) -> &str { - self.approx_percentile_cont.name() - } -} - -impl PartialEq for ApproxPercentileContWithWeight { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.approx_percentile_cont == x.approx_percentile_cont - && self.column_expr.eq(&x.column_expr) - && self.weight_expr.eq(&x.weight_expr) - && self.percentile_expr.eq(&x.percentile_expr) - }) - .unwrap_or(false) + #[allow(rustdoc::private_intra_doc_links)] + /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.approx_percentile_cont.state_fields(args) } } diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index cfd56619537b..062e148975bf 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -258,7 +258,7 @@ impl AggregateUDFImpl for Count { if args.is_distinct { return false; } - args.args_num == 1 + args.input_exprs.len() == 1 } fn create_groups_accumulator( diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index fabe15e416f4..daddb9d93f78 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -68,7 +68,10 @@ pub mod variance; pub mod approx_median; pub mod approx_percentile_cont; +pub mod approx_percentile_cont_with_weight; +use crate::approx_percentile_cont::approx_percentile_cont_udaf; +use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; use datafusion_expr::AggregateUDF; @@ -79,6 +82,8 @@ use std::sync::Arc; pub mod expr_fn { pub use super::approx_distinct; pub use super::approx_median::approx_median; + pub use super::approx_percentile_cont::approx_percentile_cont; + pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; pub use super::count::count; pub use super::count::count_distinct; pub use super::covariance::covar_pop; @@ -127,6 +132,8 @@ pub fn all_default_aggregate_functions() -> Vec> { stddev::stddev_pop_udaf(), approx_median::approx_median_udaf(), approx_distinct::approx_distinct_udaf(), + approx_percentile_cont_udaf(), + approx_percentile_cont_with_weight_udaf(), ] } diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs index 4c3effe7650a..42cf44f65d8f 100644 --- a/datafusion/functions-aggregate/src/stddev.rs +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -332,7 +332,7 @@ mod tests { name: "a", is_distinct: false, input_type: &DataType::Float64, - args_num: 1, + input_exprs: &[datafusion_expr::col("a")], }; let args2 = AccumulatorArgs { @@ -343,7 +343,7 @@ mod tests { name: "a", is_distinct: false, input_type: &DataType::Float64, - args_num: 1, + input_exprs: &[datafusion_expr::col("a")], }; let mut accum1 = agg1.accumulator(args1)?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 0c8e4ae34a90..acc21f14f44d 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -1055,31 +1055,6 @@ mod test { Ok(()) } - #[test] - fn agg_function_invalid_input_percentile() { - let empty = empty(); - let fun: AggregateFunction = AggregateFunction::ApproxPercentileCont; - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - vec![lit(0.95), lit(42.0), lit(100.0)], - false, - None, - None, - None, - )); - - let err = Projection::try_new(vec![agg_expr], empty) - .err() - .unwrap() - .strip_backtrace(); - - let prefix = "Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT(Float64, Float64, Float64)'. You might need to add explicit type casts.\n\tCandidate functions:"; - assert!(!err - .strip_prefix(prefix) - .unwrap() - .contains("APPROX_PERCENTILE_CONT(Float64, Float64, Float64)")); - } - #[test] fn binary_op_date32_op_interval() -> Result<()> { // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...") diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 21884f840dbd..432267e045b2 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -46,6 +46,7 @@ use datafusion_expr::utils::AggregateOrderSensitivity; pub fn create_aggregate_expr( fun: &AggregateUDF, input_phy_exprs: &[Arc], + input_exprs: &[Expr], sort_exprs: &[Expr], ordering_req: &[PhysicalSortExpr], schema: &Schema, @@ -76,6 +77,7 @@ pub fn create_aggregate_expr( Ok(Arc::new(AggregateFunctionExpr { fun: fun.clone(), args: input_phy_exprs.to_vec(), + logical_args: input_exprs.to_vec(), data_type: fun.return_type(&input_exprs_types)?, name: name.into(), schema: schema.clone(), @@ -231,6 +233,7 @@ pub struct AggregatePhysicalExpressions { pub struct AggregateFunctionExpr { fun: AggregateUDF, args: Vec>, + logical_args: Vec, /// Output / return type of this aggregate data_type: DataType, name: String, @@ -293,7 +296,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; @@ -308,7 +311,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; @@ -378,7 +381,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; self.fun.groups_accumulator_supported(args) @@ -392,7 +395,7 @@ impl AggregateExpr for AggregateFunctionExpr { sort_exprs: &self.sort_exprs, is_distinct: self.is_distinct, input_type: &self.input_type, - args_num: self.args.len(), + input_exprs: &self.logical_args, name: &self.name, }; self.fun.create_groups_accumulator(args) @@ -434,6 +437,7 @@ impl AggregateExpr for AggregateFunctionExpr { create_aggregate_expr( &updated_fn, &self.args, + &self.logical_args, &self.sort_exprs, &self.ordering_req, &self.schema, @@ -468,6 +472,7 @@ impl AggregateExpr for AggregateFunctionExpr { let reverse_aggr = create_aggregate_expr( &reverse_udf, &self.args, + &self.logical_args, &reverse_sort_exprs, &reverse_ordering_req, &self.schema, diff --git a/datafusion/physical-expr-common/src/expressions/mod.rs b/datafusion/physical-expr-common/src/expressions/mod.rs index ea21c8e9a92b..dd534cc07d20 100644 --- a/datafusion/physical-expr-common/src/expressions/mod.rs +++ b/datafusion/physical-expr-common/src/expressions/mod.rs @@ -17,7 +17,7 @@ mod cast; pub mod column; -mod literal; +pub mod literal; pub use cast::{cast, cast_with_options, CastExpr}; pub use literal::{lit, Literal}; diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index f661400fcb10..d5cd3c6f4af0 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -17,18 +17,21 @@ use std::sync::Arc; -use crate::expressions::{self, CastExpr}; -use crate::physical_expr::PhysicalExpr; -use crate::sort_expr::PhysicalSortExpr; -use crate::tree_node::ExprContext; - use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::Schema; + use datafusion_common::{exec_err, Result}; +use datafusion_expr::expr::Alias; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::Expr; +use crate::expressions::literal::Literal; +use crate::expressions::{self, CastExpr}; +use crate::physical_expr::PhysicalExpr; +use crate::sort_expr::PhysicalSortExpr; +use crate::tree_node::ExprContext; + /// Represents a [`PhysicalExpr`] node with associated properties (order and /// range) in a context where properties are tracked. pub type ExprPropertiesNode = ExprContext; @@ -115,6 +118,9 @@ pub fn limited_convert_logical_expr_to_physical_expr( schema: &Schema, ) -> Result> { match expr { + Expr::Alias(Alias { expr, .. }) => { + Ok(limited_convert_logical_expr_to_physical_expr(expr, schema)?) + } Expr::Column(col) => expressions::column::col(&col.name, schema), Expr::Cast(cast_expr) => Ok(Arc::new(CastExpr::new( limited_convert_logical_expr_to_physical_expr( @@ -124,10 +130,7 @@ pub fn limited_convert_logical_expr_to_physical_expr( cast_expr.data_type.clone(), None, ))), - Expr::Alias(alias_expr) => limited_convert_logical_expr_to_physical_expr( - alias_expr.expr.as_ref(), - schema, - ), + Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))), _ => exec_err!( "Unsupported expression: {expr} for conversion to Arc" ), @@ -138,11 +141,12 @@ pub fn limited_convert_logical_expr_to_physical_expr( mod tests { use std::sync::Arc; - use super::*; - use arrow::array::Int32Array; + use datafusion_common::cast::{as_boolean_array, as_int32_array}; + use super::*; + #[test] fn scatter_int() -> Result<()> { let truthy = Arc::new(Int32Array::from(vec![1, 10, 11, 100])); diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs deleted file mode 100644 index f2068bbc92cc..000000000000 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ /dev/null @@ -1,249 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{any::Any, sync::Arc}; - -use arrow::datatypes::{DataType, Field}; -use arrow_array::RecordBatch; -use arrow_schema::Schema; - -use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{Accumulator, ColumnarValue}; -use datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -/// APPROX_PERCENTILE_CONT aggregate expression -#[derive(Debug)] -pub struct ApproxPercentileCont { - name: String, - input_data_type: DataType, - expr: Vec>, - percentile: f64, - tdigest_max_size: Option, -} - -impl ApproxPercentileCont { - /// Create a new [`ApproxPercentileCont`] aggregate function. - pub fn new( - expr: Vec>, - name: impl Into, - input_data_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, DesiredPercentileLiteral] - debug_assert_eq!(expr.len(), 2); - - let percentile = validate_input_percentile_expr(&expr[1])?; - - Ok(Self { - name: name.into(), - input_data_type, - // The physical expr to evaluate during accumulation - expr, - percentile, - tdigest_max_size: None, - }) - } - - /// Create a new [`ApproxPercentileCont`] aggregate function. - pub fn new_with_max_size( - expr: Vec>, - name: impl Into, - input_data_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, DesiredPercentileLiteral, TDigestMaxSize] - debug_assert_eq!(expr.len(), 3); - let percentile = validate_input_percentile_expr(&expr[1])?; - let max_size = validate_input_max_size_expr(&expr[2])?; - Ok(Self { - name: name.into(), - input_data_type, - // The physical expr to evaluate during accumulation - expr, - percentile, - tdigest_max_size: Some(max_size), - }) - } - - pub(crate) fn create_plain_accumulator(&self) -> Result { - let accumulator: ApproxPercentileAccumulator = match &self.input_data_type { - t @ (DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64) => { - if let Some(max_size) = self.tdigest_max_size { - ApproxPercentileAccumulator::new_with_max_size(self.percentile, t.clone(), max_size) - - }else{ - ApproxPercentileAccumulator::new(self.percentile, t.clone()) - - } - } - other => { - return not_impl_err!( - "Support for 'APPROX_PERCENTILE_CONT' for data type {other} is not implemented" - ) - } - }; - Ok(accumulator) - } -} - -impl PartialEq for ApproxPercentileCont { - fn eq(&self, other: &ApproxPercentileCont) -> bool { - self.name == other.name - && self.input_data_type == other.input_data_type - && self.percentile == other.percentile - && self.tdigest_max_size == other.tdigest_max_size - && self.expr.len() == other.expr.len() - && self - .expr - .iter() - .zip(other.expr.iter()) - .all(|(this, other)| this.eq(other)) - } -} - -fn get_lit_value(expr: &Arc) -> Result { - let empty_schema = Schema::empty(); - let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema)); - let result = expr.evaluate(&empty_batch)?; - match result { - ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( - "The expr {:?} can't be evaluated to scalar value", - expr - ))), - ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), - } -} - -fn validate_input_percentile_expr(expr: &Arc) -> Result { - let lit = get_lit_value(expr)?; - let percentile = match &lit { - ScalarValue::Float32(Some(q)) => *q as f64, - ScalarValue::Float64(Some(q)) => *q, - got => return not_impl_err!( - "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", - got.data_type() - ) - }; - - // Ensure the percentile is between 0 and 1. - if !(0.0..=1.0).contains(&percentile) { - return plan_err!( - "Percentile value must be between 0.0 and 1.0 inclusive, {percentile} is invalid" - ); - } - Ok(percentile) -} - -fn validate_input_max_size_expr(expr: &Arc) -> Result { - let lit = get_lit_value(expr)?; - let max_size = match &lit { - ScalarValue::UInt8(Some(q)) => *q as usize, - ScalarValue::UInt16(Some(q)) => *q as usize, - ScalarValue::UInt32(Some(q)) => *q as usize, - ScalarValue::UInt64(Some(q)) => *q as usize, - ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize, - got => return not_impl_err!( - "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", - got.data_type() - ) - }; - Ok(max_size) -} - -impl AggregateExpr for ApproxPercentileCont { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), false)) - } - - #[allow(rustdoc::private_intra_doc_links)] - /// See [`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`] for a description of the serialised - /// state. - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "max_size"), - DataType::UInt64, - false, - ), - Field::new( - format_state_name(&self.name, "sum"), - DataType::Float64, - false, - ), - Field::new( - format_state_name(&self.name, "count"), - DataType::Float64, - false, - ), - Field::new( - format_state_name(&self.name, "max"), - DataType::Float64, - false, - ), - Field::new( - format_state_name(&self.name, "min"), - DataType::Float64, - false, - ), - Field::new_list( - format_state_name(&self.name, "centroids"), - Field::new("item", DataType::Float64, true), - false, - ), - ]) - } - - fn expressions(&self) -> Vec> { - self.expr.clone() - } - - fn create_accumulator(&self) -> Result> { - let accumulator = self.create_plain_accumulator()?; - Ok(Box::new(accumulator)) - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for ApproxPercentileCont { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.eq(x)) - .unwrap_or(false) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index df87a2e261a1..a1f5f153a9ff 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -36,6 +36,7 @@ use datafusion_expr::AggregateFunction; use crate::aggregate::average::Avg; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; + /// Create a physical aggregation expression. /// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. pub fn create_aggregate_expr( @@ -154,41 +155,6 @@ pub fn create_aggregate_expr( (AggregateFunction::Correlation, true) => { return not_impl_err!("CORR(DISTINCT) aggregations are not available"); } - (AggregateFunction::ApproxPercentileCont, false) => { - if input_phy_exprs.len() == 2 { - Arc::new(expressions::ApproxPercentileCont::new( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } else { - Arc::new(expressions::ApproxPercentileCont::new_with_max_size( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } - } - (AggregateFunction::ApproxPercentileCont, true) => { - return not_impl_err!( - "approx_percentile_cont(DISTINCT) aggregations are not available" - ); - } - (AggregateFunction::ApproxPercentileContWithWeight, false) => { - Arc::new(expressions::ApproxPercentileContWithWeight::new( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } - (AggregateFunction::ApproxPercentileContWithWeight, true) => { - return not_impl_err!( - "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available" - ); - } (AggregateFunction::NthValue, _) => { let expr = &input_phy_exprs[0]; let Some(n) = input_phy_exprs[1] @@ -232,15 +198,15 @@ pub fn create_aggregate_expr( mod tests { use arrow::datatypes::{DataType, Field}; - use super::*; + use datafusion_common::plan_err; + use datafusion_expr::{type_coercion, Signature}; + use crate::expressions::{ - try_cast, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, - BoolOr, DistinctArrayAgg, Max, Min, + try_cast, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, + DistinctArrayAgg, Max, Min, }; - use datafusion_common::{plan_err, DataFusionError, ScalarValue}; - use datafusion_expr::type_coercion::aggregates::NUMERICS; - use datafusion_expr::{type_coercion, Signature}; + use super::*; #[test] fn test_approx_expr() -> Result<()> { @@ -304,59 +270,6 @@ mod tests { Ok(()) } - #[test] - fn test_agg_approx_percentile_phy_expr() { - for data_type in NUMERICS { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - ), - Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))), - ]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &AggregateFunction::ApproxPercentileCont, - false, - &input_phy_exprs[..], - &input_schema, - "c1", - ) - .expect("failed to create aggregate expr"); - - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), false), - result_agg_phy_exprs.field().unwrap() - ); - } - } - - #[test] - fn test_agg_approx_percentile_invalid_phy_expr() { - for data_type in NUMERICS { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - ), - Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), - ]; - let err = create_physical_agg_expr_for_test( - &AggregateFunction::ApproxPercentileCont, - false, - &input_phy_exprs[..], - &input_schema, - "c1", - ) - .expect_err("should fail due to invalid percentile"); - - assert!(matches!(err, DataFusionError::Plan(_))); - } - } - #[test] fn test_min_max_expr() -> Result<()> { let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 9079a81e6241..c20902c11b86 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -17,8 +17,6 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; -pub(crate) mod approx_percentile_cont; -pub(crate) mod approx_percentile_cont_with_weight; pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 592393f800d0..b9a159b21e3d 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -35,8 +35,6 @@ mod try_cast; pub mod helpers { pub use crate::aggregate::min_max::{max, min}; } -pub use crate::aggregate::approx_percentile_cont::ApproxPercentileCont; -pub use crate::aggregate::approx_percentile_cont_with_weight::ApproxPercentileContWithWeight; pub use crate::aggregate::array_agg::ArrayAgg; pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; @@ -65,8 +63,8 @@ pub use column::UnKnownColumn; pub use datafusion_expr::utils::format_state_name; pub use datafusion_functions_aggregate::first_last::{FirstValue, LastValue}; pub use datafusion_physical_expr_common::expressions::column::{col, Column}; +pub use datafusion_physical_expr_common::expressions::literal::{lit, Literal}; pub use datafusion_physical_expr_common::expressions::{cast, CastExpr}; -pub use datafusion_physical_expr_common::expressions::{lit, Literal}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index b6fc70be7cbc..b7d8d60f4f35 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1339,6 +1339,7 @@ mod tests { let aggregates = vec![create_aggregate_expr( &count_udaf(), &[lit(1i8)], + &[datafusion_expr::lit(1i8)], &[], &[], &input_schema, @@ -1787,6 +1788,7 @@ mod tests { &args, &[], &[], + &[], schema, "MEDIAN(a)", false, @@ -1975,10 +1977,12 @@ mod tests { options: sort_options, }]; let args = vec![col("b", schema)?]; + let logical_args = vec![datafusion_expr::col("b")]; let func = datafusion_expr::AggregateUDF::new_from_impl(FirstValue::new()); datafusion_physical_expr_common::aggregate::create_aggregate_expr( &func, &args, + &logical_args, &sort_exprs, &ordering_req, schema, @@ -2005,10 +2009,12 @@ mod tests { options: sort_options, }]; let args = vec![col("b", schema)?]; + let logical_args = vec![datafusion_expr::col("b")]; let func = datafusion_expr::AggregateUDF::new_from_impl(LastValue::new()); - datafusion_physical_expr_common::aggregate::create_aggregate_expr( + create_aggregate_expr( &func, &args, + &logical_args, &sort_exprs, &ordering_req, schema, diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 56d780e51394..fc60ab997375 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -1194,7 +1194,7 @@ mod tests { RecordBatchStream, SendableRecordBatchStream, TaskContext, }; use datafusion_expr::{ - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + Expr, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::{col, Column, NthValue}; @@ -1301,7 +1301,10 @@ mod tests { let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf()); let col_expr = Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; + let log_expr = + Expr::Column(datafusion_common::Column::from(schema.fields[0].name())); let args = vec![col_expr]; + let log_args = vec![log_expr]; let partitionby_exprs = vec![col(hash, &schema)?]; let orderby_exprs = vec![PhysicalSortExpr { expr: col(order_by, &schema)?, @@ -1322,6 +1325,7 @@ mod tests { &window_fn, fn_name, &args, + &log_args, &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 63ce473fc57e..ecfe123a43af 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -90,6 +90,7 @@ pub fn create_window_expr( fun: &WindowFunctionDefinition, name: String, args: &[Arc], + logical_args: &[Expr], partition_by: &[Arc], order_by: &[PhysicalSortExpr], window_frame: Arc, @@ -144,6 +145,7 @@ pub fn create_window_expr( let aggregate = udaf::create_aggregate_expr( fun.as_ref(), args, + logical_args, &sort_exprs, order_by, input_schema, @@ -754,6 +756,7 @@ mod tests { &[col("a", &schema)?], &[], &[], + &[], Arc::new(WindowFrame::new(None)), schema.as_ref(), false, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 83223a04d023..e5578ae62f3e 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -486,9 +486,9 @@ enum AggregateFunction { // STDDEV = 11; // STDDEV_POP = 12; CORRELATION = 13; - APPROX_PERCENTILE_CONT = 14; + // APPROX_PERCENTILE_CONT = 14; // APPROX_MEDIAN = 15; - APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; + // APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; GROUPING = 17; // MEDIAN = 18; BIT_AND = 19; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f298dd241abf..4a7b9610e5bc 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -537,8 +537,6 @@ impl serde::Serialize for AggregateFunction { Self::Avg => "AVG", Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", - Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - Self::ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Self::Grouping => "GROUPING", Self::BitAnd => "BIT_AND", Self::BitOr => "BIT_OR", @@ -563,8 +561,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "AVG", "ARRAY_AGG", "CORRELATION", - "APPROX_PERCENTILE_CONT", - "APPROX_PERCENTILE_CONT_WITH_WEIGHT", "GROUPING", "BIT_AND", "BIT_OR", @@ -618,8 +614,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "AVG" => Ok(AggregateFunction::Avg), "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), - "APPROX_PERCENTILE_CONT" => Ok(AggregateFunction::ApproxPercentileCont), - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => Ok(AggregateFunction::ApproxPercentileContWithWeight), "GROUPING" => Ok(AggregateFunction::Grouping), "BIT_AND" => Ok(AggregateFunction::BitAnd), "BIT_OR" => Ok(AggregateFunction::BitOr), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index fa0217e9ef4f..ffaef445d668 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1940,9 +1940,9 @@ pub enum AggregateFunction { /// STDDEV = 11; /// STDDEV_POP = 12; Correlation = 13, - ApproxPercentileCont = 14, + /// APPROX_PERCENTILE_CONT = 14; /// APPROX_MEDIAN = 15; - ApproxPercentileContWithWeight = 16, + /// APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; Grouping = 17, /// MEDIAN = 18; BitAnd = 19, @@ -1974,10 +1974,6 @@ impl AggregateFunction { AggregateFunction::Avg => "AVG", AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", - AggregateFunction::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - AggregateFunction::ApproxPercentileContWithWeight => { - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" - } AggregateFunction::Grouping => "GROUPING", AggregateFunction::BitAnd => "BIT_AND", AggregateFunction::BitOr => "BIT_OR", @@ -1996,10 +1992,6 @@ impl AggregateFunction { "AVG" => Some(Self::Avg), "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), - "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont), - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => { - Some(Self::ApproxPercentileContWithWeight) - } "GROUPING" => Some(Self::Grouping), "BIT_AND" => Some(Self::BitAnd), "BIT_OR" => Some(Self::BitOr), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ed7b0129cc48..25b7413a984a 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -147,12 +147,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::BoolOr => Self::BoolOr, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Correlation => Self::Correlation, - protobuf::AggregateFunction::ApproxPercentileCont => { - Self::ApproxPercentileCont - } - protobuf::AggregateFunction::ApproxPercentileContWithWeight => { - Self::ApproxPercentileContWithWeight - } protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, protobuf::AggregateFunction::StringAgg => Self::StringAgg, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 04f7b596fea8..d9548325dac3 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -118,10 +118,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::BoolOr => Self::BoolOr, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Correlation => Self::Correlation, - AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, - AggregateFunction::ApproxPercentileContWithWeight => { - Self::ApproxPercentileContWithWeight - } AggregateFunction::Grouping => Self::Grouping, AggregateFunction::NthValue => Self::NthValueAgg, AggregateFunction::StringAgg => Self::StringAgg, @@ -381,12 +377,6 @@ pub fn serialize_expr( }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 0a91df568a1d..b636c77641c7 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -126,7 +126,6 @@ pub fn parse_physical_window_expr( ) -> Result> { let window_node_expr = parse_physical_exprs(&proto.args, registry, input_schema, codec)?; - let partition_by = parse_physical_exprs(&proto.partition_by, registry, input_schema, codec)?; @@ -178,10 +177,13 @@ pub fn parse_physical_window_expr( // TODO: Remove extended_schema if functions are all UDAF let extended_schema = schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?; + // approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. + let logical_exprs = &[]; create_window_expr( &fun, name, &window_node_expr, + logical_exprs, &partition_by, &order_by, Arc::new(window_frame), diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index d0011e4917bf..8a488d30cf24 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -496,11 +496,14 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { let agg_udf = registry.udaf(udaf_name)?; + // TODO: 'logical_exprs' is not supported for UDAF yet. + // approx_percentile_cont and approx_percentile_cont_weight are not supported for UDAF from protobuf yet. + let logical_exprs = &[]; // TODO: `order by` is not supported for UDAF yet let sort_exprs = &[]; let ordering_req = &[]; let ignore_nulls = false; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false) + udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, logical_exprs, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls, false) } } }).transpose()?.ok_or_else(|| { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index ef462ac94b9a..3a4c35a93e16 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,12 +23,11 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg, BinaryExpr, - BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, - CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr, IsNotNullExpr, - IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, - OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, TryCastExpr, - WindowShift, + ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, + CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, + InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, + NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, + StringAgg, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -270,13 +269,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Avg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Correlation - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ApproxPercentileCont - } else if aggr_expr - .downcast_ref::() - .is_some() - { - protobuf::AggregateFunction::ApproxPercentileContWithWeight } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::StringAgg } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index d0f1c4aade5e..a496e226855a 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -26,7 +26,6 @@ use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, }; -use datafusion_functions_aggregate::count::count_udaf; use prost::Message; use datafusion::datasource::provider::TableProviderFactory; @@ -34,10 +33,11 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; -use datafusion::functions_aggregate::approx_median::approx_median; +use datafusion::functions_aggregate::count::count_udaf; use datafusion::functions_aggregate::expr_fn::{ - count, count_distinct, covar_pop, covar_samp, first_value, median, stddev, - stddev_pop, sum, var_pop, var_sample, + approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, + count_distinct, covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, + var_pop, var_sample, }; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -663,6 +663,8 @@ async fn roundtrip_expr_api() -> Result<()> { stddev(lit(2.2)), stddev_pop(lit(2.2)), approx_median(lit(2)), + approx_percentile_cont(lit(2), lit(0.5)), + approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), ]; // ensure expressions created with the expr api can be round tripped @@ -1799,21 +1801,6 @@ fn roundtrip_count_distinct() { roundtrip_expr_test(test_expr, ctx); } -#[test] -fn roundtrip_approx_percentile_cont() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::ApproxPercentileCont, - vec![col("bananas"), lit(0.42_f32)], - false, - None, - None, - None, - )); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); -} - #[test] fn roundtrip_aggregate_udf() { #[derive(Debug)] diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index e517482f1db0..7f66cdbf7663 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -303,6 +303,7 @@ fn roundtrip_window() -> Result<()> { &args, &[], &[], + &[], &schema, "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", false, @@ -458,6 +459,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { &[col("b", &schema)?], &[], &[], + &[], &schema, "example_agg", false, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 7ba1893bb11a..0a6def3d6f27 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -76,26 +76,26 @@ statement error DataFusion error: Schema error: Schema contains duplicate unqual SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_weight -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Utf8, Int8, Float64\)'. You might need to add explicit type casts. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Utf8, Int8, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16, Utf8, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Utf8, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16, Int8, Utf8\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Int8, Utf8\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_histogram_bins -statement error This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\). +statement error DataFusion error: External error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\)\. SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Utf8\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Float64, Utf8\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Float64, Float64, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Float64, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100 # array agg can use order by From d4228feca341cd707a3a26372cae71a94a93b4fd Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Sun, 16 Jun 2024 18:54:11 -0700 Subject: [PATCH 27/34] refactor: remove extra default in max rows (#10941) --- datafusion-cli/src/main.rs | 2 +- docs/source/user-guide/cli/usage.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 187f856894b2..f2b29fe78690 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -133,7 +133,7 @@ struct Args { #[clap( long, - help = "The max number of rows to display for 'Table' format\n[default: 40] [possible values: numbers(0/10/...), inf(no limit)]", + help = "The max number of rows to display for 'Table' format\n[possible values: numbers(0/10/...), inf(no limit)]", default_value = "40" )] maxrows: MaxRows, diff --git a/docs/source/user-guide/cli/usage.md b/docs/source/user-guide/cli/usage.md index 617b462875c7..6a620fc69252 100644 --- a/docs/source/user-guide/cli/usage.md +++ b/docs/source/user-guide/cli/usage.md @@ -52,7 +52,7 @@ OPTIONS: --maxrows The max number of rows to display for 'Table' format - [default: 40] [possible values: numbers(0/10/...), inf(no limit)] + [possible values: numbers(0/10/...), inf(no limit)] [default: 40] --mem-pool-type Specify the memory pool type 'greedy' or 'fair', default to 'greedy' From 378b9eecd4a77386a59953209f75fc5c192d7af4 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Mon, 17 Jun 2024 17:43:20 +0800 Subject: [PATCH 28/34] chore: Improve performance of Parquet statistics conversion (#10932) --- .../physical_plan/parquet/statistics.rs | 32 +++---------------- 1 file changed, 4 insertions(+), 28 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index a2e0d8fa66be..327a516f1af1 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -303,24 +303,12 @@ macro_rules! get_statistics { ))), DataType::Int8 => Ok(Arc::new(Int8Array::from_iter( [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| { - x.and_then(|x| { - if let Ok(v) = i8::try_from(*x) { - Some(v) - } else { - None - } - }) + x.and_then(|x| i8::try_from(*x).ok()) }), ))), DataType::Int16 => Ok(Arc::new(Int16Array::from_iter( [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| { - x.and_then(|x| { - if let Ok(v) = i16::try_from(*x) { - Some(v) - } else { - None - } - }) + x.and_then(|x| i16::try_from(*x).ok()) }), ))), DataType::Int32 => Ok(Arc::new(Int32Array::from_iter( @@ -331,24 +319,12 @@ macro_rules! get_statistics { ))), DataType::UInt8 => Ok(Arc::new(UInt8Array::from_iter( [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| { - x.and_then(|x| { - if let Ok(v) = u8::try_from(*x) { - Some(v) - } else { - None - } - }) + x.and_then(|x| u8::try_from(*x).ok()) }), ))), DataType::UInt16 => Ok(Arc::new(UInt16Array::from_iter( [<$stat_type_prefix Int32StatsIterator>]::new($iterator).map(|x| { - x.and_then(|x| { - if let Ok(v) = u16::try_from(*x) { - Some(v) - } else { - None - } - }) + x.and_then(|x| u16::try_from(*x).ok()) }), ))), DataType::UInt32 => Ok(Arc::new(UInt32Array::from_iter( From c4fd7545ba7719d6d12473694fcdf6f34d25b8cb Mon Sep 17 00:00:00 2001 From: Leonardo Yvens Date: Mon, 17 Jun 2024 12:17:58 +0100 Subject: [PATCH 29/34] Add catalog::resolve_table_references (#10876) * resolve information_schema references only when necessary * add `catalog::resolve_table_references` as a public utility * collect CTEs separately in resolve_table_references * test CTE name shadowing * handle CTE name shadowing in resolve_table_references * handle unions, recursive and nested CTEs in resolve_table_references --- datafusion/core/src/catalog/mod.rs | 239 +++++++++++++++++- .../core/src/execution/session_state.rs | 96 +------ datafusion/sqllogictest/test_files/cte.slt | 7 + 3 files changed, 256 insertions(+), 86 deletions(-) diff --git a/datafusion/core/src/catalog/mod.rs b/datafusion/core/src/catalog/mod.rs index 209d9b2af297..53b133339924 100644 --- a/datafusion/core/src/catalog/mod.rs +++ b/datafusion/core/src/catalog/mod.rs @@ -27,6 +27,8 @@ use crate::catalog::schema::SchemaProvider; use dashmap::DashMap; use datafusion_common::{exec_err, not_impl_err, Result}; use std::any::Any; +use std::collections::BTreeSet; +use std::ops::ControlFlow; use std::sync::Arc; /// Represent a list of named [`CatalogProvider`]s. @@ -157,11 +159,11 @@ impl CatalogProviderList for MemoryCatalogProviderList { /// access required to read table details (e.g. statistics). /// /// The pattern that DataFusion itself uses to plan SQL queries is to walk over -/// the query to [find all schema / table references in an `async` function], +/// the query to [find all table references], /// performing required remote catalog in parallel, and then plans the query /// using that snapshot. /// -/// [find all schema / table references in an `async` function]: crate::execution::context::SessionState::resolve_table_references +/// [find all table references]: resolve_table_references /// /// # Example Catalog Implementations /// @@ -295,6 +297,182 @@ impl CatalogProvider for MemoryCatalogProvider { } } +/// Collects all tables and views referenced in the SQL statement. CTEs are collected separately. +/// This can be used to determine which tables need to be in the catalog for a query to be planned. +/// +/// # Returns +/// +/// A `(table_refs, ctes)` tuple, the first element contains table and view references and the second +/// element contains any CTE aliases that were defined and possibly referenced. +/// +/// ## Example +/// +/// ``` +/// # use datafusion_sql::parser::DFParser; +/// # use datafusion::catalog::resolve_table_references; +/// let query = "SELECT a FROM foo where x IN (SELECT y FROM bar)"; +/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); +/// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); +/// assert_eq!(table_refs.len(), 2); +/// assert_eq!(table_refs[0].to_string(), "bar"); +/// assert_eq!(table_refs[1].to_string(), "foo"); +/// assert_eq!(ctes.len(), 0); +/// ``` +/// +/// ## Example with CTEs +/// +/// ``` +/// # use datafusion_sql::parser::DFParser; +/// # use datafusion::catalog::resolve_table_references; +/// let query = "with my_cte as (values (1), (2)) SELECT * from my_cte;"; +/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); +/// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); +/// assert_eq!(table_refs.len(), 0); +/// assert_eq!(ctes.len(), 1); +/// assert_eq!(ctes[0].to_string(), "my_cte"); +/// ``` +pub fn resolve_table_references( + statement: &datafusion_sql::parser::Statement, + enable_ident_normalization: bool, +) -> datafusion_common::Result<(Vec, Vec)> { + use crate::sql::planner::object_name_to_table_reference; + use datafusion_sql::parser::{ + CopyToSource, CopyToStatement, Statement as DFStatement, + }; + use information_schema::INFORMATION_SCHEMA; + use information_schema::INFORMATION_SCHEMA_TABLES; + use sqlparser::ast::*; + + struct RelationVisitor { + relations: BTreeSet, + all_ctes: BTreeSet, + ctes_in_scope: Vec, + } + + impl RelationVisitor { + /// Record the reference to `relation`, if it's not a CTE reference. + fn insert_relation(&mut self, relation: &ObjectName) { + if !self.relations.contains(relation) + && !self.ctes_in_scope.contains(relation) + { + self.relations.insert(relation.clone()); + } + } + } + + impl Visitor for RelationVisitor { + type Break = (); + + fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> { + self.insert_relation(relation); + ControlFlow::Continue(()) + } + + fn pre_visit_query(&mut self, q: &Query) -> ControlFlow { + if let Some(with) = &q.with { + for cte in &with.cte_tables { + // The non-recursive CTE name is not in scope when evaluating the CTE itself, so this is valid: + // `WITH t AS (SELECT * FROM t) SELECT * FROM t` + // Where the first `t` refers to a predefined table. So we are careful here + // to visit the CTE first, before putting it in scope. + if !with.recursive { + // This is a bit hackish as the CTE will be visited again as part of visiting `q`, + // but thankfully `insert_relation` is idempotent. + cte.visit(self); + } + self.ctes_in_scope + .push(ObjectName(vec![cte.alias.name.clone()])); + } + } + ControlFlow::Continue(()) + } + + fn post_visit_query(&mut self, q: &Query) -> ControlFlow { + if let Some(with) = &q.with { + for _ in &with.cte_tables { + // Unwrap: We just pushed these in `pre_visit_query` + self.all_ctes.insert(self.ctes_in_scope.pop().unwrap()); + } + } + ControlFlow::Continue(()) + } + + fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> { + if let Statement::ShowCreate { + obj_type: ShowCreateObject::Table | ShowCreateObject::View, + obj_name, + } = statement + { + self.insert_relation(obj_name) + } + + // SHOW statements will later be rewritten into a SELECT from the information_schema + let requires_information_schema = matches!( + statement, + Statement::ShowFunctions { .. } + | Statement::ShowVariable { .. } + | Statement::ShowStatus { .. } + | Statement::ShowVariables { .. } + | Statement::ShowCreate { .. } + | Statement::ShowColumns { .. } + | Statement::ShowTables { .. } + | Statement::ShowCollation { .. } + ); + if requires_information_schema { + for s in INFORMATION_SCHEMA_TABLES { + self.relations.insert(ObjectName(vec![ + Ident::new(INFORMATION_SCHEMA), + Ident::new(*s), + ])); + } + } + ControlFlow::Continue(()) + } + } + + let mut visitor = RelationVisitor { + relations: BTreeSet::new(), + all_ctes: BTreeSet::new(), + ctes_in_scope: vec![], + }; + + fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) { + match statement { + DFStatement::Statement(s) => { + let _ = s.as_ref().visit(visitor); + } + DFStatement::CreateExternalTable(table) => { + visitor + .relations + .insert(ObjectName(vec![Ident::from(table.name.as_str())])); + } + DFStatement::CopyTo(CopyToStatement { source, .. }) => match source { + CopyToSource::Relation(table_name) => { + visitor.insert_relation(table_name); + } + CopyToSource::Query(query) => { + query.visit(visitor); + } + }, + DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor), + } + } + + visit_statement(statement, &mut visitor); + + let table_refs = visitor + .relations + .into_iter() + .map(|x| object_name_to_table_reference(x, enable_ident_normalization)) + .collect::>()?; + let ctes = visitor + .all_ctes + .into_iter() + .map(|x| object_name_to_table_reference(x, enable_ident_normalization)) + .collect::>()?; + Ok((table_refs, ctes)) +} + #[cfg(test)] mod tests { use super::*; @@ -363,4 +541,61 @@ mod tests { let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; assert!(cat.deregister_schema("foo", false).unwrap().is_none()); } + + #[test] + fn resolve_table_references_shadowed_cte() { + use datafusion_sql::parser::DFParser; + + // An interesting edge case where the `t` name is used both as an ordinary table reference + // and as a CTE reference. + let query = "WITH t AS (SELECT * FROM t) SELECT * FROM t"; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + assert_eq!(table_refs.len(), 1); + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "t"); + assert_eq!(table_refs[0].to_string(), "t"); + + // UNION is a special case where the CTE is not in scope for the second branch. + let query = "(with t as (select 1) select * from t) union (select * from t)"; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + assert_eq!(table_refs.len(), 1); + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "t"); + assert_eq!(table_refs[0].to_string(), "t"); + + // Nested CTEs are also handled. + // Here the first `u` is a CTE, but the second `u` is a table reference. + // While `t` is always a CTE. + let query = "(with t as (with u as (select 1) select * from u) select * from u cross join t)"; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + assert_eq!(table_refs.len(), 1); + assert_eq!(ctes.len(), 2); + assert_eq!(ctes[0].to_string(), "t"); + assert_eq!(ctes[1].to_string(), "u"); + assert_eq!(table_refs[0].to_string(), "u"); + } + + #[test] + fn resolve_table_references_recursive_cte() { + use datafusion_sql::parser::DFParser; + + let query = " + WITH RECURSIVE nodes AS ( + SELECT 1 as id + UNION ALL + SELECT id + 1 as id + FROM nodes + WHERE id < 10 + ) + SELECT * FROM nodes + "; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + assert_eq!(table_refs.len(), 0); + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "nodes"); + } } diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index fed101bd239b..1df77a1f9e0b 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -66,15 +66,12 @@ use datafusion_optimizer::{ use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; -use datafusion_sql::parser::{CopyToSource, CopyToStatement, DFParser, Statement}; -use datafusion_sql::planner::{ - object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel, -}; +use datafusion_sql::parser::{DFParser, Statement}; +use datafusion_sql::planner::{ContextProvider, ParserOptions, SqlToRel}; use sqlparser::dialect::dialect_from_str; use std::collections::hash_map::Entry; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; -use std::ops::ControlFlow; use std::sync::Arc; use url::Url; use uuid::Uuid; @@ -493,91 +490,22 @@ impl SessionState { Ok(statement) } - /// Resolve all table references in the SQL statement. + /// Resolve all table references in the SQL statement. Does not include CTE references. + /// + /// See [`catalog::resolve_table_references`] for more information. + /// + /// [`catalog::resolve_table_references`]: crate::catalog::resolve_table_references pub fn resolve_table_references( &self, statement: &datafusion_sql::parser::Statement, ) -> datafusion_common::Result> { - use crate::catalog::information_schema::INFORMATION_SCHEMA_TABLES; - use datafusion_sql::parser::Statement as DFStatement; - use sqlparser::ast::*; - - // Getting `TableProviders` is async but planing is not -- thus pre-fetch - // table providers for all relations referenced in this query - let mut relations = hashbrown::HashSet::with_capacity(10); - - struct RelationVisitor<'a>(&'a mut hashbrown::HashSet); - - impl<'a> RelationVisitor<'a> { - /// Record that `relation` was used in this statement - fn insert(&mut self, relation: &ObjectName) { - self.0.get_or_insert_with(relation, |_| relation.clone()); - } - } - - impl<'a> Visitor for RelationVisitor<'a> { - type Break = (); - - fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> { - self.insert(relation); - ControlFlow::Continue(()) - } - - fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> { - if let Statement::ShowCreate { - obj_type: ShowCreateObject::Table | ShowCreateObject::View, - obj_name, - } = statement - { - self.insert(obj_name) - } - ControlFlow::Continue(()) - } - } - - let mut visitor = RelationVisitor(&mut relations); - fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor<'_>) { - match statement { - DFStatement::Statement(s) => { - let _ = s.as_ref().visit(visitor); - } - DFStatement::CreateExternalTable(table) => { - visitor - .0 - .insert(ObjectName(vec![Ident::from(table.name.as_str())])); - } - DFStatement::CopyTo(CopyToStatement { source, .. }) => match source { - CopyToSource::Relation(table_name) => { - visitor.insert(table_name); - } - CopyToSource::Query(query) => { - query.visit(visitor); - } - }, - DFStatement::Explain(explain) => { - visit_statement(&explain.statement, visitor) - } - } - } - - visit_statement(statement, &mut visitor); - - // Always include information_schema if available - if self.config.information_schema() { - for s in INFORMATION_SCHEMA_TABLES { - relations.insert(ObjectName(vec![ - Ident::new(INFORMATION_SCHEMA), - Ident::new(*s), - ])); - } - } - let enable_ident_normalization = self.config.options().sql_parser.enable_ident_normalization; - relations - .into_iter() - .map(|x| object_name_to_table_reference(x, enable_ident_normalization)) - .collect::>() + let (table_refs, _) = crate::catalog::resolve_table_references( + statement, + enable_ident_normalization, + )?; + Ok(table_refs) } /// Convert an AST Statement into a LogicalPlan diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 1ff108cf6c5f..d8eaa51fc88a 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -828,3 +828,10 @@ SELECT * FROM non_recursive_cte, recursive_cte; ---- 1 1 1 3 + +# Name shadowing: +# The first `t` refers to the table, the second to the CTE. +query I +WITH t AS (SELECT * FROM t where t.a < 2) SELECT * FROM t +---- +1 \ No newline at end of file From a923c659cf932f6369f2d5257e5b99128b67091a Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Mon, 17 Jun 2024 19:22:55 +0800 Subject: [PATCH 30/34] feat: Add support for Int8 and Int16 data types in data page statistics (#10931) --- .../physical_plan/parquet/statistics.rs | 30 +++++++++++++++++++ .../core/tests/parquet/arrow_statistics.rs | 24 ++------------- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 327a516f1af1..a2f17ca9b7a7 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -548,6 +548,8 @@ macro_rules! make_data_page_stats_iterator { }; } +make_data_page_stats_iterator!(MinInt32DataPageStatsIterator, min, Index::INT32, i32); +make_data_page_stats_iterator!(MaxInt32DataPageStatsIterator, max, Index::INT32, i32); make_data_page_stats_iterator!(MinInt64DataPageStatsIterator, min, Index::INT64, i64); make_data_page_stats_iterator!(MaxInt64DataPageStatsIterator, max, Index::INT64, i64); @@ -555,6 +557,29 @@ macro_rules! get_data_page_statistics { ($stat_type_prefix: ident, $data_type: ident, $iterator: ident) => { paste! { match $data_type { + Some(DataType::Int8) => Ok(Arc::new( + Int8Array::from_iter( + [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) + .map(|x| { + x.into_iter().filter_map(|x| { + x.and_then(|x| i8::try_from(x).ok()) + }) + }) + .flatten() + ) + )), + Some(DataType::Int16) => Ok(Arc::new( + Int16Array::from_iter( + [<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator) + .map(|x| { + x.into_iter().filter_map(|x| { + x.and_then(|x| i16::try_from(x).ok()) + }) + }) + .flatten() + ) + )), + Some(DataType::Int32) => Ok(Arc::new(Int32Array::from_iter([<$stat_type_prefix Int32DataPageStatsIterator>]::new($iterator).flatten()))), Some(DataType::Int64) => Ok(Arc::new(Int64Array::from_iter([<$stat_type_prefix Int64DataPageStatsIterator>]::new($iterator).flatten()))), _ => unimplemented!() } @@ -642,6 +667,11 @@ where { let iter = iterator.flat_map(|(len, index)| match index { Index::NONE => vec![None; len], + Index::INT32(native_index) => native_index + .indexes + .iter() + .map(|x| x.null_count.map(|x| x as u64)) + .collect::>(), Index::INT64(native_index) => native_index .indexes .iter() diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index 6b8705441d12..87bd1372225f 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -550,16 +550,11 @@ async fn test_int_32() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i32", - check: Check::RowGroup, + check: Check::Both, } .run(); } -// BUG: ignore this test for now -// https://github.com/apache/datafusion/issues/10585 -// Note that the file has 4 columns named "i8", "i16", "i32", "i64". -// - The tests on column i32 and i64 passed. -// - The tests on column i8 and i16 failed. #[tokio::test] async fn test_int_16() { // This creates a parquet files of 4 columns named "i8", "i16", "i32", "i64" @@ -573,16 +568,6 @@ async fn test_int_16() { Test { reader: &reader, // mins are [-5, -4, 0, 5] - // BUG: not sure why this returns same data but in Int32Array type even though I debugged and the columns name is "i16" an its data is Int16 - // My debugging tells me the bug is either at: - // 1. The new code to get "iter". See the code in this PR with - // // Get an iterator over the column statistics - // let iter = row_groups - // .iter() - // .map(|x| x.column(parquet_idx).statistics()); - // OR - // 2. in the function (and/or its marco) `pub(crate) fn min_statistics<'a, I: Iterator>>` here - // https://github.com/apache/datafusion/blob/ea023e2d4878240eece870cf4b346c7a0667aeed/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs#L179 expected_min: Arc::new(Int16Array::from(vec![-5, -4, 0, 5])), // panic here because the actual data is Int32Array // maxes are [-1, 0, 4, 9] expected_max: Arc::new(Int16Array::from(vec![-1, 0, 4, 9])), @@ -591,13 +576,11 @@ async fn test_int_16() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i16", - check: Check::RowGroup, + check: Check::Both, } .run(); } -// BUG (same as above): ignore this test for now -// https://github.com/apache/datafusion/issues/10585 #[tokio::test] async fn test_int_8() { // This creates a parquet files of 4 columns named "i8", "i16", "i32", "i64" @@ -611,7 +594,6 @@ async fn test_int_8() { Test { reader: &reader, // mins are [-5, -4, 0, 5] - // BUG: not sure why this returns same data but in Int32Array even though I debugged and the columns name is "i8" an its data is Int8 expected_min: Arc::new(Int8Array::from(vec![-5, -4, 0, 5])), // panic here because the actual data is Int32Array // maxes are [-1, 0, 4, 9] expected_max: Arc::new(Int8Array::from(vec![-1, 0, 4, 9])), @@ -620,7 +602,7 @@ async fn test_int_8() { // row counts are [5, 5, 5, 5] expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), column_name: "i8", - check: Check::RowGroup, + check: Check::Both, } .run(); } From 2daadb75230e2c197d2915257a9637913fa2c2e6 Mon Sep 17 00:00:00 2001 From: Dharan Aditya Date: Mon, 17 Jun 2024 18:36:16 +0530 Subject: [PATCH 31/34] Convert BitAnd, BitOr, BitXor to UDAF (#10930) * remove bit and or xor from expr * remove bit and or xor from physical expr and proto * add proto regen changes * impl BitAnd, BitOr, BitXor UADF * add support for float * removing support for float * refactor helper macros * clippy'fy * simplify Bitwise operation * add documentation * formatting * fix lint issue * remove XorDistinct * update roundtrip_expr_api test * linting * support groups accumulator --- datafusion/expr/src/aggregate_function.rs | 20 - .../expr/src/type_coercion/aggregates.rs | 18 - .../functions-aggregate/src/bit_and_or_xor.rs | 458 ++++++++++++ datafusion/functions-aggregate/src/lib.rs | 7 + .../src/aggregate/bit_and_or_xor.rs | 695 ------------------ .../physical-expr/src/aggregate/build_in.rs | 78 +- datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/expressions/mod.rs | 1 - datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 9 - datafusion/proto/src/generated/prost.rs | 12 +- .../proto/src/logical_plan/from_proto.rs | 3 - datafusion/proto/src/logical_plan/to_proto.rs | 6 - .../proto/src/physical_plan/to_proto.rs | 19 +- .../tests/cases/roundtrip_logical_plan.rs | 4 + 15 files changed, 481 insertions(+), 856 deletions(-) create mode 100644 datafusion/functions-aggregate/src/bit_and_or_xor.rs delete mode 100644 datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 441e8953dffc..a7fbf26febb1 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -47,12 +47,6 @@ pub enum AggregateFunction { Correlation, /// Grouping Grouping, - /// Bit And - BitAnd, - /// Bit Or - BitOr, - /// Bit Xor - BitXor, /// Bool And BoolAnd, /// Bool Or @@ -72,9 +66,6 @@ impl AggregateFunction { NthValue => "NTH_VALUE", Correlation => "CORR", Grouping => "GROUPING", - BitAnd => "BIT_AND", - BitOr => "BIT_OR", - BitXor => "BIT_XOR", BoolAnd => "BOOL_AND", BoolOr => "BOOL_OR", StringAgg => "STRING_AGG", @@ -94,9 +85,6 @@ impl FromStr for AggregateFunction { Ok(match name { // general "avg" => AggregateFunction::Avg, - "bit_and" => AggregateFunction::BitAnd, - "bit_or" => AggregateFunction::BitOr, - "bit_xor" => AggregateFunction::BitXor, "bool_and" => AggregateFunction::BoolAnd, "bool_or" => AggregateFunction::BoolOr, "max" => AggregateFunction::Max, @@ -144,9 +132,6 @@ impl AggregateFunction { // The coerced_data_types is same with input_types. Ok(coerced_data_types[0].clone()) } - AggregateFunction::BitAnd - | AggregateFunction::BitOr - | AggregateFunction::BitXor => Ok(coerced_data_types[0].clone()), AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { Ok(DataType::Boolean) } @@ -199,11 +184,6 @@ impl AggregateFunction { .collect::>(); Signature::uniform(1, valid, Volatility::Immutable) } - AggregateFunction::BitAnd - | AggregateFunction::BitOr - | AggregateFunction::BitXor => { - Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable) - } AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable) } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 98324ed6120b..a216c98899fe 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -121,20 +121,6 @@ pub fn coerce_types( }; Ok(vec![v]) } - AggregateFunction::BitAnd - | AggregateFunction::BitOr - | AggregateFunction::BitXor => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval. - if !is_bit_and_or_xor_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(input_types.to_vec()) - } AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. @@ -350,10 +336,6 @@ pub fn avg_sum_type(arg_type: &DataType) -> Result { } } -pub fn is_bit_and_or_xor_support_arg_type(arg_type: &DataType) -> bool { - NUMERICS.contains(arg_type) -} - pub fn is_bool_and_or_support_arg_type(arg_type: &DataType) -> bool { matches!(arg_type, DataType::Boolean) } diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs new file mode 100644 index 000000000000..19e24f547d8a --- /dev/null +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -0,0 +1,458 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines `BitAnd`, `BitOr`, `BitXor` and `BitXor DISTINCT` aggregate accumulators + +use std::any::Any; +use std::collections::HashSet; +use std::fmt::{Display, Formatter}; + +use ahash::RandomState; +use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; +use arrow::datatypes::{ + ArrowNativeType, ArrowNumericType, DataType, Int16Type, Int32Type, Int64Type, + Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_schema::Field; + +use datafusion_common::cast::as_list_array; +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::INTEGERS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, GroupsAccumulator, ReversedUDAF, Signature, Volatility, +}; + +use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign}; + +/// This macro helps create group accumulators based on bitwise operations typically used internally +/// and might not be necessary for users to call directly. +macro_rules! group_accumulator_helper { + ($t:ty, $dt:expr, $opr:expr) => { + match $opr { + BitwiseOperationType::And => Ok(Box::new( + PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitand_assign(y)) + .with_starting_value(!0), + )), + BitwiseOperationType::Or => Ok(Box::new( + PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y)), + )), + BitwiseOperationType::Xor => Ok(Box::new( + PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y)), + )), + } + }; +} + +/// `accumulator_helper` is a macro accepting (ArrowPrimitiveType, BitwiseOperationType, bool) +macro_rules! accumulator_helper { + ($t:ty, $opr:expr, $is_distinct: expr) => { + match $opr { + BitwiseOperationType::And => Ok(Box::>::default()), + BitwiseOperationType::Or => Ok(Box::>::default()), + BitwiseOperationType::Xor => { + if $is_distinct { + Ok(Box::>::default()) + } else { + Ok(Box::>::default()) + } + } + } + }; +} + +/// AND, OR and XOR only supports a subset of numeric types +/// +/// `args` is [AccumulatorArgs] +/// `opr` is [BitwiseOperationType] +/// `is_distinct` is boolean value indicating whether the operation is distinct or not. +macro_rules! downcast_bitwise_accumulator { + ($args:ident, $opr:expr, $is_distinct: expr) => { + match $args.data_type { + DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct), + DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct), + DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct), + DataType::Int64 => accumulator_helper!(Int64Type, $opr, $is_distinct), + DataType::UInt8 => accumulator_helper!(UInt8Type, $opr, $is_distinct), + DataType::UInt16 => accumulator_helper!(UInt16Type, $opr, $is_distinct), + DataType::UInt32 => accumulator_helper!(UInt32Type, $opr, $is_distinct), + DataType::UInt64 => accumulator_helper!(UInt64Type, $opr, $is_distinct), + _ => { + not_impl_err!( + "{} not supported for {}: {}", + stringify!($opr), + $args.name, + $args.data_type + ) + } + } + }; +} + +/// Simplifies the creation of User-Defined Aggregate Functions (UDAFs) for performing bitwise operations in a declarative manner. +/// +/// `EXPR_FN` identifier used to name the generated expression function. +/// `AGGREGATE_UDF_FN` is an identifier used to name the underlying UDAF function. +/// `OPR_TYPE` is an expression that evaluates to the type of bitwise operation to be performed. +macro_rules! make_bitwise_udaf_expr_and_func { + ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr) => { + make_udaf_expr!( + $EXPR_FN, + expr_x, + concat!( + "Returns the bitwise", + stringify!($OPR_TYPE), + "of a group of values" + ), + $AGGREGATE_UDF_FN + ); + create_func!( + $EXPR_FN, + $AGGREGATE_UDF_FN, + BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN)) + ); + }; +} + +make_bitwise_udaf_expr_and_func!(bit_and, bit_and_udaf, BitwiseOperationType::And); +make_bitwise_udaf_expr_and_func!(bit_or, bit_or_udaf, BitwiseOperationType::Or); +make_bitwise_udaf_expr_and_func!(bit_xor, bit_xor_udaf, BitwiseOperationType::Xor); + +/// The different types of bitwise operations that can be performed. +#[derive(Debug, Clone, Eq, PartialEq)] +enum BitwiseOperationType { + And, + Or, + Xor, +} + +impl Display for BitwiseOperationType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +/// [BitwiseOperation] struct encapsulates information about a bitwise operation. +#[derive(Debug)] +struct BitwiseOperation { + signature: Signature, + /// `operation` indicates the type of bitwise operation to be performed. + operation: BitwiseOperationType, + func_name: &'static str, +} + +impl BitwiseOperation { + pub fn new(operator: BitwiseOperationType, func_name: &'static str) -> Self { + Self { + operation: operator, + signature: Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable), + func_name, + } + } +} + +impl AggregateUDFImpl for BitwiseOperation { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.func_name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + if !arg_type.is_integer() { + return exec_err!( + "[return_type] {} not supported for {}", + self.name(), + arg_type + ); + } + Ok(arg_type.clone()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if self.operation == BitwiseOperationType::Xor && args.is_distinct { + Ok(vec![Field::new_list( + format_state_name( + args.name, + format!("{} distinct", self.name()).as_str(), + ), + Field::new("item", args.return_type.clone(), true), + false, + )]) + } else { + Ok(vec![Field::new( + format_state_name(args.name, self.name()), + args.return_type.clone(), + true, + )]) + } + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let data_type = args.data_type; + let operation = &self.operation; + downcast_integer! { + data_type => (group_accumulator_helper, data_type, operation), + _ => not_impl_err!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + data_type + ), + } + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +struct BitAndAccumulator { + value: Option, +} + +impl std::fmt::Debug for BitAndAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BitAndAccumulator({})", T::DATA_TYPE) + } +} + +impl Default for BitAndAccumulator { + fn default() -> Self { + Self { value: None } + } +} + +impl Accumulator for BitAndAccumulator +where + T::Native: std::ops::BitAnd, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if let Some(x) = arrow::compute::bit_and(values[0].as_primitive::()) { + let v = self.value.get_or_insert(x); + *v = *v & x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +struct BitOrAccumulator { + value: Option, +} + +impl std::fmt::Debug for BitOrAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BitOrAccumulator({})", T::DATA_TYPE) + } +} + +impl Default for BitOrAccumulator { + fn default() -> Self { + Self { value: None } + } +} + +impl Accumulator for BitOrAccumulator +where + T::Native: std::ops::BitOr, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if let Some(x) = arrow::compute::bit_or(values[0].as_primitive::()) { + let v = self.value.get_or_insert(T::Native::usize_as(0)); + *v = *v | x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +struct BitXorAccumulator { + value: Option, +} + +impl std::fmt::Debug for BitXorAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "BitXorAccumulator({})", T::DATA_TYPE) + } +} + +impl Default for BitXorAccumulator { + fn default() -> Self { + Self { value: None } + } +} + +impl Accumulator for BitXorAccumulator +where + T::Native: std::ops::BitXor, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if let Some(x) = arrow::compute::bit_xor(values[0].as_primitive::()) { + let v = self.value.get_or_insert(T::Native::usize_as(0)); + *v = *v ^ x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +struct DistinctBitXorAccumulator { + values: HashSet, +} + +impl std::fmt::Debug for DistinctBitXorAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE) + } +} + +impl Default for DistinctBitXorAccumulator { + fn default() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for DistinctBitXorAccumulator +where + T::Native: std::ops::BitXor + std::hash::Hash + Eq, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.values.insert(array.value(idx)); + } + } + None => array.values().iter().for_each(|x| { + self.values.insert(*x); + }), + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let mut acc = T::Native::usize_as(0); + for distinct_value in self.values.iter() { + acc = acc ^ *distinct_value; + } + let v = (!self.values.is_empty()).then_some(acc); + ScalarValue::new_primitive::(v, &T::DATA_TYPE) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.values.capacity() * std::mem::size_of::() + } + + fn state(&mut self) -> Result> { + // 1. Stores aggregate state in `ScalarValue::List` + // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set + let state_out = { + let values = self + .values + .iter() + .map(|x| ScalarValue::new_primitive::(Some(*x), &T::DATA_TYPE)) + .collect::>>()?; + + let arr = ScalarValue::new_list(&values, &T::DATA_TYPE); + vec![ScalarValue::List(arr)] + }; + Ok(state_out) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if let Some(state) = states.first() { + let list_arr = as_list_array(state)?; + for arr in list_arr.iter().flatten() { + self.update_batch(&[arr])?; + } + } + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index daddb9d93f78..990303bd1de3 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -69,6 +69,7 @@ pub mod variance; pub mod approx_median; pub mod approx_percentile_cont; pub mod approx_percentile_cont_with_weight; +pub mod bit_and_or_xor; use crate::approx_percentile_cont::approx_percentile_cont_udaf; use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; @@ -84,6 +85,9 @@ pub mod expr_fn { pub use super::approx_median::approx_median; pub use super::approx_percentile_cont::approx_percentile_cont; pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; + pub use super::bit_and_or_xor::bit_and; + pub use super::bit_and_or_xor::bit_or; + pub use super::bit_and_or_xor::bit_xor; pub use super::count::count; pub use super::count::count_distinct; pub use super::covariance::covar_pop; @@ -134,6 +138,9 @@ pub fn all_default_aggregate_functions() -> Vec> { approx_distinct::approx_distinct_udaf(), approx_percentile_cont_udaf(), approx_percentile_cont_with_weight_udaf(), + bit_and_or_xor::bit_and_udaf(), + bit_and_or_xor::bit_or_udaf(), + bit_and_or_xor::bit_xor_udaf(), ] } diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs deleted file mode 100644 index 3fa225c5e479..000000000000 --- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs +++ /dev/null @@ -1,695 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines BitAnd, BitOr, and BitXor Aggregate accumulators - -use ahash::RandomState; -use datafusion_common::cast::as_list_array; -use std::any::Any; -use std::sync::Arc; - -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::datatypes::DataType; -use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{Accumulator, GroupsAccumulator}; -use std::collections::HashSet; - -use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use arrow::array::Array; -use arrow::compute::{bit_and, bit_or, bit_xor}; -use arrow_array::cast::AsArray; -use arrow_array::{downcast_integer, ArrowNumericType}; -use arrow_buffer::ArrowNativeType; - -/// BIT_AND aggregate expression -#[derive(Debug, Clone)] -pub struct BitAnd { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl BitAnd { - /// Create a new BIT_AND aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for BitAnd { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty) => { - Ok(Box::>::default()) - }; - } - downcast_integer! { - &self.data_type => (helper), - _ => Err(DataFusionError::NotImplemented(format!( - "BitAndAccumulator not supported for {} with {}", - self.name(), - self.data_type - ))), - } - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "bit_and"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - true - } - - fn create_groups_accumulator(&self) -> Result> { - use std::ops::BitAndAssign; - - // Note the default value for BitAnd should be all set, i.e. `!0` - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new( - PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| { - x.bitand_assign(y) - }) - .with_starting_value(!0), - )) - }; - } - - let data_type = &self.data_type; - downcast_integer! { - data_type => (helper, data_type), - _ => not_impl_err!( - "GroupsAccumulator not supported for {} with {}", - self.name(), - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } -} - -impl PartialEq for BitAnd { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -struct BitAndAccumulator { - value: Option, -} - -impl std::fmt::Debug for BitAndAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "BitAndAccumulator({})", T::DATA_TYPE) - } -} - -impl Default for BitAndAccumulator { - fn default() -> Self { - Self { value: None } - } -} - -impl Accumulator for BitAndAccumulator -where - T::Native: std::ops::BitAnd, -{ - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if let Some(x) = bit_and(values[0].as_primitive::()) { - let v = self.value.get_or_insert(x); - *v = *v & x; - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn evaluate(&mut self) -> Result { - ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -/// BIT_OR aggregate expression -#[derive(Debug, Clone)] -pub struct BitOr { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl BitOr { - /// Create a new BIT_OR aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for BitOr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty) => { - Ok(Box::>::default()) - }; - } - downcast_integer! { - &self.data_type => (helper), - _ => Err(DataFusionError::NotImplemented(format!( - "BitOrAccumulator not supported for {} with {}", - self.name(), - self.data_type - ))), - } - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "bit_or"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - true - } - - fn create_groups_accumulator(&self) -> Result> { - use std::ops::BitOrAssign; - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( - $dt, - |x, y| x.bitor_assign(y), - ))) - }; - } - - let data_type = &self.data_type; - downcast_integer! { - data_type => (helper, data_type), - _ => not_impl_err!( - "GroupsAccumulator not supported for {} with {}", - self.name(), - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } -} - -impl PartialEq for BitOr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -struct BitOrAccumulator { - value: Option, -} - -impl std::fmt::Debug for BitOrAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "BitOrAccumulator({})", T::DATA_TYPE) - } -} - -impl Default for BitOrAccumulator { - fn default() -> Self { - Self { value: None } - } -} - -impl Accumulator for BitOrAccumulator -where - T::Native: std::ops::BitOr, -{ - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if let Some(x) = bit_or(values[0].as_primitive::()) { - let v = self.value.get_or_insert(T::Native::usize_as(0)); - *v = *v | x; - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn evaluate(&mut self) -> Result { - ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -/// BIT_XOR aggregate expression -#[derive(Debug, Clone)] -pub struct BitXor { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl BitXor { - /// Create a new BIT_XOR aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for BitXor { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty) => { - Ok(Box::>::default()) - }; - } - downcast_integer! { - &self.data_type => (helper), - _ => Err(DataFusionError::NotImplemented(format!( - "BitXor not supported for {} with {}", - self.name(), - self.data_type - ))), - } - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "bit_xor"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - true - } - - fn create_groups_accumulator(&self) -> Result> { - use std::ops::BitXorAssign; - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( - $dt, - |x, y| x.bitxor_assign(y), - ))) - }; - } - - let data_type = &self.data_type; - downcast_integer! { - data_type => (helper, data_type), - _ => not_impl_err!( - "GroupsAccumulator not supported for {} with {}", - self.name(), - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } -} - -impl PartialEq for BitXor { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -struct BitXorAccumulator { - value: Option, -} - -impl std::fmt::Debug for BitXorAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "BitXorAccumulator({})", T::DATA_TYPE) - } -} - -impl Default for BitXorAccumulator { - fn default() -> Self { - Self { value: None } - } -} - -impl Accumulator for BitXorAccumulator -where - T::Native: std::ops::BitXor, -{ - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if let Some(x) = bit_xor(values[0].as_primitive::()) { - let v = self.value.get_or_insert(T::Native::usize_as(0)); - *v = *v ^ x; - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn evaluate(&mut self) -> Result { - ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -/// Expression for a BIT_XOR(DISTINCT) aggregation. -#[derive(Debug, Clone)] -pub struct DistinctBitXor { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl DistinctBitXor { - /// Create a new DistinctBitXor aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for DistinctBitXor { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty) => { - Ok(Box::>::default()) - }; - } - downcast_integer! { - &self.data_type => (helper), - _ => Err(DataFusionError::NotImplemented(format!( - "DistinctBitXorAccumulator not supported for {} with {}", - self.name(), - self.data_type - ))), - } - } - - fn state_fields(&self) -> Result> { - // State field is a List which stores items to rebuild hash set. - Ok(vec![Field::new_list( - format_state_name(&self.name, "bit_xor distinct"), - Field::new("item", self.data_type.clone(), true), - false, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctBitXor { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -struct DistinctBitXorAccumulator { - values: HashSet, -} - -impl std::fmt::Debug for DistinctBitXorAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE) - } -} - -impl Default for DistinctBitXorAccumulator { - fn default() -> Self { - Self { - values: HashSet::default(), - } - } -} - -impl Accumulator for DistinctBitXorAccumulator -where - T::Native: std::ops::BitXor + std::hash::Hash + Eq, -{ - fn state(&mut self) -> Result> { - // 1. Stores aggregate state in `ScalarValue::List` - // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set - let state_out = { - let values = self - .values - .iter() - .map(|x| ScalarValue::new_primitive::(Some(*x), &T::DATA_TYPE)) - .collect::>>()?; - - let arr = ScalarValue::new_list(&values, &T::DATA_TYPE); - vec![ScalarValue::List(arr)] - }; - Ok(state_out) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let array = values[0].as_primitive::(); - match array.nulls().filter(|x| x.null_count() > 0) { - Some(n) => { - for idx in n.valid_indices() { - self.values.insert(array.value(idx)); - } - } - None => array.values().iter().for_each(|x| { - self.values.insert(*x); - }), - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if let Some(state) = states.first() { - let list_arr = as_list_array(state)?; - for arr in list_arr.iter().flatten() { - self.update_batch(&[arr])?; - } - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let mut acc = T::Native::usize_as(0); - for distinct_value in self.values.iter() { - acc = acc ^ *distinct_value; - } - let v = (!self.values.is_empty()).then_some(acc); - ScalarValue::new_primitive::(v, &T::DATA_TYPE) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.capacity() * std::mem::size_of::() - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index a1f5f153a9ff..6c01decdbf95 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -66,26 +66,6 @@ pub fn create_aggregate_expr( name, data_type, )), - (AggregateFunction::BitAnd, _) => Arc::new(expressions::BitAnd::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::BitOr, _) => Arc::new(expressions::BitOr::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::BitXor, false) => Arc::new(expressions::BitXor::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::BitXor, true) => Arc::new(expressions::DistinctBitXor::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), (AggregateFunction::BoolAnd, _) => Arc::new(expressions::BoolAnd::new( input_phy_exprs[0].clone(), name, @@ -202,12 +182,10 @@ mod tests { use datafusion_expr::{type_coercion, Signature}; use crate::expressions::{ - try_cast, ArrayAgg, Avg, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, - DistinctArrayAgg, Max, Min, + try_cast, ArrayAgg, Avg, BoolAnd, BoolOr, DistinctArrayAgg, Max, Min, }; use super::*; - #[test] fn test_approx_expr() -> Result<()> { let funcs = vec![AggregateFunction::ArrayAgg]; @@ -319,60 +297,6 @@ mod tests { Ok(()) } - #[test] - fn test_bit_and_or_xor_expr() -> Result<()> { - let funcs = vec![ - AggregateFunction::BitAnd, - AggregateFunction::BitOr, - AggregateFunction::BitXor, - ]; - let data_types = vec![DataType::UInt64, DataType::Int64]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::BitAnd => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::BitOr => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::BitXor => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; - } - } - Ok(()) - } - #[test] fn test_bool_and_or_expr() -> Result<()> { let funcs = vec![AggregateFunction::BoolAnd, AggregateFunction::BoolOr]; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index c20902c11b86..0b1f5f577435 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -21,7 +21,6 @@ pub(crate) mod array_agg; pub(crate) mod array_agg_distinct; pub(crate) mod array_agg_ordered; pub(crate) mod average; -pub(crate) mod bit_and_or_xor; pub(crate) mod bool_and_or; pub(crate) mod correlation; pub(crate) mod covariance; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index b9a159b21e3d..bffaafd7dac2 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -40,7 +40,6 @@ pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; pub use crate::aggregate::average::Avg; pub use crate::aggregate::average::AvgAccumulator; -pub use crate::aggregate::bit_and_or_xor::{BitAnd, BitOr, BitXor, DistinctBitXor}; pub use crate::aggregate::bool_and_or::{BoolAnd, BoolOr}; pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::correlation::Correlation; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e5578ae62f3e..ae4445eaa8ce 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -491,9 +491,9 @@ enum AggregateFunction { // APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; GROUPING = 17; // MEDIAN = 18; - BIT_AND = 19; - BIT_OR = 20; - BIT_XOR = 21; + // BIT_AND = 19; + // BIT_OR = 20; + // BIT_XOR = 21; BOOL_AND = 22; BOOL_OR = 23; // REGR_SLOPE = 26; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 4a7b9610e5bc..243c75435f8d 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -538,9 +538,6 @@ impl serde::Serialize for AggregateFunction { Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", Self::Grouping => "GROUPING", - Self::BitAnd => "BIT_AND", - Self::BitOr => "BIT_OR", - Self::BitXor => "BIT_XOR", Self::BoolAnd => "BOOL_AND", Self::BoolOr => "BOOL_OR", Self::StringAgg => "STRING_AGG", @@ -562,9 +559,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "ARRAY_AGG", "CORRELATION", "GROUPING", - "BIT_AND", - "BIT_OR", - "BIT_XOR", "BOOL_AND", "BOOL_OR", "STRING_AGG", @@ -615,9 +609,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), "GROUPING" => Ok(AggregateFunction::Grouping), - "BIT_AND" => Ok(AggregateFunction::BitAnd), - "BIT_OR" => Ok(AggregateFunction::BitOr), - "BIT_XOR" => Ok(AggregateFunction::BitXor), "BOOL_AND" => Ok(AggregateFunction::BoolAnd), "BOOL_OR" => Ok(AggregateFunction::BoolOr), "STRING_AGG" => Ok(AggregateFunction::StringAgg), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ffaef445d668..1172eccb90fd 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1945,9 +1945,9 @@ pub enum AggregateFunction { /// APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; Grouping = 17, /// MEDIAN = 18; - BitAnd = 19, - BitOr = 20, - BitXor = 21, + /// BIT_AND = 19; + /// BIT_OR = 20; + /// BIT_XOR = 21; BoolAnd = 22, BoolOr = 23, /// REGR_SLOPE = 26; @@ -1975,9 +1975,6 @@ impl AggregateFunction { AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", AggregateFunction::Grouping => "GROUPING", - AggregateFunction::BitAnd => "BIT_AND", - AggregateFunction::BitOr => "BIT_OR", - AggregateFunction::BitXor => "BIT_XOR", AggregateFunction::BoolAnd => "BOOL_AND", AggregateFunction::BoolOr => "BOOL_OR", AggregateFunction::StringAgg => "STRING_AGG", @@ -1993,9 +1990,6 @@ impl AggregateFunction { "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), "GROUPING" => Some(Self::Grouping), - "BIT_AND" => Some(Self::BitAnd), - "BIT_OR" => Some(Self::BitOr), - "BIT_XOR" => Some(Self::BitXor), "BOOL_AND" => Some(Self::BoolAnd), "BOOL_OR" => Some(Self::BoolOr), "STRING_AGG" => Some(Self::StringAgg), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 25b7413a984a..43cc352f98dd 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -140,9 +140,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Min => Self::Min, protobuf::AggregateFunction::Max => Self::Max, protobuf::AggregateFunction::Avg => Self::Avg, - protobuf::AggregateFunction::BitAnd => Self::BitAnd, - protobuf::AggregateFunction::BitOr => Self::BitOr, - protobuf::AggregateFunction::BitXor => Self::BitXor, protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, protobuf::AggregateFunction::BoolOr => Self::BoolOr, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d9548325dac3..33a58daeaf0a 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -111,9 +111,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Min => Self::Min, AggregateFunction::Max => Self::Max, AggregateFunction::Avg => Self::Avg, - AggregateFunction::BitAnd => Self::BitAnd, - AggregateFunction::BitOr => Self::BitOr, - AggregateFunction::BitXor => Self::BitXor, AggregateFunction::BoolAnd => Self::BoolAnd, AggregateFunction::BoolOr => Self::BoolOr, AggregateFunction::ArrayAgg => Self::ArrayAgg, @@ -380,9 +377,6 @@ pub fn serialize_expr( AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, - AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, - AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, AggregateFunction::Avg => protobuf::AggregateFunction::Avg, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 3a4c35a93e16..886179bf5627 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,11 +23,11 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, - CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, - NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, RowNumber, - StringAgg, TryCastExpr, WindowShift, + ArrayAgg, Avg, BinaryExpr, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, + CumeDist, DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, + Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, + OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, TryCastExpr, + WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -241,15 +241,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { let inner = if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Grouping - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::BitAnd - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::BitOr - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::BitXor - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::BitXor } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::BoolAnd } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index a496e226855a..52696a106183 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -59,6 +59,7 @@ use datafusion_expr::{ TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; +use datafusion_functions_aggregate::expr_fn::{bit_and, bit_or, bit_xor}; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, @@ -665,6 +666,9 @@ async fn roundtrip_expr_api() -> Result<()> { approx_median(lit(2)), approx_percentile_cont(lit(2), lit(0.5)), approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), + bit_and(lit(2)), + bit_or(lit(2)), + bit_xor(lit(2)), ]; // ensure expressions created with the expr api can be round tripped From 9b1bb68e37688330fa47c73bdd733b206be8759e Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Mon, 17 Jun 2024 08:52:26 -0700 Subject: [PATCH 32/34] refactor: improve PoolType argument handling for CLI (#10940) * refactor: dont include fallback in match on mem_pool_type * refactor: improve PoolType argument handling --- datafusion-cli/src/lib.rs | 1 + datafusion-cli/src/main.rs | 29 +++++--------------- datafusion-cli/src/pool_type.rs | 48 +++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 23 deletions(-) create mode 100644 datafusion-cli/src/pool_type.rs diff --git a/datafusion-cli/src/lib.rs b/datafusion-cli/src/lib.rs index 139a60b8cf16..5081436aa6c5 100644 --- a/datafusion-cli/src/lib.rs +++ b/datafusion-cli/src/lib.rs @@ -25,5 +25,6 @@ pub mod functions; pub mod helper; pub mod highlighter; pub mod object_storage; +pub mod pool_type; pub mod print_format; pub mod print_options; diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index f2b29fe78690..f469fda4f960 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -19,7 +19,6 @@ use std::collections::HashMap; use std::env; use std::path::Path; use std::process::ExitCode; -use std::str::FromStr; use std::sync::{Arc, OnceLock}; use datafusion::error::{DataFusionError, Result}; @@ -31,6 +30,7 @@ use datafusion_cli::catalog::DynamicFileCatalog; use datafusion_cli::functions::ParquetMetadataFunc; use datafusion_cli::{ exec, + pool_type::PoolType, print_format::PrintFormat, print_options::{MaxRows, PrintOptions}, DATAFUSION_CLI_VERSION, @@ -42,24 +42,6 @@ use mimalloc::MiMalloc; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; -#[derive(PartialEq, Debug)] -enum PoolType { - Greedy, - Fair, -} - -impl FromStr for PoolType { - type Err = String; - - fn from_str(s: &str) -> Result { - match s { - "Greedy" | "greedy" => Ok(PoolType::Greedy), - "Fair" | "fair" => Ok(PoolType::Fair), - _ => Err(format!("Invalid memory pool type '{}'", s)), - } - } -} - #[derive(Debug, Parser, PartialEq)] #[clap(author, version, about, long_about= None)] struct Args { @@ -127,9 +109,10 @@ struct Args { #[clap( long, - help = "Specify the memory pool type 'greedy' or 'fair', default to 'greedy'" + help = "Specify the memory pool type 'greedy' or 'fair'", + default_value_t = PoolType::Greedy )] - mem_pool_type: Option, + mem_pool_type: PoolType, #[clap( long, @@ -181,9 +164,9 @@ async fn main_inner() -> Result<()> { let memory_limit = extract_memory_pool_size(&memory_limit).unwrap(); // set memory pool type match args.mem_pool_type { - Some(PoolType::Fair) => rt_config + PoolType::Fair => rt_config .with_memory_pool(Arc::new(FairSpillPool::new(memory_limit))), - _ => rt_config + PoolType::Greedy => rt_config .with_memory_pool(Arc::new(GreedyMemoryPool::new(memory_limit))) } } else { diff --git a/datafusion-cli/src/pool_type.rs b/datafusion-cli/src/pool_type.rs new file mode 100644 index 000000000000..25763eba5c8c --- /dev/null +++ b/datafusion-cli/src/pool_type.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + fmt::{self, Display, Formatter}, + str::FromStr, +}; + +#[derive(PartialEq, Debug)] +pub enum PoolType { + Greedy, + Fair, +} + +impl FromStr for PoolType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "Greedy" | "greedy" => Ok(PoolType::Greedy), + "Fair" | "fair" => Ok(PoolType::Fair), + _ => Err(format!("Invalid memory pool type '{}'", s)), + } + } +} + +impl Display for PoolType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + PoolType::Greedy => write!(f, "greedy"), + PoolType::Fair => write!(f, "fair"), + } + } +} From 861a2364bdf04854482384c29d9b64962da377fe Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Mon, 17 Jun 2024 10:16:20 -0700 Subject: [PATCH 33/34] feat: add CliSessionContext trait for cli (#10890) use CliSessionContext trait for cli --- .../examples/cli-session-context.rs | 97 ++++++++++++++++++ datafusion-cli/src/cli_context.rs | 98 +++++++++++++++++++ datafusion-cli/src/command.rs | 4 +- datafusion-cli/src/exec.rs | 28 +++--- datafusion-cli/src/lib.rs | 1 + datafusion-cli/src/object_storage.rs | 55 ++--------- 6 files changed, 220 insertions(+), 63 deletions(-) create mode 100644 datafusion-cli/examples/cli-session-context.rs create mode 100644 datafusion-cli/src/cli_context.rs diff --git a/datafusion-cli/examples/cli-session-context.rs b/datafusion-cli/examples/cli-session-context.rs new file mode 100644 index 000000000000..8da52ed84a5f --- /dev/null +++ b/datafusion-cli/examples/cli-session-context.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shows an example of a custom session context that unions the input plan with itself. +//! To run this example, use `cargo run --example cli-session-context` from within the `datafusion-cli` directory. + +use std::sync::Arc; + +use datafusion::{ + dataframe::DataFrame, + error::DataFusionError, + execution::{context::SessionState, TaskContext}, + logical_expr::{LogicalPlan, LogicalPlanBuilder}, + prelude::SessionContext, +}; +use datafusion_cli::{ + cli_context::CliSessionContext, exec::exec_from_repl, print_options::PrintOptions, +}; +use object_store::ObjectStore; + +/// This is a toy example of a custom session context that unions the input plan with itself. +struct MyUnionerContext { + ctx: SessionContext, +} + +impl Default for MyUnionerContext { + fn default() -> Self { + Self { + ctx: SessionContext::new(), + } + } +} + +#[async_trait::async_trait] +impl CliSessionContext for MyUnionerContext { + fn task_ctx(&self) -> Arc { + self.ctx.task_ctx() + } + + fn session_state(&self) -> SessionState { + self.ctx.state() + } + + fn register_object_store( + &self, + url: &url::Url, + object_store: Arc, + ) -> Option> { + self.ctx.register_object_store(url, object_store) + } + + fn register_table_options_extension_from_scheme(&self, _scheme: &str) { + unimplemented!() + } + + async fn execute_logical_plan( + &self, + plan: LogicalPlan, + ) -> Result { + let new_plan = LogicalPlanBuilder::from(plan.clone()) + .union(plan.clone())? + .build()?; + + self.ctx.execute_logical_plan(new_plan).await + } +} + +#[tokio::main] +/// Runs the example. +pub async fn main() { + let mut my_ctx = MyUnionerContext::default(); + + let mut print_options = PrintOptions { + format: datafusion_cli::print_format::PrintFormat::Automatic, + quiet: false, + maxrows: datafusion_cli::print_options::MaxRows::Unlimited, + color: true, + }; + + exec_from_repl(&mut my_ctx, &mut print_options) + .await + .unwrap(); +} diff --git a/datafusion-cli/src/cli_context.rs b/datafusion-cli/src/cli_context.rs new file mode 100644 index 000000000000..516929ebacf1 --- /dev/null +++ b/datafusion-cli/src/cli_context.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion::{ + dataframe::DataFrame, + error::DataFusionError, + execution::{context::SessionState, TaskContext}, + logical_expr::LogicalPlan, + prelude::SessionContext, +}; +use object_store::ObjectStore; + +use crate::object_storage::{AwsOptions, GcpOptions}; + +#[async_trait::async_trait] +/// The CLI session context trait provides a way to have a session context that can be used with datafusion's CLI code. +pub trait CliSessionContext { + /// Get an atomic reference counted task context. + fn task_ctx(&self) -> Arc; + + /// Get the session state. + fn session_state(&self) -> SessionState; + + /// Register an object store with the session context. + fn register_object_store( + &self, + url: &url::Url, + object_store: Arc, + ) -> Option>; + + /// Register table options extension from scheme. + fn register_table_options_extension_from_scheme(&self, scheme: &str); + + /// Execute a logical plan and return a DataFrame. + async fn execute_logical_plan( + &self, + plan: LogicalPlan, + ) -> Result; +} + +#[async_trait::async_trait] +impl CliSessionContext for SessionContext { + fn task_ctx(&self) -> Arc { + self.task_ctx() + } + + fn session_state(&self) -> SessionState { + self.state() + } + + fn register_object_store( + &self, + url: &url::Url, + object_store: Arc, + ) -> Option> { + self.register_object_store(url, object_store) + } + + fn register_table_options_extension_from_scheme(&self, scheme: &str) { + match scheme { + // For Amazon S3 or Alibaba Cloud OSS + "s3" | "oss" | "cos" => { + // Register AWS specific table options in the session context: + self.register_table_options_extension(AwsOptions::default()) + } + // For Google Cloud Storage + "gs" | "gcs" => { + // Register GCP specific table options in the session context: + self.register_table_options_extension(GcpOptions::default()) + } + // For unsupported schemes, do nothing: + _ => {} + } + } + + async fn execute_logical_plan( + &self, + plan: LogicalPlan, + ) -> Result { + self.execute_logical_plan(plan).await + } +} diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index be6393351aed..1a6c023d3b50 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -17,6 +17,7 @@ //! Command within CLI +use crate::cli_context::CliSessionContext; use crate::exec::{exec_and_print, exec_from_lines}; use crate::functions::{display_all_functions, Function}; use crate::print_format::PrintFormat; @@ -28,7 +29,6 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::exec_err; use datafusion::common::instant::Instant; use datafusion::error::{DataFusionError, Result}; -use datafusion::prelude::SessionContext; use std::fs::File; use std::io::BufReader; use std::str::FromStr; @@ -55,7 +55,7 @@ pub enum OutputFormat { impl Command { pub async fn execute( &self, - ctx: &mut SessionContext, + ctx: &mut dyn CliSessionContext, print_options: &mut PrintOptions, ) -> Result<()> { match self { diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 855d6a7cbbc9..c4c92be1525d 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -23,12 +23,13 @@ use std::io::prelude::*; use std::io::BufReader; use std::str::FromStr; +use crate::cli_context::CliSessionContext; use crate::helper::split_from_semicolon; use crate::print_format::PrintFormat; use crate::{ command::{Command, OutputFormat}, helper::{unescape_input, CliHelper}, - object_storage::{get_object_store, register_options}, + object_storage::get_object_store, print_options::{MaxRows, PrintOptions}, }; @@ -38,7 +39,6 @@ use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; use datafusion::physical_plan::{collect, execute_stream, ExecutionPlanProperties}; -use datafusion::prelude::SessionContext; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; @@ -50,7 +50,7 @@ use tokio::signal; /// run and execute SQL statements and commands, against a context with the given print options pub async fn exec_from_commands( - ctx: &mut SessionContext, + ctx: &mut dyn CliSessionContext, commands: Vec, print_options: &PrintOptions, ) -> Result<()> { @@ -63,7 +63,7 @@ pub async fn exec_from_commands( /// run and execute SQL statements and commands from a file, against a context with the given print options pub async fn exec_from_lines( - ctx: &mut SessionContext, + ctx: &mut dyn CliSessionContext, reader: &mut BufReader, print_options: &PrintOptions, ) -> Result<()> { @@ -103,7 +103,7 @@ pub async fn exec_from_lines( } pub async fn exec_from_files( - ctx: &mut SessionContext, + ctx: &mut dyn CliSessionContext, files: Vec, print_options: &PrintOptions, ) -> Result<()> { @@ -122,7 +122,7 @@ pub async fn exec_from_files( /// run and execute SQL statements and commands against a context with the given print options pub async fn exec_from_repl( - ctx: &mut SessionContext, + ctx: &mut dyn CliSessionContext, print_options: &mut PrintOptions, ) -> rustyline::Result<()> { let mut rl = Editor::new()?; @@ -205,7 +205,7 @@ pub async fn exec_from_repl( } pub(super) async fn exec_and_print( - ctx: &mut SessionContext, + ctx: &mut dyn CliSessionContext, print_options: &PrintOptions, sql: String, ) -> Result<()> { @@ -292,10 +292,10 @@ impl AdjustedPrintOptions { } async fn create_plan( - ctx: &mut SessionContext, + ctx: &mut dyn CliSessionContext, statement: Statement, ) -> Result { - let mut plan = ctx.state().statement_to_plan(statement).await?; + let mut plan = ctx.session_state().statement_to_plan(statement).await?; // Note that cmd is a mutable reference so that create_external_table function can remove all // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion @@ -354,7 +354,7 @@ async fn create_plan( /// alteration fails, or if the object store cannot be retrieved and registered /// successfully. pub(crate) async fn register_object_store_and_config_extensions( - ctx: &SessionContext, + ctx: &dyn CliSessionContext, location: &String, options: &HashMap, format: Option, @@ -369,17 +369,18 @@ pub(crate) async fn register_object_store_and_config_extensions( let url = table_path.as_ref(); // Register the options based on the scheme extracted from the location - register_options(ctx, scheme); + ctx.register_table_options_extension_from_scheme(scheme); // Clone and modify the default table options based on the provided options - let mut table_options = ctx.state().default_table_options().clone(); + let mut table_options = ctx.session_state().default_table_options().clone(); if let Some(format) = format { table_options.set_file_format(format); } table_options.alter_with_string_hash_map(options)?; // Retrieve the appropriate object store based on the scheme, URL, and modified table options - let store = get_object_store(&ctx.state(), scheme, url, &table_options).await?; + let store = + get_object_store(&ctx.session_state(), scheme, url, &table_options).await?; // Register the retrieved object store in the session context's runtime environment ctx.register_object_store(url, store); @@ -394,6 +395,7 @@ mod tests { use datafusion::common::config::FormatOptions; use datafusion::common::plan_err; + use datafusion::prelude::SessionContext; use url::Url; async fn create_external_table_test(location: &str, sql: &str) -> Result<()> { diff --git a/datafusion-cli/src/lib.rs b/datafusion-cli/src/lib.rs index 5081436aa6c5..fbfc9242a61d 100644 --- a/datafusion-cli/src/lib.rs +++ b/datafusion-cli/src/lib.rs @@ -19,6 +19,7 @@ pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION"); pub mod catalog; +pub mod cli_context; pub mod command; pub mod exec; pub mod functions; diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index 85e0009bd267..87eb04d113de 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -25,7 +25,6 @@ use datafusion::common::config::{ use datafusion::common::{config_err, exec_datafusion_err, exec_err}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionState; -use datafusion::prelude::SessionContext; use async_trait::async_trait; use aws_credential_types::provider::ProvideCredentials; @@ -392,48 +391,6 @@ impl ConfigExtension for GcpOptions { const PREFIX: &'static str = "gcp"; } -/// Registers storage options for different cloud storage schemes in a given -/// session context. -/// -/// This function is responsible for extending the session context with specific -/// options based on the storage scheme being used. These options are essential -/// for handling interactions with different cloud storage services such as Amazon -/// S3, Alibaba Cloud OSS, Google Cloud Storage, etc. -/// -/// # Parameters -/// -/// * `ctx` - A mutable reference to the session context where table options are -/// to be registered. The session context holds configuration and environment -/// for the current session. -/// * `scheme` - A string slice that represents the cloud storage scheme. This -/// determines which set of options will be registered in the session context. -/// -/// # Supported Schemes -/// -/// * `s3` or `oss` - Registers `AwsOptions` which are configurations specific to -/// Amazon S3 and Alibaba Cloud OSS. -/// * `gs` or `gcs` - Registers `GcpOptions` which are configurations specific to -/// Google Cloud Storage. -/// -/// NOTE: This function will not perform any action when given an unsupported scheme. -pub(crate) fn register_options(ctx: &SessionContext, scheme: &str) { - // Match the provided scheme against supported cloud storage schemes: - match scheme { - // For Amazon S3 or Alibaba Cloud OSS - "s3" | "oss" | "cos" => { - // Register AWS specific table options in the session context: - ctx.register_table_options_extension(AwsOptions::default()) - } - // For Google Cloud Storage - "gs" | "gcs" => { - // Register GCP specific table options in the session context: - ctx.register_table_options_extension(GcpOptions::default()) - } - // For unsupported schemes, do nothing: - _ => {} - } -} - pub(crate) async fn get_object_store( state: &SessionState, scheme: &str, @@ -498,6 +455,8 @@ pub(crate) async fn get_object_store( #[cfg(test)] mod tests { + use crate::cli_context::CliSessionContext; + use super::*; use datafusion::common::plan_err; @@ -534,7 +493,7 @@ mod tests { let mut plan = ctx.state().create_logical_plan(&sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - register_options(&ctx, scheme); + ctx.register_table_options_extension_from_scheme(scheme); let mut table_options = ctx.state().default_table_options().clone(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); @@ -579,7 +538,7 @@ mod tests { let mut plan = ctx.state().create_logical_plan(&sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - register_options(&ctx, scheme); + ctx.register_table_options_extension_from_scheme(scheme); let mut table_options = ctx.state().default_table_options().clone(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); @@ -605,7 +564,7 @@ mod tests { let mut plan = ctx.state().create_logical_plan(&sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - register_options(&ctx, scheme); + ctx.register_table_options_extension_from_scheme(scheme); let mut table_options = ctx.state().default_table_options().clone(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); @@ -633,7 +592,7 @@ mod tests { let mut plan = ctx.state().create_logical_plan(&sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - register_options(&ctx, scheme); + ctx.register_table_options_extension_from_scheme(scheme); let mut table_options = ctx.state().default_table_options().clone(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); @@ -670,7 +629,7 @@ mod tests { let mut plan = ctx.state().create_logical_plan(&sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - register_options(&ctx, scheme); + ctx.register_table_options_extension_from_scheme(scheme); let mut table_options = ctx.state().default_table_options().clone(); table_options.alter_with_string_hash_map(&cmd.options)?; let gcp_options = table_options.extensions.get::().unwrap(); From e1cfb48215ee91a183e06cfee602e42d2c23f429 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 17 Jun 2024 13:18:04 -0400 Subject: [PATCH 34/34] Minor: remove string copy from Column::from_qualified_name (#10947) --- datafusion/common/src/column.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 911ff079def1..3e2bc0ad7c3a 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -109,21 +109,21 @@ impl Column { /// `foo.BAR` would be parsed to a reference to relation `foo`, column name `bar` (lower case) /// where `"foo.BAR"` would be parsed to a reference to column named `foo.BAR` pub fn from_qualified_name(flat_name: impl Into) -> Self { - let flat_name: &str = &flat_name.into(); - Self::from_idents(&mut parse_identifiers_normalized(flat_name, false)) + let flat_name = flat_name.into(); + Self::from_idents(&mut parse_identifiers_normalized(&flat_name, false)) .unwrap_or_else(|| Self { relation: None, - name: flat_name.to_owned(), + name: flat_name, }) } /// Deserialize a fully qualified name string into a column preserving column text case pub fn from_qualified_name_ignore_case(flat_name: impl Into) -> Self { - let flat_name: &str = &flat_name.into(); - Self::from_idents(&mut parse_identifiers_normalized(flat_name, true)) + let flat_name = flat_name.into(); + Self::from_idents(&mut parse_identifiers_normalized(&flat_name, true)) .unwrap_or_else(|| Self { relation: None, - name: flat_name.to_owned(), + name: flat_name, }) }