Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jun 12, 2024
1 parent 97ea05c commit 242190b
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 30 deletions.
44 changes: 44 additions & 0 deletions datafusion/core/tests/sql/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,47 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> {
}
Ok(())
}

#[tokio::test]
async fn test_smj_right_filtered() -> Result<()> {
let ctx: SessionContext = SessionContext::new();

let sql = "set datafusion.optimizer.prefer_hash_join = false;";
let _ = ctx.sql(sql).await?.collect().await?;

let sql = "set datafusion.execution.batch_size = 100";
let _ = ctx.sql(sql).await?.collect().await?;

let sql = "
select * from (
with t as (
select id, id % 5 id1 from (select unnest(range(0,10)) id)
), t1 as (
select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id)
)
select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1
) order by 1, 2, 3, 4
";

let actual = ctx.sql(sql).await?.collect().await?;

let expected: Vec<&str> = vec![
"+----+-----+----+-----+",
"| id | id1 | id | id1 |",
"+----+-----+----+-----+",
"| 5 | 0 | 0 | 2 |",
"| 6 | 1 | 1 | 3 |",
"| 7 | 2 | 2 | 4 |",
"| 8 | 3 | 3 | 5 |",
"| 9 | 4 | 4 | 6 |",
"| | | 5 | 7 |",
"| | | 6 | 8 |",
"| | | 7 | 9 |",
"| | | 8 | 10 |",
"| | | 9 | 11 |",
"+----+-----+----+-----+",
];
datafusion_common::assert_batches_eq!(expected, &actual);

