diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index cc35255dfe29..bc3f64b3dd6c 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -78,12 +78,15 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, SortExpr, + DescribeTable, DmlStatement, Extension, Filter, JoinType, RecursiveQuery, SortExpr, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::utils::{conjunction, split_conjunction}; use datafusion_physical_expr::LexOrdering; +use datafusion_physical_plan::joins::utils::JoinFilter; +use datafusion_physical_plan::joins::IEJoinExec; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; use datafusion_sql::utils::window_expr_common_partition_keys; @@ -1063,14 +1066,24 @@ impl DefaultPhysicalPlanner { session_state.config_options().optimizer.prefer_hash_join; let join: Arc = if join_on.is_empty() { - // there is no equal join condition, use the nested loop join - // TODO optimize the plan, and use the config of `target_partitions` and `repartition_joins` - Arc::new(NestedLoopJoinExec::try_new( - physical_left, - physical_right, - join_filter, + // there is no equal join condition, try to use iejoin or use the nested loop join + if let Some(iejoin) = try_iejoin( + Arc::clone(&physical_left), + Arc::clone(&physical_right), + &join_filter, join_type, - )?) + session_state.config().target_partitions(), + )? { + iejoin + } else { + // TODO optimize the plan, and use the config of `target_partitions` and `repartition_joins` + Arc::new(NestedLoopJoinExec::try_new( + physical_left, + physical_right, + join_filter, + join_type, + )?) + } } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() && !prefer_hash_join @@ -1659,6 +1672,114 @@ pub fn create_physical_sort_expr( }) } +/// Try to create an IEJoin execution plan for join without equality conditions +pub fn try_iejoin( + left: Arc, + right: Arc, + filter: &Option, + join_type: &JoinType, + target_partitions: usize, +) -> Result>> { + if join_type != &JoinType::Inner { + // TODO: support other join types, only Inner join is supported currently + return Ok(None); + } + if let Some(filter) = filter { + // split filter into multiple conditions + let mut conditions = split_conjunction(filter.expression()); + // take first two inequality conditions, swap the binary expression if necessary + let inequality_conditions = conditions + .iter() + .enumerate() + .map(|(index, condition)| { + ( + index, + JoinFilter::new( + Arc::clone(condition), + filter.column_indices().to_vec(), + filter.schema().clone(), + ), + JoinFilter::new( + join_utils::swap_binary_expr(&Arc::clone(condition)), + filter.column_indices().to_vec(), + filter.schema().clone(), + ), + ) + }) + .map(|(index, condition, condition_swap)| { + ( + index, + condition.clone(), + join_utils::check_inequality_condition(&condition).is_ok(), + condition_swap.clone(), + join_utils::check_inequality_condition(&condition_swap).is_ok(), + ) + }) + .map( + |( + index, + condition, + condition_valid, + condition_swap, + condition_swap_valid, + )| { + if condition_valid { + (index, condition, true) + } else if condition_swap_valid { + (index, condition_swap, true) + } else { + (index, condition, false) + } + }, + ) + .filter(|(_, _, condition_valid)| *condition_valid) + .take(2) + .collect::>(); + // if inequality_conditions has less than 2 elements, return None + if inequality_conditions.len() < 2 { + return Ok(None); + } + // remove the taken inequality conditions from conditions + // remove from back to front to keep the index correct + for (index, _condition, _condition_valid) in inequality_conditions.iter().rev() { + conditions.remove(*index); + } + // create a new filter with the remaining conditions + let new_filter = conjunction(conditions); + let inequality_conditions = inequality_conditions + .iter() + .map(|(_, condition, _)| condition.clone()) + .collect::>(); + let sort_exprs = + join_utils::inequality_conditions_to_sort_exprs(&inequality_conditions)?; + // sort left and right by the condition 1 + let sorted_left = Arc::new(SortExec::new( + vec![sort_exprs[0].0.clone()], + Arc::clone(&left), + )); + let sorted_right = Arc::new(SortExec::new( + vec![sort_exprs[0].1.clone()], + Arc::clone(&right), + )); + Ok(Some(Arc::new(IEJoinExec::try_new( + sorted_left, + sorted_right, + inequality_conditions, + new_filter.map(|expr| { + join_utils::JoinFilter::new( + expr, + filter.column_indices().to_vec(), + filter.schema().clone(), + ) + }), + join_type, + target_partitions, + )?))) + } else { + Ok(None) + } +} + /// Create vector of physical sort expression from a vector of logical expression pub fn create_physical_sort_exprs( exprs: &[SortExpr], diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 4c37db4849a7..309dc1f4d346 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -39,6 +39,13 @@ use itertools::Itertools; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; +pub fn conjunction(exprs: Vec<&Arc>) -> Option> { + exprs + .iter() + .map(|expr| Arc::clone(expr)) + .reduce(|acc, expr| Arc::new(BinaryExpr::new(acc, Operator::And, expr))) +} + /// Assume the predicate is in the form of CNF, split the predicate to a Vec of PhysicalExprs. /// /// For example, split "a1 = a2 AND b1 <= b2 AND c1 != c2" into ["a1 = a2", "b1 <= b2", "c1 != c2"] @@ -214,6 +221,31 @@ pub fn collect_columns(expr: &Arc) -> HashSet { columns } +/// map physical columns according to given index mapping +pub fn map_columns( + expr: Arc, + mapping: &HashMap, +) -> Result> { + expr.transform(|expr| { + if let Some(column) = expr.as_any().downcast_ref::() { + let new_index = mapping.get(&column.index()).cloned(); + if let Some(new_index) = new_index { + return Ok(Transformed::yes(Arc::new(Column::new( + column.name(), + new_index, + )))); + } else { + return datafusion_common::internal_err!( + "column index {} not found in mapping", + column.index() + ); + } + } + Ok(Transformed::no(expr)) + }) + .data() +} + /// Re-assign column indices referenced in predicate according to given schema. /// This may be helpful when dealing with projections. pub fn reassign_predicate_columns( @@ -547,4 +579,26 @@ pub(crate) mod tests { assert_eq!(collect_columns(&expr3), expected); Ok(()) } + + #[test] + fn test_map_columns() -> Result<()> { + let col1 = Arc::new(Column::new("col1", 0)); + let col2 = Arc::new(Column::new("col2", 1)); + let col3 = Arc::new(Column::new("col3", 2)); + let expr = Arc::new(BinaryExpr::new(col1, Operator::Plus, col2)) as _; + let mapping = HashMap::from([(0, 2), (1, 0)]); + let mapped = map_columns(expr, &mapping)?; + assert_eq!( + mapped.as_ref(), + Arc::new(BinaryExpr::new( + Arc::new(Column::new("col1", 2)), + Operator::Plus, + Arc::new(Column::new("col2", 0)) + )) + .as_any() + ); + // test mapping with non-existing index + assert!(map_columns(col3, &mapping).is_err()); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/joins/ie_join.rs b/datafusion/physical-plan/src/joins/ie_join.rs new file mode 100644 index 000000000000..12a3c1252a2b --- /dev/null +++ b/datafusion/physical-plan/src/joins/ie_join.rs @@ -0,0 +1,1160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::BTreeMap; +use std::fmt::Formatter; +use std::ops::Range; +use std::sync::Arc; +use std::task::Poll; + +use crate::joins::utils::{ + apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, + check_inequality_condition, check_join_is_valid, estimate_join_statistics, + inequality_conditions_to_sort_exprs, is_loose_inequality_operator, ColumnIndex, + JoinFilter, OnceAsync, OnceFut, +}; +use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::sorts::sort::sort_batch; +use crate::{ + collect, execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution, + ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, +}; +use arrow::array::{make_comparator, AsArray, UInt64Builder}; + +use arrow::compute::concat; +use arrow::compute::kernels::sort::SortOptions; +use arrow::datatypes::{Int64Type, Schema, SchemaRef, UInt64Type}; +use arrow::record_batch::RecordBatch; +use arrow_array::{ArrayRef, Int64Array, UInt64Array}; +use datafusion_common::{plan_err, JoinSide, Result, Statistics}; +use datafusion_execution::TaskContext; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::equivalence::join_equivalence_properties; + +use datafusion_physical_expr::{Partitioning, PhysicalSortExpr, PhysicalSortRequirement}; +use futures::{ready, Stream}; +use parking_lot::Mutex; + +/// IEJoinExec is optimized join without any equijoin conditions in `ON` clause but with two or more inequality conditions. +/// For more detail algorithm, see +/// +/// Take this query q as an example: +/// +/// SELECT t1.t id, t2.t id +/// FROM west t1, west t2 +/// WHERE t1.time < t2.time AND t1.cost < t2.cost +/// +/// There is no equijoin condition in the `ON` clause, but there are two inequality conditions. +/// Currently, left table is t1, right table is t2. +/// +/// The berif idea of this algorithm is converting it to ordered pair/inversion pair of permutation problem. For a permutation of a[0..n-1], for a pairs (i, j) such that i < j and a\[i\] < a\[j\], we call it an ordered pair of permutation. +/// +/// For example, for a[0..4] = [2, 1, 3, 0], there are 2 ordered pairs: (2, 3), (1, 3) +/// +/// To convert query q to ordered pair of permutation problem. We will do the following steps: +/// 1. Sort t1 union t2 by time in ascending order, mark the sorted table as l1. +/// 2. Sort t1 union t2 by cost in ascending order, mark the sorted table as l2. +/// 3. For each element e_i in l2, find the index j in l1 such that l1\[j\] = e_i, mark the computed index as permutation array p. If p\[i\] = j, it means that the ith element in l2 is the jth element in l1. +/// 4. Compute the ordered pair of permutation array p. For a pair (i, j) in l2, if i < j then e_i.cost < e_j.cost because l2 is sorted by cost in ascending order. And if p\[i\] < p\[j\], then e_i.time < e_j.time because l1 is sorted by time in ascending order. +/// 5. The result of query q is the pairs (i, j) in l2 such that i < j and p\[i\] < p\[j\] and e_i is from right table and e_j is from left table. +/// +/// To get the final result, we need to get all the pairs (i, j) in l2 such that i < j and p\[i\] < p\[j\] and e_i is from right table and e_j is from left table. We can do this by the following steps: +/// 1. Traverse l2 from left to right, at offset j, we can maintain BtreeSet or bitmap to record all the p\[i\] that i < j, then find all the pairs (i, j) in l2 such that p\[i\] < p\[j\]. +/// See more detailed example in `compute_permutation` and `build_join_indices` function. +/// +/// To parallel the above algorithm, we can sort t1 and t2 by time (condition 1) firstly, and repartition the data into N partitions, then join t1\[i\] and t2\[j\] respectively. And if the minimum time of t1\[i\] is greater than the maximum time of t2\[j\], we can skip the join of t1\[i\] and t2\[j\] because there is no join result between them according to condition 1. +#[derive(Debug)] +pub struct IEJoinExec { + /// left side, which have been sorted by condition 1 + pub(crate) left: Arc, + /// right side, which have been sorted by condition 1 + pub(crate) right: Arc, + /// inequality conditions for iejoin, for example, t1.time > t2.time and t1.cost < t2.cost, only support two inequality conditions, other conditions will be stored in `filter` + pub(crate) inequality_conditions: Vec, + /// filters which are applied while finding matching rows + pub(crate) filter: Option, + /// how the join is performed + pub(crate) join_type: JoinType, + /// the schema once the join is applied + schema: SchemaRef, + /// data for iejoin + iejoin_data: OnceAsync, + /// left condition, it represents `t1.time asc` and `t1.cost asc` in above example + left_conditions: Arc<[PhysicalSortExpr; 2]>, + /// right condition, it represents `t2.time asc` and `t2.cost asc` in above example + right_conditions: Arc<[PhysicalSortExpr; 2]>, + /// operator of the inequality condition + operators: Arc<[Operator; 2]>, + /// sort options of the inequality conditions, it represents `asc` and `asc` in above example + sort_options: Arc<[SortOptions; 2]>, + /// partition pairs, used to get the next pair of left and right blocks, IEJoinStream handles one pair of blocks each time + pairs: Arc>, + /// Information of index and left / right placement of columns + column_indices: Vec, + // TODO: add memory reservation? + /// execution metrics + metrics: ExecutionPlanMetricsSet, + /// cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, +} + +impl IEJoinExec { + /// Try to create a new [`IEJoinExec`] + pub fn try_new( + left: Arc, + right: Arc, + inequality_conditions: Vec, + filter: Option, + join_type: &JoinType, + target_partitions: usize, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + check_join_is_valid(&left_schema, &right_schema, &[])?; + let (schema, column_indices) = + build_join_schema(&left_schema, &right_schema, join_type); + if inequality_conditions.len() != 2 { + return plan_err!( + "IEJoinExec only supports two inequality conditions, got {}", + inequality_conditions.len() + ); + } + for condition in &inequality_conditions { + check_inequality_condition(condition)?; + } + let schema = Arc::new(schema); + if !matches!(join_type, JoinType::Inner) { + return plan_err!( + "IEJoinExec only supports inner join currently, got {}", + join_type + ); + } + let cache = Self::compute_properties( + &left, + &right, + Arc::clone(&schema), + *join_type, + target_partitions, + ); + let condition_parts = + inequality_conditions_to_sort_exprs(&inequality_conditions)?; + let left_conditions = + Arc::new([condition_parts[0].0.clone(), condition_parts[1].0.clone()]); + let right_conditions = + Arc::new([condition_parts[0].1.clone(), condition_parts[1].1.clone()]); + let operators = Arc::new([condition_parts[0].2, condition_parts[1].2]); + let sort_options = Arc::new([ + operator_to_sort_option(operators[0]), + operator_to_sort_option(operators[1]), + ]); + + Ok(IEJoinExec { + left, + right, + inequality_conditions, + filter, + join_type: *join_type, + schema, + iejoin_data: Default::default(), + left_conditions, + right_conditions, + operators, + sort_options, + pairs: Arc::new(Mutex::new(0)), + column_indices, + metrics: Default::default(), + cache, + }) + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties( + left: &Arc, + right: &Arc, + schema: SchemaRef, + join_type: JoinType, + target_partitions: usize, + ) -> PlanProperties { + // Calculate equivalence properties: + let eq_properties = join_equivalence_properties( + left.equivalence_properties().clone(), + right.equivalence_properties().clone(), + &join_type, + schema, + &[false, false], + None, + // No on columns in iejoin + &[], + ); + + let output_partitioning = Partitioning::UnknownPartitioning(target_partitions); + + // Determine execution mode + let mut mode = execution_mode_from_children([left, right]); + if mode.is_unbounded() { + mode = ExecutionMode::PipelineBreaking; + } + + PlanProperties::new(eq_properties, output_partitioning, mode) + } +} + +impl DisplayAs for IEJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let display_filter = self.filter.as_ref().map_or_else( + || "".to_string(), + |f| format!(", filter={}", f.expression()), + ); + let display_inequality_conditions = self + .inequality_conditions + .iter() + .map(|c| format!("({})", c.expression())) + .collect::>() + .join(", "); + write!( + f, + "IEJoinExec: mode={:?}, join_type={:?}, inequality_conditions=[{}]{}", + self.cache.execution_mode, + self.join_type, + display_inequality_conditions, + display_filter, + ) + } + } + } +} + +/// convert operator to sort option for iejoin +/// for left.a <= right.b, the sort option is ascending order +/// for left.a >= right.b, the sort option is descending order +pub fn operator_to_sort_option(op: Operator) -> SortOptions { + match op { + Operator::Lt | Operator::LtEq => SortOptions { + descending: false, + nulls_first: false, + }, + Operator::Gt | Operator::GtEq => SortOptions { + descending: true, + nulls_first: false, + }, + _ => panic!("Unsupported operator"), + } +} + +impl ExecutionPlan for IEJoinExec { + fn name(&self) -> &'static str { + "IEJoinExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn required_input_distribution(&self) -> Vec { + vec![Distribution::SinglePartition, Distribution::SinglePartition] + } + + fn required_input_ordering( + &self, + ) -> Vec> { + // sort left and right data by condition 1 to prune not intersected RecordBatch pairs + vec![ + Some(PhysicalSortRequirement::from_sort_exprs(vec![ + &self.left_conditions[0], + ])), + Some(PhysicalSortRequirement::from_sort_exprs(vec![ + &self.right_conditions[0], + ])), + ] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(IEJoinExec::try_new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.inequality_conditions.clone(), + self.filter.clone(), + &self.join_type, + self.cache.output_partitioning().partition_count(), + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let join_metrics = IEJoinMetrics::new(partition, &self.metrics); + let iejoin_data = self.iejoin_data.once(|| { + collect_iejoin_data( + Arc::clone(&self.left), + Arc::clone(&self.right), + Arc::clone(&self.left_conditions), + Arc::clone(&self.right_conditions), + join_metrics.clone(), + Arc::clone(&context), + ) + }); + Ok(Box::pin(IEJoinStream { + schema: Arc::clone(&self.schema), + filter: self.filter.clone(), + _join_type: self.join_type, + operators: Arc::clone(&self.operators), + sort_options: Arc::clone(&self.sort_options), + iejoin_data, + column_indices: self.column_indices.clone(), + pairs: Arc::clone(&self.pairs), + finished: false, + join_metrics, + })) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + estimate_join_statistics( + Arc::clone(&self.left), + Arc::clone(&self.right), + vec![], + &self.join_type, + &self.schema, + ) + } +} + +/// Metrics for iejoin +#[derive(Debug, Clone)] +struct IEJoinMetrics { + /// Total time for collecting init data of both sides + pub(crate) load_time: metrics::Time, + /// Number of batches of left side + pub(crate) left_input_batches: metrics::Count, + /// Number of batches of right side + pub(crate) right_input_batches: metrics::Count, + /// Number of rows of left side + pub(crate) left_input_rows: metrics::Count, + /// Number of rows of right side + pub(crate) right_input_rows: metrics::Count, + /// Memory used by collecting init data + pub(crate) load_mem_used: metrics::Gauge, + /// Total time for joining intersection blocks of input table + pub(crate) join_time: metrics::Time, + /// Number of batches produced by this operator + pub(crate) output_batches: metrics::Count, + /// Number of rows produced by this operator + pub(crate) output_rows: metrics::Count, + /// Number of pairs of left and right blocks are skipped because of no intersection + pub(crate) skipped_pairs: metrics::Count, +} + +impl IEJoinMetrics { + pub fn new(partition: usize, metrics: &ExecutionPlanMetricsSet) -> Self { + let load_time = MetricBuilder::new(metrics).subset_time("load_time", partition); + let left_input_batches = + MetricBuilder::new(metrics).counter("left_input_batches", partition); + let right_input_batches = + MetricBuilder::new(metrics).counter("right_input_batches", partition); + let left_input_rows = + MetricBuilder::new(metrics).counter("left_input_rows", partition); + let right_input_rows = + MetricBuilder::new(metrics).counter("right_input_rows", partition); + let load_mem_used = MetricBuilder::new(metrics).gauge("load_mem_used", partition); + let join_time = MetricBuilder::new(metrics).subset_time("join_time", partition); + let output_batches = + MetricBuilder::new(metrics).counter("output_batches", partition); + let output_rows = MetricBuilder::new(metrics).counter("output_rows", partition); + let skipped_pairs = + MetricBuilder::new(metrics).counter("skipped_pairs", partition); + Self { + load_time, + left_input_batches, + right_input_batches, + left_input_rows, + right_input_rows, + load_mem_used, + join_time, + output_batches, + output_rows, + skipped_pairs, + } + } +} + +#[derive(Debug, Clone)] +/// SortedBlock contains arrays that are sorted by specified columns +// TODO: use struct support spill? +pub struct SortedBlock { + pub data: RecordBatch, + pub sort_options: Vec<(usize, SortOptions)>, +} + +impl SortedBlock { + pub fn new(array: Vec, sort_options: Vec<(usize, SortOptions)>) -> Self { + let schema = Arc::new(Schema::new({ + array + .iter() + .enumerate() + .map(|(i, array)| { + arrow_schema::Field::new( + format!("col{}", i), + array.data_type().clone(), + true, + ) + }) + .collect::>() + })); + let data = RecordBatch::try_new(schema, array).unwrap(); + Self { data, sort_options } + } + + /// sort the block by the specified columns + pub fn sort_by_columns(&mut self) -> Result<()> { + let sort_exprs = self + .sort_options + .iter() + .map(|(i, sort_options)| PhysicalSortExpr { + expr: Arc::new(datafusion_physical_expr::expressions::Column::new( + &format!("col{}", *i), + *i, + )), + options: *sort_options, + }) + .collect::>(); + self.data = sort_batch(&self.data, &sort_exprs, None)?; + Ok(()) + } + + pub fn arrays(&self) -> &[ArrayRef] { + self.data.columns() + } + + pub fn data(&self) -> &RecordBatch { + &self.data + } + + pub fn slice(&self, range: Range) -> Self { + let data = self.data.slice(range.start, range.len()); + SortedBlock { + data, + sort_options: self.sort_options.clone(), + } + } +} + +/// IEJoinData contains all data blocks from left and right side, and the data evaluated by condition 1 and condition 2 from left and right side +#[derive(Debug)] +pub struct IEJoinData { + /// collected left data after sort by condition 1 + pub left_data: Vec, + /// collected right data after sort by condition 1 + pub right_data: Vec, + /// sorted blocks of left data, contains the evaluated result of condition 1 and condition 2 + pub left_blocks: Vec, + /// sorted blocks of right data, contains the evaluated result of condition 1 and condition 2 + pub right_blocks: Vec, +} + +async fn collect_iejoin_data( + left: Arc, + right: Arc, + left_conditions: Arc<[PhysicalSortExpr; 2]>, + right_conditions: Arc<[PhysicalSortExpr; 2]>, + join_metrics: IEJoinMetrics, + context: Arc, +) -> Result { + // the left and right data are sort by condition 1 already (the `try_iejoin` rewrite rule has done this), collect it directly + let left_data = collect(left, Arc::clone(&context)).await?; + let right_data = collect(right, Arc::clone(&context)).await?; + let left_blocks = left_data + .iter() + .map(|batch| { + join_metrics.left_input_batches.add(1); + join_metrics.left_input_rows.add(batch.num_rows()); + join_metrics + .load_mem_used + .add(batch.get_array_memory_size()); + let columns = left_conditions + .iter() + .map(|expr| expr.expr.evaluate(batch)?.into_array(batch.num_rows())) + .collect::>>()?; + Ok(SortedBlock::new(columns, vec![])) + }) + .collect::>>()?; + left_blocks.iter().for_each(|block| { + join_metrics + .load_mem_used + .add(block.data().get_array_memory_size()) + }); + let right_blocks = right_data + .iter() + .map(|batch| { + join_metrics.right_input_batches.add(1); + join_metrics.right_input_rows.add(batch.num_rows()); + join_metrics + .load_mem_used + .add(batch.get_array_memory_size()); + let columns = right_conditions + .iter() + .map(|expr| expr.expr.evaluate(batch)?.into_array(batch.num_rows())) + .collect::>>()?; + Ok(SortedBlock::new(columns, vec![])) + }) + .collect::>>()?; + right_blocks.iter().for_each(|block| { + join_metrics + .load_mem_used + .add(block.data().get_array_memory_size()) + }); + Ok(IEJoinData { + left_data, + right_data, + left_blocks, + right_blocks, + }) +} + +struct IEJoinStream { + /// input schema + schema: Arc, + /// join filter + filter: Option, + /// type of the join + /// Only support inner join currently + _join_type: JoinType, + /// operator of the inequality condition + operators: Arc<[Operator; 2]>, + /// sort options of the inequality condition + sort_options: Arc<[SortOptions; 2]>, + /// iejoin data + iejoin_data: OnceFut, + /// column indices + column_indices: Vec, + /// partition pair + pairs: Arc>, + /// finished + finished: bool, + /// join metrics + join_metrics: IEJoinMetrics, +} + +impl IEJoinStream { + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + if self.finished { + return Poll::Ready(None); + } + + let load_timer = self.join_metrics.load_time.timer(); + let iejoin_data = match ready!(self.iejoin_data.get_shared(cx)) { + Ok(data) => data, + Err(e) => return Poll::Ready(Some(Err(e))), + }; + load_timer.done(); + + // get the size of left and right blocks + let (n, m) = (iejoin_data.left_data.len(), iejoin_data.right_data.len()); + + // get pair of left and right blocks, add 1 to the pair + let pair = { + let mut pair = self.pairs.lock(); + let p = *pair; + *pair += 1; + p + }; + + // no more block pair to join + if pair >= (n * m) as u64 { + self.finished = true; + return Poll::Ready(None); + } + // get the index of left and right block + let (left_block_idx, right_block_idx) = + ((pair / m as u64) as usize, (pair % m as u64) as usize); + + // get the left and right block + let left_block = &(iejoin_data.left_blocks[left_block_idx]); + let right_block = &(iejoin_data.right_blocks[right_block_idx]); + + // no intersection between two blocks + if !IEJoinStream::check_intersection( + left_block, + right_block, + &self.sort_options[0], + ) { + self.join_metrics.skipped_pairs.add(1); + return Poll::Ready(Some(Ok(RecordBatch::new_empty(Arc::clone( + &self.schema, + ))))); + } + + let join_timer = self.join_metrics.join_time.timer(); + // compute the join result + // TODO: should return batches one by one if the result size larger than the batch size in config? + let batch = IEJoinStream::compute( + left_block, + right_block, + &self.sort_options, + &self.operators, + &iejoin_data.left_data[left_block_idx], + &iejoin_data.right_data[right_block_idx], + &self.filter, + &self.schema, + &self.column_indices, + )?; + join_timer.done(); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + Poll::Ready(Some(Ok(batch))) + } + + #[allow(clippy::too_many_arguments)] + fn compute( + left_block: &SortedBlock, + right_block: &SortedBlock, + sort_options: &[SortOptions; 2], + operators: &[Operator; 2], + left_data: &RecordBatch, + right_data: &RecordBatch, + filter: &Option, + schema: &Arc, + column_indices: &[ColumnIndex], + ) -> Result { + let (l1_indexes, permutation) = IEJoinStream::compute_permutation( + left_block, + right_block, + sort_options, + operators, + )?; + + // compute the join indices statify the inequality conditions + let (left_indices, right_indices) = + IEJoinStream::build_join_indices(&l1_indexes, &permutation)?; + + // apply the filter to the join result + let (left_indices, right_indices) = if let Some(filter) = filter { + apply_join_filter_to_indices( + left_data, + right_data, + left_indices, + right_indices, + filter, + JoinSide::Left, + )? + } else { + (left_indices, right_indices) + }; + + build_batch_from_indices( + schema, + left_data, + right_data, + &left_indices, + &right_indices, + column_indices, + JoinSide::Left, + ) + } + + /// check if there is an intersection between two sorted blocks + fn check_intersection( + left_block: &SortedBlock, + right_block: &SortedBlock, + sort_options: &SortOptions, + ) -> bool { + // filter all null result + if left_block.arrays()[0].null_count() == left_block.arrays()[0].len() + || right_block.arrays()[0].null_count() == right_block.arrays()[0].len() + { + return false; + } + let comparator = make_comparator( + &left_block.arrays()[0], + &right_block.arrays()[0], + *sort_options, + ) + .unwrap(); + // get the valid count of right block + let m = right_block.arrays()[0].len() - right_block.arrays()[0].null_count(); + // if the max valid element of right block is smaller than the min valid element of left block, there is no intersection + // for example, if left.a <= right.b, the left block is \[7, 8, 9\], the right block is \[2, 4, 6\], left\[0\] greater than right\[2\] so there is no intersection between left block and right block + // if left.a >= right.b, the left block is \[1, 0, 0\], the right block is \[6, 4, 2\], left\[0\] lesser than right\[2\] (because the sort options used in `make_comparator` is desc, so the compare result will be greater) so there is no intersection between left block and right block + if comparator(0, m - 1) == std::cmp::Ordering::Greater { + return false; + } + true + } + + /// this function computes l1_indexes array and the permutation array of condition 2 on condition 1 + /// for example, if condition 1 is left.a <= right.b, condition 2 is left.x <= right.y + /// for left table, we have: + /// | id | a | x | + /// |-------|----|---| + /// | left1 | 1 | 7 | + /// | left2 | 3 | 4 | + /// for right table, we have: + /// | id | b | y | + /// |----|---|---| + /// | right1 | 2 | 5 | + /// | right2 | 4 | 6 | + /// Sort by condition 1, we get l1: + /// | value | 1 | 2 | 3 | 4 | + /// |-------|---|---|---|---| + /// | id | left1 | right1 | left2 | right2 | + /// The l1_indexes array is [-1, 1, -2, 2], the negative value means it is the index of left table, the positive value means it is the index of right table, the absolute value is the index of original recordbatch + /// Sort by condition 2, we get l2: + /// | value | 4 | 5 | 6 | 7 | + /// |-------|---|---|---|---| + /// | id | left2 | right1 | right2 | left1 | + /// Then the permutation array is [2, 1, 3, 0] + /// The first element of l2 is left2, which is the 3rd element(index 2) of l1. The second element of l2 is right1, which is the 2nd element(index 1) of l1. And so on. + fn compute_permutation( + left_block: &SortedBlock, + right_block: &SortedBlock, + sort_options: &[SortOptions; 2], + operators: &[Operator; 2], + ) -> Result<(Int64Array, UInt64Array)> { + // step1. sort the union block l1 + let n = left_block.arrays()[0].len() as i64; + let m = right_block.arrays()[0].len() as i64; + // concat the left block and right block + let cond1 = concat(&[ + &Arc::clone(&left_block.arrays()[0]), + &Arc::clone(&right_block.arrays()[0]), + ])?; + let cond2 = concat(&[ + &Arc::clone(&left_block.arrays()[1]), + &Arc::clone(&right_block.arrays()[1]), + ])?; + // store index of left table and right table + // -i in (-n..-1) means it is index i in left table, j in (1..m) means it is index j in right table + let indexes = concat(&[ + &Int64Array::from((1..=n).map(|i| -i).collect::>()), + &Int64Array::from((1..=m).collect::>()), + ])?; + let mut l1 = SortedBlock::new( + vec![cond1, indexes, cond2], + vec![ + // order by condition 1 + (0, sort_options[0]), + ( + 1, + SortOptions { + // if the operator is loose inequality, let the right index (> 0) in backward of left index (< 0) + // otherwise, let the right index (> 0) in forward of left index (< 0) + // for example, t1.time <= t2.time + // | value| 1 | 1 | 1 | 1 | 2 | + // |------|--------|--------|-------|-------|-------| + // | index| -2(l2) | -1(l2) | 1(r1) | 2(r2) | 3(r3) | + // if t1.time < t2.time + // |value| 1 | 1 | 1 | 1 | 2 | + // |-----|--------|--------|-------|-------|-------| + // |index| 2(r2) | 1(r1) | -1(l2) | -2(l1) | 3(r3) | + // according to this order request, if i < j then value\[i\](from left table) and value\[j\](from right table) match the condition(t1.time <= t2.time or t1.time < t2.time) + descending: !is_loose_inequality_operator(&operators[0]), + nulls_first: false, + }, + ), + ], + ); + l1.sort_by_columns()?; + // ignore the null values of the first condition + let valid = (l1.arrays()[0].len() - l1.arrays()[0].null_count()) as i64; + let l1 = l1.slice(0..valid as usize); + + // l1_indexes\[i\] = j means the ith element of l1 is the jth element of original recordbatch + let l1_indexes = Arc::clone(&l1.arrays()[1]) + .as_primitive::() + .clone(); + + // mark the order of l1, the index i means this element is the ith element of l1(sorted by condition 1) + let permutation = UInt64Array::from((0..valid as u64).collect::>()); + + let mut l2 = SortedBlock::new( + vec![ + // condition 2 + Arc::clone(&l1.arrays()[2]), + // index of original recordbatch + Arc::clone(&l1.arrays()[1]), + // index of l1 + Arc::new(permutation), + ], + vec![ + // order by condition 2 + (0, sort_options[1]), + ( + 1, + SortOptions { + // same as above + descending: !is_loose_inequality_operator(&operators[1]), + nulls_first: false, + }, + ), + ], + ); + l2.sort_by_columns()?; + let valid = (l2.arrays()[0].len() - l2.arrays()[0].null_count()) as usize; + let l2 = l2.slice(0..valid); + + Ok(( + l1_indexes, + Arc::clone(&l2.arrays()[2]) + .as_primitive::() + .clone(), + )) + } + + /// compute the join indices statify the inequality conditions + /// following the example in `compute_permutation`, the l1_indexes is \[1, -1, 2, -2\], the permutation is \[2, 1, 3, 0\] + /// range_map is empty at first + /// 1、 p\[0\] = 2, range_map is empty, l1_indexes\[2\] is greater than 0, it means 2nd element in l1 is from left table, insert(2) into range_map, range_map {(2, 3)} + /// 2、 p\[1\] = 1, no value less than p\[1\] in range_map, l1_indexes\[1\] is less than 0, it means 1st element in l1 is from right table, no need to insert(1) into range_map, range_map {(2, 3)} + /// 3、 p\[2\] = 3, found 2 less than p\[2\] in range_map, append all pairs (l1_indexes\[2\], l1_indexes\[3\]) to the indeices array, l1_indexes\[3\] is less than 0, it means 3rd element in l1 is from right table, no need to insert(3) into range_map, range_map {(1, 4)} + /// 4、 p\[3\] = 0, no value less than p\[3\] in range_map, insert(0) into range_map, range_map {(0, 1), (2, 3)} + /// The indices array is \[(2), (2)\] + fn build_join_indices( + l1_indexes: &Int64Array, + permutation: &UInt64Array, + ) -> Result<(UInt64Array, UInt64Array)> { + let mut left_builder = UInt64Builder::new(); + let mut right_builder = UInt64Builder::new(); + // left_order\[i\] = l means there are l elements from left table in l1\[0..=i\], also means element i is the l-th smallest element in left recordbatch. + let mut left_order = UInt64Array::builder(l1_indexes.len()); + let mut l_pos = 0; + for ind in l1_indexes.values().iter() { + if *ind < 0 { + l_pos += 1; + } + left_order.append_value(l_pos); + } + let left_order = left_order.finish(); + // use btree map to maintain all p\[i\], for i in 0..j, map\[s\]=t means range \[s, t\) is valid + // our target is to find all pair(i, j) that i::new(); + for p in permutation.values().iter() { + // get the index of original recordbatch + let l1_index = unsafe { l1_indexes.value_unchecked(*p as usize) }; + if l1_index < 0 { + // index from left table + // insert p in to range_map + IEJoinStream::insert_range_map(&mut range_map, unsafe { + left_order.value_unchecked(*p as usize) + }); + continue; + } + // index from right table, remap to 0..m + let right_index = (l1_index - 1) as u64; + // r\[right_index] in right table and l\[0..=rp\] in left table statisfy comparsion requirement of condition1 + let rp = unsafe { left_order.value_unchecked(*p as usize) }; + for range in range_map.iter() { + let (end, start) = range; + if *start > rp { + break; + } + let (start, end) = (*start, std::cmp::min(*end, rp + 1)); + for left_index in start..end { + left_builder.append_value(left_index - 1); + // append right index + right_builder.append_value(right_index); + } + } + } + Ok((left_builder.finish(), right_builder.finish())) + } + + #[inline] + fn insert_range_map(range_map: &mut BTreeMap, p: u64) { + let mut range = (p, p + 1); + let mut need_insert = true; + let mut need_remove = false; + // merge it with prev consecutive range + // for example, if range_map is [(1, 2), (3, 4), (5, 6)], then insert(2) will make it [(1, 3), (3, 4), (5, 6)] + let mut iter = range_map.range_mut(p..); + let mut interval = iter.next(); + let mut move_next = false; + if let Some(ref interval) = interval { + if interval.0 == &p { + // merge prev range, update current range.start + range = (*interval.1, p + 1); + // remove prev range + need_remove = true; + // move to next range + move_next = true; + } + } + if move_next { + interval = iter.next(); + } + // if previous range is consecutive, merge them + // follow the example, [(1, 3), (3, 4), (5, 6)] will be merged into [(1, 4), (5, 6)] + if let Some(ref mut interval) = interval { + if *interval.1 == range.1 { + // merge into next range, update next range.start + *interval.1 = range.0; + // already merge into next range, no need to insert current range + need_insert = false; + } + } + if need_remove { + range_map.remove(&p); + } + // if this range is not consecutive with previous one, insert it + if need_insert { + range_map.insert(range.1, range.0); + } + } +} + +impl Stream for IEJoinStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +impl RecordBatchStream for IEJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{common, memory::MemoryExec, test::build_table_i32_with_nulls}; + + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{assert_batches_sorted_eq, ScalarValue}; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; + use datafusion_physical_expr::PhysicalExpr; + + use itertools::Itertools; + + #[test] + fn test_insert_range_map() { + let mut range_map = BTreeMap::new(); + // shuffle 0..8 and insert it into range_map + let values = (0..8).collect::>(); + // test for all permutation of 0..8 + for permutaion in values.iter().permutations(values.len()) { + range_map.clear(); + for v in permutaion.iter() { + IEJoinStream::insert_range_map(&mut range_map, **v as u64); + } + assert_eq!(range_map.len(), 1); + } + } + + fn build_table( + a: (&str, &Vec>), + b: (&str, &Vec>), + c: (&str, &Vec>), + ) -> Arc { + let batch = build_table_i32_with_nulls(a, b, c); + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + /// Returns the column names on the schema + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } + + async fn multi_partitioned_join_collect( + left: Arc, + right: Arc, + join_type: &JoinType, + ie_join_filter: Vec, + join_filter: Option, + context: Arc, + ) -> Result<(Vec, Vec)> { + let partition_count = 4; + + let ie_join = IEJoinExec::try_new( + left, + right, + ie_join_filter, + join_filter, + join_type, + partition_count, + )?; + let columns = columns(&ie_join.schema()); + let mut batches = vec![]; + for i in 0..partition_count { + let stream = ie_join.execute(i, Arc::clone(&context))?; + let more_batches = common::collect(stream).await?; + batches.extend( + more_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(), + ); + } + Ok((columns, batches)) + } + + #[tokio::test] + async fn test_ie_join() -> Result<()> { + let column_indices = vec![ + ColumnIndex { + index: 0, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 0, + side: JoinSide::Right, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + Field::new("z", DataType::Int32, true), + ]); + // test left.x < right.x and left.y >= right.y + let filter1 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Lt, + Arc::new(Column::new("x", 2)), + )) as Arc; + let filter2 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("y", 1)), + Operator::GtEq, + Arc::new(Column::new("y", 3)), + )) as Arc; + let ie_filter = vec![ + JoinFilter::new(filter1, column_indices.clone(), intermediate_schema.clone()), + JoinFilter::new(filter2, column_indices.clone(), intermediate_schema.clone()), + ]; + let join_filter = Some(JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("z", 4)), + Operator::NotEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), + )), + column_indices.clone(), + intermediate_schema.clone(), + )); + // + let left = build_table( + ("x", &vec![Some(5), Some(9), None]), + ("y", &vec![Some(6), Some(10), Some(10)]), + ("z", &vec![Some(3), Some(5), Some(10)]), + ); + let right = build_table( + ( + "x", + &vec![ + Some(10), + Some(6), + Some(5), + Some(6), + Some(6), + Some(6), + Some(6), + Some(6), + ], + ), + ( + "y", + &vec![ + Some(9), + Some(6), + Some(5), + Some(5), + Some(6), + Some(7), + Some(6), + None, + ], + ), + ( + "z", + &vec![ + Some(7), + Some(3), + Some(5), + Some(5), + Some(7), + Some(7), + Some(8), + Some(9), + ], + ), + ); + let task_ctx = Arc::new(TaskContext::default()); + let (columns, batches) = multi_partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + &JoinType::Inner, + ie_filter, + join_filter, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["x", "y", "z", "x", "y", "z"]); + let expected = [ + "+---+----+---+----+---+---+", + "| x | y | z | x | y | z |", + "+---+----+---+----+---+---+", + "| 5 | 6 | 3 | 6 | 5 | 5 |", + "| 5 | 6 | 3 | 6 | 6 | 3 |", + "| 5 | 6 | 3 | 6 | 6 | 7 |", + "| 9 | 10 | 5 | 10 | 9 | 7 |", + "+---+----+---+----+---+---+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 6ddf19c51193..91866f2d3cf6 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -19,12 +19,14 @@ pub use cross_join::CrossJoinExec; pub use hash_join::HashJoinExec; +pub use ie_join::IEJoinExec; pub use nested_loop_join::NestedLoopJoinExec; // Note: SortMergeJoin is not used in plans yet pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; +mod ie_join; mod nested_loop_join; mod sort_merge_join; mod stream_join_utils; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index ac718a95e9f4..36356626a71c 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -59,7 +59,7 @@ use arrow::array::{ UInt64Array, }; use arrow::compute::concat_batches; -use arrow::datatypes::{Schema, SchemaRef}; +use arrow::datatypes::{Schema, SchemaRef, UInt32Type, UInt64Type}; use arrow::record_batch::RecordBatch; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::bisect; @@ -731,13 +731,14 @@ pub(crate) fn build_side_determined_results( && need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) { // Calculate the indices for build and probe sides based on join type and build side: - let (build_indices, probe_indices) = calculate_indices_by_join_type( - build_hash_joiner.build_side, - prune_length, - &build_hash_joiner.visited_rows, - build_hash_joiner.deleted_offset, - join_type, - )?; + let (build_indices, probe_indices) = + calculate_indices_by_join_type::( + build_hash_joiner.build_side, + prune_length, + &build_hash_joiner.visited_rows, + build_hash_joiner.deleted_offset, + join_type, + )?; // Create an empty probe record batch: let empty_probe_batch = RecordBatch::new_empty(probe_schema); diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 89f3feaf07be..8a0305e32228 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -17,7 +17,7 @@ //! Join related functionality used both on logical and physical plans -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug}; use std::future::Future; use std::ops::{IndexMut, Range}; @@ -46,13 +46,16 @@ use datafusion_common::{ plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, }; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::Operator; use datafusion_physical_expr::equivalence::add_offset_to_expr; +use datafusion_physical_expr::expressions::BinaryExpr; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::utils::{collect_columns, merge_vectors}; +use datafusion_physical_expr::utils::{collect_columns, map_columns, merge_vectors}; use datafusion_physical_expr::{ LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, }; +use arrow::compute::kernels::sort::SortOptions; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use hashbrown::raw::RawTable; @@ -380,6 +383,130 @@ pub type JoinOn = Vec<(PhysicalExprRef, PhysicalExprRef)>; /// Reference for JoinOn. pub type JoinOnRef<'a> = &'a [(PhysicalExprRef, PhysicalExprRef)]; +pub fn is_ineuqality_operator(op: &Operator) -> bool { + matches!( + op, + Operator::NotEq | Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq + ) +} + +pub fn is_loose_inequality_operator(op: &Operator) -> bool { + matches!(op, Operator::LtEq | Operator::GtEq) +} + +/// Swaps the left and right expressions of a binary expression, like `a < b` to `b > a`. +/// If this is not a binary expression or the operator can't be swapped, the expression is returned as is. +pub fn swap_binary_expr(expr: &PhysicalExprRef) -> PhysicalExprRef { + match expr.as_any().downcast_ref::() { + Some(binary) => { + if let Some(swapped_op) = binary.op().swap() { + Arc::new(BinaryExpr::new( + Arc::clone(binary.right()), + swapped_op, + Arc::clone(binary.left()), + )) + } else { + Arc::clone(expr) + } + } + None => Arc::clone(expr), + } +} + +/// Checks whether the inequality condition is valid. +/// The inequality condition is valid if the expressions are not null and the expressions are not equal, and left expression is from left schema and right expression is from right schema. +pub fn check_inequality_condition(inequality_condition: &JoinFilter) -> Result<()> { + if let Some(binary) = inequality_condition + .expression() + .as_any() + .downcast_ref::() + { + if !(is_ineuqality_operator(binary.op()) && *binary.op() != Operator::NotEq) { + return plan_err!( + "Inequality conditions must be an inequality binary expression, but got {}", + binary.op() + ); + } + let column_indices = &inequality_condition.column_indices; + // check if left expression is from left table + let left_expr_columns = collect_columns(binary.left()); + let left_expr_in_left = left_expr_columns + .iter() + .all(|c| column_indices[c.index()].side == JoinSide::Left); + // check if right expression is from right table + let right_expr_columns = collect_columns(binary.right()); + let right_expr_in_right = right_expr_columns + .iter() + .all(|c| column_indices[c.index()].side == JoinSide::Right); + if left_expr_columns.is_empty() || right_expr_columns.is_empty() { + return plan_err!( + "Inequality condition shouldn't be constant expression, but got {}", + inequality_condition.expression() + ); + } + if !left_expr_in_left || !right_expr_in_right { + return plan_err!("Left/right side expression of inequality condition should be from left/right side of join, but got {} and {}", + binary.left(), + binary.right() + ); + } + } else { + return plan_err!( + "Inequality conditions must be an inequality binary expression, but got {}", + inequality_condition.expression() + ); + } + Ok(()) +} + +/// convert inequality conditions to sort expressions of each side and the operator +/// for example, if the inequality condition is `a < b`, then the sort expressions for left and right side are `a asc` and `b asc` respectively +pub fn inequality_conditions_to_sort_exprs( + inequality_conditions: &[JoinFilter], +) -> Result> { + inequality_conditions + .iter() + .map(|filter| { + let expr = filter.expression(); + let binary = expr.as_any().downcast_ref::().unwrap(); + let sort_option = match binary.op() { + Operator::Lt | Operator::LtEq => SortOptions { + descending: false, + nulls_first: false, + }, + Operator::Gt | Operator::GtEq => SortOptions { + descending: true, + nulls_first: false, + }, + _ => unreachable!(), + }; + // remap the column in join schema to origin table, because we need to use the original column index to sort left and right table independently + let (left_map, right_map): (Vec<_>, Vec<_>) = filter + .column_indices() + .iter() + .enumerate() + .partition(|(_, index)| index.side == JoinSide::Left); + let left_map = HashMap::from_iter( + left_map.iter().map(|(idx, index)| (*idx, index.index)), + ); + let right_map = HashMap::from_iter( + right_map.iter().map(|(idx, index)| (*idx, index.index)), + ); + Ok(( + PhysicalSortExpr::new( + map_columns(Arc::clone(binary.left()), &left_map)?, + sort_option, + ), + PhysicalSortExpr::new( + map_columns(Arc::clone(binary.right()), &right_map)?, + sort_option, + ), + *binary.op(), + )) + }) + .collect() +} + /// Checks whether the schemas "left" and "right" and columns "on" represent a valid join. /// They are valid whenever their columns' intersection equals the set `on` pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> { @@ -1183,14 +1310,17 @@ pub(crate) fn get_final_indices_from_bit_map( (left_indices, right_indices) } -pub(crate) fn apply_join_filter_to_indices( +pub(crate) fn apply_join_filter_to_indices< + L: ArrowPrimitiveType, + R: ArrowPrimitiveType, +>( build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, - build_indices: UInt64Array, - probe_indices: UInt32Array, + build_indices: PrimitiveArray, + probe_indices: PrimitiveArray, filter: &JoinFilter, build_side: JoinSide, -) -> Result<(UInt64Array, UInt32Array)> { +) -> Result<(PrimitiveArray, PrimitiveArray)> { if build_indices.is_empty() && probe_indices.is_empty() { return Ok((build_indices, probe_indices)); }; @@ -1220,12 +1350,12 @@ pub(crate) fn apply_join_filter_to_indices( /// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. /// The resulting batch has [Schema] `schema`. -pub(crate) fn build_batch_from_indices( +pub(crate) fn build_batch_from_indices( schema: &Schema, build_input_buffer: &RecordBatch, probe_batch: &RecordBatch, - build_indices: &UInt64Array, - probe_indices: &UInt32Array, + build_indices: &PrimitiveArray, + probe_indices: &PrimitiveArray, column_indices: &[ColumnIndex], build_side: JoinSide, ) -> Result { @@ -1646,7 +1776,10 @@ mod tests { use arrow_schema::SortOptions; use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; - use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use datafusion_common::{ + arrow_datafusion_err, arrow_err, assert_contains, ScalarValue, + }; + use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; fn check( left: &[Column], @@ -2554,4 +2687,105 @@ mod tests { Ok(()) } + + #[test] + fn test_inequality_condition() -> Result<()> { + let column_indices = vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + ]); + // test left.x!=8, it will fail because of the not eq operator + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::NotEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), + )) as Arc; + let join_filter = + JoinFilter::new(filter, column_indices.clone(), intermediate_schema.clone()); + let actual = format!("{:?}", check_inequality_condition(&join_filter)); + assert_contains!( + actual, + "Inequality conditions must be an inequality binary expression, but got !=" + ); + // test left.x>8, it will fail because of the constant expression + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), + )) as Arc; + let join_filter = + JoinFilter::new(filter, column_indices.clone(), intermediate_schema.clone()); + let actual = format!("{:?}", check_inequality_condition(&join_filter)); + assert_contains!( + actual, + "Inequality condition shouldn't be constant expression, but got x@0 > 8" + ); + // test rigth.x * left.y >= left.x, it will fail because of the left side expression contains column from right table + let filter = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 2)), + Operator::Multiply, + Arc::new(Column::new("y", 1)), + )), + Operator::GtEq, + Arc::new(Column::new("x", 0)), + )) as Arc; + let join_filter = + JoinFilter::new(filter, column_indices.clone(), intermediate_schema.clone()); + let actual = format!("{:?}", check_inequality_condition(&join_filter)); + assert_contains!( + actual, + "Left/right side expression of inequality condition should be from left/right side of join, but got x@2 * y@1 and x@0" + ); + // test left.x + left.y >= left.x, this will be ok + let filter = Arc::new(BinaryExpr::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Plus, + Arc::new(Column::new("y", 1)), + )), + Operator::GtEq, + Arc::new(Column::new("x", 2)), + )) as Arc; + let join_filter = + JoinFilter::new(filter, column_indices.clone(), intermediate_schema.clone()); + let actual = format!("{:?}", check_inequality_condition(&join_filter)); + assert_eq!(actual, "Ok(())"); + let res = inequality_conditions_to_sort_exprs(&[join_filter])?; + let (left_expr, right_expr, operator) = res.first().unwrap(); + assert_eq!(left_expr.to_string(), "x@1 + y@2 DESC NULLS LAST"); + assert_eq!(right_expr.to_string(), "x@1 DESC NULLS LAST"); + assert_eq!(*operator, Operator::GtEq); + // test left.x < left.x, this will be ok + let filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Lt, + Arc::new(Column::new("x", 2)), + )) as Arc; + let join_filter = + JoinFilter::new(filter, column_indices.clone(), intermediate_schema.clone()); + let actual = format!("{:?}", check_inequality_condition(&join_filter)); + assert_eq!(actual, "Ok(())"); + let res = inequality_conditions_to_sort_exprs(&[join_filter])?; + let (left_expr, right_expr, operator) = res.first().unwrap(); + assert_eq!(left_expr.to_string(), "x@1 ASC NULLS LAST"); + assert_eq!(right_expr.to_string(), "x@1 ASC NULLS LAST"); + assert_eq!(*operator, Operator::Lt); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index f5b4a096018f..079a75ef35d7 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -88,6 +88,29 @@ pub fn build_table_i32( .unwrap() } +/// returns record batch with 3 columns of i32 in memory +pub fn build_table_i32_with_nulls( + a: (&str, &Vec>), + b: (&str, &Vec>), + c: (&str, &Vec>), +) -> RecordBatch { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, true), + Field::new(b.0, DataType::Int32, true), + Field::new(c.0, DataType::Int32, true), + ]); + + RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap() +} + /// returns memory table scan wrapped around record batch with 3 columns of i32 pub fn build_table_scan_i32( a: (&str, &Vec), diff --git a/datafusion/sqllogictest/test_files/iejoin.slt b/datafusion/sqllogictest/test_files/iejoin.slt new file mode 100644 index 000000000000..652f8624e9c5 --- /dev/null +++ b/datafusion/sqllogictest/test_files/iejoin.slt @@ -0,0 +1,206 @@ +# create tables +statement ok +CREATE TABLE east AS SELECT * FROM (VALUES + ('r1', 100, 140, 12, 2), + ('r2', 101, 100, 12, 8), + ('r3', 103, 90, 5, 4) +) east(rid, id, dur, rev, cores) + +statement ok +CREATE TABLE west AS SELECT * FROM (VALUES + ('s1', 404, 100, 6, 4), + ('s2', 498, 140, 11, 2), + ('s3', 676, 80, 10, 1), + ('s4', 742, 90, 5, 4) +) west(rid, t_id, time, cost, cores) + +# Qs +query TT +SELECT s1.rid, s2.rid +FROM west s1, west s2 +WHERE s1.time > s2.time +ORDER BY 1, 2 +---- +s1 s3 +s1 s4 +s2 s1 +s2 s3 +s2 s4 +s4 s3 + +# Qp +query TT +SELECT s1.rid, s2.rid +FROM west s1, west s2 +WHERE s1.time > s2.time AND s1.cost < s2.cost +ORDER BY 1, 2 +---- +s1 s3 +s4 s3 + +query TT +EXPLAIN +SELECT s1.rid, s2.rid +FROM west s1, west s2 +WHERE s1.time > s2.time AND s1.cost < s2.cost +ORDER BY 1, 2 +---- +logical_plan +01)Sort: s1.rid ASC NULLS LAST, s2.rid ASC NULLS LAST +02)--Projection: s1.rid, s2.rid +03)----Inner Join: Filter: s2.time < s1.time AND s2.cost > s1.cost +04)------SubqueryAlias: s1 +05)--------TableScan: west projection=[rid, time, cost] +06)------SubqueryAlias: s2 +07)--------TableScan: west projection=[rid, time, cost] +physical_plan +01)SortPreservingMergeExec: [rid@0 ASC NULLS LAST,rid@1 ASC NULLS LAST] +02)--SortExec: expr=[rid@0 ASC NULLS LAST,rid@1 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[rid@0 as rid, rid@3 as rid] +04)------IEJoinExec: mode=Bounded, join_type=Inner, inequality_conditions=[(time@0 > time@2), (cost@1 < cost@3)] +05)--------SortExec: expr=[time@1 DESC NULLS LAST], preserve_partitioning=[false] +06)----------MemoryExec: partitions=1, partition_sizes=[1] +07)--------SortExec: expr=[time@1 DESC NULLS LAST], preserve_partitioning=[false] +08)----------MemoryExec: partitions=1, partition_sizes=[1] + +# Qt +query TT +SELECT east.rid, west.rid +FROM east, west +WHERE east.dur < west.time AND east.rev > west.cost +ORDER BY 1, 2 +---- +r2 s2 + +# Test string comparisons +query TT +WITH weststr AS ( + SELECT rid, time::VARCHAR AS time, cost::VARCHAR as cost + FROM west +) +SELECT s1.rid, s2.rid +FROM weststr s1, weststr s2 +WHERE s1.time > s2.time AND s1.cost < s2.cost +ORDER BY 1, 2 +---- +s2 s1 +s3 s1 +s3 s2 +s4 s1 + + +statement ok +create table tt (x int, y int, z int); + +statement ok +insert into tt select nullif(r % 3, 0), nullif (r % 5, 0), r from unnest(generate_series(10)) AS tbl(r); + +query IIIIII +select * +from tt t1 join tt t2 +on t1.x < t2.x and t1.y < t2.y +order by t1.x nulls first, t1.y nulls first, t1.z, t2.x, t2.y, t2.z; +---- +1 1 1 2 2 2 +1 1 1 2 3 8 +1 2 7 2 3 8 + +statement ok +create table tt2 (x int); + +# left iejoin not implement yet. +# statement ok +# insert into tt2 select * from unnest(generate_series(9)); +# +# query II +# select t1.x, t1.y +# from ( +# select (case when x < 100 then null else 99 end) x, (case when x < 100 then 99 else 99 end) y +# from tt2 +# ) t1 left join tt2 t2 +# on t1.x < t2.x and t1.y < t2.x +# order by t1.x nulls first, t1.y nulls first; +# ---- +# NULL 99 +# NULL 99 +# NULL 99 +# NULL 99 +# NULL 99 +# NULL 99 +# NULL 99 +# NULL 99 +# NULL 99 +# NULL 99 + +# Test all nulls table +statement ok +CREATE TABLE test(x INT); + +statement ok +INSERT INTO test(x) VALUES (NULL), (NULL), (NULL); + +statement ok +CREATE TABLE all_null AS SELECT * FROM test; + +query II +SELECT * +FROM all_null AS a, all_null AS b +WHERE (a.x BETWEEN b.x AND b.x); +---- + +query II +SELECT * +FROM test AS a, all_null AS b +WHERE (a.x BETWEEN b.x AND b.x); +---- + +query II +SELECT * +FROM all_null AS a, test AS b +WHERE (a.x BETWEEN b.x AND b.x); +---- + +statement ok +DROP TABLE IF EXISTS lhs; + +statement ok +DROP TABLE IF EXISTS rhs; + +statement ok +CREATE TABLE lhs ( + id INT, + begin INT, + end INT +); + +statement ok +INSERT INTO lhs (id, begin, end) +SELECT + i AS id, + i AS begin, + i + 1 AS end +FROM unnest(generate_series(1, 10001)) tbl(i); + +statement ok +CREATE TABLE rhs ( + id INT, + begin INT, + end INT +); + +statement ok +INSERT INTO rhs (id, begin, end) +SELECT + i - 10000 AS id, + i AS begin, + i + 1 AS end +FROM unnest(generate_series(10001, 20001)) tbl(i); + +query II +SELECT lhs.begin, rhs.begin +FROM lhs, rhs +WHERE lhs.begin < rhs.end AND rhs.begin < lhs.end; +---- +10001 10001 + +# TODO: use metric to check no overlap blocks pair be pruned