Skip to content

Commit

Permalink
fix issue mentioned by comment
Browse files Browse the repository at this point in the history
  • Loading branch information
arashandishgar committed Feb 26, 2025
1 parent 87e1f6b commit 2b82534
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 67 deletions.
2 changes: 0 additions & 2 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,6 @@ class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions {

/// Regular expression with named capture fields
std::string pattern;

/// Shows the matched string
};
/// Options for IsIn and IndexIn functions
class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
Expand Down
126 changes: 68 additions & 58 deletions cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ namespace compute {
namespace internal {

namespace {

// ----------------------------------------------------------------------
// re2 utilities

Expand Down Expand Up @@ -2185,9 +2184,35 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) {

using ExtractRegexState = OptionsWrapper<ExtractRegexOptions>;

struct BaseExtractRegexData {
Status Init() {
RETURN_NOT_OK(RegexStatus(*regex_));

const int group_count = regex_->NumberOfCapturingGroups();
const auto& name_map = regex_->CapturingGroupNames();
group_names_.reserve(group_count);

for (int i = 0; i < group_count; i++) {
auto item = name_map.find(i + 1); // re2 starts counting from 1
if (item == name_map.end()) {
// XXX should we instead just create fields with an empty name?
return Status::Invalid("Regular expression contains unnamed groups");
}
group_names_.emplace_back(item->second);
}
return Status::OK();
}
int64_t num_group() const { return group_names_.size(); }
std::unique_ptr<RE2> regex_;
std::vector<std::string> group_names_;

protected:
explicit BaseExtractRegexData(const std::string& pattern, bool is_utf8 = true)
: regex_(new RE2(pattern, MakeRE2Options(is_utf8))) {}
};

// TODO cache this once per ExtractRegexOptions
class ExtractRegexData {
public:
struct ExtractRegexData : public BaseExtractRegexData {
static Result<ExtractRegexData> Make(const ExtractRegexOptions& options,
bool is_utf8 = true) {
ExtractRegexData data(options.pattern, is_utf8);
Expand All @@ -2197,12 +2222,9 @@ class ExtractRegexData {

Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const {
const DataType* input_type = types[0].type;
// as mentioned here
// https://arrow.apache.org/docs/developers/cpp/development.html#code-style-linting-and-ci
// nullptr should not be used
if (input_type == NULLPTR) {
if (input_type == nullptr) {
// No input type specified
return NULLPTR;
return nullptr;
}
// Input type is either [Large]Binary or [Large]String and is also the type
// of each field in the output struct type.
Expand All @@ -2212,35 +2234,12 @@ class ExtractRegexData {
std::shared_ptr<DataType> owned_type = input_type->GetSharedPtr();
std::transform(group_names_.begin(), group_names_.end(), std::back_inserter(fields),
[&](const std::string& name) { return field(name, owned_type); });
return struct_(fields);
return struct_(std::move(fields));
}
int64_t num_group() const { return group_names_.size(); }
std::shared_ptr<RE2> regex() const { return regex_; }

protected:
private:
explicit ExtractRegexData(const std::string& pattern, bool is_utf8 = true)
: regex_(new RE2(pattern, MakeRE2Options(is_utf8))) {}

Status Init() {
RETURN_NOT_OK(RegexStatus(*regex_));

const int group_count = regex_->NumberOfCapturingGroups();
const auto& name_map = regex_->CapturingGroupNames();
group_names_.reserve(group_count);

for (int i = 0; i < group_count; i++) {
auto item = name_map.find(i + 1); // re2 starts counting from 1
if (item == name_map.end()) {
// XXX should we instead just create fields with an empty name?
return Status::Invalid("Regular expression contains unnamed groups");
}
group_names_.emplace_back(item->second);
}
return Status::OK();
}

std::shared_ptr<RE2> regex_;
std::vector<std::string> group_names_;
: BaseExtractRegexData(pattern, is_utf8) {}
};

Result<TypeHolder> ResolveExtractRegexOutput(KernelContext* ctx,
Expand All @@ -2251,15 +2250,15 @@ Result<TypeHolder> ResolveExtractRegexOutput(KernelContext* ctx,
}

struct ExtractRegexBase {
const ExtractRegexData& data;
const BaseExtractRegexData& data;
const int group_count;
std::vector<re2::StringPiece> found_values;
std::vector<RE2::Arg> args;
std::vector<const RE2::Arg*> args_pointers;
const RE2::Arg** args_pointers_start;
const RE2::Arg* null_arg = nullptr;

explicit ExtractRegexBase(const ExtractRegexData& data)
explicit ExtractRegexBase(const BaseExtractRegexData& data)
: data(data),
group_count(static_cast<int>(data.num_group())),
found_values(group_count) {
Expand All @@ -2276,7 +2275,7 @@ struct ExtractRegexBase {
}

bool Match(std::string_view s) {
return RE2::PartialMatchN(ToStringPiece(s), *data.regex(), args_pointers_start,
return RE2::PartialMatchN(ToStringPiece(s), *data.regex_, args_pointers_start,
group_count);
}
};
Expand All @@ -2291,7 +2290,7 @@ struct ExtractRegex : public ExtractRegexBase {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
ExtractRegexOptions options = ExtractRegexState::Get(ctx);
ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, Type::is_utf8));
return ExtractRegex{data}.Extract(ctx, batch, out);
return ExtractRegex(data).Extract(ctx, batch, out);
}

Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
Expand Down Expand Up @@ -2357,8 +2356,7 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) {
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}
class ExtractRegexSpanData : public ExtractRegexData {
public:
struct ExtractRegexSpanData : public BaseExtractRegexData {
static Result<ExtractRegexSpanData> Make(const std::string& pattern) {
auto data = ExtractRegexSpanData(pattern, true);
ARROW_RETURN_NOT_OK(data.Init());
Expand All @@ -2378,20 +2376,26 @@ class ExtractRegexSpanData : public ExtractRegexData {
for (const auto& group_name : group_names_) {
auto type = is_binary_like(owned_type->id()) ? int32() : int64();
// size list is 2 as every span contains position and length
fields.push_back(field(group_name + "_span", fixed_size_list(type, 2)));
fields.push_back(field(group_name, fixed_size_list(type, 2)));
}
return struct_(fields);
}

private:
ExtractRegexSpanData(const std::string& pattern, const bool is_utf8)
: ExtractRegexData(pattern, is_utf8) {}
: BaseExtractRegexData(pattern, is_utf8) {}
};

template <typename Type>
struct ExtractRegexSpan : ExtractRegexBase {
using ArrayType = typename TypeTraits<Type>::ArrayType;
using BuilderType = typename TypeTraits<Type>::BuilderType;
using offset_type = typename Type::offset_type;
using OffsetBuilderType =
typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::BuilderType;
using OffsetCType =
typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::CType;

using ExtractRegexBase::ExtractRegexBase;

static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
Expand All @@ -2407,15 +2411,20 @@ struct ExtractRegexSpan : ExtractRegexBase {
ARROW_RETURN_NOT_OK(
MakeBuilder(ctx->memory_pool(), out->type()->GetSharedPtr(), &out_builder));
auto struct_builder = checked_pointer_cast<StructBuilder>(std::move(out_builder));
ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length));
std::vector<FixedSizeListBuilder*> span_builders;
std::vector<ArrayBuilder*> array_builders;
std::vector<OffsetBuilderType*> array_builders;
span_builders.reserve(group_count);
array_builders.reserve(group_count);
for (int i = 0; i < group_count; i++) {
span_builders.push_back(
checked_cast<FixedSizeListBuilder*>(struct_builder->field_builder(i)));
array_builders.push_back(span_builders[i]->value_builder());
array_builders.push_back(
checked_cast<OffsetBuilderType*>(span_builders[i]->value_builder()));
RETURN_NOT_OK(span_builders.back()->Reserve(batch[0].array.length));
RETURN_NOT_OK(array_builders.back()->Reserve(2 * batch[0].array.length));
}

auto visit_null = [&]() { return struct_builder->AppendNull(); };
auto visit_value = [&](std::string_view element) -> Status {
if (Match(element)) {
Expand All @@ -2424,15 +2433,8 @@ struct ExtractRegexSpan : ExtractRegexBase {
if (found_values[i].data() != NULLPTR) {
int64_t begin = found_values[i].data() - element.data();
int64_t size = found_values[i].size();
if (is_binary_like(batch.GetTypes()[0].id())) {
ARROW_RETURN_NOT_OK(checked_cast<Int32Builder*>(array_builders[i])
->AppendValues({static_cast<int32_t>(begin),
static_cast<int32_t>(size)}));
} else {
ARROW_RETURN_NOT_OK(checked_cast<Int64Builder*>(array_builders[i])
->AppendValues({begin, size}));
}

array_builders[i]->UnsafeAppend(static_cast<OffsetCType>(begin));
array_builders[i]->UnsafeAppend(static_cast<OffsetCType>(size));
ARROW_RETURN_NOT_OK(span_builders[i]->Append());
} else {
ARROW_RETURN_NOT_OK(span_builders[i]->AppendNull());
Expand All @@ -2453,20 +2455,28 @@ struct ExtractRegexSpan : ExtractRegexBase {
}
};

const FunctionDoc extract_regex_doc_span(
"likes extract_regex; however, it contains the position and length of results", "",
const FunctionDoc extract_regex_span_doc(
"Extract substrings captured by a regex pattern and Save the result in the form of "
"(offset,length)",
"For each string in strings, match the regular expression and, if\n"
"successful, emit a struct with field names and values coming from the\n"
"regular expression's named capture groups, which are stored in a form of a\n "
"fixed_size_list(offset, length). If the input is null or the regular \n"
"expression Fails matching, a null output value is emitted.\n"
"Regular expression matching is done using the Google RE2 library.",
{"strings"}, "ExtractRegexSpanOptions", true);

Result<TypeHolder> resolver(KernelContext* ctx, const std::vector<TypeHolder>& types) {
Result<TypeHolder> ResolveExtractRegexSpanOutputType(
KernelContext* ctx, const std::vector<TypeHolder>& types) {
auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(*ctx->state());
ARROW_ASSIGN_OR_RAISE(auto span, ExtractRegexSpanData::Make(options.pattern));
return span.ResolveOutputType(types);
}

void AddAsciiStringExtractRegexSpan(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("extract_regex_span", Arity::Unary(),
extract_regex_doc_span);
OutputType output_type(resolver);
extract_regex_span_doc);
OutputType output_type(ResolveExtractRegexSpanOutputType);
for (const auto& type : BaseBinaryTypes()) {
ScalarKernel kernel({type}, output_type,
GenerateVarBinaryToVarBinary<ExtractRegexSpan>(type),
Expand Down
32 changes: 25 additions & 7 deletions cpp/src/arrow/compute/kernels/scalar_string_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1960,29 +1960,47 @@ TYPED_TEST(TestBaseBinaryKernels, ExtractRegex) {
R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "3"}])",
&options);
}
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSapn) {
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpan) {
ExtractRegexSpanOptions options{"(?P<letter>[ab])(?P<digit>\\d)"};
auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64();
auto out_type = struct_({field("letter_span", fixed_size_list(type_fixe_size_list, 2)),
field("digit_span", fixed_size_list(type_fixe_size_list, 2))});
auto out_type = struct_({field("letter", fixed_size_list(type_fixe_size_list, 2)),
field("digit", fixed_size_list(type_fixe_size_list, 2))});
this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options);
this->CheckUnary(
"extract_regex_span", R"(["a1", "b2", "c3", null])", out_type,
R"([{"letter_span":[0,1], "digit_span":[1,1]}, {"letter_span":[0,1], "digit_span":[1,1]}, null, null])",
R"([{"letter":[0,1], "digit":[1,1]}, {"letter":[0,1], "digit":[1,1]}, null, null])",
&options);
this->CheckUnary(
"extract_regex_span", R"(["a1", "c3", null, "b2"])", out_type,
R"([{"letter_span":[0,1], "digit_span": [1,1]}, null, null, {"letter_span":[0,1], "digit_span":[1,1]}])",
R"([{"letter":[0,1], "digit": [1,1]}, null, null, {"letter":[0,1], "digit":[1,1]}])",
&options);
this->CheckUnary(
"extract_regex_span", R"(["a1", "b2"])", out_type,
R"([{"letter_span": [0,1], "digit_span": [1,1]}, {"letter_span": [0,1], "digit_span": [1,1]}])",
R"([{"letter": [0,1], "digit": [1,1]}, {"letter": [0,1], "digit": [1,1]}])",
&options);
this->CheckUnary(
"extract_regex_span", R"(["a1", "zb3z"])", out_type,
R"([{"letter_span": [0,1], "digit_span": [1,1]}, {"letter_span": [1,1], "digit_span": [2,1]}])",
R"([{"letter": [0,1], "digit": [1,1]}, {"letter": [1,1], "digit": [2,1]}])",
&options);
}
TYPED_TEST(TestBaseBinaryKernels, ExtractRegexSpanCaptureOption) {
ExtractRegexSpanOptions options{"(?P<foo>foo)?(?P<digit>\\d+)?"};
auto type_fixe_size_list = is_binary_like(this->type()->id()) ? int32() : int64();
auto out_type = struct_({field("foo", fixed_size_list(type_fixe_size_list, 2)),
field("digit", fixed_size_list(type_fixe_size_list, 2))});
this->CheckUnary("extract_regex_span", R"([])", out_type, R"([])", &options);
this->CheckUnary("extract_regex_span", R"(["abcfoo"])", out_type,
R"([{"foo":null,"digit":null}])", &options);
options = ExtractRegexSpanOptions{"(?P<foo>foo)(?P<digit>\\d+)?"};
this->CheckUnary("extract_regex_span", R"(["foo123","foo","123","abc","abcfoo"])",
out_type,
R"([{"foo":[0,3],"digit":[3,3]},
{"foo":[0,3],"digit":null},
null,
null,
{"foo":[3,3],"digit":null}])",
&options);
}

TYPED_TEST(TestBaseBinaryKernels, ExtractRegexNoCapture) {
// XXX Should we accept this or is it a user error?
Expand Down

0 comments on commit 2b82534

Please sign in to comment.