Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update REGEXP_MATCH scalar function to support Utf8View #14449

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions datafusion/functions/benches/regx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
extern crate criterion;

use arrow::array::builder::StringBuilder;
use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray};
use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray, StringViewArray};
use arrow::compute::cast;
use arrow::datatypes::DataType;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
Expand Down Expand Up @@ -141,6 +141,20 @@ fn criterion_benchmark(c: &mut Criterion) {
})
});

c.bench_function("regexp_like_1000 utf8view", |b| {
let mut rng = rand::thread_rng();
let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap();
let regex = cast(&regex(&mut rng), &DataType::Utf8View).unwrap();
let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap();

b.iter(|| {
black_box(
regexp_like(&[Arc::clone(&data), Arc::clone(&regex), Arc::clone(&flags)])
.expect("regexp_like should work on valid values"),
)
})
});

c.bench_function("regexp_match_1000", |b| {
let mut rng = rand::thread_rng();
let data = Arc::new(data(&mut rng)) as ArrayRef;
Expand All @@ -149,7 +163,25 @@ fn criterion_benchmark(c: &mut Criterion) {

b.iter(|| {
black_box(
regexp_match::<i32>(&[
regexp_match(&[
Arc::clone(&data),
Arc::clone(&regex),
Arc::clone(&flags),
])
.expect("regexp_match should work on valid values"),
)
})
});

c.bench_function("regexp_match_1000 utf8view", |b| {
let mut rng = rand::thread_rng();
let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap();
let regex = cast(&regex(&mut rng), &DataType::Utf8View).unwrap();
let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap();

b.iter(|| {
black_box(
regexp_match(&[
Arc::clone(&data),
Arc::clone(&regex),
Arc::clone(&flags),
Expand Down Expand Up @@ -180,6 +212,29 @@ fn criterion_benchmark(c: &mut Criterion) {
)
})
});

c.bench_function("regexp_replace_1000 utf8view", |b| {
let mut rng = rand::thread_rng();
let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap();
let regex = cast(&regex(&mut rng), &DataType::Utf8View).unwrap();
// flags are not allowed to be utf8view according to the function
let flags = Arc::new(flags(&mut rng)) as ArrayRef;
let replacement = Arc::new(StringViewArray::from_iter_values(
iter::repeat("XX").take(1000),
));

b.iter(|| {
black_box(
regexp_replace::<i32, _, _>(
data.as_string_view(),
regex.as_string_view(),
&replacement,
Some(&flags),
)
.expect("regexp_replace should work on valid values"),
)
})
});
}

criterion_group!(benches, criterion_benchmark);
Expand Down
66 changes: 33 additions & 33 deletions datafusion/functions/src/regex/regexpmatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@
// under the License.

//! Regex expressions
use arrow::array::{Array, ArrayRef, OffsetSizeTrait};
use arrow::array::{Array, ArrayRef, AsArray};
use arrow::compute::kernels::regexp;
use arrow::datatypes::DataType;
use arrow::datatypes::Field;
use datafusion_common::exec_err;
use datafusion_common::ScalarValue;
use datafusion_common::{arrow_datafusion_err, plan_err};
use datafusion_common::{
cast::as_generic_string_array, internal_err, DataFusionError, Result,
};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{ColumnarValue, Documentation, TypeSignature};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use datafusion_macros::user_doc;
Expand Down Expand Up @@ -86,11 +84,12 @@ impl RegexpMatchFunc {
signature: Signature::one_of(
vec![
// Planner attempts coercion to the target type starting with the most preferred candidate.
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8, Utf8)`.
// If that fails, it proceeds to `(LargeUtf8, Utf8)`.
// TODO: Native support Utf8View for regexp_match.
// For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8View, Utf8View)`.
// If that fails, it proceeds to `(Utf8, Utf8)`.
TypeSignature::Exact(vec![Utf8View, Utf8View]),
TypeSignature::Exact(vec![Utf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]),
TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8View]),
TypeSignature::Exact(vec![Utf8, Utf8, Utf8]),
TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]),
],
Expand Down Expand Up @@ -138,7 +137,7 @@ impl ScalarUDFImpl for RegexpMatchFunc {
.map(|arg| arg.to_array(inferred_length))
.collect::<Result<Vec<_>>>()?;

let result = regexp_match_func(&args);
let result = regexp_match(&args);
if is_scalar {
// If all inputs are scalar, keeps output as scalar
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
Expand All @@ -153,33 +152,35 @@ impl ScalarUDFImpl for RegexpMatchFunc {
}
}

fn regexp_match_func(args: &[ArrayRef]) -> Result<ArrayRef> {
match args[0].data_type() {
DataType::Utf8 => regexp_match::<i32>(args),
DataType::LargeUtf8 => regexp_match::<i64>(args),
other => {
internal_err!("Unsupported data type {other:?} for function regexp_match")
}
}
}
pub fn regexp_match<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
pub fn regexp_match(args: &[ArrayRef]) -> Result<ArrayRef> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree technically speaking this is an API change but I also think it is small and is ok. I will flag this PR as API change but I think it will be minimally disruptive

match args.len() {
2 => {
let values = as_generic_string_array::<T>(&args[0])?;
let regex = as_generic_string_array::<T>(&args[1])?;
regexp::regexp_match(values, regex, None)
regexp::regexp_match(&args[0], &args[1], None)
.map_err(|e| arrow_datafusion_err!(e))
}
3 => {
let values = as_generic_string_array::<T>(&args[0])?;
let regex = as_generic_string_array::<T>(&args[1])?;
let flags = as_generic_string_array::<T>(&args[2])?;

if flags.iter().any(|s| s == Some("g")) {
return plan_err!("regexp_match() does not support the \"global\" option");
match args[2].data_type() {
DataType::Utf8View => {
if args[2].as_string_view().iter().any(|s| s == Some("g")) {
return plan_err!("regexp_match() does not support the \"global\" option");
}
}
DataType::Utf8 => {
if args[2].as_string::<i32>().iter().any(|s| s == Some("g")) {
return plan_err!("regexp_match() does not support the \"global\" option");
}
}
DataType::LargeUtf8 => {
if args[2].as_string::<i64>().iter().any(|s| s == Some("g")) {
return plan_err!("regexp_match() does not support the \"global\" option");
}
}
e => {
return plan_err!("regexp_match was called with unexpected data type {e:?}");
}
}

regexp::regexp_match(values, regex, Some(flags))
regexp::regexp_match(&args[0], &args[1], Some(&args[2]))
.map_err(|e| arrow_datafusion_err!(e))
}
other => exec_err!(
Expand Down Expand Up @@ -211,7 +212,7 @@ mod tests {
expected_builder.append(false);
let expected = expected_builder.finish();

let re = regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns)]).unwrap();
let re = regexp_match(&[Arc::new(values), Arc::new(patterns)]).unwrap();

assert_eq!(re.as_ref(), &expected);
}
Expand All @@ -236,9 +237,8 @@ mod tests {
expected_builder.append(false);
let expected = expected_builder.finish();

let re =
regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.unwrap();
let re = regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.unwrap();

assert_eq!(re.as_ref(), &expected);
}
Expand All @@ -250,7 +250,7 @@ mod tests {
let flags = StringArray::from(vec!["g"]);

let re_err =
regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
regexp_match(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.expect_err("unsupported flag should have failed");

assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_match() does not support the \"global\" option");
Expand Down
60 changes: 53 additions & 7 deletions datafusion/sqllogictest/test_files/regexp.slt
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,29 @@ NULL
[Köln]
[إسرائيل]

# test string view
statement ok
CREATE TABLE t_stringview AS
SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t;

query ?
SELECT regexp_match(str, pattern, flags) FROM t_stringview;
----
[a]
[A]
[B]
NULL
NULL
NULL
[010]
[Düsseldorf]
[Москва]
[Köln]
[إسرائيل]

statement ok
DROP TABLE t_stringview;

query ?
SELECT regexp_match('foobarbequebaz', '');
----
Expand Down Expand Up @@ -354,6 +377,29 @@ X
X
X

# test string view
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in #11911 (comment) it would be great to move these tests into string.slt but we can totally do it as a follow on as well

statement ok
CREATE TABLE t_stringview AS
SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(flags, 'Utf8View') as flags FROM t;

query T
SELECT regexp_replace(str, pattern, 'X', concat('g', flags)) FROM t_stringview;
----
Xbc
X
aXc
AbC
aBC
4000
X
X
X
X
X

statement ok
DROP TABLE t_stringview;

query T
SELECT regexp_replace('ABCabcABC', '(abc)', 'X', 'gi');
----
Expand Down Expand Up @@ -621,7 +667,7 @@ CREATE TABLE t_stringview AS
SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t;

query I
SELECT regexp_count(str, '\w') from t;
SELECT regexp_count(str, '\w') from t_stringview;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 this looks like a driveby fix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it was. Couldn't resist.

----
3
3
Expand All @@ -636,7 +682,7 @@ SELECT regexp_count(str, '\w') from t;
7

query I
SELECT regexp_count(str, '\w{2}', start) from t;
SELECT regexp_count(str, '\w{2}', start) from t_stringview;
----
1
1
Expand All @@ -651,7 +697,7 @@ SELECT regexp_count(str, '\w{2}', start) from t;
3

query I
SELECT regexp_count(str, 'ab', 1, 'i') from t;
SELECT regexp_count(str, 'ab', 1, 'i') from t_stringview;
----
1
1
Expand All @@ -667,7 +713,7 @@ SELECT regexp_count(str, 'ab', 1, 'i') from t;


query I
SELECT regexp_count(str, pattern) from t;
SELECT regexp_count(str, pattern) from t_stringview;
----
1
1
Expand All @@ -682,7 +728,7 @@ SELECT regexp_count(str, pattern) from t;
1

query I
SELECT regexp_count(str, pattern, start) from t;
SELECT regexp_count(str, pattern, start) from t_stringview;
----
1
1
Expand All @@ -697,7 +743,7 @@ SELECT regexp_count(str, pattern, start) from t;
1

query I
SELECT regexp_count(str, pattern, start, flags) from t;
SELECT regexp_count(str, pattern, start, flags) from t_stringview;
----
1
1
Expand All @@ -713,7 +759,7 @@ SELECT regexp_count(str, pattern, start, flags) from t;

# test type coercion
query I
SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t;
SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t_stringview;
----
1
1
Expand Down
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/string/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: regexp_match(CAST(test.column1_utf8view AS Utf8), Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k
01)Projection: regexp_match(test.column1_utf8view, Utf8View("^https?://(?:www\.)?([^/]+)/.*$")) AS k
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

02)--TableScan: test projection=[column1_utf8view]

## Ensure no casts for REGEXP_REPLACE
Expand Down