From 0e07873d368fa17dfff11e28fd45531f1f388864 Mon Sep 17 00:00:00 2001
From: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Date: Wed, 31 Jul 2024 17:16:17 +0800
Subject: [PATCH] [SPARK-48977][SQL] Optimize string searching under UTF8_LCASE
 collation

### What changes were proposed in this pull request?
Modify string search under UTF8_LCASE collation by utilizing UTF8String character iterator to reduce one order of algorithmic complexity.

### Why are the changes needed?
Optimize implementation for `contains`, `startsWith`, `endsWith`, `locate` expressions.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #47444 from uros-db/optimize-search.

Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
---
 .../util/CollationAwareUTF8String.java        | 83 ++++++++++++++++---
 1 file changed, 73 insertions(+), 10 deletions(-)

diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
index 430d1fb89832b..5b005f152c51a 100644
--- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java
@@ -85,13 +85,44 @@ private static int lowercaseMatchLengthFrom(
       final UTF8String lowercasePattern,
       int startPos) {
     assert startPos >= 0;
-    for (int len = 0; len <= target.numChars() - startPos; ++len) {
-      if (lowerCaseCodePoints(target.substring(startPos, startPos + len))
-          .equals(lowercasePattern)) {
-        return len;
+    // Use code point iterators for efficient string search.
+    Iterator<Integer> targetIterator = target.codePointIterator();
+    Iterator<Integer> patternIterator = lowercasePattern.codePointIterator();
+    // Skip to startPos in the target string.
+    for (int i = 0; i < startPos; ++i) {
+      if (targetIterator.hasNext()) {
+        targetIterator.next();
+      } else {
+        return MATCH_NOT_FOUND;
       }
     }
-    return MATCH_NOT_FOUND;
+    // Compare the characters in the target and pattern strings.
+    int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint;
+    while (targetIterator.hasNext() && patternIterator.hasNext()) {
+      if (codePointBuffer != -1) {
+        targetCodePoint = codePointBuffer;
+        codePointBuffer = -1;
+      } else {
+        // Use buffered lowercase code point iteration to handle one-to-many case mappings.
+        targetCodePoint = getLowercaseCodePoint(targetIterator.next());
+        if (targetCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) {
+          targetCodePoint = CODE_POINT_LOWERCASE_I;
+          codePointBuffer = CODE_POINT_COMBINING_DOT;
+        }
+        ++matchLength;
+      }
+      patternCodePoint = patternIterator.next();
+      if (targetCodePoint != patternCodePoint) {
+        return MATCH_NOT_FOUND;
+      }
+    }
+    // If the pattern string has more characters, or the match is found at the middle of a
+    // character that maps to multiple characters in lowercase, then match is not found.
+    if (patternIterator.hasNext() || codePointBuffer != -1) {
+      return MATCH_NOT_FOUND;
+    }
+    // If all characters are equal, return the length of the match in the target string.
+    return matchLength;
   }
 
   /**
@@ -155,13 +186,45 @@ private static int lowercaseMatchLengthUntil(
       final UTF8String target,
       final UTF8String lowercasePattern,
       int endPos) {
-    assert endPos <= target.numChars();
-    for (int len = 0; len <= endPos; ++len) {
-      if (lowerCaseCodePoints(target.substring(endPos - len, endPos)).equals(lowercasePattern)) {
-        return len;
+    assert endPos >= 0;
+    // Use code point iterators for efficient string search.
+    Iterator<Integer> targetIterator = target.reverseCodePointIterator();
+    Iterator<Integer> patternIterator = lowercasePattern.reverseCodePointIterator();
+    // Skip to startPos in the target string.
+    for (int i = endPos; i < target.numChars(); ++i) {
+      if (targetIterator.hasNext()) {
+        targetIterator.next();
+      } else {
+        return MATCH_NOT_FOUND;
       }
     }
-    return MATCH_NOT_FOUND;
+    // Compare the characters in the target and pattern strings.
+    int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint;
+    while (targetIterator.hasNext() && patternIterator.hasNext()) {
+      if (codePointBuffer != -1) {
+        targetCodePoint = codePointBuffer;
+        codePointBuffer = -1;
+      } else {
+        // Use buffered lowercase code point iteration to handle one-to-many case mappings.
+        targetCodePoint = getLowercaseCodePoint(targetIterator.next());
+        if (targetCodePoint == CODE_POINT_COMBINED_LOWERCASE_I_DOT) {
+          targetCodePoint = CODE_POINT_COMBINING_DOT;
+          codePointBuffer = CODE_POINT_LOWERCASE_I;
+        }
+        ++matchLength;
+      }
+      patternCodePoint = patternIterator.next();
+      if (targetCodePoint != patternCodePoint) {
+        return MATCH_NOT_FOUND;
+      }
+    }
+    // If the pattern string has more characters, or the match is found at the middle of a
+    // character that maps to multiple characters in lowercase, then match is not found.
+    if (patternIterator.hasNext() || codePointBuffer != -1) {
+      return MATCH_NOT_FOUND;
+    }
+    // If all characters are equal, return the length of the match in the target string.
+    return matchLength;
   }
 
   /**