diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index fe56fc22ea8c..2d4203464300 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -267,6 +267,26 @@ async fn test_fn_initcap() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_instr() -> Result<()> { + let expr = instr(col("a"), lit("b")); + + let expected = [ + "+-------------------------+", + "| instr(test.a,Utf8(\"b\")) |", + "+-------------------------+", + "| 2 |", + "| 2 |", + "| 0 |", + "| 5 |", + "+-------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_left() -> Result<()> { @@ -634,6 +654,26 @@ async fn test_fn_starts_with() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_ends_with() -> Result<()> { + let expr = ends_with(col("a"), lit("DEF")); + + let expected = [ + "+-------------------------------+", + "| ends_with(test.a,Utf8(\"DEF\")) |", + "+-------------------------------+", + "| true |", + "| false |", + "| false |", + "| false |", + "+-------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_strpos() -> Result<()> { diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index fc1b66efed87..c8cac5dd9b7a 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -25,138 +25,164 @@ use itertools::Itertools; use crate::parquet::Unit::RowGroup; use crate::parquet::{ContextWithParquet, Scenario}; use datafusion_expr::{col, lit}; - -async fn test_row_group_prune( - case_data_type: Scenario, - sql: &str, +struct RowGroupPruningTest { + scenario: Scenario, + query: String, expected_errors: Option, expected_row_group_pruned_by_statistics: Option, expected_row_group_pruned_by_bloom_filter: Option, expected_results: usize, -) { - let output = ContextWithParquet::new(case_data_type, RowGroup) - .await - .query(sql) - .await; - - println!("{}", output.description()); - assert_eq!(output.predicate_evaluation_errors(), expected_errors); - assert_eq!( - output.row_groups_pruned_statistics(), - expected_row_group_pruned_by_statistics - ); - assert_eq!( - output.row_groups_pruned_bloom_filter(), - expected_row_group_pruned_by_bloom_filter - ); - assert_eq!( - output.result_rows, - expected_results, - "{}", - output.description() - ); } - -/// check row group pruning by bloom filter and statistics independently -async fn test_prune_verbose( - case_data_type: Scenario, - sql: &str, - expected_errors: Option, - expected_row_group_pruned_sbbf: Option, - expected_row_group_pruned_statistics: Option, - expected_results: usize, -) { - let output = ContextWithParquet::new(case_data_type, RowGroup) - .await - .query(sql) - .await; - - println!("{}", output.description()); - assert_eq!(output.predicate_evaluation_errors(), expected_errors); - assert_eq!( - output.row_groups_pruned_bloom_filter(), - expected_row_group_pruned_sbbf - ); - assert_eq!( - output.row_groups_pruned_statistics(), - expected_row_group_pruned_statistics - ); - assert_eq!( - output.result_rows, - expected_results, - "{}", - output.description() - ); +impl RowGroupPruningTest { + // Start building the test configuration + fn new() -> Self { + Self { + scenario: Scenario::Timestamps, // or another default + query: String::new(), + expected_errors: None, + expected_row_group_pruned_by_statistics: None, + expected_row_group_pruned_by_bloom_filter: None, + expected_results: 0, + } + } + + // Set the scenario for the test + fn with_scenario(mut self, scenario: Scenario) -> Self { + self.scenario = scenario; + self + } + + // Set the SQL query for the test + fn with_query(mut self, query: &str) -> Self { + self.query = query.to_string(); + self + } + + // Set the expected errors for the test + fn with_expected_errors(mut self, errors: Option) -> Self { + self.expected_errors = errors; + self + } + + // Set the expected pruned row groups by statistics + fn with_pruned_by_stats(mut self, pruned_by_stats: Option) -> Self { + self.expected_row_group_pruned_by_statistics = pruned_by_stats; + self + } + + // Set the expected pruned row groups by bloom filter + fn with_pruned_by_bloom_filter(mut self, pruned_by_bf: Option) -> Self { + self.expected_row_group_pruned_by_bloom_filter = pruned_by_bf; + self + } + + // Set the expected rows for the test + fn with_expected_rows(mut self, rows: usize) -> Self { + self.expected_results = rows; + self + } + + // Execute the test with the current configuration + async fn test_row_group_prune(self) { + let output = ContextWithParquet::new(self.scenario, RowGroup) + .await + .query(&self.query) + .await; + + println!("{}", output.description()); + assert_eq!(output.predicate_evaluation_errors(), self.expected_errors); + assert_eq!( + output.row_groups_pruned_statistics(), + self.expected_row_group_pruned_by_statistics + ); + assert_eq!( + output.row_groups_pruned_bloom_filter(), + self.expected_row_group_pruned_by_bloom_filter + ); + assert_eq!( + output.result_rows, + self.expected_results, + "{}", + output.description() + ); + } } #[tokio::test] async fn prune_timestamps_nanos() { - test_row_group_prune( - Scenario::Timestamps, - "SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')", - Some(0), - Some(1), - Some(0), - 10, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Timestamps) + .with_query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(10) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_timestamps_micros() { - test_row_group_prune( - Scenario::Timestamps, - "SELECT * FROM t where micros < to_timestamp_micros('2020-01-02 01:01:11Z')", - Some(0), - Some(1), - Some(0), - 10, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Timestamps) + .with_query( + "SELECT * FROM t where micros < to_timestamp_micros('2020-01-02 01:01:11Z')", + ) + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(10) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_timestamps_millis() { - test_row_group_prune( - Scenario::Timestamps, - "SELECT * FROM t where millis < to_timestamp_millis('2020-01-02 01:01:11Z')", - Some(0), - Some(1), - Some(0), - 10, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Timestamps) + .with_query( + "SELECT * FROM t where micros < to_timestamp_millis('2020-01-02 01:01:11Z')", + ) + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(10) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_timestamps_seconds() { - test_row_group_prune( - Scenario::Timestamps, - "SELECT * FROM t where seconds < to_timestamp_seconds('2020-01-02 01:01:11Z')", - Some(0), - Some(1), - Some(0), - 10, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Timestamps) + .with_query( + "SELECT * FROM t where seconds < to_timestamp_seconds('2020-01-02 01:01:11Z')", + ) + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(10) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_date32() { - test_row_group_prune( - Scenario::Dates, - "SELECT * FROM t where date32 < cast('2020-01-02' as date)", - Some(0), - Some(3), - Some(0), - 1, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Dates) + .with_query("SELECT * FROM t where date32 < cast('2020-01-02' as date)") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_date64() { // work around for not being able to cast Date32 to Date64 automatically + let date = "2020-01-02" .parse::() .unwrap() @@ -181,15 +207,15 @@ async fn prune_date64() { #[tokio::test] async fn prune_disabled() { - test_row_group_prune( - Scenario::Timestamps, - "SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')", - Some(0), - Some(1), - Some(0), - 10, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Timestamps) + .with_query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(10) + .test_row_group_prune() + .await; // test without pruning let query = "SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')"; @@ -215,232 +241,233 @@ async fn prune_disabled() { #[tokio::test] async fn prune_int32_lt() { - test_row_group_prune( - Scenario::Int32, - "SELECT * FROM t where i < 1", - Some(0), - Some(1), - Some(0), - 11, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Int32) + .with_query("SELECT * FROM t where i < 1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; + // result of sql "SELECT * FROM t where i < 1" is same as // "SELECT * FROM t where -i > -1" - test_row_group_prune( - Scenario::Int32, - "SELECT * FROM t where -i > -1", - Some(0), - Some(1), - Some(0), - 11, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Int32) + .with_query("SELECT * FROM t where -i > -1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_int32_eq() { - test_row_group_prune( - Scenario::Int32, - "SELECT * FROM t where i = 1", - Some(0), - Some(3), - Some(0), - 1, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Int32) + .with_query("SELECT * FROM t where i = 1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_int32_scalar_fun_and_eq() { - test_row_group_prune( - Scenario::Int32, - "SELECT * FROM t where abs(i) = 1 and i = 1", - Some(0), - Some(3), - Some(0), - 1, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Int32) + .with_query("SELECT * FROM t where i = 1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_int32_scalar_fun() { - test_row_group_prune( - Scenario::Int32, - "SELECT * FROM t where abs(i) = 1", - Some(0), - Some(0), - Some(0), - 3, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Int32) + .with_query("SELECT * FROM t where abs(i) = 1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(3) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_int32_complex_expr() { - test_row_group_prune( - Scenario::Int32, - "SELECT * FROM t where i+1 = 1", - Some(0), - Some(0), - Some(0), - 2, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Int32) + .with_query("SELECT * FROM t where i+1 = 1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_int32_complex_expr_subtract() { - test_row_group_prune( - Scenario::Int32, - "SELECT * FROM t where 1-i > 1", - Some(0), - Some(0), - Some(0), - 9, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Int32) + .with_query("SELECT * FROM t where 1-i > 1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(9) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_f64_lt() { - test_row_group_prune( - Scenario::Float64, - "SELECT * FROM t where f < 1", - Some(0), - Some(1), - Some(0), - 11, - ) - .await; - test_row_group_prune( - Scenario::Float64, - "SELECT * FROM t where -f > -1", - Some(0), - Some(1), - Some(0), - 11, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Float64) + .with_query("SELECT * FROM t where f < 1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Float64) + .with_query("SELECT * FROM t where -f > -1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_f64_scalar_fun_and_gt() { // result of sql "SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1" // only use "f >= 0" to prune - test_row_group_prune( - Scenario::Float64, - "SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1", - Some(0), - Some(2), - Some(0), - 1, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Float64) + .with_query("SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(2)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_f64_scalar_fun() { // result of sql "SELECT * FROM t where abs(f-1) <= 0.000001" is not supported - test_row_group_prune( - Scenario::Float64, - "SELECT * FROM t where abs(f-1) <= 0.000001", - Some(0), - Some(0), - Some(0), - 1, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Float64) + .with_query("SELECT * FROM t where abs(f-1) <= 0.000001") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_f64_complex_expr() { // result of sql "SELECT * FROM t where f+1 > 1.1"" is not supported - test_row_group_prune( - Scenario::Float64, - "SELECT * FROM t where f+1 > 1.1", - Some(0), - Some(0), - Some(0), - 9, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Float64) + .with_query("SELECT * FROM t where f+1 > 1.1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(9) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_f64_complex_expr_subtract() { // result of sql "SELECT * FROM t where 1-f > 1" is not supported - test_row_group_prune( - Scenario::Float64, - "SELECT * FROM t where 1-f > 1", - Some(0), - Some(0), - Some(0), - 9, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Float64) + .with_query("SELECT * FROM t where 1-f > 1") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(9) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_int32_eq_in_list() { // result of sql "SELECT * FROM t where in (1)" - test_row_group_prune( - Scenario::Int32, - "SELECT * FROM t where i in (1)", - Some(0), - Some(3), - Some(0), - 1, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Int32) + .with_query("SELECT * FROM t where i in (1)") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(1) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_int32_eq_in_list_2() { // result of sql "SELECT * FROM t where in (1000)", prune all // test whether statistics works - test_prune_verbose( - Scenario::Int32, - "SELECT * FROM t where i in (1000)", - Some(0), - Some(0), - Some(4), - 0, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Int32) + .with_query("SELECT * FROM t where i in (1000)") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(0) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_int32_eq_large_in_list() { // result of sql "SELECT * FROM t where i in (2050...2582)", prune all - // test whether sbbf works - test_prune_verbose( - Scenario::Int32Range, - format!( - "SELECT * FROM t where i in ({})", - (200050..200082).join(",") + RowGroupPruningTest::new() + .with_scenario(Scenario::Int32Range) + .with_query( + format!( + "SELECT * FROM t where i in ({})", + (200050..200082).join(",") + ) + .as_str(), ) - .as_str(), - Some(0), - Some(1), - // we don't support pruning by statistics for in_list with more than 20 elements currently - Some(0), - 0, - ) - .await; + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) + .test_row_group_prune() + .await; } #[tokio::test] async fn prune_int32_eq_in_list_negated() { // result of sql "SELECT * FROM t where not in (1)" prune nothing - test_row_group_prune( - Scenario::Int32, - "SELECT * FROM t where i not in (1)", - Some(0), - Some(0), - Some(0), - 19, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Int32) + .with_query("SELECT * FROM t where i not in (1)") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(19) + .test_row_group_prune() + .await; } #[tokio::test] @@ -448,46 +475,42 @@ async fn prune_decimal_lt() { // The data type of decimal_col is decimal(9,2) // There are three row groups: // [1.00, 6.00], [-5.00,6.00], [20.00,60.00] - test_row_group_prune( - Scenario::Decimal, - "SELECT * FROM t where decimal_col < 4", - Some(0), - Some(1), - Some(0), - 6, - ) - .await; - // compare with the casted decimal value - test_row_group_prune( - Scenario::Decimal, - "SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))", - Some(0), - Some(1), - Some(0), - 8, - ) - .await; - - // The data type of decimal_col is decimal(38,2) - test_row_group_prune( - Scenario::DecimalLargePrecision, - "SELECT * FROM t where decimal_col < 4", - Some(0), - Some(1), - Some(0), - 6, - ) - .await; - // compare with the casted decimal value - test_row_group_prune( - Scenario::DecimalLargePrecision, - "SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))", - Some(0), - Some(1), - Some(0), - 8, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Decimal) + .with_query("SELECT * FROM t where decimal_col < 4") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(6) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Decimal) + .with_query("SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(8) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::DecimalLargePrecision) + .with_query("SELECT * FROM t where decimal_col < 4") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(6) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::DecimalLargePrecision) + .with_query("SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(8) + .test_row_group_prune() + .await; } #[tokio::test] @@ -495,44 +518,44 @@ async fn prune_decimal_eq() { // The data type of decimal_col is decimal(9,2) // There are three row groups: // [1.00, 6.00], [-5.00,6.00], [20.00,60.00] - test_row_group_prune( - Scenario::Decimal, - "SELECT * FROM t where decimal_col = 4", - Some(0), - Some(1), - Some(0), - 2, - ) - .await; - test_row_group_prune( - Scenario::Decimal, - "SELECT * FROM t where decimal_col = 4.00", - Some(0), - Some(1), - Some(0), - 2, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Decimal) + .with_query("SELECT * FROM t where decimal_col = 4") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Decimal) + .with_query("SELECT * FROM t where decimal_col = 4.00") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::DecimalLargePrecision) + .with_query("SELECT * FROM t where decimal_col = 4") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::DecimalLargePrecision) + .with_query("SELECT * FROM t where decimal_col = 4.00") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; // The data type of decimal_col is decimal(38,2) - test_row_group_prune( - Scenario::DecimalLargePrecision, - "SELECT * FROM t where decimal_col = 4", - Some(0), - Some(1), - Some(0), - 2, - ) - .await; - test_row_group_prune( - Scenario::DecimalLargePrecision, - "SELECT * FROM t where decimal_col = 4.00", - Some(0), - Some(1), - Some(0), - 2, - ) - .await; } #[tokio::test] @@ -540,44 +563,42 @@ async fn prune_decimal_in_list() { // The data type of decimal_col is decimal(9,2) // There are three row groups: // [1.00, 6.00], [-5.00,6.00], [20.00,60.00] - test_row_group_prune( - Scenario::Decimal, - "SELECT * FROM t where decimal_col in (4,3,2,123456789123)", - Some(0), - Some(1), - Some(0), - 5, - ) - .await; - test_row_group_prune( - Scenario::Decimal, - "SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)", - Some(0), - Some(1), - Some(0), - 6, - ) - .await; - - // The data type of decimal_col is decimal(38,2) - test_row_group_prune( - Scenario::DecimalLargePrecision, - "SELECT * FROM t where decimal_col in (4,3,2,123456789123)", - Some(0), - Some(1), - Some(0), - 5, - ) - .await; - test_row_group_prune( - Scenario::DecimalLargePrecision, - "SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)", - Some(0), - Some(1), - Some(0), - 6, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Decimal) + .with_query("SELECT * FROM t where decimal_col in (4,3,2,123456789123)") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(5) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Decimal) + .with_query("SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(6) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::Decimal) + .with_query("SELECT * FROM t where decimal_col in (4,3,2,123456789123)") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(5) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::DecimalLargePrecision) + .with_query("SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(6) + .test_row_group_prune() + .await; } #[tokio::test] @@ -586,32 +607,31 @@ async fn prune_periods_in_column_names() { // name = "HTTP GET / DISPATCH", service.name = ['frontend', 'frontend'], // name = "HTTP PUT / DISPATCH", service.name = ['backend', 'frontend'], // name = "HTTP GET / DISPATCH", service.name = ['backend', 'backend' ], - test_row_group_prune( - Scenario::PeriodsInColumnNames, - // use double quotes to use column named "service.name" - "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend'", - Some(0), - Some(1), // prune out last row group - Some(0), - 7, - ) - .await; - test_row_group_prune( - Scenario::PeriodsInColumnNames, - "SELECT \"name\", \"service.name\" FROM t WHERE \"name\" != 'HTTP GET / DISPATCH'", - Some(0), - Some(2), // prune out first and last row group - Some(0), - 5, - ) - .await; - test_row_group_prune( - Scenario::PeriodsInColumnNames, - "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend' AND \"name\" != 'HTTP GET / DISPATCH'", - Some(0), - Some(2), // prune out middle and last row group - Some(0), - 2, - ) - .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::PeriodsInColumnNames) + .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend'") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(1)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(7) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::PeriodsInColumnNames) + .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"name\" != 'HTTP GET / DISPATCH'") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(2)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(5) + .test_row_group_prune() + .await; + RowGroupPruningTest::new() + .with_scenario(Scenario::PeriodsInColumnNames) + .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend' AND \"name\" != 'HTTP GET / DISPATCH'") + .with_expected_errors(Some(0)) + .with_pruned_by_stats(Some(2)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; } diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 81c8f67cc67b..e86d6172cecd 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -221,8 +221,12 @@ pub enum BuiltinScalarFunction { DateTrunc, /// date_bin DateBin, + /// ends_with + EndsWith, /// initcap InitCap, + /// InStr + InStr, /// left Left, /// lpad @@ -446,7 +450,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::DatePart => Volatility::Immutable, BuiltinScalarFunction::DateTrunc => Volatility::Immutable, BuiltinScalarFunction::DateBin => Volatility::Immutable, + BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, + BuiltinScalarFunction::InStr => Volatility::Immutable, BuiltinScalarFunction::Left => Volatility::Immutable, BuiltinScalarFunction::Lpad => Volatility::Immutable, BuiltinScalarFunction::Lower => Volatility::Immutable, @@ -708,6 +714,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => { utf8_to_str_type(&input_expr_types[0], "initcap") } + BuiltinScalarFunction::InStr => { + utf8_to_int_type(&input_expr_types[0], "instr") + } BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), BuiltinScalarFunction::Lower => { utf8_to_str_type(&input_expr_types[0], "lower") @@ -795,6 +804,7 @@ impl BuiltinScalarFunction { true, )))), BuiltinScalarFunction::StartsWith => Ok(Boolean), + BuiltinScalarFunction::EndsWith => Ok(Boolean), BuiltinScalarFunction::Strpos => { utf8_to_int_type(&input_expr_types[0], "strpos") } @@ -1211,17 +1221,19 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Strpos | BuiltinScalarFunction::StartsWith => { - Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - self.volatility(), - ) - } + + BuiltinScalarFunction::EndsWith + | BuiltinScalarFunction::InStr + | BuiltinScalarFunction::Strpos + | BuiltinScalarFunction::StartsWith => Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + self.volatility(), + ), BuiltinScalarFunction::Substr => Signature::one_of( vec![ @@ -1473,7 +1485,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Concat => &["concat"], BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::Chr => &["chr"], + BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], + BuiltinScalarFunction::InStr => &["instr"], BuiltinScalarFunction::Left => &["left"], BuiltinScalarFunction::Lower => &["lower"], BuiltinScalarFunction::Lpad => &["lpad"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1d45fa4facd0..006b5f10f10d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -798,6 +798,7 @@ scalar_expr!(Digest, digest, input algorithm, "compute the binary hash of `input scalar_expr!(Encode, encode, input encoding, "encode the `input`, using the `encoding`. encoding can be base64 or hex"); scalar_expr!(Decode, decode, input encoding, "decode the`input`, using the `encoding`. encoding can be base64 or hex"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); +scalar_expr!(InStr, instr, string substring, "returns the position of the first occurrence of `substring` in `string`"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); scalar_expr!(Lower, lower, string, "convert the string to lower case"); scalar_expr!( @@ -830,6 +831,7 @@ scalar_expr!(SHA512, sha512, string, "SHA-512 hash"); scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index."); scalar_expr!(StringToArray, string_to_array, string delimiter null_string, "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`"); scalar_expr!(StartsWith, starts_with, string prefix, "whether the `string` starts with the `prefix`"); +scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters"); @@ -1372,6 +1374,7 @@ mod test { test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); + test_scalar_expr!(InStr, instr, string, substring); test_scalar_expr!(Left, left, string, count); test_scalar_expr!(Lower, lower, string); test_nary_scalar_expr!(Lpad, lpad, string, count); @@ -1410,6 +1413,7 @@ mod test { test_scalar_expr!(SplitPart, split_part, expr, delimiter, index); test_scalar_expr!(StringToArray, string_to_array, expr, delimiter, null_value); test_scalar_expr!(StartsWith, starts_with, string, characters); + test_scalar_expr!(EndsWith, ends_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); test_scalar_expr!(Substr, substr, string, position); test_scalar_expr!(Substr, substring, string, position, count); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index fadb524bbd3e..25301da444cf 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -603,6 +603,15 @@ pub fn create_physical_fun( internal_err!("Unsupported data type {other:?} for function initcap") } }), + BuiltinScalarFunction::InStr => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::instr::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::instr::)(args) + } + other => internal_err!("Unsupported data type {other:?} for function instr"), + }), BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); @@ -825,6 +834,17 @@ pub fn create_physical_fun( internal_err!("Unsupported data type {other:?} for function starts_with") } }), + BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(string_expressions::ends_with::)(args) + } + DataType::LargeUtf8 => { + make_scalar_function(string_expressions::ends_with::)(args) + } + other => { + internal_err!("Unsupported data type {other:?} for function ends_with") + } + }), BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -1047,7 +1067,7 @@ mod tests { use arrow::{ array::{ Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, - Int32Array, StringArray, UInt64Array, + Int32Array, Int64Array, StringArray, UInt64Array, }, datatypes::Field, record_batch::RecordBatch, @@ -1439,6 +1459,95 @@ mod tests { Utf8, StringArray ); + test_function!( + InStr, + &[lit("abc"), lit("b")], + Ok(Some(2)), + i32, + Int32, + Int32Array + ); + test_function!( + InStr, + &[lit("abc"), lit("c")], + Ok(Some(3)), + i32, + Int32, + Int32Array + ); + test_function!( + InStr, + &[lit("abc"), lit("d")], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + InStr, + &[lit("abc"), lit("")], + Ok(Some(1)), + i32, + Int32, + Int32Array + ); + test_function!( + InStr, + &[lit("Helloworld"), lit("world")], + Ok(Some(6)), + i32, + Int32, + Int32Array + ); + test_function!( + InStr, + &[lit("Helloworld"), lit(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + InStr, + &[lit(ScalarValue::Utf8(None)), lit("Hello")], + Ok(None), + i32, + Int32, + Int32Array + ); + test_function!( + InStr, + &[ + lit(ScalarValue::LargeUtf8(Some("Helloworld".to_string()))), + lit(ScalarValue::LargeUtf8(Some("world".to_string()))) + ], + Ok(Some(6)), + i64, + Int64, + Int64Array + ); + test_function!( + InStr, + &[ + lit(ScalarValue::LargeUtf8(None)), + lit(ScalarValue::LargeUtf8(Some("world".to_string()))) + ], + Ok(None), + i64, + Int64, + Int64Array + ); + test_function!( + InStr, + &[ + lit(ScalarValue::LargeUtf8(Some("Helloworld".to_string()))), + lit(ScalarValue::LargeUtf8(None)) + ], + Ok(None), + i64, + Int64, + Int64Array + ); #[cfg(feature = "unicode_expressions")] test_function!( Left, @@ -2557,6 +2666,38 @@ mod tests { Boolean, BooleanArray ); + test_function!( + EndsWith, + &[lit("alphabet"), lit("alph"),], + Ok(Some(false)), + bool, + Boolean, + BooleanArray + ); + test_function!( + EndsWith, + &[lit("alphabet"), lit("bet"),], + Ok(Some(true)), + bool, + Boolean, + BooleanArray + ); + test_function!( + EndsWith, + &[lit(ScalarValue::Utf8(None)), lit("alph"),], + Ok(None), + bool, + Boolean, + BooleanArray + ); + test_function!( + EndsWith, + &[lit("alphabet"), lit(ScalarValue::Utf8(None)),], + Ok(None), + bool, + Boolean, + BooleanArray + ); #[cfg(feature = "unicode_expressions")] test_function!( Strpos, diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 7d9fecf61407..d5344773cfbc 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -23,8 +23,8 @@ use arrow::{ array::{ - Array, ArrayRef, BooleanArray, GenericStringArray, Int32Array, Int64Array, - OffsetSizeTrait, StringArray, + Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, OffsetSizeTrait, + StringArray, }, datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, }; @@ -296,6 +296,50 @@ pub fn initcap(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +/// Returns the position of the first occurrence of substring in string. +/// The position is counted from 1. If the substring is not found, returns 0. +/// For example, instr('Helloworld', 'world') = 6. +pub fn instr(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let substr_array = as_generic_string_array::(&args[1])?; + + match args[0].data_type() { + DataType::Utf8 => { + let result = string_array + .iter() + .zip(substr_array.iter()) + .map(|(string, substr)| match (string, substr) { + (Some(string), Some(substr)) => string + .find(substr) + .map_or(Some(0), |index| Some((index + 1) as i32)), + _ => None, + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) + } + DataType::LargeUtf8 => { + let result = string_array + .iter() + .zip(substr_array.iter()) + .map(|(string, substr)| match (string, substr) { + (Some(string), Some(substr)) => string + .find(substr) + .map_or(Some(0), |index| Some((index + 1) as i64)), + _ => None, + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) + } + other => { + internal_err!( + "instr was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + ) + } + } +} + /// Converts the string to all lower case. /// lower('TOM') = 'tom' pub fn lower(args: &[ColumnarValue]) -> Result { @@ -461,17 +505,21 @@ pub fn split_part(args: &[ArrayRef]) -> Result { /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' pub fn starts_with(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let prefix_array = as_generic_string_array::(&args[1])?; + let left = as_generic_string_array::(&args[0])?; + let right = as_generic_string_array::(&args[1])?; - let result = string_array - .iter() - .zip(prefix_array.iter()) - .map(|(string, prefix)| match (string, prefix) { - (Some(string), Some(prefix)) => Some(string.starts_with(prefix)), - _ => None, - }) - .collect::(); + let result = arrow::compute::kernels::comparison::starts_with(left, right)?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns true if string ends with suffix. +/// ends_with('alphabet', 'abet') = 't' +pub fn ends_with(args: &[ArrayRef]) -> Result { + let left = as_generic_string_array::(&args[0])?; + let right = as_generic_string_array::(&args[1])?; + + let result = arrow::compute::kernels::comparison::ends_with(left, right)?; Ok(Arc::new(result) as ArrayRef) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d79879e57a7d..66c1271e65c1 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -669,6 +669,8 @@ enum ScalarFunction { ArraySort = 128; ArrayDistinct = 129; ArrayResize = 130; + EndsWith = 131; + InStr = 132; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index d7ad6fb03c92..39a8678ef250 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22423,6 +22423,8 @@ impl serde::Serialize for ScalarFunction { Self::ArraySort => "ArraySort", Self::ArrayDistinct => "ArrayDistinct", Self::ArrayResize => "ArrayResize", + Self::EndsWith => "EndsWith", + Self::InStr => "InStr", }; serializer.serialize_str(variant) } @@ -22565,6 +22567,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArraySort", "ArrayDistinct", "ArrayResize", + "EndsWith", + "InStr", ]; struct GeneratedVisitor; @@ -22736,6 +22740,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArraySort" => Ok(ScalarFunction::ArraySort), "ArrayDistinct" => Ok(ScalarFunction::ArrayDistinct), "ArrayResize" => Ok(ScalarFunction::ArrayResize), + "EndsWith" => Ok(ScalarFunction::EndsWith), + "InStr" => Ok(ScalarFunction::InStr), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d594da90879c..7bf1d8ed0450 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2764,6 +2764,8 @@ pub enum ScalarFunction { ArraySort = 128, ArrayDistinct = 129, ArrayResize = 130, + EndsWith = 131, + InStr = 132, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2903,6 +2905,8 @@ impl ScalarFunction { ScalarFunction::ArraySort => "ArraySort", ScalarFunction::ArrayDistinct => "ArrayDistinct", ScalarFunction::ArrayResize => "ArrayResize", + ScalarFunction::EndsWith => "EndsWith", + ScalarFunction::InStr => "InStr", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3039,6 +3043,8 @@ impl ScalarFunction { "ArraySort" => Some(Self::ArraySort), "ArrayDistinct" => Some(Self::ArrayDistinct), "ArrayResize" => Some(Self::ArrayResize), + "EndsWith" => Some(Self::EndsWith), + "InStr" => Some(Self::InStr), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 8db5ccdfd604..42d39b5c5139 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -56,10 +56,10 @@ use datafusion_expr::{ ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest, - encode, exp, + encode, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, initcap, - isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, + instr, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, power, radians, random, regexp_match, regexp_replace, repeat, replace, reverse, right, @@ -529,7 +529,9 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::CharacterLength => Self::CharacterLength, ScalarFunction::Chr => Self::Chr, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, + ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, + ScalarFunction::InStr => Self::InStr, ScalarFunction::Left => Self::Left, ScalarFunction::Lpad => Self::Lpad, ScalarFunction::Random => Self::Random, @@ -1586,6 +1588,10 @@ pub fn parse_expr( } ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry)?)), ScalarFunction::InitCap => Ok(initcap(parse_expr(&args[0], registry)?)), + ScalarFunction::InStr => Ok(instr( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::Gcd => Ok(gcd( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, @@ -1665,6 +1671,10 @@ pub fn parse_expr( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, )), + ScalarFunction::EndsWith => Ok(ends_with( + parse_expr(&args[0], registry)?, + parse_expr(&args[1], registry)?, + )), ScalarFunction::Strpos => Ok(strpos( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 7eef3da9519f..dbb52eced36c 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1525,7 +1525,9 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::CharacterLength => Self::CharacterLength, BuiltinScalarFunction::Chr => Self::Chr, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, + BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, + BuiltinScalarFunction::InStr => Self::InStr, BuiltinScalarFunction::Left => Self::Left, BuiltinScalarFunction::Lpad => Self::Lpad, BuiltinScalarFunction::Random => Self::Random, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index a098c8de0d3c..e9c92f53e0fa 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2515,198 +2515,6 @@ false true NULL -# TopK aggregation -statement ok -CREATE TABLE traces(trace_id varchar, timestamp bigint, other bigint) AS VALUES -(NULL, 0, 0), -('a', NULL, NULL), -('a', 1, 1), -('a', -1, -1), -('b', 0, 0), -('c', 1, 1), -('c', 2, 2), -('b', 3, 3); - -statement ok -set datafusion.optimizer.enable_topk_aggregation = false; - -query TT -explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; ----- -logical_plan -Limit: skip=0, fetch=4 ---Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 -----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] -------TableScan: traces projection=[trace_id, timestamp] -physical_plan -GlobalLimitExec: skip=0, fetch=4 ---SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 -----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 DESC] -------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] -----------------MemoryExec: partitions=1, partition_sizes=[1] - - -query TI -select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; ----- -b 3 -c 2 -a 1 -NULL 0 - -query TI -select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; ----- -a -1 -NULL 0 -b 0 -c 1 - -query TII -select trace_id, other, MIN(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4; ----- -a -1 -1 -b 0 0 -NULL 0 0 -c 1 1 - -query TII -select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; ----- -a -1 -1 -NULL 0 0 -b 0 0 -c 1 1 - -statement ok -set datafusion.optimizer.enable_topk_aggregation = true; - -query TT -explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; ----- -logical_plan -Limit: skip=0, fetch=4 ---Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 -----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] -------TableScan: traces projection=[trace_id, timestamp] -physical_plan -GlobalLimitExec: skip=0, fetch=4 ---SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 -----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 DESC] -------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] -----------------MemoryExec: partitions=1, partition_sizes=[1] - -query TT -explain select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) desc limit 4; ----- -logical_plan -Limit: skip=0, fetch=4 ---Sort: MIN(traces.timestamp) DESC NULLS FIRST, fetch=4 -----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MIN(traces.timestamp)]] -------TableScan: traces projection=[trace_id, timestamp] -physical_plan -GlobalLimitExec: skip=0, fetch=4 ---SortPreservingMergeExec: [MIN(traces.timestamp)@1 DESC], fetch=4 -----SortExec: TopK(fetch=4), expr=[MIN(traces.timestamp)@1 DESC] -------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] -----------------MemoryExec: partitions=1, partition_sizes=[1] - -query TT -explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) asc limit 4; ----- -logical_plan -Limit: skip=0, fetch=4 ---Sort: MAX(traces.timestamp) ASC NULLS LAST, fetch=4 -----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] -------TableScan: traces projection=[trace_id, timestamp] -physical_plan -GlobalLimitExec: skip=0, fetch=4 ---SortPreservingMergeExec: [MAX(traces.timestamp)@1 ASC NULLS LAST], fetch=4 -----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 ASC NULLS LAST] -------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] -----------------MemoryExec: partitions=1, partition_sizes=[1] - -query TT -explain select trace_id, MAX(timestamp) from traces group by trace_id order by trace_id asc limit 4; ----- -logical_plan -Limit: skip=0, fetch=4 ---Sort: traces.trace_id ASC NULLS LAST, fetch=4 -----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] -------TableScan: traces projection=[trace_id, timestamp] -physical_plan -GlobalLimitExec: skip=0, fetch=4 ---SortPreservingMergeExec: [trace_id@0 ASC NULLS LAST], fetch=4 -----SortExec: TopK(fetch=4), expr=[trace_id@0 ASC NULLS LAST] -------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] ---------CoalesceBatchesExec: target_batch_size=8192 -----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 ---------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] -----------------MemoryExec: partitions=1, partition_sizes=[1] - -query TI -select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; ----- -b 3 -c 2 -a 1 -NULL 0 - -query TI -select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; ----- -a -1 -NULL 0 -b 0 -c 1 - -query TI -select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 3; ----- -b 3 -c 2 -a 1 - -query TI -select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 3; ----- -a -1 -NULL 0 -b 0 - -query TII -select trace_id, other, MIN(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4; ----- -a -1 -1 -b 0 0 -NULL 0 0 -c 1 1 - -query TII -select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; ----- -a -1 -1 -NULL 0 0 -b 0 0 -c 1 1 - # # Push limit into distinct group-by aggregation tests # diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt new file mode 100644 index 000000000000..6b6204e09f40 --- /dev/null +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +####### +# Setup test data table +####### + +# TopK aggregation +statement ok +CREATE TABLE traces(trace_id varchar, timestamp bigint, other bigint) AS VALUES +(NULL, 0, 0), +('a', NULL, NULL), +('a', 1, 1), +('a', -1, -1), +('b', 0, 0), +('c', 1, 1), +('c', 2, 2), +('b', 3, 3); + +statement ok +set datafusion.optimizer.enable_topk_aggregation = false; + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 +----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +----------------MemoryExec: partitions=1, partition_sizes=[1] + + +query TI +select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +b 3 +c 2 +a 1 +NULL 0 + +query TI +select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; +---- +a -1 +NULL 0 +b 0 +c 1 + +query TII +select trace_id, other, MIN(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4; +---- +a -1 -1 +b 0 0 +NULL 0 0 +c 1 1 + +query TII +select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; +---- +a -1 -1 +NULL 0 0 +b 0 0 +c 1 1 + +statement ok +set datafusion.optimizer.enable_topk_aggregation = true; + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 +----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) desc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MIN(traces.timestamp) DESC NULLS FIRST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MIN(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MIN(traces.timestamp)@1 DESC], fetch=4 +----SortExec: TopK(fetch=4), expr=[MIN(traces.timestamp)@1 DESC] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) asc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: MAX(traces.timestamp) ASC NULLS LAST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [MAX(traces.timestamp)@1 ASC NULLS LAST], fetch=4 +----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +explain select trace_id, MAX(timestamp) from traces group by trace_id order by trace_id asc limit 4; +---- +logical_plan +Limit: skip=0, fetch=4 +--Sort: traces.trace_id ASC NULLS LAST, fetch=4 +----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] +------TableScan: traces projection=[trace_id, timestamp] +physical_plan +GlobalLimitExec: skip=0, fetch=4 +--SortPreservingMergeExec: [trace_id@0 ASC NULLS LAST], fetch=4 +----SortExec: TopK(fetch=4), expr=[trace_id@0 ASC NULLS LAST] +------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +--------CoalesceBatchesExec: target_batch_size=8192 +----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] +----------------MemoryExec: partitions=1, partition_sizes=[1] + +query TI +select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +---- +b 3 +c 2 +a 1 +NULL 0 + +query TI +select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; +---- +a -1 +NULL 0 +b 0 +c 1 + +query TI +select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 3; +---- +b 3 +c 2 +a 1 + +query TI +select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 3; +---- +a -1 +NULL 0 +b 0 + +query TII +select trace_id, other, MIN(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4; +---- +a -1 -1 +b 0 0 +NULL 0 0 +c 1 1 + +query TII +select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; +---- +a -1 -1 +NULL 0 0 +b 0 0 +c 1 1 diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 7bd60a3a154b..d3f81cc61e95 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -647,6 +647,21 @@ SELECT initcap(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) ---- Foo +query I +SELECT instr('foobarbar', 'bar') +---- +4 + +query I +SELECT instr('foobarbar', 'aa') +---- +0 + +query I +SELECT instr('foobarbar', '') +---- +1 + query T SELECT lower('FOObar') ---- @@ -727,6 +742,26 @@ SELECT split_part(arrow_cast('foo_bar', 'Dictionary(Int32, Utf8)'), '_', 2) ---- bar +query B +SELECT starts_with('foobar', 'foo') +---- +true + +query B +SELECT starts_with('foobar', 'bar') +---- +false + +query B +SELECT ends_with('foobar', 'bar') +---- +true + +query B +SELECT ends_with('foobar', 'foo') +---- +false + query T SELECT trim(' foo ') ---- diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 50e1cbc3d622..c3def3f89b5b 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -613,7 +613,9 @@ nullif(expression1, expression2) - [concat](#concat) - [concat_ws](#concat_ws) - [chr](#chr) +- [ends_with](#ends_with) - [initcap](#initcap) +- [instr](#instr) - [left](#left) - [length](#length) - [lower](#lower) @@ -756,6 +758,20 @@ chr(expression) **Related functions**: [ascii](#ascii) +### `ends_with` + +Tests if a string ends with a substring. + +``` +ends_with(str, substr) +``` + +#### Arguments + +- **str**: String expression to test. + Can be a constant, column, or function, and any combination of string operators. +- **substr**: Substring to test for. + ### `initcap` Capitalizes the first character in each word in the input string. @@ -774,6 +790,22 @@ initcap(str) [lower](#lower), [upper](#upper) +### `instr` + +Returns the location where substr first appeared in str (counting from 1). +If substr does not appear in str, return 0. + +``` +instr(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. + Can be a constant, column, or function, and any combination of string operators. +- **substr**: Substring expression to search for. + Can be a constant, column, or function, and any combination of string operators. + ### `left` Returns a specified number of characters from the left side of a string.