diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index 516749e82a53..5fdf02079496 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -43,6 +43,17 @@ use datafusion::physical_plan::memory::MemoryExec; use datafusion::prelude::{SessionConfig, SessionContext}; use test_utils::stagger_batch_with_seed; +// Determines what Fuzz tests needs to run +// Ideally all tests should match, but in reality some tests +// passes only partial cases +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum JoinTestType { + // compare NestedLoopJoin and HashJoin + NljHj, + // compare HashJoin and SortMergeJoin, no need to compare SortMergeJoin and NestedLoopJoin + // because if existing variants both passed that means SortMergeJoin and NestedLoopJoin also passes + HjSmj, +} #[tokio::test] async fn test_inner_join_1k() { JoinFuzzTestCase::new( @@ -51,7 +62,7 @@ async fn test_inner_join_1k() { JoinType::Inner, None, ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -71,6 +82,30 @@ fn less_than_100_join_filter(schema1: Arc, _schema2: Arc) -> Joi JoinFilter::new(less_than_100, column_indices, intermediate_schema) } +fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { + let less_than_100 = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 1)), + Operator::Lt, + Arc::new(Column::new("x", 0)), + )) as _; + let column_indices = vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + schema1.field_with_name("x").unwrap().to_owned(), + schema2.field_with_name("x").unwrap().to_owned(), + ]); + + JoinFilter::new(less_than_100, column_indices, intermediate_schema) +} + #[tokio::test] async fn test_inner_join_1k_filtered() { JoinFuzzTestCase::new( @@ -79,7 +114,7 @@ async fn test_inner_join_1k_filtered() { JoinType::Inner, Some(Box::new(less_than_100_join_filter)), ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -91,7 +126,7 @@ async fn test_inner_join_1k_smjoin() { JoinType::Inner, None, ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -103,7 +138,7 @@ async fn test_left_join_1k() { JoinType::Left, None, ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -115,7 +150,7 @@ async fn test_left_join_1k_filtered() { JoinType::Left, Some(Box::new(less_than_100_join_filter)), ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -127,7 +162,7 @@ async fn test_right_join_1k() { JoinType::Right, None, ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } // Add support for Right filtered joins @@ -140,7 +175,7 @@ async fn test_right_join_1k_filtered() { JoinType::Right, Some(Box::new(less_than_100_join_filter)), ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -152,7 +187,7 @@ async fn test_full_join_1k() { JoinType::Full, None, ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -164,7 +199,7 @@ async fn test_full_join_1k_filtered() { JoinType::Full, Some(Box::new(less_than_100_join_filter)), ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -176,12 +211,13 @@ async fn test_semi_join_1k() { JoinType::LeftSemi, None, ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } // The test is flaky // https://github.com/apache/datafusion/issues/10886 +// SMJ produces 1 more row in the output #[ignore] #[tokio::test] async fn test_semi_join_1k_filtered() { @@ -189,9 +225,9 @@ async fn test_semi_join_1k_filtered() { make_staggered_batches(1000), make_staggered_batches(1000), JoinType::LeftSemi, - Some(Box::new(less_than_100_join_filter)), + Some(Box::new(col_lt_col_filter)), ) - .run_test() + .run_test(&[JoinTestType::HjSmj], false) .await } @@ -203,7 +239,7 @@ async fn test_anti_join_1k() { JoinType::LeftAnti, None, ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -217,7 +253,7 @@ async fn test_anti_join_1k_filtered() { JoinType::LeftAnti, Some(Box::new(less_than_100_join_filter)), ) - .run_test() + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } @@ -331,7 +367,7 @@ impl JoinFuzzTestCase { self.on_columns().clone(), self.join_filter(), self.join_type, - vec![SortOptions::default(), SortOptions::default()], + vec![SortOptions::default(); self.on_columns().len()], false, ) .unwrap(), @@ -381,9 +417,11 @@ impl JoinFuzzTestCase { ) } - /// Perform sort-merge join and hash join on same input - /// and verify two outputs are equal - async fn run_test(&self) { + /// Perform joins tests on same inputs and verify outputs are equal + /// `join_tests` - identifies what join types to test + /// if `debug` flag is set the test will save randomly generated inputs and outputs to user folders, + /// so it is easy to debug a test on top of the failed data + async fn run_test(&self, join_tests: &[JoinTestType], debug: bool) { for batch_size in self.batch_sizes { let session_config = SessionConfig::new().with_batch_size(*batch_size); let ctx = SessionContext::new_with_config(session_config); @@ -394,17 +432,30 @@ impl JoinFuzzTestCase { let hj = self.hash_join(); let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); + let nlj = self.nested_loop_join(); + let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); + // Get actual row counts(without formatting overhead) for HJ and SMJ let hj_rows = hj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); + let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); - assert_eq!( - hj_rows, smj_rows, - "SortMergeJoinExec and HashJoinExec produced different row counts" - ); + if debug { + println!("The debug is ON. Input data will be saved"); + let out_dir_name = &format!("fuzz_test_debug_batch_size_{batch_size}"); + Self::save_as_parquet(&self.input1, out_dir_name, "input1"); + Self::save_as_parquet(&self.input2, out_dir_name, "input2"); - let nlj = self.nested_loop_join(); - let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); + if join_tests.contains(&JoinTestType::NljHj) { + Self::save_as_parquet(&nlj_collected, out_dir_name, "nlj"); + Self::save_as_parquet(&hj_collected, out_dir_name, "hj"); + } + + if join_tests.contains(&JoinTestType::HjSmj) { + Self::save_as_parquet(&hj_collected, out_dir_name, "hj"); + Self::save_as_parquet(&smj_collected, out_dir_name, "smj"); + } + } // compare let smj_formatted = @@ -425,35 +476,106 @@ impl JoinFuzzTestCase { nlj_formatted.trim().lines().collect(); nlj_formatted_sorted.sort_unstable(); - // row level compare if any of joins returns the result - // the reason is different formatting when there is no rows - if smj_rows > 0 || hj_rows > 0 { - for (i, (smj_line, hj_line)) in smj_formatted_sorted + if join_tests.contains(&JoinTestType::NljHj) { + let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size); + assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); + + let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {}", batch_size); + // row level compare if any of joins returns the result + // the reason is different formatting when there is no rows + for (i, (nlj_line, hj_line)) in nlj_formatted_sorted .iter() .zip(&hj_formatted_sorted) .enumerate() { assert_eq!( - (i, smj_line), + (i, nlj_line), (i, hj_line), - "SortMergeJoinExec and HashJoinExec produced different results" + "{}", + err_msg_contents.as_str() ); } } - for (i, (nlj_line, hj_line)) in nlj_formatted_sorted - .iter() - .zip(&hj_formatted_sorted) - .enumerate() - { - assert_eq!( - (i, nlj_line), - (i, hj_line), - "NestedLoopJoinExec and HashJoinExec produced different results" - ); + if join_tests.contains(&JoinTestType::HjSmj) { + let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size); + assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str()); + + let err_msg_contents = format!("SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", &batch_size); + // row level compare if any of joins returns the result + // the reason is different formatting when there is no rows + if smj_rows > 0 || hj_rows > 0 { + for (i, (smj_line, hj_line)) in smj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, smj_line), + (i, hj_line), + "{}", + err_msg_contents.as_str() + ); + } + } } } } + + /// This method useful for debugging fuzz tests + /// It helps to save randomly generated input test data for both join inputs into the user folder + /// as a parquet files preserving partitioning. + /// Once the data is saved it is possible to run a custom test on top of the saved data and debug + /// + /// let ctx: SessionContext = SessionContext::new(); + /// let df = ctx + /// .read_parquet( + /// "/tmp/input1/*.parquet", + /// ParquetReadOptions::default(), + /// ) + /// .await + /// .unwrap(); + /// let left = df.collect().await.unwrap(); + /// + /// let df = ctx + /// .read_parquet( + /// "/tmp/input2/*.parquet", + /// ParquetReadOptions::default(), + /// ) + /// .await + /// .unwrap(); + /// + /// let right = df.collect().await.unwrap(); + /// JoinFuzzTestCase::new( + /// left, + /// right, + /// JoinType::LeftSemi, + /// Some(Box::new(less_than_100_join_filter)), + /// ) + /// .run_test() + /// .await + /// } + fn save_as_parquet(input: &[RecordBatch], output_dir: &str, out_name: &str) { + let out_path = &format!("{output_dir}/{out_name}"); + std::fs::remove_dir_all(out_path).unwrap_or(()); + std::fs::create_dir_all(out_path).unwrap(); + + input.iter().enumerate().for_each(|(idx, batch)| { + let mut file = + std::fs::File::create(format!("{out_path}/file_{}.parquet", idx)) + .unwrap(); + let mut writer = parquet::arrow::ArrowWriter::try_new( + &mut file, + input.first().unwrap().schema(), + None, + ) + .expect("creating writer"); + writer.write(batch).unwrap(); + writer.close().unwrap(); + }); + + println!("The data {out_name} saved as parquet into {out_path}"); + } } /// Return randomly sized record batches with: