diff --git a/velox/docs/functions/spark/string.rst b/velox/docs/functions/spark/string.rst index 3bced7be91b04..fd76762095a7b 100644 --- a/velox/docs/functions/spark/string.rst +++ b/velox/docs/functions/spark/string.rst @@ -195,14 +195,21 @@ Unless specified otherwise, all functions return NULL if at least one of the arg SELECT soundex('Miller'); -- "M460" -.. spark:function:: split(string, delimiter) -> array(string) - - Splits ``string`` on ``delimiter`` and returns an array. - The delimiter is any string matching regex, supported by re2. :: +.. spark:function:: split(string, delimiter[, limit]) -> array(string) + Splits ``string`` around occurrences that match ``delimiter`` and returns an array + with a length of at most ``limit``. ``delimiter`` is a string representing a regular + expression. ``limit`` is an integer which controls the number of times the regex is + applied. By default, ``limit`` is -1. When ``limit`` > 0, the resulting array's + length will not be more than ``limit``, and the resulting array's last entry will + contain all input beyond the last matched regex. When ``limit`` <= 0, ``regex`` will + be applied as many times as possible, and the resulting array can be of any size. :: SELECT split('oneAtwoBthreeC', '[ABC]'); -- ["one","two","three",""] + SELECT split('oneAtwoBthreeC', '[ABC]', 2); -- ["one","twoBthreeC"] SELECT split('one', ''); -- ["o", "n", "e", ""] SELECT split('one', '1'); -- ["one"] + SELECT split('abcd', ''); -- ["a", "b", "c", "d"] + SELECT split('abcd', '', 3); -- ["a", "b", "c"] .. spark:function:: split(string, delimiter, limit) -> array(string) :noindex: diff --git a/velox/functions/sparksql/SplitFunctions.cpp b/velox/functions/sparksql/SplitFunctions.cpp index 1c97b6500f98c..dfc22bcb14690 100644 --- a/velox/functions/sparksql/SplitFunctions.cpp +++ b/velox/functions/sparksql/SplitFunctions.cpp @@ -54,18 +54,28 @@ class Split final : public exec::VectorFunction { const auto* rawStrings = strings->data(); const auto delim = delims->valueAt(0); rows.applyToSelected([&](vector_size_t row) { - applyInner(rawStrings[row], delim, limit, row, resultWriter); + if (delim.size() == 0) { + splitEmptyDelimer(rawStrings[row], limit, row, resultWriter); + } else { + splitInner(rawStrings[row], delim, limit, row, resultWriter); + } }); } else { // The rest of the cases are handled through this general path and no // direct access. rows.applyToSelected([&](vector_size_t row) { - applyInner( - strings->valueAt(row), - delims->valueAt(row), - limit, - row, - resultWriter); + const auto delim = delims->valueAt(row); + if (delim.size() == 0) { + splitEmptyDelimer( + strings->valueAt(row), limit, row, resultWriter); + } else { + splitInner( + strings->valueAt(row), + delim, + limit, + row, + resultWriter); + } }); } @@ -78,7 +88,40 @@ class Split final : public exec::VectorFunction { ->acquireSharedStringBuffers(strings->base()); } - inline void applyInner( + private: + mutable functions::detail::ReCache cache_; + + // When pattern is empty, split each character out. Since Spark 3.4, when + // delimiter is empty, the result does not include an empty tail string, e.g. + // split('abc', '') outputs ["a", "b", "c"] instead of ["a", "b", "c", ""]. + // The result does not include remaining string when limit is smaller than the + // string size, e.g. split('abc', '', 2) outputs ["a", "b"] instead of ["a", + // "bc"]. + void splitEmptyDelimer( + const StringView current, + int64_t limit, + vector_size_t row, + exec::VectorWriter>& resultWriter) const { + resultWriter.setOffset(row); + auto& arrayWriter = resultWriter.current(); + if (current.size() == 0) { + arrayWriter.add_item().setNoCopy(StringView()); + resultWriter.commit(); + return; + } + + const char* const begin = current.begin(); + const char* const end = current.end(); + const char* pos = begin; + while (pos < end && pos < limit + begin) { + arrayWriter.add_item().setNoCopy(StringView(pos, 1)); + pos += 1; + } + resultWriter.commit(); + } + + // Split with a non-empty pattern. + void splitInner( StringView input, const StringView delim, int64_t limit, @@ -99,6 +142,7 @@ class Split final : public exec::VectorFunction { // adding them to the elements vector, until we reached the end of the // string or the limit. int32_t addedElements{0}; + bool emptyDelim = delim.size() == 0 ? true : false; auto* re = cache_.findOrCompile(delim); const auto re2String = re2::StringPiece(input.data(), input.size()); size_t pos = 0; @@ -110,11 +154,6 @@ class Split final : public exec::VectorFunction { auto offset = fullMatch.data() - start; const auto size = fullMatch.size(); - if (size == 0) { - // delimer is empty string - offset += 1; - } - if (offset >= input.size()) { break; } @@ -135,9 +174,6 @@ class Split final : public exec::VectorFunction { StringView(input.data() + pos, input.size() - pos)); resultWriter.commit(); } - - private: - mutable functions::detail::ReCache cache_; }; std::shared_ptr createSplit( diff --git a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp index 6dd6b5e06d792..1de8b6df50bd4 100644 --- a/velox/functions/sparksql/tests/SplitFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/SplitFunctionsTest.cpp @@ -254,6 +254,13 @@ TEST_F(SplitTest, split) { {""}, }); assertEqualVectors(expected, run(inputStrings, delim, "split(C0, C1)")); + auto expected2 = makeArrayVector({ + {"I", ","}, + {"o", "n"}, + {""}, + }); + assertEqualVectors( + expected2, run(inputStrings, delim, "split(C0, C1, C2)", 2)); // Non-ascii, flat strings, flat delimiter, no limit. delim = "లేదా";