Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support inner iejoin #12754

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
0e478de
init iejoinexec.
my-vegetable-has-exploded Sep 21, 2024
9b552cc
init executionplan.
my-vegetable-has-exploded Sep 22, 2024
12da70e
wip
my-vegetable-has-exploded Sep 28, 2024
eca6cf8
basic implement iejoinstream.
my-vegetable-has-exploded Sep 30, 2024
d7d3dfd
..
my-vegetable-has-exploded Oct 1, 2024
b3b0e69
basic init.
my-vegetable-has-exploded Oct 1, 2024
2dd0635
impl planner.
my-vegetable-has-exploded Oct 2, 2024
24e516f
fix column index.
my-vegetable-has-exploded Oct 3, 2024
a8b509b
add ut.
my-vegetable-has-exploded Oct 4, 2024
0c3a893
fix swap operator.
my-vegetable-has-exploded Oct 4, 2024
ffbf265
add sqllogicaltest.
my-vegetable-has-exploded Oct 4, 2024
f04021d
fix cargo.lock.
my-vegetable-has-exploded Oct 4, 2024
acd8474
rm useless dependcy.
my-vegetable-has-exploded Oct 4, 2024
007c00b
fix sort partition.
my-vegetable-has-exploded Oct 5, 2024
ca296d3
fix test string.
my-vegetable-has-exploded Oct 5, 2024
b6633a7
fix tests & clippy
my-vegetable-has-exploded Oct 5, 2024
8110ecd
fix test contain.
my-vegetable-has-exploded Oct 5, 2024
4d48810
fix sort removed.
my-vegetable-has-exploded Oct 5, 2024
44d5f76
add more tests.
my-vegetable-has-exploded Oct 6, 2024
246811a
test generate_series.
my-vegetable-has-exploded Oct 8, 2024
4c3bd6c
test generate_series.
my-vegetable-has-exploded Oct 8, 2024
8c819a9
add more comments.
my-vegetable-has-exploded Oct 8, 2024
1738495
add metric
my-vegetable-has-exploded Oct 13, 2024
a67a720
fix permutation len.
my-vegetable-has-exploded Oct 13, 2024
9fcd867
fix metric
my-vegetable-has-exploded Oct 13, 2024
698fb5c
fix comment.
my-vegetable-has-exploded Oct 13, 2024
b246e7e
little update.
my-vegetable-has-exploded Oct 13, 2024
cde1f8f
use left_order.
my-vegetable-has-exploded Oct 16, 2024
dea673a
fix tests.
my-vegetable-has-exploded Oct 16, 2024
7d03765
fix clippy.
my-vegetable-has-exploded Oct 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add more comments.
my-vegetable-has-exploded committed Oct 8, 2024
commit 8c819a923e1bda623f9473cf173bfb3820cdc849
81 changes: 51 additions & 30 deletions datafusion/physical-plan/src/joins/ie_join.rs
Original file line number Diff line number Diff line change
@@ -52,7 +52,7 @@ use futures::{ready, Stream};
use parking_lot::Mutex;

