From 6c40214f3d93907686ed731caa3d572a9fa93d53 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 26 Apr 2024 19:57:11 +0800 Subject: [PATCH] [SPARK-47350][SQL] Add collation support for SplitPart string expression ### What changes were proposed in this pull request? Introduce collation awareness for string expression: split_part. ### Why are the changes needed? Add collation support for built-in string function in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for built-in string function: split_part. ### How was this patch tested? Unit collation support tests and e2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46158 from uros-db/SPARK-47350. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../sql/catalyst/util/CollationSupport.java | 58 +++++++++++++ .../unsafe/types/CollationSupportSuite.java | 86 +++++++++++++++++++ .../expressions/stringExpressions.scala | 15 ++-- .../sql/CollationStringExpressionsSuite.scala | 17 ++++ 4 files changed, 170 insertions(+), 6 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index 70a3f5bd61362..0c03faa0d23a9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -23,6 +23,8 @@ import org.apache.spark.unsafe.types.UTF8String; +import java.util.ArrayList; +import java.util.List; import java.util.regex.Pattern; /** @@ -36,6 +38,62 @@ public final class CollationSupport { * Collation-aware string expressions. */ + public static class StringSplitSQL { + public static UTF8String[] exec(final UTF8String s, final UTF8String d, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return execBinary(s, d); + } else if (collation.supportsLowercaseEquality) { + return execLowercase(s, d); + } else { + return execICU(s, d, collationId); + } + } + public static String genCode(final String s, final String d, final int collationId) { + CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); + String expr = "CollationSupport.StringSplitSQL.exec"; + if (collation.supportsBinaryEquality) { + return String.format(expr + "Binary(%s, %s)", s, d); + } else if (collation.supportsLowercaseEquality) { + return String.format(expr + "Lowercase(%s, %s)", s, d); + } else { + return String.format(expr + "ICU(%s, %s, %d)", s, d, collationId); + } + } + public static UTF8String[] execBinary(final UTF8String string, final UTF8String delimiter) { + return string.splitSQL(delimiter, -1); + } + public static UTF8String[] execLowercase(final UTF8String string, final UTF8String delimiter) { + if (delimiter.numBytes() == 0) return new UTF8String[] { string }; + if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 }; + Pattern pattern = Pattern.compile(Pattern.quote(delimiter.toString()), + CollationSupport.lowercaseRegexFlags); + String[] splits = pattern.split(string.toString(), -1); + UTF8String[] res = new UTF8String[splits.length]; + for (int i = 0; i < res.length; i++) { + res[i] = UTF8String.fromString(splits[i]); + } + return res; + } + public static UTF8String[] execICU(final UTF8String string, final UTF8String delimiter, + final int collationId) { + if (delimiter.numBytes() == 0) return new UTF8String[] { string }; + if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 }; + List strings = new ArrayList<>(); + String target = string.toString(), pattern = delimiter.toString(); + StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId); + int start = 0, end; + while ((end = stringSearch.next()) != StringSearch.DONE) { + strings.add(UTF8String.fromString(target.substring(start, end))); + start = end + stringSearch.getMatchLength(); + } + if (start <= target.length()) { + strings.add(UTF8String.fromString(target.substring(start))); + } + return strings.toArray(new UTF8String[0]); + } + } + public static class Contains { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index d59bd5c20e674..72edd3e06f9ce 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -295,6 +295,92 @@ public void testEndsWith() throws SparkException { assertEndsWith("The KKelvin", "KKelvin,", "UTF8_BINARY_LCASE", false); } + private void assertStringSplitSQL(String str, String delimiter, String collationName, + UTF8String[] expected) throws SparkException { + UTF8String s = UTF8String.fromString(str); + UTF8String d = UTF8String.fromString(delimiter); + int collationId = CollationFactory.collationNameToId(collationName); + assertArrayEquals(expected, CollationSupport.StringSplitSQL.exec(s, d, collationId)); + } + + @Test + public void testStringSplitSQL() throws SparkException { + // Possible splits + var empty_match = new UTF8String[] { UTF8String.fromString("") }; + var array_abc = new UTF8String[] { UTF8String.fromString("abc") }; + var array_1a2 = new UTF8String[] { UTF8String.fromString("1a2") }; + var array_AaXbB = new UTF8String[] { UTF8String.fromString("AaXbB") }; + var array_aBcDe = new UTF8String[] { UTF8String.fromString("aBcDe") }; + var array_special = new UTF8String[] { UTF8String.fromString("äb世De") }; + var array_abcde = new UTF8String[] { UTF8String.fromString("äbćδe") }; + var full_match = new UTF8String[] { UTF8String.fromString(""), UTF8String.fromString("") }; + var array_1_2 = new UTF8String[] { UTF8String.fromString("1"), UTF8String.fromString("2") }; + var array_A_B = new UTF8String[] { UTF8String.fromString("A"), UTF8String.fromString("B") }; + var array_a_e = new UTF8String[] { UTF8String.fromString("ä"), UTF8String.fromString("e") }; + var array_Aa_bB = new UTF8String[] { UTF8String.fromString("Aa"), UTF8String.fromString("bB") }; + // Edge cases + assertStringSplitSQL("", "", "UTF8_BINARY", empty_match); + assertStringSplitSQL("abc", "", "UTF8_BINARY", array_abc); + assertStringSplitSQL("", "abc", "UTF8_BINARY", empty_match); + assertStringSplitSQL("", "", "UNICODE", empty_match); + assertStringSplitSQL("abc", "", "UNICODE", array_abc); + assertStringSplitSQL("", "abc", "UNICODE", empty_match); + assertStringSplitSQL("", "", "UTF8_BINARY_LCASE", empty_match); + assertStringSplitSQL("abc", "", "UTF8_BINARY_LCASE", array_abc); + assertStringSplitSQL("", "abc", "UTF8_BINARY_LCASE", empty_match); + assertStringSplitSQL("", "", "UNICODE_CI", empty_match); + assertStringSplitSQL("abc", "", "UNICODE_CI", array_abc); + assertStringSplitSQL("", "abc", "UNICODE_CI", empty_match); + // Basic tests + assertStringSplitSQL("1a2", "a", "UTF8_BINARY", array_1_2); + assertStringSplitSQL("1a2", "A", "UTF8_BINARY", array_1a2); + assertStringSplitSQL("1a2", "b", "UTF8_BINARY", array_1a2); + assertStringSplitSQL("1a2", "1a2", "UNICODE", full_match); + assertStringSplitSQL("1a2", "1A2", "UNICODE", array_1a2); + assertStringSplitSQL("1a2", "3b4", "UNICODE", array_1a2); + assertStringSplitSQL("1a2", "A", "UTF8_BINARY_LCASE", array_1_2); + assertStringSplitSQL("1a2", "1A2", "UTF8_BINARY_LCASE", full_match); + assertStringSplitSQL("1a2", "X", "UTF8_BINARY_LCASE", array_1a2); + assertStringSplitSQL("1a2", "a", "UNICODE_CI", array_1_2); + assertStringSplitSQL("1a2", "A", "UNICODE_CI", array_1_2); + assertStringSplitSQL("1a2", "1A2", "UNICODE_CI", full_match); + assertStringSplitSQL("1a2", "123", "UNICODE_CI", array_1a2); + // Case variation + assertStringSplitSQL("AaXbB", "x", "UTF8_BINARY", array_AaXbB); + assertStringSplitSQL("AaXbB", "X", "UTF8_BINARY", array_Aa_bB); + assertStringSplitSQL("AaXbB", "axb", "UNICODE", array_AaXbB); + assertStringSplitSQL("AaXbB", "aXb", "UNICODE", array_A_B); + assertStringSplitSQL("AaXbB", "axb", "UTF8_BINARY_LCASE", array_A_B); + assertStringSplitSQL("AaXbB", "AXB", "UTF8_BINARY_LCASE", array_A_B); + assertStringSplitSQL("AaXbB", "axb", "UNICODE_CI", array_A_B); + assertStringSplitSQL("AaXbB", "AxB", "UNICODE_CI", array_A_B); + // Accent variation + assertStringSplitSQL("aBcDe", "bćd", "UTF8_BINARY", array_aBcDe); + assertStringSplitSQL("aBcDe", "BćD", "UTF8_BINARY", array_aBcDe); + assertStringSplitSQL("aBcDe", "abćde", "UNICODE", array_aBcDe); + assertStringSplitSQL("aBcDe", "aBćDe", "UNICODE", array_aBcDe); + assertStringSplitSQL("aBcDe", "bćd", "UTF8_BINARY_LCASE", array_aBcDe); + assertStringSplitSQL("aBcDe", "BĆD", "UTF8_BINARY_LCASE", array_aBcDe); + assertStringSplitSQL("aBcDe", "abćde", "UNICODE_CI", array_aBcDe); + assertStringSplitSQL("aBcDe", "AbĆdE", "UNICODE_CI", array_aBcDe); + // Variable byte length characters + assertStringSplitSQL("äb世De", "b世D", "UTF8_BINARY", array_a_e); + assertStringSplitSQL("äb世De", "B世d", "UTF8_BINARY", array_special); + assertStringSplitSQL("äbćδe", "bćδ", "UTF8_BINARY", array_a_e); + assertStringSplitSQL("äbćδe", "BcΔ", "UTF8_BINARY", array_abcde); + assertStringSplitSQL("äb世De", "äb世De", "UNICODE", full_match); + assertStringSplitSQL("äb世De", "äB世de", "UNICODE", array_special); + assertStringSplitSQL("äbćδe", "äbćδe", "UNICODE", full_match); + assertStringSplitSQL("äbćδe", "ÄBcΔÉ", "UNICODE", array_abcde); + assertStringSplitSQL("äb世De", "b世D", "UTF8_BINARY_LCASE", array_a_e); + assertStringSplitSQL("äb世De", "B世d", "UTF8_BINARY_LCASE", array_a_e); + assertStringSplitSQL("äbćδe", "bćδ", "UTF8_BINARY_LCASE", array_a_e); + assertStringSplitSQL("äbćδe", "BcΔ", "UTF8_BINARY_LCASE", array_abcde); + assertStringSplitSQL("äb世De", "ab世De", "UNICODE_CI", array_special); + assertStringSplitSQL("äb世De", "AB世dE", "UNICODE_CI", array_special); + assertStringSplitSQL("äbćδe", "ÄbćδE", "UNICODE_CI", full_match); + assertStringSplitSQL("äbćδe", "ÄBcΔÉ", "UNICODE_CI", array_abcde); + } private void assertUpper(String target, String collationName, String expected) throws SparkException { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index fd4fc7a542291..612082c56096f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -3187,13 +3187,14 @@ case class Sentences( case class StringSplitSQL( str: Expression, delimiter: Expression) extends BinaryExpression with NullIntolerant { - override def dataType: DataType = ArrayType(StringType, containsNull = false) + override def dataType: DataType = ArrayType(str.dataType, containsNull = false) + final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def left: Expression = str override def right: Expression = delimiter override def nullSafeEval(string: Any, delimiter: Any): Any = { - val strings = string.asInstanceOf[UTF8String].splitSQL( - delimiter.asInstanceOf[UTF8String], -1); + val strings = CollationSupport.StringSplitSQL.exec(string.asInstanceOf[UTF8String], + delimiter.asInstanceOf[UTF8String], collationId) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } @@ -3201,7 +3202,8 @@ case class StringSplitSQL( val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, delimiter) => { // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"${ev.value} = new $arrayClass($str.splitSQL($delimiter,-1));" + s"${ev.value} = new $arrayClass(" + + s"${CollationSupport.StringSplitSQL.genCode(str, delimiter, collationId)});" }) } @@ -3239,10 +3241,11 @@ case class SplitPart ( partNum: Expression) extends RuntimeReplaceable with ImplicitCastInputTypes { override lazy val replacement: Expression = - ElementAt(StringSplitSQL(str, delimiter), partNum, Some(Literal.create("", StringType)), + ElementAt(StringSplitSQL(str, delimiter), partNum, Some(Literal.create("", str.dataType)), false) override def nodeName: String = "split_part" - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) def children: Seq[Expression] = Seq(str, delimiter, partNum) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { copy(str = newChildren.apply(0), delimiter = newChildren.apply(1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 9c207df95dadb..2b6761475a43a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -88,6 +88,23 @@ class CollationStringExpressionsSuite assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") } + test("Support SplitPart string expression with collation") { + // Supported collations + case class SplitPartTestCase[R](s: String, d: String, p: Int, c: String, result: R) + val testCases = Seq( + SplitPartTestCase("1a2", "a", 2, "UTF8_BINARY", "2"), + SplitPartTestCase("1a2", "a", 2, "UNICODE", "2"), + SplitPartTestCase("1a2", "A", 2, "UTF8_BINARY_LCASE", "2"), + SplitPartTestCase("1a2", "A", 2, "UNICODE_CI", "2") + ) + testCases.foreach(t => { + val query = s"SELECT split_part(collate('${t.s}','${t.c}'),collate('${t.d}','${t.c}'),${t.p})" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c))) + }) + } + test("Support Contains string expression with collation") { // Supported collations case class ContainsTestCase[R](l: String, r: String, c: String, result: R)