diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index adeeb0bf8eab..33c68273077a 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -50,6 +50,7 @@ use pin_project_lite::pin_project; use async_trait::async_trait; +use super::common::AbortOnDropSingle; use super::metrics::{ self, BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, }; @@ -339,6 +340,7 @@ pin_project! { #[pin] output: futures::channel::oneshot::Receiver>, finished: bool, + drop_helper: AbortOnDropSingle<()>, } } @@ -561,7 +563,8 @@ impl GroupedHashAggregateStream { let schema_clone = schema.clone(); let elapsed_compute = baseline_metrics.elapsed_compute().clone(); - tokio::spawn(async move { + + let join_handle = tokio::spawn(async move { let result = compute_grouped_hash_aggregate( mode, schema_clone, @@ -572,13 +575,16 @@ impl GroupedHashAggregateStream { ) .await .record_output(&baseline_metrics); - tx.send(result) + + // failing here is OK, the receiver is gone and does not care about the result + tx.send(result).ok(); }); Self { schema, output: rx, finished: false, + drop_helper: AbortOnDropSingle::new(join_handle), } } } @@ -738,6 +744,7 @@ pin_project! { #[pin] output: futures::channel::oneshot::Receiver>, finished: bool, + drop_helper: AbortOnDropSingle<()>, } } @@ -789,7 +796,7 @@ impl HashAggregateStream { let schema_clone = schema.clone(); let elapsed_compute = baseline_metrics.elapsed_compute().clone(); - tokio::spawn(async move { + let join_handle = tokio::spawn(async move { let result = compute_hash_aggregate( mode, schema_clone, @@ -800,13 +807,15 @@ impl HashAggregateStream { .await .record_output(&baseline_metrics); - tx.send(result) + // failing here is OK, the receiver is gone and does not care about the result + tx.send(result).ok(); }); Self { schema, output: rx, finished: false, + drop_helper: AbortOnDropSingle::new(join_handle), } } } @@ -1005,9 +1014,12 @@ mod tests { use arrow::array::{Float64Array, UInt32Array}; use arrow::datatypes::DataType; + use futures::FutureExt; use super::*; use crate::physical_plan::expressions::{col, Avg}; + use crate::test::assert_is_pending; + use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::{assert_batches_sorted_eq, physical_plan::common}; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -1230,4 +1242,73 @@ mod tests { check_aggregates(input).await } + + #[tokio::test] + async fn test_drop_cancel_without_groups() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + let groups = vec![]; + + let aggregates: Vec> = vec![Arc::new(Avg::new( + col("a", &schema)?, + "AVG(a)".to_string(), + DataType::Float64, + ))]; + + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let refs = blocking_exec.refs(); + let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates.clone(), + blocking_exec, + schema, + )?); + + let fut = crate::physical_plan::collect(hash_aggregate_exec); + let mut fut = fut.boxed(); + + assert_is_pending(&mut fut); + drop(fut); + assert_strong_count_converges_to_zero(refs).await; + + Ok(()) + } + + #[tokio::test] + async fn test_drop_cancel_with_groups() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + ])); + + let groups: Vec<(Arc, String)> = + vec![(col("a", &schema)?, "a".to_string())]; + + let aggregates: Vec> = vec![Arc::new(Avg::new( + col("b", &schema)?, + "AVG(b)".to_string(), + DataType::Float64, + ))]; + + let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); + let refs = blocking_exec.refs(); + let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates.clone(), + blocking_exec, + schema, + )?); + + let fut = crate::physical_plan::collect(hash_aggregate_exec); + let mut fut = fut.boxed(); + + assert_is_pending(&mut fut); + drop(fut); + assert_strong_count_converges_to_zero(refs).await; + + Ok(()) + } }