Skip to content

Commit

Permalink
Clean up spawned task on HashAggregateExec drop
Browse files Browse the repository at this point in the history
  • Loading branch information
crepererum committed Oct 15, 2021
1 parent c22f575 commit 4159a5c
Showing 1 changed file with 85 additions and 4 deletions.
89 changes: 85 additions & 4 deletions datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use pin_project_lite::pin_project;

use async_trait::async_trait;

use super::common::AbortOnDropSingle;
use super::metrics::{
self, BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
};
Expand Down Expand Up @@ -339,6 +340,7 @@ pin_project! {
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
finished: bool,
drop_helper: AbortOnDropSingle<()>,
}
}

Expand Down Expand Up @@ -561,7 +563,8 @@ impl GroupedHashAggregateStream {

let schema_clone = schema.clone();
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
tokio::spawn(async move {

let join_handle = tokio::spawn(async move {
let result = compute_grouped_hash_aggregate(
mode,
schema_clone,
Expand All @@ -572,13 +575,16 @@ impl GroupedHashAggregateStream {
)
.await
.record_output(&baseline_metrics);
tx.send(result)

// failing here is OK, the receiver is gone and does not care about the result
tx.send(result).ok();
});

Self {
schema,
output: rx,
finished: false,
drop_helper: AbortOnDropSingle::new(join_handle),
}
}
}
Expand Down Expand Up @@ -738,6 +744,7 @@ pin_project! {
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
finished: bool,
drop_helper: AbortOnDropSingle<()>,
}
}

Expand Down Expand Up @@ -789,7 +796,7 @@ impl HashAggregateStream {

let schema_clone = schema.clone();
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
tokio::spawn(async move {
let join_handle = tokio::spawn(async move {
let result = compute_hash_aggregate(
mode,
schema_clone,
Expand All @@ -800,13 +807,15 @@ impl HashAggregateStream {
.await
.record_output(&baseline_metrics);

tx.send(result)
// failing here is OK, the receiver is gone and does not care about the result
tx.send(result).ok();
});

Self {
schema,
output: rx,
finished: false,
drop_helper: AbortOnDropSingle::new(join_handle),
}
}
}
Expand Down Expand Up @@ -1005,9 +1014,12 @@ mod tests {

use arrow::array::{Float64Array, UInt32Array};
use arrow::datatypes::DataType;
use futures::FutureExt;

use super::*;
use crate::physical_plan::expressions::{col, Avg};
use crate::test::assert_is_pending;
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
use crate::{assert_batches_sorted_eq, physical_plan::common};

use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
Expand Down Expand Up @@ -1230,4 +1242,73 @@ mod tests {

check_aggregates(input).await
}

#[tokio::test]
async fn test_drop_cancel_without_groups() -> Result<()> {
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));

let groups = vec![];

let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("a", &schema)?,
"AVG(a)".to_string(),
DataType::Float64,
))];

let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
let refs = blocking_exec.refs();
let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
blocking_exec,
schema,
)?);

let fut = crate::physical_plan::collect(hash_aggregate_exec);
let mut fut = fut.boxed();

assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;

Ok(())
}

#[tokio::test]
async fn test_drop_cancel_with_groups() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, true),
Field::new("b", DataType::Float32, true),
]));

let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
vec![(col("a", &schema)?, "a".to_string())];

let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
col("b", &schema)?,
"AVG(b)".to_string(),
DataType::Float64,
))];

let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1));
let refs = blocking_exec.refs();
let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
aggregates.clone(),
blocking_exec,
schema,
)?);

let fut = crate::physical_plan::collect(hash_aggregate_exec);
let mut fut = fut.boxed();

assert_is_pending(&mut fut);
drop(fut);
assert_strong_count_converges_to_zero(refs).await;

Ok(())
}
}

0 comments on commit 4159a5c

Please sign in to comment.