From 87e1f6b89a5839fa21f489a3db942e1d964efd54 Mon Sep 17 00:00:00 2001 From: arash andishgar Date: Wed, 19 Feb 2025 16:47:29 +0330 Subject: [PATCH 1/2] add extract_regex_span function --- cpp/src/arrow/compute/api_scalar.cc | 11 +- cpp/src/arrow/compute/api_scalar.h | 10 + .../compute/kernels/scalar_string_ascii.cc | 197 +++++++++++++++--- .../compute/kernels/scalar_string_test.cc | 49 ++++- 4 files changed, 232 insertions(+), 35 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 61a16f5f5eb9b..e6606ba53eda8 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -30,7 +30,6 @@ #include "arrow/type.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" - namespace arrow { namespace internal { @@ -325,6 +324,9 @@ static auto kElementWiseAggregateOptionsType = DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls)); static auto kExtractRegexOptionsType = GetFunctionOptionsType( DataMember("pattern", &ExtractRegexOptions::pattern)); +static auto kExtractRegexSpanOptionsType = + GetFunctionOptionsType( + DataMember("pattern", &ExtractRegexSpanOptions::pattern)); static auto kJoinOptionsType = GetFunctionOptionsType( DataMember("null_handling", &JoinOptions::null_handling), DataMember("null_replacement", &JoinOptions::null_replacement)); @@ -438,6 +440,12 @@ ExtractRegexOptions::ExtractRegexOptions(std::string pattern) ExtractRegexOptions::ExtractRegexOptions() : ExtractRegexOptions("") {} constexpr char ExtractRegexOptions::kTypeName[]; +ExtractRegexSpanOptions::ExtractRegexSpanOptions(std::string pattern) + : FunctionOptions(internal::kExtractRegexSpanOptionsType), + pattern(std::move(pattern)) {} +ExtractRegexSpanOptions::ExtractRegexSpanOptions() : ExtractRegexSpanOptions("") {} +constexpr char ExtractRegexSpanOptions::kTypeName[]; + JoinOptions::JoinOptions(NullHandlingBehavior null_handling, std::string null_replacement) : FunctionOptions(internal::kJoinOptionsType), null_handling(null_handling), @@ -684,6 +692,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunctionOptionsType(kDayOfWeekOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType)); + DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexSpanOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kListSliceOptionsType)); DCHECK_OK(registry->AddFunctionOptionsType(kMakeStructOptionsType)); diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 0e5a388b1074f..3e299ac134f79 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -264,7 +264,17 @@ class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { /// Regular expression with named capture fields std::string pattern; }; +class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions { + public: + explicit ExtractRegexSpanOptions(std::string pattern); + ExtractRegexSpanOptions(); + static constexpr char const kTypeName[] = "ExtractRegexSpanOptions"; + /// Regular expression with named capture fields + std::string pattern; + + /// Shows the matched string +}; /// Options for IsIn and IndexIn functions class ARROW_EXPORT SetLookupOptions : public FunctionOptions { public: diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc index e58f7b065a8e5..535bc7979a380 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -22,6 +22,7 @@ #include #include "arrow/array/builder_nested.h" +#include "arrow/array/builder_primitive.h" #include "arrow/compute/kernels/scalar_string_internal.h" #include "arrow/result.h" #include "arrow/util/config.h" @@ -2185,51 +2186,61 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) { using ExtractRegexState = OptionsWrapper; // TODO cache this once per ExtractRegexOptions -struct ExtractRegexData { - // Use unique_ptr<> because RE2 is non-movable (for ARROW_ASSIGN_OR_RAISE) - std::unique_ptr regex; - std::vector group_names; - +class ExtractRegexData { + public: static Result Make(const ExtractRegexOptions& options, bool is_utf8 = true) { ExtractRegexData data(options.pattern, is_utf8); - RETURN_NOT_OK(RegexStatus(*data.regex)); - - const int group_count = data.regex->NumberOfCapturingGroups(); - const auto& name_map = data.regex->CapturingGroupNames(); - data.group_names.reserve(group_count); - - for (int i = 0; i < group_count; i++) { - auto item = name_map.find(i + 1); // re2 starts counting from 1 - if (item == name_map.end()) { - // XXX should we instead just create fields with an empty name? - return Status::Invalid("Regular expression contains unnamed groups"); - } - data.group_names.emplace_back(item->second); - } + ARROW_RETURN_NOT_OK(data.Init()); return data; } Result ResolveOutputType(const std::vector& types) const { const DataType* input_type = types[0].type; - if (input_type == nullptr) { + // as mentioned here + // https://arrow.apache.org/docs/developers/cpp/development.html#code-style-linting-and-ci + // nullptr should not be used + if (input_type == NULLPTR) { // No input type specified - return nullptr; + return NULLPTR; } // Input type is either [Large]Binary or [Large]String and is also the type // of each field in the output struct type. DCHECK(is_base_binary_like(input_type->id())); FieldVector fields; - fields.reserve(group_names.size()); + fields.reserve(group_names_.size()); std::shared_ptr owned_type = input_type->GetSharedPtr(); - std::transform(group_names.begin(), group_names.end(), std::back_inserter(fields), + std::transform(group_names_.begin(), group_names_.end(), std::back_inserter(fields), [&](const std::string& name) { return field(name, owned_type); }); - return struct_(std::move(fields)); + return struct_(fields); } + int64_t num_group() const { return group_names_.size(); } + std::shared_ptr regex() const { return regex_; } - private: + protected: explicit ExtractRegexData(const std::string& pattern, bool is_utf8 = true) - : regex(new RE2(pattern, MakeRE2Options(is_utf8))) {} + : regex_(new RE2(pattern, MakeRE2Options(is_utf8))) {} + + Status Init() { + RETURN_NOT_OK(RegexStatus(*regex_)); + + const int group_count = regex_->NumberOfCapturingGroups(); + const auto& name_map = regex_->CapturingGroupNames(); + group_names_.reserve(group_count); + + for (int i = 0; i < group_count; i++) { + auto item = name_map.find(i + 1); // re2 starts counting from 1 + if (item == name_map.end()) { + // XXX should we instead just create fields with an empty name? + return Status::Invalid("Regular expression contains unnamed groups"); + } + group_names_.emplace_back(item->second); + } + return Status::OK(); + } + + std::shared_ptr regex_; + std::vector group_names_; }; Result ResolveExtractRegexOutput(KernelContext* ctx, @@ -2250,7 +2261,7 @@ struct ExtractRegexBase { explicit ExtractRegexBase(const ExtractRegexData& data) : data(data), - group_count(static_cast(data.group_names.size())), + group_count(static_cast(data.num_group())), found_values(group_count) { args.reserve(group_count); args_pointers.reserve(group_count); @@ -2265,7 +2276,7 @@ struct ExtractRegexBase { } bool Match(std::string_view s) { - return RE2::PartialMatchN(ToStringPiece(s), *data.regex, args_pointers_start, + return RE2::PartialMatchN(ToStringPiece(s), *data.regex(), args_pointers_start, group_count); } }; @@ -2284,11 +2295,10 @@ struct ExtractRegex : public ExtractRegexBase { } Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - // TODO: why is this needed? Type resolution should already be - // done and the output type set in the output variable - ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, data.ResolveOutputType(batch.GetTypes())); - DCHECK_NE(out_type.type, nullptr); - std::shared_ptr type = out_type.GetSharedPtr(); + ExtractRegexOptions options = ExtractRegexState::Get(ctx); + DCHECK_NE(out->array_data(), NULLPTR); + std::shared_ptr type = out->array_data()->type; + DCHECK_NE(type, NULLPTR); std::unique_ptr array_builder; RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type, &array_builder)); @@ -2347,6 +2357,126 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) { } DCHECK_OK(registry->AddFunction(std::move(func))); } +class ExtractRegexSpanData : public ExtractRegexData { + public: + static Result Make(const std::string& pattern) { + auto data = ExtractRegexSpanData(pattern, true); + ARROW_RETURN_NOT_OK(data.Init()); + return data; + } + + Result ResolveOutputType(const std::vector& types) const { + const DataType* input_type = types[0].type; + if (input_type == NULLPTR) { + return NULLPTR; + } + DCHECK(is_base_binary_like(input_type->id())); + const size_t field_count = group_names_.size(); + FieldVector fields; + fields.reserve(field_count); + const auto owned_type = input_type->GetSharedPtr(); + for (const auto& group_name : group_names_) { + auto type = is_binary_like(owned_type->id()) ? int32() : int64(); + // size list is 2 as every span contains position and length + fields.push_back(field(group_name + "_span", fixed_size_list(type, 2))); + } + return struct_(fields); + } + + private: + ExtractRegexSpanData(const std::string& pattern, const bool is_utf8) + : ExtractRegexData(pattern, is_utf8) {} +}; + +template +struct ExtractRegexSpan : ExtractRegexBase { + using ArrayType = typename TypeTraits::ArrayType; + using BuilderType = typename TypeTraits::BuilderType; + using ExtractRegexBase::ExtractRegexBase; + + static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + auto options = OptionsWrapper::Get(ctx); + ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexSpanData::Make(options.pattern)); + return ExtractRegexSpan{data}.Extract(ctx, batch, out); + } + Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + DCHECK_NE(out->array_data(), NULLPTR); + std::shared_ptr out_type = out->array_data()->type; + DCHECK_NE(out_type, NULLPTR); + std::unique_ptr out_builder; + ARROW_RETURN_NOT_OK( + MakeBuilder(ctx->memory_pool(), out->type()->GetSharedPtr(), &out_builder)); + auto struct_builder = checked_pointer_cast(std::move(out_builder)); + std::vector span_builders; + std::vector array_builders; + span_builders.reserve(group_count); + array_builders.reserve(group_count); + for (int i = 0; i < group_count; i++) { + span_builders.push_back( + checked_cast(struct_builder->field_builder(i))); + array_builders.push_back(span_builders[i]->value_builder()); + } + auto visit_null = [&]() { return struct_builder->AppendNull(); }; + auto visit_value = [&](std::string_view element) -> Status { + if (Match(element)) { + for (int i = 0; i < group_count; i++) { + // https://github.com/google/re2/issues/24#issuecomment-97653183 + if (found_values[i].data() != NULLPTR) { + int64_t begin = found_values[i].data() - element.data(); + int64_t size = found_values[i].size(); + if (is_binary_like(batch.GetTypes()[0].id())) { + ARROW_RETURN_NOT_OK(checked_cast(array_builders[i]) + ->AppendValues({static_cast(begin), + static_cast(size)})); + } else { + ARROW_RETURN_NOT_OK(checked_cast(array_builders[i]) + ->AppendValues({begin, size})); + } + + ARROW_RETURN_NOT_OK(span_builders[i]->Append()); + } else { + ARROW_RETURN_NOT_OK(span_builders[i]->AppendNull()); + } + } + ARROW_RETURN_NOT_OK(struct_builder->Append()); + } else { + ARROW_RETURN_NOT_OK(struct_builder->AppendNull()); + } + return Status::OK(); + }; + ARROW_RETURN_NOT_OK( + VisitArraySpanInline(batch[0].array, visit_value, visit_null)); + + ARROW_ASSIGN_OR_RAISE(auto out_array, struct_builder->Finish()); + out->value = out_array->data(); + return Status::OK(); + } +}; + +const FunctionDoc extract_regex_doc_span( + "likes extract_regex; however, it contains the position and length of results", "", + {"strings"}, "ExtractRegexSpanOptions", true); + +Result resolver(KernelContext* ctx, const std::vector& types) { + auto options = OptionsWrapper::Get(*ctx->state()); + ARROW_ASSIGN_OR_RAISE(auto span, ExtractRegexSpanData::Make(options.pattern)); + return span.ResolveOutputType(types); +} + +void AddAsciiStringExtractRegexSpan(FunctionRegistry* registry) { + auto func = std::make_shared("extract_regex_span", Arity::Unary(), + extract_regex_doc_span); + OutputType output_type(resolver); + for (const auto& type : BaseBinaryTypes()) { + ScalarKernel kernel({type}, output_type, + GenerateVarBinaryToVarBinary(type), + OptionsWrapper::Init); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + DCHECK_OK(func->AddKernel(std::move(kernel))); + } + DCHECK_OK(registry->AddFunction(func)); +} #endif // ARROW_WITH_RE2 // ---------------------------------------------------------------------- @@ -3457,6 +3587,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddAsciiStringSplitWhitespace(registry); #ifdef ARROW_WITH_RE2 AddAsciiStringSplitRegex(registry); + AddAsciiStringExtractRegexSpan(registry); #endif AddAsciiStringJoin(registry); AddAsciiStringRepeat(registry); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 38455dc146711..4023491dee5b0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -314,6 +314,7 @@ TYPED_TEST(TestBinaryKernels, NonUtf8Regex) { this->MakeArray({"\xfc\x40", "this \xfc\x40 that \xfc\x40"}), this->MakeArray({"bazz", "this bazz that \xfc\x40"}), &options); } + // TODO the following test is broken { ExtractRegexOptions options("(?P[\\xfc])(?P\\d)"); auto null_bitmap = std::make_shared("0"); @@ -370,6 +371,7 @@ TYPED_TEST(TestBinaryKernels, NonUtf8WithNullRegex) { this->template MakeArray({{"\x00\x40", 2}}), this->type(), R"(["bazz"])", &options); } + // TODO the following test is broken { ExtractRegexOptions options("(?P[\\x00])(?P\\d)"); auto null_bitmap = std::make_shared("0"); @@ -1958,6 +1960,29 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegex) { R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "3"}])", &options); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSapn) { + ExtractRegexSpanOptions options{"(?P[ab])(?P\\d)"}; + auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64(); + auto out_type = struct_({field("letter_span", fixed_size_list(type_fixe_size_list, 2)), + field("digit_span", fixed_size_list(type_fixe_size_list, 2))}); + this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options); + this->CheckUnary( + "extract_regex_span", R"(["a1", "b2", "c3", null])", out_type, + R"([{"letter_span":[0,1], "digit_span":[1,1]}, {"letter_span":[0,1], "digit_span":[1,1]}, null, null])", + &options); + this->CheckUnary( + "extract_regex_span", R"(["a1", "c3", null, "b2"])", out_type, + R"([{"letter_span":[0,1], "digit_span": [1,1]}, null, null, {"letter_span":[0,1], "digit_span":[1,1]}])", + &options); + this->CheckUnary( + "extract_regex_span", R"(["a1", "b2"])", out_type, + R"([{"letter_span": [0,1], "digit_span": [1,1]}, {"letter_span": [0,1], "digit_span": [1,1]}])", + &options); + this->CheckUnary( + "extract_regex_span", R"(["a1", "zb3z"])", out_type, + R"([{"letter_span": [0,1], "digit_span": [1,1]}, {"letter_span": [1,1], "digit_span": [2,1]}])", + &options); +} TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoCapture) { // XXX Should we accept this or is it a user error? @@ -1966,12 +1991,23 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoCapture) { this->CheckUnary("extract_regex", R"(["oofoo", "bar", null])", type, R"([{}, null, null])", &options); } - +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanNoCapture) { + // XXX Should we accept this or is it a user error? + ExtractRegexSpanOptions options{"foo"}; + auto type = struct_({}); + this->CheckUnary("extract_regex_span", R"(["oofoo", "bar", null])", type, + R"([{}, null, null])", &options); +} TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoOptions) { Datum input = ArrayFromJSON(this->type(), "[]"); ASSERT_RAISES(Invalid, CallFunction("extract_regex", {input})); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanNoOptions) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ASSERT_RAISES(Invalid, CallFunction("extract_regex_span", {input})); +} + TYPED_TEST(TestBaseBinaryKernels, ExtractRegexInvalid) { Datum input = ArrayFromJSON(this->type(), "[]"); ExtractRegexOptions options{"invalid["}; @@ -1984,6 +2020,17 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegexInvalid) { Invalid, ::testing::HasSubstr("Regular expression contains unnamed groups"), CallFunction("extract_regex", {input}, &options)); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanInvalid) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ExtractRegexSpanOptions options{"invalid["}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"), + CallFunction("extract_regex_span", {input}, &options)); + options = ExtractRegexSpanOptions{"(.)"}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Regular expression contains unnamed groups"), + CallFunction("extract_regex_span", {input}, &options)); +} #endif From 2b82534788f71cb74ec25d966f97fd841f9846bb Mon Sep 17 00:00:00 2001 From: arash andishgar Date: Wed, 26 Feb 2025 15:53:15 +0330 Subject: [PATCH 2/2] fix issue mentioned by comment --- cpp/src/arrow/compute/api_scalar.h | 2 - .../compute/kernels/scalar_string_ascii.cc | 126 ++++++++++-------- .../compute/kernels/scalar_string_test.cc | 32 ++++- 3 files changed, 93 insertions(+), 67 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 3e299ac134f79..8a0b4a0cd151d 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -272,8 +272,6 @@ class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions { /// Regular expression with named capture fields std::string pattern; - - /// Shows the matched string }; /// Options for IsIn and IndexIn functions class ARROW_EXPORT SetLookupOptions : public FunctionOptions { diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc index 535bc7979a380..3eda5273dbfac 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -43,7 +43,6 @@ namespace compute { namespace internal { namespace { - // ---------------------------------------------------------------------- // re2 utilities @@ -2185,9 +2184,35 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) { using ExtractRegexState = OptionsWrapper; +struct BaseExtractRegexData { + Status Init() { + RETURN_NOT_OK(RegexStatus(*regex_)); + + const int group_count = regex_->NumberOfCapturingGroups(); + const auto& name_map = regex_->CapturingGroupNames(); + group_names_.reserve(group_count); + + for (int i = 0; i < group_count; i++) { + auto item = name_map.find(i + 1); // re2 starts counting from 1 + if (item == name_map.end()) { + // XXX should we instead just create fields with an empty name? + return Status::Invalid("Regular expression contains unnamed groups"); + } + group_names_.emplace_back(item->second); + } + return Status::OK(); + } + int64_t num_group() const { return group_names_.size(); } + std::unique_ptr regex_; + std::vector group_names_; + + protected: + explicit BaseExtractRegexData(const std::string& pattern, bool is_utf8 = true) + : regex_(new RE2(pattern, MakeRE2Options(is_utf8))) {} +}; + // TODO cache this once per ExtractRegexOptions -class ExtractRegexData { - public: +struct ExtractRegexData : public BaseExtractRegexData { static Result Make(const ExtractRegexOptions& options, bool is_utf8 = true) { ExtractRegexData data(options.pattern, is_utf8); @@ -2197,12 +2222,9 @@ class ExtractRegexData { Result ResolveOutputType(const std::vector& types) const { const DataType* input_type = types[0].type; - // as mentioned here - // https://arrow.apache.org/docs/developers/cpp/development.html#code-style-linting-and-ci - // nullptr should not be used - if (input_type == NULLPTR) { + if (input_type == nullptr) { // No input type specified - return NULLPTR; + return nullptr; } // Input type is either [Large]Binary or [Large]String and is also the type // of each field in the output struct type. @@ -2212,35 +2234,12 @@ class ExtractRegexData { std::shared_ptr owned_type = input_type->GetSharedPtr(); std::transform(group_names_.begin(), group_names_.end(), std::back_inserter(fields), [&](const std::string& name) { return field(name, owned_type); }); - return struct_(fields); + return struct_(std::move(fields)); } - int64_t num_group() const { return group_names_.size(); } - std::shared_ptr regex() const { return regex_; } - protected: + private: explicit ExtractRegexData(const std::string& pattern, bool is_utf8 = true) - : regex_(new RE2(pattern, MakeRE2Options(is_utf8))) {} - - Status Init() { - RETURN_NOT_OK(RegexStatus(*regex_)); - - const int group_count = regex_->NumberOfCapturingGroups(); - const auto& name_map = regex_->CapturingGroupNames(); - group_names_.reserve(group_count); - - for (int i = 0; i < group_count; i++) { - auto item = name_map.find(i + 1); // re2 starts counting from 1 - if (item == name_map.end()) { - // XXX should we instead just create fields with an empty name? - return Status::Invalid("Regular expression contains unnamed groups"); - } - group_names_.emplace_back(item->second); - } - return Status::OK(); - } - - std::shared_ptr regex_; - std::vector group_names_; + : BaseExtractRegexData(pattern, is_utf8) {} }; Result ResolveExtractRegexOutput(KernelContext* ctx, @@ -2251,7 +2250,7 @@ Result ResolveExtractRegexOutput(KernelContext* ctx, } struct ExtractRegexBase { - const ExtractRegexData& data; + const BaseExtractRegexData& data; const int group_count; std::vector found_values; std::vector args; @@ -2259,7 +2258,7 @@ struct ExtractRegexBase { const RE2::Arg** args_pointers_start; const RE2::Arg* null_arg = nullptr; - explicit ExtractRegexBase(const ExtractRegexData& data) + explicit ExtractRegexBase(const BaseExtractRegexData& data) : data(data), group_count(static_cast(data.num_group())), found_values(group_count) { @@ -2276,7 +2275,7 @@ struct ExtractRegexBase { } bool Match(std::string_view s) { - return RE2::PartialMatchN(ToStringPiece(s), *data.regex(), args_pointers_start, + return RE2::PartialMatchN(ToStringPiece(s), *data.regex_, args_pointers_start, group_count); } }; @@ -2291,7 +2290,7 @@ struct ExtractRegex : public ExtractRegexBase { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { ExtractRegexOptions options = ExtractRegexState::Get(ctx); ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, Type::is_utf8)); - return ExtractRegex{data}.Extract(ctx, batch, out); + return ExtractRegex(data).Extract(ctx, batch, out); } Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { @@ -2357,8 +2356,7 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) { } DCHECK_OK(registry->AddFunction(std::move(func))); } -class ExtractRegexSpanData : public ExtractRegexData { - public: +struct ExtractRegexSpanData : public BaseExtractRegexData { static Result Make(const std::string& pattern) { auto data = ExtractRegexSpanData(pattern, true); ARROW_RETURN_NOT_OK(data.Init()); @@ -2378,20 +2376,26 @@ class ExtractRegexSpanData : public ExtractRegexData { for (const auto& group_name : group_names_) { auto type = is_binary_like(owned_type->id()) ? int32() : int64(); // size list is 2 as every span contains position and length - fields.push_back(field(group_name + "_span", fixed_size_list(type, 2))); + fields.push_back(field(group_name, fixed_size_list(type, 2))); } return struct_(fields); } private: ExtractRegexSpanData(const std::string& pattern, const bool is_utf8) - : ExtractRegexData(pattern, is_utf8) {} + : BaseExtractRegexData(pattern, is_utf8) {} }; template struct ExtractRegexSpan : ExtractRegexBase { using ArrayType = typename TypeTraits::ArrayType; using BuilderType = typename TypeTraits::BuilderType; + using offset_type = typename Type::offset_type; + using OffsetBuilderType = + typename TypeTraits::ArrowType>::BuilderType; + using OffsetCType = + typename TypeTraits::ArrowType>::CType; + using ExtractRegexBase::ExtractRegexBase; static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { @@ -2407,15 +2411,20 @@ struct ExtractRegexSpan : ExtractRegexBase { ARROW_RETURN_NOT_OK( MakeBuilder(ctx->memory_pool(), out->type()->GetSharedPtr(), &out_builder)); auto struct_builder = checked_pointer_cast(std::move(out_builder)); + ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length)); std::vector span_builders; - std::vector array_builders; + std::vector array_builders; span_builders.reserve(group_count); array_builders.reserve(group_count); for (int i = 0; i < group_count; i++) { span_builders.push_back( checked_cast(struct_builder->field_builder(i))); - array_builders.push_back(span_builders[i]->value_builder()); + array_builders.push_back( + checked_cast(span_builders[i]->value_builder())); + RETURN_NOT_OK(span_builders.back()->Reserve(batch[0].array.length)); + RETURN_NOT_OK(array_builders.back()->Reserve(2 * batch[0].array.length)); } + auto visit_null = [&]() { return struct_builder->AppendNull(); }; auto visit_value = [&](std::string_view element) -> Status { if (Match(element)) { @@ -2424,15 +2433,8 @@ struct ExtractRegexSpan : ExtractRegexBase { if (found_values[i].data() != NULLPTR) { int64_t begin = found_values[i].data() - element.data(); int64_t size = found_values[i].size(); - if (is_binary_like(batch.GetTypes()[0].id())) { - ARROW_RETURN_NOT_OK(checked_cast(array_builders[i]) - ->AppendValues({static_cast(begin), - static_cast(size)})); - } else { - ARROW_RETURN_NOT_OK(checked_cast(array_builders[i]) - ->AppendValues({begin, size})); - } - + array_builders[i]->UnsafeAppend(static_cast(begin)); + array_builders[i]->UnsafeAppend(static_cast(size)); ARROW_RETURN_NOT_OK(span_builders[i]->Append()); } else { ARROW_RETURN_NOT_OK(span_builders[i]->AppendNull()); @@ -2453,11 +2455,19 @@ struct ExtractRegexSpan : ExtractRegexBase { } }; -const FunctionDoc extract_regex_doc_span( - "likes extract_regex; however, it contains the position and length of results", "", +const FunctionDoc extract_regex_span_doc( + "Extract substrings captured by a regex pattern and Save the result in the form of " + "(offset,length)", + "For each string in strings, match the regular expression and, if\n" + "successful, emit a struct with field names and values coming from the\n" + "regular expression's named capture groups, which are stored in a form of a\n " + "fixed_size_list(offset, length). If the input is null or the regular \n" + "expression Fails matching, a null output value is emitted.\n" + "Regular expression matching is done using the Google RE2 library.", {"strings"}, "ExtractRegexSpanOptions", true); -Result resolver(KernelContext* ctx, const std::vector& types) { +Result ResolveExtractRegexSpanOutputType( + KernelContext* ctx, const std::vector& types) { auto options = OptionsWrapper::Get(*ctx->state()); ARROW_ASSIGN_OR_RAISE(auto span, ExtractRegexSpanData::Make(options.pattern)); return span.ResolveOutputType(types); @@ -2465,8 +2475,8 @@ Result resolver(KernelContext* ctx, const std::vector& t void AddAsciiStringExtractRegexSpan(FunctionRegistry* registry) { auto func = std::make_shared("extract_regex_span", Arity::Unary(), - extract_regex_doc_span); - OutputType output_type(resolver); + extract_regex_span_doc); + OutputType output_type(ResolveExtractRegexSpanOutputType); for (const auto& type : BaseBinaryTypes()) { ScalarKernel kernel({type}, output_type, GenerateVarBinaryToVarBinary(type), diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 4023491dee5b0..79a7e788b8ccc 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -1960,29 +1960,47 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegex) { R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "3"}])", &options); } -TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSapn) { +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpan) { ExtractRegexSpanOptions options{"(?P[ab])(?P\\d)"}; auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64(); - auto out_type = struct_({field("letter_span", fixed_size_list(type_fixe_size_list, 2)), - field("digit_span", fixed_size_list(type_fixe_size_list, 2))}); + auto out_type = struct_({field("letter", fixed_size_list(type_fixe_size_list, 2)), + field("digit", fixed_size_list(type_fixe_size_list, 2))}); this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options); this->CheckUnary( "extract_regex_span", R"(["a1", "b2", "c3", null])", out_type, - R"([{"letter_span":[0,1], "digit_span":[1,1]}, {"letter_span":[0,1], "digit_span":[1,1]}, null, null])", + R"([{"letter":[0,1], "digit":[1,1]}, {"letter":[0,1], "digit":[1,1]}, null, null])", &options); this->CheckUnary( "extract_regex_span", R"(["a1", "c3", null, "b2"])", out_type, - R"([{"letter_span":[0,1], "digit_span": [1,1]}, null, null, {"letter_span":[0,1], "digit_span":[1,1]}])", + R"([{"letter":[0,1], "digit": [1,1]}, null, null, {"letter":[0,1], "digit":[1,1]}])", &options); this->CheckUnary( "extract_regex_span", R"(["a1", "b2"])", out_type, - R"([{"letter_span": [0,1], "digit_span": [1,1]}, {"letter_span": [0,1], "digit_span": [1,1]}])", + R"([{"letter": [0,1], "digit": [1,1]}, {"letter": [0,1], "digit": [1,1]}])", &options); this->CheckUnary( "extract_regex_span", R"(["a1", "zb3z"])", out_type, - R"([{"letter_span": [0,1], "digit_span": [1,1]}, {"letter_span": [1,1], "digit_span": [2,1]}])", + R"([{"letter": [0,1], "digit": [1,1]}, {"letter": [1,1], "digit": [2,1]}])", &options); } +TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanCaptureOption) { + ExtractRegexSpanOptions options{"(?Pfoo)?(?P\\d+)?"}; + auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64(); + auto out_type = struct_({field("foo", fixed_size_list(type_fixe_size_list, 2)), + field("digit", fixed_size_list(type_fixe_size_list, 2))}); + this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options); + this->CheckUnary("extract_regex_span", R"(["abcfoo"])", out_type, + R"([{"foo":null,"digit":null}])", &options); + options = ExtractRegexSpanOptions{"(?Pfoo)(?P\\d+)?"}; + this->CheckUnary("extract_regex_span", R"(["foo123","foo","123","abc","abcfoo"])", + out_type, + R"([{"foo":[0,3],"digit":[3,3]}, + {"foo":[0,3],"digit":null}, + null, + null, + {"foo":[3,3],"digit":null}])", + &options); +} TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoCapture) { // XXX Should we accept this or is it a user error?