Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support inner iejoin #12754

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0e478de
init iejoinexec.
my-vegetable-has-exploded Sep 21, 2024
9b552cc
init executionplan.
my-vegetable-has-exploded Sep 22, 2024
12da70e
wip
my-vegetable-has-exploded Sep 28, 2024
eca6cf8
basic implement iejoinstream.
my-vegetable-has-exploded Sep 30, 2024
d7d3dfd
..
my-vegetable-has-exploded Oct 1, 2024
b3b0e69
basic init.
my-vegetable-has-exploded Oct 1, 2024
2dd0635
impl planner.
my-vegetable-has-exploded Oct 2, 2024
24e516f
fix column index.
my-vegetable-has-exploded Oct 3, 2024
a8b509b
add ut.
my-vegetable-has-exploded Oct 4, 2024
0c3a893
fix swap operator.
my-vegetable-has-exploded Oct 4, 2024
ffbf265
add sqllogicaltest.
my-vegetable-has-exploded Oct 4, 2024
f04021d
fix cargo.lock.
my-vegetable-has-exploded Oct 4, 2024
acd8474
rm useless dependcy.
my-vegetable-has-exploded Oct 4, 2024
007c00b
fix sort partition.
my-vegetable-has-exploded Oct 5, 2024
ca296d3
fix test string.
my-vegetable-has-exploded Oct 5, 2024
b6633a7
fix tests & clippy
my-vegetable-has-exploded Oct 5, 2024
8110ecd
fix test contain.
my-vegetable-has-exploded Oct 5, 2024
4d48810
fix sort removed.
my-vegetable-has-exploded Oct 5, 2024
44d5f76
add more tests.
my-vegetable-has-exploded Oct 6, 2024
246811a
test generate_series.
my-vegetable-has-exploded Oct 8, 2024
4c3bd6c
test generate_series.
my-vegetable-has-exploded Oct 8, 2024
8c819a9
add more comments.
my-vegetable-has-exploded Oct 8, 2024
1738495
add metric
my-vegetable-has-exploded Oct 13, 2024
a67a720
fix permutation len.
my-vegetable-has-exploded Oct 13, 2024
9fcd867
fix metric
my-vegetable-has-exploded Oct 13, 2024
698fb5c
fix comment.
my-vegetable-has-exploded Oct 13, 2024
b246e7e
little update.
my-vegetable-has-exploded Oct 13, 2024
cde1f8f
use left_order.
my-vegetable-has-exploded Oct 16, 2024
dea673a
fix tests.
my-vegetable-has-exploded Oct 16, 2024
7d03765
fix clippy.
my-vegetable-has-exploded Oct 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix tests & clippy
my-vegetable-has-exploded committed Oct 5, 2024
commit b6633a7e70122c4c8c7502b92d70ba9fe433ffa4
54 changes: 31 additions & 23 deletions datafusion/physical-plan/src/joins/ie_join.rs
Original file line number Diff line number Diff line change
@@ -156,8 +156,7 @@ impl IEJoinExec {
Arc::new([condition_parts[0].0.clone(), condition_parts[1].0.clone()]);
let right_conditions =
Arc::new([condition_parts[0].1.clone(), condition_parts[1].1.clone()]);
let operators =
Arc::new([condition_parts[0].2.clone(), condition_parts[1].2.clone()]);
let operators = Arc::new([condition_parts[0].2, condition_parts[1].2]);
let sort_options = Arc::new([
operator_to_sort_option(operators[0]),
operator_to_sort_option(operators[1]),
@@ -305,17 +304,17 @@ impl ExecutionPlan for IEJoinExec {
collect_iejoin_data(
Arc::clone(&self.left),
Arc::clone(&self.right),
self.left_conditions.clone(),
self.right_conditions.clone(),
context.clone(),
Arc::clone(&self.left_conditions),
Arc::clone(&self.right_conditions),
Arc::clone(&context),
)
});
Ok(Box::pin(IEJoinStream {
schema: Arc::clone(&self.schema),
filter: self.filter.clone(),
_join_type: self.join_type,
operators: self.operators.clone(),
sort_options: self.sort_options.clone(),
operators: Arc::clone(&self.operators),
sort_options: Arc::clone(&self.sort_options),
iejoin_data,
column_indices: self.column_indices.clone(),
pairs: Arc::clone(&self.pairs),
@@ -360,7 +359,7 @@ impl SortedBlock {
.sort_options
.iter()
.map(|(i, opt)| SortColumn {
values: self.array[*i].clone(),
values: Arc::clone(&self.array[*i]),
options: Some(*opt),
})
.collect::<Vec<_>>();
@@ -408,8 +407,8 @@ async fn collect_iejoin_data(
context: Arc<TaskContext>,
) -> Result<IEJoinData> {
// the left and right data are sort by condition 1 already (the `try_iejoin` rewrite rule has done this), collect it directly
let left_data = collect(left, context.clone()).await?;
let right_data = collect(right, context.clone()).await?;
let left_data = collect(left, Arc::clone(&context)).await?;
let right_data = collect(right, Arc::clone(&context)).await?;
let left_blocks = left_data
.iter()
.map(|batch| {
@@ -525,6 +524,7 @@ impl IEJoinStream {
Poll::Ready(Some(Ok(batch)))
}

#[allow(clippy::too_many_arguments)]
fn compute(
left_block: &SortedBlock,
right_block: &SortedBlock,
@@ -625,10 +625,14 @@ impl IEJoinStream {
let n = left_block.array[0].len() as i64;
let m = right_block.array[0].len() as i64;
// concat the left block and right block
let cond1 =
concat(&[&left_block.array[0].clone(), &right_block.array[0].clone()])?;
let cond2 =
concat(&[&left_block.array[1].clone(), &right_block.array[1].clone()])?;
let cond1 = concat(&[
&Arc::clone(&left_block.array[0]),
&Arc::clone(&right_block.array[0]),
])?;
let cond2 = concat(&[
&Arc::clone(&left_block.array[1]),
&Arc::clone(&right_block.array[1]),
])?;
// store index of left table and right table
// -i in (-n..-1) means it is index i in left table, j in (1..m) means it is index j in right table
let indexes = concat(&[
@@ -677,11 +681,13 @@ impl IEJoinStream {
let l1 = l1.slice(0..valid as usize);

// l1_indexes[i] = j means the ith element of l1 is the jth element of original recordbatch
let l1_indexes = l1.arrays()[1].clone().as_primitive::<Int64Type>().clone();
let l1_indexes = Arc::clone(&l1.arrays()[1])
.as_primitive::<Int64Type>()
.clone();

// mark the order of l1, the index i means this element is the ith element of l1(sorted by condition 1)
let permutation = UInt64Array::from(
std::iter::successors(Some(0 as u64), |&x| {
std::iter::successors(Some(0_u64), |&x| {
if x < (valid as u64) {
Some(x + 1)
} else {
@@ -694,9 +700,9 @@ impl IEJoinStream {
let mut l2 = SortedBlock::new(
vec![
// condition 2
l1.arrays()[2].clone(),
Arc::clone(&l1.arrays()[2]),
// index of original recordbatch
l1.arrays()[1].clone(),
Arc::clone(&l1.arrays()[1]),
// index of l1
Arc::new(permutation),
],
@@ -719,7 +725,9 @@ impl IEJoinStream {

Ok((
l1_indexes,
l2.arrays()[2].clone().as_primitive::<UInt64Type>().clone(),
Arc::clone(&l2.arrays()[2])
.as_primitive::<UInt64Type>()
.clone(),
))
}

@@ -738,14 +746,14 @@ impl IEJoinStream {
if l1_index < 0 {
// index from left table
// insert p in to range_map
IEJoinStream::insert_range_map(&mut range_map, *p as u64);
IEJoinStream::insert_range_map(&mut range_map, *p);
continue;
}
// index from right table, remap to 0..m
let right_index = (l1_index - 1) as u64;
for range in range_map.range(0..(*p as u64)) {
for range in range_map.range(0..{ *p }) {
let (start, end) = range;
let (start, end) = (*start, std::cmp::min(*end, *p as u64));
let (start, end) = (*start, std::cmp::min(*end, *p));
for left_l1_index in start..end {
// get all p[i] in range(start, end) and remap it to original recordbatch index in left table
left_builder.append_value(
@@ -843,7 +851,7 @@ mod tests {
right,
ie_join_filter,
join_filter,
&join_type,
join_type,
partition_count,
)?;
let columns = columns(&ie_join.schema());
37 changes: 19 additions & 18 deletions datafusion/physical-plan/src/joins/utils.rs
Original file line number Diff line number Diff line change
@@ -401,9 +401,9 @@ pub fn swap_binary_expr(expr: &PhysicalExprRef) -> PhysicalExprRef {
Some(binary) => {
if let Some(swapped_op) = binary.op().swap() {
Arc::new(BinaryExpr::new(
Arc::clone(&binary.right()),
Arc::clone(binary.right()),
swapped_op,
Arc::clone(&binary.left()),
Arc::clone(binary.left()),
))
} else {
Arc::clone(expr)
@@ -415,27 +415,26 @@ pub fn swap_binary_expr(expr: &PhysicalExprRef) -> PhysicalExprRef {

/// Checks whether the inequality condition is valid.
/// The inequality condition is valid if the expressions are not null and the expressions are not equal, and left expression is from left schema and right expression is from right schema.
/// TODO: Maybe we can reorder the expressions to make it statisfy this condition later, like (right.b < left.a) -> (left.a > right.b).
pub fn check_inequality_condition(inequality_condition: &JoinFilter) -> Result<()> {
if let Some(binary) = inequality_condition
.expression()
.as_any()
.downcast_ref::<BinaryExpr>()
{
if !(is_ineuqality_operator(&binary.op()) && *binary.op() != Operator::NotEq) {
if !(is_ineuqality_operator(binary.op()) && *binary.op() != Operator::NotEq) {
return plan_err!(
"Inequality conditions must be an inequality binary expression, but got {}",
binary.op()
);
}
let column_indices = &inequality_condition.column_indices;
// check if left expression is from left table
let left_expr_columns = collect_columns(&binary.left());
let left_expr_columns = collect_columns(binary.left());
let left_expr_in_left = left_expr_columns
.iter()
.all(|c| column_indices[c.index()].side == JoinSide::Left);
// check if right expression is from right table
let right_expr_columns = collect_columns(&binary.right());
let right_expr_columns = collect_columns(binary.right());
let right_expr_in_right = right_expr_columns
.iter()
.all(|c| column_indices[c.index()].side == JoinSide::Right);
@@ -482,7 +481,7 @@ pub fn inequality_conditions_to_sort_exprs(
// remap the column in join schema to origin table, because we need to use the original column index to sort left and right table independently
let (left_map, right_map): (Vec<_>, Vec<_>) = filter
.column_indices()
.into_iter()
.iter()
.enumerate()
.partition(|(_, index)| index.side == JoinSide::Left);
let left_map = HashMap::from_iter(
@@ -493,14 +492,14 @@ pub fn inequality_conditions_to_sort_exprs(
);
Ok((
PhysicalSortExpr::new(
map_columns(Arc::clone(&binary.left()), &left_map)?,
map_columns(Arc::clone(binary.left()), &left_map)?,
sort_option,
),
PhysicalSortExpr::new(
map_columns(Arc::clone(&binary.right()), &right_map)?,
map_columns(Arc::clone(binary.right()), &right_map)?,
sort_option,
),
binary.op().clone(),
*binary.op(),
))
})
.collect()
@@ -1775,7 +1774,9 @@ mod tests {
use arrow_schema::SortOptions;

use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};
use datafusion_common::{
arrow_datafusion_err, arrow_err, assert_contains, ScalarValue,
};
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};

fn check(
@@ -2715,7 +2716,7 @@ mod tests {
let join_filter =
JoinFilter::new(filter, column_indices.clone(), intermediate_schema.clone());
let actual = format!("{:?}", check_inequality_condition(&join_filter));
assert_eq!(
assert_contains!(
actual,
"Err(Plan(\"Inequality conditions must be an inequality binary expression, but got !=\"))"
);
@@ -2728,7 +2729,7 @@ mod tests {
let join_filter =
JoinFilter::new(filter, column_indices.clone(), intermediate_schema.clone());
let actual = format!("{:?}", check_inequality_condition(&join_filter));
assert_eq!(
assert_contains!(
actual,
"Err(Plan(\"Inequality condition shouldn't be constant expression, but got x@0 > 8\"))"
);
@@ -2745,7 +2746,7 @@ mod tests {
let join_filter =
JoinFilter::new(filter, column_indices.clone(), intermediate_schema.clone());
let actual = format!("{:?}", check_inequality_condition(&join_filter));
assert_eq!(
assert_contains!(
actual,
"Err(Plan(\"Left/right side expression of inequality condition should be from left/right side of join, but got x@2 * y@1 and x@0\"))"
);
@@ -2763,8 +2764,8 @@ mod tests {
JoinFilter::new(filter, column_indices.clone(), intermediate_schema.clone());
let actual = format!("{:?}", check_inequality_condition(&join_filter));
assert_eq!(actual, "Ok(())");
let res = inequality_conditions_to_sort_exprs(&vec![join_filter])?;
let (left_expr, right_expr, operator) = res.get(0).unwrap();
let res = inequality_conditions_to_sort_exprs(&[join_filter])?;
let (left_expr, right_expr, operator) = res.first().unwrap();
assert_eq!(left_expr.to_string(), "x@1 + y@2 DESC NULLS LAST");
assert_eq!(right_expr.to_string(), "x@1 DESC NULLS LAST");
assert_eq!(*operator, Operator::GtEq);
@@ -2778,8 +2779,8 @@ mod tests {
JoinFilter::new(filter, column_indices.clone(), intermediate_schema.clone());
let actual = format!("{:?}", check_inequality_condition(&join_filter));
assert_eq!(actual, "Ok(())");
let res = inequality_conditions_to_sort_exprs(&vec![join_filter])?;
let (left_expr, right_expr, operator) = res.get(0).unwrap();
let res = inequality_conditions_to_sort_exprs(&[join_filter])?;
let (left_expr, right_expr, operator) = res.first().unwrap();
assert_eq!(left_expr.to_string(), "x@1 ASC NULLS LAST");
assert_eq!(right_expr.to_string(), "x@1 ASC NULLS LAST");
assert_eq!(*operator, Operator::Lt);