From 242190b89d68e5e36918c443c3a93327870b43b9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 12 Jun 2024 10:01:03 -0700 Subject: [PATCH] fix --- datafusion/core/tests/sql/joins.rs | 44 ++++++++++ .../src/joins/sort_merge_join.rs | 86 ++++++++++++------- 2 files changed, 100 insertions(+), 30 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index fad9b94b01120..1e690b45a09e0 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -235,3 +235,47 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { } Ok(()) } + +#[tokio::test] +async fn test_smj_right_filtered() -> Result<()> { + let ctx: SessionContext = SessionContext::new(); + + let sql = "set datafusion.optimizer.prefer_hash_join = false;"; + let _ = ctx.sql(sql).await?.collect().await?; + + let sql = "set datafusion.execution.batch_size = 100"; + let _ = ctx.sql(sql).await?.collect().await?; + + let sql = " + select * from ( + with t as ( + select id, id % 5 id1 from (select unnest(range(0,10)) id) + ), t1 as ( + select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) + ) + select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 + ) order by 1, 2, 3, 4 + "; + + let actual = ctx.sql(sql).await?.collect().await?; + + let expected: Vec<&str> = vec![ + "+----+-----+----+-----+", + "| id | id1 | id | id1 |", + "+----+-----+----+-----+", + "| 5 | 0 | 0 | 2 |", + "| 6 | 1 | 1 | 3 |", + "| 7 | 2 | 2 | 4 |", + "| 8 | 3 | 3 | 5 |", + "| 9 | 4 | 4 | 6 |", + "| | | 5 | 7 |", + "| | | 6 | 8 |", + "| | | 7 | 9 |", + "| | | 8 | 10 |", + "| | | 9 | 11 |", + "+----+-----+----+-----+", + ]; + datafusion_common::assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 8da345cdfca6e..77e07e17ea273 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -1149,6 +1149,7 @@ impl SMJStream { // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { for chunk in self.streamed_batch.output_indices.iter_mut() { + // The row indices of joined streamed batch let streamed_indices = chunk.streamed_indices.finish(); if streamed_indices.is_empty() { @@ -1163,6 +1164,7 @@ impl SMJStream { .map(|column| take(column, &streamed_indices, None)) .collect::, ArrowError>>()?; + // The row indices of joined buffered batch let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); let mut buffered_columns = if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { @@ -1174,6 +1176,8 @@ impl SMJStream { &buffered_indices, )? } else { + // If buffered batch none, meaning it is null joined batch. + // We need to create null arrays for buffered columns to join with streamed rows. self.buffered_schema .fields() .iter() @@ -1241,7 +1245,7 @@ impl SMJStream { let maybe_filtered_join_mask: Option<(BooleanArray, Vec)> = get_filtered_join_mask( self.join_type, - streamed_indices, + &streamed_indices, mask, &self.streamed_batch.join_filter_matched_idxs, &self.buffered_data.scanning_offset, @@ -1254,29 +1258,43 @@ impl SMJStream { .extend(&filtered_join_mask.1); } - // Push the filtered batch to the output + // Push the filtered batch which contains rows passing join filter to the output let filtered_batch = compute::filter_record_batch(&output_batch, mask)?; + println!("output_batch: {:?}", output_batch); + println!("mask: {:?}", mask); + println!("filtered_batch: {:?}", filtered_batch); self.output_record_batches.push(filtered_batch); - // For outer joins, we need to push the null joined rows to the output. + // For outer joins, we need to push the null joined rows to the output if + // all joined rows are failed on the join filter. + // I.e., if all rows joined from a streamed row are failed with the join filter, + // we need to join it with null columns from buffered side. if matches!( self.join_type, JoinType::Left | JoinType::Right | JoinType::Full ) { - // The reverse of the selection mask. For the rows not pass join filter above, - // we need to join them (left or right) with null rows for outer joins. - let not_mask = if mask.null_count() > 0 { - // If the mask contains nulls, we need to use `prep_null_mask_filter` to - // handle the nulls in the mask as false to produce rows where the mask - // was null itself. - compute::not(&compute::prep_null_mask_filter(mask))? - } else { - compute::not(mask)? - }; + // We need to get the mask for row indices that the joined rows are failed + // on the join filter. I.e., for a row in streamed side, if all joined rows + // between it and all buffered rows are failed on the join filter, we need to + // output it with null columns from buffered side. For the mask here, it + // behaves like LeftAnti join. + let null_mask: BooleanArray = get_filtered_join_mask( + // Set a mask slot as true if all joined rows of same streamed index are failed on the join filter. + // The behavior is like LeftAnti join. + JoinType::LeftAnti, + &streamed_indices, + mask, + &self.streamed_batch.join_filter_matched_idxs, + &self.buffered_data.scanning_offset, + ) + .unwrap() + .0; + println!("null_mask: {:?}", null_mask); let null_joined_batch = - compute::filter_record_batch(&output_batch, ¬_mask)?; + compute::filter_record_batch(&output_batch, &null_mask)?; + println!("null_joined_batch: {:?}", null_joined_batch); let mut buffered_columns = self .buffered_schema @@ -1317,16 +1335,20 @@ impl SMJStream { RecordBatch::try_new(self.schema.clone(), columns.clone())?; self.output_record_batches.push(null_joined_streamed_batch); - // For full join, we also need to output the null joined rows from the buffered side + // For full join, we also need to output the null joined rows from the buffered side. + // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with + // streamed side, it won't be outputted by `freeze_buffered`. So we need to output it here. + // We need to check if a buffered row is joined with streamed side and output, if not, + // we need to output it with null columns from streamed side. if matches!(self.join_type, JoinType::Full) { // Handle not mask for buffered side further. // For buffered side, we want to output the rows that are not null joined with // the streamed side. i.e. the rows that are not null in the `buffered_indices`. let not_mask = if let Some(nulls) = buffered_indices.nulls() { - let mask = not_mask.values() & nulls.inner(); + let mask = null_mask.values() & nulls.inner(); BooleanArray::new(mask, None) } else { - not_mask + null_mask }; let null_joined_batch = @@ -1445,9 +1467,13 @@ fn get_buffered_columns( /// `streamed_indices` have the same length as `mask` /// `matched_indices` array of streaming indices that already has a join filter match /// `scanning_buffered_offset` current buffered offset across batches +/// +/// This return a tuple of: +/// - corrected mask with respect to the join type +/// - indices of rows in streamed batch that have a join filter match fn get_filtered_join_mask( join_type: JoinType, - streamed_indices: UInt64Array, + streamed_indices: &UInt64Array, mask: &BooleanArray, matched_indices: &HashSet, scanning_buffered_offset: &usize, @@ -2808,7 +2834,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 0, 1, 1]), + &UInt64Array::from(vec![0, 0, 1, 1]), &BooleanArray::from(vec![true, true, false, false]), &HashSet::new(), &0, @@ -2819,7 +2845,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, true]), &HashSet::new(), &0, @@ -2830,7 +2856,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![false, true]), &HashSet::new(), &0, @@ -2841,7 +2867,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, false]), &HashSet::new(), &0, @@ -2852,7 +2878,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, true, true, true, true, true]), &HashSet::new(), &0, @@ -2866,7 +2892,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftSemi, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, false, false, false, false, true]), &HashSet::new(), &0, @@ -2885,7 +2911,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 0, 1, 1]), + &UInt64Array::from(vec![0, 0, 1, 1]), &BooleanArray::from(vec![true, true, false, false]), &HashSet::new(), &0, @@ -2896,7 +2922,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, true]), &HashSet::new(), &0, @@ -2907,7 +2933,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![false, true]), &HashSet::new(), &0, @@ -2918,7 +2944,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 1]), + &UInt64Array::from(vec![0, 1]), &BooleanArray::from(vec![true, false]), &HashSet::new(), &0, @@ -2929,7 +2955,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, true, true, true, true, true]), &HashSet::new(), &0, @@ -2943,7 +2969,7 @@ mod tests { assert_eq!( get_filtered_join_mask( LeftAnti, - UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), + &UInt64Array::from(vec![0, 0, 0, 1, 1, 1]), &BooleanArray::from(vec![false, false, false, false, false, true]), &HashSet::new(), &0,