From 1afc410855825208e9ffdea1aeeef56e8f488369 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Mon, 3 Feb 2025 13:37:14 -0500 Subject: [PATCH] Update REGEXP_MATCH scalar function to support Utf8View (#14449) * Update REGEXP_MATCH scalar function to support Utf8View * Cargo fmt fix. --- datafusion/functions/benches/regx.rs | 59 ++++++++++++++++- datafusion/functions/src/regex/regexpmatch.rs | 66 +++++++++---------- datafusion/sqllogictest/test_files/regexp.slt | 60 +++++++++++++++-- .../test_files/string/string_view.slt | 2 +- 4 files changed, 144 insertions(+), 43 deletions(-) diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 468d3d548bcf..1f99cc3a5f0b 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::builder::StringBuilder; -use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray}; +use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray, StringViewArray}; use arrow::compute::cast; use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -141,6 +141,20 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function("regexp_like_1000 utf8view", |b| { + let mut rng = rand::thread_rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + + b.iter(|| { + black_box( + regexp_like(&[Arc::clone(&data), Arc::clone(®ex), Arc::clone(&flags)]) + .expect("regexp_like should work on valid values"), + ) + }) + }); + c.bench_function("regexp_match_1000", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; @@ -149,7 +163,25 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - regexp_match::(&[ + regexp_match(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&flags), + ]) + .expect("regexp_match should work on valid values"), + ) + }) + }); + + c.bench_function("regexp_match_1000 utf8view", |b| { + let mut rng = rand::thread_rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + + b.iter(|| { + black_box( + regexp_match(&[ Arc::clone(&data), Arc::clone(®ex), Arc::clone(&flags), @@ -180,6 +212,29 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + c.bench_function("regexp_replace_1000 utf8view", |b| { + let mut rng = rand::thread_rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + // flags are not allowed to be utf8view according to the function + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + let replacement = Arc::new(StringViewArray::from_iter_values( + iter::repeat("XX").take(1000), + )); + + b.iter(|| { + black_box( + regexp_replace::( + data.as_string_view(), + regex.as_string_view(), + &replacement, + Some(&flags), + ) + .expect("regexp_replace should work on valid values"), + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 06b9a9d98b47..57207ecfdacd 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -16,16 +16,14 @@ // under the License. //! Regex expressions -use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; +use arrow::array::{Array, ArrayRef, AsArray}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; use arrow::datatypes::Field; use datafusion_common::exec_err; use datafusion_common::ScalarValue; use datafusion_common::{arrow_datafusion_err, plan_err}; -use datafusion_common::{ - cast::as_generic_string_array, internal_err, DataFusionError, Result, -}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -86,11 +84,12 @@ impl RegexpMatchFunc { signature: Signature::one_of( vec![ // Planner attempts coercion to the target type starting with the most preferred candidate. - // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8, Utf8)`. - // If that fails, it proceeds to `(LargeUtf8, Utf8)`. - // TODO: Native support Utf8View for regexp_match. + // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`. + // If that fails, it proceeds to `(Utf8, Utf8)`. + TypeSignature::Exact(vec![Utf8View, Utf8View]), TypeSignature::Exact(vec![Utf8, Utf8]), TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8View]), TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), ], @@ -138,7 +137,7 @@ impl ScalarUDFImpl for RegexpMatchFunc { .map(|arg| arg.to_array(inferred_length)) .collect::>>()?; - let result = regexp_match_func(&args); + let result = regexp_match(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); @@ -153,33 +152,35 @@ impl ScalarUDFImpl for RegexpMatchFunc { } } -fn regexp_match_func(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Utf8 => regexp_match::(args), - DataType::LargeUtf8 => regexp_match::(args), - other => { - internal_err!("Unsupported data type {other:?} for function regexp_match") - } - } -} -pub fn regexp_match(args: &[ArrayRef]) -> Result { +pub fn regexp_match(args: &[ArrayRef]) -> Result { match args.len() { 2 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - regexp::regexp_match(values, regex, None) + regexp::regexp_match(&args[0], &args[1], None) .map_err(|e| arrow_datafusion_err!(e)) } 3 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - let flags = as_generic_string_array::(&args[2])?; - - if flags.iter().any(|s| s == Some("g")) { - return plan_err!("regexp_match() does not support the \"global\" option"); + match args[2].data_type() { + DataType::Utf8View => { + if args[2].as_string_view().iter().any(|s| s == Some("g")) { + return plan_err!("regexp_match() does not support the \"global\" option"); + } + } + DataType::Utf8 => { + if args[2].as_string::().iter().any(|s| s == Some("g")) { + return plan_err!("regexp_match() does not support the \"global\" option"); + } + } + DataType::LargeUtf8 => { + if args[2].as_string::().iter().any(|s| s == Some("g")) { + return plan_err!("regexp_match() does not support the \"global\" option"); + } + } + e => { + return plan_err!("regexp_match was called with unexpected data type {e:?}"); + } } - regexp::regexp_match(values, regex, Some(flags)) + regexp::regexp_match(&args[0], &args[1], Some(&args[2])) .map_err(|e| arrow_datafusion_err!(e)) } other => exec_err!( @@ -211,7 +212,7 @@ mod tests { expected_builder.append(false); let expected = expected_builder.finish(); - let re = regexp_match::(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + let re = regexp_match(&[Arc::new(values), Arc::new(patterns)]).unwrap(); assert_eq!(re.as_ref(), &expected); } @@ -236,9 +237,8 @@ mod tests { expected_builder.append(false); let expected = expected_builder.finish(); - let re = - regexp_match::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) - .unwrap(); + let re = regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); assert_eq!(re.as_ref(), &expected); } @@ -250,7 +250,7 @@ mod tests { let flags = StringArray::from(vec!["g"]); let re_err = - regexp_match::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) .expect_err("unsupported flag should have failed"); assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_match() does not support the \"global\" option"); diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index 800026dd766d..80f94e21d1fe 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -193,6 +193,29 @@ NULL [Köln] [إسرائيل] +# test string view +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t; + +query ? +SELECT regexp_match(str, pattern, flags) FROM t_stringview; +---- +[a] +[A] +[B] +NULL +NULL +NULL +[010] +[Düsseldorf] +[Москва] +[Köln] +[إسرائيل] + +statement ok +DROP TABLE t_stringview; + query ? SELECT regexp_match('foobarbequebaz', ''); ---- @@ -354,6 +377,29 @@ X X X +# test string view +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t; + +query T +SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t_stringview; +---- +Xbc +X +aXc +AbC +aBC +4000 +X +X +X +X +X + +statement ok +DROP TABLE t_stringview; + query T SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'gi'); ---- @@ -621,7 +667,7 @@ CREATE TABLE t_stringview AS SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t; query I -SELECT regexp_count(str, '\w') from t; +SELECT regexp_count(str, '\w') from t_stringview; ---- 3 3 @@ -636,7 +682,7 @@ SELECT regexp_count(str, '\w') from t; 7 query I -SELECT regexp_count(str, '\w{2}', start) from t; +SELECT regexp_count(str, '\w{2}', start) from t_stringview; ---- 1 1 @@ -651,7 +697,7 @@ SELECT regexp_count(str, '\w{2}', start) from t; 3 query I -SELECT regexp_count(str, 'ab', 1, 'i') from t; +SELECT regexp_count(str, 'ab', 1, 'i') from t_stringview; ---- 1 1 @@ -667,7 +713,7 @@ SELECT regexp_count(str, 'ab', 1, 'i') from t; query I -SELECT regexp_count(str, pattern) from t; +SELECT regexp_count(str, pattern) from t_stringview; ---- 1 1 @@ -682,7 +728,7 @@ SELECT regexp_count(str, pattern) from t; 1 query I -SELECT regexp_count(str, pattern, start) from t; +SELECT regexp_count(str, pattern, start) from t_stringview; ---- 1 1 @@ -697,7 +743,7 @@ SELECT regexp_count(str, pattern, start) from t; 1 query I -SELECT regexp_count(str, pattern, start, flags) from t; +SELECT regexp_count(str, pattern, start, flags) from t_stringview; ---- 1 1 @@ -713,7 +759,7 @@ SELECT regexp_count(str, pattern, start, flags) from t; # test type coercion query I -SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t_stringview; ---- 1 1 diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 435b4bc3c5a8..3b70e6de80be 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -783,7 +783,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: regexp_match(CAST(test.column1_utf8view AS Utf8), Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k +01)Projection: regexp_match(test.column1_utf8view, Utf8View("^https?://(?:www\.)?([^/]+)/.*$")) AS k 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for REGEXP_REPLACE