Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-44615: [C++] Add possibility to extract spans/byte offsets directly for compute.extract_regex #45577

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include "arrow/type.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"

namespace arrow {

namespace internal {
Expand Down Expand Up @@ -325,6 +324,9 @@ static auto kElementWiseAggregateOptionsType =
DataMember("skip_nulls", &ElementWiseAggregateOptions::skip_nulls));
static auto kExtractRegexOptionsType = GetFunctionOptionsType<ExtractRegexOptions>(
DataMember("pattern", &ExtractRegexOptions::pattern));
static auto kExtractRegexSpanOptionsType =
GetFunctionOptionsType<ExtractRegexSpanOptions>(
DataMember("pattern", &ExtractRegexSpanOptions::pattern));
static auto kJoinOptionsType = GetFunctionOptionsType<JoinOptions>(
DataMember("null_handling", &JoinOptions::null_handling),
DataMember("null_replacement", &JoinOptions::null_replacement));
Expand Down Expand Up @@ -438,6 +440,12 @@ ExtractRegexOptions::ExtractRegexOptions(std::string pattern)
ExtractRegexOptions::ExtractRegexOptions() : ExtractRegexOptions("") {}
constexpr char ExtractRegexOptions::kTypeName[];

ExtractRegexSpanOptions::ExtractRegexSpanOptions(std::string pattern)
: FunctionOptions(internal::kExtractRegexSpanOptionsType),
pattern(std::move(pattern)) {}
ExtractRegexSpanOptions::ExtractRegexSpanOptions() : ExtractRegexSpanOptions("") {}
constexpr char ExtractRegexSpanOptions::kTypeName[];

JoinOptions::JoinOptions(NullHandlingBehavior null_handling, std::string null_replacement)
: FunctionOptions(internal::kJoinOptionsType),
null_handling(null_handling),
Expand Down Expand Up @@ -684,6 +692,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kDayOfWeekOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kElementWiseAggregateOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexSpanOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kJoinOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListSliceOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kMakeStructOptionsType));
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,15 @@ class ARROW_EXPORT ExtractRegexOptions : public FunctionOptions {
/// Regular expression with named capture fields
std::string pattern;
};
class ARROW_EXPORT ExtractRegexSpanOptions : public FunctionOptions {
public:
explicit ExtractRegexSpanOptions(std::string pattern);
ExtractRegexSpanOptions();
static constexpr char const kTypeName[] = "ExtractRegexSpanOptions";

/// Regular expression with named capture fields
std::string pattern;
};
/// Options for IsIn and IndexIn functions
class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
public:
Expand Down
197 changes: 169 additions & 28 deletions cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <string>

