Skip to content

Commit

Permalink
More
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jan 30, 2024
1 parent bc4fe6f commit 9e3fdec
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ fn adjust_input_keys_ordering(
left,
right,
on,
filter,
join_type,
sort_options,
null_equals_null,
Expand All @@ -356,6 +357,7 @@ fn adjust_input_keys_ordering(
left.clone(),
right.clone(),
new_conditions.0,
filter.clone(),
*join_type,
new_conditions.1,
*null_equals_null,
Expand Down Expand Up @@ -635,6 +637,7 @@ pub(crate) fn reorder_join_keys_to_inputs(
left,
right,
on,
filter,
join_type,
sort_options,
null_equals_null,
Expand Down Expand Up @@ -664,6 +667,7 @@ pub(crate) fn reorder_join_keys_to_inputs(
left.clone(),
right.clone(),
new_join_on,
filter.clone(),
*join_type,
new_sort_options,
*null_equals_null,
Expand Down Expand Up @@ -1642,6 +1646,7 @@ pub(crate) mod tests {
left,
right,
join_on.clone(),
None,
*join_type,
vec![SortOptions::default(); join_on.len()],
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ fn try_swapping_with_sort_merge_join(
Arc::new(new_left),
Arc::new(new_right),
new_on,
sm_join.filter.clone(),
sm_join.join_type,
sm_join.sort_options.clone(),
sm_join.null_equals_null,
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/src/physical_optimizer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ pub fn sort_merge_join_exec(
left,
right,
join_on.clone(),
None,
*join_type,
vec![SortOptions::default(); join_on.len()],
false,
Expand Down
1 change: 1 addition & 0 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ async fn run_join_test(
left,
right,
on_columns.clone(),
None,
join_type,
vec![SortOptions::default(), SortOptions::default()],
false,
Expand Down
57 changes: 50 additions & 7 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use crate::{
};

use arrow::array::*;
use arrow::compute;
use arrow::compute::{concat_batches, take, SortOptions};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use arrow::error::ArrowError;
Expand Down Expand Up @@ -467,8 +468,9 @@ enum BufferedState {
Exhausted,
}

/// Represents a chunk of joined data from streamed and buffered side
struct StreamedJoinedChunk {
/// Index of batch buffered_data
/// Index of batch in buffered_data
buffered_batch_idx: Option<usize>,
/// Array builder for streamed indices
streamed_indices: UInt64Builder,
Expand All @@ -477,13 +479,17 @@ struct StreamedJoinedChunk {
}

struct StreamedBatch {
/// The streamed record batch
pub batch: RecordBatch,
/// The index of row in the streamed batch to compare with buffered batches
pub idx: usize,
/// The join key arrays of streamed batch which are used to compare with buffered batches
/// and to produce output. They are produced by evaluating `on` expressions.
pub join_arrays: Vec<ArrayRef>,

// Chunks of indices from buffered side (may be nulls) joined to streamed
/// Chunks of indices from buffered side (may be nulls) joined to streamed
pub output_indices: Vec<StreamedJoinedChunk>,
// Index of currently scanned batch from buffered data
/// Index of currently scanned batch from buffered data
pub buffered_batch_idx: Option<usize>,
}

Expand Down Expand Up @@ -516,6 +522,8 @@ impl StreamedBatch {
buffered_batch_idx: Option<usize>,
buffered_idx: Option<usize>,
) {
// If no current chunk exists or current chunk is not for current buffered batch,
// create a new chunk
if self.output_indices.is_empty() || self.buffered_batch_idx != buffered_batch_idx
{
self.output_indices.push(StreamedJoinedChunk {
Expand All @@ -527,6 +535,7 @@ impl StreamedBatch {
};
let current_chunk = self.output_indices.last_mut().unwrap();

// Append index of streamed batch and index of buffered batch into current chunk
current_chunk.streamed_indices.append_value(self.idx as u64);
if let Some(idx) = buffered_idx {
current_chunk.buffered_indices.append_value(idx as u64);
Expand Down Expand Up @@ -958,7 +967,9 @@ impl SMJStream {
/// Produce join and fill output buffer until reaching target batch size
/// or the join is finished
fn join_partial(&mut self) -> Result<()> {
// Whether to join streamed rows
let mut join_streamed = false;
// Whether to join buffered rows
let mut join_buffered = false;

// determine whether we need to join streamed/buffered rows
Expand Down Expand Up @@ -1006,11 +1017,13 @@ impl SMJStream {
{
let scanning_idx = self.buffered_data.scanning_idx();
if join_streamed {
// Join streamed row and buffered row
self.streamed_batch.append_output_pair(
Some(self.buffered_data.scanning_batch_idx),
Some(scanning_idx),
);
} else {
// Join nulls and buffered row
self.buffered_data
.scanning_batch_mut()
.null_joined
Expand Down Expand Up @@ -1074,6 +1087,7 @@ impl SMJStream {
}
buffered_batch.null_joined.clear();

// Take buffered (right) columns
let buffered_columns = buffered_batch
.batch
.columns()
Expand All @@ -1082,6 +1096,7 @@ impl SMJStream {
.collect::<Result<Vec<_>, ArrowError>>()
.map_err(Into::<DataFusionError>::into)?;

// Create null streamed (left) columns
let mut streamed_columns = self
.streamed_schema
.fields()
Expand All @@ -1092,8 +1107,22 @@ impl SMJStream {
streamed_columns.extend(buffered_columns);
let columns = streamed_columns;

self.output_record_batches
.push(RecordBatch::try_new(self.schema.clone(), columns)?);
let output_batch = RecordBatch::try_new(self.schema.clone(), columns)?;

// Apply join filter if any
let output_batch = if let Some(f) = &self.filter {
let filter_result = f
.expression()
.evaluate(&output_batch)?
.into_array(output_batch.num_rows())?;
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?;

compute::filter_record_batch(&output_batch, mask)?
} else {
output_batch
};

self.output_record_batches.push(output_batch);
}
Ok(())
}
Expand Down Expand Up @@ -1144,8 +1173,22 @@ impl SMJStream {
streamed_columns
};

self.output_record_batches
.push(RecordBatch::try_new(self.schema.clone(), columns)?);
let output_batch = RecordBatch::try_new(self.schema.clone(), columns)?;

// Apply join filter if any
let output_batch = if let Some(f) = &self.filter {
let filter_result = f
.expression()
.evaluate(&output_batch)?
.into_array(output_batch.num_rows())?;
let mask = datafusion_common::cast::as_boolean_array(&filter_result)?;

compute::filter_record_batch(&output_batch, mask)?
} else {
output_batch
};

self.output_record_batches.push(output_batch);
}

self.streamed_batch.output_indices.clear();
Expand Down

0 comments on commit 9e3fdec

Please sign in to comment.