Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-45611: [C++][Acero] Improve Swiss join build performance by partitioning batches ahead to reduce contention #45612

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 108 additions & 81 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,8 @@ uint32_t SwissTableForJoin::payload_id_to_key_id(uint32_t payload_id) const {
}

Status SwissTableForJoinBuild::Init(SwissTableForJoin* target, int dop, int64_t num_rows,
bool reject_duplicate_keys, bool no_payload,
int64_t num_batches, bool reject_duplicate_keys,
bool no_payload,
const std::vector<KeyColumnMetadata>& key_types,
const std::vector<KeyColumnMetadata>& payload_types,
MemoryPool* pool, int64_t hardware_flags) {
Expand All @@ -1112,7 +1113,7 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin* target, int dop, int64_t

// Make sure that we do not use many partitions if there are not enough rows.
//
constexpr int64_t min_num_rows_per_prtn = 1 << 18;
constexpr int64_t min_num_rows_per_prtn = 1 << 12;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the contention is eliminated, we can be a little more aggressive on the parallelism.

log_num_prtns_ =
std::min(bit_util::Log2(dop_),
bit_util::Log2(bit_util::CeilDiv(num_rows, min_num_rows_per_prtn)));
Expand All @@ -1123,9 +1124,9 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin* target, int dop, int64_t
pool_ = pool;
hardware_flags_ = hardware_flags;

batch_states_.resize(num_batches);
prtn_states_.resize(num_prtns_);
thread_states_.resize(dop_);
prtn_locks_.Init(dop_, num_prtns_);

RowTableMetadata key_row_metadata;
key_row_metadata.FromColumnMetadataVector(key_types,
Expand Down Expand Up @@ -1154,91 +1155,73 @@ Status SwissTableForJoinBuild::Init(SwissTableForJoin* target, int dop, int64_t
return Status::OK();
}

Status SwissTableForJoinBuild::PushNextBatch(int64_t thread_id,
const ExecBatch& key_batch,
const ExecBatch* payload_batch_maybe_null,
arrow::util::TempVectorStack* temp_stack) {
ARROW_DCHECK(thread_id < dop_);
Status SwissTableForJoinBuild::PartitionBatch(size_t thread_id, int64_t batch_id,
const ExecBatch& key_batch,
arrow::util::TempVectorStack* temp_stack) {
DCHECK_LE(static_cast<int64_t>(thread_id), dop_);
DCHECK_LE(batch_id, static_cast<int64_t>(batch_states_.size()));
ThreadState& locals = thread_states_[thread_id];
BatchState& batch_state = batch_states_[batch_id];
uint16_t num_rows = static_cast<uint16_t>(key_batch.length);

// Compute hash
//
locals.batch_hashes.resize(key_batch.length);
RETURN_NOT_OK(Hashing32::HashBatch(
key_batch, locals.batch_hashes.data(), locals.temp_column_arrays, hardware_flags_,
temp_stack, /*start_row=*/0, static_cast<int>(key_batch.length)));
batch_state.hashes.resize(num_rows);
RETURN_NOT_OK(Hashing32::HashBatch(key_batch, batch_state.hashes.data(),
locals.temp_column_arrays, hardware_flags_,
temp_stack, /*start_row=*/0, num_rows));

// Partition on hash
//
locals.batch_prtn_row_ids.resize(locals.batch_hashes.size());
locals.batch_prtn_ranges.resize(num_prtns_ + 1);
int num_rows = static_cast<int>(locals.batch_hashes.size());
batch_state.prtn_ranges.resize(num_prtns_ + 1);
batch_state.prtn_row_ids.resize(num_rows);
if (num_prtns_ == 1) {
// We treat single partition case separately to avoid extra checks in row
// partitioning implementation for general case.
//
locals.batch_prtn_ranges[0] = 0;
locals.batch_prtn_ranges[1] = num_rows;
for (int i = 0; i < num_rows; ++i) {
locals.batch_prtn_row_ids[i] = i;
batch_state.prtn_ranges[0] = 0;
batch_state.prtn_ranges[1] = num_rows;
for (uint16_t i = 0; i < num_rows; ++i) {
batch_state.prtn_row_ids[i] = i;
}
} else {
PartitionSort::Eval(
static_cast<int>(locals.batch_hashes.size()), num_prtns_,
locals.batch_prtn_ranges.data(),
[this, &locals](int64_t i) {
num_rows, num_prtns_, batch_state.prtn_ranges.data(),
[this, &batch_state](int64_t i) {
// SwissTable uses the highest bits of the hash for block index.
// We want each partition to correspond to a range of block indices,
// so we also partition on the highest bits of the hash.
//
return locals.batch_hashes[i] >> (SwissTable::bits_hash_ - log_num_prtns_);
return batch_state.hashes[i] >> (SwissTable::bits_hash_ - log_num_prtns_);
},
[&locals](int64_t i, int pos) {
locals.batch_prtn_row_ids[pos] = static_cast<uint16_t>(i);
[&batch_state](int64_t i, int pos) {
batch_state.prtn_row_ids[pos] = static_cast<uint16_t>(i);
});
}

// Update hashes, shifting left to get rid of the bits that were already used
// for partitioning.
//
for (size_t i = 0; i < locals.batch_hashes.size(); ++i) {
locals.batch_hashes[i] <<= log_num_prtns_;
for (size_t i = 0; i < batch_state.hashes.size(); ++i) {
batch_state.hashes[i] <<= log_num_prtns_;
}

// For each partition:
// - map keys to unique integers using (this partition's) hash table
// - append payloads (if present) to (this partition's) row array
//
locals.temp_prtn_ids.resize(num_prtns_);

RETURN_NOT_OK(prtn_locks_.ForEachPartition(
thread_id, locals.temp_prtn_ids.data(),
/*is_prtn_empty_fn=*/
[&](int prtn_id) {
return locals.batch_prtn_ranges[prtn_id + 1] == locals.batch_prtn_ranges[prtn_id];
},
/*process_prtn_fn=*/
[&](int prtn_id) {
return ProcessPartition(thread_id, key_batch, payload_batch_maybe_null,
temp_stack, prtn_id);
}));

return Status::OK();
}

Status SwissTableForJoinBuild::ProcessPartition(int64_t thread_id,
const ExecBatch& key_batch,
const ExecBatch* payload_batch_maybe_null,
arrow::util::TempVectorStack* temp_stack,
int prtn_id) {
ARROW_DCHECK(thread_id < dop_);
Status SwissTableForJoinBuild::ProcessPartition(
size_t thread_id, int64_t batch_id, int prtn_id, const ExecBatch& key_batch,
const ExecBatch* payload_batch_maybe_null, arrow::util::TempVectorStack* temp_stack) {
DCHECK_LE(static_cast<int64_t>(thread_id), dop_);
DCHECK_LE(batch_id, static_cast<int64_t>(batch_states_.size()));
ThreadState& locals = thread_states_[thread_id];
BatchState& batch_state = batch_states_[batch_id];
PartitionState& prtn_state = prtn_states_[prtn_id];

int num_rows_new =
locals.batch_prtn_ranges[prtn_id + 1] - locals.batch_prtn_ranges[prtn_id];
batch_state.prtn_ranges[prtn_id + 1] - batch_state.prtn_ranges[prtn_id];
const uint16_t* row_ids =
locals.batch_prtn_row_ids.data() + locals.batch_prtn_ranges[prtn_id];
PartitionState& prtn_state = prtn_states_[prtn_id];
batch_state.prtn_row_ids.data() + batch_state.prtn_ranges[prtn_id];
size_t num_rows_before = prtn_state.key_ids.size();
// Insert new keys into hash table associated with the current partition
// and map existing keys to integer ids.
Expand All @@ -1247,7 +1230,7 @@ Status SwissTableForJoinBuild::ProcessPartition(int64_t thread_id,
SwissTableWithKeys::Input input(&key_batch, num_rows_new, row_ids, temp_stack,
&locals.temp_column_arrays, &locals.temp_group_ids);
RETURN_NOT_OK(prtn_state.keys.MapWithInserts(
&input, locals.batch_hashes.data(), prtn_state.key_ids.data() + num_rows_before));
&input, batch_state.hashes.data(), prtn_state.key_ids.data() + num_rows_before));
// Append input batch rows from current partition to an array of payload
// rows for this partition.
//
Expand Down Expand Up @@ -2504,6 +2487,13 @@ class SwissJoin : public HashJoinImpl {
}

void InitTaskGroups() {
task_group_partition_ = register_task_group_callback_(
[this](size_t thread_index, int64_t task_id) -> Status {
return PartitionTask(thread_index, task_id);
},
[this](size_t thread_index) -> Status {
return PartitionFinished(thread_index);
});
task_group_build_ = register_task_group_callback_(
[this](size_t thread_index, int64_t task_id) -> Status {
return BuildTask(thread_index, task_id);
Expand Down Expand Up @@ -2593,58 +2583,94 @@ class SwissJoin : public HashJoinImpl {
hash_table_build_ = std::make_unique<SwissTableForJoinBuild>();
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->Init(
&hash_table_, num_threads_, build_side_batches_.row_count(),
reject_duplicate_keys, no_payload, key_types, payload_types, pool_,
hardware_flags_)));
build_side_batches_.batch_count(), reject_duplicate_keys, no_payload, key_types,
payload_types, pool_, hardware_flags_)));

// Process all input batches
//
return CancelIfNotOK(
start_task_group_callback_(task_group_build_, build_side_batches_.batch_count()));
return CancelIfNotOK(start_task_group_callback_(task_group_partition_,
build_side_batches_.batch_count()));
}

Status BuildTask(size_t thread_id, int64_t batch_id) {
Status PartitionTask(size_t thread_id, int64_t batch_id) {
if (IsCancelled()) {
return Status::OK();
}

DCHECK_GT(build_side_batches_[batch_id].length, 0);

const HashJoinProjectionMaps* schema = schema_[1];
DCHECK_NE(hash_table_build_, nullptr);
bool no_payload = hash_table_build_->no_payload();

ExecBatch input_batch;
ARROW_ASSIGN_OR_RAISE(
input_batch, KeyPayloadFromInput(/*side=*/1, &build_side_batches_[batch_id]));

// Split batch into key batch and optional payload batch
//
// Input batch is key-payload batch (key columns followed by payload
// columns). We split it into two separate batches.
//
// TODO: Change SwissTableForJoinBuild interface to use key-payload
// batch instead to avoid this operation, which involves increasing
// shared pointer ref counts.
//
ExecBatch key_batch({}, input_batch.length);
key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY));
for (size_t icol = 0; icol < key_batch.values.size(); ++icol) {
key_batch.values[icol] = input_batch.values[icol];
}
ExecBatch payload_batch({}, input_batch.length);
arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;

DCHECK_NE(hash_table_build_, nullptr);
return hash_table_build_->PartitionBatch(static_cast<int64_t>(thread_id), batch_id,
key_batch, temp_stack);
}

Status PartitionFinished(size_t thread_id) {
RETURN_NOT_OK(status());

DCHECK_NE(hash_table_build_, nullptr);
return CancelIfNotOK(
start_task_group_callback_(task_group_build_, hash_table_build_->num_prtns()));
}

Status BuildTask(size_t thread_id, int64_t prtn_id) {
if (IsCancelled()) {
return Status::OK();
}

const HashJoinProjectionMaps* schema = schema_[1];
DCHECK_NE(hash_table_build_, nullptr);
bool no_payload = hash_table_build_->no_payload();
ExecBatch key_batch, payload_batch;
key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY));
if (!no_payload) {
payload_batch.values.resize(schema->num_cols(HashJoinProjection::PAYLOAD));
for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) {
payload_batch.values[icol] =
input_batch.values[schema->num_cols(HashJoinProjection::KEY) + icol];
}
}
arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;
DCHECK_NE(hash_table_build_, nullptr);
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->PushNextBatch(
static_cast<int64_t>(thread_id), key_batch, no_payload ? nullptr : &payload_batch,
temp_stack)));

for (int64_t batch_id = 0;
batch_id < static_cast<int64_t>(build_side_batches_.batch_count()); ++batch_id) {
ExecBatch input_batch;
ARROW_ASSIGN_OR_RAISE(
input_batch, KeyPayloadFromInput(/*side=*/1, &build_side_batches_[batch_id]));

// Split batch into key batch and optional payload batch
//
// Input batch is key-payload batch (key columns followed by payload
// columns). We split it into two separate batches.
//
// TODO: Change SwissTableForJoinBuild interface to use key-payload
// batch instead to avoid this operation, which involves increasing
// shared pointer ref counts.
//
key_batch.length = input_batch.length;
for (size_t icol = 0; icol < key_batch.values.size(); ++icol) {
key_batch.values[icol] = input_batch.values[icol];
}

if (!no_payload) {
payload_batch.length = input_batch.length;
for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) {
payload_batch.values[icol] =
input_batch.values[schema->num_cols(HashJoinProjection::KEY) + icol];
}
}

RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->ProcessPartition(
thread_id, batch_id, static_cast<int>(prtn_id), key_batch,
no_payload ? nullptr : &payload_batch, temp_stack)));
}

return Status::OK();
}
Expand Down Expand Up @@ -2897,6 +2923,7 @@ class SwissJoin : public HashJoinImpl {
const HashJoinProjectionMaps* schema_[2];

// Task scheduling
int task_group_partition_;
int task_group_build_;
int task_group_merge_;
int task_group_scan_;
Expand Down
Loading
Loading