Skip to content

Commit

Permalink
Implement native support StringView for Levenshtein (#11925)
Browse files Browse the repository at this point in the history
* Implement native support StringView for Levenshtein

Signed-off-by: Chojan Shang <[email protected]>

* Remove useless code

Signed-off-by: Chojan Shang <[email protected]>

* Minor fix

Signed-off-by: Chojan Shang <[email protected]>

---------

Signed-off-by: Chojan Shang <[email protected]>
  • Loading branch information
PsiACE authored Aug 12, 2024
1 parent 8deba02 commit b60cdc7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
37 changes: 30 additions & 7 deletions datafusion/functions/src/string/levenshtein.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait};
use arrow::datatypes::DataType;

use crate::utils::{make_scalar_function, utf8_to_int_type};
use datafusion_common::cast::as_generic_string_array;
use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
use datafusion_common::utils::datafusion_strsim;
use datafusion_common::{exec_err, Result};
use datafusion_expr::ColumnarValue;
Expand All @@ -42,10 +42,13 @@ impl Default for LevenshteinFunc {

impl LevenshteinFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
vec![
Exact(vec![DataType::Utf8View, DataType::Utf8View]),
Exact(vec![DataType::Utf8, DataType::Utf8]),
Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]),
],
Volatility::Immutable,
),
}
Expand All @@ -71,7 +74,9 @@ impl ScalarUDFImpl for LevenshteinFunc {

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(levenshtein::<i32>, vec![])(args),
DataType::Utf8View | DataType::Utf8 => {
make_scalar_function(levenshtein::<i32>, vec![])(args)
}
DataType::LargeUtf8 => make_scalar_function(levenshtein::<i64>, vec![])(args),
other => {
exec_err!("Unsupported data type {other:?} for function levenshtein")
Expand All @@ -89,10 +94,26 @@ pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
args.len()
);
}
let str1_array = as_generic_string_array::<T>(&args[0])?;
let str2_array = as_generic_string_array::<T>(&args[1])?;

match args[0].data_type() {
DataType::Utf8View => {
let str1_array = as_string_view_array(&args[0])?;
let str2_array = as_string_view_array(&args[1])?;
let result = str1_array
.iter()
.zip(str2_array.iter())
.map(|(string1, string2)| match (string1, string2) {
(Some(string1), Some(string2)) => {
Some(datafusion_strsim::levenshtein(string1, string2) as i32)
}
_ => None,
})
.collect::<Int32Array>();
Ok(Arc::new(result) as ArrayRef)
}
DataType::Utf8 => {
let str1_array = as_generic_string_array::<T>(&args[0])?;
let str2_array = as_generic_string_array::<T>(&args[1])?;
let result = str1_array
.iter()
.zip(str2_array.iter())
Expand All @@ -106,6 +127,8 @@ pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(result) as ArrayRef)
}
DataType::LargeUtf8 => {
let str1_array = as_generic_string_array::<T>(&args[0])?;
let str2_array = as_generic_string_array::<T>(&args[1])?;
let result = str1_array
.iter()
.zip(str2_array.iter())
Expand All @@ -120,7 +143,7 @@ pub fn levenshtein<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
}
other => {
exec_err!(
"levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8."
"levenshtein was called with {other} datatype arguments. It requires Utf8View, Utf8 or LargeUtf8."
)
}
}
Expand Down
6 changes: 2 additions & 4 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -629,17 +629,15 @@ logical_plan
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]

## Ensure no casts for LEVENSHTEIN
## TODO https://github.com/apache/datafusion/issues/11854
query TT
EXPLAIN SELECT
levenshtein(column1_utf8view, 'foo') as c1,
levenshtein(column1_utf8view, column2_utf8view) as c2
FROM test;
----
logical_plan
01)Projection: levenshtein(__common_expr_1, Utf8("foo")) AS c1, levenshtein(__common_expr_1, CAST(test.column2_utf8view AS Utf8)) AS c2
02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1, test.column2_utf8view
03)----TableScan: test projection=[column1_utf8view, column2_utf8view]
01)Projection: levenshtein(test.column1_utf8view, Utf8View("foo")) AS c1, levenshtein(test.column1_utf8view, test.column2_utf8view) AS c2
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]

## Ensure no casts for LOWER
## TODO https://github.com/apache/datafusion/issues/11855
Expand Down

0 comments on commit b60cdc7

Please sign in to comment.