Skip to content

Commit

Permalink
Refine
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Feb 21, 2025
1 parent d628fe6 commit 281ca6d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 29 deletions.
28 changes: 11 additions & 17 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2600,34 +2600,26 @@ class SwissJoin : public HashJoinImpl {
DCHECK_GT(build_side_batches_[batch_id].length, 0);

const HashJoinProjectionMaps* schema = schema_[1];
DCHECK_NE(hash_table_build_, nullptr);

ExecBatch input_batch;
ARROW_ASSIGN_OR_RAISE(
input_batch, KeyPayloadFromInput(/*side=*/1, &build_side_batches_[batch_id]));

// Split batch into key batch and optional payload batch
//
// Input batch is key-payload batch (key columns followed by payload
// columns). We split it into two separate batches.
//
// TODO: Change SwissTableForJoinBuild interface to use key-payload
// batch instead to avoid this operation, which involves increasing
// shared pointer ref counts.
//
ExecBatch key_batch({}, input_batch.length);
key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY));
for (size_t icol = 0; icol < key_batch.values.size(); ++icol) {
key_batch.values[icol] = input_batch.values[icol];
}
arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;

DCHECK_NE(hash_table_build_, nullptr);
return hash_table_build_->PartitionBatch(static_cast<int64_t>(thread_id), batch_id,
key_batch, temp_stack);
}

Status PartitionFinished(size_t thread_id) {
RETURN_NOT_OK(status());

DCHECK_NE(hash_table_build_, nullptr);
return CancelIfNotOK(
start_task_group_callback_(task_group_build_, hash_table_build_->num_prtns()));
}
Expand All @@ -2638,9 +2630,14 @@ class SwissJoin : public HashJoinImpl {
}

const HashJoinProjectionMaps* schema = schema_[1];
DCHECK_NE(hash_table_build_, nullptr);
bool no_payload = hash_table_build_->no_payload();
ExecBatch key_batch, payload_batch;
key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY));
if (!no_payload) {
payload_batch.values.resize(schema->num_cols(HashJoinProjection::PAYLOAD));
}
arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack;
DCHECK_NE(hash_table_build_, nullptr);

for (int64_t batch_id = 0;
batch_id < static_cast<int64_t>(build_side_batches_.batch_count()); ++batch_id) {
Expand All @@ -2657,22 +2654,19 @@ class SwissJoin : public HashJoinImpl {
// batch instead to avoid this operation, which involves increasing
// shared pointer ref counts.
//
ExecBatch key_batch({}, input_batch.length);
key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY));
key_batch.length = input_batch.length;
for (size_t icol = 0; icol < key_batch.values.size(); ++icol) {
key_batch.values[icol] = input_batch.values[icol];
}

ExecBatch payload_batch({}, input_batch.length);
if (!no_payload) {
payload_batch.values.resize(schema->num_cols(HashJoinProjection::PAYLOAD));
payload_batch.length = input_batch.length;
for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) {
payload_batch.values[icol] =
input_batch.values[schema->num_cols(HashJoinProjection::KEY) + icol];
}
}

DCHECK_NE(hash_table_build_, nullptr);
RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->ProcessPartition(
thread_id, batch_id, static_cast<int>(prtn_id), key_batch,
no_payload ? nullptr : &payload_batch, temp_stack)));
Expand Down
19 changes: 7 additions & 12 deletions cpp/src/arrow/acero/swiss_join_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,22 +529,17 @@ class SwissTableForJoinBuild {
const std::vector<KeyColumnMetadata>& payload_types, MemoryPool* pool,
int64_t hardware_flags);

// In the first phase of parallel hash table build, threads pick unprocessed
// exec batches, partition the rows based on hash, and update all of the
// partitions with information related to that batch of rows.
// In the first phase of parallel hash table build, each thread picks unprocessed exec
// batches, hashes the batches and preserve the hashes, then partition the rows based on
// hashes.
//
Status PartitionBatch(size_t thread_id, int64_t batch_id, const ExecBatch& key_batch,
arrow::util::TempVectorStack* temp_stack);

// In the first phase of parallel hash table build, threads pick unprocessed
// exec batches, partition the rows based on hash, and update all of the
// partitions with information related to that batch of rows.
// In the second phase of parallel hash table build, each thread picks the given
// partition of all batches, and updates that particular partition with information
// related to that batch of rows.
//
// Status BuildBatch(size_t thread_id, int64_t batch_id, int prtn_id,
// const ExecBatch& key_batch, const ExecBatch*
// payload_batch_maybe_null, arrow::util::TempVectorStack*
// temp_stack);

Status ProcessPartition(size_t thread_id, int64_t batch_id, int prtn_id,
const ExecBatch& key_batch,
const ExecBatch* payload_batch_maybe_null,
Expand All @@ -556,7 +551,7 @@ class SwissTableForJoinBuild {
//
Status PreparePrtnMerge();

// Second phase of parallel hash table build.
// Third phase of parallel hash table build.
// Each partition can be processed by a different thread.
// Parallel step.
//
Expand Down

0 comments on commit 281ca6d

Please sign in to comment.