Skip to content

Commit

Permalink
[SPARK-47350][SQL] Add collation support for SplitPart string expression
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Introduce collation awareness for string expression: split_part.

### Why are the changes needed?
Add collation support for built-in string function in Spark.

### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use collated strings within arguments for built-in string function: split_part.

### How was this patch tested?
Unit collation support tests and e2e sql tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46158 from uros-db/SPARK-47350.

Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
uros-db authored and cloud-fan committed Apr 26, 2024
1 parent b8b6d17 commit 6c40214
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import org.apache.spark.unsafe.types.UTF8String;

import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;

/**
Expand All @@ -36,6 +38,62 @@ public final class CollationSupport {
* Collation-aware string expressions.
*/

public static class StringSplitSQL {
public static UTF8String[] exec(final UTF8String s, final UTF8String d, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return execBinary(s, d);
} else if (collation.supportsLowercaseEquality) {
return execLowercase(s, d);
} else {
return execICU(s, d, collationId);
}
}
public static String genCode(final String s, final String d, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
String expr = "CollationSupport.StringSplitSQL.exec";
if (collation.supportsBinaryEquality) {
return String.format(expr + "Binary(%s, %s)", s, d);
} else if (collation.supportsLowercaseEquality) {
return String.format(expr + "Lowercase(%s, %s)", s, d);
} else {
return String.format(expr + "ICU(%s, %s, %d)", s, d, collationId);
}
}
public static UTF8String[] execBinary(final UTF8String string, final UTF8String delimiter) {
return string.splitSQL(delimiter, -1);
}
public static UTF8String[] execLowercase(final UTF8String string, final UTF8String delimiter) {
if (delimiter.numBytes() == 0) return new UTF8String[] { string };
if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 };
Pattern pattern = Pattern.compile(Pattern.quote(delimiter.toString()),
CollationSupport.lowercaseRegexFlags);
String[] splits = pattern.split(string.toString(), -1);
UTF8String[] res = new UTF8String[splits.length];
for (int i = 0; i < res.length; i++) {
res[i] = UTF8String.fromString(splits[i]);
}
return res;
}
public static UTF8String[] execICU(final UTF8String string, final UTF8String delimiter,
final int collationId) {
if (delimiter.numBytes() == 0) return new UTF8String[] { string };
if (string.numBytes() == 0) return new UTF8String[] { UTF8String.EMPTY_UTF8 };
List<UTF8String> strings = new ArrayList<>();
String target = string.toString(), pattern = delimiter.toString();
StringSearch stringSearch = CollationFactory.getStringSearch(target, pattern, collationId);
int start = 0, end;
while ((end = stringSearch.next()) != StringSearch.DONE) {
strings.add(UTF8String.fromString(target.substring(start, end)));
start = end + stringSearch.getMatchLength();
}
if (start <= target.length()) {
strings.add(UTF8String.fromString(target.substring(start)));
}
return strings.toArray(new UTF8String[0]);
}
}

