diff --git a/datafusion/src/physical_plan/coalesce_partitions.rs b/datafusion/src/physical_plan/coalesce_partitions.rs index a1068386f0d2..1fd18d2c4f37 100644 --- a/datafusion/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/src/physical_plan/coalesce_partitions.rs @@ -30,6 +30,7 @@ use async_trait::async_trait; use arrow::record_batch::RecordBatch; use arrow::{datatypes::SchemaRef, error::Result as ArrowResult}; +use super::common::AbortOnDropMany; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{RecordBatchStream, Statistics}; use crate::error::{DataFusionError, Result}; @@ -129,14 +130,20 @@ impl ExecutionPlan for CoalescePartitionsExec { // spawn independent tasks whose resulting streams (of batches) // are sent to the channel for consumption. + let mut join_handles = Vec::with_capacity(input_partitions); for part_i in 0..input_partitions { - spawn_execution(self.input.clone(), sender.clone(), part_i); + join_handles.push(spawn_execution( + self.input.clone(), + sender.clone(), + part_i, + )); } Ok(Box::pin(MergeStream { input: receiver, schema: self.schema(), baseline_metrics, + drop_helper: AbortOnDropMany(join_handles), })) } } @@ -168,7 +175,8 @@ pin_project! { schema: SchemaRef, #[pin] input: mpsc::Receiver>, - baseline_metrics: BaselineMetrics + baseline_metrics: BaselineMetrics, + drop_helper: AbortOnDropMany<()>, } } @@ -194,11 +202,15 @@ impl RecordBatchStream for MergeStream { #[cfg(test)] mod tests { + use arrow::datatypes::{DataType, Field, Schema}; + use futures::FutureExt; + use super::*; use crate::datasource::object_store::local::LocalFileSystem; - use crate::physical_plan::common; use crate::physical_plan::file_format::CsvExec; - use crate::test; + use crate::physical_plan::{collect, common}; + use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use crate::test::{self, assert_is_pending}; #[tokio::test] async fn merge() -> Result<()> { @@ -238,4 +250,24 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_drop_cancel() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); + let refs = blocking_exec.refs(); + let coaelesce_partitions_exec = + Arc::new(CoalescePartitionsExec::new(blocking_exec)); + + let fut = collect(coaelesce_partitions_exec); + let mut fut = fut.boxed(); + + assert_is_pending(&mut fut); + drop(fut); + assert_strong_count_converges_to_zero(refs).await; + + Ok(()) + } }