From 281ca6de76ea6cddfc602a66076f4ebd63ca143b Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Thu, 20 Feb 2025 19:04:24 +0800 Subject: [PATCH] Refine --- cpp/src/arrow/acero/swiss_join.cc | 28 +++++++++-------------- cpp/src/arrow/acero/swiss_join_internal.h | 19 ++++++--------- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index dc4a75af87da1..ba7c16350f9cb 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -2600,27 +2600,18 @@ class SwissJoin : public HashJoinImpl { DCHECK_GT(build_side_batches_[batch_id].length, 0); const HashJoinProjectionMaps* schema = schema_[1]; - DCHECK_NE(hash_table_build_, nullptr); - ExecBatch input_batch; ARROW_ASSIGN_OR_RAISE( input_batch, KeyPayloadFromInput(/*side=*/1, &build_side_batches_[batch_id])); - // Split batch into key batch and optional payload batch - // - // Input batch is key-payload batch (key columns followed by payload - // columns). We split it into two separate batches. - // - // TODO: Change SwissTableForJoinBuild interface to use key-payload - // batch instead to avoid this operation, which involves increasing - // shared pointer ref counts. - // ExecBatch key_batch({}, input_batch.length); key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY)); for (size_t icol = 0; icol < key_batch.values.size(); ++icol) { key_batch.values[icol] = input_batch.values[icol]; } arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack; + + DCHECK_NE(hash_table_build_, nullptr); return hash_table_build_->PartitionBatch(static_cast(thread_id), batch_id, key_batch, temp_stack); } @@ -2628,6 +2619,7 @@ class SwissJoin : public HashJoinImpl { Status PartitionFinished(size_t thread_id) { RETURN_NOT_OK(status()); + DCHECK_NE(hash_table_build_, nullptr); return CancelIfNotOK( start_task_group_callback_(task_group_build_, hash_table_build_->num_prtns())); } @@ -2638,9 +2630,14 @@ class SwissJoin : public HashJoinImpl { } const HashJoinProjectionMaps* schema = schema_[1]; + DCHECK_NE(hash_table_build_, nullptr); bool no_payload = hash_table_build_->no_payload(); + ExecBatch key_batch, payload_batch; + key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY)); + if (!no_payload) { + payload_batch.values.resize(schema->num_cols(HashJoinProjection::PAYLOAD)); + } arrow::util::TempVectorStack* temp_stack = &local_states_[thread_id].stack; - DCHECK_NE(hash_table_build_, nullptr); for (int64_t batch_id = 0; batch_id < static_cast(build_side_batches_.batch_count()); ++batch_id) { @@ -2657,22 +2654,19 @@ class SwissJoin : public HashJoinImpl { // batch instead to avoid this operation, which involves increasing // shared pointer ref counts. // - ExecBatch key_batch({}, input_batch.length); - key_batch.values.resize(schema->num_cols(HashJoinProjection::KEY)); + key_batch.length = input_batch.length; for (size_t icol = 0; icol < key_batch.values.size(); ++icol) { key_batch.values[icol] = input_batch.values[icol]; } - ExecBatch payload_batch({}, input_batch.length); if (!no_payload) { - payload_batch.values.resize(schema->num_cols(HashJoinProjection::PAYLOAD)); + payload_batch.length = input_batch.length; for (size_t icol = 0; icol < payload_batch.values.size(); ++icol) { payload_batch.values[icol] = input_batch.values[schema->num_cols(HashJoinProjection::KEY) + icol]; } } - DCHECK_NE(hash_table_build_, nullptr); RETURN_NOT_OK(CancelIfNotOK(hash_table_build_->ProcessPartition( thread_id, batch_id, static_cast(prtn_id), key_batch, no_payload ? nullptr : &payload_batch, temp_stack))); diff --git a/cpp/src/arrow/acero/swiss_join_internal.h b/cpp/src/arrow/acero/swiss_join_internal.h index 48c5fc81ab0ec..7799c3a8349f0 100644 --- a/cpp/src/arrow/acero/swiss_join_internal.h +++ b/cpp/src/arrow/acero/swiss_join_internal.h @@ -529,22 +529,17 @@ class SwissTableForJoinBuild { const std::vector& payload_types, MemoryPool* pool, int64_t hardware_flags); - // In the first phase of parallel hash table build, threads pick unprocessed - // exec batches, partition the rows based on hash, and update all of the - // partitions with information related to that batch of rows. + // In the first phase of parallel hash table build, each thread picks unprocessed exec + // batches, hashes the batches and preserve the hashes, then partition the rows based on + // hashes. // Status PartitionBatch(size_t thread_id, int64_t batch_id, const ExecBatch& key_batch, arrow::util::TempVectorStack* temp_stack); - // In the first phase of parallel hash table build, threads pick unprocessed - // exec batches, partition the rows based on hash, and update all of the - // partitions with information related to that batch of rows. + // In the second phase of parallel hash table build, each thread picks the given + // partition of all batches, and updates that particular partition with information + // related to that batch of rows. // - // Status BuildBatch(size_t thread_id, int64_t batch_id, int prtn_id, - // const ExecBatch& key_batch, const ExecBatch* - // payload_batch_maybe_null, arrow::util::TempVectorStack* - // temp_stack); - Status ProcessPartition(size_t thread_id, int64_t batch_id, int prtn_id, const ExecBatch& key_batch, const ExecBatch* payload_batch_maybe_null, @@ -556,7 +551,7 @@ class SwissTableForJoinBuild { // Status PreparePrtnMerge(); - // Second phase of parallel hash table build. + // Third phase of parallel hash table build. // Each partition can be processed by a different thread. // Parallel step. //