Ok(())
}
86 changes: 56 additions & 30 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,7 @@ impl SMJStream {
// for current streamed batch and clears staged output indices.
fn freeze_streamed(&mut self) -> Result<()> {
for chunk in self.streamed_batch.output_indices.iter_mut() {
// The row indices of joined streamed batch
let streamed_indices = chunk.streamed_indices.finish();

if streamed_indices.is_empty() {
Expand All @@ -1163,6 +1164,7 @@ impl SMJStream {
.map(|column| take(column, &streamed_indices, None))
.collect::<Result<Vec<_>, ArrowError>>()?;

// The row indices of joined buffered batch
let buffered_indices: UInt64Array = chunk.buffered_indices.finish();
let mut buffered_columns =
if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
Expand All @@ -1174,6 +1176,8 @@ impl SMJStream {
&buffered_indices,
)?
} else {
// If buffered batch none, meaning it is null joined batch.
// We need to create null arrays for buffered columns to join with streamed rows.
self.buffered_schema
.fields()
.iter()
Expand Down Expand Up @@ -1241,7 +1245,7 @@ impl SMJStream {
let maybe_filtered_join_mask: Option<(BooleanArray, Vec<u64>)> =
get_filtered_join_mask(
self.join_type,
streamed_indices,
&streamed_indices,
mask,
&self.streamed_batch.join_filter_matched_idxs,
&self.buffered_data.scanning_offset,
Expand All @@ -1254,29 +1258,43 @@ impl SMJStream {
.extend(&filtered_join_mask.1);
}

// Push the filtered batch to the output
// Push the filtered batch which contains rows passing join filter to the output
let filtered_batch =
compute::filter_record_batch(&output_batch, mask)?;
println!("output_batch: {:?}", output_batch);
println!("mask: {:?}", mask);
println!("filtered_batch: {:?}", filtered_batch);
self.output_record_batches.push(filtered_batch);

// For outer joins, we need to push the null joined rows to the output.
// For outer joins, we need to push the null joined rows to the output if
// all joined rows are failed on the join filter.
// I.e., if all rows joined from a streamed row are failed with the join filter,
// we need to join it with null columns from buffered side.
if matches!(
self.join_type,
JoinType::Left | JoinType::Right | JoinType::Full
) {
// The reverse of the selection mask. For the rows not pass join filter above,
// we need to join them (left or right) with null rows for outer joins.
let not_mask = if mask.null_count() > 0 {
// If the mask contains nulls, we need to use `prep_null_mask_filter` to
// handle the nulls in the mask as false to produce rows where the mask
// was null itself.
compute::not(&compute::prep_null_mask_filter(mask))?
} else {
compute::not(mask)?
};
// We need to get the mask for row indices that the joined rows are failed
// on the join filter. I.e., for a row in streamed side, if all joined rows
// between it and all buffered rows are failed on the join filter, we need to
// output it with null columns from buffered side. For the mask here, it
// behaves like LeftAnti join.
let null_mask: BooleanArray = get_filtered_join_mask(
// Set a mask slot as true if all joined rows of same streamed index are failed on the join filter.
// The behavior is like LeftAnti join.
JoinType::LeftAnti,
&streamed_indices,
mask,
&self.streamed_batch.join_filter_matched_idxs,
&self.buffered_data.scanning_offset,
)
.unwrap()
.0;

println!("null_mask: {:?}", null_mask);
let null_joined_batch =
compute::filter_record_batch(&output_batch, &not_mask)?;
compute::filter_record_batch(&output_batch, &null_mask)?;
println!("null_joined_batch: {:?}", null_joined_batch);

let mut buffered_columns = self
.buffered_schema
Expand Down Expand Up @@ -1317,16 +1335,20 @@ impl SMJStream {
RecordBatch::try_new(self.schema.clone(), columns.clone())?;
self.output_record_batches.push(null_joined_streamed_batch);

// For full join, we also need to output the null joined rows from the buffered side
// For full join, we also need to output the null joined rows from the buffered side.
// Usually this is done by `freeze_buffered`. However, if a buffered row is joined with
// streamed side, it won't be outputted by `freeze_buffered`. So we need to output it here.
// We need to check if a buffered row is joined with streamed side and output, if not,
// we need to output it with null columns from streamed side.
if matches!(self.join_type, JoinType::Full) {
// Handle not mask for buffered side further.
// For buffered side, we want to output the rows that are not null joined with
// the streamed side. i.e. the rows that are not null in the `buffered_indices`.
let not_mask = if let Some(nulls) = buffered_indices.nulls() {
let mask = not_mask.values() & nulls.inner();
let mask = null_mask.values() & nulls.inner();
BooleanArray::new(mask, None)
} else {
not_mask
null_mask
};

let null_joined_batch =
Expand Down Expand Up @@ -1445,9 +1467,13 @@ fn get_buffered_columns(
/// `streamed_indices` have the same length as `mask`
/// `matched_indices` array of streaming indices that already has a join filter match
/// `scanning_buffered_offset` current buffered offset across batches
///
/// This return a tuple of:
/// - corrected mask with respect to the join type
/// - indices of rows in streamed batch that have a join filter match
fn get_filtered_join_mask(
join_type: JoinType,
streamed_indices: UInt64Array,
streamed_indices: &UInt64Array,
mask: &BooleanArray,
matched_indices: &HashSet<u64>,
scanning_buffered_offset: &usize,
Expand Down Expand Up @@ -2808,7 +2834,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 1, 1]),
&UInt64Array::from(vec![0, 0, 1, 1]),
&BooleanArray::from(vec![true, true, false, false]),
&HashSet::new(),
&0,
Expand All @@ -2819,7 +2845,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, true]),
&HashSet::new(),
&0,
Expand All @@ -2830,7 +2856,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![false, true]),
&HashSet::new(),
&0,
Expand All @@ -2841,7 +2867,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 1]),
&UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, false]),
&HashSet::new(),
&0,
Expand All @@ -2852,7 +2878,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, true, true, true, true, true]),
&HashSet::new(),
&0,
Expand All @@ -2866,7 +2892,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftSemi,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, false, false, false, false, true]),
&HashSet::new(),
&0,
Expand All @@ -2885,7 +2911,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
UInt64Array::from(vec![0, 0, 1, 1]),
&UInt64Array::from(vec![0, 0, 1, 1]),
&BooleanArray::from(vec![true, true, false, false]),
&HashSet::new(),
&0,
Expand All @@ -2896,7 +2922,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
UInt64Array::from(vec![0, 1]),
&UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, true]),
&HashSet::new(),
&0,
Expand All @@ -2907,7 +2933,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
UInt64Array::from(vec![0, 1]),
&UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![false, true]),
&HashSet::new(),
&0,
Expand All @@ -2918,7 +2944,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
UInt64Array::from(vec![0, 1]),
&UInt64Array::from(vec![0, 1]),
&BooleanArray::from(vec![true, false]),
&HashSet::new(),
&0,
Expand All @@ -2929,7 +2955,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, true, true, true, true, true]),
&HashSet::new(),
&0,
Expand All @@ -2943,7 +2969,7 @@ mod tests {
assert_eq!(
get_filtered_join_mask(
LeftAnti,
UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&UInt64Array::from(vec![0, 0, 0, 1, 1, 1]),
&BooleanArray::from(vec![false, false, false, false, false, true]),
&HashSet::new(),
&0,
Expand Down

0 comments on commit 242190b

Please sign in to comment.