Skip to content

Commit

Permalink
fix: duplicate output for HashJoinExec in CollectLeft mode
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Mar 21, 2024
1 parent eda2ddf commit c2c7a96
Showing 1 changed file with 131 additions and 4 deletions.
135 changes: 131 additions & 4 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1640,26 +1640,73 @@ mod tests {
join_type: &JoinType,
null_equals_null: bool,
context: Arc<TaskContext>,
) -> Result<(Vec<String>, Vec<RecordBatch>)> {
join_collect_with_partition_mode(
left,
right,
on,
join_type,
PartitionMode::Partitioned,
null_equals_null,
context,
)
.await
}

async fn join_collect_with_partition_mode(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
on: JoinOn,
join_type: &JoinType,
partition_mode: PartitionMode,
null_equals_null: bool,
context: Arc<TaskContext>,
) -> Result<(Vec<String>, Vec<RecordBatch>)> {
let partition_count = 4;

let (left_expr, right_expr) =
on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip();

let join = HashJoinExec::try_new(
Arc::new(RepartitionExec::try_new(
let left_repartitioned: Arc<dyn ExecutionPlan> = match partition_mode {
PartitionMode::CollectLeft => Arc::new(CoalescePartitionsExec::new(left)),
PartitionMode::Partitioned => Arc::new(RepartitionExec::try_new(
left,
Partitioning::Hash(left_expr, partition_count),
)?),
Arc::new(RepartitionExec::try_new(
PartitionMode::Auto => {
return internal_err!("Unexpected PartitionMode::Auto in join tests")
}
};

let right_repartitioned: Arc<dyn ExecutionPlan> = match partition_mode {
PartitionMode::CollectLeft => {
let partition_column_name = right.schema().field(0).name().clone();
let partition_expr = vec![Arc::new(Column::new_with_schema(
&partition_column_name,
&right.schema(),
)?) as _];
Arc::new(RepartitionExec::try_new(
right,
Partitioning::Hash(partition_expr, partition_count),
)?) as _
}
PartitionMode::Partitioned => Arc::new(RepartitionExec::try_new(
right,
Partitioning::Hash(right_expr, partition_count),
)?),
PartitionMode::Auto => {
return internal_err!("Unexpected PartitionMode::Auto in join tests")
}
};

let join = HashJoinExec::try_new(
left_repartitioned,
right_repartitioned,
on,
None,
join_type,
None,
PartitionMode::Partitioned,
partition_mode,
null_equals_null,
)?;

Expand Down Expand Up @@ -3312,6 +3359,86 @@ mod tests {
Ok(())
}

/// Test for parallelised HashJoinExec with PartitionMode::CollectLeft
#[tokio::test]
async fn test_collect_left_multiple_partitions_join() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let left = build_table(
("a1", &vec![1, 2, 3]),
("b1", &vec![4, 5, 7]),
("c1", &vec![7, 8, 9]),
);
let right = build_table(
("a2", &vec![10, 20, 30]),
("b2", &vec![4, 5, 6]),
("c2", &vec![70, 80, 90]),
);
let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap()) as _,
Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _,
)];

let expected_inner = vec![
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b2 | c2 |",
"+----+----+----+----+----+----+",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"+----+----+----+----+----+----+",
];
let expected_left = vec![
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b2 | c2 |",
"+----+----+----+----+----+----+",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"| 3 | 7 | 9 | | | |",
"+----+----+----+----+----+----+",
];
let expected_right = vec![
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b2 | c2 |",
"+----+----+----+----+----+----+",
"| | | | 30 | 6 | 90 |",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"+----+----+----+----+----+----+",
];
let expected_full = vec![
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b2 | c2 |",
"+----+----+----+----+----+----+",
"| | | | 30 | 6 | 90 |",
"| 1 | 4 | 7 | 10 | 4 | 70 |",
"| 2 | 5 | 8 | 20 | 5 | 80 |",
"| 3 | 7 | 9 | | | |",
"+----+----+----+----+----+----+",
];

let test_cases = vec![
(JoinType::Inner, expected_inner),
(JoinType::Left, expected_left),
(JoinType::Right, expected_right),
(JoinType::Full, expected_full),
];

for (join_type, expected) in test_cases {
let (_, batches) = join_collect_with_partition_mode(
left.clone(),
right.clone(),
on.clone(),
&join_type,
PartitionMode::CollectLeft,
false,
task_ctx.clone(),
)
.await?;
assert_batches_sorted_eq!(expected, &batches);
}

Ok(())
}

#[tokio::test]
async fn join_date32() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Expand Down

0 comments on commit c2c7a96

Please sign in to comment.