public static class Contains {
public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) {
CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,92 @@ public void testEndsWith() throws SparkException {
assertEndsWith("The KKelvin", "KKelvin,", "UTF8_BINARY_LCASE", false);
}

private void assertStringSplitSQL(String str, String delimiter, String collationName,
UTF8String[] expected) throws SparkException {
UTF8String s = UTF8String.fromString(str);
UTF8String d = UTF8String.fromString(delimiter);
int collationId = CollationFactory.collationNameToId(collationName);
assertArrayEquals(expected, CollationSupport.StringSplitSQL.exec(s, d, collationId));
}

@Test
public void testStringSplitSQL() throws SparkException {
// Possible splits
var empty_match = new UTF8String[] { UTF8String.fromString("") };
var array_abc = new UTF8String[] { UTF8String.fromString("abc") };
var array_1a2 = new UTF8String[] { UTF8String.fromString("1a2") };
var array_AaXbB = new UTF8String[] { UTF8String.fromString("AaXbB") };
var array_aBcDe = new UTF8String[] { UTF8String.fromString("aBcDe") };
var array_special = new UTF8String[] { UTF8String.fromString("äb世De") };
var array_abcde = new UTF8String[] { UTF8String.fromString("äbćδe") };
var full_match = new UTF8String[] { UTF8String.fromString(""), UTF8String.fromString("") };
var array_1_2 = new UTF8String[] { UTF8String.fromString("1"), UTF8String.fromString("2") };
var array_A_B = new UTF8String[] { UTF8String.fromString("A"), UTF8String.fromString("B") };
var array_a_e = new UTF8String[] { UTF8String.fromString("ä"), UTF8String.fromString("e") };
var array_Aa_bB = new UTF8String[] { UTF8String.fromString("Aa"), UTF8String.fromString("bB") };
// Edge cases
assertStringSplitSQL("", "", "UTF8_BINARY", empty_match);
assertStringSplitSQL("abc", "", "UTF8_BINARY", array_abc);
assertStringSplitSQL("", "abc", "UTF8_BINARY", empty_match);
assertStringSplitSQL("", "", "UNICODE", empty_match);
assertStringSplitSQL("abc", "", "UNICODE", array_abc);
assertStringSplitSQL("", "abc", "UNICODE", empty_match);
assertStringSplitSQL("", "", "UTF8_BINARY_LCASE", empty_match);
assertStringSplitSQL("abc", "", "UTF8_BINARY_LCASE", array_abc);
assertStringSplitSQL("", "abc", "UTF8_BINARY_LCASE", empty_match);
assertStringSplitSQL("", "", "UNICODE_CI", empty_match);
assertStringSplitSQL("abc", "", "UNICODE_CI", array_abc);
assertStringSplitSQL("", "abc", "UNICODE_CI", empty_match);
// Basic tests
assertStringSplitSQL("1a2", "a", "UTF8_BINARY", array_1_2);
assertStringSplitSQL("1a2", "A", "UTF8_BINARY", array_1a2);
assertStringSplitSQL("1a2", "b", "UTF8_BINARY", array_1a2);
assertStringSplitSQL("1a2", "1a2", "UNICODE", full_match);
assertStringSplitSQL("1a2", "1A2", "UNICODE", array_1a2);
assertStringSplitSQL("1a2", "3b4", "UNICODE", array_1a2);
assertStringSplitSQL("1a2", "A", "UTF8_BINARY_LCASE", array_1_2);
assertStringSplitSQL("1a2", "1A2", "UTF8_BINARY_LCASE", full_match);
assertStringSplitSQL("1a2", "X", "UTF8_BINARY_LCASE", array_1a2);
assertStringSplitSQL("1a2", "a", "UNICODE_CI", array_1_2);
assertStringSplitSQL("1a2", "A", "UNICODE_CI", array_1_2);
assertStringSplitSQL("1a2", "1A2", "UNICODE_CI", full_match);
assertStringSplitSQL("1a2", "123", "UNICODE_CI", array_1a2);
// Case variation
assertStringSplitSQL("AaXbB", "x", "UTF8_BINARY", array_AaXbB);
assertStringSplitSQL("AaXbB", "X", "UTF8_BINARY", array_Aa_bB);
assertStringSplitSQL("AaXbB", "axb", "UNICODE", array_AaXbB);
assertStringSplitSQL("AaXbB", "aXb", "UNICODE", array_A_B);
assertStringSplitSQL("AaXbB", "axb", "UTF8_BINARY_LCASE", array_A_B);
assertStringSplitSQL("AaXbB", "AXB", "UTF8_BINARY_LCASE", array_A_B);
assertStringSplitSQL("AaXbB", "axb", "UNICODE_CI", array_A_B);
assertStringSplitSQL("AaXbB", "AxB", "UNICODE_CI", array_A_B);
// Accent variation
assertStringSplitSQL("aBcDe", "bćd", "UTF8_BINARY", array_aBcDe);
assertStringSplitSQL("aBcDe", "BćD", "UTF8_BINARY", array_aBcDe);
assertStringSplitSQL("aBcDe", "abćde", "UNICODE", array_aBcDe);
assertStringSplitSQL("aBcDe", "aBćDe", "UNICODE", array_aBcDe);
assertStringSplitSQL("aBcDe", "bćd", "UTF8_BINARY_LCASE", array_aBcDe);
assertStringSplitSQL("aBcDe", "BĆD", "UTF8_BINARY_LCASE", array_aBcDe);
assertStringSplitSQL("aBcDe", "abćde", "UNICODE_CI", array_aBcDe);
assertStringSplitSQL("aBcDe", "AbĆdE", "UNICODE_CI", array_aBcDe);
// Variable byte length characters
assertStringSplitSQL("äb世De", "b世D", "UTF8_BINARY", array_a_e);
assertStringSplitSQL("äb世De", "B世d", "UTF8_BINARY", array_special);
assertStringSplitSQL("äbćδe", "bćδ", "UTF8_BINARY", array_a_e);
assertStringSplitSQL("äbćδe", "BcΔ", "UTF8_BINARY", array_abcde);
assertStringSplitSQL("äb世De", "äb世De", "UNICODE", full_match);
assertStringSplitSQL("äb世De", "äB世de", "UNICODE", array_special);
assertStringSplitSQL("äbćδe", "äbćδe", "UNICODE", full_match);
assertStringSplitSQL("äbćδe", "ÄBcΔÉ", "UNICODE", array_abcde);
assertStringSplitSQL("äb世De", "b世D", "UTF8_BINARY_LCASE", array_a_e);
assertStringSplitSQL("äb世De", "B世d", "UTF8_BINARY_LCASE", array_a_e);
assertStringSplitSQL("äbćδe", "bćδ", "UTF8_BINARY_LCASE", array_a_e);
assertStringSplitSQL("äbćδe", "BcΔ", "UTF8_BINARY_LCASE", array_abcde);
assertStringSplitSQL("äb世De", "ab世De", "UNICODE_CI", array_special);
assertStringSplitSQL("äb世De", "AB世dE", "UNICODE_CI", array_special);
assertStringSplitSQL("äbćδe", "ÄbćδE", "UNICODE_CI", full_match);
assertStringSplitSQL("äbćδe", "ÄBcΔÉ", "UNICODE_CI", array_abcde);
}

private void assertUpper(String target, String collationName, String expected)
throws SparkException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3187,21 +3187,23 @@ case class Sentences(
case class StringSplitSQL(
str: Expression,
delimiter: Expression) extends BinaryExpression with NullIntolerant {
override def dataType: DataType = ArrayType(StringType, containsNull = false)
override def dataType: DataType = ArrayType(str.dataType, containsNull = false)
final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId
override def left: Expression = str
override def right: Expression = delimiter

override def nullSafeEval(string: Any, delimiter: Any): Any = {
val strings = string.asInstanceOf[UTF8String].splitSQL(
delimiter.asInstanceOf[UTF8String], -1);
val strings = CollationSupport.StringSplitSQL.exec(string.asInstanceOf[UTF8String],
delimiter.asInstanceOf[UTF8String], collationId)
new GenericArrayData(strings.asInstanceOf[Array[Any]])
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, (str, delimiter) => {
// Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
s"${ev.value} = new $arrayClass($str.splitSQL($delimiter,-1));"
s"${ev.value} = new $arrayClass(" +
s"${CollationSupport.StringSplitSQL.genCode(str, delimiter, collationId)});"
})
}

Expand Down Expand Up @@ -3239,10 +3241,11 @@ case class SplitPart (
partNum: Expression)
extends RuntimeReplaceable with ImplicitCastInputTypes {
override lazy val replacement: Expression =
ElementAt(StringSplitSQL(str, delimiter), partNum, Some(Literal.create("", StringType)),
ElementAt(StringSplitSQL(str, delimiter), partNum, Some(Literal.create("", str.dataType)),
false)
override def nodeName: String = "split_part"
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType)
def children: Seq[Expression] = Seq(str, delimiter, partNum)
protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
copy(str = newChildren.apply(0), delimiter = newChildren.apply(1),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ class CollationStringExpressionsSuite
assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
}

test("Support SplitPart string expression with collation") {
// Supported collations
case class SplitPartTestCase[R](s: String, d: String, p: Int, c: String, result: R)
val testCases = Seq(
SplitPartTestCase("1a2", "a", 2, "UTF8_BINARY", "2"),
SplitPartTestCase("1a2", "a", 2, "UNICODE", "2"),
SplitPartTestCase("1a2", "A", 2, "UTF8_BINARY_LCASE", "2"),
SplitPartTestCase("1a2", "A", 2, "UNICODE_CI", "2")
)
testCases.foreach(t => {
val query = s"SELECT split_part(collate('${t.s}','${t.c}'),collate('${t.d}','${t.c}'),${t.p})"
// Result & data type
checkAnswer(sql(query), Row(t.result))
assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
})
}

test("Support Contains string expression with collation") {
// Supported collations
case class ContainsTestCase[R](l: String, r: String, c: String, result: R)
Expand Down

0 comments on commit 6c40214

Please sign in to comment.