/// IEJoinExec is optimized join without any equijoin conditions in `ON` clause but with two or more inequality conditions.
/// For more detail algorithm, see https://vldb.org/pvldb/vol8/p2074-khayyat.pdf
/// For more detail algorithm, see <https://vldb.org/pvldb/vol8/p2074-khayyat.pdf>
///
/// Take this query q as an example:
///
@@ -63,26 +63,26 @@ use parking_lot::Mutex;
/// There is no equijoin condition in the `ON` clause, but there are two inequality conditions.
/// Currently, left table is t1, right table is t2.
///
/// The berif idea of this algorithm is converting it to ordered pair/inversion pair of permutation problem. For a permutation of a[0..n-1], for a pairs (i, j) such that i < j and a[i] < a[j], we call it an ordered pair of permutation.
/// The berif idea of this algorithm is converting it to ordered pair/inversion pair of permutation problem. For a permutation of a[0..n-1], for a pairs (i, j) such that i < j and a\[i\] < a\[j\], we call it an ordered pair of permutation.
///
/// For example, for a[0..4] = [2, 1, 3, 0], there are 2 ordered pairs: (2, 3), (1, 3)
///
/// To convert query q to ordered pair of permutation problem. We will do the following steps:
/// 1. Sort t1 union t2 by time in ascending order, mark the sorted table as l1.
/// 2. Sort t1 union t2 by cost in ascending order, mark the sorted table as l2.
/// 3. For each element e_i in l2, find the index j in l1 such that l1[j] = e_i, mark the computed index as permutation array p.
/// 4. Compute the inversion of permutation array p. For a pair (i, j) in l2, if i < j then e_i.cost < e_j.cost because l2 is sorted by cost in ascending order. And if p[i] < p[j], then e_i.time < e_j.time because l1 is sorted by time in ascending order.
/// 5. The result of query q is the pairs (i, j) in l2 such that i < j and p[i] < p[j] and e_i is from right table and e_j is from left table.
/// 3. For each element e_i in l2, find the index j in l1 such that l1\[j\] = e_i, mark the computed index as permutation array p. If p\[i\] = j, it means that the ith element in l2 is the jth element in l1.
/// 4. Compute the ordered pair of permutation array p. For a pair (i, j) in l2, if i < j then e_i.cost < e_j.cost because l2 is sorted by cost in ascending order. And if p\[i\] < p\[j\], then e_i.time < e_j.time because l1 is sorted by time in ascending order.
/// 5. The result of query q is the pairs (i, j) in l2 such that i < j and p\[i\] < p\[j\] and e_i is from right table and e_j is from left table.
///
/// To get the final result, we need to get all the pairs (i, j) in l2 such that i < j and p[i] < p[j] and e_i is from right table and e_j is from left table. We can do this by the following steps:
/// 1. Traverse l2 from left to right, at offset j, we can maintain BtreeSet or bitmap to record all the p[i] that i < j, then find all the pairs (i, j) in l2 such that p[i] < p[j].
/// To get the final result, we need to get all the pairs (i, j) in l2 such that i < j and p\[i\] < p\[j\] and e_i is from right table and e_j is from left table. We can do this by the following steps:
/// 1. Traverse l2 from left to right, at offset j, we can maintain BtreeSet or bitmap to record all the p\[i\] that i < j, then find all the pairs (i, j) in l2 such that p\[i\] < p\[j\].
///
/// To parallel the above algorithm, we can sort t1 and t2 by time (condition 1) firstly, and repartition the data into N partitions, then join t1[i] and t2[j] respectively. And if the minimum time of t1[i] is greater than the maximum time of t2[j], we can skip the join of t1[i] and t2[j] because there is no join result between them according to condition 1.
/// To parallel the above algorithm, we can sort t1 and t2 by time (condition 1) firstly, and repartition the data into N partitions, then join t1\[i\] and t2\[j\] respectively. And if the minimum time of t1\[i\] is greater than the maximum time of t2\[j\], we can skip the join of t1\[i\] and t2\[j\] because there is no join result between them according to condition 1.
#[derive(Debug)]
pub struct IEJoinExec {
/// left side
/// left side, which have been sorted by condition 1
pub(crate) left: Arc<dyn ExecutionPlan>,
/// right side
/// right side, which have been sorted by condition 1
pub(crate) right: Arc<dyn ExecutionPlan>,
/// inequality conditions for iejoin, for example, t1.time > t2.time and t1.cost < t2.cost, only support two inequality conditions, other conditions will be stored in `filter`
pub(crate) inequality_conditions: Vec<JoinFilter>,
@@ -92,16 +92,17 @@ pub struct IEJoinExec {
pub(crate) join_type: JoinType,
/// the schema once the join is applied
schema: SchemaRef,
/// data for iejoin
iejoin_data: OnceAsync<IEJoinData>,
/// left condition
/// left condition, it represents `t1.time asc` and `t1.cost asc` in above example
left_conditions: Arc<[PhysicalSortExpr; 2]>,
/// right condition
/// right condition, it represents `t2.time asc` and `t2.cost asc` in above example
right_conditions: Arc<[PhysicalSortExpr; 2]>,
/// operator of the inequality condition
operators: Arc<[Operator; 2]>,
/// sort options of the inequality condition
/// sort options of the inequality conditions, it represents `asc` and `asc` in above example
sort_options: Arc<[SortOptions; 2]>,
/// partition pairs
/// partition pairs, used to get the next pair of left and right blocks, IEJoinStream handles one pair of blocks each time
pairs: Arc<Mutex<u64>>,
/// Information of index and left / right placement of columns
column_indices: Vec<ColumnIndex>,
@@ -197,13 +198,13 @@ impl IEJoinExec {
schema,
&[false, false],
None,
// No on columns in nested loop join
// No on columns in iejoin
&[],
);

let output_partitioning = Partitioning::UnknownPartitioning(target_partitions);

// Determine execution mode:
// Determine execution mode
let mut mode = execution_mode_from_children([left, right]);
if mode.is_unbounded() {
mode = ExecutionMode::PipelineBreaking;
@@ -401,6 +402,7 @@ impl SortedBlock {
}
}

/// IEJoinData contains all data blocks from left and right side, and the data evaluated by condition 1 and condition 2 from left and right side
#[derive(Debug)]
pub struct IEJoinData {
/// collected left data after sort by condition 1
@@ -523,7 +525,7 @@ impl IEJoinStream {
}

// compute the join result
// TODO: return one batch if the result size larger than the batch size in config
// TODO: should return batches one by one if the result size larger than the batch size in config?
let batch = IEJoinStream::compute(
left_block,
right_block,
@@ -557,9 +559,11 @@ impl IEJoinStream {
operators,
)?;

// compute the join indices statify the inequality conditions
let (left_indices, right_indices) =
IEJoinStream::build_join_indices(&l1_indexes, &permutation)?;

// apply the filter to the join result
let (left_indices, right_indices) = if let Some(filter) = filter {
apply_join_filter_to_indices(
left_data,
@@ -605,27 +609,34 @@ impl IEJoinStream {
// get the valid count of right block
let m = right_block.arrays()[0].len() - right_block.arrays()[0].null_count();
// if the max valid element of right block is smaller than the min valid element of left block, there is no intersection
// for example, if left.a <= right.b, the left block is \[7, 8, 9\], the right block is \[2, 4, 6\], left\[0\] greater than right\[2\] so there is no intersection between left block and right block
// if left.a >= right.b, the left block is \[1, 0, 0\], the right block is \[6, 4, 2\], left\[0\] lesser than right\[2\] (because the sort options used in `make_comparator` is desc, so the compare result will be greater) so there is no intersection between left block and right block
if comparator(0, m - 1) == std::cmp::Ordering::Greater {
return false;
}
true
}

/// this function computes the permutation array of condition 2 on condition 1
/// this function computes l1_indexes array and the permutation array of condition 2 on condition 1
/// for example, if condition 1 is left.a <= right.b, condition 2 is left.x <= right.y
/// for left table, we have:
/// | id | a | x |
/// | id | a | x |
/// |-------|----|---|
/// | left1 | 1 | 7 |
/// | left2 | 3 | 4 |
/// for right table, we have:
/// | id | b | y |
/// |----|---|---|
/// | right1 | 2 | 5 |
/// | right2 | 4 | 6 |
/// Sort by condition 1, we get l1:
/// | value | 1 | 2 | 3 | 4 |
/// |-------|---|---|---|---|
/// | id | left1 | right1 | left2 | right2 |
/// The l1_indexes array is [1, -1, 2, -2], the negative value means it is the index of right table, the positive value means it is the index of left table, the absolute value is the index of original recordbatch
/// Sort by condition 2, we get l2:
/// | value | 4 | 5 | 6 | 7 |
/// |-------|---|---|---|---|
/// | id | left2 | right1 | right2 | left1 |
/// Then the permutation array is [2, 1, 3, 0]
/// The first element of l2 is left2, which is the 3rd element(index 2) of l1. The second element of l2 is right1, which is the 2nd element(index 1) of l1. And so on.
@@ -636,16 +647,16 @@ impl IEJoinStream {
operators: &[Operator; 2],
) -> Result<(Int64Array, UInt64Array)> {
// step1. sort the union block l1
let n = left_block.array[0].len() as i64;
let m = right_block.array[0].len() as i64;
let n = left_block.arrays()[0].len() as i64;
let m = right_block.arrays()[0].len() as i64;
// concat the left block and right block
let cond1 = concat(&[
&Arc::clone(&left_block.array[0]),
&Arc::clone(&right_block.array[0]),
&Arc::clone(&left_block.arrays()[0]),
&Arc::clone(&right_block.arrays()[0]),
])?;
let cond2 = concat(&[
&Arc::clone(&left_block.array[1]),
&Arc::clone(&right_block.array[1]),
&Arc::clone(&left_block.arrays()[1]),
&Arc::clone(&right_block.arrays()[1]),
])?;
// store index of left table and right table
// -i in (-n..-1) means it is index i in left table, j in (1..m) means it is index j in right table
@@ -677,11 +688,13 @@ impl IEJoinStream {
// otherwise, let the right index (> 0) in forward of left index (< 0)
// for example, t1.time <= t2.time
// | value| 1 | 1 | 1 | 1 | 2 |
// |------|--------|--------|-------|-------|-------|
// | index| -2(l2) | -1(l2) | 1(r1) | 2(r2) | 3(r3) |
// if t1.time < t2.time
// |value| 1 | 1 | 1 | 1 | 2 |
// |-----|--------|--------|-------|-------|-------|
// |index| 2(r2) | 1(r1) | -1(l2) | -2(l1) | 3(r3) |
// according to this order, if i < j then value[i](from left table) and value[j](from right table) match the condition
// according to this order request, if i < j then value\[i\](from left table) and value\[j\](from right table) match the condition(t1.time <= t2.time or t1.time < t2.time)
descending: !is_loose_inequality_operator(&operators[0]),
nulls_first: false,
},
@@ -694,7 +707,7 @@ impl IEJoinStream {
let valid = (l1.arrays()[0].len() - l1.arrays()[0].null_count()) as i64;
let l1 = l1.slice(0..valid as usize);

// l1_indexes[i] = j means the ith element of l1 is the jth element of original recordbatch
// l1_indexes\[i\] = j means the ith element of l1 is the jth element of original recordbatch
let l1_indexes = Arc::clone(&l1.arrays()[1])
.as_primitive::<Int64Type>()
.clone();
@@ -745,14 +758,22 @@ impl IEJoinStream {
))
}

/// compute the join indices statify the inequality conditions
/// following the example in `compute_permutation`, the l1_indexes is \[1, -1, 2, -2\], the permutation is \[2, 1, 3, 0\]
/// range_map is empty at first
/// 1、 p\[0\] = 2, range_map is empty, l1_indexes\[2\] is greater than 0, it means 2nd element in l1 is from left table, insert(2) into range_map, range_map {(2, 3)}
/// 2、 p\[1\] = 1, no value less than p\[1\] in range_map, l1_indexes\[1\] is less than 0, it means 1st element in l1 is from right table, no need to insert(1) into range_map, range_map {(2, 3)}
/// 3、 p\[2\] = 3, found 2 less than p\[2\] in range_map, append all pairs (l1_indexes\[2\], l1_indexes\[3\]) to the indeices array, l1_indexes\[3\] is less than 0, it means 3rd element in l1 is from right table, no need to insert(3) into range_map, range_map {(1, 4)}
/// 4、 p\[3\] = 0, no value less than p\[3\] in range_map, insert(0) into range_map, range_map {(0, 1), (2, 3)}
/// The indices array is \[(2), (2)\]
fn build_join_indices(
l1_indexes: &Int64Array,
permutation: &UInt64Array,
) -> Result<(UInt64Array, UInt64Array)> {
let mut left_builder = UInt64Builder::new();
let mut right_builder = UInt64Builder::new();
// maintain all p[i], for i in 0..j.
// our target is to find all pair(i, j) that i<j and p[i] < p[j] and i from left table and j from right table here
// use btree map to maintain all p\[i\], for i in 0..j, map\[s\]=t means range \[s, t\) is valid
// our target is to find all pair(i, j) that i<j and p\[i\] < p\[j\] and i from left table and j from right table here
let mut range_map = BTreeMap::<u64, u64>::new();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my test, use btreemap is a little faster than bitmap (https://github.com/my-vegetable-has-exploded/arrow-datafusion/compare/iejoin...my-vegetable-has-exploded:arrow-datafusion:iejoin-bitmap?expand=1), 8.5s -> 7.5s. Currently, the main cost is sorting though.

for p in permutation.values().iter() {
// get the index of original recordbatch
@@ -769,7 +790,7 @@ impl IEJoinStream {
let (start, end) = range;
let (start, end) = (*start, std::cmp::min(*end, *p));
for left_l1_index in start..end {
// get all p[i] in range(start, end) and remap it to original recordbatch index in left table
// get all p\[i\] in range(start, end) and remap it to original recordbatch index in left table
left_builder.append_value(
(-unsafe { l1_indexes.value_unchecked(left_l1_index as usize) }
- 1) as u64,
2 changes: 2 additions & 0 deletions datafusion/physical-plan/src/joins/utils.rs
Original file line number Diff line number Diff line change
@@ -459,6 +459,8 @@ pub fn check_inequality_condition(inequality_condition: &JoinFilter) -> Result<(
Ok(())
}

/// convert inequality conditions to sort expressions of each side and the operator
/// for example, if the inequality condition is `a < b`, then the sort expressions for left and right side are `a asc` and `b asc` respectively
pub fn inequality_conditions_to_sort_exprs(
inequality_conditions: &[JoinFilter],
) -> Result<Vec<(PhysicalSortExpr, PhysicalSortExpr, Operator)>> {