#include "arrow/array/builder_nested.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/compute/kernels/scalar_string_internal.h"
#include "arrow/result.h"
#include "arrow/util/config.h"
Expand All @@ -42,7 +43,6 @@ namespace compute {
namespace internal {

namespace {

// ----------------------------------------------------------------------
// re2 utilities

Expand Down Expand Up @@ -2184,29 +2184,39 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) {

using ExtractRegexState = OptionsWrapper<ExtractRegexOptions>;

// TODO cache this once per ExtractRegexOptions
struct ExtractRegexData {
// Use unique_ptr<> because RE2 is non-movable (for ARROW_ASSIGN_OR_RAISE)
std::unique_ptr<RE2> regex;
std::vector<std::string> group_names;

static Result<ExtractRegexData> Make(const ExtractRegexOptions& options,
bool is_utf8 = true) {
ExtractRegexData data(options.pattern, is_utf8);
RETURN_NOT_OK(RegexStatus(*data.regex));
struct BaseExtractRegexData {
Status Init() {
RETURN_NOT_OK(RegexStatus(*regex_));

const int group_count = data.regex->NumberOfCapturingGroups();
const auto& name_map = data.regex->CapturingGroupNames();
data.group_names.reserve(group_count);
const int group_count = regex_->NumberOfCapturingGroups();
const auto& name_map = regex_->CapturingGroupNames();
group_names_.reserve(group_count);

for (int i = 0; i < group_count; i++) {
auto item = name_map.find(i + 1); // re2 starts counting from 1
if (item == name_map.end()) {
// XXX should we instead just create fields with an empty name?
return Status::Invalid("Regular expression contains unnamed groups");
}
data.group_names.emplace_back(item->second);
group_names_.emplace_back(item->second);
}
return Status::OK();
}
int64_t num_group() const { return group_names_.size(); }
std::unique_ptr<RE2> regex_;
std::vector<std::string> group_names_;

protected:
explicit BaseExtractRegexData(const std::string& pattern, bool is_utf8 = true)
: regex_(new RE2(pattern, MakeRE2Options(is_utf8))) {}
};

// TODO cache this once per ExtractRegexOptions
struct ExtractRegexData : public BaseExtractRegexData {
static Result<ExtractRegexData> Make(const ExtractRegexOptions& options,
bool is_utf8 = true) {
ExtractRegexData data(options.pattern, is_utf8);
ARROW_RETURN_NOT_OK(data.Init());
return data;
}

Expand All @@ -2220,16 +2230,16 @@ struct ExtractRegexData {
// of each field in the output struct type.
DCHECK(is_base_binary_like(input_type->id()));
FieldVector fields;
fields.reserve(group_names.size());
fields.reserve(group_names_.size());
std::shared_ptr<DataType> owned_type = input_type->GetSharedPtr();
std::transform(group_names.begin(), group_names.end(), std::back_inserter(fields),
std::transform(group_names_.begin(), group_names_.end(), std::back_inserter(fields),
[&](const std::string& name) { return field(name, owned_type); });
return struct_(std::move(fields));
}

private:
explicit ExtractRegexData(const std::string& pattern, bool is_utf8 = true)
: regex(new RE2(pattern, MakeRE2Options(is_utf8))) {}
: BaseExtractRegexData(pattern, is_utf8) {}
};

Result<TypeHolder> ResolveExtractRegexOutput(KernelContext* ctx,
Expand All @@ -2240,17 +2250,17 @@ Result<TypeHolder> ResolveExtractRegexOutput(KernelContext* ctx,
}

struct ExtractRegexBase {
const ExtractRegexData& data;
const BaseExtractRegexData& data;
const int group_count;
std::vector<re2::StringPiece> found_values;
std::vector<RE2::Arg> args;
std::vector<const RE2::Arg*> args_pointers;
const RE2::Arg** args_pointers_start;
const RE2::Arg* null_arg = nullptr;

explicit ExtractRegexBase(const ExtractRegexData& data)
explicit ExtractRegexBase(const BaseExtractRegexData& data)
: data(data),
group_count(static_cast<int>(data.group_names.size())),
group_count(static_cast<int>(data.num_group())),
found_values(group_count) {
args.reserve(group_count);
args_pointers.reserve(group_count);
Expand All @@ -2265,7 +2275,7 @@ struct ExtractRegexBase {
}

bool Match(std::string_view s) {
return RE2::PartialMatchN(ToStringPiece(s), *data.regex, args_pointers_start,
return RE2::PartialMatchN(ToStringPiece(s), *data.regex_, args_pointers_start,
group_count);
}
};
Expand All @@ -2280,15 +2290,14 @@ struct ExtractRegex : public ExtractRegexBase {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
ExtractRegexOptions options = ExtractRegexState::Get(ctx);
ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options, Type::is_utf8));
return ExtractRegex{data}.Extract(ctx, batch, out);
return ExtractRegex(data).Extract(ctx, batch, out);
}

Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
// TODO: why is this needed? Type resolution should already be
// done and the output type set in the output variable
ARROW_ASSIGN_OR_RAISE(TypeHolder out_type, data.ResolveOutputType(batch.GetTypes()));
DCHECK_NE(out_type.type, nullptr);
std::shared_ptr<DataType> type = out_type.GetSharedPtr();
ExtractRegexOptions options = ExtractRegexState::Get(ctx);
DCHECK_NE(out->array_data(), NULLPTR);
std::shared_ptr<DataType> type = out->array_data()->type;
DCHECK_NE(type, NULLPTR);

std::unique_ptr<ArrayBuilder> array_builder;
RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type, &array_builder));
Expand Down Expand Up @@ -2347,6 +2356,137 @@ void AddAsciiStringExtractRegex(FunctionRegistry* registry) {
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}
struct ExtractRegexSpanData : public BaseExtractRegexData {
static Result<ExtractRegexSpanData> Make(const std::string& pattern) {
auto data = ExtractRegexSpanData(pattern, true);
ARROW_RETURN_NOT_OK(data.Init());
return data;
}

Result<TypeHolder> ResolveOutputType(const std::vector<TypeHolder>& types) const {
const DataType* input_type = types[0].type;
if (input_type == NULLPTR) {
return NULLPTR;
}
DCHECK(is_base_binary_like(input_type->id()));
const size_t field_count = group_names_.size();
FieldVector fields;
fields.reserve(field_count);
const auto owned_type = input_type->GetSharedPtr();
for (const auto& group_name : group_names_) {
auto type = is_binary_like(owned_type->id()) ? int32() : int64();
// size list is 2 as every span contains position and length
fields.push_back(field(group_name, fixed_size_list(type, 2)));
}
return struct_(fields);
}

private:
ExtractRegexSpanData(const std::string& pattern, const bool is_utf8)
: BaseExtractRegexData(pattern, is_utf8) {}
};

template <typename Type>
struct ExtractRegexSpan : ExtractRegexBase {
using ArrayType = typename TypeTraits<Type>::ArrayType;
using BuilderType = typename TypeTraits<Type>::BuilderType;
using offset_type = typename Type::offset_type;
using OffsetBuilderType =
typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::BuilderType;
using OffsetCType =
typename TypeTraits<typename CTypeTraits<offset_type>::ArrowType>::CType;

using ExtractRegexBase::ExtractRegexBase;

static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(ctx);
ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexSpanData::Make(options.pattern));
return ExtractRegexSpan{data}.Extract(ctx, batch, out);
}
Status Extract(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
DCHECK_NE(out->array_data(), NULLPTR);
std::shared_ptr<DataType> out_type = out->array_data()->type;
DCHECK_NE(out_type, NULLPTR);
std::unique_ptr<ArrayBuilder> out_builder;
ARROW_RETURN_NOT_OK(
MakeBuilder(ctx->memory_pool(), out->type()->GetSharedPtr(), &out_builder));
auto struct_builder = checked_pointer_cast<StructBuilder>(std::move(out_builder));
ARROW_RETURN_NOT_OK(struct_builder->Reserve(batch[0].array.length));
std::vector<FixedSizeListBuilder*> span_builders;
std::vector<OffsetBuilderType*> array_builders;
span_builders.reserve(group_count);
array_builders.reserve(group_count);
for (int i = 0; i < group_count; i++) {
span_builders.push_back(
checked_cast<FixedSizeListBuilder*>(struct_builder->field_builder(i)));
array_builders.push_back(
checked_cast<OffsetBuilderType*>(span_builders[i]->value_builder()));
RETURN_NOT_OK(span_builders.back()->Reserve(batch[0].array.length));
RETURN_NOT_OK(array_builders.back()->Reserve(2 * batch[0].array.length));
}

auto visit_null = [&]() { return struct_builder->AppendNull(); };
auto visit_value = [&](std::string_view element) -> Status {
if (Match(element)) {
for (int i = 0; i < group_count; i++) {
// https://github.com/google/re2/issues/24#issuecomment-97653183
if (found_values[i].data() != NULLPTR) {
int64_t begin = found_values[i].data() - element.data();
int64_t size = found_values[i].size();
array_builders[i]->UnsafeAppend(static_cast<OffsetCType>(begin));
array_builders[i]->UnsafeAppend(static_cast<OffsetCType>(size));
ARROW_RETURN_NOT_OK(span_builders[i]->Append());
} else {
ARROW_RETURN_NOT_OK(span_builders[i]->AppendNull());
}
}
ARROW_RETURN_NOT_OK(struct_builder->Append());
} else {
ARROW_RETURN_NOT_OK(struct_builder->AppendNull());
}
return Status::OK();
};
ARROW_RETURN_NOT_OK(
VisitArraySpanInline<Type>(batch[0].array, visit_value, visit_null));

ARROW_ASSIGN_OR_RAISE(auto out_array, struct_builder->Finish());
out->value = out_array->data();
return Status::OK();
}
};

const FunctionDoc extract_regex_span_doc(
"Extract substrings captured by a regex pattern and Save the result in the form of "
"(offset,length)",
"For each string in strings, match the regular expression and, if\n"
"successful, emit a struct with field names and values coming from the\n"
"regular expression's named capture groups, which are stored in a form of a\n "
"fixed_size_list(offset, length). If the input is null or the regular \n"
"expression Fails matching, a null output value is emitted.\n"
"Regular expression matching is done using the Google RE2 library.",
{"strings"}, "ExtractRegexSpanOptions", true);

Result<TypeHolder> ResolveExtractRegexSpanOutputType(
KernelContext* ctx, const std::vector<TypeHolder>& types) {
auto options = OptionsWrapper<ExtractRegexSpanOptions>::Get(*ctx->state());
ARROW_ASSIGN_OR_RAISE(auto span, ExtractRegexSpanData::Make(options.pattern));
return span.ResolveOutputType(types);
}

void AddAsciiStringExtractRegexSpan(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("extract_regex_span", Arity::Unary(),
extract_regex_span_doc);
OutputType output_type(ResolveExtractRegexSpanOutputType);
for (const auto& type : BaseBinaryTypes()) {
ScalarKernel kernel({type}, output_type,
GenerateVarBinaryToVarBinary<ExtractRegexSpan>(type),
OptionsWrapper<ExtractRegexSpanOptions>::Init);
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
DCHECK_OK(func->AddKernel(std::move(kernel)));
}
DCHECK_OK(registry->AddFunction(func));
}
#endif // ARROW_WITH_RE2

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -3457,6 +3597,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
AddAsciiStringSplitWhitespace(registry);
#ifdef ARROW_WITH_RE2
AddAsciiStringSplitRegex(registry);
AddAsciiStringExtractRegexSpan(registry);
#endif
AddAsciiStringJoin(registry);
AddAsciiStringRepeat(registry);
Expand Down
Loading
Loading